Add NLP task models
This commit is contained in:
parent
d8beb17dfb
commit
69f6de0ace
46 changed files with 4976 additions and 0 deletions
428
joint_paraphrase_model/.gitignore
vendored
Normal file
428
joint_paraphrase_model/.gitignore
vendored
Normal file
|
@ -0,0 +1,428 @@
|
|||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python,latex
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,latex
|
||||
|
||||
### LaTeX ###
|
||||
## Core latex/pdflatex auxiliary files:
|
||||
*.aux
|
||||
*.lof
|
||||
*.log
|
||||
*.lot
|
||||
*.fls
|
||||
*.out
|
||||
*.toc
|
||||
*.fmt
|
||||
*.fot
|
||||
*.cb
|
||||
*.cb2
|
||||
.*.lb
|
||||
|
||||
## Intermediate documents:
|
||||
*.dvi
|
||||
*.xdv
|
||||
*-converted-to.*
|
||||
# these rules might exclude image files for figures etc.
|
||||
# *.ps
|
||||
# *.eps
|
||||
# *.pdf
|
||||
|
||||
## Generated if empty string is given at "Please type another file name for output:"
|
||||
.pdf
|
||||
|
||||
## Bibliography auxiliary files (bibtex/biblatex/biber):
|
||||
*.bbl
|
||||
*.bcf
|
||||
*.blg
|
||||
*-blx.aux
|
||||
*-blx.bib
|
||||
*.run.xml
|
||||
|
||||
## Build tool auxiliary files:
|
||||
*.fdb_latexmk
|
||||
*.synctex
|
||||
*.synctex(busy)
|
||||
*.synctex.gz
|
||||
*.synctex.gz(busy)
|
||||
*.pdfsync
|
||||
|
||||
## Build tool directories for auxiliary files
|
||||
# latexrun
|
||||
latex.out/
|
||||
|
||||
## Auxiliary and intermediate files from other packages:
|
||||
# algorithms
|
||||
*.alg
|
||||
*.loa
|
||||
|
||||
# achemso
|
||||
acs-*.bib
|
||||
|
||||
# amsthm
|
||||
*.thm
|
||||
|
||||
# beamer
|
||||
*.nav
|
||||
*.pre
|
||||
*.snm
|
||||
*.vrb
|
||||
|
||||
# changes
|
||||
*.soc
|
||||
|
||||
# comment
|
||||
*.cut
|
||||
|
||||
# cprotect
|
||||
*.cpt
|
||||
|
||||
# elsarticle (documentclass of Elsevier journals)
|
||||
*.spl
|
||||
|
||||
# endnotes
|
||||
*.ent
|
||||
|
||||
# fixme
|
||||
*.lox
|
||||
|
||||
# feynmf/feynmp
|
||||
*.mf
|
||||
*.mp
|
||||
*.t[1-9]
|
||||
*.t[1-9][0-9]
|
||||
*.tfm
|
||||
|
||||
#(r)(e)ledmac/(r)(e)ledpar
|
||||
*.end
|
||||
*.?end
|
||||
*.[1-9]
|
||||
*.[1-9][0-9]
|
||||
*.[1-9][0-9][0-9]
|
||||
*.[1-9]R
|
||||
*.[1-9][0-9]R
|
||||
*.[1-9][0-9][0-9]R
|
||||
*.eledsec[1-9]
|
||||
*.eledsec[1-9]R
|
||||
*.eledsec[1-9][0-9]
|
||||
*.eledsec[1-9][0-9]R
|
||||
*.eledsec[1-9][0-9][0-9]
|
||||
*.eledsec[1-9][0-9][0-9]R
|
||||
|
||||
# glossaries
|
||||
*.acn
|
||||
*.acr
|
||||
*.glg
|
||||
*.glo
|
||||
*.gls
|
||||
*.glsdefs
|
||||
*.lzo
|
||||
*.lzs
|
||||
|
||||
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
|
||||
# *.ist
|
||||
|
||||
# gnuplottex
|
||||
*-gnuplottex-*
|
||||
|
||||
# gregoriotex
|
||||
*.gaux
|
||||
*.gtex
|
||||
|
||||
# htlatex
|
||||
*.4ct
|
||||
*.4tc
|
||||
*.idv
|
||||
*.lg
|
||||
*.trc
|
||||
*.xref
|
||||
|
||||
# hyperref
|
||||
*.brf
|
||||
|
||||
# knitr
|
||||
*-concordance.tex
|
||||
# TODO Comment the next line if you want to keep your tikz graphics files
|
||||
*.tikz
|
||||
*-tikzDictionary
|
||||
|
||||
# listings
|
||||
*.lol
|
||||
|
||||
# luatexja-ruby
|
||||
*.ltjruby
|
||||
|
||||
# makeidx
|
||||
*.idx
|
||||
*.ilg
|
||||
*.ind
|
||||
|
||||
# minitoc
|
||||
*.maf
|
||||
*.mlf
|
||||
*.mlt
|
||||
*.mtc[0-9]*
|
||||
*.slf[0-9]*
|
||||
*.slt[0-9]*
|
||||
*.stc[0-9]*
|
||||
|
||||
# minted
|
||||
_minted*
|
||||
*.pyg
|
||||
|
||||
# morewrites
|
||||
*.mw
|
||||
|
||||
# nomencl
|
||||
*.nlg
|
||||
*.nlo
|
||||
*.nls
|
||||
|
||||
# pax
|
||||
*.pax
|
||||
|
||||
# pdfpcnotes
|
||||
*.pdfpc
|
||||
|
||||
# sagetex
|
||||
*.sagetex.sage
|
||||
*.sagetex.py
|
||||
*.sagetex.scmd
|
||||
|
||||
# scrwfile
|
||||
*.wrt
|
||||
|
||||
# sympy
|
||||
*.sout
|
||||
*.sympy
|
||||
sympy-plots-for-*.tex/
|
||||
|
||||
# pdfcomment
|
||||
*.upa
|
||||
*.upb
|
||||
|
||||
# pythontex
|
||||
*.pytxcode
|
||||
pythontex-files-*/
|
||||
|
||||
# tcolorbox
|
||||
*.listing
|
||||
|
||||
# thmtools
|
||||
*.loe
|
||||
|
||||
# TikZ & PGF
|
||||
*.dpth
|
||||
*.md5
|
||||
*.auxlock
|
||||
|
||||
# todonotes
|
||||
*.tdo
|
||||
|
||||
# vhistory
|
||||
*.hst
|
||||
*.ver
|
||||
|
||||
# easy-todo
|
||||
*.lod
|
||||
|
||||
# xcolor
|
||||
*.xcp
|
||||
|
||||
# xmpincl
|
||||
*.xmpi
|
||||
|
||||
# xindy
|
||||
*.xdy
|
||||
|
||||
# xypic precompiled matrices and outlines
|
||||
*.xyc
|
||||
*.xyd
|
||||
|
||||
# endfloat
|
||||
*.ttt
|
||||
*.fff
|
||||
|
||||
# Latexian
|
||||
TSWLatexianTemp*
|
||||
|
||||
## Editors:
|
||||
# WinEdt
|
||||
*.bak
|
||||
*.sav
|
||||
|
||||
# Texpad
|
||||
.texpadtmp
|
||||
|
||||
# LyX
|
||||
*.lyx~
|
||||
|
||||
# Kile
|
||||
*.backup
|
||||
|
||||
# gummi
|
||||
.*.swp
|
||||
|
||||
# KBibTeX
|
||||
*~[0-9]*
|
||||
|
||||
# TeXnicCenter
|
||||
*.tps
|
||||
|
||||
# auto folder when using emacs and auctex
|
||||
./auto/*
|
||||
*.el
|
||||
|
||||
# expex forward references with \gathertags
|
||||
*-tags.tex
|
||||
|
||||
# standalone packages
|
||||
*.sta
|
||||
|
||||
# Makeindex log files
|
||||
*.lpz
|
||||
|
||||
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
|
||||
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
|
||||
# Uncomment the next line to have this generated file ignored.
|
||||
#*Notes.bib
|
||||
|
||||
### LaTeX Patch ###
|
||||
# LIPIcs / OASIcs
|
||||
*.vtc
|
||||
|
||||
# glossaries
|
||||
*.glstex
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python,latex
|
3
joint_paraphrase_model/README.md
Normal file
3
joint_paraphrase_model/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# joint_paraphrase_model
|
||||
|
||||
joint training paraphrase model --- neurips
|
112
joint_paraphrase_model/config.py
Normal file
112
joint_paraphrase_model/config.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
|
||||
# general
|
||||
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
PAD = "<__PAD__>"
|
||||
UNK = "<__UNK__>"
|
||||
NOFIX = "<__NOFIX__>"
|
||||
SOS = "<__SOS__>"
|
||||
EOS = "<__EOS__>"
|
||||
|
||||
batch_size = 1
|
||||
teacher_forcing_ratio = 0.5
|
||||
embedding_dim = 300
|
||||
fix_hidden_dim = 128
|
||||
par_hidden_dim = 1024
|
||||
fix_dropout = 0.5
|
||||
par_dropout = 0.2
|
||||
_fix_learning_rate = 0.00001
|
||||
_par_learning_rate = 0.0001
|
||||
learning_rate = _par_learning_rate
|
||||
fix_momentum = 0.9
|
||||
par_momentum = 0.0
|
||||
max_length = 121
|
||||
epochs = 5
|
||||
|
||||
# paths
|
||||
data_path = "./data"
|
||||
provo_predictability_path = os.path.join(
|
||||
data_path, "datasets/provo/Provo_Corpus-Predictability_Norms.csv"
|
||||
)
|
||||
provo_eyetracking_path = os.path.join(
|
||||
data_path, "datasets/provo/Provo_Corpus-Eyetracking_Data.csv"
|
||||
)
|
||||
|
||||
geco_en_path = os.path.join(data_path, "datasets/geco/EnglishMaterial.csv")
|
||||
geco_mono_path = os.path.join(data_path, "datasets/geco/MonolingualReadingData.csv")
|
||||
|
||||
movieqa_human_path = os.path.join(data_path, "datasets/all_word_scores_fixations")
|
||||
movieqa_human_path_2 = os.path.join(
|
||||
data_path, "datasets/all_word_scores_fixations_exp2"
|
||||
)
|
||||
movieqa_human_path_3 = os.path.join(
|
||||
data_path, "datasets/all_word_scores_fixations_exp3"
|
||||
)
|
||||
movieqa_split_plot_path = os.path.join(data_path, "datasets/split_plot_UNRESOLVED")
|
||||
|
||||
cnn_path = os.path.join(
|
||||
data_path,
|
||||
"projects/2019/fixation_prediction/ez-reader-wrapper/predictability/output_cnn/",
|
||||
)
|
||||
dm_path = os.path.join(
|
||||
data_path,
|
||||
"projects/2019/fixation_prediction/ez-reader-wrapper/predictability/output_dm/",
|
||||
)
|
||||
|
||||
qqp_paws_basedir = os.path.join(data_path, "datasets/paw_google/qqp/paws_qqp/output")
|
||||
qqp_paws_train_path = os.path.join(qqp_paws_basedir, "train.tsv")
|
||||
qqp_paws_dev_path = os.path.join(qqp_paws_basedir, "dev.tsv")
|
||||
qqp_paws_test_path = os.path.join(qqp_paws_basedir, "test.tsv")
|
||||
|
||||
qqp_basedir = os.path.join(data_path, "datasets/Quora_question_pair_partition_OG")
|
||||
qqp_train_path = os.path.join(qqp_basedir, "train.tsv")
|
||||
qqp_dev_path = os.path.join(qqp_basedir, "dev.tsv")
|
||||
qqp_test_path = os.path.join(qqp_basedir, "test.tsv")
|
||||
|
||||
qqp_kag_basedir = os.path.join(data_path, "datasets/Quora_question_pair_partition_kag")
|
||||
qqp_kag_train_path = os.path.join(qqp_kag_basedir, "train.tsv")
|
||||
qqp_kag_dev_path = os.path.join(qqp_kag_basedir, "dev.tsv")
|
||||
qqp_kag_test_path = os.path.join(qqp_kag_basedir, "test.tsv")
|
||||
|
||||
wiki_basedir = os.path.join(data_path, "datasets/paw_google/wiki")
|
||||
wiki_train_path = os.path.join(wiki_basedir, "train.tsv")
|
||||
wiki_dev_path = os.path.join(wiki_basedir, "dev.tsv")
|
||||
wiki_test_path = os.path.join(wiki_basedir, "test.tsv")
|
||||
|
||||
msrpc_basedir = os.path.join(data_path, "datasets/MSRPC")
|
||||
msrpc_train_path = os.path.join(msrpc_basedir, "msr_paraphrase_train.txt")
|
||||
msrpc_dev_path = os.path.join(msrpc_basedir, "msr_paraphrase_dev.txt")
|
||||
msrpc_test_path = os.path.join(msrpc_basedir, "msr_paraphrase_test.txt")
|
||||
|
||||
sentiment_basedir = os.path.join(data_path, "datasets/sentiment_kag")
|
||||
sentiment_train_path = os.path.join(sentiment_basedir, "train.tsv")
|
||||
sentiment_dev_path = os.path.join(sentiment_basedir, "dev.tsv")
|
||||
sentiment_test_path = os.path.join(sentiment_basedir, "test.tsv")
|
||||
|
||||
tamil_basedir = os.path.join(data_path, "datasets/en-ta-parallel-v2")
|
||||
tamil_train_path = os.path.join(tamil_basedir, "corpus.bcn.train.enta")
|
||||
tamil_dev_path = os.path.join(tamil_basedir, "corpus.bcn.dev.enta")
|
||||
tamil_test_path = os.path.join(tamil_basedir, "corpus.bcn.test.enta")
|
||||
|
||||
compression_basedir = os.path.join(data_path, "datasets/sentence-compression/data")
|
||||
compression_train_path = os.path.join(compression_basedir, "train.tsv")
|
||||
compression_dev_path = os.path.join(compression_basedir, "dev.tsv")
|
||||
compression_test_path = os.path.join(compression_basedir, "test.tsv")
|
||||
|
||||
stanford_basedir = os.path.join(data_path, "datasets/stanfordSentimentTreebank")
|
||||
stanford_train_path = os.path.join(stanford_basedir, "train.tsv")
|
||||
stanford_dev_path = os.path.join(stanford_basedir, "dev.tsv")
|
||||
stanford_test_path = os.path.join(stanford_basedir, "test.tsv")
|
||||
|
||||
stanford_sent_basedir = os.path.join(data_path, "datasets/stanfordSentimentTreebank")
|
||||
stanford_sent_train_path = os.path.join(stanford_basedir, "train_sent.tsv")
|
||||
stanford_sent_dev_path = os.path.join(stanford_basedir, "dev_sent.tsv")
|
||||
stanford_sent_test_path = os.path.join(stanford_basedir, "test_sent.tsv")
|
||||
|
||||
|
||||
emb_path = os.path.join(data_path, "Google_word2vec/GoogleNews-vectors-negative300.bin")
|
||||
|
||||
glove_path = "glove.840B.300d.txt"
|
1
joint_paraphrase_model/data
Symbolic link
1
joint_paraphrase_model/data
Symbolic link
|
@ -0,0 +1 @@
|
|||
/netpool/work/gpu-2/users/soodea/
|
1
joint_paraphrase_model/glove.840B.300d.txt
Symbolic link
1
joint_paraphrase_model/glove.840B.300d.txt
Symbolic link
|
@ -0,0 +1 @@
|
|||
/netpool/work/gpu-2/users/soodea/datasets/glove/glove.840B.300d.txt
|
1
joint_paraphrase_model/glove.cache
Normal file
1
joint_paraphrase_model/glove.cache
Normal file
File diff suppressed because one or more lines are too long
0
joint_paraphrase_model/libs/__init__.py
Normal file
0
joint_paraphrase_model/libs/__init__.py
Normal file
416
joint_paraphrase_model/libs/corpora.py
Normal file
416
joint_paraphrase_model/libs/corpora.py
Normal file
|
@ -0,0 +1,416 @@
|
|||
import logging
|
||||
|
||||
import config
|
||||
|
||||
|
||||
def tokenize(sent):
|
||||
return sent.split(" ")
|
||||
|
||||
|
||||
class Lang:
|
||||
"""Represents the vocabulary
|
||||
"""
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.word2index = {
|
||||
config.PAD: 0,
|
||||
config.UNK: 1,
|
||||
config.NOFIX: 2,
|
||||
config.SOS: 3,
|
||||
config.EOS: 4,
|
||||
}
|
||||
self.word2count = {}
|
||||
self.index2word = {
|
||||
0: config.PAD,
|
||||
1: config.UNK,
|
||||
2: config.NOFIX,
|
||||
3: config.SOS,
|
||||
4: config.EOS,
|
||||
}
|
||||
self.n_words = 5
|
||||
|
||||
def add_sentence(self, sentence):
|
||||
assert isinstance(
|
||||
sentence, (list, tuple)
|
||||
), "input to add_sentence must be tokenized"
|
||||
for word in sentence:
|
||||
self.add_word(word)
|
||||
|
||||
def add_word(self, word):
|
||||
if word not in self.word2index:
|
||||
self.word2index[word] = self.n_words
|
||||
self.word2count[word] = 1
|
||||
self.index2word[self.n_words] = word
|
||||
self.n_words += 1
|
||||
else:
|
||||
self.word2count[word] += 1
|
||||
|
||||
def __add__(self, other):
|
||||
"""Returns a new Lang object containing the vocabulary from this and
|
||||
the other Lang object
|
||||
"""
|
||||
new_lang = Lang(f"{self.name}_{other.name}")
|
||||
|
||||
# Add vocabulary from both Langs
|
||||
for word in self.word2count.keys():
|
||||
new_lang.add_word(word)
|
||||
for word in other.word2count.keys():
|
||||
new_lang.add_word(word)
|
||||
|
||||
# Fix the counts on the new one
|
||||
for word in new_lang.word2count.keys():
|
||||
new_lang.word2count[word] = self.word2count.get(
|
||||
word, 0
|
||||
) + other.word2count.get(word, 0)
|
||||
|
||||
return new_lang
|
||||
|
||||
|
||||
def load_wiki(split):
|
||||
"""Load the Wiki from PAWs"""
|
||||
logger = logging.getLogger(f"{__name__}.load_wiki")
|
||||
lang = Lang("wiki")
|
||||
|
||||
if split == "train":
|
||||
path = config.wiki_train_path
|
||||
elif split == "val":
|
||||
path = config.wiki_dev_path
|
||||
elif split == "test":
|
||||
path = config.wiki_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, sent1, sent2, rating = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# MS makes the vocab for paraphrase the same
|
||||
return pairs, lang
|
||||
|
||||
|
||||
def load_qqp_paws(split):
|
||||
"""Load the QQP from PAWs"""
|
||||
logger = logging.getLogger(f"{__name__}.load_qqp_paws")
|
||||
lang = Lang("qqp_paws")
|
||||
|
||||
if split == "train":
|
||||
path = config.qqp_paws_train_path
|
||||
elif split == "val":
|
||||
path = config.qqp_paws_dev_path
|
||||
elif split == "test":
|
||||
path = config.qqp_paws_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, sent1, sent2, rating = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# MS makes the vocab for paraphrase the same
|
||||
return pairs, lang
|
||||
|
||||
def load_qqp(split):
|
||||
"""Load the QQP from Original"""
|
||||
logger = logging.getLogger(f"{__name__}.load_qqp")
|
||||
lang = Lang("qqp")
|
||||
|
||||
if split == "train":
|
||||
path = config.qqp_train_path
|
||||
elif split == "val":
|
||||
path = config.qqp_dev_path
|
||||
elif split == "test":
|
||||
path = config.qqp_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
rating, sent1, sent2, _ = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# MS makes the vocab for paraphrase the same
|
||||
return pairs, lang
|
||||
|
||||
|
||||
def load_qqp_kag(split):
|
||||
"""Load the QQP from Kaggle""" #not original right now, expriemnting with kaggle 100K, 3K, 30K split
|
||||
logger = logging.getLogger(f"{__name__}.load_qqp_kag")
|
||||
lang = Lang("qqp_kag")
|
||||
|
||||
if split == "train":
|
||||
path = config.qqp_kag_train_path
|
||||
elif split == "val":
|
||||
path = config.qqp_kag_dev_path
|
||||
elif split == "test":
|
||||
path = config.qqp_kag_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle: #when reading the kag version we do not have 4 fields, but rather 3
|
||||
rating, sent1, sent2 = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# MS makes the vocab for paraphrase the same
|
||||
return pairs, lang
|
||||
|
||||
|
||||
def load_msrpc(split):
|
||||
"""Load the Microsoft Research Paraphrase Corpus (MSRPC)"""
|
||||
logger = logging.getLogger(f"{__name__}.load_msrpc")
|
||||
lang = Lang("msrpc")
|
||||
|
||||
if split == "train":
|
||||
path = config.msrpc_train_path
|
||||
elif split == "val":
|
||||
path = config.msrpc_dev_path
|
||||
elif split == "test":
|
||||
path = config.msrpc_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
rating, _, _, sent1, sent2 = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# return src_lang, dst_lang, pairs
|
||||
# MS makes the vocab for paraphrase the same
|
||||
|
||||
return pairs, lang
|
||||
|
||||
def load_sentiment(split):
|
||||
"""Load the Sentiment Kaggle Comp Dataset"""
|
||||
logger = logging.getLogger(f"{__name__}.load_sentiment")
|
||||
lang = Lang("sentiment")
|
||||
|
||||
if split == "train":
|
||||
path = config.sentiment_train_path
|
||||
elif split == "val":
|
||||
path = config.sentiment_dev_path
|
||||
elif split == "test":
|
||||
path = config.sentiment_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, _, sent1, sent2 = line.strip().split("\t")
|
||||
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
||||
|
||||
|
||||
def load_tamil(split):
|
||||
"""Load the En to Tamil dataset, current SOTA ~13 bleu"""
|
||||
logger = logging.getLogger(f"{__name__}.load_tamil")
|
||||
lang = Lang("tamil")
|
||||
|
||||
if split == "train":
|
||||
path = config.tamil_train_path
|
||||
elif split == "val":
|
||||
path = config.tamil_dev_path
|
||||
elif split == "test":
|
||||
path = config.tamil_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
sent1, sent2 = line.strip().split("\t")
|
||||
#if rating == "0":
|
||||
# continue
|
||||
sent1 = tokenize(sent1)
|
||||
#I dunno how to tokenize tamil.....?
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
||||
|
||||
def load_compression(split):
|
||||
"""Load the Google Sentence Compression Dataset"""
|
||||
logger = logging.getLogger(f"{__name__}.load_compression")
|
||||
lang = Lang("compression")
|
||||
|
||||
if split == "train":
|
||||
path = config.compression_train_path
|
||||
elif split == "val":
|
||||
path = config.compression_dev_path
|
||||
elif split == "test":
|
||||
path = config.compression_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
sent1, sent2 = line.strip().split("\t")
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
# print(len(sent1), sent1)
|
||||
# print(len(sent2), sent2)
|
||||
# print()
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
||||
|
||||
def load_stanford(split):
|
||||
"""Load the Stanford Sentiment Dataset phrases"""
|
||||
logger = logging.getLogger(f"{__name__}.load_stanford")
|
||||
lang = Lang("stanford")
|
||||
|
||||
if split == "train":
|
||||
path = config.stanford_train_path
|
||||
elif split == "val":
|
||||
path = config.stanford_dev_path
|
||||
elif split == "test":
|
||||
path = config.stanford_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
#handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, _, sent1, sent2 = line.strip().split("\t")
|
||||
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
||||
|
||||
def load_stanford_sent(split):
|
||||
"""Load the Stanford Sentiment Dataset sentences"""
|
||||
logger = logging.getLogger(f"{__name__}.load_stanford_sent")
|
||||
lang = Lang("stanford_sent")
|
||||
|
||||
if split == "train":
|
||||
path = config.stanford_sent_train_path
|
||||
elif split == "val":
|
||||
path = config.stanford_sent_dev_path
|
||||
elif split == "test":
|
||||
path = config.stanford_sent_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
#handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, _, sent1, sent2 = line.strip().split("\t")
|
||||
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
|
@ -0,0 +1 @@
|
|||
from .main import *
|
131
joint_paraphrase_model/libs/fixation_generation/main.py
Normal file
131
joint_paraphrase_model/libs/fixation_generation/main.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
from collections import OrderedDict
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from .self_attention import Transformer
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence
|
||||
|
||||
|
||||
def random_embedding(vocab_size, embedding_dim):
|
||||
pretrain_emb = np.empty([vocab_size, embedding_dim])
|
||||
scale = np.sqrt(3.0 / embedding_dim)
|
||||
for index in range(vocab_size):
|
||||
pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
|
||||
return pretrain_emb
|
||||
|
||||
|
||||
def neg_log_likelihood_loss(outputs, batch_label, batch_size, seq_len):
|
||||
outputs = outputs.view(batch_size * seq_len, -1)
|
||||
score = F.log_softmax(outputs, 1)
|
||||
|
||||
loss = nn.NLLLoss(ignore_index=0, size_average=False)(
|
||||
score, batch_label.view(batch_size * seq_len)
|
||||
)
|
||||
loss = loss / batch_size
|
||||
_, tag_seq = torch.max(score, 1)
|
||||
tag_seq = tag_seq.view(batch_size, seq_len)
|
||||
|
||||
# print(score[0], tag_seq[0])
|
||||
|
||||
return loss, tag_seq
|
||||
|
||||
|
||||
def mse_loss(outputs, batch_label, batch_size, seq_len, word_seq_length):
|
||||
# score = torch.nn.functional.softmax(outputs, 1)
|
||||
score = torch.sigmoid(outputs)
|
||||
|
||||
mask = torch.zeros_like(score)
|
||||
for i, v in enumerate(word_seq_length):
|
||||
mask[i, 0:v] = 1
|
||||
|
||||
score = score * mask
|
||||
|
||||
loss = nn.MSELoss(reduction="sum")(
|
||||
score.view(batch_size, seq_len), batch_label.view(batch_size, seq_len)
|
||||
)
|
||||
|
||||
loss = loss / batch_size
|
||||
|
||||
return loss, score.view(batch_size, seq_len)
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_type,
|
||||
vocab_size,
|
||||
embedding_dim,
|
||||
dropout,
|
||||
hidden_dim,
|
||||
embeddings=None,
|
||||
attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(f"{__name__}")
|
||||
prelayers = OrderedDict()
|
||||
postlayers = OrderedDict()
|
||||
|
||||
if embedding_type in ("w2v", "glove"):
|
||||
if embeddings is not None:
|
||||
prelayers["embedding_layer"] = nn.Embedding.from_pretrained(embeddings)
|
||||
else:
|
||||
prelayers["embedding_layer"] = nn.Embedding(vocab_size, embedding_dim)
|
||||
prelayers["embedding_dropout_layer"] = nn.Dropout(dropout)
|
||||
embedding_dim = 300
|
||||
elif embedding_type == "bert":
|
||||
embedding_dim = 768
|
||||
|
||||
self.lstm = BiLSTM(embedding_dim, hidden_dim // 2, num_layers=1)
|
||||
postlayers["lstm_dropout_layer"] = nn.Dropout(dropout)
|
||||
|
||||
if attention:
|
||||
# increased compl with 1024D, and 16,16: for no att and att experiments
|
||||
# before: for the initial att and pretraining: heads 4 and layers 4, 128D
|
||||
# then was 128 D with heads 4 layer 1 = results for all IUI
|
||||
###postlayers["position_encodings"] = PositionalEncoding(hidden_dim)
|
||||
postlayers["attention_layer"] = Transformer(
|
||||
d_model=hidden_dim, n_heads=4, n_layers=1
|
||||
)
|
||||
|
||||
postlayers["ff_layer"] = nn.Linear(hidden_dim, hidden_dim // 2)
|
||||
postlayers["ff_activation"] = nn.ReLU()
|
||||
postlayers["output_layer"] = nn.Linear(hidden_dim // 2, 1)
|
||||
|
||||
self.logger.info(f"prelayers: {prelayers.keys()}")
|
||||
self.logger.info(f"postlayers: {postlayers.keys()}")
|
||||
|
||||
self.pre = nn.Sequential(prelayers)
|
||||
self.post = nn.Sequential(postlayers)
|
||||
|
||||
def forward(self, x, word_seq_length):
|
||||
x = self.pre(x)
|
||||
x = self.lstm(x, word_seq_length)
|
||||
#MS pritning fix model params
|
||||
#for p in self.parameters():
|
||||
# print(p.data)
|
||||
# break
|
||||
|
||||
return self.post(x.transpose(1, 0))
|
||||
|
||||
|
||||
class BiLSTM(nn.Module):
|
||||
def __init__(self, embedding_dim, lstm_hidden, num_layers):
|
||||
super().__init__()
|
||||
self.net = nn.LSTM(
|
||||
input_size=embedding_dim,
|
||||
hidden_size=lstm_hidden,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
def forward(self, x, word_seq_length):
|
||||
packed_words = pack_padded_sequence(x, word_seq_length, True, False)
|
||||
lstm_out, hidden = self.net(packed_words)
|
||||
lstm_out, _ = pad_packed_sequence(lstm_out)
|
||||
return lstm_out
|
|
@ -0,0 +1,131 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_hid, n_position=200):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
|
||||
# Not a parameter
|
||||
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
''' Sinusoid position encoding table '''
|
||||
# TODO: make it with torch instead of numpy
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
||||
|
||||
|
||||
class AttentionLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(AttentionLayer, self).__init__()
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
# Q: float32:[batch_size, n_queries, d_k]
|
||||
# K: float32:[batch_size, n_keys, d_k]
|
||||
# V: float32:[batch_size, n_keys, d_v]
|
||||
dk = K.shape[-1]
|
||||
dv = V.shape[-1]
|
||||
KT = torch.transpose(K, -1, -2)
|
||||
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
|
||||
# weight_logits: float32[batch_size, n_queries, n_keys]
|
||||
weights = F.softmax(weight_logits, dim=-1)
|
||||
# weight: float32[batch_size, n_queries, n_keys]
|
||||
return torch.bmm(weights, V)
|
||||
# return float32[batch_size, n_queries, dv]
|
||||
|
||||
|
||||
class MultiHeadedSelfAttentionLayer(nn.Module):
|
||||
def __init__(self, d_model, n_heads):
|
||||
super(MultiHeadedSelfAttentionLayer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
print('{} {}'.format(d_model, n_heads))
|
||||
assert d_model % n_heads == 0
|
||||
self.d_k = d_model // n_heads
|
||||
self.d_v = self.d_k
|
||||
self.attention_layer = AttentionLayer()
|
||||
self.W_Qs = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_k, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_Ks = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_k, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_Vs = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_v, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_O = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
# x:float32[batch_size, sequence_length, self.d_model]
|
||||
head_outputs = []
|
||||
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
|
||||
Q = W_Q(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||
K = W_K(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||
V = W_V(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_v]
|
||||
head_output = self.attention_layer(Q, K, V)
|
||||
# float32:[batch_size, sequence_length, self.d_v]
|
||||
head_outputs.append(head_output)
|
||||
concatenated = torch.cat(head_outputs, dim=-1)
|
||||
# concatenated float32:[batch_size, sequence_length, self.d_model]
|
||||
out = self.W_O(concatenated)
|
||||
# out float32:[batch_size, sequence_length, self.d_model]
|
||||
return out
|
||||
|
||||
class Feedforward(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
super(Feedforward, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.W1 = nn.Linear(d_model, d_model)
|
||||
self.W2 = nn.Linear(d_model, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# x: float32[batch_size, sequence_length, d_model]
|
||||
return self.W2(torch.relu(self.W1(x)))
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_model, n_heads, n_layers):
|
||||
super(Transformer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.attention_layers = nn.ModuleList([
|
||||
MultiHeadedSelfAttentionLayer(d_model, n_heads)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.ffs = nn.ModuleList([
|
||||
Feedforward(d_model)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
# x: float32[batch_size, sequence_length, self.d_model]
|
||||
for attention_layer, ff in zip(self.attention_layers, self.ffs):
|
||||
attention_out = attention_layer(x)
|
||||
# attention_out: float32[batch_size, sequence_length, self.d_model]
|
||||
x = F.layer_norm(x + attention_out, x.shape[2:])
|
||||
ff_out = ff(x)
|
||||
# ff_out: float32[batch_size, sequence_length, self.d_model]
|
||||
x = F.layer_norm(x + ff_out, x.shape[2:])
|
||||
return x
|
|
@ -0,0 +1 @@
|
|||
from .main import *
|
86
joint_paraphrase_model/libs/paraphrase_generation/main.py
Normal file
86
joint_paraphrase_model/libs/paraphrase_generation/main.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
class EncoderRNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, embeddings):
|
||||
super(EncoderRNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.embedding = nn.Embedding.from_pretrained(embeddings)
|
||||
self.gru = nn.GRU(input_size, hidden_size)
|
||||
|
||||
def forward(self, input, hidden):
|
||||
embedded = self.embedding(input).view(1, 1, -1)
|
||||
output = embedded
|
||||
output, hidden = self.gru(output, hidden)
|
||||
return output, hidden
|
||||
|
||||
def initHidden(self):
|
||||
return torch.zeros(1, 1, self.hidden_size)
|
||||
|
||||
class AttnDecoderRNN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
output_size,
|
||||
embeddings,
|
||||
dropout_p,
|
||||
max_length,
|
||||
):
|
||||
super(AttnDecoderRNN, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
self.dropout_p = dropout_p
|
||||
self.max_length = max_length
|
||||
|
||||
self.embedding = nn.Embedding.from_pretrained(embeddings) #for paragen
|
||||
#self.embedding = nn.Embedding(len(embeddings), 300) #for NMT with tamil, trying wiht senitment too
|
||||
self.attn = nn.Linear(self.input_size + self.hidden_size, self.max_length)
|
||||
self.attn_combine = nn.Linear(
|
||||
self.input_size + self.hidden_size, self.hidden_size
|
||||
)
|
||||
self.dropout = nn.Dropout(self.dropout_p)
|
||||
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
|
||||
self.out = nn.Linear(self.hidden_size, self.output_size)
|
||||
|
||||
def forward(self, input, hidden, encoder_outputs, fixations):
|
||||
embedded = self.embedding(input).view(1, 1, -1)
|
||||
embedded = self.dropout(embedded)
|
||||
|
||||
attn_weights = F.softmax(
|
||||
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1
|
||||
)
|
||||
|
||||
attn_weights = attn_weights * torch.nn.ConstantPad1d((0, attn_weights.shape[-1] - fixations.shape[-2]), 0)(fixations.squeeze().unsqueeze(0))
|
||||
|
||||
# attn_weights = torch.softmax(attn_weights * torch.nn.ConstantPad1d((0, attn_weights.shape[-1] - fixations.shape[-2]), 0)(fixations.squeeze().unsqueeze(0)), dim=1)
|
||||
|
||||
attn_applied = torch.bmm(
|
||||
attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)
|
||||
)
|
||||
|
||||
output = torch.cat((embedded[0], attn_applied[0]), 1)
|
||||
output = self.attn_combine(output).unsqueeze(0)
|
||||
|
||||
output = F.relu(output)
|
||||
output, hidden = self.gru(output, hidden)
|
||||
|
||||
# output = F.log_softmax(self.out(output[0]), dim=1)
|
||||
output = self.out(output[0])
|
||||
# output = F.log_softmax(output, dim=1)
|
||||
return output, hidden, attn_weights
|
225
joint_paraphrase_model/libs/utils.py
Normal file
225
joint_paraphrase_model/libs/utils.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
from nltk.translate.bleu_score import sentence_bleu
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import config
|
||||
|
||||
|
||||
plt.switch_backend("agg")
|
||||
|
||||
|
||||
def load_glove(vocabulary):
|
||||
logger = logging.getLogger(f"{__name__}.load_glove")
|
||||
logger.info("loading embeddings")
|
||||
try:
|
||||
with open(f"glove.cache") as h:
|
||||
cache = json.load(h)
|
||||
except:
|
||||
logger.info("cache doesn't exist")
|
||||
cache = {}
|
||||
cache[config.PAD] = [0] * 300
|
||||
cache[config.SOS] = [0] * 300
|
||||
cache[config.EOS] = [0] * 300
|
||||
cache[config.UNK] = [0] * 300
|
||||
cache[config.NOFIX] = [0] * 300
|
||||
else:
|
||||
logger.info("cache found")
|
||||
|
||||
cache_miss = False
|
||||
|
||||
if not set(vocabulary) <= set(cache):
|
||||
cache_miss = True
|
||||
logger.warn("cache miss, loading full embeddings")
|
||||
data = {}
|
||||
with open("glove.840B.300d.txt") as h:
|
||||
for line in h:
|
||||
word, *emb = line.strip().split()
|
||||
try:
|
||||
data[word] = [float(x) for x in emb]
|
||||
except:
|
||||
continue
|
||||
logger.info("finished loading full embeddings")
|
||||
for word in vocabulary:
|
||||
try:
|
||||
cache[word] = data[word]
|
||||
except KeyError:
|
||||
cache[word] = [0] * 300
|
||||
logger.info("cache updated")
|
||||
|
||||
embeddings = []
|
||||
for word in vocabulary:
|
||||
embeddings.append(torch.tensor(cache[word], dtype=torch.float32))
|
||||
embeddings = torch.stack(embeddings)
|
||||
|
||||
if cache_miss:
|
||||
with open(f"glove.cache", "w") as h:
|
||||
json.dump(cache, h)
|
||||
logger.info("cache saved")
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def tokenize(s):
|
||||
s = s.lower().strip()
|
||||
s = re.sub(r"([.!?])", r" \1", s)
|
||||
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
|
||||
s = s.split(" ")
|
||||
return s
|
||||
|
||||
|
||||
def indices_from_sentence(word2index, sentence, unknown_threshold):
|
||||
if unknown_threshold:
|
||||
return [
|
||||
word2index.get(
|
||||
word if random.random() > unknown_threshold else config.UNK,
|
||||
word2index[config.UNK],
|
||||
)
|
||||
for word in sentence
|
||||
]
|
||||
else:
|
||||
return [
|
||||
word2index.get(word, word2index[config.UNK]) for word in sentence
|
||||
]
|
||||
|
||||
|
||||
def tensor_from_sentence(word2index, sentence, unknown_threshold):
|
||||
# indices = [config.SOS]
|
||||
indices = indices_from_sentence(word2index, sentence, unknown_threshold)
|
||||
indices.append(word2index[config.EOS])
|
||||
return torch.tensor(indices, dtype=torch.long, device=config.DEV)
|
||||
|
||||
|
||||
def tensors_from_pair(word2index, pair, shuffle, unknown_threshold):
|
||||
tensors = [
|
||||
tensor_from_sentence(word2index, pair[0], unknown_threshold),
|
||||
tensor_from_sentence(word2index, pair[1], unknown_threshold),
|
||||
]
|
||||
if shuffle:
|
||||
random.shuffle(tensors)
|
||||
return tensors
|
||||
|
||||
|
||||
def bleu(reference, hypothesis, n=4): #not sure if this actually changes the n gram
|
||||
if n < 1:
|
||||
return 0
|
||||
weights = [1/n]*n
|
||||
return sentence_bleu([reference], hypothesis, weights)
|
||||
|
||||
|
||||
def pair_iter(pairs, word2index, shuffle=False, shuffle_pairs=False, unknown_threshold=0.00):
|
||||
if shuffle:
|
||||
pairs = pairs.copy()
|
||||
random.shuffle(pairs)
|
||||
for pair in pairs:
|
||||
tensor1, tensor2 = tensors_from_pair(word2index, (pair[0], pair[1]), shuffle_pairs, unknown_threshold)
|
||||
yield (tensor1,), (tensor2,)
|
||||
|
||||
|
||||
def sent_iter(sents, word2index, unknown_threshold=0.00):
|
||||
for sent in sents:
|
||||
tensor = tensor_from_sentence(word2index, sent, unknown_threshold)
|
||||
yield (tensor,)
|
||||
|
||||
|
||||
def batch_iter(pairs, word2index, batch_size, shuffle=False, unknown_threshold=0.00):
|
||||
for i in range(len(pairs) // batch_size):
|
||||
batch = pairs[i : i + batch_size]
|
||||
if len(batch) != batch_size:
|
||||
continue
|
||||
batch_tensors = [
|
||||
tensors_from_pair(word2index, (pair[0], pair[1]), shuffle, unknown_threshold)
|
||||
for pair in batch
|
||||
]
|
||||
|
||||
tensors1, tensors2 = zip(*batch_tensors)
|
||||
|
||||
# targets = torch.tensor(targets, dtype=torch.long, device=config.DEV)
|
||||
|
||||
# tensors1_lengths = [len(t) for t in tensors1]
|
||||
# tensors2_lengths = [len(t) for t in tensors2]
|
||||
|
||||
# tensors1 = nn.utils.rnn.pack_sequence(tensors1, enforce_sorted=False)
|
||||
# tensors2 = nn.utils.rnn.pack_sequence(tensors2, enforce_sorted=False)
|
||||
|
||||
yield tensors1, tensors2
|
||||
|
||||
|
||||
def asMinutes(s):
|
||||
m = math.floor(s / 60)
|
||||
s -= m * 60
|
||||
return "%dm %ds" % (m, s)
|
||||
|
||||
|
||||
def timeSince(since, percent):
|
||||
now = time.time()
|
||||
s = now - since
|
||||
es = s / (percent)
|
||||
rs = es - s
|
||||
return "%s (- %s)" % (asMinutes(s), asMinutes(rs))
|
||||
|
||||
|
||||
def showPlot(points):
|
||||
plt.figure()
|
||||
fig, ax = plt.subplots()
|
||||
# this locator puts ticks at regular intervals
|
||||
loc = ticker.MultipleLocator(base=0.2)
|
||||
ax.yaxis.set_major_locator(loc)
|
||||
plt.plot(points)
|
||||
|
||||
|
||||
def showAttention(input_sentence, output_words, attentions):
|
||||
# Set up figure with colorbar
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
cax = ax.matshow(attentions.numpy(), cmap="bone")
|
||||
fig.colorbar(cax)
|
||||
|
||||
# Set up axes
|
||||
ax.set_xticklabels([""] + input_sentence.split(" ") + ["<__EOS__>"], rotation=90)
|
||||
ax.set_yticklabels([""] + output_words)
|
||||
|
||||
# Show label at every tick
|
||||
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def evaluateAndShowAttention(input_sentence):
|
||||
output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence)
|
||||
print("input =", input_sentence)
|
||||
print("output =", " ".join(output_words))
|
||||
showAttention(input_sentence, output_words, attentions)
|
||||
|
||||
|
||||
def save_model(model, word2index, path):
|
||||
if not path.endswith(".tar"):
|
||||
path += ".tar"
|
||||
torch.save(
|
||||
{"weights": model.state_dict(), "word2index": word2index},
|
||||
path,
|
||||
)
|
||||
|
||||
|
||||
def load_model(path):
|
||||
checkpoint = torch.load(path)
|
||||
return checkpoint["weights"], checkpoint["word2index"]
|
||||
|
||||
|
||||
def extend_vocabulary(word2index, langs):
|
||||
for lang in langs:
|
||||
for word in lang.word2index:
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
return word2index
|
794
joint_paraphrase_model/main.py
Normal file
794
joint_paraphrase_model/main.py
Normal file
|
@ -0,0 +1,794 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
|
||||
import click
|
||||
import sacrebleu
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import tqdm
|
||||
|
||||
import config
|
||||
from libs import corpora
|
||||
from libs import utils
|
||||
from libs.fixation_generation import Network as FixNN
|
||||
from libs.paraphrase_generation import (
|
||||
EncoderRNN as ParEncNN,
|
||||
AttnDecoderRNN as ParDecNN,
|
||||
)
|
||||
|
||||
|
||||
cwd = os.path.dirname(__file__)
|
||||
|
||||
logger = logging.getLogger("main")
|
||||
|
||||
'''
|
||||
#qqp_paw sentences:
|
||||
|
||||
debug_sentences = [
|
||||
'What are the driving rules in Georgia versus Mississippi ?',
|
||||
'I want to be a child psychologist , what qualification do i need to become one ? Are there good and reputed psychology Institute or Colleges in India ?',
|
||||
'What is deep web and dark web and what are the contents of these sites ?',
|
||||
'What is difference between North Indian Brahmins and South Indian Brahmins ?',
|
||||
'Is carbon dioxide an ionic bond or a covalent bond ?',
|
||||
'How do accounts receivable and accounts payable differ ?',
|
||||
'Why did Wikipedia hide its audit history for ( superluminal ) successful speed experiments ?',
|
||||
'What makes a person simple , or inversely , complicated ?',
|
||||
'`` How do you say `` Miss you , too , `` in Spanish ? Are there multiple ways to say it ? ‘’',
|
||||
'What is the difference between dominant trait and recessive trait ?’',
|
||||
'`` What is the difference between `` seeing someone , `` `` dating someone , `` and `` having a girlfriend/boyfriend `` ? ‘’',
|
||||
'How was the Empire State building built and designed ? How is it used ?',
|
||||
'What is the sum of the square roots of first n natural number ?',
|
||||
'Why is Roman Saini not so active on Quora now a days ?',
|
||||
'If I have someone blocked on Instagram , and see their story , can they view I viewed it ?',
|
||||
'Amongst the major IT companies of India which is the best ; Wipro , Capgemini , Infosys , TCS or is Oracle the best ?',
|
||||
'How much mass does Saturn lose each year ? How much mass does it gain ?',
|
||||
'What is a cheap healthy diet , I can keep the same and eat every day ?',
|
||||
' What is it like to be beautiful ? Not just pretty or hot , but the kind of almost objective beauty that people are sometimes intimidated by ?',
|
||||
'Could someone tell of a James Ronsey ( misspelled likely ) , writer and filmmaker , probably of the British Isles ?',
|
||||
'How much pressure Is there around the core of Pluto ? is it enough to turn hydrogen/helium gas into a liquid or metallic state ?',
|
||||
'How does quality of life in Vancouver compare to that in Melbourne Or Brisbane ?',
|
||||
]
|
||||
'''
|
||||
'''
|
||||
#wiki sentences:
|
||||
|
||||
debug_sentences = [
|
||||
'They were there to enjoy us and they were there to pray for us .',
|
||||
'Components of elastic potential systems store mechanical energy if they are deformed when forces are applied to the system .',
|
||||
'Steam can also be used , and does not need to be pumped .',
|
||||
'The solar approach to this requirement is the use of solar panels in a conventional-powered aircraft .',
|
||||
'Daudkhali is a village in Barisal Division in the Pirojpur district in southwestern Bangladesh .',
|
||||
'Briggs later met Briggs at the 1967 Monterey Pop Festival , where Ravi Shankar was also performing , with Eric Burdon and The Animals .',
|
||||
'Brockton is approximately 25 miles northeast of Providence , Rhode Island , and 30 miles south of Boston .',
|
||||
]
|
||||
'''
|
||||
|
||||
#qqp sentences:
|
||||
|
||||
debug_sentences = [
|
||||
'How do I get funding for my web based startup idea ?',
|
||||
'What do intelligent people do to pass time ?',
|
||||
'Which is the best SEO Company in Delhi ?',
|
||||
'Why do you waer makeup ?',
|
||||
'How do start chatting with a girl ?',
|
||||
'What is the meaning of living life ?',
|
||||
'Why do my armpits hurt ?',
|
||||
'Why does eye color change with age ?',
|
||||
'How do you find the standard deviation of a probability distribution ? What are some examples ?',
|
||||
'How can I complete my 11 syllabus in one month ?',
|
||||
'How do I concentrate better on my studies ?',
|
||||
'Which is the best retirement plan in india ?',
|
||||
'Should I tell my best friend I love her ?',
|
||||
'Which is the best company for Appian Vagrant online job support ?',
|
||||
'How can one do for good handwriting ?',
|
||||
'What are remedies to get rid of belly fat ?',
|
||||
'What is the best way to cook precooked turkey ?',
|
||||
'What is the future of e-commerce in India ?',
|
||||
'Why do my burps taste like rotten eggs ?',
|
||||
'What is an example of chemical weathering ?',
|
||||
'What are some of the advantages and disadvantages of cyber schooling ?',
|
||||
'How can I increase traffic to my websites by Facebook ?',
|
||||
'How do I increase my patience level in life ?',
|
||||
'What are the best hospitals for treating cancer in India ?',
|
||||
'Will Jio sim work in a 3G phone ? If yes , how ?',
|
||||
]
|
||||
|
||||
debug_sentences = [s.split(" ") for s in debug_sentences]
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
word2index,
|
||||
embeddings,
|
||||
):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(f"{__name__}")
|
||||
self.word2index = word2index
|
||||
self.index2word = {i: k for k, i in word2index.items()}
|
||||
self.fix_gen = FixNN(
|
||||
embedding_type="glove",
|
||||
vocab_size=len(word2index),
|
||||
embedding_dim=config.embedding_dim,
|
||||
embeddings=embeddings,
|
||||
dropout=config.fix_dropout,
|
||||
hidden_dim=config.fix_hidden_dim,
|
||||
)
|
||||
self.par_enc = ParEncNN(
|
||||
input_size=config.embedding_dim,
|
||||
hidden_size=config.par_hidden_dim,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
self.par_dec = ParDecNN(
|
||||
input_size=config.embedding_dim,
|
||||
hidden_size=config.par_hidden_dim,
|
||||
output_size=len(word2index),
|
||||
embeddings=embeddings,
|
||||
dropout_p=config.par_dropout,
|
||||
max_length=config.max_length,
|
||||
)
|
||||
|
||||
def forward(self, x, target=None, teacher_forcing_ratio=None):
|
||||
teacher_forcing_ratio = teacher_forcing_ratio if teacher_forcing_ratio is not None else config.teacher_forcing_ratio
|
||||
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
|
||||
x2 = nn.utils.rnn.pad_sequence(x, batch_first=False)
|
||||
fixations = torch.sigmoid(self.fix_gen(x1, [len(_x) for _x in x1]))
|
||||
|
||||
enc_hidden = self.par_enc.initHidden().to(config.DEV)
|
||||
enc_outs = torch.zeros(config.max_length, config.par_hidden_dim, device=config.DEV)
|
||||
|
||||
for ei in range(len(x2)):
|
||||
enc_out, enc_hidden = self.par_enc(x2[ei], enc_hidden)
|
||||
enc_outs[ei] += enc_out[0, 0]
|
||||
|
||||
dec_in = torch.tensor([[self.word2index[config.SOS]]], device=config.DEV) # SOS
|
||||
dec_hidden = enc_hidden
|
||||
dec_outs = []
|
||||
dec_words = []
|
||||
dec_atts = torch.zeros(config.max_length, config.max_length)
|
||||
|
||||
if target is not None: # training
|
||||
target = nn.utils.rnn.pad_sequence(target, batch_first=False)
|
||||
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
||||
|
||||
if use_teacher_forcing:
|
||||
for di in range(len(target)):
|
||||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||
dec_in, dec_hidden, enc_outs, fixations
|
||||
)
|
||||
dec_outs.append(dec_out)
|
||||
dec_atts[di] = dec_att.data
|
||||
dec_input = target[di]
|
||||
|
||||
else:
|
||||
for di in range(len(target)):
|
||||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||
dec_in, dec_hidden, enc_outs, fixations
|
||||
)
|
||||
dec_outs.append(dec_out)
|
||||
dec_atts[di] = dec_att.data
|
||||
topv, topi = dec_out.data.topk(1)
|
||||
dec_words.append(self.index2word[topi.item()])
|
||||
|
||||
dec_input = topi.squeeze().detach()
|
||||
|
||||
else: # prediction
|
||||
for di in range(config.max_length):
|
||||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||
dec_in, dec_hidden, enc_outs, fixations
|
||||
)
|
||||
dec_outs.append(dec_out)
|
||||
dec_atts[di] = dec_att.data
|
||||
topv, topi = dec_out.data.topk(1)
|
||||
if topi.item() == self.word2index[config.EOS]:
|
||||
dec_words.append("<__EOS__>")
|
||||
break
|
||||
else:
|
||||
dec_words.append(self.index2word[topi.item()])
|
||||
|
||||
dec_input = topi.squeeze().detach()
|
||||
|
||||
return dec_outs, dec_words, dec_atts[: di + 1], fixations
|
||||
|
||||
|
||||
def load_corpus(corpus_name, splits):
|
||||
if not splits:
|
||||
return
|
||||
|
||||
logger.info("loading corpus")
|
||||
if corpus_name == "msrpc":
|
||||
load_fn = corpora.load_msrpc
|
||||
elif corpus_name == "qqp":
|
||||
load_fn = corpora.load_qqp
|
||||
elif corpus_name == "wiki":
|
||||
load_fn = corpora.load_wiki
|
||||
elif corpus_name == "qqp_paws":
|
||||
load_fn = corpora.load_qqp_paws
|
||||
elif corpus_name == "qqp_kag":
|
||||
load_fn = corpora.load_qqp_kag
|
||||
elif corpus_name == "sentiment":
|
||||
load_fn = corpora.load_sentiment
|
||||
elif corpus_name == "stanford":
|
||||
load_fn = corpora.load_stanford
|
||||
elif corpus_name == "stanford_sent":
|
||||
load_fn = corpora.load_stanford_sent
|
||||
elif corpus_name == "tamil":
|
||||
load_fn = corpora.load_tamil
|
||||
elif corpus_name == "compression":
|
||||
load_fn = corpora.load_compression
|
||||
|
||||
corpus = {}
|
||||
langs = []
|
||||
|
||||
if "train" in splits:
|
||||
train_pairs, train_lang = load_fn("train")
|
||||
corpus["train"] = train_pairs
|
||||
langs.append(train_lang)
|
||||
if "val" in splits:
|
||||
val_pairs, val_lang = load_fn("val")
|
||||
corpus["val"] = val_pairs
|
||||
langs.append(val_lang)
|
||||
if "test" in splits:
|
||||
test_pairs, test_lang = load_fn("test")
|
||||
corpus["test"] = test_pairs
|
||||
langs.append(test_lang)
|
||||
|
||||
logger.info("creating word index")
|
||||
lang = langs[0]
|
||||
for _lang in langs[1:]:
|
||||
lang += _lang
|
||||
word2index = lang.word2index
|
||||
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
|
||||
return corpus, word2index, index2word
|
||||
|
||||
|
||||
def init_network(word2index):
|
||||
logger.info("loading embeddings")
|
||||
vocabulary = sorted(word2index.keys())
|
||||
embeddings = utils.load_glove(vocabulary)
|
||||
|
||||
logger.info("initializing model")
|
||||
network = Network(
|
||||
word2index=word2index,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
network.to(config.DEV)
|
||||
|
||||
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
|
||||
|
||||
return network
|
||||
|
||||
|
||||
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
|
||||
@click.option("-v", "--verbose", count=True)
|
||||
@click.option("-d", "--debug", is_flag=True)
|
||||
def main(verbose, debug):
|
||||
if verbose == 0:
|
||||
loglevel = logging.ERROR
|
||||
elif verbose == 1:
|
||||
loglevel = logging.WARN
|
||||
elif verbose >= 2:
|
||||
loglevel = logging.INFO
|
||||
|
||||
if debug:
|
||||
loglevel = logging.DEBUG
|
||||
|
||||
logging.basicConfig(
|
||||
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
|
||||
datefmt="%d.%m. %H:%M:%S",
|
||||
level=loglevel,
|
||||
)
|
||||
|
||||
logger.debug("arguments: %s" % str(sys.argv))
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||
)
|
||||
@click.option("-m", "--model_name", required=True)
|
||||
@click.option("-w", "--fixation_weights", required=False)
|
||||
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
|
||||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||
def train(corpus_name, model_name, fixation_weights, freeze_fixations, bleu):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
|
||||
train_pairs = corpus["train"]
|
||||
val_pairs = corpus["val"]
|
||||
network = init_network(word2index)
|
||||
|
||||
model_dir = os.path.join("models", model_name)
|
||||
logger.debug("creating model dir %s" % model_dir)
|
||||
pathlib.Path(model_dir).mkdir(parents=True)
|
||||
|
||||
if fixation_weights is not None:
|
||||
logger.info("loading fixation prediction checkpoint")
|
||||
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
else:
|
||||
weights = checkpoint
|
||||
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")}
|
||||
network.fix_gen.load_state_dict(weights, strict=False)
|
||||
|
||||
if freeze_fixations:
|
||||
logger.info("freezing fixation generation network")
|
||||
for p in network.fix_gen.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
|
||||
#optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=1e-5)
|
||||
|
||||
best_val_loss = None
|
||||
|
||||
epoch = 1
|
||||
while True:
|
||||
train_batch_iter = utils.pair_iter(pairs=train_pairs, word2index=word2index, shuffle=True, shuffle_pairs=False)
|
||||
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
# test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
|
||||
running_train_loss = 0
|
||||
total_train_loss = 0
|
||||
total_val_loss = 0
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
running_train_bleu = 0
|
||||
total_train_bleu = 0
|
||||
total_val_bleu = 0
|
||||
elif bleu == "nltk":
|
||||
running_train_bleu_1 = 0
|
||||
running_train_bleu_2 = 0
|
||||
running_train_bleu_3 = 0
|
||||
running_train_bleu_4 = 0
|
||||
total_train_bleu_1 = 0
|
||||
total_train_bleu_2 = 0
|
||||
total_train_bleu_3 = 0
|
||||
total_train_bleu_4 = 0
|
||||
total_val_bleu_1 = 0
|
||||
total_val_bleu_2 = 0
|
||||
total_val_bleu_3 = 0
|
||||
total_val_bleu_4 = 0
|
||||
|
||||
network.train()
|
||||
for i, batch in enumerate(train_batch_iter, 1):
|
||||
optimizer.zero_grad()
|
||||
|
||||
input, target = batch
|
||||
prediction, words, attention, fixations = network(input, target)
|
||||
|
||||
loss = loss_fn(
|
||||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_train_loss += loss.item()
|
||||
total_train_loss += loss.item()
|
||||
|
||||
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
|
||||
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
|
||||
running_train_bleu += bleu_score
|
||||
total_train_bleu += bleu_score
|
||||
elif bleu == "nltk":
|
||||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||
running_train_bleu_1 += bleu_1_score
|
||||
running_train_bleu_2 += bleu_2_score
|
||||
running_train_bleu_3 += bleu_3_score
|
||||
running_train_bleu_4 += bleu_4_score
|
||||
total_train_bleu_1 += bleu_1_score
|
||||
total_train_bleu_2 += bleu_2_score
|
||||
total_train_bleu_3 += bleu_3_score
|
||||
total_train_bleu_4 += bleu_4_score
|
||||
# print(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist())
|
||||
|
||||
if i % 100 == 0:
|
||||
if bleu == "sacrebleu":
|
||||
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
|
||||
elif bleu == "nltk":
|
||||
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
|
||||
|
||||
network.eval()
|
||||
with open(os.path.join(model_dir, f"debug_{epoch}_{i}.out"), "w") as h:
|
||||
if bleu == "sacrebleu":
|
||||
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
|
||||
running_train_bleu = 0
|
||||
elif bleu == "nltk":
|
||||
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
|
||||
running_train_bleu_1 = 0
|
||||
running_train_bleu_2 = 0
|
||||
running_train_bleu_3 = 0
|
||||
running_train_bleu_4 = 0
|
||||
|
||||
running_train_loss = 0
|
||||
|
||||
h.write("\n")
|
||||
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
|
||||
h.write("\n")
|
||||
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
|
||||
prediction, words, attentions, fixations = network(input)
|
||||
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
|
||||
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
|
||||
h.write("\n")
|
||||
|
||||
network.train()
|
||||
|
||||
network.eval()
|
||||
for i, batch in enumerate(val_batch_iter):
|
||||
input, target = batch
|
||||
prediction, words, attention, fixations = network(input, target, teacher_forcing_ratio=0)
|
||||
loss = loss_fn(
|
||||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||
)
|
||||
|
||||
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
|
||||
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
|
||||
if bleu == "sacrebleu":
|
||||
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
|
||||
elif bleu == "nltk":
|
||||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||
total_val_bleu_1 += bleu_1_score
|
||||
total_val_bleu_2 += bleu_2_score
|
||||
total_val_bleu_3 += bleu_3_score
|
||||
total_val_bleu_4 += bleu_4_score
|
||||
|
||||
total_val_loss += loss.item()
|
||||
|
||||
avg_val_loss = total_val_loss/len(val_pairs)
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu {total_train_bleu/len(train_pairs):.2f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
|
||||
elif bleu == "nltk":
|
||||
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||
|
||||
with open(os.path.join(model_dir, f"debug_{epoch}_end.out"), "w") as h:
|
||||
if bleu == "sacrebleu":
|
||||
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs)} avg_val_loss {total_val_loss/len(val_pairs)} avg_train_bleu {total_train_bleu/len(train_pairs)} avg_val_bleu {total_val_bleu/len(val_pairs)}")
|
||||
elif bleu == "nltk":
|
||||
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||
h.write("\n")
|
||||
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
|
||||
h.write("\n")
|
||||
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
|
||||
prediction, words, attentions, fixations = network(input)
|
||||
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
|
||||
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
|
||||
h.write("\n")
|
||||
|
||||
utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_{epoch}"))
|
||||
|
||||
if best_val_loss is None or avg_val_loss < best_val_loss:
|
||||
if best_val_loss is not None:
|
||||
logger.info(f"{avg_val_loss} < {best_val_loss} ({avg_val_loss-best_val_loss}): new best model from epoch {epoch}")
|
||||
else:
|
||||
logger.info(f"{avg_val_loss} < {best_val_loss}: new best model from epoch {epoch}")
|
||||
|
||||
best_val_loss = avg_val_loss
|
||||
# save_model(model, word2index, model_name + "_epoch_" + str(epoch))
|
||||
# utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_best"))
|
||||
|
||||
epoch += 1
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||
)
|
||||
@click.option("-w", "--model_weights", required=True)
|
||||
@click.option("-s", "--sentence_statistics", is_flag=True)
|
||||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||
def val(corpus_name, model_weights, sentence_statistics, bleu):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
|
||||
val_pairs = corpus["val"]
|
||||
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index)
|
||||
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# make a new output layer to match the weights from the checkpoint
|
||||
# we cannot remove it like we did with the embedding layers because
|
||||
# unlike those the output layer actually contains learned parameters
|
||||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
|
||||
total_val_loss = 0
|
||||
if bleu == "sacrebleu":
|
||||
total_val_bleu = 0
|
||||
elif bleu == "nltk":
|
||||
total_val_bleu_1 = 0
|
||||
total_val_bleu_2 = 0
|
||||
total_val_bleu_3 = 0
|
||||
total_val_bleu_4 = 0
|
||||
|
||||
network.eval()
|
||||
for i, batch in enumerate(val_batch_iter, 1):
|
||||
input, target = batch
|
||||
prediction, words, attentions, fixations = network(input, target)
|
||||
|
||||
loss = loss_fn(
|
||||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||
)
|
||||
total_val_loss += loss.item()
|
||||
|
||||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
|
||||
_target = [index2word[_x] for _x in target[0].tolist()]
|
||||
if bleu == "sacrebleu":
|
||||
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join(_target)).score
|
||||
total_val_bleu += bleu_score
|
||||
elif bleu == "nltk":
|
||||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||
total_val_bleu_1 += bleu_1_score
|
||||
total_val_bleu_2 += bleu_2_score
|
||||
total_val_bleu_3 += bleu_3_score
|
||||
total_val_bleu_4 += bleu_4_score
|
||||
|
||||
if sentence_statistics:
|
||||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"{bleu_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
|
||||
elif bleu == "nltk":
|
||||
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
|
||||
elif bleu == "nltk":
|
||||
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||
)
|
||||
@click.option("-w", "--model_weights", required=True)
|
||||
@click.option("-s", "--sentence_statistics", is_flag=True)
|
||||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||
def test(corpus_name, model_weights, sentence_statistics, bleu):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
||||
test_pairs = corpus["test"]
|
||||
|
||||
if model_weights is not None:
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index)
|
||||
|
||||
if model_weights is not None:
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# make a new output layer to match the weights from the checkpoint
|
||||
# we cannot remove it like we did with the embedding layers because
|
||||
# unlike those the output layer actually contains learned parameters
|
||||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
|
||||
total_test_loss = 0
|
||||
if bleu == "sacrebleu":
|
||||
total_test_bleu = 0
|
||||
elif bleu == "nltk":
|
||||
total_test_bleu_1 = 0
|
||||
total_test_bleu_2 = 0
|
||||
total_test_bleu_3 = 0
|
||||
total_test_bleu_4 = 0
|
||||
|
||||
network.eval()
|
||||
for i, batch in enumerate(test_batch_iter, 1):
|
||||
input, target = batch
|
||||
|
||||
prediction, words, attentions, fixations = network(input, target)
|
||||
|
||||
loss = loss_fn(
|
||||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||
)
|
||||
total_test_loss += loss.item()
|
||||
|
||||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
|
||||
_target = [index2word[_x] for _x in target[0].tolist()]
|
||||
if bleu == "sacrebleu":
|
||||
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join( _target)).score
|
||||
total_test_bleu += bleu_score
|
||||
elif bleu == "nltk":
|
||||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||
total_test_bleu_1 += bleu_1_score
|
||||
total_test_bleu_2 += bleu_2_score
|
||||
total_test_bleu_3 += bleu_3_score
|
||||
total_test_bleu_4 += bleu_4_score
|
||||
|
||||
if sentence_statistics:
|
||||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"{bleu_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||
elif bleu == "nltk":
|
||||
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu {total_test_bleu/len(test_pairs):.2f}")
|
||||
elif bleu == "nltk":
|
||||
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu_1 {total_test_bleu_1/len(test_pairs):.2f} avg_test_bleu_2 {total_test_bleu_2/len(test_pairs):.2f} avg_test_bleu_3 {total_test_bleu_3/len(test_pairs):.2f} avg_test_bleu_4 {total_test_bleu_4/len(test_pairs):.2f}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||
)
|
||||
@click.option("-w", "--model_weights", required=False)
|
||||
def predict(corpus_name, model_weights):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
|
||||
test_pairs = corpus["val"]
|
||||
|
||||
if model_weights is not None:
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index)
|
||||
|
||||
logger.info(f"vocab size {len(word2index)}")
|
||||
|
||||
if model_weights is not None:
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# make a new output layer to match the weights from the checkpoint
|
||||
# we cannot remove it like we did with the embedding layers because
|
||||
# unlike those the output layer actually contains learned parameters
|
||||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
|
||||
network.eval()
|
||||
for i, batch in enumerate(test_batch_iter, 1):
|
||||
input, target = batch
|
||||
|
||||
prediction, words, attentions, fixations = network(input, target)
|
||||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
|
||||
|
||||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
|
||||
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("-w", "--model_weights", required=True)
|
||||
@click.argument("path")
|
||||
def predict_file(model_weights, path):
|
||||
logger.info("loading sentences")
|
||||
sentences = []
|
||||
lang = corpora.Lang("pred")
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
if line:
|
||||
sentence = line.split(" ")
|
||||
lang.add_sentence(sentence)
|
||||
sentences.append(sentence)
|
||||
word2index = lang.word2index
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
|
||||
logger.info(f"{len(sentences)} sentences loaded")
|
||||
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index)
|
||||
|
||||
logger.info(f"vocab size {len(word2index)}")
|
||||
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# make a new output layer to match the weights from the checkpoint
|
||||
# we cannot remove it like we did with the embedding layers because
|
||||
# unlike those the output layer actually contains learned parameters
|
||||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
debug_sentence_iter = utils.sent_iter(sentences, word2index=word2index)
|
||||
|
||||
network.eval()
|
||||
for i, input in enumerate(debug_sentence_iter, 1):
|
||||
|
||||
prediction, words, attentions, fixations = network(input)
|
||||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
|
||||
|
||||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
|
||||
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
19
joint_paraphrase_model/requirements.txt
Normal file
19
joint_paraphrase_model/requirements.txt
Normal file
|
@ -0,0 +1,19 @@
|
|||
click==7.1.2
|
||||
cycler==0.10.0
|
||||
dataclasses==0.6
|
||||
future==0.18.2
|
||||
joblib==0.17.0
|
||||
kiwisolver==1.3.1
|
||||
matplotlib==3.3.3
|
||||
nltk==3.5
|
||||
numpy==1.19.4
|
||||
Pillow==8.0.1
|
||||
portalocker==2.0.0
|
||||
pyparsing==2.4.7
|
||||
python-dateutil==2.8.1
|
||||
regex==2020.11.13
|
||||
sacrebleu==1.4.6
|
||||
six==1.15.0
|
||||
torch==1.7.0
|
||||
tqdm==4.54.1
|
||||
typing-extensions==3.7.4.3
|
43
joint_paraphrase_model/utils/long_sentence_split.py
Normal file
43
joint_paraphrase_model/utils/long_sentence_split.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def read(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
b, s, p, a, f = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line {line}", file=sys.stderr)
|
||||
continue
|
||||
else:
|
||||
yield b, s, p, a, f
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path")
|
||||
def main(path):
|
||||
data = list(read(path))
|
||||
avg_len = sum(len(x[1]) for x in data)/len(data)
|
||||
filtered_data = []
|
||||
filtered_data2 = []
|
||||
|
||||
fname, ext = os.path.splitext(path)
|
||||
ext = f".{ext}" if ext else ext
|
||||
with open(f"{fname}_long{ext}", "w") as lh, open(f"{fname}_short{ext}", "w") as sh:
|
||||
for x in data:
|
||||
if len(x[1]) > avg_len:
|
||||
lh.write("\t".join(x))
|
||||
lh.write("\n")
|
||||
else:
|
||||
sh.write("\t".join(x))
|
||||
sh.write("\n")
|
||||
|
||||
print(f"avg sentence length {avg_len}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
40
joint_paraphrase_model/utils/long_sentence_stats.py
Normal file
40
joint_paraphrase_model/utils/long_sentence_stats.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import sys
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def read(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
b, s, p, *_ = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line {line}", file=sys.stderr)
|
||||
continue
|
||||
else:
|
||||
yield float(b), s, p
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path")
|
||||
def main(path):
|
||||
data = list(read(path))
|
||||
avg_len = sum(len(x[1]) for x in data)/len(data)
|
||||
filtered_data = []
|
||||
filtered_data2 = []
|
||||
for x in data:
|
||||
if len(x[1]) > avg_len:
|
||||
filtered_data.append(x)
|
||||
else:
|
||||
filtered_data2.append(x)
|
||||
print(f"avg sentence length {avg_len}")
|
||||
print(f"long sentences {len(filtered_data)}")
|
||||
print(f"short sentences {len(filtered_data2)}")
|
||||
print(f"total bleu {sum(x[0] for x in data)/len(data)}")
|
||||
print(f"longest bleu {sum(x[0] for x in filtered_data)/len(filtered_data)}")
|
||||
print(f"shortest bleu {sum(x[0] for x in filtered_data2)/len(filtered_data2)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
64
joint_paraphrase_model/utils/plot_attention.py
Normal file
64
joint_paraphrase_model/utils/plot_attention.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import ast
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import text_attention
|
||||
|
||||
import click
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("agg")
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
|
||||
def plot_attention(input_sentence, output_words, attentions, path):
|
||||
# Set up figure with colorbar
|
||||
attentions = np.array(attentions)[:,:len(input_sentence)]
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
cax = ax.matshow(attentions, cmap="bone")
|
||||
fig.colorbar(cax)
|
||||
|
||||
# Set up axes
|
||||
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
|
||||
ax.set_yticklabels([""] + output_words)
|
||||
|
||||
# Show label at every tick
|
||||
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
|
||||
plt.savefig(f"{path}.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
def parse(p):
|
||||
with open(p) as h:
|
||||
for line in h:
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
|
||||
try:
|
||||
sentence = ast.literal_eval(_sentence)
|
||||
prediction = ast.literal_eval(_prediction)
|
||||
attention = ast.literal_eval(_attention)
|
||||
except:
|
||||
continue
|
||||
|
||||
yield sentence, prediction, attention
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path", nargs=-1, required=True)
|
||||
def main(path):
|
||||
for p in tqdm.tqdm(path):
|
||||
out_dir = os.path.splitext(p)[0]
|
||||
if out_dir == path:
|
||||
out_dir = f"{out_dir}_"
|
||||
pathlib.Path(out_dir).mkdir(exist_ok=True)
|
||||
for i, spa in enumerate(parse(p)):
|
||||
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
70
joint_paraphrase_model/utils/text_attention.py
Executable file
70
joint_paraphrase_model/utils/text_attention.py
Executable file
|
@ -0,0 +1,70 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author: Jie Yang
|
||||
# @Date: 2019-03-29 16:10:23
|
||||
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||
# @Last Modified time: 2019-04-12 09:56:12
|
||||
|
||||
|
||||
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
|
||||
import numpy as np
|
||||
|
||||
latex_special_token = ["!@#$%^&*()"]
|
||||
|
||||
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
|
||||
assert(len(text_list) == len(attention_list))
|
||||
if rescale_value:
|
||||
attention_list = rescale(attention_list)
|
||||
word_num = len(text_list)
|
||||
text_list = clean_word(text_list)
|
||||
with open(latex_file,'w') as f:
|
||||
f.write(r'''\documentclass[varwidth]{standalone}
|
||||
\special{papersize=210mm,297mm}
|
||||
\usepackage{color}
|
||||
\usepackage{tcolorbox}
|
||||
\usepackage{CJK}
|
||||
\usepackage{adjustbox}
|
||||
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
||||
\begin{document}
|
||||
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
|
||||
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
|
||||
for idx in range(word_num):
|
||||
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
|
||||
string += "\n}}}"
|
||||
f.write(string+'\n')
|
||||
f.write(r'''\end{CJK*}
|
||||
\end{document}''')
|
||||
|
||||
def rescale(input_list):
|
||||
the_array = np.asarray(input_list)
|
||||
the_max = np.max(the_array)
|
||||
the_min = np.min(the_array)
|
||||
rescale = (the_array - the_min)/(the_max-the_min)*100
|
||||
return rescale.tolist()
|
||||
|
||||
|
||||
def clean_word(word_list):
|
||||
new_word_list = []
|
||||
for word in word_list:
|
||||
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
|
||||
if latex_sensitive in word:
|
||||
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
|
||||
new_word_list.append(word)
|
||||
return new_word_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
## This is a demo:
|
||||
|
||||
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
|
||||
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
|
||||
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
|
||||
sent = '''我 回忆 起 我 曾经 在 大学 年代 , 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar , 我 想起 了 西游记 里 的 琵琶精 。
|
||||
今年 下半年 , 中 美 合拍 的 西游记 即将 正式 开机 , 我 继续 扮演 美猴王 孙悟空 , 我 会 用 美猴王 艺术 形象 努力 创造 一 个 正能量 的 形象 , 文 体 两 开花 , 弘扬 中华 文化 , 希望 大家 能 多多 关注 。'''
|
||||
words = sent.split()
|
||||
word_num = len(words)
|
||||
attention = [(x+1.)/word_num*100 for x in range(word_num)]
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(attention)
|
||||
color = 'red'
|
||||
generate(words, attention, "sample.tex", color)
|
428
joint_sentence_compression_model/.gitignore
vendored
Normal file
428
joint_sentence_compression_model/.gitignore
vendored
Normal file
|
@ -0,0 +1,428 @@
|
|||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python,latex
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,latex
|
||||
|
||||
### LaTeX ###
|
||||
## Core latex/pdflatex auxiliary files:
|
||||
*.aux
|
||||
*.lof
|
||||
*.log
|
||||
*.lot
|
||||
*.fls
|
||||
*.out
|
||||
*.toc
|
||||
*.fmt
|
||||
*.fot
|
||||
*.cb
|
||||
*.cb2
|
||||
.*.lb
|
||||
|
||||
## Intermediate documents:
|
||||
*.dvi
|
||||
*.xdv
|
||||
*-converted-to.*
|
||||
# these rules might exclude image files for figures etc.
|
||||
# *.ps
|
||||
# *.eps
|
||||
# *.pdf
|
||||
|
||||
## Generated if empty string is given at "Please type another file name for output:"
|
||||
.pdf
|
||||
|
||||
## Bibliography auxiliary files (bibtex/biblatex/biber):
|
||||
*.bbl
|
||||
*.bcf
|
||||
*.blg
|
||||
*-blx.aux
|
||||
*-blx.bib
|
||||
*.run.xml
|
||||
|
||||
## Build tool auxiliary files:
|
||||
*.fdb_latexmk
|
||||
*.synctex
|
||||
*.synctex(busy)
|
||||
*.synctex.gz
|
||||
*.synctex.gz(busy)
|
||||
*.pdfsync
|
||||
|
||||
## Build tool directories for auxiliary files
|
||||
# latexrun
|
||||
latex.out/
|
||||
|
||||
## Auxiliary and intermediate files from other packages:
|
||||
# algorithms
|
||||
*.alg
|
||||
*.loa
|
||||
|
||||
# achemso
|
||||
acs-*.bib
|
||||
|
||||
# amsthm
|
||||
*.thm
|
||||
|
||||
# beamer
|
||||
*.nav
|
||||
*.pre
|
||||
*.snm
|
||||
*.vrb
|
||||
|
||||
# changes
|
||||
*.soc
|
||||
|
||||
# comment
|
||||
*.cut
|
||||
|
||||
# cprotect
|
||||
*.cpt
|
||||
|
||||
# elsarticle (documentclass of Elsevier journals)
|
||||
*.spl
|
||||
|
||||
# endnotes
|
||||
*.ent
|
||||
|
||||
# fixme
|
||||
*.lox
|
||||
|
||||
# feynmf/feynmp
|
||||
*.mf
|
||||
*.mp
|
||||
*.t[1-9]
|
||||
*.t[1-9][0-9]
|
||||
*.tfm
|
||||
|
||||
#(r)(e)ledmac/(r)(e)ledpar
|
||||
*.end
|
||||
*.?end
|
||||
*.[1-9]
|
||||
*.[1-9][0-9]
|
||||
*.[1-9][0-9][0-9]
|
||||
*.[1-9]R
|
||||
*.[1-9][0-9]R
|
||||
*.[1-9][0-9][0-9]R
|
||||
*.eledsec[1-9]
|
||||
*.eledsec[1-9]R
|
||||
*.eledsec[1-9][0-9]
|
||||
*.eledsec[1-9][0-9]R
|
||||
*.eledsec[1-9][0-9][0-9]
|
||||
*.eledsec[1-9][0-9][0-9]R
|
||||
|
||||
# glossaries
|
||||
*.acn
|
||||
*.acr
|
||||
*.glg
|
||||
*.glo
|
||||
*.gls
|
||||
*.glsdefs
|
||||
*.lzo
|
||||
*.lzs
|
||||
|
||||
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
|
||||
# *.ist
|
||||
|
||||
# gnuplottex
|
||||
*-gnuplottex-*
|
||||
|
||||
# gregoriotex
|
||||
*.gaux
|
||||
*.gtex
|
||||
|
||||
# htlatex
|
||||
*.4ct
|
||||
*.4tc
|
||||
*.idv
|
||||
*.lg
|
||||
*.trc
|
||||
*.xref
|
||||
|
||||
# hyperref
|
||||
*.brf
|
||||
|
||||
# knitr
|
||||
*-concordance.tex
|
||||
# TODO Comment the next line if you want to keep your tikz graphics files
|
||||
*.tikz
|
||||
*-tikzDictionary
|
||||
|
||||
# listings
|
||||
*.lol
|
||||
|
||||
# luatexja-ruby
|
||||
*.ltjruby
|
||||
|
||||
# makeidx
|
||||
*.idx
|
||||
*.ilg
|
||||
*.ind
|
||||
|
||||
# minitoc
|
||||
*.maf
|
||||
*.mlf
|
||||
*.mlt
|
||||
*.mtc[0-9]*
|
||||
*.slf[0-9]*
|
||||
*.slt[0-9]*
|
||||
*.stc[0-9]*
|
||||
|
||||
# minted
|
||||
_minted*
|
||||
*.pyg
|
||||
|
||||
# morewrites
|
||||
*.mw
|
||||
|
||||
# nomencl
|
||||
*.nlg
|
||||
*.nlo
|
||||
*.nls
|
||||
|
||||
# pax
|
||||
*.pax
|
||||
|
||||
# pdfpcnotes
|
||||
*.pdfpc
|
||||
|
||||
# sagetex
|
||||
*.sagetex.sage
|
||||
*.sagetex.py
|
||||
*.sagetex.scmd
|
||||
|
||||
# scrwfile
|
||||
*.wrt
|
||||
|
||||
# sympy
|
||||
*.sout
|
||||
*.sympy
|
||||
sympy-plots-for-*.tex/
|
||||
|
||||
# pdfcomment
|
||||
*.upa
|
||||
*.upb
|
||||
|
||||
# pythontex
|
||||
*.pytxcode
|
||||
pythontex-files-*/
|
||||
|
||||
# tcolorbox
|
||||
*.listing
|
||||
|
||||
# thmtools
|
||||
*.loe
|
||||
|
||||
# TikZ & PGF
|
||||
*.dpth
|
||||
*.md5
|
||||
*.auxlock
|
||||
|
||||
# todonotes
|
||||
*.tdo
|
||||
|
||||
# vhistory
|
||||
*.hst
|
||||
*.ver
|
||||
|
||||
# easy-todo
|
||||
*.lod
|
||||
|
||||
# xcolor
|
||||
*.xcp
|
||||
|
||||
# xmpincl
|
||||
*.xmpi
|
||||
|
||||
# xindy
|
||||
*.xdy
|
||||
|
||||
# xypic precompiled matrices and outlines
|
||||
*.xyc
|
||||
*.xyd
|
||||
|
||||
# endfloat
|
||||
*.ttt
|
||||
*.fff
|
||||
|
||||
# Latexian
|
||||
TSWLatexianTemp*
|
||||
|
||||
## Editors:
|
||||
# WinEdt
|
||||
*.bak
|
||||
*.sav
|
||||
|
||||
# Texpad
|
||||
.texpadtmp
|
||||
|
||||
# LyX
|
||||
*.lyx~
|
||||
|
||||
# Kile
|
||||
*.backup
|
||||
|
||||
# gummi
|
||||
.*.swp
|
||||
|
||||
# KBibTeX
|
||||
*~[0-9]*
|
||||
|
||||
# TeXnicCenter
|
||||
*.tps
|
||||
|
||||
# auto folder when using emacs and auctex
|
||||
./auto/*
|
||||
*.el
|
||||
|
||||
# expex forward references with \gathertags
|
||||
*-tags.tex
|
||||
|
||||
# standalone packages
|
||||
*.sta
|
||||
|
||||
# Makeindex log files
|
||||
*.lpz
|
||||
|
||||
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
|
||||
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
|
||||
# Uncomment the next line to have this generated file ignored.
|
||||
#*Notes.bib
|
||||
|
||||
### LaTeX Patch ###
|
||||
# LIPIcs / OASIcs
|
||||
*.vtc
|
||||
|
||||
# glossaries
|
||||
*.glstex
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python,latex
|
3
joint_sentence_compression_model/README.md
Normal file
3
joint_sentence_compression_model/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# joint_sentence_compression
|
||||
|
||||
joint training for sentence compression -- neurips submission
|
39
joint_sentence_compression_model/config.py
Normal file
39
joint_sentence_compression_model/config.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
|
||||
# general
|
||||
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
PAD = "<__PAD__>"
|
||||
UNK = "<__UNK__>"
|
||||
NOFIX = "<__NOFIX__>"
|
||||
SOS = "<__SOS__>"
|
||||
EOS = "<__EOS__>"
|
||||
|
||||
batch_size = 1
|
||||
teacher_forcing_ratio = 0.5
|
||||
embedding_dim = 300
|
||||
fix_hidden_dim = 128
|
||||
sem_hidden_dim = 1024
|
||||
fix_dropout = 0.5
|
||||
par_dropout = 0.2
|
||||
_fix_learning_rate = 0.00001
|
||||
_par_learning_rate = 0.0001
|
||||
learning_rate = _par_learning_rate
|
||||
fix_momentum = 0.9
|
||||
par_momentum = 0.0
|
||||
max_length = 851
|
||||
epochs = 5
|
||||
|
||||
# paths
|
||||
data_path = "./data"
|
||||
|
||||
emb_path = os.path.join(data_path, "Google_word2vec/GoogleNews-vectors-negative300.bin")
|
||||
|
||||
glove_path = "glove.840B.300d.txt"
|
||||
|
||||
google_path = os.path.join(data_path, "datasets/sentence-compression/data")
|
||||
google_train_path = os.path.join(google_path, "train_mask_token.tsv")
|
||||
google_dev_path = os.path.join(google_path, "dev_mask_token.tsv")
|
||||
google_test_path = os.path.join(google_path, "test_mask_token.tsv")
|
1
joint_sentence_compression_model/data
Symbolic link
1
joint_sentence_compression_model/data
Symbolic link
|
@ -0,0 +1 @@
|
|||
/netpool/work/gpu-2/users/soodea/
|
1
joint_sentence_compression_model/glove.840B.300d.txt
Symbolic link
1
joint_sentence_compression_model/glove.840B.300d.txt
Symbolic link
|
@ -0,0 +1 @@
|
|||
/netpool/work/gpu-2/users/soodea/datasets/glove/glove.840B.300d.txt
|
0
joint_sentence_compression_model/libs/__init__.py
Normal file
0
joint_sentence_compression_model/libs/__init__.py
Normal file
97
joint_sentence_compression_model/libs/corpora.py
Normal file
97
joint_sentence_compression_model/libs/corpora.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import logging
|
||||
|
||||
import config
|
||||
|
||||
|
||||
def tokenize(sent):
|
||||
return sent.split(" ")
|
||||
|
||||
|
||||
class Lang:
|
||||
"""Represents the vocabulary
|
||||
"""
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.word2index = {
|
||||
config.PAD: 0,
|
||||
config.UNK: 1,
|
||||
}
|
||||
self.word2count = {}
|
||||
self.index2word = {
|
||||
0: config.PAD,
|
||||
1: config.UNK,
|
||||
}
|
||||
self.n_words = 2
|
||||
|
||||
def add_sentence(self, sentence):
|
||||
assert isinstance(
|
||||
sentence, (list, tuple)
|
||||
), "input to add_sentence must be tokenized"
|
||||
for word in sentence:
|
||||
self.add_word(word)
|
||||
|
||||
def add_word(self, word):
|
||||
if word not in self.word2index:
|
||||
self.word2index[word] = self.n_words
|
||||
self.word2count[word] = 1
|
||||
self.index2word[self.n_words] = word
|
||||
self.n_words += 1
|
||||
else:
|
||||
self.word2count[word] += 1
|
||||
|
||||
def __add__(self, other):
|
||||
"""Returns a new Lang object containing the vocabulary from this and
|
||||
the other Lang object
|
||||
"""
|
||||
new_lang = Lang(f"{self.name}_{other.name}")
|
||||
|
||||
# Add vocabulary from both Langs
|
||||
for word in self.word2count.keys():
|
||||
new_lang.add_word(word)
|
||||
for word in other.word2count.keys():
|
||||
new_lang.add_word(word)
|
||||
|
||||
# Fix the counts on the new one
|
||||
for word in new_lang.word2count.keys():
|
||||
new_lang.word2count[word] = self.word2count.get(
|
||||
word, 0
|
||||
) + other.word2count.get(word, 0)
|
||||
|
||||
return new_lang
|
||||
|
||||
|
||||
def load_google(split, max_len=None):
|
||||
"""Load the Google Sentence Compression Dataset"""
|
||||
logger = logging.getLogger(f"{__name__}.load_compression")
|
||||
lang = Lang("compression")
|
||||
|
||||
if split == "train":
|
||||
path = config.google_train_path
|
||||
elif split == "val":
|
||||
path = config.google_dev_path
|
||||
elif split == "test":
|
||||
path = config.google_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
data = []
|
||||
sent = []
|
||||
mask = []
|
||||
with open(path) as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if line:
|
||||
w, d = line.split("\t")
|
||||
sent.append(w)
|
||||
mask.append(int(d))
|
||||
else:
|
||||
if sent and (max_len is None or len(sent) <= max_len):
|
||||
data.append([sent, mask])
|
||||
lang.add_sentence(sent)
|
||||
sent = []
|
||||
mask = []
|
||||
if sent:
|
||||
data.append([tuple(sent), tuple(mask)])
|
||||
lang.add_sentence(sent)
|
||||
|
||||
return data, lang
|
|
@ -0,0 +1 @@
|
|||
from .main import *
|
|
@ -0,0 +1,125 @@
|
|||
from collections import OrderedDict
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from .self_attention import Transformer
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence, pad_sequence
|
||||
|
||||
|
||||
def random_embedding(vocab_size, embedding_dim):
|
||||
pretrain_emb = np.empty([vocab_size, embedding_dim])
|
||||
scale = np.sqrt(3.0 / embedding_dim)
|
||||
for index in range(vocab_size):
|
||||
pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
|
||||
return pretrain_emb
|
||||
|
||||
|
||||
def neg_log_likelihood_loss(outputs, batch_label, batch_size, seq_len):
|
||||
outputs = outputs.view(batch_size * seq_len, -1)
|
||||
score = F.log_softmax(outputs, 1)
|
||||
|
||||
loss = nn.NLLLoss(ignore_index=0, size_average=False)(
|
||||
score, batch_label.view(batch_size * seq_len)
|
||||
)
|
||||
loss = loss / batch_size
|
||||
_, tag_seq = torch.max(score, 1)
|
||||
tag_seq = tag_seq.view(batch_size, seq_len)
|
||||
|
||||
return loss, tag_seq
|
||||
|
||||
|
||||
def mse_loss(outputs, batch_label, batch_size, seq_len, word_seq_length):
|
||||
score = torch.sigmoid(outputs)
|
||||
|
||||
mask = torch.zeros_like(score)
|
||||
for i, v in enumerate(word_seq_length):
|
||||
mask[i, 0:v] = 1
|
||||
|
||||
score = score * mask
|
||||
|
||||
loss = nn.MSELoss(reduction="sum")(
|
||||
score.view(batch_size, seq_len), batch_label.view(batch_size, seq_len)
|
||||
)
|
||||
|
||||
loss = loss / batch_size
|
||||
|
||||
return loss, score.view(batch_size, seq_len)
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_type,
|
||||
vocab_size,
|
||||
embedding_dim,
|
||||
dropout,
|
||||
hidden_dim,
|
||||
embeddings=None,
|
||||
attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(f"{__name__}")
|
||||
self.attention = attention
|
||||
prelayers = OrderedDict()
|
||||
postlayers = OrderedDict()
|
||||
|
||||
if embedding_type in ("w2v", "glove"):
|
||||
if embeddings is not None:
|
||||
prelayers["embedding_layer"] = nn.Embedding.from_pretrained(embeddings, freeze=True)
|
||||
else:
|
||||
prelayers["embedding_layer"] = nn.Embedding(vocab_size, embedding_dim)
|
||||
prelayers["embedding_dropout_layer"] = nn.Dropout(dropout)
|
||||
embedding_dim = 300
|
||||
elif embedding_type == "bert":
|
||||
embedding_dim = 768
|
||||
|
||||
self.lstm = BiLSTM(embedding_dim, hidden_dim // 2, num_layers=1)
|
||||
postlayers["lstm_dropout_layer"] = nn.Dropout(dropout)
|
||||
|
||||
if self.attention:
|
||||
postlayers["attention_layer"] = Transformer(
|
||||
d_model=hidden_dim, n_heads=4, n_layers=1
|
||||
)
|
||||
|
||||
postlayers["ff_layer"] = nn.Linear(hidden_dim, hidden_dim // 2)
|
||||
postlayers["ff_activation"] = nn.ReLU()
|
||||
postlayers["output_layer"] = nn.Linear(hidden_dim // 2, 1)
|
||||
|
||||
self.logger.info(f"prelayers: {prelayers.keys()}")
|
||||
self.logger.info(f"postlayers: {postlayers.keys()}")
|
||||
|
||||
self.pre = nn.Sequential(prelayers)
|
||||
self.post = nn.Sequential(postlayers)
|
||||
|
||||
def forward(self, x, word_seq_length):
|
||||
x = self.pre(x)
|
||||
x = self.lstm(x, word_seq_length)
|
||||
|
||||
output = []
|
||||
for _x, l in zip(x.transpose(1, 0), word_seq_length):
|
||||
output.append(self.post(_x[:l].unsqueeze(0))[0])
|
||||
|
||||
return pad_sequence(output, batch_first=True)
|
||||
|
||||
|
||||
class BiLSTM(nn.Module):
|
||||
def __init__(self, embedding_dim, lstm_hidden, num_layers):
|
||||
super().__init__()
|
||||
self.net = nn.LSTM(
|
||||
input_size=embedding_dim,
|
||||
hidden_size=lstm_hidden,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
def forward(self, x, word_seq_length):
|
||||
packed_words = pack_padded_sequence(x, word_seq_length, True, False)
|
||||
lstm_out, hidden = self.net(packed_words)
|
||||
lstm_out, _ = pad_packed_sequence(lstm_out)
|
||||
return lstm_out
|
|
@ -0,0 +1,128 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_hid, n_position=200):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
|
||||
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
''' Sinusoid position encoding table '''
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
||||
|
||||
|
||||
class AttentionLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(AttentionLayer, self).__init__()
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
# Q: float32:[batch_size, n_queries, d_k]
|
||||
# K: float32:[batch_size, n_keys, d_k]
|
||||
# V: float32:[batch_size, n_keys, d_v]
|
||||
dk = K.shape[-1]
|
||||
dv = V.shape[-1]
|
||||
KT = torch.transpose(K, -1, -2)
|
||||
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
|
||||
# weight_logits: float32[batch_size, n_queries, n_keys]
|
||||
weights = F.softmax(weight_logits, dim=-1)
|
||||
# weight: float32[batch_size, n_queries, n_keys]
|
||||
return torch.bmm(weights, V)
|
||||
# return float32[batch_size, n_queries, dv]
|
||||
|
||||
|
||||
class MultiHeadedSelfAttentionLayer(nn.Module):
|
||||
def __init__(self, d_model, n_heads):
|
||||
super(MultiHeadedSelfAttentionLayer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
print('{} {}'.format(d_model, n_heads))
|
||||
assert d_model % n_heads == 0
|
||||
self.d_k = d_model // n_heads
|
||||
self.d_v = self.d_k
|
||||
self.attention_layer = AttentionLayer()
|
||||
self.W_Qs = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_k, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_Ks = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_k, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_Vs = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_v, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_O = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
# x:float32[batch_size, sequence_length, self.d_model]
|
||||
head_outputs = []
|
||||
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
|
||||
Q = W_Q(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||
K = W_K(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||
V = W_V(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_v]
|
||||
head_output = self.attention_layer(Q, K, V)
|
||||
# float32:[batch_size, sequence_length, self.d_v]
|
||||
head_outputs.append(head_output)
|
||||
concatenated = torch.cat(head_outputs, dim=-1)
|
||||
# concatenated float32:[batch_size, sequence_length, self.d_model]
|
||||
out = self.W_O(concatenated)
|
||||
# out float32:[batch_size, sequence_length, self.d_model]
|
||||
return out
|
||||
|
||||
class Feedforward(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
super(Feedforward, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.W1 = nn.Linear(d_model, d_model)
|
||||
self.W2 = nn.Linear(d_model, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# x: float32[batch_size, sequence_length, d_model]
|
||||
return self.W2(torch.relu(self.W1(x)))
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_model, n_heads, n_layers):
|
||||
super(Transformer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.attention_layers = nn.ModuleList([
|
||||
MultiHeadedSelfAttentionLayer(d_model, n_heads)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.ffs = nn.ModuleList([
|
||||
Feedforward(d_model)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
# x: float32[batch_size, sequence_length, self.d_model]
|
||||
for attention_layer, ff in zip(self.attention_layers, self.ffs):
|
||||
attention_out = attention_layer(x)
|
||||
# attention_out: float32[batch_size, sequence_length, self.d_model]
|
||||
x = F.layer_norm(x + attention_out, x.shape[2:])
|
||||
ff_out = ff(x)
|
||||
# ff_out: float32[batch_size, sequence_length, self.d_model]
|
||||
x = F.layer_norm(x + ff_out, x.shape[2:])
|
||||
return x
|
|
@ -0,0 +1,29 @@
|
|||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2018, Tatsuya Aoki
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,31 @@
|
|||
# Simple Model for Sentence Compression
|
||||
3-layered BILSTM model for sentence compression, referred as Baseline in [Klerke et al., NAACL 2016](http://aclweb.org/anthology/N/N16/N16-1179.pdf).
|
||||
## Requirements
|
||||
### Framework
|
||||
- python (<= 3.6)
|
||||
- pytorch (<= 0.3.0)
|
||||
|
||||
### Packages
|
||||
- torchtext
|
||||
|
||||
## How to run
|
||||
```
|
||||
./getdata
|
||||
python main.py
|
||||
```
|
||||
To run the scripts with gpu, use this command `python main.py --gpu-id ID`, which ID is the integer from 0 to the number of gpus what you have.
|
||||
## Reference
|
||||
|
||||
```
|
||||
@InProceedings{klerke-goldberg-sogaard:2016:N16-1,
|
||||
author = {Klerke, Sigrid and Goldberg, Yoav and S{\o}gaard, Anders},
|
||||
title = {Improving sentence compression by learning to predict gaze},
|
||||
booktitle = {Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
|
||||
month = {June},
|
||||
year = {2016},
|
||||
address = {San Diego, California},
|
||||
publisher = {Association for Computational Linguistics},
|
||||
pages = {1528--1533},
|
||||
url = {http://www.aclweb.org/anthology/N16-1179}
|
||||
}
|
||||
```
|
|
@ -0,0 +1 @@
|
|||
from .main import *
|
|
@ -0,0 +1,95 @@
|
|||
from torchtext import data
|
||||
from const import Phase
|
||||
|
||||
|
||||
def create_dataset(data: dict, batch_size: int, device: int):
|
||||
|
||||
train = Dataset(data[Phase.TRAIN]['tokens'],
|
||||
data[Phase.TRAIN]['labels'],
|
||||
vocab=None,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
phase=Phase.TRAIN)
|
||||
|
||||
dev = Dataset(data[Phase.DEV]['tokens'],
|
||||
data[Phase.DEV]['labels'],
|
||||
vocab=train.vocab,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
phase=Phase.DEV)
|
||||
|
||||
test = Dataset(data[Phase.TEST]['tokens'],
|
||||
data[Phase.TEST]['labels'],
|
||||
vocab=train.vocab,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
phase=Phase.TEST)
|
||||
return train, dev, test
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __init__(self,
|
||||
tokens: list,
|
||||
label_list: list,
|
||||
vocab: list,
|
||||
batch_size: int,
|
||||
device: int,
|
||||
phase: Phase):
|
||||
assert len(tokens) == len(label_list), \
|
||||
'the number of sentences and the number of POS/head sequences \
|
||||
should be the same length'
|
||||
|
||||
self.pad_token = '<PAD>'
|
||||
# self.unk_token = '<UNK>'
|
||||
self.tokens = tokens
|
||||
self.label_list = label_list
|
||||
self.sentence_id = [[i] for i in range(len(tokens))]
|
||||
self.device = device
|
||||
|
||||
self.token_field = data.Field(use_vocab=True,
|
||||
# unk_token=self.unk_token,
|
||||
pad_token=self.pad_token,
|
||||
batch_first=True)
|
||||
self.label_field = data.Field(use_vocab=False, pad_token=-1, batch_first=True)
|
||||
self.sentence_id_field = data.Field(use_vocab=False, batch_first=True)
|
||||
self.dataset = self._create_dataset()
|
||||
|
||||
if vocab is None:
|
||||
self.token_field.build_vocab(self.tokens)
|
||||
self.vocab = self.token_field.vocab
|
||||
else:
|
||||
self.token_field.vocab = vocab
|
||||
self.vocab = vocab
|
||||
self.pad_index = self.token_field.vocab.stoi[self.pad_token]
|
||||
|
||||
self._set_batch_iter(batch_size, phase)
|
||||
|
||||
def get_raw_sentence(self, sentences):
|
||||
return [[self.vocab.itos[idx] for idx in sentence]
|
||||
for sentence in sentences]
|
||||
|
||||
def _create_dataset(self):
|
||||
_fields = [('token', self.token_field),
|
||||
('label', self.label_field),
|
||||
('sentence_id', self.sentence_id_field)]
|
||||
return data.Dataset(self._get_examples(_fields), _fields)
|
||||
|
||||
def _get_examples(self, fields: list):
|
||||
ex = []
|
||||
for sentence, label, sentence_id in zip(self.tokens, self.label_list, self.sentence_id):
|
||||
ex.append(data.Example.fromlist([sentence, label, sentence_id], fields))
|
||||
return ex
|
||||
|
||||
def _set_batch_iter(self, batch_size: int, phase: Phase):
|
||||
|
||||
def sort(data: data.Dataset) -> int:
|
||||
return len(getattr(data, 'token'))
|
||||
|
||||
train = True if phase == Phase.TRAIN else False
|
||||
|
||||
self.batch_iter = data.BucketIterator(dataset=self.dataset,
|
||||
batch_size=batch_size,
|
||||
sort_key=sort,
|
||||
train=train,
|
||||
repeat=False,
|
||||
device=self.device)
|
|
@ -0,0 +1,8 @@
|
|||
from enum import Enum, unique
|
||||
|
||||
|
||||
@unique
|
||||
class Phase(Enum):
|
||||
TRAIN = 'train'
|
||||
DEV = 'dev'
|
||||
TEST = 'test'
|
|
@ -0,0 +1,92 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(self,
|
||||
embeddings,
|
||||
hidden_size: int,
|
||||
prior,
|
||||
device: torch.device):
|
||||
|
||||
super(Network, self).__init__()
|
||||
self.device = device
|
||||
self.priors = torch.log(torch.tensor([prior, 1-prior])).to(device)
|
||||
self.hidden_size = hidden_size
|
||||
self.bilstm_layers = 3
|
||||
self.bilstm_input_size = 300
|
||||
self.bilstm_output_size = 2 * hidden_size
|
||||
self.word_emb = nn.Embedding.from_pretrained(embeddings, freeze=False)
|
||||
self.bilstm = nn.LSTM(self.bilstm_input_size,
|
||||
self.hidden_size,
|
||||
num_layers=self.bilstm_layers,
|
||||
batch_first=True,
|
||||
dropout=0.1, #ms best mod 0.1
|
||||
bidirectional=True)
|
||||
self.dropout = nn.Dropout(p=0.35)
|
||||
if self.attention:
|
||||
self.attention_size = self.bilstm_output_size * 2
|
||||
self.u_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size)
|
||||
self.w_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size)
|
||||
self.v_a_inv = nn.Linear(self.bilstm_output_size, 1, bias=False)
|
||||
self.linear_attn = nn.Linear(self.attention_size, self.bilstm_output_size)
|
||||
self.linear = nn.Linear(self.bilstm_output_size, self.hidden_size)
|
||||
self.pred = nn.Linear(self.hidden_size, 2)
|
||||
self.softmax = nn.LogSoftmax(dim=1)
|
||||
self.criterion = nn.NLLLoss(ignore_index=-1)
|
||||
|
||||
def forward(self, input_tokens, labels, fixations=None):
|
||||
loss = 0.0
|
||||
preds = []
|
||||
atts = []
|
||||
batch_size, seq_len = input_tokens.size()
|
||||
self.init_hidden(batch_size, device=self.device)
|
||||
|
||||
x_i = self.word_emb(input_tokens)
|
||||
x_i = self.dropout(x_i)
|
||||
|
||||
hidden, (self.h_n, self.c_n) = self.bilstm(x_i, (self.h_n, self.c_n))
|
||||
_, _, hidden_size = hidden.size()
|
||||
|
||||
for i in range(seq_len):
|
||||
nth_hidden = hidden[:, i, :]
|
||||
if self.attention:
|
||||
target = nth_hidden.expand(seq_len, batch_size, -1).transpose(0, 1)
|
||||
mask = hidden.eq(target)[:, :, 0].unsqueeze(2)
|
||||
attn_weight = self.attention(hidden, target, fixations, mask)
|
||||
context_vector = torch.bmm(attn_weight.transpose(1, 2), hidden).squeeze(1)
|
||||
|
||||
nth_hidden = torch.tanh(self.linear_attn(torch.cat((nth_hidden, context_vector), -1)))
|
||||
atts.append(attn_weight.detach().cpu())
|
||||
logits = self.pred(self.linear(nth_hidden))
|
||||
if not self.training:
|
||||
logits = logits + self.priors
|
||||
output = self.softmax(logits)
|
||||
loss += self.criterion(output, labels[:, i])
|
||||
|
||||
_, topi = output.topk(k=1, dim=1)
|
||||
pred = topi.squeeze(-1)
|
||||
preds.append(pred)
|
||||
|
||||
preds = torch.stack(torch.cat(preds, dim=0).split(batch_size), dim=1)
|
||||
|
||||
if atts:
|
||||
atts = torch.stack(torch.cat(atts, dim=0).split(batch_size), dim=1)
|
||||
|
||||
return loss, preds, atts
|
||||
|
||||
def attention(self, source, target, fixations=None, mask=None):
|
||||
function_g = \
|
||||
self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
|
||||
if mask is not None:
|
||||
function_g.masked_fill_(mask, -1e4)
|
||||
if fixations is not None:
|
||||
function_g = function_g*fixations
|
||||
return nn.functional.softmax(function_g, dim=1)
|
||||
|
||||
def init_hidden(self, batch_size, device):
|
||||
zeros = Variable(torch.zeros(2*self.bilstm_layers, batch_size, self.hidden_size))
|
||||
self.h_n = zeros.to(device)
|
||||
self.c_n = zeros.to(device)
|
||||
return self.h_n, self.c_n
|
|
@ -0,0 +1,183 @@
|
|||
import torch
|
||||
from torch import optim
|
||||
import tqdm
|
||||
|
||||
from const import Phase
|
||||
from batch import create_dataset
|
||||
from models import Baseline
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
|
||||
def run(dataset_train,
|
||||
dataset_dev,
|
||||
dataset_test,
|
||||
model_type,
|
||||
word_embed_size,
|
||||
hidden_size,
|
||||
batch_size,
|
||||
device,
|
||||
n_epochs):
|
||||
|
||||
if model_type == 'base':
|
||||
model = Baseline(vocab=dataset_train.vocab,
|
||||
word_embed_size=word_embed_size,
|
||||
hidden_size=hidden_size,
|
||||
device=device,
|
||||
inference=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model = model.to(device)
|
||||
|
||||
optim_params = model.parameters()
|
||||
optimizer = optim.Adam(optim_params, lr=10**-3)
|
||||
|
||||
print('start training')
|
||||
for epoch in range(n_epochs):
|
||||
train_loss, tokens, preds, golds = train(dataset_train,
|
||||
model,
|
||||
optimizer,
|
||||
batch_size,
|
||||
epoch,
|
||||
Phase.TRAIN,
|
||||
device)
|
||||
|
||||
dev_loss, tokens, preds, golds = train(dataset_dev,
|
||||
model,
|
||||
optimizer,
|
||||
batch_size,
|
||||
epoch,
|
||||
Phase.DEV,
|
||||
device)
|
||||
logger = '\t'.join(['epoch {}'.format(epoch+1),
|
||||
'TRAIN Loss: {:.9f}'.format(train_loss),
|
||||
'DEV Loss: {:.9f}'.format(dev_loss)])
|
||||
# print('\r'+logger, end='')
|
||||
print(logger)
|
||||
test_loss, tokens, preds, golds = train(dataset_test,
|
||||
model,
|
||||
optimizer,
|
||||
batch_size,
|
||||
epoch,
|
||||
Phase.TEST,
|
||||
device)
|
||||
print('====', 'TEST', '=====')
|
||||
print_scores(preds, golds)
|
||||
output_results(tokens, preds, golds)
|
||||
|
||||
|
||||
def train(dataset,
|
||||
model,
|
||||
optimizer,
|
||||
batch_size,
|
||||
n_epoch,
|
||||
phase,
|
||||
device):
|
||||
|
||||
total_loss = 0.0
|
||||
tokens = []
|
||||
preds = []
|
||||
labels = []
|
||||
if phase == Phase.TRAIN:
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
|
||||
for batch in tqdm.tqdm(dataset.batch_iter):
|
||||
token = getattr(batch, 'token')
|
||||
label = getattr(batch, 'label')
|
||||
raw_sentences = dataset.get_raw_sentence(token.data.detach().cpu().numpy())
|
||||
|
||||
loss, pred = \
|
||||
model(token, raw_sentences, label, phase)
|
||||
|
||||
if phase == Phase.TRAIN:
|
||||
optimizer.zero_grad()
|
||||
torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=5)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# remove PAD from input sentences/labels and results
|
||||
mask = (token != dataset.pad_index)
|
||||
length_tensor = mask.sum(1)
|
||||
length_tensor = length_tensor.data.detach().cpu().numpy()
|
||||
|
||||
for index, n_tokens_in_the_sentence in enumerate(length_tensor):
|
||||
if n_tokens_in_the_sentence > 0:
|
||||
tokens.append(raw_sentences[index][:n_tokens_in_the_sentence])
|
||||
_label = label[index][:n_tokens_in_the_sentence]
|
||||
_pred = pred[index][:n_tokens_in_the_sentence]
|
||||
_label = _label.data.detach().cpu().numpy()
|
||||
_pred = _pred.data.detach().cpu().numpy()
|
||||
labels.append(_label)
|
||||
preds.append(_pred)
|
||||
|
||||
total_loss += torch.mean(loss).item()
|
||||
|
||||
return total_loss, tokens, preds, labels
|
||||
|
||||
|
||||
def read_two_cols_data(fname, max_len=None):
|
||||
data = {}
|
||||
tokens = []
|
||||
labels = []
|
||||
token = []
|
||||
label = []
|
||||
with open(fname, mode='r') as f:
|
||||
for line in f:
|
||||
line = line.strip().lower().split()
|
||||
if line:
|
||||
try:
|
||||
_token, _label = line
|
||||
except ValueError:
|
||||
raise
|
||||
token.append(_token)
|
||||
if _label == '0' or _label == '1':
|
||||
label.append(int(_label))
|
||||
else:
|
||||
if _label == 'del':
|
||||
label.append(1)
|
||||
else:
|
||||
label.append(0)
|
||||
else:
|
||||
if max_len is None or len(token) <= max_len:
|
||||
tokens.append(token)
|
||||
labels.append(label)
|
||||
token = []
|
||||
label = []
|
||||
|
||||
data['tokens'] = tokens
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
|
||||
def load(train_path, dev_path, test_path, batch_size, max_len, device):
|
||||
train = read_two_cols_data(train_path, max_len)
|
||||
dev = read_two_cols_data(dev_path)
|
||||
test = read_two_cols_data(test_path)
|
||||
data = {Phase.TRAIN: train, Phase.DEV: dev, Phase.TEST: test}
|
||||
return create_dataset(data, batch_size=batch_size, device=device)
|
||||
|
||||
|
||||
def print_scores(preds, golds):
|
||||
_preds = [label for sublist in preds for label in sublist]
|
||||
_golds = [label for sublist in golds for label in sublist]
|
||||
target_names = ['not_del', 'del']
|
||||
print(classification_report(_golds, _preds, target_names=target_names, digits=5))
|
||||
|
||||
|
||||
def output_results(tokens, preds, golds, path='./result/sentcomp'):
|
||||
with open(path+'.original.txt', mode='w') as w, \
|
||||
open(path+'.gold.txt', mode='w') as w_gold, \
|
||||
open(path+'.pred.txt', mode='w') as w_pred:
|
||||
|
||||
for _tokens, _golds, _preds in zip(tokens, golds, preds):
|
||||
for token, gold, pred in zip(_tokens, _golds, _preds):
|
||||
w.write(token + ' ')
|
||||
if gold == 0:
|
||||
w_gold.write(token + ' ')
|
||||
# 0 -> keep, 1 -> delete
|
||||
if pred == 0:
|
||||
w_pred.write(token + ' ')
|
||||
w.write('\n')
|
||||
w_gold.write('\n')
|
||||
w_pred.write('\n')
|
218
joint_sentence_compression_model/libs/utils.py
Normal file
218
joint_sentence_compression_model/libs/utils.py
Normal file
|
@ -0,0 +1,218 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
from nltk.translate.bleu_score import sentence_bleu
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import config
|
||||
|
||||
|
||||
plt.switch_backend("agg")
|
||||
|
||||
|
||||
def load_glove(vocabulary):
|
||||
logger = logging.getLogger(f"{__name__}.load_glove")
|
||||
logger.info("loading embeddings")
|
||||
try:
|
||||
with open(f"glove.cache") as h:
|
||||
cache = json.load(h)
|
||||
except:
|
||||
logger.info("cache doesn't exist")
|
||||
cache = {}
|
||||
cache[config.PAD] = [0] * 300
|
||||
cache[config.SOS] = [0] * 300
|
||||
cache[config.EOS] = [0] * 300
|
||||
cache[config.UNK] = [0] * 300
|
||||
cache[config.NOFIX] = [0] * 300
|
||||
else:
|
||||
logger.info("cache found")
|
||||
|
||||
cache_miss = False
|
||||
|
||||
if not set(vocabulary) <= set(cache):
|
||||
cache_miss = True
|
||||
logger.warn("cache miss, loading full embeddings")
|
||||
data = {}
|
||||
with open("glove.840B.300d.txt") as h:
|
||||
for line in h:
|
||||
word, *emb = line.strip().split()
|
||||
try:
|
||||
data[word] = [float(x) for x in emb]
|
||||
except:
|
||||
continue
|
||||
logger.info("finished loading full embeddings")
|
||||
for word in vocabulary:
|
||||
try:
|
||||
cache[word] = data[word]
|
||||
except KeyError:
|
||||
cache[word] = [0] * 300
|
||||
logger.info("cache updated")
|
||||
|
||||
embeddings = []
|
||||
for word in vocabulary:
|
||||
embeddings.append(torch.tensor(cache[word], dtype=torch.float32))
|
||||
embeddings = torch.stack(embeddings)
|
||||
|
||||
if cache_miss:
|
||||
with open(f"glove.cache", "w") as h:
|
||||
json.dump(cache, h)
|
||||
logger.info("cache saved")
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def tokenize(s):
|
||||
s = s.lower().strip()
|
||||
s = re.sub(r"([.!?])", r" \1", s)
|
||||
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
|
||||
s = s.split(" ")
|
||||
return s
|
||||
|
||||
|
||||
def indices_from_sentence(word2index, sentence, unknown_threshold):
|
||||
if unknown_threshold:
|
||||
return [
|
||||
word2index.get(
|
||||
word if random.random() > unknown_threshold else config.UNK,
|
||||
word2index[config.UNK],
|
||||
)
|
||||
for word in sentence
|
||||
]
|
||||
else:
|
||||
return [
|
||||
word2index.get(word, word2index[config.UNK]) for word in sentence
|
||||
]
|
||||
|
||||
|
||||
def tensor_from_sentence(word2index, sentence, unknown_threshold):
|
||||
indices = indices_from_sentence(word2index, sentence, unknown_threshold)
|
||||
return torch.tensor(indices, dtype=torch.long, device=config.DEV)
|
||||
|
||||
|
||||
def tensors_from_pair(word2index, pair, shuffle, unknown_threshold):
|
||||
tensors = [
|
||||
tensor_from_sentence(word2index, pair[0], unknown_threshold),
|
||||
tensor_from_sentence(word2index, pair[1], unknown_threshold),
|
||||
]
|
||||
if shuffle:
|
||||
random.shuffle(tensors)
|
||||
return tensors
|
||||
|
||||
|
||||
def bleu(reference, hypothesis, n=4):
|
||||
if n < 1:
|
||||
return 0
|
||||
weights = [1/n]*n
|
||||
return sentence_bleu([reference], hypothesis, weights)
|
||||
|
||||
|
||||
def pair_iter(pairs, word2index, shuffle=False, shuffle_pairs=False, unknown_threshold=0.00):
|
||||
if shuffle:
|
||||
pairs = pairs.copy()
|
||||
random.shuffle(pairs)
|
||||
for pair in pairs:
|
||||
tensor1, tensor2 = tensors_from_pair(word2index, (pair[0], pair[1]), shuffle_pairs, unknown_threshold)
|
||||
yield (tensor1,), (tensor2,)
|
||||
|
||||
|
||||
def sent_iter(sents, word2index, batch_size, unknown_threshold=0.00):
|
||||
for i in range(len(sents)//batch_size+1):
|
||||
raw_sents = [x[0] for x in sents[i*batch_size:i*batch_size+batch_size]]
|
||||
_sents = [tensor_from_sentence(word2index, sent, unknown_threshold) for sent, target in sents[i*batch_size:i*batch_size+batch_size]]
|
||||
_targets = [torch.tensor(target, dtype=torch.long).to(config.DEV) for sent, target in sents[i*batch_size:i*batch_size+batch_size]]
|
||||
if raw_sents and _sents and _targets:
|
||||
yield(raw_sents, _sents, _targets)
|
||||
|
||||
|
||||
def batch_iter(pairs, word2index, batch_size, shuffle=False, unknown_threshold=0.00):
|
||||
for i in range(len(pairs) // batch_size):
|
||||
batch = pairs[i : i + batch_size]
|
||||
if len(batch) != batch_size:
|
||||
continue
|
||||
batch_tensors = [
|
||||
tensors_from_pair(word2index, (pair[0], pair[1]), shuffle, unknown_threshold)
|
||||
for pair in batch
|
||||
]
|
||||
|
||||
tensors1, tensors2 = zip(*batch_tensors)
|
||||
|
||||
yield tensors1, tensors2
|
||||
|
||||
|
||||
def asMinutes(s):
|
||||
m = math.floor(s / 60)
|
||||
s -= m * 60
|
||||
return "%dm %ds" % (m, s)
|
||||
|
||||
|
||||
def timeSince(since, percent):
|
||||
now = time.time()
|
||||
s = now - since
|
||||
es = s / (percent)
|
||||
rs = es - s
|
||||
return "%s (- %s)" % (asMinutes(s), asMinutes(rs))
|
||||
|
||||
|
||||
def showPlot(points):
|
||||
plt.figure()
|
||||
fig, ax = plt.subplots()
|
||||
# this locator puts ticks at regular intervals
|
||||
loc = ticker.MultipleLocator(base=0.2)
|
||||
ax.yaxis.set_major_locator(loc)
|
||||
plt.plot(points)
|
||||
|
||||
|
||||
def showAttention(input_sentence, output_words, attentions):
|
||||
# Set up figure with colorbar
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
cax = ax.matshow(attentions.numpy(), cmap="bone")
|
||||
fig.colorbar(cax)
|
||||
|
||||
# Set up axes
|
||||
ax.set_xticklabels([""] + input_sentence.split(" ") + ["<__EOS__>"], rotation=90)
|
||||
ax.set_yticklabels([""] + output_words)
|
||||
|
||||
# Show label at every tick
|
||||
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def evaluateAndShowAttention(input_sentence):
|
||||
output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence)
|
||||
print("input =", input_sentence)
|
||||
print("output =", " ".join(output_words))
|
||||
showAttention(input_sentence, output_words, attentions)
|
||||
|
||||
|
||||
def save_model(model, word2index, path):
|
||||
if not path.endswith(".tar"):
|
||||
path += ".tar"
|
||||
torch.save(
|
||||
{"weights": model.state_dict(), "word2index": word2index},
|
||||
path,
|
||||
)
|
||||
|
||||
|
||||
def load_model(path):
|
||||
checkpoint = torch.load(path)
|
||||
return checkpoint["weights"], checkpoint["word2index"]
|
||||
|
||||
|
||||
def extend_vocabulary(word2index, langs):
|
||||
for lang in langs:
|
||||
for word in lang.word2index:
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
return word2index
|
495
joint_sentence_compression_model/main.py
Normal file
495
joint_sentence_compression_model/main.py
Normal file
|
@ -0,0 +1,495 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
|
||||
import click
|
||||
import sacrebleu
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import tqdm
|
||||
|
||||
import config
|
||||
from libs import corpora
|
||||
from libs import utils
|
||||
from libs.fixation_generation import Network as FixNN
|
||||
from libs.sentence_compression import Network as ComNN
|
||||
from sklearn.metrics import classification_report, precision_recall_fscore_support
|
||||
|
||||
|
||||
cwd = os.path.dirname(__file__)
|
||||
|
||||
logger = logging.getLogger("main")
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(
|
||||
self, word2index, embeddings, prior,
|
||||
):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(f"{__name__}")
|
||||
self.word2index = word2index
|
||||
self.index2word = {i: k for k, i in word2index.items()}
|
||||
self.fix_gen = FixNN(
|
||||
embedding_type="glove",
|
||||
vocab_size=len(word2index),
|
||||
embedding_dim=config.embedding_dim,
|
||||
embeddings=embeddings,
|
||||
dropout=config.fix_dropout,
|
||||
hidden_dim=config.fix_hidden_dim,
|
||||
)
|
||||
self.com_nn = ComNN(
|
||||
embeddings=embeddings, hidden_size=config.sem_hidden_dim, prior=prior, device=config.DEV
|
||||
)
|
||||
|
||||
def forward(self, x, target, seq_lens):
|
||||
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
|
||||
target = nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=-1)
|
||||
|
||||
fixations = torch.sigmoid(self.fix_gen(x1, seq_lens))
|
||||
# fixations = None
|
||||
|
||||
loss, pred, atts = self.com_nn(x1, target, fixations)
|
||||
return loss, pred, atts, fixations
|
||||
|
||||
|
||||
def load_corpus(corpus_name, splits):
|
||||
if not splits:
|
||||
return
|
||||
|
||||
logger.info("loading corpus")
|
||||
if corpus_name == "google":
|
||||
load_fn = corpora.load_google
|
||||
|
||||
corpus = {}
|
||||
langs = []
|
||||
|
||||
if "train" in splits:
|
||||
train_pairs, train_lang = load_fn("train", max_len=200)
|
||||
corpus["train"] = train_pairs
|
||||
langs.append(train_lang)
|
||||
if "val" in splits:
|
||||
val_pairs, val_lang = load_fn("val")
|
||||
corpus["val"] = val_pairs
|
||||
langs.append(val_lang)
|
||||
if "test" in splits:
|
||||
test_pairs, test_lang = load_fn("test")
|
||||
corpus["test"] = test_pairs
|
||||
langs.append(test_lang)
|
||||
|
||||
logger.info("creating word index")
|
||||
lang = langs[0]
|
||||
for _lang in langs[1:]:
|
||||
lang += _lang
|
||||
word2index = lang.word2index
|
||||
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
|
||||
return corpus, word2index, index2word
|
||||
|
||||
|
||||
def init_network(word2index, prior):
|
||||
logger.info("loading embeddings")
|
||||
vocabulary = sorted(word2index.keys())
|
||||
embeddings = utils.load_glove(vocabulary)
|
||||
|
||||
logger.info("initializing model")
|
||||
network = Network(word2index=word2index, embeddings=embeddings, prior=prior)
|
||||
network.to(config.DEV)
|
||||
|
||||
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
|
||||
|
||||
return network
|
||||
|
||||
|
||||
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
|
||||
@click.option("-v", "--verbose", count=True)
|
||||
@click.option("-d", "--debug", is_flag=True)
|
||||
def main(verbose, debug):
|
||||
if verbose == 0:
|
||||
loglevel = logging.ERROR
|
||||
elif verbose == 1:
|
||||
loglevel = logging.WARN
|
||||
elif verbose >= 2:
|
||||
loglevel = logging.INFO
|
||||
|
||||
if debug:
|
||||
loglevel = logging.DEBUG
|
||||
|
||||
logging.basicConfig(
|
||||
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
|
||||
datefmt="%d.%m. %H:%M:%S",
|
||||
level=loglevel,
|
||||
)
|
||||
|
||||
logger.debug("arguments: %s" % str(sys.argv))
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["google",])),
|
||||
)
|
||||
@click.option("-m", "--model_name", required=True)
|
||||
@click.option("-w", "--fixation_weights", required=False)
|
||||
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
|
||||
@click.option("-d", "--debug", is_flag=True, default=False)
|
||||
@click.option("-p", "--prior", type=float, default=.5)
|
||||
def train(corpus_name, model_name, fixation_weights, freeze_fixations, debug, prior):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
|
||||
train_pairs = corpus["train"]
|
||||
val_pairs = corpus["val"]
|
||||
network = init_network(word2index, prior)
|
||||
|
||||
model_dir = os.path.join("models", model_name)
|
||||
logger.debug("creating model dir %s" % model_dir)
|
||||
pathlib.Path(model_dir).mkdir(parents=True)
|
||||
|
||||
if fixation_weights is not None:
|
||||
logger.info("loading fixation prediction checkpoint")
|
||||
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
else:
|
||||
weights = checkpoint
|
||||
|
||||
# remove the embedding layer before loading
|
||||
weights = {
|
||||
k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")
|
||||
}
|
||||
network.fix_gen.load_state_dict(weights, strict=False)
|
||||
|
||||
if freeze_fixations:
|
||||
logger.info("freezing fixation generation network")
|
||||
for p in network.fix_gen.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
|
||||
|
||||
best_val_loss = None
|
||||
|
||||
epoch = 1
|
||||
batch_size = 20
|
||||
|
||||
while True:
|
||||
train_batch_iter = utils.sent_iter(
|
||||
sents=train_pairs, word2index=word2index, batch_size=batch_size
|
||||
)
|
||||
val_batch_iter = utils.sent_iter(
|
||||
sents=val_pairs, word2index=word2index, batch_size=batch_size
|
||||
)
|
||||
|
||||
total_train_loss = 0
|
||||
total_val_loss = 0
|
||||
|
||||
network.train()
|
||||
for i, batch in tqdm.tqdm(
|
||||
enumerate(train_batch_iter, 1), total=len(train_pairs) // batch_size + 1
|
||||
):
|
||||
optimizer.zero_grad()
|
||||
|
||||
raw_sent, sent, target = batch
|
||||
seq_lens = [len(x) for x in sent]
|
||||
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
||||
|
||||
prediction = prediction.detach().cpu().numpy()
|
||||
|
||||
torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_train_loss += loss.item()
|
||||
|
||||
avg_train_loss = total_train_loss / len(train_pairs)
|
||||
|
||||
val_sents = []
|
||||
val_preds = []
|
||||
val_targets = []
|
||||
|
||||
network.eval()
|
||||
for i, batch in tqdm.tqdm(
|
||||
enumerate(val_batch_iter), total=len(val_pairs) // batch_size + 1
|
||||
):
|
||||
raw_sent, sent, target = batch
|
||||
seq_lens = [len(x) for x in sent]
|
||||
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
||||
|
||||
prediction = prediction.detach().cpu().numpy()
|
||||
|
||||
for i, l in enumerate(seq_lens):
|
||||
val_sents.append(raw_sent[i][:l])
|
||||
val_preds.append(prediction[i][:l].tolist())
|
||||
val_targets.append(target[i][:l].tolist())
|
||||
|
||||
total_val_loss += loss.item()
|
||||
|
||||
avg_val_loss = total_val_loss / len(val_pairs)
|
||||
|
||||
print(
|
||||
f"epoch {epoch} train_loss {avg_train_loss:.4f} val_loss {avg_val_loss:.4f}"
|
||||
)
|
||||
|
||||
print(
|
||||
classification_report(
|
||||
[x for y in val_targets for x in y],
|
||||
[x for y in val_preds for x in y],
|
||||
target_names=["not_del", "del"],
|
||||
digits=5,
|
||||
)
|
||||
)
|
||||
|
||||
with open(f"models/{model_name}/val_original_{epoch}.txt", "w") as oh, open(
|
||||
f"models/{model_name}/val_pred_{epoch}.txt", "w"
|
||||
) as ph, open(f"models/{model_name}/val_gold_{epoch}.txt", "w") as gh:
|
||||
for sent, preds, golds in zip(val_sents, val_preds, val_targets):
|
||||
pred_compressed = [
|
||||
word for word, delete in zip(sent, preds) if not delete
|
||||
]
|
||||
gold_compressed = [
|
||||
word for word, delete in zip(sent, golds) if not delete
|
||||
]
|
||||
|
||||
oh.write(" ".join(sent))
|
||||
ph.write(" ".join(pred_compressed))
|
||||
gh.write(" ".join(gold_compressed))
|
||||
|
||||
oh.write("\n")
|
||||
ph.write("\n")
|
||||
gh.write("\n")
|
||||
|
||||
if best_val_loss is None or avg_val_loss < best_val_loss:
|
||||
delta = avg_val_loss - best_val_loss if best_val_loss is not None else 0.0
|
||||
best_val_loss = avg_val_loss
|
||||
print(
|
||||
f"new best model epoch {epoch} val loss {avg_val_loss:.4f} ({delta:.4f})"
|
||||
)
|
||||
|
||||
utils.save_model(
|
||||
network, word2index, f"models/{model_name}/{model_name}_{epoch}"
|
||||
)
|
||||
|
||||
epoch += 1
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["google",])),
|
||||
)
|
||||
@click.option("-w", "--model_weights", required=True)
|
||||
@click.option("-p", "--prior", type=float, default=.5)
|
||||
@click.option("-l", "--longest", is_flag=True)
|
||||
@click.option("-s", "--shortest", is_flag=True)
|
||||
@click.option("-d", "--detailed", is_flag=True)
|
||||
def test(corpus_name, model_weights, prior, longest, shortest, detailed):
|
||||
if longest and shortest:
|
||||
print("longest and shortest are mutually exclusive", file=sys.stderr)
|
||||
sys.exit()
|
||||
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
||||
test_pairs = corpus["test"]
|
||||
|
||||
model_name = os.path.basename(os.path.dirname(model_weights))
|
||||
epoch = re.search("_(\d+).tar", model_weights).group(1)
|
||||
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index, prior)
|
||||
network.eval()
|
||||
|
||||
# remove the embedding layer before loading
|
||||
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
total_test_loss = 0
|
||||
|
||||
batch_size = 20
|
||||
|
||||
test_batch_iter = utils.sent_iter(
|
||||
sents=test_pairs, word2index=word2index, batch_size=batch_size
|
||||
)
|
||||
|
||||
test_sents = []
|
||||
test_preds = []
|
||||
test_targets = []
|
||||
|
||||
for i, batch in tqdm.tqdm(
|
||||
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
|
||||
):
|
||||
raw_sent, sent, target = batch
|
||||
seq_lens = [len(x) for x in sent]
|
||||
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
||||
|
||||
prediction = prediction.detach().cpu().numpy()
|
||||
|
||||
for i, l in enumerate(
|
||||
seq_lens
|
||||
):
|
||||
test_sents.append(raw_sent[i][:l])
|
||||
test_preds.append(prediction[i][:l].tolist())
|
||||
test_targets.append(target[i][:l].tolist())
|
||||
|
||||
total_test_loss += loss.item()
|
||||
|
||||
avg_test_loss = total_test_loss / len(test_pairs)
|
||||
|
||||
print(f"test_loss {avg_test_loss:.4f}")
|
||||
|
||||
if longest:
|
||||
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
||||
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
|
||||
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
|
||||
test_targets = list(filter(lambda x: len(x) > avg_len, test_targets))
|
||||
elif shortest:
|
||||
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
||||
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
|
||||
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
|
||||
test_targets = list(filter(lambda x: len(x) <= avg_len, test_targets))
|
||||
|
||||
if detailed:
|
||||
for test_sent, test_target, test_pred in zip(test_sents, test_targets, test_preds):
|
||||
print(precision_recall_fscore_support(test_target, test_pred, average="weighted")[2], test_sent, test_target, test_pred)
|
||||
else:
|
||||
print(
|
||||
classification_report(
|
||||
[x for y in test_targets for x in y],
|
||||
[x for y in test_preds for x in y],
|
||||
target_names=["not_del", "del"],
|
||||
digits=5,
|
||||
)
|
||||
)
|
||||
|
||||
with open(f"models/{model_name}/test_original_{epoch}.txt", "w") as oh, open(
|
||||
f"models/{model_name}/test_pred_{epoch}.txt", "w"
|
||||
) as ph, open(f"models/{model_name}/test_gold_{epoch}.txt", "w") as gh:
|
||||
for sent, preds, golds in zip(test_sents, test_preds, test_targets):
|
||||
pred_compressed = [word for word, delete in zip(sent, preds) if not delete]
|
||||
gold_compressed = [word for word, delete in zip(sent, golds) if not delete]
|
||||
|
||||
oh.write(" ".join(sent))
|
||||
ph.write(" ".join(pred_compressed))
|
||||
gh.write(" ".join(gold_compressed))
|
||||
|
||||
oh.write("\n")
|
||||
ph.write("\n")
|
||||
gh.write("\n")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["google",])),
|
||||
)
|
||||
@click.option("-w", "--model_weights", required=True)
|
||||
@click.option("-p", "--prior", type=float, default=.5)
|
||||
@click.option("-l", "--longest", is_flag=True)
|
||||
@click.option("-s", "--shortest", is_flag=True)
|
||||
def predict(corpus_name, model_weights, prior, longest, shortest):
|
||||
if longest and shortest:
|
||||
print("longest and shortest are mutually exclusive", file=sys.stderr)
|
||||
sys.exit()
|
||||
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
||||
test_pairs = corpus["test"]
|
||||
|
||||
model_name = os.path.basename(os.path.dirname(model_weights))
|
||||
epoch = re.search("_(\d+).tar", model_weights).group(1)
|
||||
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index, prior)
|
||||
network.eval()
|
||||
|
||||
# remove the embedding layer before loading
|
||||
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
total_test_loss = 0
|
||||
|
||||
batch_size = 20
|
||||
|
||||
test_batch_iter = utils.sent_iter(
|
||||
sents=test_pairs, word2index=word2index, batch_size=batch_size
|
||||
)
|
||||
|
||||
test_sents = []
|
||||
test_preds = []
|
||||
test_attentions = []
|
||||
test_fixations = []
|
||||
|
||||
for i, batch in tqdm.tqdm(
|
||||
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
|
||||
):
|
||||
raw_sent, sent, target = batch
|
||||
seq_lens = [len(x) for x in sent]
|
||||
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
||||
|
||||
prediction = prediction.detach().cpu().numpy()
|
||||
attention = attention.detach().cpu().numpy()
|
||||
if fixations is not None:
|
||||
fixations = fixations.detach().cpu().numpy()
|
||||
|
||||
for i, l in enumerate(
|
||||
seq_lens
|
||||
):
|
||||
test_sents.append(raw_sent[i][:l])
|
||||
test_preds.append(prediction[i][:l].tolist())
|
||||
test_attentions.append(attention[i][:l].tolist())
|
||||
if fixations is not None:
|
||||
test_fixations.append(fixations[i][:l].tolist())
|
||||
else:
|
||||
test_fixations.append([])
|
||||
|
||||
total_test_loss += loss.item()
|
||||
|
||||
avg_test_loss = total_test_loss / len(test_pairs)
|
||||
|
||||
if longest:
|
||||
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
||||
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
|
||||
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
|
||||
test_attentions = list(filter(lambda x: len(x) > avg_len, test_attentions))
|
||||
test_fixations = list(filter(lambda x: len(x) > avg_len, test_fixations))
|
||||
elif shortest:
|
||||
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
||||
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
|
||||
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
|
||||
test_attentions = list(filter(lambda x: len(x) <= avg_len, test_attentions))
|
||||
test_fixations = list(filter(lambda x: len(x) <= avg_len, test_fixations))
|
||||
|
||||
print(f"sentence\tprediction\tattentions\tfixations")
|
||||
for s, p, a, f in zip(test_sents, test_preds, test_attentions, test_fixations):
|
||||
a = [x[:len(a)] for x in a]
|
||||
print(f"{s}\t{p}\t{a}\t{f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
23
joint_sentence_compression_model/requirements.txt
Normal file
23
joint_sentence_compression_model/requirements.txt
Normal file
|
@ -0,0 +1,23 @@
|
|||
click==7.1.2
|
||||
cycler==0.10.0
|
||||
dataclasses==0.6
|
||||
future==0.18.2
|
||||
joblib==0.17.0
|
||||
kiwisolver==1.3.1
|
||||
matplotlib==3.3.3
|
||||
nltk==3.5
|
||||
numpy==1.19.4
|
||||
Pillow==8.0.1
|
||||
portalocker==2.0.0
|
||||
pyparsing==2.4.7
|
||||
python-dateutil==2.8.1
|
||||
regex==2020.11.13
|
||||
sacrebleu==1.4.6
|
||||
scikit-learn==0.23.2
|
||||
scipy==1.5.4
|
||||
six==1.15.0
|
||||
sklearn==0.0
|
||||
threadpoolctl==2.1.0
|
||||
torch==1.7.0
|
||||
tqdm==4.54.1
|
||||
typing-extensions==3.7.4.3
|
43
joint_sentence_compression_model/utils/check_stats.py
Normal file
43
joint_sentence_compression_model/utils/check_stats.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import ast
|
||||
from statistics import mean, pstdev
|
||||
import sys
|
||||
|
||||
import click
|
||||
from scipy.stats import entropy
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def reader(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
s, p, a, f = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line: {line}", file=sys.stderr)
|
||||
else:
|
||||
try:
|
||||
yield ast.literal_eval(s), ast.literal_eval(p), ast.literal_eval(a), ast.literal_eval(f)
|
||||
except:
|
||||
print(f"malformed line: {s}")
|
||||
|
||||
|
||||
def get_stats(seq):
|
||||
for s, p, a, f in seq:
|
||||
print(s)
|
||||
print(p)
|
||||
print(len(s), len(p), len(a), len(f))
|
||||
for x in a:
|
||||
print(len(x))
|
||||
print()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path")
|
||||
def main(path):
|
||||
get_stats(reader(path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
34
joint_sentence_compression_model/utils/cr.py
Normal file
34
joint_sentence_compression_model/utils/cr.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import ast
|
||||
from math import log2
|
||||
import os
|
||||
from statistics import mean, pstdev
|
||||
import sys
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
from scipy.special import kl_div
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def reader(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
yield line.split()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("original")
|
||||
@click.argument("compressed")
|
||||
def main(original, compressed):
|
||||
ratio = 0
|
||||
total = 0
|
||||
for o, c in zip(reader(original), reader(compressed)):
|
||||
ratio += len(c)/len(o)
|
||||
total += 1
|
||||
print(f"cr: {ratio/total:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
88
joint_sentence_compression_model/utils/kl.py
Normal file
88
joint_sentence_compression_model/utils/kl.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import ast
|
||||
from math import log2
|
||||
import os
|
||||
from statistics import mean, pstdev
|
||||
import sys
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
from scipy.special import kl_div
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def attention_reader(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
s, p, a, f = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line: {line}", file=sys.stderr)
|
||||
else:
|
||||
try:
|
||||
yield [[x[0] for x in y] for y in ast.literal_eval(a)]
|
||||
except:
|
||||
print(f"skipping malformed line: {s}", file=sys.stderr)
|
||||
|
||||
|
||||
def _kl_divergence(p, q):
|
||||
p = np.asarray(p)
|
||||
q = np.asarray(q)
|
||||
p /= sum(p)
|
||||
q /= sum(q)
|
||||
return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
|
||||
|
||||
def kl_divergence(ps, qs):
|
||||
kl = 0
|
||||
count = 0
|
||||
for p, q in zip(ps, qs):
|
||||
print(p, q)
|
||||
kl += _kl_divergence(p, q)
|
||||
count += 1
|
||||
return kl/count
|
||||
|
||||
|
||||
def _js_divergence(p, q):
|
||||
p = np.asarray(p)
|
||||
q = np.asarray(q)
|
||||
p /= sum(p)
|
||||
q /= sum(q)
|
||||
print(p, q)
|
||||
m = 0.5 * (p + q)
|
||||
return 0.5 * _kl_divergence(p, m) + 0.5 * _kl_divergence(q, m)
|
||||
|
||||
def js_divergence(ps, qs):
|
||||
js = 0
|
||||
count = 0
|
||||
for p, q in zip(ps, qs):
|
||||
js += _js_divergence(p, q)
|
||||
count += 1
|
||||
return js/count
|
||||
|
||||
|
||||
def get_kl_div(seq1, seq2):
|
||||
return [js_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
|
||||
return [kl_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("ref")
|
||||
@click.argument("path", nargs=-1)
|
||||
def main(ref, path):
|
||||
kls = []
|
||||
labels = []
|
||||
for p in path:
|
||||
labels.append(os.path.basename(p))
|
||||
kl = get_kl_div(attention_reader(ref), attention_reader(p))
|
||||
print(mean(kl))
|
||||
print(pstdev(kl))
|
||||
kls.append(kl)
|
||||
|
||||
plt.boxplot(kls, labels=labels)
|
||||
plt.show()
|
||||
plt.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
112
joint_sentence_compression_model/utils/kl_divergence.py
Normal file
112
joint_sentence_compression_model/utils/kl_divergence.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
'''
|
||||
Calculates the KL divergence between specified columns in a file or across columns of different files
|
||||
'''
|
||||
|
||||
from math import log2
|
||||
import pandas as pd
|
||||
import ast
|
||||
import collections
|
||||
from sklearn import metrics
|
||||
import sys
|
||||
|
||||
|
||||
def kl_divergence(p, q):
|
||||
return sum(p[i] * log2(p[i] / q[i]) for i in range(len(p)))
|
||||
|
||||
|
||||
def flatten(x):
|
||||
if isinstance(x, collections.Iterable):
|
||||
return [a for i in x for a in flatten(i)]
|
||||
else:
|
||||
return [x]
|
||||
|
||||
|
||||
def get_data(file):
|
||||
names = ["sentence", "prediction", "attentions", "fixations"]
|
||||
df = pd.read_csv(file, sep='\t', names=names)
|
||||
df = df[2:]
|
||||
attentions = df.loc[:, "attentions"].tolist()
|
||||
fixations = df.loc[:, "fixations"].tolist()
|
||||
|
||||
return attentions, fixations
|
||||
|
||||
def attention_attention(attentions1, attentions2):
|
||||
divergence = []
|
||||
|
||||
for att1, att2 in zip(attentions1, attentions2):
|
||||
current_att1 = ast.literal_eval(att1)
|
||||
current_att2 = ast.literal_eval(att2)
|
||||
|
||||
lst_att1 = flatten(current_att1)
|
||||
lst_att2 = flatten(current_att2)
|
||||
|
||||
try:
|
||||
kl_pq = metrics.mutual_info_score(lst_att1, lst_att2)
|
||||
divergence.append(kl_pq)
|
||||
except:
|
||||
divergence.append(None)
|
||||
|
||||
avg = sum(divergence) / len(divergence)
|
||||
return avg
|
||||
|
||||
def fixation_fixation(fixation1, fixation2):
|
||||
divergence = []
|
||||
|
||||
for fix1, fix2 in zip(fixation1, fixation2):
|
||||
|
||||
current_fixation1 = ast.literal_eval(fix1)
|
||||
current_fixation2 = ast.literal_eval(fix2)
|
||||
|
||||
lst_fixation1 = flatten(current_fixation1)
|
||||
lst_fixation2 = flatten(current_fixation2)
|
||||
|
||||
try:
|
||||
kl_pq = metrics.mutual_info_score(lst_fixation1, lst_fixation2)
|
||||
divergence.append(kl_pq)
|
||||
except:
|
||||
divergence.append(None)
|
||||
|
||||
avg = sum(divergence) / len(divergence)
|
||||
return avg
|
||||
|
||||
def attention_fixation(attentions,fixations):
|
||||
divergence = []
|
||||
|
||||
for attention, fixation in zip(attentions, fixations):
|
||||
current_attention = ast.literal_eval(attention)
|
||||
current_fixation = ast.literal_eval(fixation)
|
||||
|
||||
lst_attention = []
|
||||
|
||||
for t in current_attention:
|
||||
attention_lst = flatten(t)
|
||||
lst_attention.append(sum(attention_lst)/len(attention_lst))
|
||||
|
||||
lst_fixation = flatten(current_fixation)
|
||||
|
||||
try:
|
||||
kl_pq = metrics.mutual_info_score(lst_attention, lst_fixation)
|
||||
divergence.append(kl_pq)
|
||||
except:
|
||||
divergence.append(None)
|
||||
|
||||
avg = sum(divergence)/len(divergence)
|
||||
return avg
|
||||
|
||||
|
||||
def divergent_calculations(file1, file2=None, val1=None):
|
||||
attentions, fixations = get_data(file1)
|
||||
attentions2, fixations2 = get_data(file2)
|
||||
|
||||
if file2:
|
||||
if val1 == "attention":
|
||||
divergence = attention_attention(attentions, attentions2)
|
||||
else:
|
||||
divergence = fixation_fixation(fixations, fixations2)
|
||||
|
||||
else:
|
||||
divergence = attention_fixation(attentions, fixations)
|
||||
|
||||
print ("DL Divergence: ", divergence)
|
||||
|
||||
divergent_calculations(sys.argv[1], sys.argv[2], sys.argv[3])
|
64
joint_sentence_compression_model/utils/plot_attention.py
Normal file
64
joint_sentence_compression_model/utils/plot_attention.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import ast
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import text_attention
|
||||
|
||||
import click
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("agg")
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
|
||||
def plot_attention(input_sentence, output_words, attentions, path):
|
||||
# Set up figure with colorbar
|
||||
attentions = np.array(attentions)[:,:len(input_sentence)]
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
cax = ax.matshow(attentions, cmap="bone")
|
||||
fig.colorbar(cax)
|
||||
|
||||
# Set up axes
|
||||
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
|
||||
ax.set_yticklabels([""] + output_words)
|
||||
|
||||
# Show label at every tick
|
||||
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
|
||||
plt.savefig(f"{path}.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
def parse(p):
|
||||
with open(p) as h:
|
||||
for line in h:
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
|
||||
try:
|
||||
sentence = ast.literal_eval(_sentence)
|
||||
prediction = ast.literal_eval(_prediction)
|
||||
attention = ast.literal_eval(_attention)
|
||||
except:
|
||||
continue
|
||||
|
||||
yield sentence, prediction, attention
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path", nargs=-1, required=True)
|
||||
def main(path):
|
||||
for p in tqdm.tqdm(path):
|
||||
out_dir = os.path.splitext(p)[0]
|
||||
if out_dir == path:
|
||||
out_dir = f"{out_dir}_"
|
||||
pathlib.Path(out_dir).mkdir(exist_ok=True)
|
||||
for i, spa in enumerate(parse(p)):
|
||||
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
70
joint_sentence_compression_model/utils/text_attention.py
Executable file
70
joint_sentence_compression_model/utils/text_attention.py
Executable file
|
@ -0,0 +1,70 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author: Jie Yang
|
||||
# @Date: 2019-03-29 16:10:23
|
||||
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||
# @Last Modified time: 2019-04-12 09:56:12
|
||||
|
||||
|
||||
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
|
||||
import numpy as np
|
||||
|
||||
latex_special_token = ["!@#$%^&*()"]
|
||||
|
||||
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
|
||||
assert(len(text_list) == len(attention_list))
|
||||
if rescale_value:
|
||||
attention_list = rescale(attention_list)
|
||||
word_num = len(text_list)
|
||||
text_list = clean_word(text_list)
|
||||
with open(latex_file,'w') as f:
|
||||
f.write(r'''\documentclass[varwidth]{standalone}
|
||||
\special{papersize=210mm,297mm}
|
||||
\usepackage{color}
|
||||
\usepackage{tcolorbox}
|
||||
\usepackage{CJK}
|
||||
\usepackage{adjustbox}
|
||||
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
||||
\begin{document}
|
||||
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
|
||||
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
|
||||
for idx in range(word_num):
|
||||
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
|
||||
string += "\n}}}"
|
||||
f.write(string+'\n')
|
||||
f.write(r'''\end{CJK*}
|
||||
\end{document}''')
|
||||
|
||||
def rescale(input_list):
|
||||
the_array = np.asarray(input_list)
|
||||
the_max = np.max(the_array)
|
||||
the_min = np.min(the_array)
|
||||
rescale = (the_array - the_min)/(the_max-the_min)*100
|
||||
return rescale.tolist()
|
||||
|
||||
|
||||
def clean_word(word_list):
|
||||
new_word_list = []
|
||||
for word in word_list:
|
||||
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
|
||||
if latex_sensitive in word:
|
||||
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
|
||||
new_word_list.append(word)
|
||||
return new_word_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
## This is a demo:
|
||||
|
||||
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
|
||||
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
|
||||
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
|
||||
sent = '''我 回忆 起 我 曾经 在 大学 年代 , 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar , 我 想起 了 西游记 里 的 琵琶精 。
|
||||
今年 下半年 , 中 美 合拍 的 西游记 即将 正式 开机 , 我 继续 扮演 美猴王 孙悟空 , 我 会 用 美猴王 艺术 形象 努力 创造 一 个 正能量 的 形象 , 文 体 两 开花 , 弘扬 中华 文化 , 希望 大家 能 多多 关注 。'''
|
||||
words = sent.split()
|
||||
word_num = len(words)
|
||||
attention = [(x+1.)/word_num*100 for x in range(word_num)]
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(attention)
|
||||
color = 'red'
|
||||
generate(words, attention, "sample.tex", color)
|
Loading…
Reference in a new issue