Add NLP task models
This commit is contained in:
parent
d8beb17dfb
commit
69f6de0ace
46 changed files with 4976 additions and 0 deletions
428
joint_paraphrase_model/.gitignore
vendored
Normal file
428
joint_paraphrase_model/.gitignore
vendored
Normal file
|
@ -0,0 +1,428 @@
|
|||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python,latex
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,latex
|
||||
|
||||
### LaTeX ###
|
||||
## Core latex/pdflatex auxiliary files:
|
||||
*.aux
|
||||
*.lof
|
||||
*.log
|
||||
*.lot
|
||||
*.fls
|
||||
*.out
|
||||
*.toc
|
||||
*.fmt
|
||||
*.fot
|
||||
*.cb
|
||||
*.cb2
|
||||
.*.lb
|
||||
|
||||
## Intermediate documents:
|
||||
*.dvi
|
||||
*.xdv
|
||||
*-converted-to.*
|
||||
# these rules might exclude image files for figures etc.
|
||||
# *.ps
|
||||
# *.eps
|
||||
# *.pdf
|
||||
|
||||
## Generated if empty string is given at "Please type another file name for output:"
|
||||
.pdf
|
||||
|
||||
## Bibliography auxiliary files (bibtex/biblatex/biber):
|
||||
*.bbl
|
||||
*.bcf
|
||||
*.blg
|
||||
*-blx.aux
|
||||
*-blx.bib
|
||||
*.run.xml
|
||||
|
||||
## Build tool auxiliary files:
|
||||
*.fdb_latexmk
|
||||
*.synctex
|
||||
*.synctex(busy)
|
||||
*.synctex.gz
|
||||
*.synctex.gz(busy)
|
||||
*.pdfsync
|
||||
|
||||
## Build tool directories for auxiliary files
|
||||
# latexrun
|
||||
latex.out/
|
||||
|
||||
## Auxiliary and intermediate files from other packages:
|
||||
# algorithms
|
||||
*.alg
|
||||
*.loa
|
||||
|
||||
# achemso
|
||||
acs-*.bib
|
||||
|
||||
# amsthm
|
||||
*.thm
|
||||
|
||||
# beamer
|
||||
*.nav
|
||||
*.pre
|
||||
*.snm
|
||||
*.vrb
|
||||
|
||||
# changes
|
||||
*.soc
|
||||
|
||||
# comment
|
||||
*.cut
|
||||
|
||||
# cprotect
|
||||
*.cpt
|
||||
|
||||
# elsarticle (documentclass of Elsevier journals)
|
||||
*.spl
|
||||
|
||||
# endnotes
|
||||
*.ent
|
||||
|
||||
# fixme
|
||||
*.lox
|
||||
|
||||
# feynmf/feynmp
|
||||
*.mf
|
||||
*.mp
|
||||
*.t[1-9]
|
||||
*.t[1-9][0-9]
|
||||
*.tfm
|
||||
|
||||
#(r)(e)ledmac/(r)(e)ledpar
|
||||
*.end
|
||||
*.?end
|
||||
*.[1-9]
|
||||
*.[1-9][0-9]
|
||||
*.[1-9][0-9][0-9]
|
||||
*.[1-9]R
|
||||
*.[1-9][0-9]R
|
||||
*.[1-9][0-9][0-9]R
|
||||
*.eledsec[1-9]
|
||||
*.eledsec[1-9]R
|
||||
*.eledsec[1-9][0-9]
|
||||
*.eledsec[1-9][0-9]R
|
||||
*.eledsec[1-9][0-9][0-9]
|
||||
*.eledsec[1-9][0-9][0-9]R
|
||||
|
||||
# glossaries
|
||||
*.acn
|
||||
*.acr
|
||||
*.glg
|
||||
*.glo
|
||||
*.gls
|
||||
*.glsdefs
|
||||
*.lzo
|
||||
*.lzs
|
||||
|
||||
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
|
||||
# *.ist
|
||||
|
||||
# gnuplottex
|
||||
*-gnuplottex-*
|
||||
|
||||
# gregoriotex
|
||||
*.gaux
|
||||
*.gtex
|
||||
|
||||
# htlatex
|
||||
*.4ct
|
||||
*.4tc
|
||||
*.idv
|
||||
*.lg
|
||||
*.trc
|
||||
*.xref
|
||||
|
||||
# hyperref
|
||||
*.brf
|
||||
|
||||
# knitr
|
||||
*-concordance.tex
|
||||
# TODO Comment the next line if you want to keep your tikz graphics files
|
||||
*.tikz
|
||||
*-tikzDictionary
|
||||
|
||||
# listings
|
||||
*.lol
|
||||
|
||||
# luatexja-ruby
|
||||
*.ltjruby
|
||||
|
||||
# makeidx
|
||||
*.idx
|
||||
*.ilg
|
||||
*.ind
|
||||
|
||||
# minitoc
|
||||
*.maf
|
||||
*.mlf
|
||||
*.mlt
|
||||
*.mtc[0-9]*
|
||||
*.slf[0-9]*
|
||||
*.slt[0-9]*
|
||||
*.stc[0-9]*
|
||||
|
||||
# minted
|
||||
_minted*
|
||||
*.pyg
|
||||
|
||||
# morewrites
|
||||
*.mw
|
||||
|
||||
# nomencl
|
||||
*.nlg
|
||||
*.nlo
|
||||
*.nls
|
||||
|
||||
# pax
|
||||
*.pax
|
||||
|
||||
# pdfpcnotes
|
||||
*.pdfpc
|
||||
|
||||
# sagetex
|
||||
*.sagetex.sage
|
||||
*.sagetex.py
|
||||
*.sagetex.scmd
|
||||
|
||||
# scrwfile
|
||||
*.wrt
|
||||
|
||||
# sympy
|
||||
*.sout
|
||||
*.sympy
|
||||
sympy-plots-for-*.tex/
|
||||
|
||||
# pdfcomment
|
||||
*.upa
|
||||
*.upb
|
||||
|
||||
# pythontex
|
||||
*.pytxcode
|
||||
pythontex-files-*/
|
||||
|
||||
# tcolorbox
|
||||
*.listing
|
||||
|
||||
# thmtools
|
||||
*.loe
|
||||
|
||||
# TikZ & PGF
|
||||
*.dpth
|
||||
*.md5
|
||||
*.auxlock
|
||||
|
||||
# todonotes
|
||||
*.tdo
|
||||
|
||||
# vhistory
|
||||
*.hst
|
||||
*.ver
|
||||
|
||||
# easy-todo
|
||||
*.lod
|
||||
|
||||
# xcolor
|
||||
*.xcp
|
||||
|
||||
# xmpincl
|
||||
*.xmpi
|
||||
|
||||
# xindy
|
||||
*.xdy
|
||||
|
||||
# xypic precompiled matrices and outlines
|
||||
*.xyc
|
||||
*.xyd
|
||||
|
||||
# endfloat
|
||||
*.ttt
|
||||
*.fff
|
||||
|
||||
# Latexian
|
||||
TSWLatexianTemp*
|
||||
|
||||
## Editors:
|
||||
# WinEdt
|
||||
*.bak
|
||||
*.sav
|
||||
|
||||
# Texpad
|
||||
.texpadtmp
|
||||
|
||||
# LyX
|
||||
*.lyx~
|
||||
|
||||
# Kile
|
||||
*.backup
|
||||
|
||||
# gummi
|
||||
.*.swp
|
||||
|
||||
# KBibTeX
|
||||
*~[0-9]*
|
||||
|
||||
# TeXnicCenter
|
||||
*.tps
|
||||
|
||||
# auto folder when using emacs and auctex
|
||||
./auto/*
|
||||
*.el
|
||||
|
||||
# expex forward references with \gathertags
|
||||
*-tags.tex
|
||||
|
||||
# standalone packages
|
||||
*.sta
|
||||
|
||||
# Makeindex log files
|
||||
*.lpz
|
||||
|
||||
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
|
||||
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
|
||||
# Uncomment the next line to have this generated file ignored.
|
||||
#*Notes.bib
|
||||
|
||||
### LaTeX Patch ###
|
||||
# LIPIcs / OASIcs
|
||||
*.vtc
|
||||
|
||||
# glossaries
|
||||
*.glstex
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python,latex
|
3
joint_paraphrase_model/README.md
Normal file
3
joint_paraphrase_model/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# joint_paraphrase_model
|
||||
|
||||
joint training paraphrase model --- neurips
|
112
joint_paraphrase_model/config.py
Normal file
112
joint_paraphrase_model/config.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
|
||||
# general
|
||||
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
PAD = "<__PAD__>"
|
||||
UNK = "<__UNK__>"
|
||||
NOFIX = "<__NOFIX__>"
|
||||
SOS = "<__SOS__>"
|
||||
EOS = "<__EOS__>"
|
||||
|
||||
batch_size = 1
|
||||
teacher_forcing_ratio = 0.5
|
||||
embedding_dim = 300
|
||||
fix_hidden_dim = 128
|
||||
par_hidden_dim = 1024
|
||||
fix_dropout = 0.5
|
||||
par_dropout = 0.2
|
||||
_fix_learning_rate = 0.00001
|
||||
_par_learning_rate = 0.0001
|
||||
learning_rate = _par_learning_rate
|
||||
fix_momentum = 0.9
|
||||
par_momentum = 0.0
|
||||
max_length = 121
|
||||
epochs = 5
|
||||
|
||||
# paths
|
||||
data_path = "./data"
|
||||
provo_predictability_path = os.path.join(
|
||||
data_path, "datasets/provo/Provo_Corpus-Predictability_Norms.csv"
|
||||
)
|
||||
provo_eyetracking_path = os.path.join(
|
||||
data_path, "datasets/provo/Provo_Corpus-Eyetracking_Data.csv"
|
||||
)
|
||||
|
||||
geco_en_path = os.path.join(data_path, "datasets/geco/EnglishMaterial.csv")
|
||||
geco_mono_path = os.path.join(data_path, "datasets/geco/MonolingualReadingData.csv")
|
||||
|
||||
movieqa_human_path = os.path.join(data_path, "datasets/all_word_scores_fixations")
|
||||
movieqa_human_path_2 = os.path.join(
|
||||
data_path, "datasets/all_word_scores_fixations_exp2"
|
||||
)
|
||||
movieqa_human_path_3 = os.path.join(
|
||||
data_path, "datasets/all_word_scores_fixations_exp3"
|
||||
)
|
||||
movieqa_split_plot_path = os.path.join(data_path, "datasets/split_plot_UNRESOLVED")
|
||||
|
||||
cnn_path = os.path.join(
|
||||
data_path,
|
||||
"projects/2019/fixation_prediction/ez-reader-wrapper/predictability/output_cnn/",
|
||||
)
|
||||
dm_path = os.path.join(
|
||||
data_path,
|
||||
"projects/2019/fixation_prediction/ez-reader-wrapper/predictability/output_dm/",
|
||||
)
|
||||
|
||||
qqp_paws_basedir = os.path.join(data_path, "datasets/paw_google/qqp/paws_qqp/output")
|
||||
qqp_paws_train_path = os.path.join(qqp_paws_basedir, "train.tsv")
|
||||
qqp_paws_dev_path = os.path.join(qqp_paws_basedir, "dev.tsv")
|
||||
qqp_paws_test_path = os.path.join(qqp_paws_basedir, "test.tsv")
|
||||
|
||||
qqp_basedir = os.path.join(data_path, "datasets/Quora_question_pair_partition_OG")
|
||||
qqp_train_path = os.path.join(qqp_basedir, "train.tsv")
|
||||
qqp_dev_path = os.path.join(qqp_basedir, "dev.tsv")
|
||||
qqp_test_path = os.path.join(qqp_basedir, "test.tsv")
|
||||
|
||||
qqp_kag_basedir = os.path.join(data_path, "datasets/Quora_question_pair_partition_kag")
|
||||
qqp_kag_train_path = os.path.join(qqp_kag_basedir, "train.tsv")
|
||||
qqp_kag_dev_path = os.path.join(qqp_kag_basedir, "dev.tsv")
|
||||
qqp_kag_test_path = os.path.join(qqp_kag_basedir, "test.tsv")
|
||||
|
||||
wiki_basedir = os.path.join(data_path, "datasets/paw_google/wiki")
|
||||
wiki_train_path = os.path.join(wiki_basedir, "train.tsv")
|
||||
wiki_dev_path = os.path.join(wiki_basedir, "dev.tsv")
|
||||
wiki_test_path = os.path.join(wiki_basedir, "test.tsv")
|
||||
|
||||
msrpc_basedir = os.path.join(data_path, "datasets/MSRPC")
|
||||
msrpc_train_path = os.path.join(msrpc_basedir, "msr_paraphrase_train.txt")
|
||||
msrpc_dev_path = os.path.join(msrpc_basedir, "msr_paraphrase_dev.txt")
|
||||
msrpc_test_path = os.path.join(msrpc_basedir, "msr_paraphrase_test.txt")
|
||||
|
||||
sentiment_basedir = os.path.join(data_path, "datasets/sentiment_kag")
|
||||
sentiment_train_path = os.path.join(sentiment_basedir, "train.tsv")
|
||||
sentiment_dev_path = os.path.join(sentiment_basedir, "dev.tsv")
|
||||
sentiment_test_path = os.path.join(sentiment_basedir, "test.tsv")
|
||||
|
||||
tamil_basedir = os.path.join(data_path, "datasets/en-ta-parallel-v2")
|
||||
tamil_train_path = os.path.join(tamil_basedir, "corpus.bcn.train.enta")
|
||||
tamil_dev_path = os.path.join(tamil_basedir, "corpus.bcn.dev.enta")
|
||||
tamil_test_path = os.path.join(tamil_basedir, "corpus.bcn.test.enta")
|
||||
|
||||
compression_basedir = os.path.join(data_path, "datasets/sentence-compression/data")
|
||||
compression_train_path = os.path.join(compression_basedir, "train.tsv")
|
||||
compression_dev_path = os.path.join(compression_basedir, "dev.tsv")
|
||||
compression_test_path = os.path.join(compression_basedir, "test.tsv")
|
||||
|
||||
stanford_basedir = os.path.join(data_path, "datasets/stanfordSentimentTreebank")
|
||||
stanford_train_path = os.path.join(stanford_basedir, "train.tsv")
|
||||
stanford_dev_path = os.path.join(stanford_basedir, "dev.tsv")
|
||||
stanford_test_path = os.path.join(stanford_basedir, "test.tsv")
|
||||
|
||||
stanford_sent_basedir = os.path.join(data_path, "datasets/stanfordSentimentTreebank")
|
||||
stanford_sent_train_path = os.path.join(stanford_basedir, "train_sent.tsv")
|
||||
stanford_sent_dev_path = os.path.join(stanford_basedir, "dev_sent.tsv")
|
||||
stanford_sent_test_path = os.path.join(stanford_basedir, "test_sent.tsv")
|
||||
|
||||
|
||||
emb_path = os.path.join(data_path, "Google_word2vec/GoogleNews-vectors-negative300.bin")
|
||||
|
||||
glove_path = "glove.840B.300d.txt"
|
1
joint_paraphrase_model/data
Symbolic link
1
joint_paraphrase_model/data
Symbolic link
|
@ -0,0 +1 @@
|
|||
/netpool/work/gpu-2/users/soodea/
|
1
joint_paraphrase_model/glove.840B.300d.txt
Symbolic link
1
joint_paraphrase_model/glove.840B.300d.txt
Symbolic link
|
@ -0,0 +1 @@
|
|||
/netpool/work/gpu-2/users/soodea/datasets/glove/glove.840B.300d.txt
|
1
joint_paraphrase_model/glove.cache
Normal file
1
joint_paraphrase_model/glove.cache
Normal file
File diff suppressed because one or more lines are too long
0
joint_paraphrase_model/libs/__init__.py
Normal file
0
joint_paraphrase_model/libs/__init__.py
Normal file
416
joint_paraphrase_model/libs/corpora.py
Normal file
416
joint_paraphrase_model/libs/corpora.py
Normal file
|
@ -0,0 +1,416 @@
|
|||
import logging
|
||||
|
||||
import config
|
||||
|
||||
|
||||
def tokenize(sent):
|
||||
return sent.split(" ")
|
||||
|
||||
|
||||
class Lang:
|
||||
"""Represents the vocabulary
|
||||
"""
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.word2index = {
|
||||
config.PAD: 0,
|
||||
config.UNK: 1,
|
||||
config.NOFIX: 2,
|
||||
config.SOS: 3,
|
||||
config.EOS: 4,
|
||||
}
|
||||
self.word2count = {}
|
||||
self.index2word = {
|
||||
0: config.PAD,
|
||||
1: config.UNK,
|
||||
2: config.NOFIX,
|
||||
3: config.SOS,
|
||||
4: config.EOS,
|
||||
}
|
||||
self.n_words = 5
|
||||
|
||||
def add_sentence(self, sentence):
|
||||
assert isinstance(
|
||||
sentence, (list, tuple)
|
||||
), "input to add_sentence must be tokenized"
|
||||
for word in sentence:
|
||||
self.add_word(word)
|
||||
|
||||
def add_word(self, word):
|
||||
if word not in self.word2index:
|
||||
self.word2index[word] = self.n_words
|
||||
self.word2count[word] = 1
|
||||
self.index2word[self.n_words] = word
|
||||
self.n_words += 1
|
||||
else:
|
||||
self.word2count[word] += 1
|
||||
|
||||
def __add__(self, other):
|
||||
"""Returns a new Lang object containing the vocabulary from this and
|
||||
the other Lang object
|
||||
"""
|
||||
new_lang = Lang(f"{self.name}_{other.name}")
|
||||
|
||||
# Add vocabulary from both Langs
|
||||
for word in self.word2count.keys():
|
||||
new_lang.add_word(word)
|
||||
for word in other.word2count.keys():
|
||||
new_lang.add_word(word)
|
||||
|
||||
# Fix the counts on the new one
|
||||
for word in new_lang.word2count.keys():
|
||||
new_lang.word2count[word] = self.word2count.get(
|
||||
word, 0
|
||||
) + other.word2count.get(word, 0)
|
||||
|
||||
return new_lang
|
||||
|
||||
|
||||
def load_wiki(split):
|
||||
"""Load the Wiki from PAWs"""
|
||||
logger = logging.getLogger(f"{__name__}.load_wiki")
|
||||
lang = Lang("wiki")
|
||||
|
||||
if split == "train":
|
||||
path = config.wiki_train_path
|
||||
elif split == "val":
|
||||
path = config.wiki_dev_path
|
||||
elif split == "test":
|
||||
path = config.wiki_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, sent1, sent2, rating = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# MS makes the vocab for paraphrase the same
|
||||
return pairs, lang
|
||||
|
||||
|
||||
def load_qqp_paws(split):
|
||||
"""Load the QQP from PAWs"""
|
||||
logger = logging.getLogger(f"{__name__}.load_qqp_paws")
|
||||
lang = Lang("qqp_paws")
|
||||
|
||||
if split == "train":
|
||||
path = config.qqp_paws_train_path
|
||||
elif split == "val":
|
||||
path = config.qqp_paws_dev_path
|
||||
elif split == "test":
|
||||
path = config.qqp_paws_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, sent1, sent2, rating = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# MS makes the vocab for paraphrase the same
|
||||
return pairs, lang
|
||||
|
||||
def load_qqp(split):
|
||||
"""Load the QQP from Original"""
|
||||
logger = logging.getLogger(f"{__name__}.load_qqp")
|
||||
lang = Lang("qqp")
|
||||
|
||||
if split == "train":
|
||||
path = config.qqp_train_path
|
||||
elif split == "val":
|
||||
path = config.qqp_dev_path
|
||||
elif split == "test":
|
||||
path = config.qqp_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
rating, sent1, sent2, _ = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# MS makes the vocab for paraphrase the same
|
||||
return pairs, lang
|
||||
|
||||
|
||||
def load_qqp_kag(split):
|
||||
"""Load the QQP from Kaggle""" #not original right now, expriemnting with kaggle 100K, 3K, 30K split
|
||||
logger = logging.getLogger(f"{__name__}.load_qqp_kag")
|
||||
lang = Lang("qqp_kag")
|
||||
|
||||
if split == "train":
|
||||
path = config.qqp_kag_train_path
|
||||
elif split == "val":
|
||||
path = config.qqp_kag_dev_path
|
||||
elif split == "test":
|
||||
path = config.qqp_kag_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle: #when reading the kag version we do not have 4 fields, but rather 3
|
||||
rating, sent1, sent2 = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# MS makes the vocab for paraphrase the same
|
||||
return pairs, lang
|
||||
|
||||
|
||||
def load_msrpc(split):
|
||||
"""Load the Microsoft Research Paraphrase Corpus (MSRPC)"""
|
||||
logger = logging.getLogger(f"{__name__}.load_msrpc")
|
||||
lang = Lang("msrpc")
|
||||
|
||||
if split == "train":
|
||||
path = config.msrpc_train_path
|
||||
elif split == "val":
|
||||
path = config.msrpc_dev_path
|
||||
elif split == "test":
|
||||
path = config.msrpc_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
rating, _, _, sent1, sent2 = line.strip().split("\t")
|
||||
if rating == "0":
|
||||
continue
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
# return src_lang, dst_lang, pairs
|
||||
# MS makes the vocab for paraphrase the same
|
||||
|
||||
return pairs, lang
|
||||
|
||||
def load_sentiment(split):
|
||||
"""Load the Sentiment Kaggle Comp Dataset"""
|
||||
logger = logging.getLogger(f"{__name__}.load_sentiment")
|
||||
lang = Lang("sentiment")
|
||||
|
||||
if split == "train":
|
||||
path = config.sentiment_train_path
|
||||
elif split == "val":
|
||||
path = config.sentiment_dev_path
|
||||
elif split == "test":
|
||||
path = config.sentiment_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, _, sent1, sent2 = line.strip().split("\t")
|
||||
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
||||
|
||||
|
||||
def load_tamil(split):
|
||||
"""Load the En to Tamil dataset, current SOTA ~13 bleu"""
|
||||
logger = logging.getLogger(f"{__name__}.load_tamil")
|
||||
lang = Lang("tamil")
|
||||
|
||||
if split == "train":
|
||||
path = config.tamil_train_path
|
||||
elif split == "val":
|
||||
path = config.tamil_dev_path
|
||||
elif split == "test":
|
||||
path = config.tamil_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
sent1, sent2 = line.strip().split("\t")
|
||||
#if rating == "0":
|
||||
# continue
|
||||
sent1 = tokenize(sent1)
|
||||
#I dunno how to tokenize tamil.....?
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
||||
|
||||
def load_compression(split):
|
||||
"""Load the Google Sentence Compression Dataset"""
|
||||
logger = logging.getLogger(f"{__name__}.load_compression")
|
||||
lang = Lang("compression")
|
||||
|
||||
if split == "train":
|
||||
path = config.compression_train_path
|
||||
elif split == "val":
|
||||
path = config.compression_dev_path
|
||||
elif split == "test":
|
||||
path = config.compression_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
with open(path) as handle:
|
||||
|
||||
handle.readline()
|
||||
|
||||
for line in handle:
|
||||
sent1, sent2 = line.strip().split("\t")
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
# print(len(sent1), sent1)
|
||||
# print(len(sent2), sent2)
|
||||
# print()
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
||||
|
||||
def load_stanford(split):
|
||||
"""Load the Stanford Sentiment Dataset phrases"""
|
||||
logger = logging.getLogger(f"{__name__}.load_stanford")
|
||||
lang = Lang("stanford")
|
||||
|
||||
if split == "train":
|
||||
path = config.stanford_train_path
|
||||
elif split == "val":
|
||||
path = config.stanford_dev_path
|
||||
elif split == "test":
|
||||
path = config.stanford_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
#handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, _, sent1, sent2 = line.strip().split("\t")
|
||||
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
||||
|
||||
def load_stanford_sent(split):
|
||||
"""Load the Stanford Sentiment Dataset sentences"""
|
||||
logger = logging.getLogger(f"{__name__}.load_stanford_sent")
|
||||
lang = Lang("stanford_sent")
|
||||
|
||||
if split == "train":
|
||||
path = config.stanford_sent_train_path
|
||||
elif split == "val":
|
||||
path = config.stanford_sent_dev_path
|
||||
elif split == "test":
|
||||
path = config.stanford_sent_test_path
|
||||
|
||||
logger.info("loading %s from %s" % (split, path))
|
||||
|
||||
pairs = []
|
||||
|
||||
with open(path) as handle:
|
||||
|
||||
# skip header
|
||||
#handle.readline()
|
||||
|
||||
for line in handle:
|
||||
_, _, sent1, sent2 = line.strip().split("\t")
|
||||
|
||||
sent1 = tokenize(sent1)
|
||||
sent2 = tokenize(sent2)
|
||||
lang.add_sentence(sent1)
|
||||
lang.add_sentence(sent2)
|
||||
|
||||
# pairs.append([sent1, sent2, rating])
|
||||
pairs.append([sent1, sent2])
|
||||
|
||||
return pairs, lang
|
|
@ -0,0 +1 @@
|
|||
from .main import *
|
131
joint_paraphrase_model/libs/fixation_generation/main.py
Normal file
131
joint_paraphrase_model/libs/fixation_generation/main.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
from collections import OrderedDict
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from .self_attention import Transformer
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence
|
||||
|
||||
|
||||
def random_embedding(vocab_size, embedding_dim):
|
||||
pretrain_emb = np.empty([vocab_size, embedding_dim])
|
||||
scale = np.sqrt(3.0 / embedding_dim)
|
||||
for index in range(vocab_size):
|
||||
pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
|
||||
return pretrain_emb
|
||||
|
||||
|
||||
def neg_log_likelihood_loss(outputs, batch_label, batch_size, seq_len):
|
||||
outputs = outputs.view(batch_size * seq_len, -1)
|
||||
score = F.log_softmax(outputs, 1)
|
||||
|
||||
loss = nn.NLLLoss(ignore_index=0, size_average=False)(
|
||||
score, batch_label.view(batch_size * seq_len)
|
||||
)
|
||||
loss = loss / batch_size
|
||||
_, tag_seq = torch.max(score, 1)
|
||||
tag_seq = tag_seq.view(batch_size, seq_len)
|
||||
|
||||
# print(score[0], tag_seq[0])
|
||||
|
||||
return loss, tag_seq
|
||||
|
||||
|
||||
def mse_loss(outputs, batch_label, batch_size, seq_len, word_seq_length):
|
||||
# score = torch.nn.functional.softmax(outputs, 1)
|
||||
score = torch.sigmoid(outputs)
|
||||
|
||||
mask = torch.zeros_like(score)
|
||||
for i, v in enumerate(word_seq_length):
|
||||
mask[i, 0:v] = 1
|
||||
|
||||
score = score * mask
|
||||
|
||||
loss = nn.MSELoss(reduction="sum")(
|
||||
score.view(batch_size, seq_len), batch_label.view(batch_size, seq_len)
|
||||
)
|
||||
|
||||
loss = loss / batch_size
|
||||
|
||||
return loss, score.view(batch_size, seq_len)
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_type,
|
||||
vocab_size,
|
||||
embedding_dim,
|
||||
dropout,
|
||||
hidden_dim,
|
||||
embeddings=None,
|
||||
attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(f"{__name__}")
|
||||
prelayers = OrderedDict()
|
||||
postlayers = OrderedDict()
|
||||
|
||||
if embedding_type in ("w2v", "glove"):
|
||||
if embeddings is not None:
|
||||
prelayers["embedding_layer"] = nn.Embedding.from_pretrained(embeddings)
|
||||
else:
|
||||
prelayers["embedding_layer"] = nn.Embedding(vocab_size, embedding_dim)
|
||||
prelayers["embedding_dropout_layer"] = nn.Dropout(dropout)
|
||||
embedding_dim = 300
|
||||
elif embedding_type == "bert":
|
||||
embedding_dim = 768
|
||||
|
||||
self.lstm = BiLSTM(embedding_dim, hidden_dim // 2, num_layers=1)
|
||||
postlayers["lstm_dropout_layer"] = nn.Dropout(dropout)
|
||||
|
||||
if attention:
|
||||
# increased compl with 1024D, and 16,16: for no att and att experiments
|
||||
# before: for the initial att and pretraining: heads 4 and layers 4, 128D
|
||||
# then was 128 D with heads 4 layer 1 = results for all IUI
|
||||
###postlayers["position_encodings"] = PositionalEncoding(hidden_dim)
|
||||
postlayers["attention_layer"] = Transformer(
|
||||
d_model=hidden_dim, n_heads=4, n_layers=1
|
||||
)
|
||||
|
||||
postlayers["ff_layer"] = nn.Linear(hidden_dim, hidden_dim // 2)
|
||||
postlayers["ff_activation"] = nn.ReLU()
|
||||
postlayers["output_layer"] = nn.Linear(hidden_dim // 2, 1)
|
||||
|
||||
self.logger.info(f"prelayers: {prelayers.keys()}")
|
||||
self.logger.info(f"postlayers: {postlayers.keys()}")
|
||||
|
||||
self.pre = nn.Sequential(prelayers)
|
||||
self.post = nn.Sequential(postlayers)
|
||||
|
||||
def forward(self, x, word_seq_length):
|
||||
x = self.pre(x)
|
||||
x = self.lstm(x, word_seq_length)
|
||||
#MS pritning fix model params
|
||||
#for p in self.parameters():
|
||||
# print(p.data)
|
||||
# break
|
||||
|
||||
return self.post(x.transpose(1, 0))
|
||||
|
||||
|
||||
class BiLSTM(nn.Module):
|
||||
def __init__(self, embedding_dim, lstm_hidden, num_layers):
|
||||
super().__init__()
|
||||
self.net = nn.LSTM(
|
||||
input_size=embedding_dim,
|
||||
hidden_size=lstm_hidden,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
def forward(self, x, word_seq_length):
|
||||
packed_words = pack_padded_sequence(x, word_seq_length, True, False)
|
||||
lstm_out, hidden = self.net(packed_words)
|
||||
lstm_out, _ = pad_packed_sequence(lstm_out)
|
||||
return lstm_out
|
|
@ -0,0 +1,131 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_hid, n_position=200):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
|
||||
# Not a parameter
|
||||
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
''' Sinusoid position encoding table '''
|
||||
# TODO: make it with torch instead of numpy
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
||||
|
||||
|
||||
class AttentionLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(AttentionLayer, self).__init__()
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
# Q: float32:[batch_size, n_queries, d_k]
|
||||
# K: float32:[batch_size, n_keys, d_k]
|
||||
# V: float32:[batch_size, n_keys, d_v]
|
||||
dk = K.shape[-1]
|
||||
dv = V.shape[-1]
|
||||
KT = torch.transpose(K, -1, -2)
|
||||
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
|
||||
# weight_logits: float32[batch_size, n_queries, n_keys]
|
||||
weights = F.softmax(weight_logits, dim=-1)
|
||||
# weight: float32[batch_size, n_queries, n_keys]
|
||||
return torch.bmm(weights, V)
|
||||
# return float32[batch_size, n_queries, dv]
|
||||
|
||||
|
||||
class MultiHeadedSelfAttentionLayer(nn.Module):
|
||||
def __init__(self, d_model, n_heads):
|
||||
super(MultiHeadedSelfAttentionLayer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
print('{} {}'.format(d_model, n_heads))
|
||||
assert d_model % n_heads == 0
|
||||
self.d_k = d_model // n_heads
|
||||
self.d_v = self.d_k
|
||||
self.attention_layer = AttentionLayer()
|
||||
self.W_Qs = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_k, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_Ks = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_k, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_Vs = nn.ModuleList([
|
||||
nn.Linear(d_model, self.d_v, bias=False)
|
||||
for _ in range(n_heads)
|
||||
])
|
||||
self.W_O = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
# x:float32[batch_size, sequence_length, self.d_model]
|
||||
head_outputs = []
|
||||
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
|
||||
Q = W_Q(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||
K = W_K(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||
V = W_V(x)
|
||||
# Q float32:[batch_size, sequence_length, self.d_v]
|
||||
head_output = self.attention_layer(Q, K, V)
|
||||
# float32:[batch_size, sequence_length, self.d_v]
|
||||
head_outputs.append(head_output)
|
||||
concatenated = torch.cat(head_outputs, dim=-1)
|
||||
# concatenated float32:[batch_size, sequence_length, self.d_model]
|
||||
out = self.W_O(concatenated)
|
||||
# out float32:[batch_size, sequence_length, self.d_model]
|
||||
return out
|
||||
|
||||
class Feedforward(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
super(Feedforward, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.W1 = nn.Linear(d_model, d_model)
|
||||
self.W2 = nn.Linear(d_model, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# x: float32[batch_size, sequence_length, d_model]
|
||||
return self.W2(torch.relu(self.W1(x)))
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_model, n_heads, n_layers):
|
||||
super(Transformer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.attention_layers = nn.ModuleList([
|
||||
MultiHeadedSelfAttentionLayer(d_model, n_heads)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.ffs = nn.ModuleList([
|
||||
Feedforward(d_model)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
# x: float32[batch_size, sequence_length, self.d_model]
|
||||
for attention_layer, ff in zip(self.attention_layers, self.ffs):
|
||||
attention_out = attention_layer(x)
|
||||
# attention_out: float32[batch_size, sequence_length, self.d_model]
|
||||
x = F.layer_norm(x + attention_out, x.shape[2:])
|
||||
ff_out = ff(x)
|
||||
# ff_out: float32[batch_size, sequence_length, self.d_model]
|
||||
x = F.layer_norm(x + ff_out, x.shape[2:])
|
||||
return x
|
|
@ -0,0 +1 @@
|
|||
from .main import *
|
86
joint_paraphrase_model/libs/paraphrase_generation/main.py
Normal file
86
joint_paraphrase_model/libs/paraphrase_generation/main.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
class EncoderRNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, embeddings):
|
||||
super(EncoderRNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.embedding = nn.Embedding.from_pretrained(embeddings)
|
||||
self.gru = nn.GRU(input_size, hidden_size)
|
||||
|
||||
def forward(self, input, hidden):
|
||||
embedded = self.embedding(input).view(1, 1, -1)
|
||||
output = embedded
|
||||
output, hidden = self.gru(output, hidden)
|
||||
return output, hidden
|
||||
|
||||
def initHidden(self):
|
||||
return torch.zeros(1, 1, self.hidden_size)
|
||||
|
||||
class AttnDecoderRNN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
output_size,
|
||||
embeddings,
|
||||
dropout_p,
|
||||
max_length,
|
||||
):
|
||||
super(AttnDecoderRNN, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
self.dropout_p = dropout_p
|
||||
self.max_length = max_length
|
||||
|
||||
self.embedding = nn.Embedding.from_pretrained(embeddings) #for paragen
|
||||
#self.embedding = nn.Embedding(len(embeddings), 300) #for NMT with tamil, trying wiht senitment too
|
||||
self.attn = nn.Linear(self.input_size + self.hidden_size, self.max_length)
|
||||
self.attn_combine = nn.Linear(
|
||||
self.input_size + self.hidden_size, self.hidden_size
|
||||
)
|
||||
self.dropout = nn.Dropout(self.dropout_p)
|
||||
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
|
||||
self.out = nn.Linear(self.hidden_size, self.output_size)
|
||||
|
||||
def forward(self, input, hidden, encoder_outputs, fixations):
|
||||
embedded = self.embedding(input).view(1, 1, -1)
|
||||
embedded = self.dropout(embedded)
|
||||
|
||||
attn_weights = F.softmax(
|
||||
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1
|
||||
)
|
||||
|
||||
attn_weights = attn_weights * torch.nn.ConstantPad1d((0, attn_weights.shape[-1] - fixations.shape[-2]), 0)(fixations.squeeze().unsqueeze(0))
|
||||
|
||||
# attn_weights = torch.softmax(attn_weights * torch.nn.ConstantPad1d((0, attn_weights.shape[-1] - fixations.shape[-2]), 0)(fixations.squeeze().unsqueeze(0)), dim=1)
|
||||
|
||||
attn_applied = torch.bmm(
|
||||
attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)
|
||||
)
|
||||
|
||||
output = torch.cat((embedded[0], attn_applied[0]), 1)
|
||||
output = self.attn_combine(output).unsqueeze(0)
|
||||
|
||||
output = F.relu(output)
|
||||
output, hidden = self.gru(output, hidden)
|
||||
|
||||
# output = F.log_softmax(self.out(output[0]), dim=1)
|
||||
output = self.out(output[0])
|
||||
# output = F.log_softmax(output, dim=1)
|
||||
return output, hidden, attn_weights
|
225
joint_paraphrase_model/libs/utils.py
Normal file
225
joint_paraphrase_model/libs/utils.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
from nltk.translate.bleu_score import sentence_bleu
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import config
|
||||
|
||||
|
||||
plt.switch_backend("agg")
|
||||
|
||||
|
||||
def load_glove(vocabulary):
|
||||
logger = logging.getLogger(f"{__name__}.load_glove")
|
||||
logger.info("loading embeddings")
|
||||
try:
|
||||
with open(f"glove.cache") as h:
|
||||
cache = json.load(h)
|
||||
except:
|
||||
logger.info("cache doesn't exist")
|
||||
cache = {}
|
||||
cache[config.PAD] = [0] * 300
|
||||
cache[config.SOS] = [0] * 300
|
||||
cache[config.EOS] = [0] * 300
|
||||
cache[config.UNK] = [0] * 300
|
||||
cache[config.NOFIX] = [0] * 300
|
||||
else:
|
||||
logger.info("cache found")
|
||||
|
||||
cache_miss = False
|
||||
|
||||
if not set(vocabulary) <= set(cache):
|
||||
cache_miss = True
|
||||
logger.warn("cache miss, loading full embeddings")
|
||||
data = {}
|
||||
with open("glove.840B.300d.txt") as h:
|
||||
for line in h:
|
||||
word, *emb = line.strip().split()
|
||||
try:
|
||||
data[word] = [float(x) for x in emb]
|
||||
except:
|
||||
continue
|
||||
logger.info("finished loading full embeddings")
|
||||
for word in vocabulary:
|
||||
try:
|
||||
cache[word] = data[word]
|
||||
except KeyError:
|
||||
cache[word] = [0] * 300
|
||||
logger.info("cache updated")
|
||||
|
||||
embeddings = []
|
||||
for word in vocabulary:
|
||||
embeddings.append(torch.tensor(cache[word], dtype=torch.float32))
|
||||
embeddings = torch.stack(embeddings)
|
||||
|
||||
if cache_miss:
|
||||
with open(f"glove.cache", "w") as h:
|
||||
json.dump(cache, h)
|
||||
logger.info("cache saved")
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def tokenize(s):
|
||||
s = s.lower().strip()
|
||||
s = re.sub(r"([.!?])", r" \1", s)
|
||||
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
|
||||
s = s.split(" ")
|
||||
return s
|
||||
|
||||
|
||||
def indices_from_sentence(word2index, sentence, unknown_threshold):
|
||||
if unknown_threshold:
|
||||
return [
|
||||
word2index.get(
|
||||
word if random.random() > unknown_threshold else config.UNK,
|
||||
word2index[config.UNK],
|
||||
)
|
||||
for word in sentence
|
||||
]
|
||||
else:
|
||||
return [
|
||||
word2index.get(word, word2index[config.UNK]) for word in sentence
|
||||
]
|
||||
|
||||
|
||||
def tensor_from_sentence(word2index, sentence, unknown_threshold):
|
||||
# indices = [config.SOS]
|
||||
indices = indices_from_sentence(word2index, sentence, unknown_threshold)
|
||||
indices.append(word2index[config.EOS])
|
||||
return torch.tensor(indices, dtype=torch.long, device=config.DEV)
|
||||
|
||||
|
||||
def tensors_from_pair(word2index, pair, shuffle, unknown_threshold):
|
||||
tensors = [
|
||||
tensor_from_sentence(word2index, pair[0], unknown_threshold),
|
||||
tensor_from_sentence(word2index, pair[1], unknown_threshold),
|
||||
]
|
||||
if shuffle:
|
||||
random.shuffle(tensors)
|
||||
return tensors
|
||||
|
||||
|
||||
def bleu(reference, hypothesis, n=4): #not sure if this actually changes the n gram
|
||||
if n < 1:
|
||||
return 0
|
||||
weights = [1/n]*n
|
||||
return sentence_bleu([reference], hypothesis, weights)
|
||||
|
||||
|
||||
def pair_iter(pairs, word2index, shuffle=False, shuffle_pairs=False, unknown_threshold=0.00):
|
||||
if shuffle:
|
||||
pairs = pairs.copy()
|
||||
random.shuffle(pairs)
|
||||
for pair in pairs:
|
||||
tensor1, tensor2 = tensors_from_pair(word2index, (pair[0], pair[1]), shuffle_pairs, unknown_threshold)
|
||||
yield (tensor1,), (tensor2,)
|
||||
|
||||
|
||||
def sent_iter(sents, word2index, unknown_threshold=0.00):
|
||||
for sent in sents:
|
||||
tensor = tensor_from_sentence(word2index, sent, unknown_threshold)
|
||||
yield (tensor,)
|
||||
|
||||
|
||||
def batch_iter(pairs, word2index, batch_size, shuffle=False, unknown_threshold=0.00):
|
||||
for i in range(len(pairs) // batch_size):
|
||||
batch = pairs[i : i + batch_size]
|
||||
if len(batch) != batch_size:
|
||||
continue
|
||||
batch_tensors = [
|
||||
tensors_from_pair(word2index, (pair[0], pair[1]), shuffle, unknown_threshold)
|
||||
for pair in batch
|
||||
]
|
||||
|
||||
tensors1, tensors2 = zip(*batch_tensors)
|
||||
|
||||
# targets = torch.tensor(targets, dtype=torch.long, device=config.DEV)
|
||||
|
||||
# tensors1_lengths = [len(t) for t in tensors1]
|
||||
# tensors2_lengths = [len(t) for t in tensors2]
|
||||
|
||||
# tensors1 = nn.utils.rnn.pack_sequence(tensors1, enforce_sorted=False)
|
||||
# tensors2 = nn.utils.rnn.pack_sequence(tensors2, enforce_sorted=False)
|
||||
|
||||
yield tensors1, tensors2
|
||||
|
||||
|
||||
def asMinutes(s):
|
||||
m = math.floor(s / 60)
|
||||
s -= m * 60
|
||||
return "%dm %ds" % (m, s)
|
||||
|
||||
|
||||
def timeSince(since, percent):
|
||||
now = time.time()
|
||||
s = now - since
|
||||
es = s / (percent)
|
||||
rs = es - s
|
||||
return "%s (- %s)" % (asMinutes(s), asMinutes(rs))
|
||||
|
||||
|
||||
def showPlot(points):
|
||||
plt.figure()
|
||||
fig, ax = plt.subplots()
|
||||
# this locator puts ticks at regular intervals
|
||||
loc = ticker.MultipleLocator(base=0.2)
|
||||
ax.yaxis.set_major_locator(loc)
|
||||
plt.plot(points)
|
||||
|
||||
|
||||
def showAttention(input_sentence, output_words, attentions):
|
||||
# Set up figure with colorbar
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
cax = ax.matshow(attentions.numpy(), cmap="bone")
|
||||
fig.colorbar(cax)
|
||||
|
||||
# Set up axes
|
||||
ax.set_xticklabels([""] + input_sentence.split(" ") + ["<__EOS__>"], rotation=90)
|
||||
ax.set_yticklabels([""] + output_words)
|
||||
|
||||
# Show label at every tick
|
||||
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def evaluateAndShowAttention(input_sentence):
|
||||
output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence)
|
||||
print("input =", input_sentence)
|
||||
print("output =", " ".join(output_words))
|
||||
showAttention(input_sentence, output_words, attentions)
|
||||
|
||||
|
||||
def save_model(model, word2index, path):
|
||||
if not path.endswith(".tar"):
|
||||
path += ".tar"
|
||||
torch.save(
|
||||
{"weights": model.state_dict(), "word2index": word2index},
|
||||
path,
|
||||
)
|
||||
|
||||
|
||||
def load_model(path):
|
||||
checkpoint = torch.load(path)
|
||||
return checkpoint["weights"], checkpoint["word2index"]
|
||||
|
||||
|
||||
def extend_vocabulary(word2index, langs):
|
||||
for lang in langs:
|
||||
for word in lang.word2index:
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
return word2index
|
794
joint_paraphrase_model/main.py
Normal file
794
joint_paraphrase_model/main.py
Normal file
|
@ -0,0 +1,794 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
|
||||
import click
|
||||
import sacrebleu
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import tqdm
|
||||
|
||||
import config
|
||||
from libs import corpora
|
||||
from libs import utils
|
||||
from libs.fixation_generation import Network as FixNN
|
||||
from libs.paraphrase_generation import (
|
||||
EncoderRNN as ParEncNN,
|
||||
AttnDecoderRNN as ParDecNN,
|
||||
)
|
||||
|
||||
|
||||
cwd = os.path.dirname(__file__)
|
||||
|
||||
logger = logging.getLogger("main")
|
||||
|
||||
'''
|
||||
#qqp_paw sentences:
|
||||
|
||||
debug_sentences = [
|
||||
'What are the driving rules in Georgia versus Mississippi ?',
|
||||
'I want to be a child psychologist , what qualification do i need to become one ? Are there good and reputed psychology Institute or Colleges in India ?',
|
||||
'What is deep web and dark web and what are the contents of these sites ?',
|
||||
'What is difference between North Indian Brahmins and South Indian Brahmins ?',
|
||||
'Is carbon dioxide an ionic bond or a covalent bond ?',
|
||||
'How do accounts receivable and accounts payable differ ?',
|
||||
'Why did Wikipedia hide its audit history for ( superluminal ) successful speed experiments ?',
|
||||
'What makes a person simple , or inversely , complicated ?',
|
||||
'`` How do you say `` Miss you , too , `` in Spanish ? Are there multiple ways to say it ? ‘’',
|
||||
'What is the difference between dominant trait and recessive trait ?’',
|
||||
'`` What is the difference between `` seeing someone , `` `` dating someone , `` and `` having a girlfriend/boyfriend `` ? ‘’',
|
||||
'How was the Empire State building built and designed ? How is it used ?',
|
||||
'What is the sum of the square roots of first n natural number ?',
|
||||
'Why is Roman Saini not so active on Quora now a days ?',
|
||||
'If I have someone blocked on Instagram , and see their story , can they view I viewed it ?',
|
||||
'Amongst the major IT companies of India which is the best ; Wipro , Capgemini , Infosys , TCS or is Oracle the best ?',
|
||||
'How much mass does Saturn lose each year ? How much mass does it gain ?',
|
||||
'What is a cheap healthy diet , I can keep the same and eat every day ?',
|
||||
' What is it like to be beautiful ? Not just pretty or hot , but the kind of almost objective beauty that people are sometimes intimidated by ?',
|
||||
'Could someone tell of a James Ronsey ( misspelled likely ) , writer and filmmaker , probably of the British Isles ?',
|
||||
'How much pressure Is there around the core of Pluto ? is it enough to turn hydrogen/helium gas into a liquid or metallic state ?',
|
||||
'How does quality of life in Vancouver compare to that in Melbourne Or Brisbane ?',
|
||||
]
|
||||
'''
|
||||
'''
|
||||
#wiki sentences:
|
||||
|
||||
debug_sentences = [
|
||||
'They were there to enjoy us and they were there to pray for us .',
|
||||
'Components of elastic potential systems store mechanical energy if they are deformed when forces are applied to the system .',
|
||||
'Steam can also be used , and does not need to be pumped .',
|
||||
'The solar approach to this requirement is the use of solar panels in a conventional-powered aircraft .',
|
||||
'Daudkhali is a village in Barisal Division in the Pirojpur district in southwestern Bangladesh .',
|
||||
'Briggs later met Briggs at the 1967 Monterey Pop Festival , where Ravi Shankar was also performing , with Eric Burdon and The Animals .',
|
||||
'Brockton is approximately 25 miles northeast of Providence , Rhode Island , and 30 miles south of Boston .',
|
||||
]
|
||||
'''
|
||||
|
||||
#qqp sentences:
|
||||
|
||||
debug_sentences = [
|
||||
'How do I get funding for my web based startup idea ?',
|
||||
'What do intelligent people do to pass time ?',
|
||||
'Which is the best SEO Company in Delhi ?',
|
||||
'Why do you waer makeup ?',
|
||||
'How do start chatting with a girl ?',
|
||||
'What is the meaning of living life ?',
|
||||
'Why do my armpits hurt ?',
|
||||
'Why does eye color change with age ?',
|
||||
'How do you find the standard deviation of a probability distribution ? What are some examples ?',
|
||||
'How can I complete my 11 syllabus in one month ?',
|
||||
'How do I concentrate better on my studies ?',
|
||||
'Which is the best retirement plan in india ?',
|
||||
'Should I tell my best friend I love her ?',
|
||||
'Which is the best company for Appian Vagrant online job support ?',
|
||||
'How can one do for good handwriting ?',
|
||||
'What are remedies to get rid of belly fat ?',
|
||||
'What is the best way to cook precooked turkey ?',
|
||||
'What is the future of e-commerce in India ?',
|
||||
'Why do my burps taste like rotten eggs ?',
|
||||
'What is an example of chemical weathering ?',
|
||||
'What are some of the advantages and disadvantages of cyber schooling ?',
|
||||
'How can I increase traffic to my websites by Facebook ?',
|
||||
'How do I increase my patience level in life ?',
|
||||
'What are the best hospitals for treating cancer in India ?',
|
||||
'Will Jio sim work in a 3G phone ? If yes , how ?',
|
||||
]
|
||||
|
||||
debug_sentences = [s.split(" ") for s in debug_sentences]
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
word2index,
|
||||
embeddings,
|
||||
):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(f"{__name__}")
|
||||
self.word2index = word2index
|
||||
self.index2word = {i: k for k, i in word2index.items()}
|
||||
self.fix_gen = FixNN(
|
||||
embedding_type="glove",
|
||||
vocab_size=len(word2index),
|
||||
embedding_dim=config.embedding_dim,
|
||||
embeddings=embeddings,
|
||||
dropout=config.fix_dropout,
|
||||
hidden_dim=config.fix_hidden_dim,
|
||||
)
|
||||
self.par_enc = ParEncNN(
|
||||
input_size=config.embedding_dim,
|
||||
hidden_size=config.par_hidden_dim,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
self.par_dec = ParDecNN(
|
||||
input_size=config.embedding_dim,
|
||||
hidden_size=config.par_hidden_dim,
|
||||
output_size=len(word2index),
|
||||
embeddings=embeddings,
|
||||
dropout_p=config.par_dropout,
|
||||
max_length=config.max_length,
|
||||
)
|
||||
|
||||
def forward(self, x, target=None, teacher_forcing_ratio=None):
|
||||
teacher_forcing_ratio = teacher_forcing_ratio if teacher_forcing_ratio is not None else config.teacher_forcing_ratio
|
||||
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
|
||||
x2 = nn.utils.rnn.pad_sequence(x, batch_first=False)
|
||||
fixations = torch.sigmoid(self.fix_gen(x1, [len(_x) for _x in x1]))
|
||||
|
||||
enc_hidden = self.par_enc.initHidden().to(config.DEV)
|
||||
enc_outs = torch.zeros(config.max_length, config.par_hidden_dim, device=config.DEV)
|
||||
|
||||
for ei in range(len(x2)):
|
||||
enc_out, enc_hidden = self.par_enc(x2[ei], enc_hidden)
|
||||
enc_outs[ei] += enc_out[0, 0]
|
||||
|
||||
dec_in = torch.tensor([[self.word2index[config.SOS]]], device=config.DEV) # SOS
|
||||
dec_hidden = enc_hidden
|
||||
dec_outs = []
|
||||
dec_words = []
|
||||
dec_atts = torch.zeros(config.max_length, config.max_length)
|
||||
|
||||
if target is not None: # training
|
||||
target = nn.utils.rnn.pad_sequence(target, batch_first=False)
|
||||
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
||||
|
||||
if use_teacher_forcing:
|
||||
for di in range(len(target)):
|
||||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||
dec_in, dec_hidden, enc_outs, fixations
|
||||
)
|
||||
dec_outs.append(dec_out)
|
||||
dec_atts[di] = dec_att.data
|
||||
dec_input = target[di]
|
||||
|
||||
else:
|
||||
for di in range(len(target)):
|
||||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||
dec_in, dec_hidden, enc_outs, fixations
|
||||
)
|
||||
dec_outs.append(dec_out)
|
||||
dec_atts[di] = dec_att.data
|
||||
topv, topi = dec_out.data.topk(1)
|
||||
dec_words.append(self.index2word[topi.item()])
|
||||
|
||||
dec_input = topi.squeeze().detach()
|
||||
|
||||
else: # prediction
|
||||
for di in range(config.max_length):
|
||||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||
dec_in, dec_hidden, enc_outs, fixations
|
||||
)
|
||||
dec_outs.append(dec_out)
|
||||
dec_atts[di] = dec_att.data
|
||||
topv, topi = dec_out.data.topk(1)
|
||||
if topi.item() == self.word2index[config.EOS]:
|
||||
dec_words.append("<__EOS__>")
|
||||
break
|
||||
else:
|
||||
dec_words.append(self.index2word[topi.item()])
|
||||
|
||||
dec_input = topi.squeeze().detach()
|
||||
|
||||
return dec_outs, dec_words, dec_atts[: di + 1], fixations
|
||||
|
||||
|
||||
def load_corpus(corpus_name, splits):
|
||||
if not splits:
|
||||
return
|
||||
|
||||
logger.info("loading corpus")
|
||||
if corpus_name == "msrpc":
|
||||
load_fn = corpora.load_msrpc
|
||||
elif corpus_name == "qqp":
|
||||
load_fn = corpora.load_qqp
|
||||
elif corpus_name == "wiki":
|
||||
load_fn = corpora.load_wiki
|
||||
elif corpus_name == "qqp_paws":
|
||||
load_fn = corpora.load_qqp_paws
|
||||
elif corpus_name == "qqp_kag":
|
||||
load_fn = corpora.load_qqp_kag
|
||||
elif corpus_name == "sentiment":
|
||||
load_fn = corpora.load_sentiment
|
||||
elif corpus_name == "stanford":
|
||||
load_fn = corpora.load_stanford
|
||||
elif corpus_name == "stanford_sent":
|
||||
load_fn = corpora.load_stanford_sent
|
||||
elif corpus_name == "tamil":
|
||||
load_fn = corpora.load_tamil
|
||||
elif corpus_name == "compression":
|
||||
load_fn = corpora.load_compression
|
||||
|
||||
corpus = {}
|
||||
langs = []
|
||||
|
||||
if "train" in splits:
|
||||
train_pairs, train_lang = load_fn("train")
|
||||
corpus["train"] = train_pairs
|
||||
langs.append(train_lang)
|
||||
if "val" in splits:
|
||||
val_pairs, val_lang = load_fn("val")
|
||||
corpus["val"] = val_pairs
|
||||
langs.append(val_lang)
|
||||
if "test" in splits:
|
||||
test_pairs, test_lang = load_fn("test")
|
||||
corpus["test"] = test_pairs
|
||||
langs.append(test_lang)
|
||||
|
||||
logger.info("creating word index")
|
||||
lang = langs[0]
|
||||
for _lang in langs[1:]:
|
||||
lang += _lang
|
||||
word2index = lang.word2index
|
||||
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
|
||||
return corpus, word2index, index2word
|
||||
|
||||
|
||||
def init_network(word2index):
|
||||
logger.info("loading embeddings")
|
||||
vocabulary = sorted(word2index.keys())
|
||||
embeddings = utils.load_glove(vocabulary)
|
||||
|
||||
logger.info("initializing model")
|
||||
network = Network(
|
||||
word2index=word2index,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
network.to(config.DEV)
|
||||
|
||||
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
|
||||
|
||||
return network
|
||||
|
||||
|
||||
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
|
||||
@click.option("-v", "--verbose", count=True)
|
||||
@click.option("-d", "--debug", is_flag=True)
|
||||
def main(verbose, debug):
|
||||
if verbose == 0:
|
||||
loglevel = logging.ERROR
|
||||
elif verbose == 1:
|
||||
loglevel = logging.WARN
|
||||
elif verbose >= 2:
|
||||
loglevel = logging.INFO
|
||||
|
||||
if debug:
|
||||
loglevel = logging.DEBUG
|
||||
|
||||
logging.basicConfig(
|
||||
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
|
||||
datefmt="%d.%m. %H:%M:%S",
|
||||
level=loglevel,
|
||||
)
|
||||
|
||||
logger.debug("arguments: %s" % str(sys.argv))
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||
)
|
||||
@click.option("-m", "--model_name", required=True)
|
||||
@click.option("-w", "--fixation_weights", required=False)
|
||||
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
|
||||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||
def train(corpus_name, model_name, fixation_weights, freeze_fixations, bleu):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
|
||||
train_pairs = corpus["train"]
|
||||
val_pairs = corpus["val"]
|
||||
network = init_network(word2index)
|
||||
|
||||
model_dir = os.path.join("models", model_name)
|
||||
logger.debug("creating model dir %s" % model_dir)
|
||||
pathlib.Path(model_dir).mkdir(parents=True)
|
||||
|
||||
if fixation_weights is not None:
|
||||
logger.info("loading fixation prediction checkpoint")
|
||||
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
else:
|
||||
weights = checkpoint
|
||||
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")}
|
||||
network.fix_gen.load_state_dict(weights, strict=False)
|
||||
|
||||
if freeze_fixations:
|
||||
logger.info("freezing fixation generation network")
|
||||
for p in network.fix_gen.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
|
||||
#optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=1e-5)
|
||||
|
||||
best_val_loss = None
|
||||
|
||||
epoch = 1
|
||||
while True:
|
||||
train_batch_iter = utils.pair_iter(pairs=train_pairs, word2index=word2index, shuffle=True, shuffle_pairs=False)
|
||||
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
# test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
|
||||
running_train_loss = 0
|
||||
total_train_loss = 0
|
||||
total_val_loss = 0
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
running_train_bleu = 0
|
||||
total_train_bleu = 0
|
||||
total_val_bleu = 0
|
||||
elif bleu == "nltk":
|
||||
running_train_bleu_1 = 0
|
||||
running_train_bleu_2 = 0
|
||||
running_train_bleu_3 = 0
|
||||
running_train_bleu_4 = 0
|
||||
total_train_bleu_1 = 0
|
||||
total_train_bleu_2 = 0
|
||||
total_train_bleu_3 = 0
|
||||
total_train_bleu_4 = 0
|
||||
total_val_bleu_1 = 0
|
||||
total_val_bleu_2 = 0
|
||||
total_val_bleu_3 = 0
|
||||
total_val_bleu_4 = 0
|
||||
|
||||
network.train()
|
||||
for i, batch in enumerate(train_batch_iter, 1):
|
||||
optimizer.zero_grad()
|
||||
|
||||
input, target = batch
|
||||
prediction, words, attention, fixations = network(input, target)
|
||||
|
||||
loss = loss_fn(
|
||||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_train_loss += loss.item()
|
||||
total_train_loss += loss.item()
|
||||
|
||||
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
|
||||
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
|
||||
running_train_bleu += bleu_score
|
||||
total_train_bleu += bleu_score
|
||||
elif bleu == "nltk":
|
||||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||
running_train_bleu_1 += bleu_1_score
|
||||
running_train_bleu_2 += bleu_2_score
|
||||
running_train_bleu_3 += bleu_3_score
|
||||
running_train_bleu_4 += bleu_4_score
|
||||
total_train_bleu_1 += bleu_1_score
|
||||
total_train_bleu_2 += bleu_2_score
|
||||
total_train_bleu_3 += bleu_3_score
|
||||
total_train_bleu_4 += bleu_4_score
|
||||
# print(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist())
|
||||
|
||||
if i % 100 == 0:
|
||||
if bleu == "sacrebleu":
|
||||
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
|
||||
elif bleu == "nltk":
|
||||
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
|
||||
|
||||
network.eval()
|
||||
with open(os.path.join(model_dir, f"debug_{epoch}_{i}.out"), "w") as h:
|
||||
if bleu == "sacrebleu":
|
||||
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
|
||||
running_train_bleu = 0
|
||||
elif bleu == "nltk":
|
||||
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
|
||||
running_train_bleu_1 = 0
|
||||
running_train_bleu_2 = 0
|
||||
running_train_bleu_3 = 0
|
||||
running_train_bleu_4 = 0
|
||||
|
||||
running_train_loss = 0
|
||||
|
||||
h.write("\n")
|
||||
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
|
||||
h.write("\n")
|
||||
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
|
||||
prediction, words, attentions, fixations = network(input)
|
||||
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
|
||||
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
|
||||
h.write("\n")
|
||||
|
||||
network.train()
|
||||
|
||||
network.eval()
|
||||
for i, batch in enumerate(val_batch_iter):
|
||||
input, target = batch
|
||||
prediction, words, attention, fixations = network(input, target, teacher_forcing_ratio=0)
|
||||
loss = loss_fn(
|
||||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||
)
|
||||
|
||||
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
|
||||
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
|
||||
if bleu == "sacrebleu":
|
||||
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
|
||||
elif bleu == "nltk":
|
||||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||
total_val_bleu_1 += bleu_1_score
|
||||
total_val_bleu_2 += bleu_2_score
|
||||
total_val_bleu_3 += bleu_3_score
|
||||
total_val_bleu_4 += bleu_4_score
|
||||
|
||||
total_val_loss += loss.item()
|
||||
|
||||
avg_val_loss = total_val_loss/len(val_pairs)
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu {total_train_bleu/len(train_pairs):.2f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
|
||||
elif bleu == "nltk":
|
||||
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||
|
||||
with open(os.path.join(model_dir, f"debug_{epoch}_end.out"), "w") as h:
|
||||
if bleu == "sacrebleu":
|
||||
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs)} avg_val_loss {total_val_loss/len(val_pairs)} avg_train_bleu {total_train_bleu/len(train_pairs)} avg_val_bleu {total_val_bleu/len(val_pairs)}")
|
||||
elif bleu == "nltk":
|
||||
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||
h.write("\n")
|
||||
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
|
||||
h.write("\n")
|
||||
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
|
||||
prediction, words, attentions, fixations = network(input)
|
||||
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
|
||||
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
|
||||
h.write("\n")
|
||||
|
||||
utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_{epoch}"))
|
||||
|
||||
if best_val_loss is None or avg_val_loss < best_val_loss:
|
||||
if best_val_loss is not None:
|
||||
logger.info(f"{avg_val_loss} < {best_val_loss} ({avg_val_loss-best_val_loss}): new best model from epoch {epoch}")
|
||||
else:
|
||||
logger.info(f"{avg_val_loss} < {best_val_loss}: new best model from epoch {epoch}")
|
||||
|
||||
best_val_loss = avg_val_loss
|
||||
# save_model(model, word2index, model_name + "_epoch_" + str(epoch))
|
||||
# utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_best"))
|
||||
|
||||
epoch += 1
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||
)
|
||||
@click.option("-w", "--model_weights", required=True)
|
||||
@click.option("-s", "--sentence_statistics", is_flag=True)
|
||||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||
def val(corpus_name, model_weights, sentence_statistics, bleu):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
|
||||
val_pairs = corpus["val"]
|
||||
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index)
|
||||
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# make a new output layer to match the weights from the checkpoint
|
||||
# we cannot remove it like we did with the embedding layers because
|
||||
# unlike those the output layer actually contains learned parameters
|
||||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
|
||||
total_val_loss = 0
|
||||
if bleu == "sacrebleu":
|
||||
total_val_bleu = 0
|
||||
elif bleu == "nltk":
|
||||
total_val_bleu_1 = 0
|
||||
total_val_bleu_2 = 0
|
||||
total_val_bleu_3 = 0
|
||||
total_val_bleu_4 = 0
|
||||
|
||||
network.eval()
|
||||
for i, batch in enumerate(val_batch_iter, 1):
|
||||
input, target = batch
|
||||
prediction, words, attentions, fixations = network(input, target)
|
||||
|
||||
loss = loss_fn(
|
||||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||
)
|
||||
total_val_loss += loss.item()
|
||||
|
||||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
|
||||
_target = [index2word[_x] for _x in target[0].tolist()]
|
||||
if bleu == "sacrebleu":
|
||||
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join(_target)).score
|
||||
total_val_bleu += bleu_score
|
||||
elif bleu == "nltk":
|
||||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||
total_val_bleu_1 += bleu_1_score
|
||||
total_val_bleu_2 += bleu_2_score
|
||||
total_val_bleu_3 += bleu_3_score
|
||||
total_val_bleu_4 += bleu_4_score
|
||||
|
||||
if sentence_statistics:
|
||||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"{bleu_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
|
||||
elif bleu == "nltk":
|
||||
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
|
||||
elif bleu == "nltk":
|
||||
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||
)
|
||||
@click.option("-w", "--model_weights", required=True)
|
||||
@click.option("-s", "--sentence_statistics", is_flag=True)
|
||||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||
def test(corpus_name, model_weights, sentence_statistics, bleu):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
||||
test_pairs = corpus["test"]
|
||||
|
||||
if model_weights is not None:
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index)
|
||||
|
||||
if model_weights is not None:
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# make a new output layer to match the weights from the checkpoint
|
||||
# we cannot remove it like we did with the embedding layers because
|
||||
# unlike those the output layer actually contains learned parameters
|
||||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
|
||||
total_test_loss = 0
|
||||
if bleu == "sacrebleu":
|
||||
total_test_bleu = 0
|
||||
elif bleu == "nltk":
|
||||
total_test_bleu_1 = 0
|
||||
total_test_bleu_2 = 0
|
||||
total_test_bleu_3 = 0
|
||||
total_test_bleu_4 = 0
|
||||
|
||||
network.eval()
|
||||
for i, batch in enumerate(test_batch_iter, 1):
|
||||
input, target = batch
|
||||
|
||||
prediction, words, attentions, fixations = network(input, target)
|
||||
|
||||
loss = loss_fn(
|
||||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||
)
|
||||
total_test_loss += loss.item()
|
||||
|
||||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
|
||||
_target = [index2word[_x] for _x in target[0].tolist()]
|
||||
if bleu == "sacrebleu":
|
||||
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join( _target)).score
|
||||
total_test_bleu += bleu_score
|
||||
elif bleu == "nltk":
|
||||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||
total_test_bleu_1 += bleu_1_score
|
||||
total_test_bleu_2 += bleu_2_score
|
||||
total_test_bleu_3 += bleu_3_score
|
||||
total_test_bleu_4 += bleu_4_score
|
||||
|
||||
if sentence_statistics:
|
||||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"{bleu_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||
elif bleu == "nltk":
|
||||
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||
|
||||
if bleu == "sacrebleu":
|
||||
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu {total_test_bleu/len(test_pairs):.2f}")
|
||||
elif bleu == "nltk":
|
||||
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu_1 {total_test_bleu_1/len(test_pairs):.2f} avg_test_bleu_2 {total_test_bleu_2/len(test_pairs):.2f} avg_test_bleu_3 {total_test_bleu_3/len(test_pairs):.2f} avg_test_bleu_4 {total_test_bleu_4/len(test_pairs):.2f}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"-c",
|
||||
"--corpus",
|
||||
"corpus_name",
|
||||
required=True,
|
||||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||
)
|
||||
@click.option("-w", "--model_weights", required=False)
|
||||
def predict(corpus_name, model_weights):
|
||||
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
|
||||
test_pairs = corpus["val"]
|
||||
|
||||
if model_weights is not None:
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index)
|
||||
|
||||
logger.info(f"vocab size {len(word2index)}")
|
||||
|
||||
if model_weights is not None:
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# make a new output layer to match the weights from the checkpoint
|
||||
# we cannot remove it like we did with the embedding layers because
|
||||
# unlike those the output layer actually contains learned parameters
|
||||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||
|
||||
network.eval()
|
||||
for i, batch in enumerate(test_batch_iter, 1):
|
||||
input, target = batch
|
||||
|
||||
prediction, words, attentions, fixations = network(input, target)
|
||||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
|
||||
|
||||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
|
||||
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("-w", "--model_weights", required=True)
|
||||
@click.argument("path")
|
||||
def predict_file(model_weights, path):
|
||||
logger.info("loading sentences")
|
||||
sentences = []
|
||||
lang = corpora.Lang("pred")
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
if line:
|
||||
sentence = line.split(" ")
|
||||
lang.add_sentence(sentence)
|
||||
sentences.append(sentence)
|
||||
word2index = lang.word2index
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
|
||||
logger.info(f"{len(sentences)} sentences loaded")
|
||||
|
||||
logger.info("loading model checkpoint")
|
||||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||
if "word2index" in checkpoint:
|
||||
weights = checkpoint["weights"]
|
||||
word2index = checkpoint["word2index"]
|
||||
index2word = {i: w for w, i in word2index.items()}
|
||||
else:
|
||||
asdf
|
||||
|
||||
network = init_network(word2index)
|
||||
|
||||
logger.info(f"vocab size {len(word2index)}")
|
||||
|
||||
# remove the embedding layer before loading
|
||||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||
# make a new output layer to match the weights from the checkpoint
|
||||
# we cannot remove it like we did with the embedding layers because
|
||||
# unlike those the output layer actually contains learned parameters
|
||||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||
# actually load the parameters
|
||||
network.load_state_dict(weights, strict=False)
|
||||
|
||||
debug_sentence_iter = utils.sent_iter(sentences, word2index=word2index)
|
||||
|
||||
network.eval()
|
||||
for i, input in enumerate(debug_sentence_iter, 1):
|
||||
|
||||
prediction, words, attentions, fixations = network(input)
|
||||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
|
||||
|
||||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||
|
||||
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
19
joint_paraphrase_model/requirements.txt
Normal file
19
joint_paraphrase_model/requirements.txt
Normal file
|
@ -0,0 +1,19 @@
|
|||
click==7.1.2
|
||||
cycler==0.10.0
|
||||
dataclasses==0.6
|
||||
future==0.18.2
|
||||
joblib==0.17.0
|
||||
kiwisolver==1.3.1
|
||||
matplotlib==3.3.3
|
||||
nltk==3.5
|
||||
numpy==1.19.4
|
||||
Pillow==8.0.1
|
||||
portalocker==2.0.0
|
||||
pyparsing==2.4.7
|
||||
python-dateutil==2.8.1
|
||||
regex==2020.11.13
|
||||
sacrebleu==1.4.6
|
||||
six==1.15.0
|
||||
torch==1.7.0
|
||||
tqdm==4.54.1
|
||||
typing-extensions==3.7.4.3
|
43
joint_paraphrase_model/utils/long_sentence_split.py
Normal file
43
joint_paraphrase_model/utils/long_sentence_split.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def read(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
b, s, p, a, f = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line {line}", file=sys.stderr)
|
||||
continue
|
||||
else:
|
||||
yield b, s, p, a, f
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path")
|
||||
def main(path):
|
||||
data = list(read(path))
|
||||
avg_len = sum(len(x[1]) for x in data)/len(data)
|
||||
filtered_data = []
|
||||
filtered_data2 = []
|
||||
|
||||
fname, ext = os.path.splitext(path)
|
||||
ext = f".{ext}" if ext else ext
|
||||
with open(f"{fname}_long{ext}", "w") as lh, open(f"{fname}_short{ext}", "w") as sh:
|
||||
for x in data:
|
||||
if len(x[1]) > avg_len:
|
||||
lh.write("\t".join(x))
|
||||
lh.write("\n")
|
||||
else:
|
||||
sh.write("\t".join(x))
|
||||
sh.write("\n")
|
||||
|
||||
print(f"avg sentence length {avg_len}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
40
joint_paraphrase_model/utils/long_sentence_stats.py
Normal file
40
joint_paraphrase_model/utils/long_sentence_stats.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import sys
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def read(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
b, s, p, *_ = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line {line}", file=sys.stderr)
|
||||
continue
|
||||
else:
|
||||
yield float(b), s, p
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path")
|
||||
def main(path):
|
||||
data = list(read(path))
|
||||
avg_len = sum(len(x[1]) for x in data)/len(data)
|
||||
filtered_data = []
|
||||
filtered_data2 = []
|
||||
for x in data:
|
||||
if len(x[1]) > avg_len:
|
||||
filtered_data.append(x)
|
||||
else:
|
||||
filtered_data2.append(x)
|
||||
print(f"avg sentence length {avg_len}")
|
||||
print(f"long sentences {len(filtered_data)}")
|
||||
print(f"short sentences {len(filtered_data2)}")
|
||||
print(f"total bleu {sum(x[0] for x in data)/len(data)}")
|
||||
print(f"longest bleu {sum(x[0] for x in filtered_data)/len(filtered_data)}")
|
||||
print(f"shortest bleu {sum(x[0] for x in filtered_data2)/len(filtered_data2)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
64
joint_paraphrase_model/utils/plot_attention.py
Normal file
64
joint_paraphrase_model/utils/plot_attention.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import ast
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import text_attention
|
||||
|
||||
import click
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("agg")
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
|
||||
def plot_attention(input_sentence, output_words, attentions, path):
|
||||
# Set up figure with colorbar
|
||||
attentions = np.array(attentions)[:,:len(input_sentence)]
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
cax = ax.matshow(attentions, cmap="bone")
|
||||
fig.colorbar(cax)
|
||||
|
||||
# Set up axes
|
||||
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
|
||||
ax.set_yticklabels([""] + output_words)
|
||||
|
||||
# Show label at every tick
|
||||
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
|
||||
plt.savefig(f"{path}.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
def parse(p):
|
||||
with open(p) as h:
|
||||
for line in h:
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
|
||||
try:
|
||||
sentence = ast.literal_eval(_sentence)
|
||||
prediction = ast.literal_eval(_prediction)
|
||||
attention = ast.literal_eval(_attention)
|
||||
except:
|
||||
continue
|
||||
|
||||
yield sentence, prediction, attention
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path", nargs=-1, required=True)
|
||||
def main(path):
|
||||
for p in tqdm.tqdm(path):
|
||||
out_dir = os.path.splitext(p)[0]
|
||||
if out_dir == path:
|
||||
out_dir = f"{out_dir}_"
|
||||
pathlib.Path(out_dir).mkdir(exist_ok=True)
|
||||
for i, spa in enumerate(parse(p)):
|
||||
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
70
joint_paraphrase_model/utils/text_attention.py
Executable file
70
joint_paraphrase_model/utils/text_attention.py
Executable file
|
@ -0,0 +1,70 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author: Jie Yang
|
||||
# @Date: 2019-03-29 16:10:23
|
||||
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||
# @Last Modified time: 2019-04-12 09:56:12
|
||||
|
||||
|
||||
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
|
||||
import numpy as np
|
||||
|
||||
latex_special_token = ["!@#$%^&*()"]
|
||||
|
||||
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
|
||||
assert(len(text_list) == len(attention_list))
|
||||
if rescale_value:
|
||||
attention_list = rescale(attention_list)
|
||||
word_num = len(text_list)
|
||||
text_list = clean_word(text_list)
|
||||
with open(latex_file,'w') as f:
|
||||
f.write(r'''\documentclass[varwidth]{standalone}
|
||||
\special{papersize=210mm,297mm}
|
||||
\usepackage{color}
|
||||
\usepackage{tcolorbox}
|
||||
\usepackage{CJK}
|
||||
\usepackage{adjustbox}
|
||||
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
||||
\begin{document}
|
||||
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
|
||||
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
|
||||
for idx in range(word_num):
|
||||
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
|
||||
string += "\n}}}"
|
||||
f.write(string+'\n')
|
||||
f.write(r'''\end{CJK*}
|
||||
\end{document}''')
|
||||
|
||||
def rescale(input_list):
|
||||
the_array = np.asarray(input_list)
|
||||
the_max = np.max(the_array)
|
||||
the_min = np.min(the_array)
|
||||
rescale = (the_array - the_min)/(the_max-the_min)*100
|
||||
return rescale.tolist()
|
||||
|
||||
|
||||
def clean_word(word_list):
|
||||
new_word_list = []
|
||||
for word in word_list:
|
||||
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
|
||||
if latex_sensitive in word:
|
||||
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
|
||||
new_word_list.append(word)
|
||||
return new_word_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
## This is a demo:
|
||||
|
||||
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
|
||||
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
|
||||
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
|
||||
sent = '''我 回忆 起 我 曾经 在 大学 年代 , 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar , 我 想起 了 西游记 里 的 琵琶精 。
|
||||
今年 下半年 , 中 美 合拍 的 西游记 即将 正式 开机 , 我 继续 扮演 美猴王 孙悟空 , 我 会 用 美猴王 艺术 形象 努力 创造 一 个 正能量 的 形象 , 文 体 两 开花 , 弘扬 中华 文化 , 希望 大家 能 多多 关注 。'''
|
||||
words = sent.split()
|
||||
word_num = len(words)
|
||||
attention = [(x+1.)/word_num*100 for x in range(word_num)]
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(attention)
|
||||
color = 'red'
|
||||
generate(words, attention, "sample.tex", color)
|
Loading…
Add table
Add a link
Reference in a new issue