Add NLP task models
This commit is contained in:
parent
d8beb17dfb
commit
69f6de0ace
46 changed files with 4976 additions and 0 deletions
428
joint_paraphrase_model/.gitignore
vendored
Normal file
428
joint_paraphrase_model/.gitignore
vendored
Normal file
|
@ -0,0 +1,428 @@
|
||||||
|
|
||||||
|
# Created by https://www.toptal.com/developers/gitignore/api/python,latex
|
||||||
|
# Edit at https://www.toptal.com/developers/gitignore?templates=python,latex
|
||||||
|
|
||||||
|
### LaTeX ###
|
||||||
|
## Core latex/pdflatex auxiliary files:
|
||||||
|
*.aux
|
||||||
|
*.lof
|
||||||
|
*.log
|
||||||
|
*.lot
|
||||||
|
*.fls
|
||||||
|
*.out
|
||||||
|
*.toc
|
||||||
|
*.fmt
|
||||||
|
*.fot
|
||||||
|
*.cb
|
||||||
|
*.cb2
|
||||||
|
.*.lb
|
||||||
|
|
||||||
|
## Intermediate documents:
|
||||||
|
*.dvi
|
||||||
|
*.xdv
|
||||||
|
*-converted-to.*
|
||||||
|
# these rules might exclude image files for figures etc.
|
||||||
|
# *.ps
|
||||||
|
# *.eps
|
||||||
|
# *.pdf
|
||||||
|
|
||||||
|
## Generated if empty string is given at "Please type another file name for output:"
|
||||||
|
.pdf
|
||||||
|
|
||||||
|
## Bibliography auxiliary files (bibtex/biblatex/biber):
|
||||||
|
*.bbl
|
||||||
|
*.bcf
|
||||||
|
*.blg
|
||||||
|
*-blx.aux
|
||||||
|
*-blx.bib
|
||||||
|
*.run.xml
|
||||||
|
|
||||||
|
## Build tool auxiliary files:
|
||||||
|
*.fdb_latexmk
|
||||||
|
*.synctex
|
||||||
|
*.synctex(busy)
|
||||||
|
*.synctex.gz
|
||||||
|
*.synctex.gz(busy)
|
||||||
|
*.pdfsync
|
||||||
|
|
||||||
|
## Build tool directories for auxiliary files
|
||||||
|
# latexrun
|
||||||
|
latex.out/
|
||||||
|
|
||||||
|
## Auxiliary and intermediate files from other packages:
|
||||||
|
# algorithms
|
||||||
|
*.alg
|
||||||
|
*.loa
|
||||||
|
|
||||||
|
# achemso
|
||||||
|
acs-*.bib
|
||||||
|
|
||||||
|
# amsthm
|
||||||
|
*.thm
|
||||||
|
|
||||||
|
# beamer
|
||||||
|
*.nav
|
||||||
|
*.pre
|
||||||
|
*.snm
|
||||||
|
*.vrb
|
||||||
|
|
||||||
|
# changes
|
||||||
|
*.soc
|
||||||
|
|
||||||
|
# comment
|
||||||
|
*.cut
|
||||||
|
|
||||||
|
# cprotect
|
||||||
|
*.cpt
|
||||||
|
|
||||||
|
# elsarticle (documentclass of Elsevier journals)
|
||||||
|
*.spl
|
||||||
|
|
||||||
|
# endnotes
|
||||||
|
*.ent
|
||||||
|
|
||||||
|
# fixme
|
||||||
|
*.lox
|
||||||
|
|
||||||
|
# feynmf/feynmp
|
||||||
|
*.mf
|
||||||
|
*.mp
|
||||||
|
*.t[1-9]
|
||||||
|
*.t[1-9][0-9]
|
||||||
|
*.tfm
|
||||||
|
|
||||||
|
#(r)(e)ledmac/(r)(e)ledpar
|
||||||
|
*.end
|
||||||
|
*.?end
|
||||||
|
*.[1-9]
|
||||||
|
*.[1-9][0-9]
|
||||||
|
*.[1-9][0-9][0-9]
|
||||||
|
*.[1-9]R
|
||||||
|
*.[1-9][0-9]R
|
||||||
|
*.[1-9][0-9][0-9]R
|
||||||
|
*.eledsec[1-9]
|
||||||
|
*.eledsec[1-9]R
|
||||||
|
*.eledsec[1-9][0-9]
|
||||||
|
*.eledsec[1-9][0-9]R
|
||||||
|
*.eledsec[1-9][0-9][0-9]
|
||||||
|
*.eledsec[1-9][0-9][0-9]R
|
||||||
|
|
||||||
|
# glossaries
|
||||||
|
*.acn
|
||||||
|
*.acr
|
||||||
|
*.glg
|
||||||
|
*.glo
|
||||||
|
*.gls
|
||||||
|
*.glsdefs
|
||||||
|
*.lzo
|
||||||
|
*.lzs
|
||||||
|
|
||||||
|
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
|
||||||
|
# *.ist
|
||||||
|
|
||||||
|
# gnuplottex
|
||||||
|
*-gnuplottex-*
|
||||||
|
|
||||||
|
# gregoriotex
|
||||||
|
*.gaux
|
||||||
|
*.gtex
|
||||||
|
|
||||||
|
# htlatex
|
||||||
|
*.4ct
|
||||||
|
*.4tc
|
||||||
|
*.idv
|
||||||
|
*.lg
|
||||||
|
*.trc
|
||||||
|
*.xref
|
||||||
|
|
||||||
|
# hyperref
|
||||||
|
*.brf
|
||||||
|
|
||||||
|
# knitr
|
||||||
|
*-concordance.tex
|
||||||
|
# TODO Comment the next line if you want to keep your tikz graphics files
|
||||||
|
*.tikz
|
||||||
|
*-tikzDictionary
|
||||||
|
|
||||||
|
# listings
|
||||||
|
*.lol
|
||||||
|
|
||||||
|
# luatexja-ruby
|
||||||
|
*.ltjruby
|
||||||
|
|
||||||
|
# makeidx
|
||||||
|
*.idx
|
||||||
|
*.ilg
|
||||||
|
*.ind
|
||||||
|
|
||||||
|
# minitoc
|
||||||
|
*.maf
|
||||||
|
*.mlf
|
||||||
|
*.mlt
|
||||||
|
*.mtc[0-9]*
|
||||||
|
*.slf[0-9]*
|
||||||
|
*.slt[0-9]*
|
||||||
|
*.stc[0-9]*
|
||||||
|
|
||||||
|
# minted
|
||||||
|
_minted*
|
||||||
|
*.pyg
|
||||||
|
|
||||||
|
# morewrites
|
||||||
|
*.mw
|
||||||
|
|
||||||
|
# nomencl
|
||||||
|
*.nlg
|
||||||
|
*.nlo
|
||||||
|
*.nls
|
||||||
|
|
||||||
|
# pax
|
||||||
|
*.pax
|
||||||
|
|
||||||
|
# pdfpcnotes
|
||||||
|
*.pdfpc
|
||||||
|
|
||||||
|
# sagetex
|
||||||
|
*.sagetex.sage
|
||||||
|
*.sagetex.py
|
||||||
|
*.sagetex.scmd
|
||||||
|
|
||||||
|
# scrwfile
|
||||||
|
*.wrt
|
||||||
|
|
||||||
|
# sympy
|
||||||
|
*.sout
|
||||||
|
*.sympy
|
||||||
|
sympy-plots-for-*.tex/
|
||||||
|
|
||||||
|
# pdfcomment
|
||||||
|
*.upa
|
||||||
|
*.upb
|
||||||
|
|
||||||
|
# pythontex
|
||||||
|
*.pytxcode
|
||||||
|
pythontex-files-*/
|
||||||
|
|
||||||
|
# tcolorbox
|
||||||
|
*.listing
|
||||||
|
|
||||||
|
# thmtools
|
||||||
|
*.loe
|
||||||
|
|
||||||
|
# TikZ & PGF
|
||||||
|
*.dpth
|
||||||
|
*.md5
|
||||||
|
*.auxlock
|
||||||
|
|
||||||
|
# todonotes
|
||||||
|
*.tdo
|
||||||
|
|
||||||
|
# vhistory
|
||||||
|
*.hst
|
||||||
|
*.ver
|
||||||
|
|
||||||
|
# easy-todo
|
||||||
|
*.lod
|
||||||
|
|
||||||
|
# xcolor
|
||||||
|
*.xcp
|
||||||
|
|
||||||
|
# xmpincl
|
||||||
|
*.xmpi
|
||||||
|
|
||||||
|
# xindy
|
||||||
|
*.xdy
|
||||||
|
|
||||||
|
# xypic precompiled matrices and outlines
|
||||||
|
*.xyc
|
||||||
|
*.xyd
|
||||||
|
|
||||||
|
# endfloat
|
||||||
|
*.ttt
|
||||||
|
*.fff
|
||||||
|
|
||||||
|
# Latexian
|
||||||
|
TSWLatexianTemp*
|
||||||
|
|
||||||
|
## Editors:
|
||||||
|
# WinEdt
|
||||||
|
*.bak
|
||||||
|
*.sav
|
||||||
|
|
||||||
|
# Texpad
|
||||||
|
.texpadtmp
|
||||||
|
|
||||||
|
# LyX
|
||||||
|
*.lyx~
|
||||||
|
|
||||||
|
# Kile
|
||||||
|
*.backup
|
||||||
|
|
||||||
|
# gummi
|
||||||
|
.*.swp
|
||||||
|
|
||||||
|
# KBibTeX
|
||||||
|
*~[0-9]*
|
||||||
|
|
||||||
|
# TeXnicCenter
|
||||||
|
*.tps
|
||||||
|
|
||||||
|
# auto folder when using emacs and auctex
|
||||||
|
./auto/*
|
||||||
|
*.el
|
||||||
|
|
||||||
|
# expex forward references with \gathertags
|
||||||
|
*-tags.tex
|
||||||
|
|
||||||
|
# standalone packages
|
||||||
|
*.sta
|
||||||
|
|
||||||
|
# Makeindex log files
|
||||||
|
*.lpz
|
||||||
|
|
||||||
|
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
|
||||||
|
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
|
||||||
|
# Uncomment the next line to have this generated file ignored.
|
||||||
|
#*Notes.bib
|
||||||
|
|
||||||
|
### LaTeX Patch ###
|
||||||
|
# LIPIcs / OASIcs
|
||||||
|
*.vtc
|
||||||
|
|
||||||
|
# glossaries
|
||||||
|
*.glstex
|
||||||
|
|
||||||
|
### Python ###
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# End of https://www.toptal.com/developers/gitignore/api/python,latex
|
3
joint_paraphrase_model/README.md
Normal file
3
joint_paraphrase_model/README.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# joint_paraphrase_model
|
||||||
|
|
||||||
|
joint training paraphrase model --- neurips
|
112
joint_paraphrase_model/config.py
Normal file
112
joint_paraphrase_model/config.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# general
|
||||||
|
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
PAD = "<__PAD__>"
|
||||||
|
UNK = "<__UNK__>"
|
||||||
|
NOFIX = "<__NOFIX__>"
|
||||||
|
SOS = "<__SOS__>"
|
||||||
|
EOS = "<__EOS__>"
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
teacher_forcing_ratio = 0.5
|
||||||
|
embedding_dim = 300
|
||||||
|
fix_hidden_dim = 128
|
||||||
|
par_hidden_dim = 1024
|
||||||
|
fix_dropout = 0.5
|
||||||
|
par_dropout = 0.2
|
||||||
|
_fix_learning_rate = 0.00001
|
||||||
|
_par_learning_rate = 0.0001
|
||||||
|
learning_rate = _par_learning_rate
|
||||||
|
fix_momentum = 0.9
|
||||||
|
par_momentum = 0.0
|
||||||
|
max_length = 121
|
||||||
|
epochs = 5
|
||||||
|
|
||||||
|
# paths
|
||||||
|
data_path = "./data"
|
||||||
|
provo_predictability_path = os.path.join(
|
||||||
|
data_path, "datasets/provo/Provo_Corpus-Predictability_Norms.csv"
|
||||||
|
)
|
||||||
|
provo_eyetracking_path = os.path.join(
|
||||||
|
data_path, "datasets/provo/Provo_Corpus-Eyetracking_Data.csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
geco_en_path = os.path.join(data_path, "datasets/geco/EnglishMaterial.csv")
|
||||||
|
geco_mono_path = os.path.join(data_path, "datasets/geco/MonolingualReadingData.csv")
|
||||||
|
|
||||||
|
movieqa_human_path = os.path.join(data_path, "datasets/all_word_scores_fixations")
|
||||||
|
movieqa_human_path_2 = os.path.join(
|
||||||
|
data_path, "datasets/all_word_scores_fixations_exp2"
|
||||||
|
)
|
||||||
|
movieqa_human_path_3 = os.path.join(
|
||||||
|
data_path, "datasets/all_word_scores_fixations_exp3"
|
||||||
|
)
|
||||||
|
movieqa_split_plot_path = os.path.join(data_path, "datasets/split_plot_UNRESOLVED")
|
||||||
|
|
||||||
|
cnn_path = os.path.join(
|
||||||
|
data_path,
|
||||||
|
"projects/2019/fixation_prediction/ez-reader-wrapper/predictability/output_cnn/",
|
||||||
|
)
|
||||||
|
dm_path = os.path.join(
|
||||||
|
data_path,
|
||||||
|
"projects/2019/fixation_prediction/ez-reader-wrapper/predictability/output_dm/",
|
||||||
|
)
|
||||||
|
|
||||||
|
qqp_paws_basedir = os.path.join(data_path, "datasets/paw_google/qqp/paws_qqp/output")
|
||||||
|
qqp_paws_train_path = os.path.join(qqp_paws_basedir, "train.tsv")
|
||||||
|
qqp_paws_dev_path = os.path.join(qqp_paws_basedir, "dev.tsv")
|
||||||
|
qqp_paws_test_path = os.path.join(qqp_paws_basedir, "test.tsv")
|
||||||
|
|
||||||
|
qqp_basedir = os.path.join(data_path, "datasets/Quora_question_pair_partition_OG")
|
||||||
|
qqp_train_path = os.path.join(qqp_basedir, "train.tsv")
|
||||||
|
qqp_dev_path = os.path.join(qqp_basedir, "dev.tsv")
|
||||||
|
qqp_test_path = os.path.join(qqp_basedir, "test.tsv")
|
||||||
|
|
||||||
|
qqp_kag_basedir = os.path.join(data_path, "datasets/Quora_question_pair_partition_kag")
|
||||||
|
qqp_kag_train_path = os.path.join(qqp_kag_basedir, "train.tsv")
|
||||||
|
qqp_kag_dev_path = os.path.join(qqp_kag_basedir, "dev.tsv")
|
||||||
|
qqp_kag_test_path = os.path.join(qqp_kag_basedir, "test.tsv")
|
||||||
|
|
||||||
|
wiki_basedir = os.path.join(data_path, "datasets/paw_google/wiki")
|
||||||
|
wiki_train_path = os.path.join(wiki_basedir, "train.tsv")
|
||||||
|
wiki_dev_path = os.path.join(wiki_basedir, "dev.tsv")
|
||||||
|
wiki_test_path = os.path.join(wiki_basedir, "test.tsv")
|
||||||
|
|
||||||
|
msrpc_basedir = os.path.join(data_path, "datasets/MSRPC")
|
||||||
|
msrpc_train_path = os.path.join(msrpc_basedir, "msr_paraphrase_train.txt")
|
||||||
|
msrpc_dev_path = os.path.join(msrpc_basedir, "msr_paraphrase_dev.txt")
|
||||||
|
msrpc_test_path = os.path.join(msrpc_basedir, "msr_paraphrase_test.txt")
|
||||||
|
|
||||||
|
sentiment_basedir = os.path.join(data_path, "datasets/sentiment_kag")
|
||||||
|
sentiment_train_path = os.path.join(sentiment_basedir, "train.tsv")
|
||||||
|
sentiment_dev_path = os.path.join(sentiment_basedir, "dev.tsv")
|
||||||
|
sentiment_test_path = os.path.join(sentiment_basedir, "test.tsv")
|
||||||
|
|
||||||
|
tamil_basedir = os.path.join(data_path, "datasets/en-ta-parallel-v2")
|
||||||
|
tamil_train_path = os.path.join(tamil_basedir, "corpus.bcn.train.enta")
|
||||||
|
tamil_dev_path = os.path.join(tamil_basedir, "corpus.bcn.dev.enta")
|
||||||
|
tamil_test_path = os.path.join(tamil_basedir, "corpus.bcn.test.enta")
|
||||||
|
|
||||||
|
compression_basedir = os.path.join(data_path, "datasets/sentence-compression/data")
|
||||||
|
compression_train_path = os.path.join(compression_basedir, "train.tsv")
|
||||||
|
compression_dev_path = os.path.join(compression_basedir, "dev.tsv")
|
||||||
|
compression_test_path = os.path.join(compression_basedir, "test.tsv")
|
||||||
|
|
||||||
|
stanford_basedir = os.path.join(data_path, "datasets/stanfordSentimentTreebank")
|
||||||
|
stanford_train_path = os.path.join(stanford_basedir, "train.tsv")
|
||||||
|
stanford_dev_path = os.path.join(stanford_basedir, "dev.tsv")
|
||||||
|
stanford_test_path = os.path.join(stanford_basedir, "test.tsv")
|
||||||
|
|
||||||
|
stanford_sent_basedir = os.path.join(data_path, "datasets/stanfordSentimentTreebank")
|
||||||
|
stanford_sent_train_path = os.path.join(stanford_basedir, "train_sent.tsv")
|
||||||
|
stanford_sent_dev_path = os.path.join(stanford_basedir, "dev_sent.tsv")
|
||||||
|
stanford_sent_test_path = os.path.join(stanford_basedir, "test_sent.tsv")
|
||||||
|
|
||||||
|
|
||||||
|
emb_path = os.path.join(data_path, "Google_word2vec/GoogleNews-vectors-negative300.bin")
|
||||||
|
|
||||||
|
glove_path = "glove.840B.300d.txt"
|
1
joint_paraphrase_model/data
Symbolic link
1
joint_paraphrase_model/data
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
/netpool/work/gpu-2/users/soodea/
|
1
joint_paraphrase_model/glove.840B.300d.txt
Symbolic link
1
joint_paraphrase_model/glove.840B.300d.txt
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
/netpool/work/gpu-2/users/soodea/datasets/glove/glove.840B.300d.txt
|
1
joint_paraphrase_model/glove.cache
Normal file
1
joint_paraphrase_model/glove.cache
Normal file
File diff suppressed because one or more lines are too long
0
joint_paraphrase_model/libs/__init__.py
Normal file
0
joint_paraphrase_model/libs/__init__.py
Normal file
416
joint_paraphrase_model/libs/corpora.py
Normal file
416
joint_paraphrase_model/libs/corpora.py
Normal file
|
@ -0,0 +1,416 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(sent):
|
||||||
|
return sent.split(" ")
|
||||||
|
|
||||||
|
|
||||||
|
class Lang:
|
||||||
|
"""Represents the vocabulary
|
||||||
|
"""
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
self.word2index = {
|
||||||
|
config.PAD: 0,
|
||||||
|
config.UNK: 1,
|
||||||
|
config.NOFIX: 2,
|
||||||
|
config.SOS: 3,
|
||||||
|
config.EOS: 4,
|
||||||
|
}
|
||||||
|
self.word2count = {}
|
||||||
|
self.index2word = {
|
||||||
|
0: config.PAD,
|
||||||
|
1: config.UNK,
|
||||||
|
2: config.NOFIX,
|
||||||
|
3: config.SOS,
|
||||||
|
4: config.EOS,
|
||||||
|
}
|
||||||
|
self.n_words = 5
|
||||||
|
|
||||||
|
def add_sentence(self, sentence):
|
||||||
|
assert isinstance(
|
||||||
|
sentence, (list, tuple)
|
||||||
|
), "input to add_sentence must be tokenized"
|
||||||
|
for word in sentence:
|
||||||
|
self.add_word(word)
|
||||||
|
|
||||||
|
def add_word(self, word):
|
||||||
|
if word not in self.word2index:
|
||||||
|
self.word2index[word] = self.n_words
|
||||||
|
self.word2count[word] = 1
|
||||||
|
self.index2word[self.n_words] = word
|
||||||
|
self.n_words += 1
|
||||||
|
else:
|
||||||
|
self.word2count[word] += 1
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
"""Returns a new Lang object containing the vocabulary from this and
|
||||||
|
the other Lang object
|
||||||
|
"""
|
||||||
|
new_lang = Lang(f"{self.name}_{other.name}")
|
||||||
|
|
||||||
|
# Add vocabulary from both Langs
|
||||||
|
for word in self.word2count.keys():
|
||||||
|
new_lang.add_word(word)
|
||||||
|
for word in other.word2count.keys():
|
||||||
|
new_lang.add_word(word)
|
||||||
|
|
||||||
|
# Fix the counts on the new one
|
||||||
|
for word in new_lang.word2count.keys():
|
||||||
|
new_lang.word2count[word] = self.word2count.get(
|
||||||
|
word, 0
|
||||||
|
) + other.word2count.get(word, 0)
|
||||||
|
|
||||||
|
return new_lang
|
||||||
|
|
||||||
|
|
||||||
|
def load_wiki(split):
|
||||||
|
"""Load the Wiki from PAWs"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_wiki")
|
||||||
|
lang = Lang("wiki")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.wiki_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.wiki_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.wiki_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
# skip header
|
||||||
|
handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
_, sent1, sent2, rating = line.strip().split("\t")
|
||||||
|
if rating == "0":
|
||||||
|
continue
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
# pairs.append([sent1, sent2, rating])
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
# MS makes the vocab for paraphrase the same
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
|
||||||
|
def load_qqp_paws(split):
|
||||||
|
"""Load the QQP from PAWs"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_qqp_paws")
|
||||||
|
lang = Lang("qqp_paws")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.qqp_paws_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.qqp_paws_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.qqp_paws_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
# skip header
|
||||||
|
handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
_, sent1, sent2, rating = line.strip().split("\t")
|
||||||
|
if rating == "0":
|
||||||
|
continue
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
# pairs.append([sent1, sent2, rating])
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
# MS makes the vocab for paraphrase the same
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
def load_qqp(split):
|
||||||
|
"""Load the QQP from Original"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_qqp")
|
||||||
|
lang = Lang("qqp")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.qqp_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.qqp_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.qqp_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
# skip header
|
||||||
|
handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
rating, sent1, sent2, _ = line.strip().split("\t")
|
||||||
|
if rating == "0":
|
||||||
|
continue
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
# pairs.append([sent1, sent2, rating])
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
# MS makes the vocab for paraphrase the same
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
|
||||||
|
def load_qqp_kag(split):
|
||||||
|
"""Load the QQP from Kaggle""" #not original right now, expriemnting with kaggle 100K, 3K, 30K split
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_qqp_kag")
|
||||||
|
lang = Lang("qqp_kag")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.qqp_kag_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.qqp_kag_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.qqp_kag_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
# skip header
|
||||||
|
handle.readline()
|
||||||
|
|
||||||
|
for line in handle: #when reading the kag version we do not have 4 fields, but rather 3
|
||||||
|
rating, sent1, sent2 = line.strip().split("\t")
|
||||||
|
if rating == "0":
|
||||||
|
continue
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
# pairs.append([sent1, sent2, rating])
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
# MS makes the vocab for paraphrase the same
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
|
||||||
|
def load_msrpc(split):
|
||||||
|
"""Load the Microsoft Research Paraphrase Corpus (MSRPC)"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_msrpc")
|
||||||
|
lang = Lang("msrpc")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.msrpc_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.msrpc_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.msrpc_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
# skip header
|
||||||
|
handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
rating, _, _, sent1, sent2 = line.strip().split("\t")
|
||||||
|
if rating == "0":
|
||||||
|
continue
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
# pairs.append([sent1, sent2, rating])
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
# return src_lang, dst_lang, pairs
|
||||||
|
# MS makes the vocab for paraphrase the same
|
||||||
|
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
def load_sentiment(split):
|
||||||
|
"""Load the Sentiment Kaggle Comp Dataset"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_sentiment")
|
||||||
|
lang = Lang("sentiment")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.sentiment_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.sentiment_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.sentiment_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
# skip header
|
||||||
|
handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
_, _, sent1, sent2 = line.strip().split("\t")
|
||||||
|
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
# pairs.append([sent1, sent2, rating])
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
|
||||||
|
def load_tamil(split):
|
||||||
|
"""Load the En to Tamil dataset, current SOTA ~13 bleu"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_tamil")
|
||||||
|
lang = Lang("tamil")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.tamil_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.tamil_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.tamil_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
sent1, sent2 = line.strip().split("\t")
|
||||||
|
#if rating == "0":
|
||||||
|
# continue
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
#I dunno how to tokenize tamil.....?
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
def load_compression(split):
|
||||||
|
"""Load the Google Sentence Compression Dataset"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_compression")
|
||||||
|
lang = Lang("compression")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.compression_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.compression_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.compression_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
sent1, sent2 = line.strip().split("\t")
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
# print(len(sent1), sent1)
|
||||||
|
# print(len(sent2), sent2)
|
||||||
|
# print()
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
def load_stanford(split):
|
||||||
|
"""Load the Stanford Sentiment Dataset phrases"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_stanford")
|
||||||
|
lang = Lang("stanford")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.stanford_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.stanford_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.stanford_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
# skip header
|
||||||
|
#handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
_, _, sent1, sent2 = line.strip().split("\t")
|
||||||
|
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
# pairs.append([sent1, sent2, rating])
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
return pairs, lang
|
||||||
|
|
||||||
|
def load_stanford_sent(split):
|
||||||
|
"""Load the Stanford Sentiment Dataset sentences"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_stanford_sent")
|
||||||
|
lang = Lang("stanford_sent")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.stanford_sent_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.stanford_sent_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.stanford_sent_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
with open(path) as handle:
|
||||||
|
|
||||||
|
# skip header
|
||||||
|
#handle.readline()
|
||||||
|
|
||||||
|
for line in handle:
|
||||||
|
_, _, sent1, sent2 = line.strip().split("\t")
|
||||||
|
|
||||||
|
sent1 = tokenize(sent1)
|
||||||
|
sent2 = tokenize(sent2)
|
||||||
|
lang.add_sentence(sent1)
|
||||||
|
lang.add_sentence(sent2)
|
||||||
|
|
||||||
|
# pairs.append([sent1, sent2, rating])
|
||||||
|
pairs.append([sent1, sent2])
|
||||||
|
|
||||||
|
return pairs, lang
|
|
@ -0,0 +1 @@
|
||||||
|
from .main import *
|
131
joint_paraphrase_model/libs/fixation_generation/main.py
Normal file
131
joint_paraphrase_model/libs/fixation_generation/main.py
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from .self_attention import Transformer
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def random_embedding(vocab_size, embedding_dim):
|
||||||
|
pretrain_emb = np.empty([vocab_size, embedding_dim])
|
||||||
|
scale = np.sqrt(3.0 / embedding_dim)
|
||||||
|
for index in range(vocab_size):
|
||||||
|
pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
|
||||||
|
return pretrain_emb
|
||||||
|
|
||||||
|
|
||||||
|
def neg_log_likelihood_loss(outputs, batch_label, batch_size, seq_len):
|
||||||
|
outputs = outputs.view(batch_size * seq_len, -1)
|
||||||
|
score = F.log_softmax(outputs, 1)
|
||||||
|
|
||||||
|
loss = nn.NLLLoss(ignore_index=0, size_average=False)(
|
||||||
|
score, batch_label.view(batch_size * seq_len)
|
||||||
|
)
|
||||||
|
loss = loss / batch_size
|
||||||
|
_, tag_seq = torch.max(score, 1)
|
||||||
|
tag_seq = tag_seq.view(batch_size, seq_len)
|
||||||
|
|
||||||
|
# print(score[0], tag_seq[0])
|
||||||
|
|
||||||
|
return loss, tag_seq
|
||||||
|
|
||||||
|
|
||||||
|
def mse_loss(outputs, batch_label, batch_size, seq_len, word_seq_length):
|
||||||
|
# score = torch.nn.functional.softmax(outputs, 1)
|
||||||
|
score = torch.sigmoid(outputs)
|
||||||
|
|
||||||
|
mask = torch.zeros_like(score)
|
||||||
|
for i, v in enumerate(word_seq_length):
|
||||||
|
mask[i, 0:v] = 1
|
||||||
|
|
||||||
|
score = score * mask
|
||||||
|
|
||||||
|
loss = nn.MSELoss(reduction="sum")(
|
||||||
|
score.view(batch_size, seq_len), batch_label.view(batch_size, seq_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = loss / batch_size
|
||||||
|
|
||||||
|
return loss, score.view(batch_size, seq_len)
|
||||||
|
|
||||||
|
|
||||||
|
class Network(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_type,
|
||||||
|
vocab_size,
|
||||||
|
embedding_dim,
|
||||||
|
dropout,
|
||||||
|
hidden_dim,
|
||||||
|
embeddings=None,
|
||||||
|
attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.logger = logging.getLogger(f"{__name__}")
|
||||||
|
prelayers = OrderedDict()
|
||||||
|
postlayers = OrderedDict()
|
||||||
|
|
||||||
|
if embedding_type in ("w2v", "glove"):
|
||||||
|
if embeddings is not None:
|
||||||
|
prelayers["embedding_layer"] = nn.Embedding.from_pretrained(embeddings)
|
||||||
|
else:
|
||||||
|
prelayers["embedding_layer"] = nn.Embedding(vocab_size, embedding_dim)
|
||||||
|
prelayers["embedding_dropout_layer"] = nn.Dropout(dropout)
|
||||||
|
embedding_dim = 300
|
||||||
|
elif embedding_type == "bert":
|
||||||
|
embedding_dim = 768
|
||||||
|
|
||||||
|
self.lstm = BiLSTM(embedding_dim, hidden_dim // 2, num_layers=1)
|
||||||
|
postlayers["lstm_dropout_layer"] = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
if attention:
|
||||||
|
# increased compl with 1024D, and 16,16: for no att and att experiments
|
||||||
|
# before: for the initial att and pretraining: heads 4 and layers 4, 128D
|
||||||
|
# then was 128 D with heads 4 layer 1 = results for all IUI
|
||||||
|
###postlayers["position_encodings"] = PositionalEncoding(hidden_dim)
|
||||||
|
postlayers["attention_layer"] = Transformer(
|
||||||
|
d_model=hidden_dim, n_heads=4, n_layers=1
|
||||||
|
)
|
||||||
|
|
||||||
|
postlayers["ff_layer"] = nn.Linear(hidden_dim, hidden_dim // 2)
|
||||||
|
postlayers["ff_activation"] = nn.ReLU()
|
||||||
|
postlayers["output_layer"] = nn.Linear(hidden_dim // 2, 1)
|
||||||
|
|
||||||
|
self.logger.info(f"prelayers: {prelayers.keys()}")
|
||||||
|
self.logger.info(f"postlayers: {postlayers.keys()}")
|
||||||
|
|
||||||
|
self.pre = nn.Sequential(prelayers)
|
||||||
|
self.post = nn.Sequential(postlayers)
|
||||||
|
|
||||||
|
def forward(self, x, word_seq_length):
|
||||||
|
x = self.pre(x)
|
||||||
|
x = self.lstm(x, word_seq_length)
|
||||||
|
#MS pritning fix model params
|
||||||
|
#for p in self.parameters():
|
||||||
|
# print(p.data)
|
||||||
|
# break
|
||||||
|
|
||||||
|
return self.post(x.transpose(1, 0))
|
||||||
|
|
||||||
|
|
||||||
|
class BiLSTM(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, lstm_hidden, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.LSTM(
|
||||||
|
input_size=embedding_dim,
|
||||||
|
hidden_size=lstm_hidden,
|
||||||
|
num_layers=num_layers,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, word_seq_length):
|
||||||
|
packed_words = pack_padded_sequence(x, word_seq_length, True, False)
|
||||||
|
lstm_out, hidden = self.net(packed_words)
|
||||||
|
lstm_out, _ = pad_packed_sequence(lstm_out)
|
||||||
|
return lstm_out
|
|
@ -0,0 +1,131 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, d_hid, n_position=200):
|
||||||
|
super(PositionalEncoding, self).__init__()
|
||||||
|
|
||||||
|
# Not a parameter
|
||||||
|
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||||
|
|
||||||
|
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||||
|
''' Sinusoid position encoding table '''
|
||||||
|
# TODO: make it with torch instead of numpy
|
||||||
|
|
||||||
|
def get_position_angle_vec(position):
|
||||||
|
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||||
|
|
||||||
|
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||||
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||||
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||||
|
|
||||||
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayer(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(AttentionLayer, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, Q, K, V):
|
||||||
|
# Q: float32:[batch_size, n_queries, d_k]
|
||||||
|
# K: float32:[batch_size, n_keys, d_k]
|
||||||
|
# V: float32:[batch_size, n_keys, d_v]
|
||||||
|
dk = K.shape[-1]
|
||||||
|
dv = V.shape[-1]
|
||||||
|
KT = torch.transpose(K, -1, -2)
|
||||||
|
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
|
||||||
|
# weight_logits: float32[batch_size, n_queries, n_keys]
|
||||||
|
weights = F.softmax(weight_logits, dim=-1)
|
||||||
|
# weight: float32[batch_size, n_queries, n_keys]
|
||||||
|
return torch.bmm(weights, V)
|
||||||
|
# return float32[batch_size, n_queries, dv]
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadedSelfAttentionLayer(nn.Module):
|
||||||
|
def __init__(self, d_model, n_heads):
|
||||||
|
super(MultiHeadedSelfAttentionLayer, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
print('{} {}'.format(d_model, n_heads))
|
||||||
|
assert d_model % n_heads == 0
|
||||||
|
self.d_k = d_model // n_heads
|
||||||
|
self.d_v = self.d_k
|
||||||
|
self.attention_layer = AttentionLayer()
|
||||||
|
self.W_Qs = nn.ModuleList([
|
||||||
|
nn.Linear(d_model, self.d_k, bias=False)
|
||||||
|
for _ in range(n_heads)
|
||||||
|
])
|
||||||
|
self.W_Ks = nn.ModuleList([
|
||||||
|
nn.Linear(d_model, self.d_k, bias=False)
|
||||||
|
for _ in range(n_heads)
|
||||||
|
])
|
||||||
|
self.W_Vs = nn.ModuleList([
|
||||||
|
nn.Linear(d_model, self.d_v, bias=False)
|
||||||
|
for _ in range(n_heads)
|
||||||
|
])
|
||||||
|
self.W_O = nn.Linear(d_model, d_model, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x:float32[batch_size, sequence_length, self.d_model]
|
||||||
|
head_outputs = []
|
||||||
|
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
|
||||||
|
Q = W_Q(x)
|
||||||
|
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||||
|
K = W_K(x)
|
||||||
|
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||||
|
V = W_V(x)
|
||||||
|
# Q float32:[batch_size, sequence_length, self.d_v]
|
||||||
|
head_output = self.attention_layer(Q, K, V)
|
||||||
|
# float32:[batch_size, sequence_length, self.d_v]
|
||||||
|
head_outputs.append(head_output)
|
||||||
|
concatenated = torch.cat(head_outputs, dim=-1)
|
||||||
|
# concatenated float32:[batch_size, sequence_length, self.d_model]
|
||||||
|
out = self.W_O(concatenated)
|
||||||
|
# out float32:[batch_size, sequence_length, self.d_model]
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Feedforward(nn.Module):
|
||||||
|
def __init__(self, d_model):
|
||||||
|
super(Feedforward, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.W1 = nn.Linear(d_model, d_model)
|
||||||
|
self.W2 = nn.Linear(d_model, d_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: float32[batch_size, sequence_length, d_model]
|
||||||
|
return self.W2(torch.relu(self.W1(x)))
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, d_model, n_heads, n_layers):
|
||||||
|
super(Transformer, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.attention_layers = nn.ModuleList([
|
||||||
|
MultiHeadedSelfAttentionLayer(d_model, n_heads)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
])
|
||||||
|
self.ffs = nn.ModuleList([
|
||||||
|
Feedforward(d_model)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: float32[batch_size, sequence_length, self.d_model]
|
||||||
|
for attention_layer, ff in zip(self.attention_layers, self.ffs):
|
||||||
|
attention_out = attention_layer(x)
|
||||||
|
# attention_out: float32[batch_size, sequence_length, self.d_model]
|
||||||
|
x = F.layer_norm(x + attention_out, x.shape[2:])
|
||||||
|
ff_out = ff(x)
|
||||||
|
# ff_out: float32[batch_size, sequence_length, self.d_model]
|
||||||
|
x = F.layer_norm(x + ff_out, x.shape[2:])
|
||||||
|
return x
|
|
@ -0,0 +1 @@
|
||||||
|
from .main import *
|
86
joint_paraphrase_model/libs/paraphrase_generation/main.py
Normal file
86
joint_paraphrase_model/libs/paraphrase_generation/main.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.ticker as ticker
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderRNN(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, embeddings):
|
||||||
|
super(EncoderRNN, self).__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding.from_pretrained(embeddings)
|
||||||
|
self.gru = nn.GRU(input_size, hidden_size)
|
||||||
|
|
||||||
|
def forward(self, input, hidden):
|
||||||
|
embedded = self.embedding(input).view(1, 1, -1)
|
||||||
|
output = embedded
|
||||||
|
output, hidden = self.gru(output, hidden)
|
||||||
|
return output, hidden
|
||||||
|
|
||||||
|
def initHidden(self):
|
||||||
|
return torch.zeros(1, 1, self.hidden_size)
|
||||||
|
|
||||||
|
class AttnDecoderRNN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size,
|
||||||
|
hidden_size,
|
||||||
|
output_size,
|
||||||
|
embeddings,
|
||||||
|
dropout_p,
|
||||||
|
max_length,
|
||||||
|
):
|
||||||
|
super(AttnDecoderRNN, self).__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.dropout_p = dropout_p
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding.from_pretrained(embeddings) #for paragen
|
||||||
|
#self.embedding = nn.Embedding(len(embeddings), 300) #for NMT with tamil, trying wiht senitment too
|
||||||
|
self.attn = nn.Linear(self.input_size + self.hidden_size, self.max_length)
|
||||||
|
self.attn_combine = nn.Linear(
|
||||||
|
self.input_size + self.hidden_size, self.hidden_size
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(self.dropout_p)
|
||||||
|
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
|
||||||
|
self.out = nn.Linear(self.hidden_size, self.output_size)
|
||||||
|
|
||||||
|
def forward(self, input, hidden, encoder_outputs, fixations):
|
||||||
|
embedded = self.embedding(input).view(1, 1, -1)
|
||||||
|
embedded = self.dropout(embedded)
|
||||||
|
|
||||||
|
attn_weights = F.softmax(
|
||||||
|
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = attn_weights * torch.nn.ConstantPad1d((0, attn_weights.shape[-1] - fixations.shape[-2]), 0)(fixations.squeeze().unsqueeze(0))
|
||||||
|
|
||||||
|
# attn_weights = torch.softmax(attn_weights * torch.nn.ConstantPad1d((0, attn_weights.shape[-1] - fixations.shape[-2]), 0)(fixations.squeeze().unsqueeze(0)), dim=1)
|
||||||
|
|
||||||
|
attn_applied = torch.bmm(
|
||||||
|
attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.cat((embedded[0], attn_applied[0]), 1)
|
||||||
|
output = self.attn_combine(output).unsqueeze(0)
|
||||||
|
|
||||||
|
output = F.relu(output)
|
||||||
|
output, hidden = self.gru(output, hidden)
|
||||||
|
|
||||||
|
# output = F.log_softmax(self.out(output[0]), dim=1)
|
||||||
|
output = self.out(output[0])
|
||||||
|
# output = F.log_softmax(output, dim=1)
|
||||||
|
return output, hidden, attn_weights
|
225
joint_paraphrase_model/libs/utils.py
Normal file
225
joint_paraphrase_model/libs/utils.py
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.ticker as ticker
|
||||||
|
from nltk.translate.bleu_score import sentence_bleu
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
plt.switch_backend("agg")
|
||||||
|
|
||||||
|
|
||||||
|
def load_glove(vocabulary):
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_glove")
|
||||||
|
logger.info("loading embeddings")
|
||||||
|
try:
|
||||||
|
with open(f"glove.cache") as h:
|
||||||
|
cache = json.load(h)
|
||||||
|
except:
|
||||||
|
logger.info("cache doesn't exist")
|
||||||
|
cache = {}
|
||||||
|
cache[config.PAD] = [0] * 300
|
||||||
|
cache[config.SOS] = [0] * 300
|
||||||
|
cache[config.EOS] = [0] * 300
|
||||||
|
cache[config.UNK] = [0] * 300
|
||||||
|
cache[config.NOFIX] = [0] * 300
|
||||||
|
else:
|
||||||
|
logger.info("cache found")
|
||||||
|
|
||||||
|
cache_miss = False
|
||||||
|
|
||||||
|
if not set(vocabulary) <= set(cache):
|
||||||
|
cache_miss = True
|
||||||
|
logger.warn("cache miss, loading full embeddings")
|
||||||
|
data = {}
|
||||||
|
with open("glove.840B.300d.txt") as h:
|
||||||
|
for line in h:
|
||||||
|
word, *emb = line.strip().split()
|
||||||
|
try:
|
||||||
|
data[word] = [float(x) for x in emb]
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
logger.info("finished loading full embeddings")
|
||||||
|
for word in vocabulary:
|
||||||
|
try:
|
||||||
|
cache[word] = data[word]
|
||||||
|
except KeyError:
|
||||||
|
cache[word] = [0] * 300
|
||||||
|
logger.info("cache updated")
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for word in vocabulary:
|
||||||
|
embeddings.append(torch.tensor(cache[word], dtype=torch.float32))
|
||||||
|
embeddings = torch.stack(embeddings)
|
||||||
|
|
||||||
|
if cache_miss:
|
||||||
|
with open(f"glove.cache", "w") as h:
|
||||||
|
json.dump(cache, h)
|
||||||
|
logger.info("cache saved")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(s):
|
||||||
|
s = s.lower().strip()
|
||||||
|
s = re.sub(r"([.!?])", r" \1", s)
|
||||||
|
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
|
||||||
|
s = s.split(" ")
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def indices_from_sentence(word2index, sentence, unknown_threshold):
|
||||||
|
if unknown_threshold:
|
||||||
|
return [
|
||||||
|
word2index.get(
|
||||||
|
word if random.random() > unknown_threshold else config.UNK,
|
||||||
|
word2index[config.UNK],
|
||||||
|
)
|
||||||
|
for word in sentence
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
word2index.get(word, word2index[config.UNK]) for word in sentence
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_from_sentence(word2index, sentence, unknown_threshold):
|
||||||
|
# indices = [config.SOS]
|
||||||
|
indices = indices_from_sentence(word2index, sentence, unknown_threshold)
|
||||||
|
indices.append(word2index[config.EOS])
|
||||||
|
return torch.tensor(indices, dtype=torch.long, device=config.DEV)
|
||||||
|
|
||||||
|
|
||||||
|
def tensors_from_pair(word2index, pair, shuffle, unknown_threshold):
|
||||||
|
tensors = [
|
||||||
|
tensor_from_sentence(word2index, pair[0], unknown_threshold),
|
||||||
|
tensor_from_sentence(word2index, pair[1], unknown_threshold),
|
||||||
|
]
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(tensors)
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
def bleu(reference, hypothesis, n=4): #not sure if this actually changes the n gram
|
||||||
|
if n < 1:
|
||||||
|
return 0
|
||||||
|
weights = [1/n]*n
|
||||||
|
return sentence_bleu([reference], hypothesis, weights)
|
||||||
|
|
||||||
|
|
||||||
|
def pair_iter(pairs, word2index, shuffle=False, shuffle_pairs=False, unknown_threshold=0.00):
|
||||||
|
if shuffle:
|
||||||
|
pairs = pairs.copy()
|
||||||
|
random.shuffle(pairs)
|
||||||
|
for pair in pairs:
|
||||||
|
tensor1, tensor2 = tensors_from_pair(word2index, (pair[0], pair[1]), shuffle_pairs, unknown_threshold)
|
||||||
|
yield (tensor1,), (tensor2,)
|
||||||
|
|
||||||
|
|
||||||
|
def sent_iter(sents, word2index, unknown_threshold=0.00):
|
||||||
|
for sent in sents:
|
||||||
|
tensor = tensor_from_sentence(word2index, sent, unknown_threshold)
|
||||||
|
yield (tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_iter(pairs, word2index, batch_size, shuffle=False, unknown_threshold=0.00):
|
||||||
|
for i in range(len(pairs) // batch_size):
|
||||||
|
batch = pairs[i : i + batch_size]
|
||||||
|
if len(batch) != batch_size:
|
||||||
|
continue
|
||||||
|
batch_tensors = [
|
||||||
|
tensors_from_pair(word2index, (pair[0], pair[1]), shuffle, unknown_threshold)
|
||||||
|
for pair in batch
|
||||||
|
]
|
||||||
|
|
||||||
|
tensors1, tensors2 = zip(*batch_tensors)
|
||||||
|
|
||||||
|
# targets = torch.tensor(targets, dtype=torch.long, device=config.DEV)
|
||||||
|
|
||||||
|
# tensors1_lengths = [len(t) for t in tensors1]
|
||||||
|
# tensors2_lengths = [len(t) for t in tensors2]
|
||||||
|
|
||||||
|
# tensors1 = nn.utils.rnn.pack_sequence(tensors1, enforce_sorted=False)
|
||||||
|
# tensors2 = nn.utils.rnn.pack_sequence(tensors2, enforce_sorted=False)
|
||||||
|
|
||||||
|
yield tensors1, tensors2
|
||||||
|
|
||||||
|
|
||||||
|
def asMinutes(s):
|
||||||
|
m = math.floor(s / 60)
|
||||||
|
s -= m * 60
|
||||||
|
return "%dm %ds" % (m, s)
|
||||||
|
|
||||||
|
|
||||||
|
def timeSince(since, percent):
|
||||||
|
now = time.time()
|
||||||
|
s = now - since
|
||||||
|
es = s / (percent)
|
||||||
|
rs = es - s
|
||||||
|
return "%s (- %s)" % (asMinutes(s), asMinutes(rs))
|
||||||
|
|
||||||
|
|
||||||
|
def showPlot(points):
|
||||||
|
plt.figure()
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
# this locator puts ticks at regular intervals
|
||||||
|
loc = ticker.MultipleLocator(base=0.2)
|
||||||
|
ax.yaxis.set_major_locator(loc)
|
||||||
|
plt.plot(points)
|
||||||
|
|
||||||
|
|
||||||
|
def showAttention(input_sentence, output_words, attentions):
|
||||||
|
# Set up figure with colorbar
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
cax = ax.matshow(attentions.numpy(), cmap="bone")
|
||||||
|
fig.colorbar(cax)
|
||||||
|
|
||||||
|
# Set up axes
|
||||||
|
ax.set_xticklabels([""] + input_sentence.split(" ") + ["<__EOS__>"], rotation=90)
|
||||||
|
ax.set_yticklabels([""] + output_words)
|
||||||
|
|
||||||
|
# Show label at every tick
|
||||||
|
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||||
|
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def evaluateAndShowAttention(input_sentence):
|
||||||
|
output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence)
|
||||||
|
print("input =", input_sentence)
|
||||||
|
print("output =", " ".join(output_words))
|
||||||
|
showAttention(input_sentence, output_words, attentions)
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model, word2index, path):
|
||||||
|
if not path.endswith(".tar"):
|
||||||
|
path += ".tar"
|
||||||
|
torch.save(
|
||||||
|
{"weights": model.state_dict(), "word2index": word2index},
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(path):
|
||||||
|
checkpoint = torch.load(path)
|
||||||
|
return checkpoint["weights"], checkpoint["word2index"]
|
||||||
|
|
||||||
|
|
||||||
|
def extend_vocabulary(word2index, langs):
|
||||||
|
for lang in langs:
|
||||||
|
for word in lang.word2index:
|
||||||
|
if word not in word2index:
|
||||||
|
word2index[word] = len(word2index)
|
||||||
|
return word2index
|
794
joint_paraphrase_model/main.py
Normal file
794
joint_paraphrase_model/main.py
Normal file
|
@ -0,0 +1,794 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import click
|
||||||
|
import sacrebleu
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import config
|
||||||
|
from libs import corpora
|
||||||
|
from libs import utils
|
||||||
|
from libs.fixation_generation import Network as FixNN
|
||||||
|
from libs.paraphrase_generation import (
|
||||||
|
EncoderRNN as ParEncNN,
|
||||||
|
AttnDecoderRNN as ParDecNN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
cwd = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
logger = logging.getLogger("main")
|
||||||
|
|
||||||
|
'''
|
||||||
|
#qqp_paw sentences:
|
||||||
|
|
||||||
|
debug_sentences = [
|
||||||
|
'What are the driving rules in Georgia versus Mississippi ?',
|
||||||
|
'I want to be a child psychologist , what qualification do i need to become one ? Are there good and reputed psychology Institute or Colleges in India ?',
|
||||||
|
'What is deep web and dark web and what are the contents of these sites ?',
|
||||||
|
'What is difference between North Indian Brahmins and South Indian Brahmins ?',
|
||||||
|
'Is carbon dioxide an ionic bond or a covalent bond ?',
|
||||||
|
'How do accounts receivable and accounts payable differ ?',
|
||||||
|
'Why did Wikipedia hide its audit history for ( superluminal ) successful speed experiments ?',
|
||||||
|
'What makes a person simple , or inversely , complicated ?',
|
||||||
|
'`` How do you say `` Miss you , too , `` in Spanish ? Are there multiple ways to say it ? ‘’',
|
||||||
|
'What is the difference between dominant trait and recessive trait ?’',
|
||||||
|
'`` What is the difference between `` seeing someone , `` `` dating someone , `` and `` having a girlfriend/boyfriend `` ? ‘’',
|
||||||
|
'How was the Empire State building built and designed ? How is it used ?',
|
||||||
|
'What is the sum of the square roots of first n natural number ?',
|
||||||
|
'Why is Roman Saini not so active on Quora now a days ?',
|
||||||
|
'If I have someone blocked on Instagram , and see their story , can they view I viewed it ?',
|
||||||
|
'Amongst the major IT companies of India which is the best ; Wipro , Capgemini , Infosys , TCS or is Oracle the best ?',
|
||||||
|
'How much mass does Saturn lose each year ? How much mass does it gain ?',
|
||||||
|
'What is a cheap healthy diet , I can keep the same and eat every day ?',
|
||||||
|
' What is it like to be beautiful ? Not just pretty or hot , but the kind of almost objective beauty that people are sometimes intimidated by ?',
|
||||||
|
'Could someone tell of a James Ronsey ( misspelled likely ) , writer and filmmaker , probably of the British Isles ?',
|
||||||
|
'How much pressure Is there around the core of Pluto ? is it enough to turn hydrogen/helium gas into a liquid or metallic state ?',
|
||||||
|
'How does quality of life in Vancouver compare to that in Melbourne Or Brisbane ?',
|
||||||
|
]
|
||||||
|
'''
|
||||||
|
'''
|
||||||
|
#wiki sentences:
|
||||||
|
|
||||||
|
debug_sentences = [
|
||||||
|
'They were there to enjoy us and they were there to pray for us .',
|
||||||
|
'Components of elastic potential systems store mechanical energy if they are deformed when forces are applied to the system .',
|
||||||
|
'Steam can also be used , and does not need to be pumped .',
|
||||||
|
'The solar approach to this requirement is the use of solar panels in a conventional-powered aircraft .',
|
||||||
|
'Daudkhali is a village in Barisal Division in the Pirojpur district in southwestern Bangladesh .',
|
||||||
|
'Briggs later met Briggs at the 1967 Monterey Pop Festival , where Ravi Shankar was also performing , with Eric Burdon and The Animals .',
|
||||||
|
'Brockton is approximately 25 miles northeast of Providence , Rhode Island , and 30 miles south of Boston .',
|
||||||
|
]
|
||||||
|
'''
|
||||||
|
|
||||||
|
#qqp sentences:
|
||||||
|
|
||||||
|
debug_sentences = [
|
||||||
|
'How do I get funding for my web based startup idea ?',
|
||||||
|
'What do intelligent people do to pass time ?',
|
||||||
|
'Which is the best SEO Company in Delhi ?',
|
||||||
|
'Why do you waer makeup ?',
|
||||||
|
'How do start chatting with a girl ?',
|
||||||
|
'What is the meaning of living life ?',
|
||||||
|
'Why do my armpits hurt ?',
|
||||||
|
'Why does eye color change with age ?',
|
||||||
|
'How do you find the standard deviation of a probability distribution ? What are some examples ?',
|
||||||
|
'How can I complete my 11 syllabus in one month ?',
|
||||||
|
'How do I concentrate better on my studies ?',
|
||||||
|
'Which is the best retirement plan in india ?',
|
||||||
|
'Should I tell my best friend I love her ?',
|
||||||
|
'Which is the best company for Appian Vagrant online job support ?',
|
||||||
|
'How can one do for good handwriting ?',
|
||||||
|
'What are remedies to get rid of belly fat ?',
|
||||||
|
'What is the best way to cook precooked turkey ?',
|
||||||
|
'What is the future of e-commerce in India ?',
|
||||||
|
'Why do my burps taste like rotten eggs ?',
|
||||||
|
'What is an example of chemical weathering ?',
|
||||||
|
'What are some of the advantages and disadvantages of cyber schooling ?',
|
||||||
|
'How can I increase traffic to my websites by Facebook ?',
|
||||||
|
'How do I increase my patience level in life ?',
|
||||||
|
'What are the best hospitals for treating cancer in India ?',
|
||||||
|
'Will Jio sim work in a 3G phone ? If yes , how ?',
|
||||||
|
]
|
||||||
|
|
||||||
|
debug_sentences = [s.split(" ") for s in debug_sentences]
|
||||||
|
|
||||||
|
|
||||||
|
class Network(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
word2index,
|
||||||
|
embeddings,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.logger = logging.getLogger(f"{__name__}")
|
||||||
|
self.word2index = word2index
|
||||||
|
self.index2word = {i: k for k, i in word2index.items()}
|
||||||
|
self.fix_gen = FixNN(
|
||||||
|
embedding_type="glove",
|
||||||
|
vocab_size=len(word2index),
|
||||||
|
embedding_dim=config.embedding_dim,
|
||||||
|
embeddings=embeddings,
|
||||||
|
dropout=config.fix_dropout,
|
||||||
|
hidden_dim=config.fix_hidden_dim,
|
||||||
|
)
|
||||||
|
self.par_enc = ParEncNN(
|
||||||
|
input_size=config.embedding_dim,
|
||||||
|
hidden_size=config.par_hidden_dim,
|
||||||
|
embeddings=embeddings,
|
||||||
|
)
|
||||||
|
self.par_dec = ParDecNN(
|
||||||
|
input_size=config.embedding_dim,
|
||||||
|
hidden_size=config.par_hidden_dim,
|
||||||
|
output_size=len(word2index),
|
||||||
|
embeddings=embeddings,
|
||||||
|
dropout_p=config.par_dropout,
|
||||||
|
max_length=config.max_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, target=None, teacher_forcing_ratio=None):
|
||||||
|
teacher_forcing_ratio = teacher_forcing_ratio if teacher_forcing_ratio is not None else config.teacher_forcing_ratio
|
||||||
|
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
|
||||||
|
x2 = nn.utils.rnn.pad_sequence(x, batch_first=False)
|
||||||
|
fixations = torch.sigmoid(self.fix_gen(x1, [len(_x) for _x in x1]))
|
||||||
|
|
||||||
|
enc_hidden = self.par_enc.initHidden().to(config.DEV)
|
||||||
|
enc_outs = torch.zeros(config.max_length, config.par_hidden_dim, device=config.DEV)
|
||||||
|
|
||||||
|
for ei in range(len(x2)):
|
||||||
|
enc_out, enc_hidden = self.par_enc(x2[ei], enc_hidden)
|
||||||
|
enc_outs[ei] += enc_out[0, 0]
|
||||||
|
|
||||||
|
dec_in = torch.tensor([[self.word2index[config.SOS]]], device=config.DEV) # SOS
|
||||||
|
dec_hidden = enc_hidden
|
||||||
|
dec_outs = []
|
||||||
|
dec_words = []
|
||||||
|
dec_atts = torch.zeros(config.max_length, config.max_length)
|
||||||
|
|
||||||
|
if target is not None: # training
|
||||||
|
target = nn.utils.rnn.pad_sequence(target, batch_first=False)
|
||||||
|
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
||||||
|
|
||||||
|
if use_teacher_forcing:
|
||||||
|
for di in range(len(target)):
|
||||||
|
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||||
|
dec_in, dec_hidden, enc_outs, fixations
|
||||||
|
)
|
||||||
|
dec_outs.append(dec_out)
|
||||||
|
dec_atts[di] = dec_att.data
|
||||||
|
dec_input = target[di]
|
||||||
|
|
||||||
|
else:
|
||||||
|
for di in range(len(target)):
|
||||||
|
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||||
|
dec_in, dec_hidden, enc_outs, fixations
|
||||||
|
)
|
||||||
|
dec_outs.append(dec_out)
|
||||||
|
dec_atts[di] = dec_att.data
|
||||||
|
topv, topi = dec_out.data.topk(1)
|
||||||
|
dec_words.append(self.index2word[topi.item()])
|
||||||
|
|
||||||
|
dec_input = topi.squeeze().detach()
|
||||||
|
|
||||||
|
else: # prediction
|
||||||
|
for di in range(config.max_length):
|
||||||
|
dec_out, dec_hidden, dec_att = self.par_dec(
|
||||||
|
dec_in, dec_hidden, enc_outs, fixations
|
||||||
|
)
|
||||||
|
dec_outs.append(dec_out)
|
||||||
|
dec_atts[di] = dec_att.data
|
||||||
|
topv, topi = dec_out.data.topk(1)
|
||||||
|
if topi.item() == self.word2index[config.EOS]:
|
||||||
|
dec_words.append("<__EOS__>")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
dec_words.append(self.index2word[topi.item()])
|
||||||
|
|
||||||
|
dec_input = topi.squeeze().detach()
|
||||||
|
|
||||||
|
return dec_outs, dec_words, dec_atts[: di + 1], fixations
|
||||||
|
|
||||||
|
|
||||||
|
def load_corpus(corpus_name, splits):
|
||||||
|
if not splits:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("loading corpus")
|
||||||
|
if corpus_name == "msrpc":
|
||||||
|
load_fn = corpora.load_msrpc
|
||||||
|
elif corpus_name == "qqp":
|
||||||
|
load_fn = corpora.load_qqp
|
||||||
|
elif corpus_name == "wiki":
|
||||||
|
load_fn = corpora.load_wiki
|
||||||
|
elif corpus_name == "qqp_paws":
|
||||||
|
load_fn = corpora.load_qqp_paws
|
||||||
|
elif corpus_name == "qqp_kag":
|
||||||
|
load_fn = corpora.load_qqp_kag
|
||||||
|
elif corpus_name == "sentiment":
|
||||||
|
load_fn = corpora.load_sentiment
|
||||||
|
elif corpus_name == "stanford":
|
||||||
|
load_fn = corpora.load_stanford
|
||||||
|
elif corpus_name == "stanford_sent":
|
||||||
|
load_fn = corpora.load_stanford_sent
|
||||||
|
elif corpus_name == "tamil":
|
||||||
|
load_fn = corpora.load_tamil
|
||||||
|
elif corpus_name == "compression":
|
||||||
|
load_fn = corpora.load_compression
|
||||||
|
|
||||||
|
corpus = {}
|
||||||
|
langs = []
|
||||||
|
|
||||||
|
if "train" in splits:
|
||||||
|
train_pairs, train_lang = load_fn("train")
|
||||||
|
corpus["train"] = train_pairs
|
||||||
|
langs.append(train_lang)
|
||||||
|
if "val" in splits:
|
||||||
|
val_pairs, val_lang = load_fn("val")
|
||||||
|
corpus["val"] = val_pairs
|
||||||
|
langs.append(val_lang)
|
||||||
|
if "test" in splits:
|
||||||
|
test_pairs, test_lang = load_fn("test")
|
||||||
|
corpus["test"] = test_pairs
|
||||||
|
langs.append(test_lang)
|
||||||
|
|
||||||
|
logger.info("creating word index")
|
||||||
|
lang = langs[0]
|
||||||
|
for _lang in langs[1:]:
|
||||||
|
lang += _lang
|
||||||
|
word2index = lang.word2index
|
||||||
|
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
|
||||||
|
return corpus, word2index, index2word
|
||||||
|
|
||||||
|
|
||||||
|
def init_network(word2index):
|
||||||
|
logger.info("loading embeddings")
|
||||||
|
vocabulary = sorted(word2index.keys())
|
||||||
|
embeddings = utils.load_glove(vocabulary)
|
||||||
|
|
||||||
|
logger.info("initializing model")
|
||||||
|
network = Network(
|
||||||
|
word2index=word2index,
|
||||||
|
embeddings=embeddings,
|
||||||
|
)
|
||||||
|
network.to(config.DEV)
|
||||||
|
|
||||||
|
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
|
||||||
|
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
|
||||||
|
@click.option("-v", "--verbose", count=True)
|
||||||
|
@click.option("-d", "--debug", is_flag=True)
|
||||||
|
def main(verbose, debug):
|
||||||
|
if verbose == 0:
|
||||||
|
loglevel = logging.ERROR
|
||||||
|
elif verbose == 1:
|
||||||
|
loglevel = logging.WARN
|
||||||
|
elif verbose >= 2:
|
||||||
|
loglevel = logging.INFO
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
loglevel = logging.DEBUG
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
|
||||||
|
datefmt="%d.%m. %H:%M:%S",
|
||||||
|
level=loglevel,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("arguments: %s" % str(sys.argv))
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"-c",
|
||||||
|
"--corpus",
|
||||||
|
"corpus_name",
|
||||||
|
required=True,
|
||||||
|
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||||
|
)
|
||||||
|
@click.option("-m", "--model_name", required=True)
|
||||||
|
@click.option("-w", "--fixation_weights", required=False)
|
||||||
|
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
|
||||||
|
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||||
|
def train(corpus_name, model_name, fixation_weights, freeze_fixations, bleu):
|
||||||
|
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
|
||||||
|
train_pairs = corpus["train"]
|
||||||
|
val_pairs = corpus["val"]
|
||||||
|
network = init_network(word2index)
|
||||||
|
|
||||||
|
model_dir = os.path.join("models", model_name)
|
||||||
|
logger.debug("creating model dir %s" % model_dir)
|
||||||
|
pathlib.Path(model_dir).mkdir(parents=True)
|
||||||
|
|
||||||
|
if fixation_weights is not None:
|
||||||
|
logger.info("loading fixation prediction checkpoint")
|
||||||
|
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
|
||||||
|
if "word2index" in checkpoint:
|
||||||
|
weights = checkpoint["weights"]
|
||||||
|
else:
|
||||||
|
weights = checkpoint
|
||||||
|
|
||||||
|
# remove the embedding layer before loading
|
||||||
|
weights = {k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")}
|
||||||
|
network.fix_gen.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
if freeze_fixations:
|
||||||
|
logger.info("freezing fixation generation network")
|
||||||
|
for p in network.fix_gen.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
|
||||||
|
#optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=1e-5)
|
||||||
|
|
||||||
|
best_val_loss = None
|
||||||
|
|
||||||
|
epoch = 1
|
||||||
|
while True:
|
||||||
|
train_batch_iter = utils.pair_iter(pairs=train_pairs, word2index=word2index, shuffle=True, shuffle_pairs=False)
|
||||||
|
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||||
|
# test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||||
|
|
||||||
|
running_train_loss = 0
|
||||||
|
total_train_loss = 0
|
||||||
|
total_val_loss = 0
|
||||||
|
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
running_train_bleu = 0
|
||||||
|
total_train_bleu = 0
|
||||||
|
total_val_bleu = 0
|
||||||
|
elif bleu == "nltk":
|
||||||
|
running_train_bleu_1 = 0
|
||||||
|
running_train_bleu_2 = 0
|
||||||
|
running_train_bleu_3 = 0
|
||||||
|
running_train_bleu_4 = 0
|
||||||
|
total_train_bleu_1 = 0
|
||||||
|
total_train_bleu_2 = 0
|
||||||
|
total_train_bleu_3 = 0
|
||||||
|
total_train_bleu_4 = 0
|
||||||
|
total_val_bleu_1 = 0
|
||||||
|
total_val_bleu_2 = 0
|
||||||
|
total_val_bleu_3 = 0
|
||||||
|
total_val_bleu_4 = 0
|
||||||
|
|
||||||
|
network.train()
|
||||||
|
for i, batch in enumerate(train_batch_iter, 1):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
input, target = batch
|
||||||
|
prediction, words, attention, fixations = network(input, target)
|
||||||
|
|
||||||
|
loss = loss_fn(
|
||||||
|
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_train_loss += loss.item()
|
||||||
|
total_train_loss += loss.item()
|
||||||
|
|
||||||
|
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
|
||||||
|
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
|
||||||
|
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
|
||||||
|
running_train_bleu += bleu_score
|
||||||
|
total_train_bleu += bleu_score
|
||||||
|
elif bleu == "nltk":
|
||||||
|
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||||
|
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||||
|
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||||
|
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||||
|
running_train_bleu_1 += bleu_1_score
|
||||||
|
running_train_bleu_2 += bleu_2_score
|
||||||
|
running_train_bleu_3 += bleu_3_score
|
||||||
|
running_train_bleu_4 += bleu_4_score
|
||||||
|
total_train_bleu_1 += bleu_1_score
|
||||||
|
total_train_bleu_2 += bleu_2_score
|
||||||
|
total_train_bleu_3 += bleu_3_score
|
||||||
|
total_train_bleu_4 += bleu_4_score
|
||||||
|
# print(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist())
|
||||||
|
|
||||||
|
if i % 100 == 0:
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
|
||||||
|
elif bleu == "nltk":
|
||||||
|
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
|
||||||
|
|
||||||
|
network.eval()
|
||||||
|
with open(os.path.join(model_dir, f"debug_{epoch}_{i}.out"), "w") as h:
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
|
||||||
|
running_train_bleu = 0
|
||||||
|
elif bleu == "nltk":
|
||||||
|
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
|
||||||
|
running_train_bleu_1 = 0
|
||||||
|
running_train_bleu_2 = 0
|
||||||
|
running_train_bleu_3 = 0
|
||||||
|
running_train_bleu_4 = 0
|
||||||
|
|
||||||
|
running_train_loss = 0
|
||||||
|
|
||||||
|
h.write("\n")
|
||||||
|
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
|
||||||
|
h.write("\n")
|
||||||
|
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
|
||||||
|
prediction, words, attentions, fixations = network(input)
|
||||||
|
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
|
||||||
|
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
|
||||||
|
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||||
|
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||||
|
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
|
||||||
|
h.write("\n")
|
||||||
|
|
||||||
|
network.train()
|
||||||
|
|
||||||
|
network.eval()
|
||||||
|
for i, batch in enumerate(val_batch_iter):
|
||||||
|
input, target = batch
|
||||||
|
prediction, words, attention, fixations = network(input, target, teacher_forcing_ratio=0)
|
||||||
|
loss = loss_fn(
|
||||||
|
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
|
||||||
|
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
|
||||||
|
elif bleu == "nltk":
|
||||||
|
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||||
|
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||||
|
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||||
|
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||||
|
total_val_bleu_1 += bleu_1_score
|
||||||
|
total_val_bleu_2 += bleu_2_score
|
||||||
|
total_val_bleu_3 += bleu_3_score
|
||||||
|
total_val_bleu_4 += bleu_4_score
|
||||||
|
|
||||||
|
total_val_loss += loss.item()
|
||||||
|
|
||||||
|
avg_val_loss = total_val_loss/len(val_pairs)
|
||||||
|
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu {total_train_bleu/len(train_pairs):.2f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
|
||||||
|
elif bleu == "nltk":
|
||||||
|
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||||
|
|
||||||
|
with open(os.path.join(model_dir, f"debug_{epoch}_end.out"), "w") as h:
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs)} avg_val_loss {total_val_loss/len(val_pairs)} avg_train_bleu {total_train_bleu/len(train_pairs)} avg_val_bleu {total_val_bleu/len(val_pairs)}")
|
||||||
|
elif bleu == "nltk":
|
||||||
|
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||||
|
h.write("\n")
|
||||||
|
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
|
||||||
|
h.write("\n")
|
||||||
|
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
|
||||||
|
prediction, words, attentions, fixations = network(input)
|
||||||
|
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
|
||||||
|
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
|
||||||
|
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||||
|
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||||
|
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
|
||||||
|
h.write("\n")
|
||||||
|
|
||||||
|
utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_{epoch}"))
|
||||||
|
|
||||||
|
if best_val_loss is None or avg_val_loss < best_val_loss:
|
||||||
|
if best_val_loss is not None:
|
||||||
|
logger.info(f"{avg_val_loss} < {best_val_loss} ({avg_val_loss-best_val_loss}): new best model from epoch {epoch}")
|
||||||
|
else:
|
||||||
|
logger.info(f"{avg_val_loss} < {best_val_loss}: new best model from epoch {epoch}")
|
||||||
|
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
# save_model(model, word2index, model_name + "_epoch_" + str(epoch))
|
||||||
|
# utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_best"))
|
||||||
|
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"-c",
|
||||||
|
"--corpus",
|
||||||
|
"corpus_name",
|
||||||
|
required=True,
|
||||||
|
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||||
|
)
|
||||||
|
@click.option("-w", "--model_weights", required=True)
|
||||||
|
@click.option("-s", "--sentence_statistics", is_flag=True)
|
||||||
|
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||||
|
def val(corpus_name, model_weights, sentence_statistics, bleu):
|
||||||
|
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
|
||||||
|
val_pairs = corpus["val"]
|
||||||
|
|
||||||
|
logger.info("loading model checkpoint")
|
||||||
|
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||||
|
if "word2index" in checkpoint:
|
||||||
|
weights = checkpoint["weights"]
|
||||||
|
word2index = checkpoint["word2index"]
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
else:
|
||||||
|
asdf
|
||||||
|
|
||||||
|
network = init_network(word2index)
|
||||||
|
|
||||||
|
# remove the embedding layer before loading
|
||||||
|
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||||
|
# make a new output layer to match the weights from the checkpoint
|
||||||
|
# we cannot remove it like we did with the embedding layers because
|
||||||
|
# unlike those the output layer actually contains learned parameters
|
||||||
|
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||||
|
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||||
|
# actually load the parameters
|
||||||
|
network.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||||
|
|
||||||
|
total_val_loss = 0
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
total_val_bleu = 0
|
||||||
|
elif bleu == "nltk":
|
||||||
|
total_val_bleu_1 = 0
|
||||||
|
total_val_bleu_2 = 0
|
||||||
|
total_val_bleu_3 = 0
|
||||||
|
total_val_bleu_4 = 0
|
||||||
|
|
||||||
|
network.eval()
|
||||||
|
for i, batch in enumerate(val_batch_iter, 1):
|
||||||
|
input, target = batch
|
||||||
|
prediction, words, attentions, fixations = network(input, target)
|
||||||
|
|
||||||
|
loss = loss_fn(
|
||||||
|
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||||
|
)
|
||||||
|
total_val_loss += loss.item()
|
||||||
|
|
||||||
|
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
|
||||||
|
_target = [index2word[_x] for _x in target[0].tolist()]
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join(_target)).score
|
||||||
|
total_val_bleu += bleu_score
|
||||||
|
elif bleu == "nltk":
|
||||||
|
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||||
|
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||||
|
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||||
|
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||||
|
total_val_bleu_1 += bleu_1_score
|
||||||
|
total_val_bleu_2 += bleu_2_score
|
||||||
|
total_val_bleu_3 += bleu_3_score
|
||||||
|
total_val_bleu_4 += bleu_4_score
|
||||||
|
|
||||||
|
if sentence_statistics:
|
||||||
|
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||||
|
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||||
|
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||||
|
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
print(f"{bleu_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
|
||||||
|
elif bleu == "nltk":
|
||||||
|
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
|
||||||
|
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
|
||||||
|
elif bleu == "nltk":
|
||||||
|
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"-c",
|
||||||
|
"--corpus",
|
||||||
|
"corpus_name",
|
||||||
|
required=True,
|
||||||
|
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||||
|
)
|
||||||
|
@click.option("-w", "--model_weights", required=True)
|
||||||
|
@click.option("-s", "--sentence_statistics", is_flag=True)
|
||||||
|
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||||||
|
def test(corpus_name, model_weights, sentence_statistics, bleu):
|
||||||
|
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
||||||
|
test_pairs = corpus["test"]
|
||||||
|
|
||||||
|
if model_weights is not None:
|
||||||
|
logger.info("loading model checkpoint")
|
||||||
|
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||||
|
if "word2index" in checkpoint:
|
||||||
|
weights = checkpoint["weights"]
|
||||||
|
word2index = checkpoint["word2index"]
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
else:
|
||||||
|
asdf
|
||||||
|
|
||||||
|
network = init_network(word2index)
|
||||||
|
|
||||||
|
if model_weights is not None:
|
||||||
|
# remove the embedding layer before loading
|
||||||
|
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||||
|
# make a new output layer to match the weights from the checkpoint
|
||||||
|
# we cannot remove it like we did with the embedding layers because
|
||||||
|
# unlike those the output layer actually contains learned parameters
|
||||||
|
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||||
|
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||||
|
# actually load the parameters
|
||||||
|
network.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||||
|
|
||||||
|
total_test_loss = 0
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
total_test_bleu = 0
|
||||||
|
elif bleu == "nltk":
|
||||||
|
total_test_bleu_1 = 0
|
||||||
|
total_test_bleu_2 = 0
|
||||||
|
total_test_bleu_3 = 0
|
||||||
|
total_test_bleu_4 = 0
|
||||||
|
|
||||||
|
network.eval()
|
||||||
|
for i, batch in enumerate(test_batch_iter, 1):
|
||||||
|
input, target = batch
|
||||||
|
|
||||||
|
prediction, words, attentions, fixations = network(input, target)
|
||||||
|
|
||||||
|
loss = loss_fn(
|
||||||
|
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||||||
|
)
|
||||||
|
total_test_loss += loss.item()
|
||||||
|
|
||||||
|
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
|
||||||
|
_target = [index2word[_x] for _x in target[0].tolist()]
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join( _target)).score
|
||||||
|
total_test_bleu += bleu_score
|
||||||
|
elif bleu == "nltk":
|
||||||
|
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||||||
|
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||||||
|
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||||||
|
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||||||
|
total_test_bleu_1 += bleu_1_score
|
||||||
|
total_test_bleu_2 += bleu_2_score
|
||||||
|
total_test_bleu_3 += bleu_3_score
|
||||||
|
total_test_bleu_4 += bleu_4_score
|
||||||
|
|
||||||
|
if sentence_statistics:
|
||||||
|
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||||
|
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||||
|
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||||
|
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
print(f"{bleu_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||||
|
elif bleu == "nltk":
|
||||||
|
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||||
|
|
||||||
|
if bleu == "sacrebleu":
|
||||||
|
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu {total_test_bleu/len(test_pairs):.2f}")
|
||||||
|
elif bleu == "nltk":
|
||||||
|
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu_1 {total_test_bleu_1/len(test_pairs):.2f} avg_test_bleu_2 {total_test_bleu_2/len(test_pairs):.2f} avg_test_bleu_3 {total_test_bleu_3/len(test_pairs):.2f} avg_test_bleu_4 {total_test_bleu_4/len(test_pairs):.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"-c",
|
||||||
|
"--corpus",
|
||||||
|
"corpus_name",
|
||||||
|
required=True,
|
||||||
|
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||||||
|
)
|
||||||
|
@click.option("-w", "--model_weights", required=False)
|
||||||
|
def predict(corpus_name, model_weights):
|
||||||
|
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
|
||||||
|
test_pairs = corpus["val"]
|
||||||
|
|
||||||
|
if model_weights is not None:
|
||||||
|
logger.info("loading model checkpoint")
|
||||||
|
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||||
|
if "word2index" in checkpoint:
|
||||||
|
weights = checkpoint["weights"]
|
||||||
|
word2index = checkpoint["word2index"]
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
else:
|
||||||
|
asdf
|
||||||
|
|
||||||
|
network = init_network(word2index)
|
||||||
|
|
||||||
|
logger.info(f"vocab size {len(word2index)}")
|
||||||
|
|
||||||
|
if model_weights is not None:
|
||||||
|
# remove the embedding layer before loading
|
||||||
|
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||||
|
# make a new output layer to match the weights from the checkpoint
|
||||||
|
# we cannot remove it like we did with the embedding layers because
|
||||||
|
# unlike those the output layer actually contains learned parameters
|
||||||
|
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||||
|
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||||
|
# actually load the parameters
|
||||||
|
network.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||||||
|
|
||||||
|
network.eval()
|
||||||
|
for i, batch in enumerate(test_batch_iter, 1):
|
||||||
|
input, target = batch
|
||||||
|
|
||||||
|
prediction, words, attentions, fixations = network(input, target)
|
||||||
|
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
|
||||||
|
|
||||||
|
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||||
|
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||||
|
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||||
|
|
||||||
|
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option("-w", "--model_weights", required=True)
|
||||||
|
@click.argument("path")
|
||||||
|
def predict_file(model_weights, path):
|
||||||
|
logger.info("loading sentences")
|
||||||
|
sentences = []
|
||||||
|
lang = corpora.Lang("pred")
|
||||||
|
with open(path) as h:
|
||||||
|
for line in h:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
sentence = line.split(" ")
|
||||||
|
lang.add_sentence(sentence)
|
||||||
|
sentences.append(sentence)
|
||||||
|
word2index = lang.word2index
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
|
||||||
|
logger.info(f"{len(sentences)} sentences loaded")
|
||||||
|
|
||||||
|
logger.info("loading model checkpoint")
|
||||||
|
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||||
|
if "word2index" in checkpoint:
|
||||||
|
weights = checkpoint["weights"]
|
||||||
|
word2index = checkpoint["word2index"]
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
else:
|
||||||
|
asdf
|
||||||
|
|
||||||
|
network = init_network(word2index)
|
||||||
|
|
||||||
|
logger.info(f"vocab size {len(word2index)}")
|
||||||
|
|
||||||
|
# remove the embedding layer before loading
|
||||||
|
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||||
|
# make a new output layer to match the weights from the checkpoint
|
||||||
|
# we cannot remove it like we did with the embedding layers because
|
||||||
|
# unlike those the output layer actually contains learned parameters
|
||||||
|
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||||||
|
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||||||
|
# actually load the parameters
|
||||||
|
network.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
debug_sentence_iter = utils.sent_iter(sentences, word2index=word2index)
|
||||||
|
|
||||||
|
network.eval()
|
||||||
|
for i, input in enumerate(debug_sentence_iter, 1):
|
||||||
|
|
||||||
|
prediction, words, attentions, fixations = network(input)
|
||||||
|
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
|
||||||
|
|
||||||
|
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||||||
|
attentions = attentions.detach().cpu().squeeze().tolist()
|
||||||
|
fixations = fixations.detach().cpu().squeeze().tolist()
|
||||||
|
|
||||||
|
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
19
joint_paraphrase_model/requirements.txt
Normal file
19
joint_paraphrase_model/requirements.txt
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
click==7.1.2
|
||||||
|
cycler==0.10.0
|
||||||
|
dataclasses==0.6
|
||||||
|
future==0.18.2
|
||||||
|
joblib==0.17.0
|
||||||
|
kiwisolver==1.3.1
|
||||||
|
matplotlib==3.3.3
|
||||||
|
nltk==3.5
|
||||||
|
numpy==1.19.4
|
||||||
|
Pillow==8.0.1
|
||||||
|
portalocker==2.0.0
|
||||||
|
pyparsing==2.4.7
|
||||||
|
python-dateutil==2.8.1
|
||||||
|
regex==2020.11.13
|
||||||
|
sacrebleu==1.4.6
|
||||||
|
six==1.15.0
|
||||||
|
torch==1.7.0
|
||||||
|
tqdm==4.54.1
|
||||||
|
typing-extensions==3.7.4.3
|
43
joint_paraphrase_model/utils/long_sentence_split.py
Normal file
43
joint_paraphrase_model/utils/long_sentence_split.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
def read(path):
|
||||||
|
with open(path) as h:
|
||||||
|
for line in h:
|
||||||
|
line = line.strip()
|
||||||
|
try:
|
||||||
|
b, s, p, a, f = line.split("\t")
|
||||||
|
except:
|
||||||
|
print(f"skipping line {line}", file=sys.stderr)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield b, s, p, a, f
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("path")
|
||||||
|
def main(path):
|
||||||
|
data = list(read(path))
|
||||||
|
avg_len = sum(len(x[1]) for x in data)/len(data)
|
||||||
|
filtered_data = []
|
||||||
|
filtered_data2 = []
|
||||||
|
|
||||||
|
fname, ext = os.path.splitext(path)
|
||||||
|
ext = f".{ext}" if ext else ext
|
||||||
|
with open(f"{fname}_long{ext}", "w") as lh, open(f"{fname}_short{ext}", "w") as sh:
|
||||||
|
for x in data:
|
||||||
|
if len(x[1]) > avg_len:
|
||||||
|
lh.write("\t".join(x))
|
||||||
|
lh.write("\n")
|
||||||
|
else:
|
||||||
|
sh.write("\t".join(x))
|
||||||
|
sh.write("\n")
|
||||||
|
|
||||||
|
print(f"avg sentence length {avg_len}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
40
joint_paraphrase_model/utils/long_sentence_stats.py
Normal file
40
joint_paraphrase_model/utils/long_sentence_stats.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
def read(path):
|
||||||
|
with open(path) as h:
|
||||||
|
for line in h:
|
||||||
|
line = line.strip()
|
||||||
|
try:
|
||||||
|
b, s, p, *_ = line.split("\t")
|
||||||
|
except:
|
||||||
|
print(f"skipping line {line}", file=sys.stderr)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield float(b), s, p
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("path")
|
||||||
|
def main(path):
|
||||||
|
data = list(read(path))
|
||||||
|
avg_len = sum(len(x[1]) for x in data)/len(data)
|
||||||
|
filtered_data = []
|
||||||
|
filtered_data2 = []
|
||||||
|
for x in data:
|
||||||
|
if len(x[1]) > avg_len:
|
||||||
|
filtered_data.append(x)
|
||||||
|
else:
|
||||||
|
filtered_data2.append(x)
|
||||||
|
print(f"avg sentence length {avg_len}")
|
||||||
|
print(f"long sentences {len(filtered_data)}")
|
||||||
|
print(f"short sentences {len(filtered_data2)}")
|
||||||
|
print(f"total bleu {sum(x[0] for x in data)/len(data)}")
|
||||||
|
print(f"longest bleu {sum(x[0] for x in filtered_data)/len(filtered_data)}")
|
||||||
|
print(f"shortest bleu {sum(x[0] for x in filtered_data2)/len(filtered_data2)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
64
joint_paraphrase_model/utils/plot_attention.py
Normal file
64
joint_paraphrase_model/utils/plot_attention.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import text_attention
|
||||||
|
|
||||||
|
import click
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
plt.switch_backend("agg")
|
||||||
|
import matplotlib.ticker as ticker
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def plot_attention(input_sentence, output_words, attentions, path):
|
||||||
|
# Set up figure with colorbar
|
||||||
|
attentions = np.array(attentions)[:,:len(input_sentence)]
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
cax = ax.matshow(attentions, cmap="bone")
|
||||||
|
fig.colorbar(cax)
|
||||||
|
|
||||||
|
# Set up axes
|
||||||
|
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
|
||||||
|
ax.set_yticklabels([""] + output_words)
|
||||||
|
|
||||||
|
# Show label at every tick
|
||||||
|
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||||
|
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||||
|
|
||||||
|
plt.savefig(f"{path}.pdf")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def parse(p):
|
||||||
|
with open(p) as h:
|
||||||
|
for line in h:
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
|
||||||
|
try:
|
||||||
|
sentence = ast.literal_eval(_sentence)
|
||||||
|
prediction = ast.literal_eval(_prediction)
|
||||||
|
attention = ast.literal_eval(_attention)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield sentence, prediction, attention
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("path", nargs=-1, required=True)
|
||||||
|
def main(path):
|
||||||
|
for p in tqdm.tqdm(path):
|
||||||
|
out_dir = os.path.splitext(p)[0]
|
||||||
|
if out_dir == path:
|
||||||
|
out_dir = f"{out_dir}_"
|
||||||
|
pathlib.Path(out_dir).mkdir(exist_ok=True)
|
||||||
|
for i, spa in enumerate(parse(p)):
|
||||||
|
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
70
joint_paraphrase_model/utils/text_attention.py
Executable file
70
joint_paraphrase_model/utils/text_attention.py
Executable file
|
@ -0,0 +1,70 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Author: Jie Yang
|
||||||
|
# @Date: 2019-03-29 16:10:23
|
||||||
|
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||||
|
# @Last Modified time: 2019-04-12 09:56:12
|
||||||
|
|
||||||
|
|
||||||
|
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
latex_special_token = ["!@#$%^&*()"]
|
||||||
|
|
||||||
|
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
|
||||||
|
assert(len(text_list) == len(attention_list))
|
||||||
|
if rescale_value:
|
||||||
|
attention_list = rescale(attention_list)
|
||||||
|
word_num = len(text_list)
|
||||||
|
text_list = clean_word(text_list)
|
||||||
|
with open(latex_file,'w') as f:
|
||||||
|
f.write(r'''\documentclass[varwidth]{standalone}
|
||||||
|
\special{papersize=210mm,297mm}
|
||||||
|
\usepackage{color}
|
||||||
|
\usepackage{tcolorbox}
|
||||||
|
\usepackage{CJK}
|
||||||
|
\usepackage{adjustbox}
|
||||||
|
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
||||||
|
\begin{document}
|
||||||
|
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
|
||||||
|
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
|
||||||
|
for idx in range(word_num):
|
||||||
|
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
|
||||||
|
string += "\n}}}"
|
||||||
|
f.write(string+'\n')
|
||||||
|
f.write(r'''\end{CJK*}
|
||||||
|
\end{document}''')
|
||||||
|
|
||||||
|
def rescale(input_list):
|
||||||
|
the_array = np.asarray(input_list)
|
||||||
|
the_max = np.max(the_array)
|
||||||
|
the_min = np.min(the_array)
|
||||||
|
rescale = (the_array - the_min)/(the_max-the_min)*100
|
||||||
|
return rescale.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def clean_word(word_list):
|
||||||
|
new_word_list = []
|
||||||
|
for word in word_list:
|
||||||
|
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
|
||||||
|
if latex_sensitive in word:
|
||||||
|
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
|
||||||
|
new_word_list.append(word)
|
||||||
|
return new_word_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
## This is a demo:
|
||||||
|
|
||||||
|
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
|
||||||
|
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
|
||||||
|
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
|
||||||
|
sent = '''我 回忆 起 我 曾经 在 大学 年代 , 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar , 我 想起 了 西游记 里 的 琵琶精 。
|
||||||
|
今年 下半年 , 中 美 合拍 的 西游记 即将 正式 开机 , 我 继续 扮演 美猴王 孙悟空 , 我 会 用 美猴王 艺术 形象 努力 创造 一 个 正能量 的 形象 , 文 体 两 开花 , 弘扬 中华 文化 , 希望 大家 能 多多 关注 。'''
|
||||||
|
words = sent.split()
|
||||||
|
word_num = len(words)
|
||||||
|
attention = [(x+1.)/word_num*100 for x in range(word_num)]
|
||||||
|
import random
|
||||||
|
random.seed(42)
|
||||||
|
random.shuffle(attention)
|
||||||
|
color = 'red'
|
||||||
|
generate(words, attention, "sample.tex", color)
|
428
joint_sentence_compression_model/.gitignore
vendored
Normal file
428
joint_sentence_compression_model/.gitignore
vendored
Normal file
|
@ -0,0 +1,428 @@
|
||||||
|
|
||||||
|
# Created by https://www.toptal.com/developers/gitignore/api/python,latex
|
||||||
|
# Edit at https://www.toptal.com/developers/gitignore?templates=python,latex
|
||||||
|
|
||||||
|
### LaTeX ###
|
||||||
|
## Core latex/pdflatex auxiliary files:
|
||||||
|
*.aux
|
||||||
|
*.lof
|
||||||
|
*.log
|
||||||
|
*.lot
|
||||||
|
*.fls
|
||||||
|
*.out
|
||||||
|
*.toc
|
||||||
|
*.fmt
|
||||||
|
*.fot
|
||||||
|
*.cb
|
||||||
|
*.cb2
|
||||||
|
.*.lb
|
||||||
|
|
||||||
|
## Intermediate documents:
|
||||||
|
*.dvi
|
||||||
|
*.xdv
|
||||||
|
*-converted-to.*
|
||||||
|
# these rules might exclude image files for figures etc.
|
||||||
|
# *.ps
|
||||||
|
# *.eps
|
||||||
|
# *.pdf
|
||||||
|
|
||||||
|
## Generated if empty string is given at "Please type another file name for output:"
|
||||||
|
.pdf
|
||||||
|
|
||||||
|
## Bibliography auxiliary files (bibtex/biblatex/biber):
|
||||||
|
*.bbl
|
||||||
|
*.bcf
|
||||||
|
*.blg
|
||||||
|
*-blx.aux
|
||||||
|
*-blx.bib
|
||||||
|
*.run.xml
|
||||||
|
|
||||||
|
## Build tool auxiliary files:
|
||||||
|
*.fdb_latexmk
|
||||||
|
*.synctex
|
||||||
|
*.synctex(busy)
|
||||||
|
*.synctex.gz
|
||||||
|
*.synctex.gz(busy)
|
||||||
|
*.pdfsync
|
||||||
|
|
||||||
|
## Build tool directories for auxiliary files
|
||||||
|
# latexrun
|
||||||
|
latex.out/
|
||||||
|
|
||||||
|
## Auxiliary and intermediate files from other packages:
|
||||||
|
# algorithms
|
||||||
|
*.alg
|
||||||
|
*.loa
|
||||||
|
|
||||||
|
# achemso
|
||||||
|
acs-*.bib
|
||||||
|
|
||||||
|
# amsthm
|
||||||
|
*.thm
|
||||||
|
|
||||||
|
# beamer
|
||||||
|
*.nav
|
||||||
|
*.pre
|
||||||
|
*.snm
|
||||||
|
*.vrb
|
||||||
|
|
||||||
|
# changes
|
||||||
|
*.soc
|
||||||
|
|
||||||
|
# comment
|
||||||
|
*.cut
|
||||||
|
|
||||||
|
# cprotect
|
||||||
|
*.cpt
|
||||||
|
|
||||||
|
# elsarticle (documentclass of Elsevier journals)
|
||||||
|
*.spl
|
||||||
|
|
||||||
|
# endnotes
|
||||||
|
*.ent
|
||||||
|
|
||||||
|
# fixme
|
||||||
|
*.lox
|
||||||
|
|
||||||
|
# feynmf/feynmp
|
||||||
|
*.mf
|
||||||
|
*.mp
|
||||||
|
*.t[1-9]
|
||||||
|
*.t[1-9][0-9]
|
||||||
|
*.tfm
|
||||||
|
|
||||||
|
#(r)(e)ledmac/(r)(e)ledpar
|
||||||
|
*.end
|
||||||
|
*.?end
|
||||||
|
*.[1-9]
|
||||||
|
*.[1-9][0-9]
|
||||||
|
*.[1-9][0-9][0-9]
|
||||||
|
*.[1-9]R
|
||||||
|
*.[1-9][0-9]R
|
||||||
|
*.[1-9][0-9][0-9]R
|
||||||
|
*.eledsec[1-9]
|
||||||
|
*.eledsec[1-9]R
|
||||||
|
*.eledsec[1-9][0-9]
|
||||||
|
*.eledsec[1-9][0-9]R
|
||||||
|
*.eledsec[1-9][0-9][0-9]
|
||||||
|
*.eledsec[1-9][0-9][0-9]R
|
||||||
|
|
||||||
|
# glossaries
|
||||||
|
*.acn
|
||||||
|
*.acr
|
||||||
|
*.glg
|
||||||
|
*.glo
|
||||||
|
*.gls
|
||||||
|
*.glsdefs
|
||||||
|
*.lzo
|
||||||
|
*.lzs
|
||||||
|
|
||||||
|
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
|
||||||
|
# *.ist
|
||||||
|
|
||||||
|
# gnuplottex
|
||||||
|
*-gnuplottex-*
|
||||||
|
|
||||||
|
# gregoriotex
|
||||||
|
*.gaux
|
||||||
|
*.gtex
|
||||||
|
|
||||||
|
# htlatex
|
||||||
|
*.4ct
|
||||||
|
*.4tc
|
||||||
|
*.idv
|
||||||
|
*.lg
|
||||||
|
*.trc
|
||||||
|
*.xref
|
||||||
|
|
||||||
|
# hyperref
|
||||||
|
*.brf
|
||||||
|
|
||||||
|
# knitr
|
||||||
|
*-concordance.tex
|
||||||
|
# TODO Comment the next line if you want to keep your tikz graphics files
|
||||||
|
*.tikz
|
||||||
|
*-tikzDictionary
|
||||||
|
|
||||||
|
# listings
|
||||||
|
*.lol
|
||||||
|
|
||||||
|
# luatexja-ruby
|
||||||
|
*.ltjruby
|
||||||
|
|
||||||
|
# makeidx
|
||||||
|
*.idx
|
||||||
|
*.ilg
|
||||||
|
*.ind
|
||||||
|
|
||||||
|
# minitoc
|
||||||
|
*.maf
|
||||||
|
*.mlf
|
||||||
|
*.mlt
|
||||||
|
*.mtc[0-9]*
|
||||||
|
*.slf[0-9]*
|
||||||
|
*.slt[0-9]*
|
||||||
|
*.stc[0-9]*
|
||||||
|
|
||||||
|
# minted
|
||||||
|
_minted*
|
||||||
|
*.pyg
|
||||||
|
|
||||||
|
# morewrites
|
||||||
|
*.mw
|
||||||
|
|
||||||
|
# nomencl
|
||||||
|
*.nlg
|
||||||
|
*.nlo
|
||||||
|
*.nls
|
||||||
|
|
||||||
|
# pax
|
||||||
|
*.pax
|
||||||
|
|
||||||
|
# pdfpcnotes
|
||||||
|
*.pdfpc
|
||||||
|
|
||||||
|
# sagetex
|
||||||
|
*.sagetex.sage
|
||||||
|
*.sagetex.py
|
||||||
|
*.sagetex.scmd
|
||||||
|
|
||||||
|
# scrwfile
|
||||||
|
*.wrt
|
||||||
|
|
||||||
|
# sympy
|
||||||
|
*.sout
|
||||||
|
*.sympy
|
||||||
|
sympy-plots-for-*.tex/
|
||||||
|
|
||||||
|
# pdfcomment
|
||||||
|
*.upa
|
||||||
|
*.upb
|
||||||
|
|
||||||
|
# pythontex
|
||||||
|
*.pytxcode
|
||||||
|
pythontex-files-*/
|
||||||
|
|
||||||
|
# tcolorbox
|
||||||
|
*.listing
|
||||||
|
|
||||||
|
# thmtools
|
||||||
|
*.loe
|
||||||
|
|
||||||
|
# TikZ & PGF
|
||||||
|
*.dpth
|
||||||
|
*.md5
|
||||||
|
*.auxlock
|
||||||
|
|
||||||
|
# todonotes
|
||||||
|
*.tdo
|
||||||
|
|
||||||
|
# vhistory
|
||||||
|
*.hst
|
||||||
|
*.ver
|
||||||
|
|
||||||
|
# easy-todo
|
||||||
|
*.lod
|
||||||
|
|
||||||
|
# xcolor
|
||||||
|
*.xcp
|
||||||
|
|
||||||
|
# xmpincl
|
||||||
|
*.xmpi
|
||||||
|
|
||||||
|
# xindy
|
||||||
|
*.xdy
|
||||||
|
|
||||||
|
# xypic precompiled matrices and outlines
|
||||||
|
*.xyc
|
||||||
|
*.xyd
|
||||||
|
|
||||||
|
# endfloat
|
||||||
|
*.ttt
|
||||||
|
*.fff
|
||||||
|
|
||||||
|
# Latexian
|
||||||
|
TSWLatexianTemp*
|
||||||
|
|
||||||
|
## Editors:
|
||||||
|
# WinEdt
|
||||||
|
*.bak
|
||||||
|
*.sav
|
||||||
|
|
||||||
|
# Texpad
|
||||||
|
.texpadtmp
|
||||||
|
|
||||||
|
# LyX
|
||||||
|
*.lyx~
|
||||||
|
|
||||||
|
# Kile
|
||||||
|
*.backup
|
||||||
|
|
||||||
|
# gummi
|
||||||
|
.*.swp
|
||||||
|
|
||||||
|
# KBibTeX
|
||||||
|
*~[0-9]*
|
||||||
|
|
||||||
|
# TeXnicCenter
|
||||||
|
*.tps
|
||||||
|
|
||||||
|
# auto folder when using emacs and auctex
|
||||||
|
./auto/*
|
||||||
|
*.el
|
||||||
|
|
||||||
|
# expex forward references with \gathertags
|
||||||
|
*-tags.tex
|
||||||
|
|
||||||
|
# standalone packages
|
||||||
|
*.sta
|
||||||
|
|
||||||
|
# Makeindex log files
|
||||||
|
*.lpz
|
||||||
|
|
||||||
|
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
|
||||||
|
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
|
||||||
|
# Uncomment the next line to have this generated file ignored.
|
||||||
|
#*Notes.bib
|
||||||
|
|
||||||
|
### LaTeX Patch ###
|
||||||
|
# LIPIcs / OASIcs
|
||||||
|
*.vtc
|
||||||
|
|
||||||
|
# glossaries
|
||||||
|
*.glstex
|
||||||
|
|
||||||
|
### Python ###
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# End of https://www.toptal.com/developers/gitignore/api/python,latex
|
3
joint_sentence_compression_model/README.md
Normal file
3
joint_sentence_compression_model/README.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# joint_sentence_compression
|
||||||
|
|
||||||
|
joint training for sentence compression -- neurips submission
|
39
joint_sentence_compression_model/config.py
Normal file
39
joint_sentence_compression_model/config.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# general
|
||||||
|
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
PAD = "<__PAD__>"
|
||||||
|
UNK = "<__UNK__>"
|
||||||
|
NOFIX = "<__NOFIX__>"
|
||||||
|
SOS = "<__SOS__>"
|
||||||
|
EOS = "<__EOS__>"
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
teacher_forcing_ratio = 0.5
|
||||||
|
embedding_dim = 300
|
||||||
|
fix_hidden_dim = 128
|
||||||
|
sem_hidden_dim = 1024
|
||||||
|
fix_dropout = 0.5
|
||||||
|
par_dropout = 0.2
|
||||||
|
_fix_learning_rate = 0.00001
|
||||||
|
_par_learning_rate = 0.0001
|
||||||
|
learning_rate = _par_learning_rate
|
||||||
|
fix_momentum = 0.9
|
||||||
|
par_momentum = 0.0
|
||||||
|
max_length = 851
|
||||||
|
epochs = 5
|
||||||
|
|
||||||
|
# paths
|
||||||
|
data_path = "./data"
|
||||||
|
|
||||||
|
emb_path = os.path.join(data_path, "Google_word2vec/GoogleNews-vectors-negative300.bin")
|
||||||
|
|
||||||
|
glove_path = "glove.840B.300d.txt"
|
||||||
|
|
||||||
|
google_path = os.path.join(data_path, "datasets/sentence-compression/data")
|
||||||
|
google_train_path = os.path.join(google_path, "train_mask_token.tsv")
|
||||||
|
google_dev_path = os.path.join(google_path, "dev_mask_token.tsv")
|
||||||
|
google_test_path = os.path.join(google_path, "test_mask_token.tsv")
|
1
joint_sentence_compression_model/data
Symbolic link
1
joint_sentence_compression_model/data
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
/netpool/work/gpu-2/users/soodea/
|
1
joint_sentence_compression_model/glove.840B.300d.txt
Symbolic link
1
joint_sentence_compression_model/glove.840B.300d.txt
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
/netpool/work/gpu-2/users/soodea/datasets/glove/glove.840B.300d.txt
|
0
joint_sentence_compression_model/libs/__init__.py
Normal file
0
joint_sentence_compression_model/libs/__init__.py
Normal file
97
joint_sentence_compression_model/libs/corpora.py
Normal file
97
joint_sentence_compression_model/libs/corpora.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(sent):
|
||||||
|
return sent.split(" ")
|
||||||
|
|
||||||
|
|
||||||
|
class Lang:
|
||||||
|
"""Represents the vocabulary
|
||||||
|
"""
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
self.word2index = {
|
||||||
|
config.PAD: 0,
|
||||||
|
config.UNK: 1,
|
||||||
|
}
|
||||||
|
self.word2count = {}
|
||||||
|
self.index2word = {
|
||||||
|
0: config.PAD,
|
||||||
|
1: config.UNK,
|
||||||
|
}
|
||||||
|
self.n_words = 2
|
||||||
|
|
||||||
|
def add_sentence(self, sentence):
|
||||||
|
assert isinstance(
|
||||||
|
sentence, (list, tuple)
|
||||||
|
), "input to add_sentence must be tokenized"
|
||||||
|
for word in sentence:
|
||||||
|
self.add_word(word)
|
||||||
|
|
||||||
|
def add_word(self, word):
|
||||||
|
if word not in self.word2index:
|
||||||
|
self.word2index[word] = self.n_words
|
||||||
|
self.word2count[word] = 1
|
||||||
|
self.index2word[self.n_words] = word
|
||||||
|
self.n_words += 1
|
||||||
|
else:
|
||||||
|
self.word2count[word] += 1
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
"""Returns a new Lang object containing the vocabulary from this and
|
||||||
|
the other Lang object
|
||||||
|
"""
|
||||||
|
new_lang = Lang(f"{self.name}_{other.name}")
|
||||||
|
|
||||||
|
# Add vocabulary from both Langs
|
||||||
|
for word in self.word2count.keys():
|
||||||
|
new_lang.add_word(word)
|
||||||
|
for word in other.word2count.keys():
|
||||||
|
new_lang.add_word(word)
|
||||||
|
|
||||||
|
# Fix the counts on the new one
|
||||||
|
for word in new_lang.word2count.keys():
|
||||||
|
new_lang.word2count[word] = self.word2count.get(
|
||||||
|
word, 0
|
||||||
|
) + other.word2count.get(word, 0)
|
||||||
|
|
||||||
|
return new_lang
|
||||||
|
|
||||||
|
|
||||||
|
def load_google(split, max_len=None):
|
||||||
|
"""Load the Google Sentence Compression Dataset"""
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_compression")
|
||||||
|
lang = Lang("compression")
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
path = config.google_train_path
|
||||||
|
elif split == "val":
|
||||||
|
path = config.google_dev_path
|
||||||
|
elif split == "test":
|
||||||
|
path = config.google_test_path
|
||||||
|
|
||||||
|
logger.info("loading %s from %s" % (split, path))
|
||||||
|
|
||||||
|
data = []
|
||||||
|
sent = []
|
||||||
|
mask = []
|
||||||
|
with open(path) as handle:
|
||||||
|
for line in handle:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
w, d = line.split("\t")
|
||||||
|
sent.append(w)
|
||||||
|
mask.append(int(d))
|
||||||
|
else:
|
||||||
|
if sent and (max_len is None or len(sent) <= max_len):
|
||||||
|
data.append([sent, mask])
|
||||||
|
lang.add_sentence(sent)
|
||||||
|
sent = []
|
||||||
|
mask = []
|
||||||
|
if sent:
|
||||||
|
data.append([tuple(sent), tuple(mask)])
|
||||||
|
lang.add_sentence(sent)
|
||||||
|
|
||||||
|
return data, lang
|
|
@ -0,0 +1 @@
|
||||||
|
from .main import *
|
|
@ -0,0 +1,125 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from .self_attention import Transformer
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence, pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def random_embedding(vocab_size, embedding_dim):
|
||||||
|
pretrain_emb = np.empty([vocab_size, embedding_dim])
|
||||||
|
scale = np.sqrt(3.0 / embedding_dim)
|
||||||
|
for index in range(vocab_size):
|
||||||
|
pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
|
||||||
|
return pretrain_emb
|
||||||
|
|
||||||
|
|
||||||
|
def neg_log_likelihood_loss(outputs, batch_label, batch_size, seq_len):
|
||||||
|
outputs = outputs.view(batch_size * seq_len, -1)
|
||||||
|
score = F.log_softmax(outputs, 1)
|
||||||
|
|
||||||
|
loss = nn.NLLLoss(ignore_index=0, size_average=False)(
|
||||||
|
score, batch_label.view(batch_size * seq_len)
|
||||||
|
)
|
||||||
|
loss = loss / batch_size
|
||||||
|
_, tag_seq = torch.max(score, 1)
|
||||||
|
tag_seq = tag_seq.view(batch_size, seq_len)
|
||||||
|
|
||||||
|
return loss, tag_seq
|
||||||
|
|
||||||
|
|
||||||
|
def mse_loss(outputs, batch_label, batch_size, seq_len, word_seq_length):
|
||||||
|
score = torch.sigmoid(outputs)
|
||||||
|
|
||||||
|
mask = torch.zeros_like(score)
|
||||||
|
for i, v in enumerate(word_seq_length):
|
||||||
|
mask[i, 0:v] = 1
|
||||||
|
|
||||||
|
score = score * mask
|
||||||
|
|
||||||
|
loss = nn.MSELoss(reduction="sum")(
|
||||||
|
score.view(batch_size, seq_len), batch_label.view(batch_size, seq_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = loss / batch_size
|
||||||
|
|
||||||
|
return loss, score.view(batch_size, seq_len)
|
||||||
|
|
||||||
|
|
||||||
|
class Network(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_type,
|
||||||
|
vocab_size,
|
||||||
|
embedding_dim,
|
||||||
|
dropout,
|
||||||
|
hidden_dim,
|
||||||
|
embeddings=None,
|
||||||
|
attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.logger = logging.getLogger(f"{__name__}")
|
||||||
|
self.attention = attention
|
||||||
|
prelayers = OrderedDict()
|
||||||
|
postlayers = OrderedDict()
|
||||||
|
|
||||||
|
if embedding_type in ("w2v", "glove"):
|
||||||
|
if embeddings is not None:
|
||||||
|
prelayers["embedding_layer"] = nn.Embedding.from_pretrained(embeddings, freeze=True)
|
||||||
|
else:
|
||||||
|
prelayers["embedding_layer"] = nn.Embedding(vocab_size, embedding_dim)
|
||||||
|
prelayers["embedding_dropout_layer"] = nn.Dropout(dropout)
|
||||||
|
embedding_dim = 300
|
||||||
|
elif embedding_type == "bert":
|
||||||
|
embedding_dim = 768
|
||||||
|
|
||||||
|
self.lstm = BiLSTM(embedding_dim, hidden_dim // 2, num_layers=1)
|
||||||
|
postlayers["lstm_dropout_layer"] = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
if self.attention:
|
||||||
|
postlayers["attention_layer"] = Transformer(
|
||||||
|
d_model=hidden_dim, n_heads=4, n_layers=1
|
||||||
|
)
|
||||||
|
|
||||||
|
postlayers["ff_layer"] = nn.Linear(hidden_dim, hidden_dim // 2)
|
||||||
|
postlayers["ff_activation"] = nn.ReLU()
|
||||||
|
postlayers["output_layer"] = nn.Linear(hidden_dim // 2, 1)
|
||||||
|
|
||||||
|
self.logger.info(f"prelayers: {prelayers.keys()}")
|
||||||
|
self.logger.info(f"postlayers: {postlayers.keys()}")
|
||||||
|
|
||||||
|
self.pre = nn.Sequential(prelayers)
|
||||||
|
self.post = nn.Sequential(postlayers)
|
||||||
|
|
||||||
|
def forward(self, x, word_seq_length):
|
||||||
|
x = self.pre(x)
|
||||||
|
x = self.lstm(x, word_seq_length)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for _x, l in zip(x.transpose(1, 0), word_seq_length):
|
||||||
|
output.append(self.post(_x[:l].unsqueeze(0))[0])
|
||||||
|
|
||||||
|
return pad_sequence(output, batch_first=True)
|
||||||
|
|
||||||
|
|
||||||
|
class BiLSTM(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, lstm_hidden, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.LSTM(
|
||||||
|
input_size=embedding_dim,
|
||||||
|
hidden_size=lstm_hidden,
|
||||||
|
num_layers=num_layers,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, word_seq_length):
|
||||||
|
packed_words = pack_padded_sequence(x, word_seq_length, True, False)
|
||||||
|
lstm_out, hidden = self.net(packed_words)
|
||||||
|
lstm_out, _ = pad_packed_sequence(lstm_out)
|
||||||
|
return lstm_out
|
|
@ -0,0 +1,128 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, d_hid, n_position=200):
|
||||||
|
super(PositionalEncoding, self).__init__()
|
||||||
|
|
||||||
|
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||||
|
|
||||||
|
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||||
|
''' Sinusoid position encoding table '''
|
||||||
|
def get_position_angle_vec(position):
|
||||||
|
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||||
|
|
||||||
|
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||||
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||||
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||||
|
|
||||||
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayer(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(AttentionLayer, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, Q, K, V):
|
||||||
|
# Q: float32:[batch_size, n_queries, d_k]
|
||||||
|
# K: float32:[batch_size, n_keys, d_k]
|
||||||
|
# V: float32:[batch_size, n_keys, d_v]
|
||||||
|
dk = K.shape[-1]
|
||||||
|
dv = V.shape[-1]
|
||||||
|
KT = torch.transpose(K, -1, -2)
|
||||||
|
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
|
||||||
|
# weight_logits: float32[batch_size, n_queries, n_keys]
|
||||||
|
weights = F.softmax(weight_logits, dim=-1)
|
||||||
|
# weight: float32[batch_size, n_queries, n_keys]
|
||||||
|
return torch.bmm(weights, V)
|
||||||
|
# return float32[batch_size, n_queries, dv]
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadedSelfAttentionLayer(nn.Module):
|
||||||
|
def __init__(self, d_model, n_heads):
|
||||||
|
super(MultiHeadedSelfAttentionLayer, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
print('{} {}'.format(d_model, n_heads))
|
||||||
|
assert d_model % n_heads == 0
|
||||||
|
self.d_k = d_model // n_heads
|
||||||
|
self.d_v = self.d_k
|
||||||
|
self.attention_layer = AttentionLayer()
|
||||||
|
self.W_Qs = nn.ModuleList([
|
||||||
|
nn.Linear(d_model, self.d_k, bias=False)
|
||||||
|
for _ in range(n_heads)
|
||||||
|
])
|
||||||
|
self.W_Ks = nn.ModuleList([
|
||||||
|
nn.Linear(d_model, self.d_k, bias=False)
|
||||||
|
for _ in range(n_heads)
|
||||||
|
])
|
||||||
|
self.W_Vs = nn.ModuleList([
|
||||||
|
nn.Linear(d_model, self.d_v, bias=False)
|
||||||
|
for _ in range(n_heads)
|
||||||
|
])
|
||||||
|
self.W_O = nn.Linear(d_model, d_model, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x:float32[batch_size, sequence_length, self.d_model]
|
||||||
|
head_outputs = []
|
||||||
|
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
|
||||||
|
Q = W_Q(x)
|
||||||
|
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||||
|
K = W_K(x)
|
||||||
|
# Q float32:[batch_size, sequence_length, self.d_k]
|
||||||
|
V = W_V(x)
|
||||||
|
# Q float32:[batch_size, sequence_length, self.d_v]
|
||||||
|
head_output = self.attention_layer(Q, K, V)
|
||||||
|
# float32:[batch_size, sequence_length, self.d_v]
|
||||||
|
head_outputs.append(head_output)
|
||||||
|
concatenated = torch.cat(head_outputs, dim=-1)
|
||||||
|
# concatenated float32:[batch_size, sequence_length, self.d_model]
|
||||||
|
out = self.W_O(concatenated)
|
||||||
|
# out float32:[batch_size, sequence_length, self.d_model]
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Feedforward(nn.Module):
|
||||||
|
def __init__(self, d_model):
|
||||||
|
super(Feedforward, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.W1 = nn.Linear(d_model, d_model)
|
||||||
|
self.W2 = nn.Linear(d_model, d_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: float32[batch_size, sequence_length, d_model]
|
||||||
|
return self.W2(torch.relu(self.W1(x)))
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, d_model, n_heads, n_layers):
|
||||||
|
super(Transformer, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.attention_layers = nn.ModuleList([
|
||||||
|
MultiHeadedSelfAttentionLayer(d_model, n_heads)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
])
|
||||||
|
self.ffs = nn.ModuleList([
|
||||||
|
Feedforward(d_model)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: float32[batch_size, sequence_length, self.d_model]
|
||||||
|
for attention_layer, ff in zip(self.attention_layers, self.ffs):
|
||||||
|
attention_out = attention_layer(x)
|
||||||
|
# attention_out: float32[batch_size, sequence_length, self.d_model]
|
||||||
|
x = F.layer_norm(x + attention_out, x.shape[2:])
|
||||||
|
ff_out = ff(x)
|
||||||
|
# ff_out: float32[batch_size, sequence_length, self.d_model]
|
||||||
|
x = F.layer_norm(x + ff_out, x.shape[2:])
|
||||||
|
return x
|
|
@ -0,0 +1,29 @@
|
||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2018, Tatsuya Aoki
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
* Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
* Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Simple Model for Sentence Compression
|
||||||
|
3-layered BILSTM model for sentence compression, referred as Baseline in [Klerke et al., NAACL 2016](http://aclweb.org/anthology/N/N16/N16-1179.pdf).
|
||||||
|
## Requirements
|
||||||
|
### Framework
|
||||||
|
- python (<= 3.6)
|
||||||
|
- pytorch (<= 0.3.0)
|
||||||
|
|
||||||
|
### Packages
|
||||||
|
- torchtext
|
||||||
|
|
||||||
|
## How to run
|
||||||
|
```
|
||||||
|
./getdata
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
To run the scripts with gpu, use this command `python main.py --gpu-id ID`, which ID is the integer from 0 to the number of gpus what you have.
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
```
|
||||||
|
@InProceedings{klerke-goldberg-sogaard:2016:N16-1,
|
||||||
|
author = {Klerke, Sigrid and Goldberg, Yoav and S{\o}gaard, Anders},
|
||||||
|
title = {Improving sentence compression by learning to predict gaze},
|
||||||
|
booktitle = {Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
|
||||||
|
month = {June},
|
||||||
|
year = {2016},
|
||||||
|
address = {San Diego, California},
|
||||||
|
publisher = {Association for Computational Linguistics},
|
||||||
|
pages = {1528--1533},
|
||||||
|
url = {http://www.aclweb.org/anthology/N16-1179}
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1 @@
|
||||||
|
from .main import *
|
|
@ -0,0 +1,95 @@
|
||||||
|
from torchtext import data
|
||||||
|
from const import Phase
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(data: dict, batch_size: int, device: int):
|
||||||
|
|
||||||
|
train = Dataset(data[Phase.TRAIN]['tokens'],
|
||||||
|
data[Phase.TRAIN]['labels'],
|
||||||
|
vocab=None,
|
||||||
|
batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
phase=Phase.TRAIN)
|
||||||
|
|
||||||
|
dev = Dataset(data[Phase.DEV]['tokens'],
|
||||||
|
data[Phase.DEV]['labels'],
|
||||||
|
vocab=train.vocab,
|
||||||
|
batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
phase=Phase.DEV)
|
||||||
|
|
||||||
|
test = Dataset(data[Phase.TEST]['tokens'],
|
||||||
|
data[Phase.TEST]['labels'],
|
||||||
|
vocab=train.vocab,
|
||||||
|
batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
phase=Phase.TEST)
|
||||||
|
return train, dev, test
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset:
|
||||||
|
def __init__(self,
|
||||||
|
tokens: list,
|
||||||
|
label_list: list,
|
||||||
|
vocab: list,
|
||||||
|
batch_size: int,
|
||||||
|
device: int,
|
||||||
|
phase: Phase):
|
||||||
|
assert len(tokens) == len(label_list), \
|
||||||
|
'the number of sentences and the number of POS/head sequences \
|
||||||
|
should be the same length'
|
||||||
|
|
||||||
|
self.pad_token = '<PAD>'
|
||||||
|
# self.unk_token = '<UNK>'
|
||||||
|
self.tokens = tokens
|
||||||
|
self.label_list = label_list
|
||||||
|
self.sentence_id = [[i] for i in range(len(tokens))]
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.token_field = data.Field(use_vocab=True,
|
||||||
|
# unk_token=self.unk_token,
|
||||||
|
pad_token=self.pad_token,
|
||||||
|
batch_first=True)
|
||||||
|
self.label_field = data.Field(use_vocab=False, pad_token=-1, batch_first=True)
|
||||||
|
self.sentence_id_field = data.Field(use_vocab=False, batch_first=True)
|
||||||
|
self.dataset = self._create_dataset()
|
||||||
|
|
||||||
|
if vocab is None:
|
||||||
|
self.token_field.build_vocab(self.tokens)
|
||||||
|
self.vocab = self.token_field.vocab
|
||||||
|
else:
|
||||||
|
self.token_field.vocab = vocab
|
||||||
|
self.vocab = vocab
|
||||||
|
self.pad_index = self.token_field.vocab.stoi[self.pad_token]
|
||||||
|
|
||||||
|
self._set_batch_iter(batch_size, phase)
|
||||||
|
|
||||||
|
def get_raw_sentence(self, sentences):
|
||||||
|
return [[self.vocab.itos[idx] for idx in sentence]
|
||||||
|
for sentence in sentences]
|
||||||
|
|
||||||
|
def _create_dataset(self):
|
||||||
|
_fields = [('token', self.token_field),
|
||||||
|
('label', self.label_field),
|
||||||
|
('sentence_id', self.sentence_id_field)]
|
||||||
|
return data.Dataset(self._get_examples(_fields), _fields)
|
||||||
|
|
||||||
|
def _get_examples(self, fields: list):
|
||||||
|
ex = []
|
||||||
|
for sentence, label, sentence_id in zip(self.tokens, self.label_list, self.sentence_id):
|
||||||
|
ex.append(data.Example.fromlist([sentence, label, sentence_id], fields))
|
||||||
|
return ex
|
||||||
|
|
||||||
|
def _set_batch_iter(self, batch_size: int, phase: Phase):
|
||||||
|
|
||||||
|
def sort(data: data.Dataset) -> int:
|
||||||
|
return len(getattr(data, 'token'))
|
||||||
|
|
||||||
|
train = True if phase == Phase.TRAIN else False
|
||||||
|
|
||||||
|
self.batch_iter = data.BucketIterator(dataset=self.dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sort_key=sort,
|
||||||
|
train=train,
|
||||||
|
repeat=False,
|
||||||
|
device=self.device)
|
|
@ -0,0 +1,8 @@
|
||||||
|
from enum import Enum, unique
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class Phase(Enum):
|
||||||
|
TRAIN = 'train'
|
||||||
|
DEV = 'dev'
|
||||||
|
TEST = 'test'
|
|
@ -0,0 +1,92 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
|
||||||
|
class Network(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
embeddings,
|
||||||
|
hidden_size: int,
|
||||||
|
prior,
|
||||||
|
device: torch.device):
|
||||||
|
|
||||||
|
super(Network, self).__init__()
|
||||||
|
self.device = device
|
||||||
|
self.priors = torch.log(torch.tensor([prior, 1-prior])).to(device)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bilstm_layers = 3
|
||||||
|
self.bilstm_input_size = 300
|
||||||
|
self.bilstm_output_size = 2 * hidden_size
|
||||||
|
self.word_emb = nn.Embedding.from_pretrained(embeddings, freeze=False)
|
||||||
|
self.bilstm = nn.LSTM(self.bilstm_input_size,
|
||||||
|
self.hidden_size,
|
||||||
|
num_layers=self.bilstm_layers,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=0.1, #ms best mod 0.1
|
||||||
|
bidirectional=True)
|
||||||
|
self.dropout = nn.Dropout(p=0.35)
|
||||||
|
if self.attention:
|
||||||
|
self.attention_size = self.bilstm_output_size * 2
|
||||||
|
self.u_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size)
|
||||||
|
self.w_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size)
|
||||||
|
self.v_a_inv = nn.Linear(self.bilstm_output_size, 1, bias=False)
|
||||||
|
self.linear_attn = nn.Linear(self.attention_size, self.bilstm_output_size)
|
||||||
|
self.linear = nn.Linear(self.bilstm_output_size, self.hidden_size)
|
||||||
|
self.pred = nn.Linear(self.hidden_size, 2)
|
||||||
|
self.softmax = nn.LogSoftmax(dim=1)
|
||||||
|
self.criterion = nn.NLLLoss(ignore_index=-1)
|
||||||
|
|
||||||
|
def forward(self, input_tokens, labels, fixations=None):
|
||||||
|
loss = 0.0
|
||||||
|
preds = []
|
||||||
|
atts = []
|
||||||
|
batch_size, seq_len = input_tokens.size()
|
||||||
|
self.init_hidden(batch_size, device=self.device)
|
||||||
|
|
||||||
|
x_i = self.word_emb(input_tokens)
|
||||||
|
x_i = self.dropout(x_i)
|
||||||
|
|
||||||
|
hidden, (self.h_n, self.c_n) = self.bilstm(x_i, (self.h_n, self.c_n))
|
||||||
|
_, _, hidden_size = hidden.size()
|
||||||
|
|
||||||
|
for i in range(seq_len):
|
||||||
|
nth_hidden = hidden[:, i, :]
|
||||||
|
if self.attention:
|
||||||
|
target = nth_hidden.expand(seq_len, batch_size, -1).transpose(0, 1)
|
||||||
|
mask = hidden.eq(target)[:, :, 0].unsqueeze(2)
|
||||||
|
attn_weight = self.attention(hidden, target, fixations, mask)
|
||||||
|
context_vector = torch.bmm(attn_weight.transpose(1, 2), hidden).squeeze(1)
|
||||||
|
|
||||||
|
nth_hidden = torch.tanh(self.linear_attn(torch.cat((nth_hidden, context_vector), -1)))
|
||||||
|
atts.append(attn_weight.detach().cpu())
|
||||||
|
logits = self.pred(self.linear(nth_hidden))
|
||||||
|
if not self.training:
|
||||||
|
logits = logits + self.priors
|
||||||
|
output = self.softmax(logits)
|
||||||
|
loss += self.criterion(output, labels[:, i])
|
||||||
|
|
||||||
|
_, topi = output.topk(k=1, dim=1)
|
||||||
|
pred = topi.squeeze(-1)
|
||||||
|
preds.append(pred)
|
||||||
|
|
||||||
|
preds = torch.stack(torch.cat(preds, dim=0).split(batch_size), dim=1)
|
||||||
|
|
||||||
|
if atts:
|
||||||
|
atts = torch.stack(torch.cat(atts, dim=0).split(batch_size), dim=1)
|
||||||
|
|
||||||
|
return loss, preds, atts
|
||||||
|
|
||||||
|
def attention(self, source, target, fixations=None, mask=None):
|
||||||
|
function_g = \
|
||||||
|
self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
|
||||||
|
if mask is not None:
|
||||||
|
function_g.masked_fill_(mask, -1e4)
|
||||||
|
if fixations is not None:
|
||||||
|
function_g = function_g*fixations
|
||||||
|
return nn.functional.softmax(function_g, dim=1)
|
||||||
|
|
||||||
|
def init_hidden(self, batch_size, device):
|
||||||
|
zeros = Variable(torch.zeros(2*self.bilstm_layers, batch_size, self.hidden_size))
|
||||||
|
self.h_n = zeros.to(device)
|
||||||
|
self.c_n = zeros.to(device)
|
||||||
|
return self.h_n, self.c_n
|
|
@ -0,0 +1,183 @@
|
||||||
|
import torch
|
||||||
|
from torch import optim
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from const import Phase
|
||||||
|
from batch import create_dataset
|
||||||
|
from models import Baseline
|
||||||
|
from sklearn.metrics import classification_report
|
||||||
|
|
||||||
|
|
||||||
|
def run(dataset_train,
|
||||||
|
dataset_dev,
|
||||||
|
dataset_test,
|
||||||
|
model_type,
|
||||||
|
word_embed_size,
|
||||||
|
hidden_size,
|
||||||
|
batch_size,
|
||||||
|
device,
|
||||||
|
n_epochs):
|
||||||
|
|
||||||
|
if model_type == 'base':
|
||||||
|
model = Baseline(vocab=dataset_train.vocab,
|
||||||
|
word_embed_size=word_embed_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
device=device,
|
||||||
|
inference=False)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
optim_params = model.parameters()
|
||||||
|
optimizer = optim.Adam(optim_params, lr=10**-3)
|
||||||
|
|
||||||
|
print('start training')
|
||||||
|
for epoch in range(n_epochs):
|
||||||
|
train_loss, tokens, preds, golds = train(dataset_train,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
batch_size,
|
||||||
|
epoch,
|
||||||
|
Phase.TRAIN,
|
||||||
|
device)
|
||||||
|
|
||||||
|
dev_loss, tokens, preds, golds = train(dataset_dev,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
batch_size,
|
||||||
|
epoch,
|
||||||
|
Phase.DEV,
|
||||||
|
device)
|
||||||
|
logger = '\t'.join(['epoch {}'.format(epoch+1),
|
||||||
|
'TRAIN Loss: {:.9f}'.format(train_loss),
|
||||||
|
'DEV Loss: {:.9f}'.format(dev_loss)])
|
||||||
|
# print('\r'+logger, end='')
|
||||||
|
print(logger)
|
||||||
|
test_loss, tokens, preds, golds = train(dataset_test,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
batch_size,
|
||||||
|
epoch,
|
||||||
|
Phase.TEST,
|
||||||
|
device)
|
||||||
|
print('====', 'TEST', '=====')
|
||||||
|
print_scores(preds, golds)
|
||||||
|
output_results(tokens, preds, golds)
|
||||||
|
|
||||||
|
|
||||||
|
def train(dataset,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
batch_size,
|
||||||
|
n_epoch,
|
||||||
|
phase,
|
||||||
|
device):
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
tokens = []
|
||||||
|
preds = []
|
||||||
|
labels = []
|
||||||
|
if phase == Phase.TRAIN:
|
||||||
|
model.train()
|
||||||
|
else:
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for batch in tqdm.tqdm(dataset.batch_iter):
|
||||||
|
token = getattr(batch, 'token')
|
||||||
|
label = getattr(batch, 'label')
|
||||||
|
raw_sentences = dataset.get_raw_sentence(token.data.detach().cpu().numpy())
|
||||||
|
|
||||||
|
loss, pred = \
|
||||||
|
model(token, raw_sentences, label, phase)
|
||||||
|
|
||||||
|
if phase == Phase.TRAIN:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=5)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# remove PAD from input sentences/labels and results
|
||||||
|
mask = (token != dataset.pad_index)
|
||||||
|
length_tensor = mask.sum(1)
|
||||||
|
length_tensor = length_tensor.data.detach().cpu().numpy()
|
||||||
|
|
||||||
|
for index, n_tokens_in_the_sentence in enumerate(length_tensor):
|
||||||
|
if n_tokens_in_the_sentence > 0:
|
||||||
|
tokens.append(raw_sentences[index][:n_tokens_in_the_sentence])
|
||||||
|
_label = label[index][:n_tokens_in_the_sentence]
|
||||||
|
_pred = pred[index][:n_tokens_in_the_sentence]
|
||||||
|
_label = _label.data.detach().cpu().numpy()
|
||||||
|
_pred = _pred.data.detach().cpu().numpy()
|
||||||
|
labels.append(_label)
|
||||||
|
preds.append(_pred)
|
||||||
|
|
||||||
|
total_loss += torch.mean(loss).item()
|
||||||
|
|
||||||
|
return total_loss, tokens, preds, labels
|
||||||
|
|
||||||
|
|
||||||
|
def read_two_cols_data(fname, max_len=None):
|
||||||
|
data = {}
|
||||||
|
tokens = []
|
||||||
|
labels = []
|
||||||
|
token = []
|
||||||
|
label = []
|
||||||
|
with open(fname, mode='r') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip().lower().split()
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
_token, _label = line
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
token.append(_token)
|
||||||
|
if _label == '0' or _label == '1':
|
||||||
|
label.append(int(_label))
|
||||||
|
else:
|
||||||
|
if _label == 'del':
|
||||||
|
label.append(1)
|
||||||
|
else:
|
||||||
|
label.append(0)
|
||||||
|
else:
|
||||||
|
if max_len is None or len(token) <= max_len:
|
||||||
|
tokens.append(token)
|
||||||
|
labels.append(label)
|
||||||
|
token = []
|
||||||
|
label = []
|
||||||
|
|
||||||
|
data['tokens'] = tokens
|
||||||
|
data['labels'] = labels
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load(train_path, dev_path, test_path, batch_size, max_len, device):
|
||||||
|
train = read_two_cols_data(train_path, max_len)
|
||||||
|
dev = read_two_cols_data(dev_path)
|
||||||
|
test = read_two_cols_data(test_path)
|
||||||
|
data = {Phase.TRAIN: train, Phase.DEV: dev, Phase.TEST: test}
|
||||||
|
return create_dataset(data, batch_size=batch_size, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def print_scores(preds, golds):
|
||||||
|
_preds = [label for sublist in preds for label in sublist]
|
||||||
|
_golds = [label for sublist in golds for label in sublist]
|
||||||
|
target_names = ['not_del', 'del']
|
||||||
|
print(classification_report(_golds, _preds, target_names=target_names, digits=5))
|
||||||
|
|
||||||
|
|
||||||
|
def output_results(tokens, preds, golds, path='./result/sentcomp'):
|
||||||
|
with open(path+'.original.txt', mode='w') as w, \
|
||||||
|
open(path+'.gold.txt', mode='w') as w_gold, \
|
||||||
|
open(path+'.pred.txt', mode='w') as w_pred:
|
||||||
|
|
||||||
|
for _tokens, _golds, _preds in zip(tokens, golds, preds):
|
||||||
|
for token, gold, pred in zip(_tokens, _golds, _preds):
|
||||||
|
w.write(token + ' ')
|
||||||
|
if gold == 0:
|
||||||
|
w_gold.write(token + ' ')
|
||||||
|
# 0 -> keep, 1 -> delete
|
||||||
|
if pred == 0:
|
||||||
|
w_pred.write(token + ' ')
|
||||||
|
w.write('\n')
|
||||||
|
w_gold.write('\n')
|
||||||
|
w_pred.write('\n')
|
218
joint_sentence_compression_model/libs/utils.py
Normal file
218
joint_sentence_compression_model/libs/utils.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.ticker as ticker
|
||||||
|
from nltk.translate.bleu_score import sentence_bleu
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
plt.switch_backend("agg")
|
||||||
|
|
||||||
|
|
||||||
|
def load_glove(vocabulary):
|
||||||
|
logger = logging.getLogger(f"{__name__}.load_glove")
|
||||||
|
logger.info("loading embeddings")
|
||||||
|
try:
|
||||||
|
with open(f"glove.cache") as h:
|
||||||
|
cache = json.load(h)
|
||||||
|
except:
|
||||||
|
logger.info("cache doesn't exist")
|
||||||
|
cache = {}
|
||||||
|
cache[config.PAD] = [0] * 300
|
||||||
|
cache[config.SOS] = [0] * 300
|
||||||
|
cache[config.EOS] = [0] * 300
|
||||||
|
cache[config.UNK] = [0] * 300
|
||||||
|
cache[config.NOFIX] = [0] * 300
|
||||||
|
else:
|
||||||
|
logger.info("cache found")
|
||||||
|
|
||||||
|
cache_miss = False
|
||||||
|
|
||||||
|
if not set(vocabulary) <= set(cache):
|
||||||
|
cache_miss = True
|
||||||
|
logger.warn("cache miss, loading full embeddings")
|
||||||
|
data = {}
|
||||||
|
with open("glove.840B.300d.txt") as h:
|
||||||
|
for line in h:
|
||||||
|
word, *emb = line.strip().split()
|
||||||
|
try:
|
||||||
|
data[word] = [float(x) for x in emb]
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
logger.info("finished loading full embeddings")
|
||||||
|
for word in vocabulary:
|
||||||
|
try:
|
||||||
|
cache[word] = data[word]
|
||||||
|
except KeyError:
|
||||||
|
cache[word] = [0] * 300
|
||||||
|
logger.info("cache updated")
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for word in vocabulary:
|
||||||
|
embeddings.append(torch.tensor(cache[word], dtype=torch.float32))
|
||||||
|
embeddings = torch.stack(embeddings)
|
||||||
|
|
||||||
|
if cache_miss:
|
||||||
|
with open(f"glove.cache", "w") as h:
|
||||||
|
json.dump(cache, h)
|
||||||
|
logger.info("cache saved")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(s):
|
||||||
|
s = s.lower().strip()
|
||||||
|
s = re.sub(r"([.!?])", r" \1", s)
|
||||||
|
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
|
||||||
|
s = s.split(" ")
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def indices_from_sentence(word2index, sentence, unknown_threshold):
|
||||||
|
if unknown_threshold:
|
||||||
|
return [
|
||||||
|
word2index.get(
|
||||||
|
word if random.random() > unknown_threshold else config.UNK,
|
||||||
|
word2index[config.UNK],
|
||||||
|
)
|
||||||
|
for word in sentence
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
word2index.get(word, word2index[config.UNK]) for word in sentence
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_from_sentence(word2index, sentence, unknown_threshold):
|
||||||
|
indices = indices_from_sentence(word2index, sentence, unknown_threshold)
|
||||||
|
return torch.tensor(indices, dtype=torch.long, device=config.DEV)
|
||||||
|
|
||||||
|
|
||||||
|
def tensors_from_pair(word2index, pair, shuffle, unknown_threshold):
|
||||||
|
tensors = [
|
||||||
|
tensor_from_sentence(word2index, pair[0], unknown_threshold),
|
||||||
|
tensor_from_sentence(word2index, pair[1], unknown_threshold),
|
||||||
|
]
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(tensors)
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
def bleu(reference, hypothesis, n=4):
|
||||||
|
if n < 1:
|
||||||
|
return 0
|
||||||
|
weights = [1/n]*n
|
||||||
|
return sentence_bleu([reference], hypothesis, weights)
|
||||||
|
|
||||||
|
|
||||||
|
def pair_iter(pairs, word2index, shuffle=False, shuffle_pairs=False, unknown_threshold=0.00):
|
||||||
|
if shuffle:
|
||||||
|
pairs = pairs.copy()
|
||||||
|
random.shuffle(pairs)
|
||||||
|
for pair in pairs:
|
||||||
|
tensor1, tensor2 = tensors_from_pair(word2index, (pair[0], pair[1]), shuffle_pairs, unknown_threshold)
|
||||||
|
yield (tensor1,), (tensor2,)
|
||||||
|
|
||||||
|
|
||||||
|
def sent_iter(sents, word2index, batch_size, unknown_threshold=0.00):
|
||||||
|
for i in range(len(sents)//batch_size+1):
|
||||||
|
raw_sents = [x[0] for x in sents[i*batch_size:i*batch_size+batch_size]]
|
||||||
|
_sents = [tensor_from_sentence(word2index, sent, unknown_threshold) for sent, target in sents[i*batch_size:i*batch_size+batch_size]]
|
||||||
|
_targets = [torch.tensor(target, dtype=torch.long).to(config.DEV) for sent, target in sents[i*batch_size:i*batch_size+batch_size]]
|
||||||
|
if raw_sents and _sents and _targets:
|
||||||
|
yield(raw_sents, _sents, _targets)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_iter(pairs, word2index, batch_size, shuffle=False, unknown_threshold=0.00):
|
||||||
|
for i in range(len(pairs) // batch_size):
|
||||||
|
batch = pairs[i : i + batch_size]
|
||||||
|
if len(batch) != batch_size:
|
||||||
|
continue
|
||||||
|
batch_tensors = [
|
||||||
|
tensors_from_pair(word2index, (pair[0], pair[1]), shuffle, unknown_threshold)
|
||||||
|
for pair in batch
|
||||||
|
]
|
||||||
|
|
||||||
|
tensors1, tensors2 = zip(*batch_tensors)
|
||||||
|
|
||||||
|
yield tensors1, tensors2
|
||||||
|
|
||||||
|
|
||||||
|
def asMinutes(s):
|
||||||
|
m = math.floor(s / 60)
|
||||||
|
s -= m * 60
|
||||||
|
return "%dm %ds" % (m, s)
|
||||||
|
|
||||||
|
|
||||||
|
def timeSince(since, percent):
|
||||||
|
now = time.time()
|
||||||
|
s = now - since
|
||||||
|
es = s / (percent)
|
||||||
|
rs = es - s
|
||||||
|
return "%s (- %s)" % (asMinutes(s), asMinutes(rs))
|
||||||
|
|
||||||
|
|
||||||
|
def showPlot(points):
|
||||||
|
plt.figure()
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
# this locator puts ticks at regular intervals
|
||||||
|
loc = ticker.MultipleLocator(base=0.2)
|
||||||
|
ax.yaxis.set_major_locator(loc)
|
||||||
|
plt.plot(points)
|
||||||
|
|
||||||
|
|
||||||
|
def showAttention(input_sentence, output_words, attentions):
|
||||||
|
# Set up figure with colorbar
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
cax = ax.matshow(attentions.numpy(), cmap="bone")
|
||||||
|
fig.colorbar(cax)
|
||||||
|
|
||||||
|
# Set up axes
|
||||||
|
ax.set_xticklabels([""] + input_sentence.split(" ") + ["<__EOS__>"], rotation=90)
|
||||||
|
ax.set_yticklabels([""] + output_words)
|
||||||
|
|
||||||
|
# Show label at every tick
|
||||||
|
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||||
|
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def evaluateAndShowAttention(input_sentence):
|
||||||
|
output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence)
|
||||||
|
print("input =", input_sentence)
|
||||||
|
print("output =", " ".join(output_words))
|
||||||
|
showAttention(input_sentence, output_words, attentions)
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model, word2index, path):
|
||||||
|
if not path.endswith(".tar"):
|
||||||
|
path += ".tar"
|
||||||
|
torch.save(
|
||||||
|
{"weights": model.state_dict(), "word2index": word2index},
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(path):
|
||||||
|
checkpoint = torch.load(path)
|
||||||
|
return checkpoint["weights"], checkpoint["word2index"]
|
||||||
|
|
||||||
|
|
||||||
|
def extend_vocabulary(word2index, langs):
|
||||||
|
for lang in langs:
|
||||||
|
for word in lang.word2index:
|
||||||
|
if word not in word2index:
|
||||||
|
word2index[word] = len(word2index)
|
||||||
|
return word2index
|
495
joint_sentence_compression_model/main.py
Normal file
495
joint_sentence_compression_model/main.py
Normal file
|
@ -0,0 +1,495 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import click
|
||||||
|
import sacrebleu
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import config
|
||||||
|
from libs import corpora
|
||||||
|
from libs import utils
|
||||||
|
from libs.fixation_generation import Network as FixNN
|
||||||
|
from libs.sentence_compression import Network as ComNN
|
||||||
|
from sklearn.metrics import classification_report, precision_recall_fscore_support
|
||||||
|
|
||||||
|
|
||||||
|
cwd = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
logger = logging.getLogger("main")
|
||||||
|
|
||||||
|
|
||||||
|
class Network(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, word2index, embeddings, prior,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.logger = logging.getLogger(f"{__name__}")
|
||||||
|
self.word2index = word2index
|
||||||
|
self.index2word = {i: k for k, i in word2index.items()}
|
||||||
|
self.fix_gen = FixNN(
|
||||||
|
embedding_type="glove",
|
||||||
|
vocab_size=len(word2index),
|
||||||
|
embedding_dim=config.embedding_dim,
|
||||||
|
embeddings=embeddings,
|
||||||
|
dropout=config.fix_dropout,
|
||||||
|
hidden_dim=config.fix_hidden_dim,
|
||||||
|
)
|
||||||
|
self.com_nn = ComNN(
|
||||||
|
embeddings=embeddings, hidden_size=config.sem_hidden_dim, prior=prior, device=config.DEV
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, target, seq_lens):
|
||||||
|
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
|
||||||
|
target = nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=-1)
|
||||||
|
|
||||||
|
fixations = torch.sigmoid(self.fix_gen(x1, seq_lens))
|
||||||
|
# fixations = None
|
||||||
|
|
||||||
|
loss, pred, atts = self.com_nn(x1, target, fixations)
|
||||||
|
return loss, pred, atts, fixations
|
||||||
|
|
||||||
|
|
||||||
|
def load_corpus(corpus_name, splits):
|
||||||
|
if not splits:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("loading corpus")
|
||||||
|
if corpus_name == "google":
|
||||||
|
load_fn = corpora.load_google
|
||||||
|
|
||||||
|
corpus = {}
|
||||||
|
langs = []
|
||||||
|
|
||||||
|
if "train" in splits:
|
||||||
|
train_pairs, train_lang = load_fn("train", max_len=200)
|
||||||
|
corpus["train"] = train_pairs
|
||||||
|
langs.append(train_lang)
|
||||||
|
if "val" in splits:
|
||||||
|
val_pairs, val_lang = load_fn("val")
|
||||||
|
corpus["val"] = val_pairs
|
||||||
|
langs.append(val_lang)
|
||||||
|
if "test" in splits:
|
||||||
|
test_pairs, test_lang = load_fn("test")
|
||||||
|
corpus["test"] = test_pairs
|
||||||
|
langs.append(test_lang)
|
||||||
|
|
||||||
|
logger.info("creating word index")
|
||||||
|
lang = langs[0]
|
||||||
|
for _lang in langs[1:]:
|
||||||
|
lang += _lang
|
||||||
|
word2index = lang.word2index
|
||||||
|
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
|
||||||
|
return corpus, word2index, index2word
|
||||||
|
|
||||||
|
|
||||||
|
def init_network(word2index, prior):
|
||||||
|
logger.info("loading embeddings")
|
||||||
|
vocabulary = sorted(word2index.keys())
|
||||||
|
embeddings = utils.load_glove(vocabulary)
|
||||||
|
|
||||||
|
logger.info("initializing model")
|
||||||
|
network = Network(word2index=word2index, embeddings=embeddings, prior=prior)
|
||||||
|
network.to(config.DEV)
|
||||||
|
|
||||||
|
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
|
||||||
|
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
|
||||||
|
@click.option("-v", "--verbose", count=True)
|
||||||
|
@click.option("-d", "--debug", is_flag=True)
|
||||||
|
def main(verbose, debug):
|
||||||
|
if verbose == 0:
|
||||||
|
loglevel = logging.ERROR
|
||||||
|
elif verbose == 1:
|
||||||
|
loglevel = logging.WARN
|
||||||
|
elif verbose >= 2:
|
||||||
|
loglevel = logging.INFO
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
loglevel = logging.DEBUG
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
|
||||||
|
datefmt="%d.%m. %H:%M:%S",
|
||||||
|
level=loglevel,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("arguments: %s" % str(sys.argv))
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"-c",
|
||||||
|
"--corpus",
|
||||||
|
"corpus_name",
|
||||||
|
required=True,
|
||||||
|
type=click.Choice(sorted(["google",])),
|
||||||
|
)
|
||||||
|
@click.option("-m", "--model_name", required=True)
|
||||||
|
@click.option("-w", "--fixation_weights", required=False)
|
||||||
|
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
|
||||||
|
@click.option("-d", "--debug", is_flag=True, default=False)
|
||||||
|
@click.option("-p", "--prior", type=float, default=.5)
|
||||||
|
def train(corpus_name, model_name, fixation_weights, freeze_fixations, debug, prior):
|
||||||
|
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
|
||||||
|
train_pairs = corpus["train"]
|
||||||
|
val_pairs = corpus["val"]
|
||||||
|
network = init_network(word2index, prior)
|
||||||
|
|
||||||
|
model_dir = os.path.join("models", model_name)
|
||||||
|
logger.debug("creating model dir %s" % model_dir)
|
||||||
|
pathlib.Path(model_dir).mkdir(parents=True)
|
||||||
|
|
||||||
|
if fixation_weights is not None:
|
||||||
|
logger.info("loading fixation prediction checkpoint")
|
||||||
|
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
|
||||||
|
if "word2index" in checkpoint:
|
||||||
|
weights = checkpoint["weights"]
|
||||||
|
else:
|
||||||
|
weights = checkpoint
|
||||||
|
|
||||||
|
# remove the embedding layer before loading
|
||||||
|
weights = {
|
||||||
|
k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")
|
||||||
|
}
|
||||||
|
network.fix_gen.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
if freeze_fixations:
|
||||||
|
logger.info("freezing fixation generation network")
|
||||||
|
for p in network.fix_gen.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
|
||||||
|
|
||||||
|
best_val_loss = None
|
||||||
|
|
||||||
|
epoch = 1
|
||||||
|
batch_size = 20
|
||||||
|
|
||||||
|
while True:
|
||||||
|
train_batch_iter = utils.sent_iter(
|
||||||
|
sents=train_pairs, word2index=word2index, batch_size=batch_size
|
||||||
|
)
|
||||||
|
val_batch_iter = utils.sent_iter(
|
||||||
|
sents=val_pairs, word2index=word2index, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
total_train_loss = 0
|
||||||
|
total_val_loss = 0
|
||||||
|
|
||||||
|
network.train()
|
||||||
|
for i, batch in tqdm.tqdm(
|
||||||
|
enumerate(train_batch_iter, 1), total=len(train_pairs) // batch_size + 1
|
||||||
|
):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
raw_sent, sent, target = batch
|
||||||
|
seq_lens = [len(x) for x in sent]
|
||||||
|
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
||||||
|
|
||||||
|
prediction = prediction.detach().cpu().numpy()
|
||||||
|
|
||||||
|
torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_train_loss += loss.item()
|
||||||
|
|
||||||
|
avg_train_loss = total_train_loss / len(train_pairs)
|
||||||
|
|
||||||
|
val_sents = []
|
||||||
|
val_preds = []
|
||||||
|
val_targets = []
|
||||||
|
|
||||||
|
network.eval()
|
||||||
|
for i, batch in tqdm.tqdm(
|
||||||
|
enumerate(val_batch_iter), total=len(val_pairs) // batch_size + 1
|
||||||
|
):
|
||||||
|
raw_sent, sent, target = batch
|
||||||
|
seq_lens = [len(x) for x in sent]
|
||||||
|
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
||||||
|
|
||||||
|
prediction = prediction.detach().cpu().numpy()
|
||||||
|
|
||||||
|
for i, l in enumerate(seq_lens):
|
||||||
|
val_sents.append(raw_sent[i][:l])
|
||||||
|
val_preds.append(prediction[i][:l].tolist())
|
||||||
|
val_targets.append(target[i][:l].tolist())
|
||||||
|
|
||||||
|
total_val_loss += loss.item()
|
||||||
|
|
||||||
|
avg_val_loss = total_val_loss / len(val_pairs)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"epoch {epoch} train_loss {avg_train_loss:.4f} val_loss {avg_val_loss:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
classification_report(
|
||||||
|
[x for y in val_targets for x in y],
|
||||||
|
[x for y in val_preds for x in y],
|
||||||
|
target_names=["not_del", "del"],
|
||||||
|
digits=5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(f"models/{model_name}/val_original_{epoch}.txt", "w") as oh, open(
|
||||||
|
f"models/{model_name}/val_pred_{epoch}.txt", "w"
|
||||||
|
) as ph, open(f"models/{model_name}/val_gold_{epoch}.txt", "w") as gh:
|
||||||
|
for sent, preds, golds in zip(val_sents, val_preds, val_targets):
|
||||||
|
pred_compressed = [
|
||||||
|
word for word, delete in zip(sent, preds) if not delete
|
||||||
|
]
|
||||||
|
gold_compressed = [
|
||||||
|
word for word, delete in zip(sent, golds) if not delete
|
||||||
|
]
|
||||||
|
|
||||||
|
oh.write(" ".join(sent))
|
||||||
|
ph.write(" ".join(pred_compressed))
|
||||||
|
gh.write(" ".join(gold_compressed))
|
||||||
|
|
||||||
|
oh.write("\n")
|
||||||
|
ph.write("\n")
|
||||||
|
gh.write("\n")
|
||||||
|
|
||||||
|
if best_val_loss is None or avg_val_loss < best_val_loss:
|
||||||
|
delta = avg_val_loss - best_val_loss if best_val_loss is not None else 0.0
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
print(
|
||||||
|
f"new best model epoch {epoch} val loss {avg_val_loss:.4f} ({delta:.4f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
utils.save_model(
|
||||||
|
network, word2index, f"models/{model_name}/{model_name}_{epoch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"-c",
|
||||||
|
"--corpus",
|
||||||
|
"corpus_name",
|
||||||
|
required=True,
|
||||||
|
type=click.Choice(sorted(["google",])),
|
||||||
|
)
|
||||||
|
@click.option("-w", "--model_weights", required=True)
|
||||||
|
@click.option("-p", "--prior", type=float, default=.5)
|
||||||
|
@click.option("-l", "--longest", is_flag=True)
|
||||||
|
@click.option("-s", "--shortest", is_flag=True)
|
||||||
|
@click.option("-d", "--detailed", is_flag=True)
|
||||||
|
def test(corpus_name, model_weights, prior, longest, shortest, detailed):
|
||||||
|
if longest and shortest:
|
||||||
|
print("longest and shortest are mutually exclusive", file=sys.stderr)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
||||||
|
test_pairs = corpus["test"]
|
||||||
|
|
||||||
|
model_name = os.path.basename(os.path.dirname(model_weights))
|
||||||
|
epoch = re.search("_(\d+).tar", model_weights).group(1)
|
||||||
|
|
||||||
|
logger.info("loading model checkpoint")
|
||||||
|
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||||
|
if "word2index" in checkpoint:
|
||||||
|
weights = checkpoint["weights"]
|
||||||
|
word2index = checkpoint["word2index"]
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
else:
|
||||||
|
asdf
|
||||||
|
|
||||||
|
network = init_network(word2index, prior)
|
||||||
|
network.eval()
|
||||||
|
|
||||||
|
# remove the embedding layer before loading
|
||||||
|
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||||
|
# actually load the parameters
|
||||||
|
network.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
total_test_loss = 0
|
||||||
|
|
||||||
|
batch_size = 20
|
||||||
|
|
||||||
|
test_batch_iter = utils.sent_iter(
|
||||||
|
sents=test_pairs, word2index=word2index, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
test_sents = []
|
||||||
|
test_preds = []
|
||||||
|
test_targets = []
|
||||||
|
|
||||||
|
for i, batch in tqdm.tqdm(
|
||||||
|
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
|
||||||
|
):
|
||||||
|
raw_sent, sent, target = batch
|
||||||
|
seq_lens = [len(x) for x in sent]
|
||||||
|
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
||||||
|
|
||||||
|
prediction = prediction.detach().cpu().numpy()
|
||||||
|
|
||||||
|
for i, l in enumerate(
|
||||||
|
seq_lens
|
||||||
|
):
|
||||||
|
test_sents.append(raw_sent[i][:l])
|
||||||
|
test_preds.append(prediction[i][:l].tolist())
|
||||||
|
test_targets.append(target[i][:l].tolist())
|
||||||
|
|
||||||
|
total_test_loss += loss.item()
|
||||||
|
|
||||||
|
avg_test_loss = total_test_loss / len(test_pairs)
|
||||||
|
|
||||||
|
print(f"test_loss {avg_test_loss:.4f}")
|
||||||
|
|
||||||
|
if longest:
|
||||||
|
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
||||||
|
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
|
||||||
|
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
|
||||||
|
test_targets = list(filter(lambda x: len(x) > avg_len, test_targets))
|
||||||
|
elif shortest:
|
||||||
|
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
||||||
|
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
|
||||||
|
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
|
||||||
|
test_targets = list(filter(lambda x: len(x) <= avg_len, test_targets))
|
||||||
|
|
||||||
|
if detailed:
|
||||||
|
for test_sent, test_target, test_pred in zip(test_sents, test_targets, test_preds):
|
||||||
|
print(precision_recall_fscore_support(test_target, test_pred, average="weighted")[2], test_sent, test_target, test_pred)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
classification_report(
|
||||||
|
[x for y in test_targets for x in y],
|
||||||
|
[x for y in test_preds for x in y],
|
||||||
|
target_names=["not_del", "del"],
|
||||||
|
digits=5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(f"models/{model_name}/test_original_{epoch}.txt", "w") as oh, open(
|
||||||
|
f"models/{model_name}/test_pred_{epoch}.txt", "w"
|
||||||
|
) as ph, open(f"models/{model_name}/test_gold_{epoch}.txt", "w") as gh:
|
||||||
|
for sent, preds, golds in zip(test_sents, test_preds, test_targets):
|
||||||
|
pred_compressed = [word for word, delete in zip(sent, preds) if not delete]
|
||||||
|
gold_compressed = [word for word, delete in zip(sent, golds) if not delete]
|
||||||
|
|
||||||
|
oh.write(" ".join(sent))
|
||||||
|
ph.write(" ".join(pred_compressed))
|
||||||
|
gh.write(" ".join(gold_compressed))
|
||||||
|
|
||||||
|
oh.write("\n")
|
||||||
|
ph.write("\n")
|
||||||
|
gh.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"-c",
|
||||||
|
"--corpus",
|
||||||
|
"corpus_name",
|
||||||
|
required=True,
|
||||||
|
type=click.Choice(sorted(["google",])),
|
||||||
|
)
|
||||||
|
@click.option("-w", "--model_weights", required=True)
|
||||||
|
@click.option("-p", "--prior", type=float, default=.5)
|
||||||
|
@click.option("-l", "--longest", is_flag=True)
|
||||||
|
@click.option("-s", "--shortest", is_flag=True)
|
||||||
|
def predict(corpus_name, model_weights, prior, longest, shortest):
|
||||||
|
if longest and shortest:
|
||||||
|
print("longest and shortest are mutually exclusive", file=sys.stderr)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
||||||
|
test_pairs = corpus["test"]
|
||||||
|
|
||||||
|
model_name = os.path.basename(os.path.dirname(model_weights))
|
||||||
|
epoch = re.search("_(\d+).tar", model_weights).group(1)
|
||||||
|
|
||||||
|
logger.info("loading model checkpoint")
|
||||||
|
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||||||
|
if "word2index" in checkpoint:
|
||||||
|
weights = checkpoint["weights"]
|
||||||
|
word2index = checkpoint["word2index"]
|
||||||
|
index2word = {i: w for w, i in word2index.items()}
|
||||||
|
else:
|
||||||
|
asdf
|
||||||
|
|
||||||
|
network = init_network(word2index, prior)
|
||||||
|
network.eval()
|
||||||
|
|
||||||
|
# remove the embedding layer before loading
|
||||||
|
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||||||
|
# actually load the parameters
|
||||||
|
network.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
total_test_loss = 0
|
||||||
|
|
||||||
|
batch_size = 20
|
||||||
|
|
||||||
|
test_batch_iter = utils.sent_iter(
|
||||||
|
sents=test_pairs, word2index=word2index, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
test_sents = []
|
||||||
|
test_preds = []
|
||||||
|
test_attentions = []
|
||||||
|
test_fixations = []
|
||||||
|
|
||||||
|
for i, batch in tqdm.tqdm(
|
||||||
|
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
|
||||||
|
):
|
||||||
|
raw_sent, sent, target = batch
|
||||||
|
seq_lens = [len(x) for x in sent]
|
||||||
|
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
||||||
|
|
||||||
|
prediction = prediction.detach().cpu().numpy()
|
||||||
|
attention = attention.detach().cpu().numpy()
|
||||||
|
if fixations is not None:
|
||||||
|
fixations = fixations.detach().cpu().numpy()
|
||||||
|
|
||||||
|
for i, l in enumerate(
|
||||||
|
seq_lens
|
||||||
|
):
|
||||||
|
test_sents.append(raw_sent[i][:l])
|
||||||
|
test_preds.append(prediction[i][:l].tolist())
|
||||||
|
test_attentions.append(attention[i][:l].tolist())
|
||||||
|
if fixations is not None:
|
||||||
|
test_fixations.append(fixations[i][:l].tolist())
|
||||||
|
else:
|
||||||
|
test_fixations.append([])
|
||||||
|
|
||||||
|
total_test_loss += loss.item()
|
||||||
|
|
||||||
|
avg_test_loss = total_test_loss / len(test_pairs)
|
||||||
|
|
||||||
|
if longest:
|
||||||
|
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
||||||
|
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
|
||||||
|
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
|
||||||
|
test_attentions = list(filter(lambda x: len(x) > avg_len, test_attentions))
|
||||||
|
test_fixations = list(filter(lambda x: len(x) > avg_len, test_fixations))
|
||||||
|
elif shortest:
|
||||||
|
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
||||||
|
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
|
||||||
|
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
|
||||||
|
test_attentions = list(filter(lambda x: len(x) <= avg_len, test_attentions))
|
||||||
|
test_fixations = list(filter(lambda x: len(x) <= avg_len, test_fixations))
|
||||||
|
|
||||||
|
print(f"sentence\tprediction\tattentions\tfixations")
|
||||||
|
for s, p, a, f in zip(test_sents, test_preds, test_attentions, test_fixations):
|
||||||
|
a = [x[:len(a)] for x in a]
|
||||||
|
print(f"{s}\t{p}\t{a}\t{f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
23
joint_sentence_compression_model/requirements.txt
Normal file
23
joint_sentence_compression_model/requirements.txt
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
click==7.1.2
|
||||||
|
cycler==0.10.0
|
||||||
|
dataclasses==0.6
|
||||||
|
future==0.18.2
|
||||||
|
joblib==0.17.0
|
||||||
|
kiwisolver==1.3.1
|
||||||
|
matplotlib==3.3.3
|
||||||
|
nltk==3.5
|
||||||
|
numpy==1.19.4
|
||||||
|
Pillow==8.0.1
|
||||||
|
portalocker==2.0.0
|
||||||
|
pyparsing==2.4.7
|
||||||
|
python-dateutil==2.8.1
|
||||||
|
regex==2020.11.13
|
||||||
|
sacrebleu==1.4.6
|
||||||
|
scikit-learn==0.23.2
|
||||||
|
scipy==1.5.4
|
||||||
|
six==1.15.0
|
||||||
|
sklearn==0.0
|
||||||
|
threadpoolctl==2.1.0
|
||||||
|
torch==1.7.0
|
||||||
|
tqdm==4.54.1
|
||||||
|
typing-extensions==3.7.4.3
|
43
joint_sentence_compression_model/utils/check_stats.py
Normal file
43
joint_sentence_compression_model/utils/check_stats.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import ast
|
||||||
|
from statistics import mean, pstdev
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import click
|
||||||
|
from scipy.stats import entropy
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def reader(path):
|
||||||
|
with open(path) as h:
|
||||||
|
for line in h:
|
||||||
|
line = line.strip()
|
||||||
|
try:
|
||||||
|
s, p, a, f = line.split("\t")
|
||||||
|
except:
|
||||||
|
print(f"skipping line: {line}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
yield ast.literal_eval(s), ast.literal_eval(p), ast.literal_eval(a), ast.literal_eval(f)
|
||||||
|
except:
|
||||||
|
print(f"malformed line: {s}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_stats(seq):
|
||||||
|
for s, p, a, f in seq:
|
||||||
|
print(s)
|
||||||
|
print(p)
|
||||||
|
print(len(s), len(p), len(a), len(f))
|
||||||
|
for x in a:
|
||||||
|
print(len(x))
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("path")
|
||||||
|
def main(path):
|
||||||
|
get_stats(reader(path))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
34
joint_sentence_compression_model/utils/cr.py
Normal file
34
joint_sentence_compression_model/utils/cr.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import ast
|
||||||
|
from math import log2
|
||||||
|
import os
|
||||||
|
from statistics import mean, pstdev
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import click
|
||||||
|
import numpy as np
|
||||||
|
from scipy.special import kl_div
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def reader(path):
|
||||||
|
with open(path) as h:
|
||||||
|
for line in h:
|
||||||
|
line = line.strip()
|
||||||
|
yield line.split()
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("original")
|
||||||
|
@click.argument("compressed")
|
||||||
|
def main(original, compressed):
|
||||||
|
ratio = 0
|
||||||
|
total = 0
|
||||||
|
for o, c in zip(reader(original), reader(compressed)):
|
||||||
|
ratio += len(c)/len(o)
|
||||||
|
total += 1
|
||||||
|
print(f"cr: {ratio/total:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
88
joint_sentence_compression_model/utils/kl.py
Normal file
88
joint_sentence_compression_model/utils/kl.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
import ast
|
||||||
|
from math import log2
|
||||||
|
import os
|
||||||
|
from statistics import mean, pstdev
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import click
|
||||||
|
import numpy as np
|
||||||
|
from scipy.special import kl_div
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def attention_reader(path):
|
||||||
|
with open(path) as h:
|
||||||
|
for line in h:
|
||||||
|
line = line.strip()
|
||||||
|
try:
|
||||||
|
s, p, a, f = line.split("\t")
|
||||||
|
except:
|
||||||
|
print(f"skipping line: {line}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
yield [[x[0] for x in y] for y in ast.literal_eval(a)]
|
||||||
|
except:
|
||||||
|
print(f"skipping malformed line: {s}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def _kl_divergence(p, q):
|
||||||
|
p = np.asarray(p)
|
||||||
|
q = np.asarray(q)
|
||||||
|
p /= sum(p)
|
||||||
|
q /= sum(q)
|
||||||
|
return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
|
||||||
|
|
||||||
|
def kl_divergence(ps, qs):
|
||||||
|
kl = 0
|
||||||
|
count = 0
|
||||||
|
for p, q in zip(ps, qs):
|
||||||
|
print(p, q)
|
||||||
|
kl += _kl_divergence(p, q)
|
||||||
|
count += 1
|
||||||
|
return kl/count
|
||||||
|
|
||||||
|
|
||||||
|
def _js_divergence(p, q):
|
||||||
|
p = np.asarray(p)
|
||||||
|
q = np.asarray(q)
|
||||||
|
p /= sum(p)
|
||||||
|
q /= sum(q)
|
||||||
|
print(p, q)
|
||||||
|
m = 0.5 * (p + q)
|
||||||
|
return 0.5 * _kl_divergence(p, m) + 0.5 * _kl_divergence(q, m)
|
||||||
|
|
||||||
|
def js_divergence(ps, qs):
|
||||||
|
js = 0
|
||||||
|
count = 0
|
||||||
|
for p, q in zip(ps, qs):
|
||||||
|
js += _js_divergence(p, q)
|
||||||
|
count += 1
|
||||||
|
return js/count
|
||||||
|
|
||||||
|
|
||||||
|
def get_kl_div(seq1, seq2):
|
||||||
|
return [js_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
|
||||||
|
return [kl_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("ref")
|
||||||
|
@click.argument("path", nargs=-1)
|
||||||
|
def main(ref, path):
|
||||||
|
kls = []
|
||||||
|
labels = []
|
||||||
|
for p in path:
|
||||||
|
labels.append(os.path.basename(p))
|
||||||
|
kl = get_kl_div(attention_reader(ref), attention_reader(p))
|
||||||
|
print(mean(kl))
|
||||||
|
print(pstdev(kl))
|
||||||
|
kls.append(kl)
|
||||||
|
|
||||||
|
plt.boxplot(kls, labels=labels)
|
||||||
|
plt.show()
|
||||||
|
plt.clear()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
112
joint_sentence_compression_model/utils/kl_divergence.py
Normal file
112
joint_sentence_compression_model/utils/kl_divergence.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
'''
|
||||||
|
Calculates the KL divergence between specified columns in a file or across columns of different files
|
||||||
|
'''
|
||||||
|
|
||||||
|
from math import log2
|
||||||
|
import pandas as pd
|
||||||
|
import ast
|
||||||
|
import collections
|
||||||
|
from sklearn import metrics
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def kl_divergence(p, q):
|
||||||
|
return sum(p[i] * log2(p[i] / q[i]) for i in range(len(p)))
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(x):
|
||||||
|
if isinstance(x, collections.Iterable):
|
||||||
|
return [a for i in x for a in flatten(i)]
|
||||||
|
else:
|
||||||
|
return [x]
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(file):
|
||||||
|
names = ["sentence", "prediction", "attentions", "fixations"]
|
||||||
|
df = pd.read_csv(file, sep='\t', names=names)
|
||||||
|
df = df[2:]
|
||||||
|
attentions = df.loc[:, "attentions"].tolist()
|
||||||
|
fixations = df.loc[:, "fixations"].tolist()
|
||||||
|
|
||||||
|
return attentions, fixations
|
||||||
|
|
||||||
|
def attention_attention(attentions1, attentions2):
|
||||||
|
divergence = []
|
||||||
|
|
||||||
|
for att1, att2 in zip(attentions1, attentions2):
|
||||||
|
current_att1 = ast.literal_eval(att1)
|
||||||
|
current_att2 = ast.literal_eval(att2)
|
||||||
|
|
||||||
|
lst_att1 = flatten(current_att1)
|
||||||
|
lst_att2 = flatten(current_att2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
kl_pq = metrics.mutual_info_score(lst_att1, lst_att2)
|
||||||
|
divergence.append(kl_pq)
|
||||||
|
except:
|
||||||
|
divergence.append(None)
|
||||||
|
|
||||||
|
avg = sum(divergence) / len(divergence)
|
||||||
|
return avg
|
||||||
|
|
||||||
|
def fixation_fixation(fixation1, fixation2):
|
||||||
|
divergence = []
|
||||||
|
|
||||||
|
for fix1, fix2 in zip(fixation1, fixation2):
|
||||||
|
|
||||||
|
current_fixation1 = ast.literal_eval(fix1)
|
||||||
|
current_fixation2 = ast.literal_eval(fix2)
|
||||||
|
|
||||||
|
lst_fixation1 = flatten(current_fixation1)
|
||||||
|
lst_fixation2 = flatten(current_fixation2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
kl_pq = metrics.mutual_info_score(lst_fixation1, lst_fixation2)
|
||||||
|
divergence.append(kl_pq)
|
||||||
|
except:
|
||||||
|
divergence.append(None)
|
||||||
|
|
||||||
|
avg = sum(divergence) / len(divergence)
|
||||||
|
return avg
|
||||||
|
|
||||||
|
def attention_fixation(attentions,fixations):
|
||||||
|
divergence = []
|
||||||
|
|
||||||
|
for attention, fixation in zip(attentions, fixations):
|
||||||
|
current_attention = ast.literal_eval(attention)
|
||||||
|
current_fixation = ast.literal_eval(fixation)
|
||||||
|
|
||||||
|
lst_attention = []
|
||||||
|
|
||||||
|
for t in current_attention:
|
||||||
|
attention_lst = flatten(t)
|
||||||
|
lst_attention.append(sum(attention_lst)/len(attention_lst))
|
||||||
|
|
||||||
|
lst_fixation = flatten(current_fixation)
|
||||||
|
|
||||||
|
try:
|
||||||
|
kl_pq = metrics.mutual_info_score(lst_attention, lst_fixation)
|
||||||
|
divergence.append(kl_pq)
|
||||||
|
except:
|
||||||
|
divergence.append(None)
|
||||||
|
|
||||||
|
avg = sum(divergence)/len(divergence)
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def divergent_calculations(file1, file2=None, val1=None):
|
||||||
|
attentions, fixations = get_data(file1)
|
||||||
|
attentions2, fixations2 = get_data(file2)
|
||||||
|
|
||||||
|
if file2:
|
||||||
|
if val1 == "attention":
|
||||||
|
divergence = attention_attention(attentions, attentions2)
|
||||||
|
else:
|
||||||
|
divergence = fixation_fixation(fixations, fixations2)
|
||||||
|
|
||||||
|
else:
|
||||||
|
divergence = attention_fixation(attentions, fixations)
|
||||||
|
|
||||||
|
print ("DL Divergence: ", divergence)
|
||||||
|
|
||||||
|
divergent_calculations(sys.argv[1], sys.argv[2], sys.argv[3])
|
64
joint_sentence_compression_model/utils/plot_attention.py
Normal file
64
joint_sentence_compression_model/utils/plot_attention.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import text_attention
|
||||||
|
|
||||||
|
import click
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
plt.switch_backend("agg")
|
||||||
|
import matplotlib.ticker as ticker
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def plot_attention(input_sentence, output_words, attentions, path):
|
||||||
|
# Set up figure with colorbar
|
||||||
|
attentions = np.array(attentions)[:,:len(input_sentence)]
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
cax = ax.matshow(attentions, cmap="bone")
|
||||||
|
fig.colorbar(cax)
|
||||||
|
|
||||||
|
# Set up axes
|
||||||
|
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
|
||||||
|
ax.set_yticklabels([""] + output_words)
|
||||||
|
|
||||||
|
# Show label at every tick
|
||||||
|
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||||
|
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||||
|
|
||||||
|
plt.savefig(f"{path}.pdf")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def parse(p):
|
||||||
|
with open(p) as h:
|
||||||
|
for line in h:
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
|
||||||
|
try:
|
||||||
|
sentence = ast.literal_eval(_sentence)
|
||||||
|
prediction = ast.literal_eval(_prediction)
|
||||||
|
attention = ast.literal_eval(_attention)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield sentence, prediction, attention
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("path", nargs=-1, required=True)
|
||||||
|
def main(path):
|
||||||
|
for p in tqdm.tqdm(path):
|
||||||
|
out_dir = os.path.splitext(p)[0]
|
||||||
|
if out_dir == path:
|
||||||
|
out_dir = f"{out_dir}_"
|
||||||
|
pathlib.Path(out_dir).mkdir(exist_ok=True)
|
||||||
|
for i, spa in enumerate(parse(p)):
|
||||||
|
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
70
joint_sentence_compression_model/utils/text_attention.py
Executable file
70
joint_sentence_compression_model/utils/text_attention.py
Executable file
|
@ -0,0 +1,70 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Author: Jie Yang
|
||||||
|
# @Date: 2019-03-29 16:10:23
|
||||||
|
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||||
|
# @Last Modified time: 2019-04-12 09:56:12
|
||||||
|
|
||||||
|
|
||||||
|
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
latex_special_token = ["!@#$%^&*()"]
|
||||||
|
|
||||||
|
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
|
||||||
|
assert(len(text_list) == len(attention_list))
|
||||||
|
if rescale_value:
|
||||||
|
attention_list = rescale(attention_list)
|
||||||
|
word_num = len(text_list)
|
||||||
|
text_list = clean_word(text_list)
|
||||||
|
with open(latex_file,'w') as f:
|
||||||
|
f.write(r'''\documentclass[varwidth]{standalone}
|
||||||
|
\special{papersize=210mm,297mm}
|
||||||
|
\usepackage{color}
|
||||||
|
\usepackage{tcolorbox}
|
||||||
|
\usepackage{CJK}
|
||||||
|
\usepackage{adjustbox}
|
||||||
|
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
||||||
|
\begin{document}
|
||||||
|
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
|
||||||
|
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
|
||||||
|
for idx in range(word_num):
|
||||||
|
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
|
||||||
|
string += "\n}}}"
|
||||||
|
f.write(string+'\n')
|
||||||
|
f.write(r'''\end{CJK*}
|
||||||
|
\end{document}''')
|
||||||
|
|
||||||
|
def rescale(input_list):
|
||||||
|
the_array = np.asarray(input_list)
|
||||||
|
the_max = np.max(the_array)
|
||||||
|
the_min = np.min(the_array)
|
||||||
|
rescale = (the_array - the_min)/(the_max-the_min)*100
|
||||||
|
return rescale.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def clean_word(word_list):
|
||||||
|
new_word_list = []
|
||||||
|
for word in word_list:
|
||||||
|
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
|
||||||
|
if latex_sensitive in word:
|
||||||
|
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
|
||||||
|
new_word_list.append(word)
|
||||||
|
return new_word_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
## This is a demo:
|
||||||
|
|
||||||
|
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
|
||||||
|
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
|
||||||
|
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
|
||||||
|
sent = '''我 回忆 起 我 曾经 在 大学 年代 , 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar , 我 想起 了 西游记 里 的 琵琶精 。
|
||||||
|
今年 下半年 , 中 美 合拍 的 西游记 即将 正式 开机 , 我 继续 扮演 美猴王 孙悟空 , 我 会 用 美猴王 艺术 形象 努力 创造 一 个 正能量 的 形象 , 文 体 两 开花 , 弘扬 中华 文化 , 希望 大家 能 多多 关注 。'''
|
||||||
|
words = sent.split()
|
||||||
|
word_num = len(words)
|
||||||
|
attention = [(x+1.)/word_num*100 for x in range(word_num)]
|
||||||
|
import random
|
||||||
|
random.seed(42)
|
||||||
|
random.shuffle(attention)
|
||||||
|
color = 'red'
|
||||||
|
generate(words, attention, "sample.tex", color)
|
Loading…
Reference in a new issue