Add NLP task models

This commit is contained in:
Ekta Sood 2020-12-08 21:10:52 +01:00
parent d8beb17dfb
commit 69f6de0ace
46 changed files with 4976 additions and 0 deletions

428
joint_paraphrase_model/.gitignore vendored Normal file
View file

@ -0,0 +1,428 @@
# Created by https://www.toptal.com/developers/gitignore/api/python,latex
# Edit at https://www.toptal.com/developers/gitignore?templates=python,latex
### LaTeX ###
## Core latex/pdflatex auxiliary files:
*.aux
*.lof
*.log
*.lot
*.fls
*.out
*.toc
*.fmt
*.fot
*.cb
*.cb2
.*.lb
## Intermediate documents:
*.dvi
*.xdv
*-converted-to.*
# these rules might exclude image files for figures etc.
# *.ps
# *.eps
# *.pdf
## Generated if empty string is given at "Please type another file name for output:"
.pdf
## Bibliography auxiliary files (bibtex/biblatex/biber):
*.bbl
*.bcf
*.blg
*-blx.aux
*-blx.bib
*.run.xml
## Build tool auxiliary files:
*.fdb_latexmk
*.synctex
*.synctex(busy)
*.synctex.gz
*.synctex.gz(busy)
*.pdfsync
## Build tool directories for auxiliary files
# latexrun
latex.out/
## Auxiliary and intermediate files from other packages:
# algorithms
*.alg
*.loa
# achemso
acs-*.bib
# amsthm
*.thm
# beamer
*.nav
*.pre
*.snm
*.vrb
# changes
*.soc
# comment
*.cut
# cprotect
*.cpt
# elsarticle (documentclass of Elsevier journals)
*.spl
# endnotes
*.ent
# fixme
*.lox
# feynmf/feynmp
*.mf
*.mp
*.t[1-9]
*.t[1-9][0-9]
*.tfm
#(r)(e)ledmac/(r)(e)ledpar
*.end
*.?end
*.[1-9]
*.[1-9][0-9]
*.[1-9][0-9][0-9]
*.[1-9]R
*.[1-9][0-9]R
*.[1-9][0-9][0-9]R
*.eledsec[1-9]
*.eledsec[1-9]R
*.eledsec[1-9][0-9]
*.eledsec[1-9][0-9]R
*.eledsec[1-9][0-9][0-9]
*.eledsec[1-9][0-9][0-9]R
# glossaries
*.acn
*.acr
*.glg
*.glo
*.gls
*.glsdefs
*.lzo
*.lzs
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
# *.ist
# gnuplottex
*-gnuplottex-*
# gregoriotex
*.gaux
*.gtex
# htlatex
*.4ct
*.4tc
*.idv
*.lg
*.trc
*.xref
# hyperref
*.brf
# knitr
*-concordance.tex
# TODO Comment the next line if you want to keep your tikz graphics files
*.tikz
*-tikzDictionary
# listings
*.lol
# luatexja-ruby
*.ltjruby
# makeidx
*.idx
*.ilg
*.ind
# minitoc
*.maf
*.mlf
*.mlt
*.mtc[0-9]*
*.slf[0-9]*
*.slt[0-9]*
*.stc[0-9]*
# minted
_minted*
*.pyg
# morewrites
*.mw
# nomencl
*.nlg
*.nlo
*.nls
# pax
*.pax
# pdfpcnotes
*.pdfpc
# sagetex
*.sagetex.sage
*.sagetex.py
*.sagetex.scmd
# scrwfile
*.wrt
# sympy
*.sout
*.sympy
sympy-plots-for-*.tex/
# pdfcomment
*.upa
*.upb
# pythontex
*.pytxcode
pythontex-files-*/
# tcolorbox
*.listing
# thmtools
*.loe
# TikZ & PGF
*.dpth
*.md5
*.auxlock
# todonotes
*.tdo
# vhistory
*.hst
*.ver
# easy-todo
*.lod
# xcolor
*.xcp
# xmpincl
*.xmpi
# xindy
*.xdy
# xypic precompiled matrices and outlines
*.xyc
*.xyd
# endfloat
*.ttt
*.fff
# Latexian
TSWLatexianTemp*
## Editors:
# WinEdt
*.bak
*.sav
# Texpad
.texpadtmp
# LyX
*.lyx~
# Kile
*.backup
# gummi
.*.swp
# KBibTeX
*~[0-9]*
# TeXnicCenter
*.tps
# auto folder when using emacs and auctex
./auto/*
*.el
# expex forward references with \gathertags
*-tags.tex
# standalone packages
*.sta
# Makeindex log files
*.lpz
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
# Uncomment the next line to have this generated file ignored.
#*Notes.bib
### LaTeX Patch ###
# LIPIcs / OASIcs
*.vtc
# glossaries
*.glstex
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# End of https://www.toptal.com/developers/gitignore/api/python,latex

View file

@ -0,0 +1,3 @@
# joint_paraphrase_model
joint training paraphrase model --- neurips

View file

@ -0,0 +1,112 @@
import os
import torch
# general
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PAD = "<__PAD__>"
UNK = "<__UNK__>"
NOFIX = "<__NOFIX__>"
SOS = "<__SOS__>"
EOS = "<__EOS__>"
batch_size = 1
teacher_forcing_ratio = 0.5
embedding_dim = 300
fix_hidden_dim = 128
par_hidden_dim = 1024
fix_dropout = 0.5
par_dropout = 0.2
_fix_learning_rate = 0.00001
_par_learning_rate = 0.0001
learning_rate = _par_learning_rate
fix_momentum = 0.9
par_momentum = 0.0
max_length = 121
epochs = 5
# paths
data_path = "./data"
provo_predictability_path = os.path.join(
data_path, "datasets/provo/Provo_Corpus-Predictability_Norms.csv"
)
provo_eyetracking_path = os.path.join(
data_path, "datasets/provo/Provo_Corpus-Eyetracking_Data.csv"
)
geco_en_path = os.path.join(data_path, "datasets/geco/EnglishMaterial.csv")
geco_mono_path = os.path.join(data_path, "datasets/geco/MonolingualReadingData.csv")
movieqa_human_path = os.path.join(data_path, "datasets/all_word_scores_fixations")
movieqa_human_path_2 = os.path.join(
data_path, "datasets/all_word_scores_fixations_exp2"
)
movieqa_human_path_3 = os.path.join(
data_path, "datasets/all_word_scores_fixations_exp3"
)
movieqa_split_plot_path = os.path.join(data_path, "datasets/split_plot_UNRESOLVED")
cnn_path = os.path.join(
data_path,
"projects/2019/fixation_prediction/ez-reader-wrapper/predictability/output_cnn/",
)
dm_path = os.path.join(
data_path,
"projects/2019/fixation_prediction/ez-reader-wrapper/predictability/output_dm/",
)
qqp_paws_basedir = os.path.join(data_path, "datasets/paw_google/qqp/paws_qqp/output")
qqp_paws_train_path = os.path.join(qqp_paws_basedir, "train.tsv")
qqp_paws_dev_path = os.path.join(qqp_paws_basedir, "dev.tsv")
qqp_paws_test_path = os.path.join(qqp_paws_basedir, "test.tsv")
qqp_basedir = os.path.join(data_path, "datasets/Quora_question_pair_partition_OG")
qqp_train_path = os.path.join(qqp_basedir, "train.tsv")
qqp_dev_path = os.path.join(qqp_basedir, "dev.tsv")
qqp_test_path = os.path.join(qqp_basedir, "test.tsv")
qqp_kag_basedir = os.path.join(data_path, "datasets/Quora_question_pair_partition_kag")
qqp_kag_train_path = os.path.join(qqp_kag_basedir, "train.tsv")
qqp_kag_dev_path = os.path.join(qqp_kag_basedir, "dev.tsv")
qqp_kag_test_path = os.path.join(qqp_kag_basedir, "test.tsv")
wiki_basedir = os.path.join(data_path, "datasets/paw_google/wiki")
wiki_train_path = os.path.join(wiki_basedir, "train.tsv")
wiki_dev_path = os.path.join(wiki_basedir, "dev.tsv")
wiki_test_path = os.path.join(wiki_basedir, "test.tsv")
msrpc_basedir = os.path.join(data_path, "datasets/MSRPC")
msrpc_train_path = os.path.join(msrpc_basedir, "msr_paraphrase_train.txt")
msrpc_dev_path = os.path.join(msrpc_basedir, "msr_paraphrase_dev.txt")
msrpc_test_path = os.path.join(msrpc_basedir, "msr_paraphrase_test.txt")
sentiment_basedir = os.path.join(data_path, "datasets/sentiment_kag")
sentiment_train_path = os.path.join(sentiment_basedir, "train.tsv")
sentiment_dev_path = os.path.join(sentiment_basedir, "dev.tsv")
sentiment_test_path = os.path.join(sentiment_basedir, "test.tsv")
tamil_basedir = os.path.join(data_path, "datasets/en-ta-parallel-v2")
tamil_train_path = os.path.join(tamil_basedir, "corpus.bcn.train.enta")
tamil_dev_path = os.path.join(tamil_basedir, "corpus.bcn.dev.enta")
tamil_test_path = os.path.join(tamil_basedir, "corpus.bcn.test.enta")
compression_basedir = os.path.join(data_path, "datasets/sentence-compression/data")
compression_train_path = os.path.join(compression_basedir, "train.tsv")
compression_dev_path = os.path.join(compression_basedir, "dev.tsv")
compression_test_path = os.path.join(compression_basedir, "test.tsv")
stanford_basedir = os.path.join(data_path, "datasets/stanfordSentimentTreebank")
stanford_train_path = os.path.join(stanford_basedir, "train.tsv")
stanford_dev_path = os.path.join(stanford_basedir, "dev.tsv")
stanford_test_path = os.path.join(stanford_basedir, "test.tsv")
stanford_sent_basedir = os.path.join(data_path, "datasets/stanfordSentimentTreebank")
stanford_sent_train_path = os.path.join(stanford_basedir, "train_sent.tsv")
stanford_sent_dev_path = os.path.join(stanford_basedir, "dev_sent.tsv")
stanford_sent_test_path = os.path.join(stanford_basedir, "test_sent.tsv")
emb_path = os.path.join(data_path, "Google_word2vec/GoogleNews-vectors-negative300.bin")
glove_path = "glove.840B.300d.txt"

1
joint_paraphrase_model/data Symbolic link
View file

@ -0,0 +1 @@
/netpool/work/gpu-2/users/soodea/

View file

@ -0,0 +1 @@
/netpool/work/gpu-2/users/soodea/datasets/glove/glove.840B.300d.txt

File diff suppressed because one or more lines are too long

View file

View file

@ -0,0 +1,416 @@
import logging
import config
def tokenize(sent):
return sent.split(" ")
class Lang:
"""Represents the vocabulary
"""
def __init__(self, name):
self.name = name
self.word2index = {
config.PAD: 0,
config.UNK: 1,
config.NOFIX: 2,
config.SOS: 3,
config.EOS: 4,
}
self.word2count = {}
self.index2word = {
0: config.PAD,
1: config.UNK,
2: config.NOFIX,
3: config.SOS,
4: config.EOS,
}
self.n_words = 5
def add_sentence(self, sentence):
assert isinstance(
sentence, (list, tuple)
), "input to add_sentence must be tokenized"
for word in sentence:
self.add_word(word)
def add_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def __add__(self, other):
"""Returns a new Lang object containing the vocabulary from this and
the other Lang object
"""
new_lang = Lang(f"{self.name}_{other.name}")
# Add vocabulary from both Langs
for word in self.word2count.keys():
new_lang.add_word(word)
for word in other.word2count.keys():
new_lang.add_word(word)
# Fix the counts on the new one
for word in new_lang.word2count.keys():
new_lang.word2count[word] = self.word2count.get(
word, 0
) + other.word2count.get(word, 0)
return new_lang
def load_wiki(split):
"""Load the Wiki from PAWs"""
logger = logging.getLogger(f"{__name__}.load_wiki")
lang = Lang("wiki")
if split == "train":
path = config.wiki_train_path
elif split == "val":
path = config.wiki_dev_path
elif split == "test":
path = config.wiki_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
# skip header
handle.readline()
for line in handle:
_, sent1, sent2, rating = line.strip().split("\t")
if rating == "0":
continue
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
# pairs.append([sent1, sent2, rating])
pairs.append([sent1, sent2])
# MS makes the vocab for paraphrase the same
return pairs, lang
def load_qqp_paws(split):
"""Load the QQP from PAWs"""
logger = logging.getLogger(f"{__name__}.load_qqp_paws")
lang = Lang("qqp_paws")
if split == "train":
path = config.qqp_paws_train_path
elif split == "val":
path = config.qqp_paws_dev_path
elif split == "test":
path = config.qqp_paws_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
# skip header
handle.readline()
for line in handle:
_, sent1, sent2, rating = line.strip().split("\t")
if rating == "0":
continue
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
# pairs.append([sent1, sent2, rating])
pairs.append([sent1, sent2])
# MS makes the vocab for paraphrase the same
return pairs, lang
def load_qqp(split):
"""Load the QQP from Original"""
logger = logging.getLogger(f"{__name__}.load_qqp")
lang = Lang("qqp")
if split == "train":
path = config.qqp_train_path
elif split == "val":
path = config.qqp_dev_path
elif split == "test":
path = config.qqp_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
# skip header
handle.readline()
for line in handle:
rating, sent1, sent2, _ = line.strip().split("\t")
if rating == "0":
continue
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
# pairs.append([sent1, sent2, rating])
pairs.append([sent1, sent2])
# MS makes the vocab for paraphrase the same
return pairs, lang
def load_qqp_kag(split):
"""Load the QQP from Kaggle""" #not original right now, expriemnting with kaggle 100K, 3K, 30K split
logger = logging.getLogger(f"{__name__}.load_qqp_kag")
lang = Lang("qqp_kag")
if split == "train":
path = config.qqp_kag_train_path
elif split == "val":
path = config.qqp_kag_dev_path
elif split == "test":
path = config.qqp_kag_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
# skip header
handle.readline()
for line in handle: #when reading the kag version we do not have 4 fields, but rather 3
rating, sent1, sent2 = line.strip().split("\t")
if rating == "0":
continue
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
# pairs.append([sent1, sent2, rating])
pairs.append([sent1, sent2])
# MS makes the vocab for paraphrase the same
return pairs, lang
def load_msrpc(split):
"""Load the Microsoft Research Paraphrase Corpus (MSRPC)"""
logger = logging.getLogger(f"{__name__}.load_msrpc")
lang = Lang("msrpc")
if split == "train":
path = config.msrpc_train_path
elif split == "val":
path = config.msrpc_dev_path
elif split == "test":
path = config.msrpc_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
# skip header
handle.readline()
for line in handle:
rating, _, _, sent1, sent2 = line.strip().split("\t")
if rating == "0":
continue
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
# pairs.append([sent1, sent2, rating])
pairs.append([sent1, sent2])
# return src_lang, dst_lang, pairs
# MS makes the vocab for paraphrase the same
return pairs, lang
def load_sentiment(split):
"""Load the Sentiment Kaggle Comp Dataset"""
logger = logging.getLogger(f"{__name__}.load_sentiment")
lang = Lang("sentiment")
if split == "train":
path = config.sentiment_train_path
elif split == "val":
path = config.sentiment_dev_path
elif split == "test":
path = config.sentiment_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
# skip header
handle.readline()
for line in handle:
_, _, sent1, sent2 = line.strip().split("\t")
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
# pairs.append([sent1, sent2, rating])
pairs.append([sent1, sent2])
return pairs, lang
def load_tamil(split):
"""Load the En to Tamil dataset, current SOTA ~13 bleu"""
logger = logging.getLogger(f"{__name__}.load_tamil")
lang = Lang("tamil")
if split == "train":
path = config.tamil_train_path
elif split == "val":
path = config.tamil_dev_path
elif split == "test":
path = config.tamil_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
handle.readline()
for line in handle:
sent1, sent2 = line.strip().split("\t")
#if rating == "0":
# continue
sent1 = tokenize(sent1)
#I dunno how to tokenize tamil.....?
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
pairs.append([sent1, sent2])
return pairs, lang
def load_compression(split):
"""Load the Google Sentence Compression Dataset"""
logger = logging.getLogger(f"{__name__}.load_compression")
lang = Lang("compression")
if split == "train":
path = config.compression_train_path
elif split == "val":
path = config.compression_dev_path
elif split == "test":
path = config.compression_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
handle.readline()
for line in handle:
sent1, sent2 = line.strip().split("\t")
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
# print(len(sent1), sent1)
# print(len(sent2), sent2)
# print()
lang.add_sentence(sent1)
lang.add_sentence(sent2)
pairs.append([sent1, sent2])
return pairs, lang
def load_stanford(split):
"""Load the Stanford Sentiment Dataset phrases"""
logger = logging.getLogger(f"{__name__}.load_stanford")
lang = Lang("stanford")
if split == "train":
path = config.stanford_train_path
elif split == "val":
path = config.stanford_dev_path
elif split == "test":
path = config.stanford_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
# skip header
#handle.readline()
for line in handle:
_, _, sent1, sent2 = line.strip().split("\t")
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
# pairs.append([sent1, sent2, rating])
pairs.append([sent1, sent2])
return pairs, lang
def load_stanford_sent(split):
"""Load the Stanford Sentiment Dataset sentences"""
logger = logging.getLogger(f"{__name__}.load_stanford_sent")
lang = Lang("stanford_sent")
if split == "train":
path = config.stanford_sent_train_path
elif split == "val":
path = config.stanford_sent_dev_path
elif split == "test":
path = config.stanford_sent_test_path
logger.info("loading %s from %s" % (split, path))
pairs = []
with open(path) as handle:
# skip header
#handle.readline()
for line in handle:
_, _, sent1, sent2 = line.strip().split("\t")
sent1 = tokenize(sent1)
sent2 = tokenize(sent2)
lang.add_sentence(sent1)
lang.add_sentence(sent2)
# pairs.append([sent1, sent2, rating])
pairs.append([sent1, sent2])
return pairs, lang

View file

@ -0,0 +1 @@
from .main import *

View file

@ -0,0 +1,131 @@
from collections import OrderedDict
import logging
import sys
from .self_attention import Transformer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence
def random_embedding(vocab_size, embedding_dim):
pretrain_emb = np.empty([vocab_size, embedding_dim])
scale = np.sqrt(3.0 / embedding_dim)
for index in range(vocab_size):
pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
return pretrain_emb
def neg_log_likelihood_loss(outputs, batch_label, batch_size, seq_len):
outputs = outputs.view(batch_size * seq_len, -1)
score = F.log_softmax(outputs, 1)
loss = nn.NLLLoss(ignore_index=0, size_average=False)(
score, batch_label.view(batch_size * seq_len)
)
loss = loss / batch_size
_, tag_seq = torch.max(score, 1)
tag_seq = tag_seq.view(batch_size, seq_len)
# print(score[0], tag_seq[0])
return loss, tag_seq
def mse_loss(outputs, batch_label, batch_size, seq_len, word_seq_length):
# score = torch.nn.functional.softmax(outputs, 1)
score = torch.sigmoid(outputs)
mask = torch.zeros_like(score)
for i, v in enumerate(word_seq_length):
mask[i, 0:v] = 1
score = score * mask
loss = nn.MSELoss(reduction="sum")(
score.view(batch_size, seq_len), batch_label.view(batch_size, seq_len)
)
loss = loss / batch_size
return loss, score.view(batch_size, seq_len)
class Network(nn.Module):
def __init__(
self,
embedding_type,
vocab_size,
embedding_dim,
dropout,
hidden_dim,
embeddings=None,
attention=True,
):
super().__init__()
self.logger = logging.getLogger(f"{__name__}")
prelayers = OrderedDict()
postlayers = OrderedDict()
if embedding_type in ("w2v", "glove"):
if embeddings is not None:
prelayers["embedding_layer"] = nn.Embedding.from_pretrained(embeddings)
else:
prelayers["embedding_layer"] = nn.Embedding(vocab_size, embedding_dim)
prelayers["embedding_dropout_layer"] = nn.Dropout(dropout)
embedding_dim = 300
elif embedding_type == "bert":
embedding_dim = 768
self.lstm = BiLSTM(embedding_dim, hidden_dim // 2, num_layers=1)
postlayers["lstm_dropout_layer"] = nn.Dropout(dropout)
if attention:
# increased compl with 1024D, and 16,16: for no att and att experiments
# before: for the initial att and pretraining: heads 4 and layers 4, 128D
# then was 128 D with heads 4 layer 1 = results for all IUI
###postlayers["position_encodings"] = PositionalEncoding(hidden_dim)
postlayers["attention_layer"] = Transformer(
d_model=hidden_dim, n_heads=4, n_layers=1
)
postlayers["ff_layer"] = nn.Linear(hidden_dim, hidden_dim // 2)
postlayers["ff_activation"] = nn.ReLU()
postlayers["output_layer"] = nn.Linear(hidden_dim // 2, 1)
self.logger.info(f"prelayers: {prelayers.keys()}")
self.logger.info(f"postlayers: {postlayers.keys()}")
self.pre = nn.Sequential(prelayers)
self.post = nn.Sequential(postlayers)
def forward(self, x, word_seq_length):
x = self.pre(x)
x = self.lstm(x, word_seq_length)
#MS pritning fix model params
#for p in self.parameters():
# print(p.data)
# break
return self.post(x.transpose(1, 0))
class BiLSTM(nn.Module):
def __init__(self, embedding_dim, lstm_hidden, num_layers):
super().__init__()
self.net = nn.LSTM(
input_size=embedding_dim,
hidden_size=lstm_hidden,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)
def forward(self, x, word_seq_length):
packed_words = pack_padded_sequence(x, word_seq_length, True, False)
lstm_out, hidden = self.net(packed_words)
lstm_out, _ = pad_packed_sequence(lstm_out)
return lstm_out

View file

@ -0,0 +1,131 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
# Not a parameter
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
class AttentionLayer(nn.Module):
def __init__(self):
super(AttentionLayer, self).__init__()
def forward(self, Q, K, V):
# Q: float32:[batch_size, n_queries, d_k]
# K: float32:[batch_size, n_keys, d_k]
# V: float32:[batch_size, n_keys, d_v]
dk = K.shape[-1]
dv = V.shape[-1]
KT = torch.transpose(K, -1, -2)
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
# weight_logits: float32[batch_size, n_queries, n_keys]
weights = F.softmax(weight_logits, dim=-1)
# weight: float32[batch_size, n_queries, n_keys]
return torch.bmm(weights, V)
# return float32[batch_size, n_queries, dv]
class MultiHeadedSelfAttentionLayer(nn.Module):
def __init__(self, d_model, n_heads):
super(MultiHeadedSelfAttentionLayer, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
print('{} {}'.format(d_model, n_heads))
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.d_v = self.d_k
self.attention_layer = AttentionLayer()
self.W_Qs = nn.ModuleList([
nn.Linear(d_model, self.d_k, bias=False)
for _ in range(n_heads)
])
self.W_Ks = nn.ModuleList([
nn.Linear(d_model, self.d_k, bias=False)
for _ in range(n_heads)
])
self.W_Vs = nn.ModuleList([
nn.Linear(d_model, self.d_v, bias=False)
for _ in range(n_heads)
])
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
# x:float32[batch_size, sequence_length, self.d_model]
head_outputs = []
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
Q = W_Q(x)
# Q float32:[batch_size, sequence_length, self.d_k]
K = W_K(x)
# Q float32:[batch_size, sequence_length, self.d_k]
V = W_V(x)
# Q float32:[batch_size, sequence_length, self.d_v]
head_output = self.attention_layer(Q, K, V)
# float32:[batch_size, sequence_length, self.d_v]
head_outputs.append(head_output)
concatenated = torch.cat(head_outputs, dim=-1)
# concatenated float32:[batch_size, sequence_length, self.d_model]
out = self.W_O(concatenated)
# out float32:[batch_size, sequence_length, self.d_model]
return out
class Feedforward(nn.Module):
def __init__(self, d_model):
super(Feedforward, self).__init__()
self.d_model = d_model
self.W1 = nn.Linear(d_model, d_model)
self.W2 = nn.Linear(d_model, d_model)
def forward(self, x):
# x: float32[batch_size, sequence_length, d_model]
return self.W2(torch.relu(self.W1(x)))
class Transformer(nn.Module):
def __init__(self, d_model, n_heads, n_layers):
super(Transformer, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.attention_layers = nn.ModuleList([
MultiHeadedSelfAttentionLayer(d_model, n_heads)
for _ in range(n_layers)
])
self.ffs = nn.ModuleList([
Feedforward(d_model)
for _ in range(n_layers)
])
def forward(self, x):
# x: float32[batch_size, sequence_length, self.d_model]
for attention_layer, ff in zip(self.attention_layers, self.ffs):
attention_out = attention_layer(x)
# attention_out: float32[batch_size, sequence_length, self.d_model]
x = F.layer_norm(x + attention_out, x.shape[2:])
ff_out = ff(x)
# ff_out: float32[batch_size, sequence_length, self.d_model]
x = F.layer_norm(x + ff_out, x.shape[2:])
return x

View file

@ -0,0 +1 @@
from .main import *

View file

@ -0,0 +1,86 @@
import json
import math
import os
import random
import time
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size, embeddings):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding.from_pretrained(embeddings)
self.gru = nn.GRU(input_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output = embedded
output, hidden = self.gru(output, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size)
class AttnDecoderRNN(nn.Module):
def __init__(
self,
input_size,
hidden_size,
output_size,
embeddings,
dropout_p,
max_length,
):
super(AttnDecoderRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding.from_pretrained(embeddings) #for paragen
#self.embedding = nn.Embedding(len(embeddings), 300) #for NMT with tamil, trying wiht senitment too
self.attn = nn.Linear(self.input_size + self.hidden_size, self.max_length)
self.attn_combine = nn.Linear(
self.input_size + self.hidden_size, self.hidden_size
)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs, fixations):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
attn_weights = F.softmax(
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1
)
attn_weights = attn_weights * torch.nn.ConstantPad1d((0, attn_weights.shape[-1] - fixations.shape[-2]), 0)(fixations.squeeze().unsqueeze(0))
# attn_weights = torch.softmax(attn_weights * torch.nn.ConstantPad1d((0, attn_weights.shape[-1] - fixations.shape[-2]), 0)(fixations.squeeze().unsqueeze(0)), dim=1)
attn_applied = torch.bmm(
attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)
)
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
# output = F.log_softmax(self.out(output[0]), dim=1)
output = self.out(output[0])
# output = F.log_softmax(output, dim=1)
return output, hidden, attn_weights

View file

@ -0,0 +1,225 @@
import json
import logging
import math
import os
import random
import re
import time
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from nltk.translate.bleu_score import sentence_bleu
import numpy as np
import torch
import torch.nn as nn
import config
plt.switch_backend("agg")
def load_glove(vocabulary):
logger = logging.getLogger(f"{__name__}.load_glove")
logger.info("loading embeddings")
try:
with open(f"glove.cache") as h:
cache = json.load(h)
except:
logger.info("cache doesn't exist")
cache = {}
cache[config.PAD] = [0] * 300
cache[config.SOS] = [0] * 300
cache[config.EOS] = [0] * 300
cache[config.UNK] = [0] * 300
cache[config.NOFIX] = [0] * 300
else:
logger.info("cache found")
cache_miss = False
if not set(vocabulary) <= set(cache):
cache_miss = True
logger.warn("cache miss, loading full embeddings")
data = {}
with open("glove.840B.300d.txt") as h:
for line in h:
word, *emb = line.strip().split()
try:
data[word] = [float(x) for x in emb]
except:
continue
logger.info("finished loading full embeddings")
for word in vocabulary:
try:
cache[word] = data[word]
except KeyError:
cache[word] = [0] * 300
logger.info("cache updated")
embeddings = []
for word in vocabulary:
embeddings.append(torch.tensor(cache[word], dtype=torch.float32))
embeddings = torch.stack(embeddings)
if cache_miss:
with open(f"glove.cache", "w") as h:
json.dump(cache, h)
logger.info("cache saved")
return embeddings
def tokenize(s):
s = s.lower().strip()
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
s = s.split(" ")
return s
def indices_from_sentence(word2index, sentence, unknown_threshold):
if unknown_threshold:
return [
word2index.get(
word if random.random() > unknown_threshold else config.UNK,
word2index[config.UNK],
)
for word in sentence
]
else:
return [
word2index.get(word, word2index[config.UNK]) for word in sentence
]
def tensor_from_sentence(word2index, sentence, unknown_threshold):
# indices = [config.SOS]
indices = indices_from_sentence(word2index, sentence, unknown_threshold)
indices.append(word2index[config.EOS])
return torch.tensor(indices, dtype=torch.long, device=config.DEV)
def tensors_from_pair(word2index, pair, shuffle, unknown_threshold):
tensors = [
tensor_from_sentence(word2index, pair[0], unknown_threshold),
tensor_from_sentence(word2index, pair[1], unknown_threshold),
]
if shuffle:
random.shuffle(tensors)
return tensors
def bleu(reference, hypothesis, n=4): #not sure if this actually changes the n gram
if n < 1:
return 0
weights = [1/n]*n
return sentence_bleu([reference], hypothesis, weights)
def pair_iter(pairs, word2index, shuffle=False, shuffle_pairs=False, unknown_threshold=0.00):
if shuffle:
pairs = pairs.copy()
random.shuffle(pairs)
for pair in pairs:
tensor1, tensor2 = tensors_from_pair(word2index, (pair[0], pair[1]), shuffle_pairs, unknown_threshold)
yield (tensor1,), (tensor2,)
def sent_iter(sents, word2index, unknown_threshold=0.00):
for sent in sents:
tensor = tensor_from_sentence(word2index, sent, unknown_threshold)
yield (tensor,)
def batch_iter(pairs, word2index, batch_size, shuffle=False, unknown_threshold=0.00):
for i in range(len(pairs) // batch_size):
batch = pairs[i : i + batch_size]
if len(batch) != batch_size:
continue
batch_tensors = [
tensors_from_pair(word2index, (pair[0], pair[1]), shuffle, unknown_threshold)
for pair in batch
]
tensors1, tensors2 = zip(*batch_tensors)
# targets = torch.tensor(targets, dtype=torch.long, device=config.DEV)
# tensors1_lengths = [len(t) for t in tensors1]
# tensors2_lengths = [len(t) for t in tensors2]
# tensors1 = nn.utils.rnn.pack_sequence(tensors1, enforce_sorted=False)
# tensors2 = nn.utils.rnn.pack_sequence(tensors2, enforce_sorted=False)
yield tensors1, tensors2
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return "%dm %ds" % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return "%s (- %s)" % (asMinutes(s), asMinutes(rs))
def showPlot(points):
plt.figure()
fig, ax = plt.subplots()
# this locator puts ticks at regular intervals
loc = ticker.MultipleLocator(base=0.2)
ax.yaxis.set_major_locator(loc)
plt.plot(points)
def showAttention(input_sentence, output_words, attentions):
# Set up figure with colorbar
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(attentions.numpy(), cmap="bone")
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([""] + input_sentence.split(" ") + ["<__EOS__>"], rotation=90)
ax.set_yticklabels([""] + output_words)
# Show label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
def evaluateAndShowAttention(input_sentence):
output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence)
print("input =", input_sentence)
print("output =", " ".join(output_words))
showAttention(input_sentence, output_words, attentions)
def save_model(model, word2index, path):
if not path.endswith(".tar"):
path += ".tar"
torch.save(
{"weights": model.state_dict(), "word2index": word2index},
path,
)
def load_model(path):
checkpoint = torch.load(path)
return checkpoint["weights"], checkpoint["word2index"]
def extend_vocabulary(word2index, langs):
for lang in langs:
for word in lang.word2index:
if word not in word2index:
word2index[word] = len(word2index)
return word2index

View file

@ -0,0 +1,794 @@
import logging
import os
import pathlib
import random
import sys
import click
import sacrebleu
import torch
import torch.nn as nn
import tqdm
import config
from libs import corpora
from libs import utils
from libs.fixation_generation import Network as FixNN
from libs.paraphrase_generation import (
EncoderRNN as ParEncNN,
AttnDecoderRNN as ParDecNN,
)
cwd = os.path.dirname(__file__)
logger = logging.getLogger("main")
'''
#qqp_paw sentences:
debug_sentences = [
'What are the driving rules in Georgia versus Mississippi ?',
'I want to be a child psychologist , what qualification do i need to become one ? Are there good and reputed psychology Institute or Colleges in India ?',
'What is deep web and dark web and what are the contents of these sites ?',
'What is difference between North Indian Brahmins and South Indian Brahmins ?',
'Is carbon dioxide an ionic bond or a covalent bond ?',
'How do accounts receivable and accounts payable differ ?',
'Why did Wikipedia hide its audit history for ( superluminal ) successful speed experiments ?',
'What makes a person simple , or inversely , complicated ?',
'`` How do you say `` Miss you , too , `` in Spanish ? Are there multiple ways to say it ? ',
'What is the difference between dominant trait and recessive trait ?',
'`` What is the difference between `` seeing someone , `` `` dating someone , `` and `` having a girlfriend/boyfriend `` ? ',
'How was the Empire State building built and designed ? How is it used ?',
'What is the sum of the square roots of first n natural number ?',
'Why is Roman Saini not so active on Quora now a days ?',
'If I have someone blocked on Instagram , and see their story , can they view I viewed it ?',
'Amongst the major IT companies of India which is the best ; Wipro , Capgemini , Infosys , TCS or is Oracle the best ?',
'How much mass does Saturn lose each year ? How much mass does it gain ?',
'What is a cheap healthy diet , I can keep the same and eat every day ?',
' What is it like to be beautiful ? Not just pretty or hot , but the kind of almost objective beauty that people are sometimes intimidated by ?',
'Could someone tell of a James Ronsey ( misspelled likely ) , writer and filmmaker , probably of the British Isles ?',
'How much pressure Is there around the core of Pluto ? is it enough to turn hydrogen/helium gas into a liquid or metallic state ?',
'How does quality of life in Vancouver compare to that in Melbourne Or Brisbane ?',
]
'''
'''
#wiki sentences:
debug_sentences = [
'They were there to enjoy us and they were there to pray for us .',
'Components of elastic potential systems store mechanical energy if they are deformed when forces are applied to the system .',
'Steam can also be used , and does not need to be pumped .',
'The solar approach to this requirement is the use of solar panels in a conventional-powered aircraft .',
'Daudkhali is a village in Barisal Division in the Pirojpur district in southwestern Bangladesh .',
'Briggs later met Briggs at the 1967 Monterey Pop Festival , where Ravi Shankar was also performing , with Eric Burdon and The Animals .',
'Brockton is approximately 25 miles northeast of Providence , Rhode Island , and 30 miles south of Boston .',
]
'''
#qqp sentences:
debug_sentences = [
'How do I get funding for my web based startup idea ?',
'What do intelligent people do to pass time ?',
'Which is the best SEO Company in Delhi ?',
'Why do you waer makeup ?',
'How do start chatting with a girl ?',
'What is the meaning of living life ?',
'Why do my armpits hurt ?',
'Why does eye color change with age ?',
'How do you find the standard deviation of a probability distribution ? What are some examples ?',
'How can I complete my 11 syllabus in one month ?',
'How do I concentrate better on my studies ?',
'Which is the best retirement plan in india ?',
'Should I tell my best friend I love her ?',
'Which is the best company for Appian Vagrant online job support ?',
'How can one do for good handwriting ?',
'What are remedies to get rid of belly fat ?',
'What is the best way to cook precooked turkey ?',
'What is the future of e-commerce in India ?',
'Why do my burps taste like rotten eggs ?',
'What is an example of chemical weathering ?',
'What are some of the advantages and disadvantages of cyber schooling ?',
'How can I increase traffic to my websites by Facebook ?',
'How do I increase my patience level in life ?',
'What are the best hospitals for treating cancer in India ?',
'Will Jio sim work in a 3G phone ? If yes , how ?',
]
debug_sentences = [s.split(" ") for s in debug_sentences]
class Network(nn.Module):
def __init__(
self,
word2index,
embeddings,
):
super().__init__()
self.logger = logging.getLogger(f"{__name__}")
self.word2index = word2index
self.index2word = {i: k for k, i in word2index.items()}
self.fix_gen = FixNN(
embedding_type="glove",
vocab_size=len(word2index),
embedding_dim=config.embedding_dim,
embeddings=embeddings,
dropout=config.fix_dropout,
hidden_dim=config.fix_hidden_dim,
)
self.par_enc = ParEncNN(
input_size=config.embedding_dim,
hidden_size=config.par_hidden_dim,
embeddings=embeddings,
)
self.par_dec = ParDecNN(
input_size=config.embedding_dim,
hidden_size=config.par_hidden_dim,
output_size=len(word2index),
embeddings=embeddings,
dropout_p=config.par_dropout,
max_length=config.max_length,
)
def forward(self, x, target=None, teacher_forcing_ratio=None):
teacher_forcing_ratio = teacher_forcing_ratio if teacher_forcing_ratio is not None else config.teacher_forcing_ratio
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
x2 = nn.utils.rnn.pad_sequence(x, batch_first=False)
fixations = torch.sigmoid(self.fix_gen(x1, [len(_x) for _x in x1]))
enc_hidden = self.par_enc.initHidden().to(config.DEV)
enc_outs = torch.zeros(config.max_length, config.par_hidden_dim, device=config.DEV)
for ei in range(len(x2)):
enc_out, enc_hidden = self.par_enc(x2[ei], enc_hidden)
enc_outs[ei] += enc_out[0, 0]
dec_in = torch.tensor([[self.word2index[config.SOS]]], device=config.DEV) # SOS
dec_hidden = enc_hidden
dec_outs = []
dec_words = []
dec_atts = torch.zeros(config.max_length, config.max_length)
if target is not None: # training
target = nn.utils.rnn.pad_sequence(target, batch_first=False)
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
if use_teacher_forcing:
for di in range(len(target)):
dec_out, dec_hidden, dec_att = self.par_dec(
dec_in, dec_hidden, enc_outs, fixations
)
dec_outs.append(dec_out)
dec_atts[di] = dec_att.data
dec_input = target[di]
else:
for di in range(len(target)):
dec_out, dec_hidden, dec_att = self.par_dec(
dec_in, dec_hidden, enc_outs, fixations
)
dec_outs.append(dec_out)
dec_atts[di] = dec_att.data
topv, topi = dec_out.data.topk(1)
dec_words.append(self.index2word[topi.item()])
dec_input = topi.squeeze().detach()
else: # prediction
for di in range(config.max_length):
dec_out, dec_hidden, dec_att = self.par_dec(
dec_in, dec_hidden, enc_outs, fixations
)
dec_outs.append(dec_out)
dec_atts[di] = dec_att.data
topv, topi = dec_out.data.topk(1)
if topi.item() == self.word2index[config.EOS]:
dec_words.append("<__EOS__>")
break
else:
dec_words.append(self.index2word[topi.item()])
dec_input = topi.squeeze().detach()
return dec_outs, dec_words, dec_atts[: di + 1], fixations
def load_corpus(corpus_name, splits):
if not splits:
return
logger.info("loading corpus")
if corpus_name == "msrpc":
load_fn = corpora.load_msrpc
elif corpus_name == "qqp":
load_fn = corpora.load_qqp
elif corpus_name == "wiki":
load_fn = corpora.load_wiki
elif corpus_name == "qqp_paws":
load_fn = corpora.load_qqp_paws
elif corpus_name == "qqp_kag":
load_fn = corpora.load_qqp_kag
elif corpus_name == "sentiment":
load_fn = corpora.load_sentiment
elif corpus_name == "stanford":
load_fn = corpora.load_stanford
elif corpus_name == "stanford_sent":
load_fn = corpora.load_stanford_sent
elif corpus_name == "tamil":
load_fn = corpora.load_tamil
elif corpus_name == "compression":
load_fn = corpora.load_compression
corpus = {}
langs = []
if "train" in splits:
train_pairs, train_lang = load_fn("train")
corpus["train"] = train_pairs
langs.append(train_lang)
if "val" in splits:
val_pairs, val_lang = load_fn("val")
corpus["val"] = val_pairs
langs.append(val_lang)
if "test" in splits:
test_pairs, test_lang = load_fn("test")
corpus["test"] = test_pairs
langs.append(test_lang)
logger.info("creating word index")
lang = langs[0]
for _lang in langs[1:]:
lang += _lang
word2index = lang.word2index
index2word = {i: w for w, i in word2index.items()}
return corpus, word2index, index2word
def init_network(word2index):
logger.info("loading embeddings")
vocabulary = sorted(word2index.keys())
embeddings = utils.load_glove(vocabulary)
logger.info("initializing model")
network = Network(
word2index=word2index,
embeddings=embeddings,
)
network.to(config.DEV)
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
return network
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
@click.option("-v", "--verbose", count=True)
@click.option("-d", "--debug", is_flag=True)
def main(verbose, debug):
if verbose == 0:
loglevel = logging.ERROR
elif verbose == 1:
loglevel = logging.WARN
elif verbose >= 2:
loglevel = logging.INFO
if debug:
loglevel = logging.DEBUG
logging.basicConfig(
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
datefmt="%d.%m. %H:%M:%S",
level=loglevel,
)
logger.debug("arguments: %s" % str(sys.argv))
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
)
@click.option("-m", "--model_name", required=True)
@click.option("-w", "--fixation_weights", required=False)
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
def train(corpus_name, model_name, fixation_weights, freeze_fixations, bleu):
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
train_pairs = corpus["train"]
val_pairs = corpus["val"]
network = init_network(word2index)
model_dir = os.path.join("models", model_name)
logger.debug("creating model dir %s" % model_dir)
pathlib.Path(model_dir).mkdir(parents=True)
if fixation_weights is not None:
logger.info("loading fixation prediction checkpoint")
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
else:
weights = checkpoint
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")}
network.fix_gen.load_state_dict(weights, strict=False)
if freeze_fixations:
logger.info("freezing fixation generation network")
for p in network.fix_gen.parameters():
p.requires_grad = False
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
#optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=1e-5)
best_val_loss = None
epoch = 1
while True:
train_batch_iter = utils.pair_iter(pairs=train_pairs, word2index=word2index, shuffle=True, shuffle_pairs=False)
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
# test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
running_train_loss = 0
total_train_loss = 0
total_val_loss = 0
if bleu == "sacrebleu":
running_train_bleu = 0
total_train_bleu = 0
total_val_bleu = 0
elif bleu == "nltk":
running_train_bleu_1 = 0
running_train_bleu_2 = 0
running_train_bleu_3 = 0
running_train_bleu_4 = 0
total_train_bleu_1 = 0
total_train_bleu_2 = 0
total_train_bleu_3 = 0
total_train_bleu_4 = 0
total_val_bleu_1 = 0
total_val_bleu_2 = 0
total_val_bleu_3 = 0
total_val_bleu_4 = 0
network.train()
for i, batch in enumerate(train_batch_iter, 1):
optimizer.zero_grad()
input, target = batch
prediction, words, attention, fixations = network(input, target)
loss = loss_fn(
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
)
loss.backward()
optimizer.step()
running_train_loss += loss.item()
total_train_loss += loss.item()
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
if bleu == "sacrebleu":
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
running_train_bleu += bleu_score
total_train_bleu += bleu_score
elif bleu == "nltk":
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
running_train_bleu_1 += bleu_1_score
running_train_bleu_2 += bleu_2_score
running_train_bleu_3 += bleu_3_score
running_train_bleu_4 += bleu_4_score
total_train_bleu_1 += bleu_1_score
total_train_bleu_2 += bleu_2_score
total_train_bleu_3 += bleu_3_score
total_train_bleu_4 += bleu_4_score
# print(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist())
if i % 100 == 0:
if bleu == "sacrebleu":
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
elif bleu == "nltk":
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
network.eval()
with open(os.path.join(model_dir, f"debug_{epoch}_{i}.out"), "w") as h:
if bleu == "sacrebleu":
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
running_train_bleu = 0
elif bleu == "nltk":
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
running_train_bleu_1 = 0
running_train_bleu_2 = 0
running_train_bleu_3 = 0
running_train_bleu_4 = 0
running_train_loss = 0
h.write("\n")
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
h.write("\n")
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
prediction, words, attentions, fixations = network(input)
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
h.write("\n")
network.train()
network.eval()
for i, batch in enumerate(val_batch_iter):
input, target = batch
prediction, words, attention, fixations = network(input, target, teacher_forcing_ratio=0)
loss = loss_fn(
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
)
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
if bleu == "sacrebleu":
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
elif bleu == "nltk":
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
total_val_bleu_1 += bleu_1_score
total_val_bleu_2 += bleu_2_score
total_val_bleu_3 += bleu_3_score
total_val_bleu_4 += bleu_4_score
total_val_loss += loss.item()
avg_val_loss = total_val_loss/len(val_pairs)
if bleu == "sacrebleu":
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu {total_train_bleu/len(train_pairs):.2f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
elif bleu == "nltk":
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
with open(os.path.join(model_dir, f"debug_{epoch}_end.out"), "w") as h:
if bleu == "sacrebleu":
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs)} avg_val_loss {total_val_loss/len(val_pairs)} avg_train_bleu {total_train_bleu/len(train_pairs)} avg_val_bleu {total_val_bleu/len(val_pairs)}")
elif bleu == "nltk":
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
h.write("\n")
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
h.write("\n")
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
prediction, words, attentions, fixations = network(input)
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
h.write("\n")
utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_{epoch}"))
if best_val_loss is None or avg_val_loss < best_val_loss:
if best_val_loss is not None:
logger.info(f"{avg_val_loss} < {best_val_loss} ({avg_val_loss-best_val_loss}): new best model from epoch {epoch}")
else:
logger.info(f"{avg_val_loss} < {best_val_loss}: new best model from epoch {epoch}")
best_val_loss = avg_val_loss
# save_model(model, word2index, model_name + "_epoch_" + str(epoch))
# utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_best"))
epoch += 1
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-s", "--sentence_statistics", is_flag=True)
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
def val(corpus_name, model_weights, sentence_statistics, bleu):
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
val_pairs = corpus["val"]
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index)
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not "embedding" in k}
# make a new output layer to match the weights from the checkpoint
# we cannot remove it like we did with the embedding layers because
# unlike those the output layer actually contains learned parameters
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
# actually load the parameters
network.load_state_dict(weights, strict=False)
loss_fn = nn.CrossEntropyLoss()
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
total_val_loss = 0
if bleu == "sacrebleu":
total_val_bleu = 0
elif bleu == "nltk":
total_val_bleu_1 = 0
total_val_bleu_2 = 0
total_val_bleu_3 = 0
total_val_bleu_4 = 0
network.eval()
for i, batch in enumerate(val_batch_iter, 1):
input, target = batch
prediction, words, attentions, fixations = network(input, target)
loss = loss_fn(
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
)
total_val_loss += loss.item()
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
_target = [index2word[_x] for _x in target[0].tolist()]
if bleu == "sacrebleu":
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join(_target)).score
total_val_bleu += bleu_score
elif bleu == "nltk":
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
total_val_bleu_1 += bleu_1_score
total_val_bleu_2 += bleu_2_score
total_val_bleu_3 += bleu_3_score
total_val_bleu_4 += bleu_4_score
if sentence_statistics:
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
if bleu == "sacrebleu":
print(f"{bleu_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
elif bleu == "nltk":
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
if bleu == "sacrebleu":
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
elif bleu == "nltk":
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-s", "--sentence_statistics", is_flag=True)
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
def test(corpus_name, model_weights, sentence_statistics, bleu):
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
test_pairs = corpus["test"]
if model_weights is not None:
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index)
if model_weights is not None:
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not "embedding" in k}
# make a new output layer to match the weights from the checkpoint
# we cannot remove it like we did with the embedding layers because
# unlike those the output layer actually contains learned parameters
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
# actually load the parameters
network.load_state_dict(weights, strict=False)
loss_fn = nn.CrossEntropyLoss()
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
total_test_loss = 0
if bleu == "sacrebleu":
total_test_bleu = 0
elif bleu == "nltk":
total_test_bleu_1 = 0
total_test_bleu_2 = 0
total_test_bleu_3 = 0
total_test_bleu_4 = 0
network.eval()
for i, batch in enumerate(test_batch_iter, 1):
input, target = batch
prediction, words, attentions, fixations = network(input, target)
loss = loss_fn(
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
)
total_test_loss += loss.item()
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
_target = [index2word[_x] for _x in target[0].tolist()]
if bleu == "sacrebleu":
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join( _target)).score
total_test_bleu += bleu_score
elif bleu == "nltk":
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
total_test_bleu_1 += bleu_1_score
total_test_bleu_2 += bleu_2_score
total_test_bleu_3 += bleu_3_score
total_test_bleu_4 += bleu_4_score
if sentence_statistics:
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
if bleu == "sacrebleu":
print(f"{bleu_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
elif bleu == "nltk":
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
if bleu == "sacrebleu":
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu {total_test_bleu/len(test_pairs):.2f}")
elif bleu == "nltk":
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu_1 {total_test_bleu_1/len(test_pairs):.2f} avg_test_bleu_2 {total_test_bleu_2/len(test_pairs):.2f} avg_test_bleu_3 {total_test_bleu_3/len(test_pairs):.2f} avg_test_bleu_4 {total_test_bleu_4/len(test_pairs):.2f}")
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
)
@click.option("-w", "--model_weights", required=False)
def predict(corpus_name, model_weights):
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
test_pairs = corpus["val"]
if model_weights is not None:
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index)
logger.info(f"vocab size {len(word2index)}")
if model_weights is not None:
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not "embedding" in k}
# make a new output layer to match the weights from the checkpoint
# we cannot remove it like we did with the embedding layers because
# unlike those the output layer actually contains learned parameters
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
# actually load the parameters
network.load_state_dict(weights, strict=False)
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
network.eval()
for i, batch in enumerate(test_batch_iter, 1):
input, target = batch
prediction, words, attentions, fixations = network(input, target)
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
@main.command()
@click.option("-w", "--model_weights", required=True)
@click.argument("path")
def predict_file(model_weights, path):
logger.info("loading sentences")
sentences = []
lang = corpora.Lang("pred")
with open(path) as h:
for line in h:
line = line.strip()
if line:
sentence = line.split(" ")
lang.add_sentence(sentence)
sentences.append(sentence)
word2index = lang.word2index
index2word = {i: w for w, i in word2index.items()}
logger.info(f"{len(sentences)} sentences loaded")
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index)
logger.info(f"vocab size {len(word2index)}")
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not "embedding" in k}
# make a new output layer to match the weights from the checkpoint
# we cannot remove it like we did with the embedding layers because
# unlike those the output layer actually contains learned parameters
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
# actually load the parameters
network.load_state_dict(weights, strict=False)
debug_sentence_iter = utils.sent_iter(sentences, word2index=word2index)
network.eval()
for i, input in enumerate(debug_sentence_iter, 1):
prediction, words, attentions, fixations = network(input)
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,19 @@
click==7.1.2
cycler==0.10.0
dataclasses==0.6
future==0.18.2
joblib==0.17.0
kiwisolver==1.3.1
matplotlib==3.3.3
nltk==3.5
numpy==1.19.4
Pillow==8.0.1
portalocker==2.0.0
pyparsing==2.4.7
python-dateutil==2.8.1
regex==2020.11.13
sacrebleu==1.4.6
six==1.15.0
torch==1.7.0
tqdm==4.54.1
typing-extensions==3.7.4.3

View file

@ -0,0 +1,43 @@
import os
import sys
import click
def read(path):
with open(path) as h:
for line in h:
line = line.strip()
try:
b, s, p, a, f = line.split("\t")
except:
print(f"skipping line {line}", file=sys.stderr)
continue
else:
yield b, s, p, a, f
@click.command()
@click.argument("path")
def main(path):
data = list(read(path))
avg_len = sum(len(x[1]) for x in data)/len(data)
filtered_data = []
filtered_data2 = []
fname, ext = os.path.splitext(path)
ext = f".{ext}" if ext else ext
with open(f"{fname}_long{ext}", "w") as lh, open(f"{fname}_short{ext}", "w") as sh:
for x in data:
if len(x[1]) > avg_len:
lh.write("\t".join(x))
lh.write("\n")
else:
sh.write("\t".join(x))
sh.write("\n")
print(f"avg sentence length {avg_len}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,40 @@
import sys
import click
def read(path):
with open(path) as h:
for line in h:
line = line.strip()
try:
b, s, p, *_ = line.split("\t")
except:
print(f"skipping line {line}", file=sys.stderr)
continue
else:
yield float(b), s, p
@click.command()
@click.argument("path")
def main(path):
data = list(read(path))
avg_len = sum(len(x[1]) for x in data)/len(data)
filtered_data = []
filtered_data2 = []
for x in data:
if len(x[1]) > avg_len:
filtered_data.append(x)
else:
filtered_data2.append(x)
print(f"avg sentence length {avg_len}")
print(f"long sentences {len(filtered_data)}")
print(f"short sentences {len(filtered_data2)}")
print(f"total bleu {sum(x[0] for x in data)/len(data)}")
print(f"longest bleu {sum(x[0] for x in filtered_data)/len(filtered_data)}")
print(f"shortest bleu {sum(x[0] for x in filtered_data2)/len(filtered_data2)}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,64 @@
import ast
import os
import pathlib
import text_attention
import click
import matplotlib.pyplot as plt
plt.switch_backend("agg")
import matplotlib.ticker as ticker
import numpy as np
import tqdm
def plot_attention(input_sentence, output_words, attentions, path):
# Set up figure with colorbar
attentions = np.array(attentions)[:,:len(input_sentence)]
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(attentions, cmap="bone")
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
ax.set_yticklabels([""] + output_words)
# Show label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.savefig(f"{path}.pdf")
plt.close()
def parse(p):
with open(p) as h:
for line in h:
if not line or line.startswith("#"):
continue
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
try:
sentence = ast.literal_eval(_sentence)
prediction = ast.literal_eval(_prediction)
attention = ast.literal_eval(_attention)
except:
continue
yield sentence, prediction, attention
@click.command()
@click.argument("path", nargs=-1, required=True)
def main(path):
for p in tqdm.tqdm(path):
out_dir = os.path.splitext(p)[0]
if out_dir == path:
out_dir = f"{out_dir}_"
pathlib.Path(out_dir).mkdir(exist_ok=True)
for i, spa in enumerate(parse(p)):
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Date: 2019-03-29 16:10:23
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
# @Last Modified time: 2019-04-12 09:56:12
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
import numpy as np
latex_special_token = ["!@#$%^&*()"]
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
assert(len(text_list) == len(attention_list))
if rescale_value:
attention_list = rescale(attention_list)
word_num = len(text_list)
text_list = clean_word(text_list)
with open(latex_file,'w') as f:
f.write(r'''\documentclass[varwidth]{standalone}
\special{papersize=210mm,297mm}
\usepackage{color}
\usepackage{tcolorbox}
\usepackage{CJK}
\usepackage{adjustbox}
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
\begin{document}
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
for idx in range(word_num):
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
string += "\n}}}"
f.write(string+'\n')
f.write(r'''\end{CJK*}
\end{document}''')
def rescale(input_list):
the_array = np.asarray(input_list)
the_max = np.max(the_array)
the_min = np.min(the_array)
rescale = (the_array - the_min)/(the_max-the_min)*100
return rescale.tolist()
def clean_word(word_list):
new_word_list = []
for word in word_list:
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
if latex_sensitive in word:
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
new_word_list.append(word)
return new_word_list
if __name__ == '__main__':
## This is a demo:
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
sent = '''我 回忆 起 我 曾经 在 大学 年代 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar 我 想起 了 西游记 里 的 琵琶精 。
今年 下半年 合拍 西游记 即将 正式 开机 继续 扮演 美猴王 孙悟空 美猴王 艺术 形象 努力 创造 正能量 形象 开花 弘扬 中华 文化 希望 大家 多多 关注 '''
words = sent.split()
word_num = len(words)
attention = [(x+1.)/word_num*100 for x in range(word_num)]
import random
random.seed(42)
random.shuffle(attention)
color = 'red'
generate(words, attention, "sample.tex", color)

View file

@ -0,0 +1,428 @@
# Created by https://www.toptal.com/developers/gitignore/api/python,latex
# Edit at https://www.toptal.com/developers/gitignore?templates=python,latex
### LaTeX ###
## Core latex/pdflatex auxiliary files:
*.aux
*.lof
*.log
*.lot
*.fls
*.out
*.toc
*.fmt
*.fot
*.cb
*.cb2
.*.lb
## Intermediate documents:
*.dvi
*.xdv
*-converted-to.*
# these rules might exclude image files for figures etc.
# *.ps
# *.eps
# *.pdf
## Generated if empty string is given at "Please type another file name for output:"
.pdf
## Bibliography auxiliary files (bibtex/biblatex/biber):
*.bbl
*.bcf
*.blg
*-blx.aux
*-blx.bib
*.run.xml
## Build tool auxiliary files:
*.fdb_latexmk
*.synctex
*.synctex(busy)
*.synctex.gz
*.synctex.gz(busy)
*.pdfsync
## Build tool directories for auxiliary files
# latexrun
latex.out/
## Auxiliary and intermediate files from other packages:
# algorithms
*.alg
*.loa
# achemso
acs-*.bib
# amsthm
*.thm
# beamer
*.nav
*.pre
*.snm
*.vrb
# changes
*.soc
# comment
*.cut
# cprotect
*.cpt
# elsarticle (documentclass of Elsevier journals)
*.spl
# endnotes
*.ent
# fixme
*.lox
# feynmf/feynmp
*.mf
*.mp
*.t[1-9]
*.t[1-9][0-9]
*.tfm
#(r)(e)ledmac/(r)(e)ledpar
*.end
*.?end
*.[1-9]
*.[1-9][0-9]
*.[1-9][0-9][0-9]
*.[1-9]R
*.[1-9][0-9]R
*.[1-9][0-9][0-9]R
*.eledsec[1-9]
*.eledsec[1-9]R
*.eledsec[1-9][0-9]
*.eledsec[1-9][0-9]R
*.eledsec[1-9][0-9][0-9]
*.eledsec[1-9][0-9][0-9]R
# glossaries
*.acn
*.acr
*.glg
*.glo
*.gls
*.glsdefs
*.lzo
*.lzs
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
# *.ist
# gnuplottex
*-gnuplottex-*
# gregoriotex
*.gaux
*.gtex
# htlatex
*.4ct
*.4tc
*.idv
*.lg
*.trc
*.xref
# hyperref
*.brf
# knitr
*-concordance.tex
# TODO Comment the next line if you want to keep your tikz graphics files
*.tikz
*-tikzDictionary
# listings
*.lol
# luatexja-ruby
*.ltjruby
# makeidx
*.idx
*.ilg
*.ind
# minitoc
*.maf
*.mlf
*.mlt
*.mtc[0-9]*
*.slf[0-9]*
*.slt[0-9]*
*.stc[0-9]*
# minted
_minted*
*.pyg
# morewrites
*.mw
# nomencl
*.nlg
*.nlo
*.nls
# pax
*.pax
# pdfpcnotes
*.pdfpc
# sagetex
*.sagetex.sage
*.sagetex.py
*.sagetex.scmd
# scrwfile
*.wrt
# sympy
*.sout
*.sympy
sympy-plots-for-*.tex/
# pdfcomment
*.upa
*.upb
# pythontex
*.pytxcode
pythontex-files-*/
# tcolorbox
*.listing
# thmtools
*.loe
# TikZ & PGF
*.dpth
*.md5
*.auxlock
# todonotes
*.tdo
# vhistory
*.hst
*.ver
# easy-todo
*.lod
# xcolor
*.xcp
# xmpincl
*.xmpi
# xindy
*.xdy
# xypic precompiled matrices and outlines
*.xyc
*.xyd
# endfloat
*.ttt
*.fff
# Latexian
TSWLatexianTemp*
## Editors:
# WinEdt
*.bak
*.sav
# Texpad
.texpadtmp
# LyX
*.lyx~
# Kile
*.backup
# gummi
.*.swp
# KBibTeX
*~[0-9]*
# TeXnicCenter
*.tps
# auto folder when using emacs and auctex
./auto/*
*.el
# expex forward references with \gathertags
*-tags.tex
# standalone packages
*.sta
# Makeindex log files
*.lpz
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
# Uncomment the next line to have this generated file ignored.
#*Notes.bib
### LaTeX Patch ###
# LIPIcs / OASIcs
*.vtc
# glossaries
*.glstex
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# End of https://www.toptal.com/developers/gitignore/api/python,latex

View file

@ -0,0 +1,3 @@
# joint_sentence_compression
joint training for sentence compression -- neurips submission

View file

@ -0,0 +1,39 @@
import os
import torch
# general
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PAD = "<__PAD__>"
UNK = "<__UNK__>"
NOFIX = "<__NOFIX__>"
SOS = "<__SOS__>"
EOS = "<__EOS__>"
batch_size = 1
teacher_forcing_ratio = 0.5
embedding_dim = 300
fix_hidden_dim = 128
sem_hidden_dim = 1024
fix_dropout = 0.5
par_dropout = 0.2
_fix_learning_rate = 0.00001
_par_learning_rate = 0.0001
learning_rate = _par_learning_rate
fix_momentum = 0.9
par_momentum = 0.0
max_length = 851
epochs = 5
# paths
data_path = "./data"
emb_path = os.path.join(data_path, "Google_word2vec/GoogleNews-vectors-negative300.bin")
glove_path = "glove.840B.300d.txt"
google_path = os.path.join(data_path, "datasets/sentence-compression/data")
google_train_path = os.path.join(google_path, "train_mask_token.tsv")
google_dev_path = os.path.join(google_path, "dev_mask_token.tsv")
google_test_path = os.path.join(google_path, "test_mask_token.tsv")

View file

@ -0,0 +1 @@
/netpool/work/gpu-2/users/soodea/

View file

@ -0,0 +1 @@
/netpool/work/gpu-2/users/soodea/datasets/glove/glove.840B.300d.txt

View file

@ -0,0 +1,97 @@
import logging
import config
def tokenize(sent):
return sent.split(" ")
class Lang:
"""Represents the vocabulary
"""
def __init__(self, name):
self.name = name
self.word2index = {
config.PAD: 0,
config.UNK: 1,
}
self.word2count = {}
self.index2word = {
0: config.PAD,
1: config.UNK,
}
self.n_words = 2
def add_sentence(self, sentence):
assert isinstance(
sentence, (list, tuple)
), "input to add_sentence must be tokenized"
for word in sentence:
self.add_word(word)
def add_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def __add__(self, other):
"""Returns a new Lang object containing the vocabulary from this and
the other Lang object
"""
new_lang = Lang(f"{self.name}_{other.name}")
# Add vocabulary from both Langs
for word in self.word2count.keys():
new_lang.add_word(word)
for word in other.word2count.keys():
new_lang.add_word(word)
# Fix the counts on the new one
for word in new_lang.word2count.keys():
new_lang.word2count[word] = self.word2count.get(
word, 0
) + other.word2count.get(word, 0)
return new_lang
def load_google(split, max_len=None):
"""Load the Google Sentence Compression Dataset"""
logger = logging.getLogger(f"{__name__}.load_compression")
lang = Lang("compression")
if split == "train":
path = config.google_train_path
elif split == "val":
path = config.google_dev_path
elif split == "test":
path = config.google_test_path
logger.info("loading %s from %s" % (split, path))
data = []
sent = []
mask = []
with open(path) as handle:
for line in handle:
line = line.strip()
if line:
w, d = line.split("\t")
sent.append(w)
mask.append(int(d))
else:
if sent and (max_len is None or len(sent) <= max_len):
data.append([sent, mask])
lang.add_sentence(sent)
sent = []
mask = []
if sent:
data.append([tuple(sent), tuple(mask)])
lang.add_sentence(sent)
return data, lang

View file

@ -0,0 +1 @@
from .main import *

View file

@ -0,0 +1,125 @@
from collections import OrderedDict
import logging
import sys
from .self_attention import Transformer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence, pad_sequence
def random_embedding(vocab_size, embedding_dim):
pretrain_emb = np.empty([vocab_size, embedding_dim])
scale = np.sqrt(3.0 / embedding_dim)
for index in range(vocab_size):
pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
return pretrain_emb
def neg_log_likelihood_loss(outputs, batch_label, batch_size, seq_len):
outputs = outputs.view(batch_size * seq_len, -1)
score = F.log_softmax(outputs, 1)
loss = nn.NLLLoss(ignore_index=0, size_average=False)(
score, batch_label.view(batch_size * seq_len)
)
loss = loss / batch_size
_, tag_seq = torch.max(score, 1)
tag_seq = tag_seq.view(batch_size, seq_len)
return loss, tag_seq
def mse_loss(outputs, batch_label, batch_size, seq_len, word_seq_length):
score = torch.sigmoid(outputs)
mask = torch.zeros_like(score)
for i, v in enumerate(word_seq_length):
mask[i, 0:v] = 1
score = score * mask
loss = nn.MSELoss(reduction="sum")(
score.view(batch_size, seq_len), batch_label.view(batch_size, seq_len)
)
loss = loss / batch_size
return loss, score.view(batch_size, seq_len)
class Network(nn.Module):
def __init__(
self,
embedding_type,
vocab_size,
embedding_dim,
dropout,
hidden_dim,
embeddings=None,
attention=True,
):
super().__init__()
self.logger = logging.getLogger(f"{__name__}")
self.attention = attention
prelayers = OrderedDict()
postlayers = OrderedDict()
if embedding_type in ("w2v", "glove"):
if embeddings is not None:
prelayers["embedding_layer"] = nn.Embedding.from_pretrained(embeddings, freeze=True)
else:
prelayers["embedding_layer"] = nn.Embedding(vocab_size, embedding_dim)
prelayers["embedding_dropout_layer"] = nn.Dropout(dropout)
embedding_dim = 300
elif embedding_type == "bert":
embedding_dim = 768
self.lstm = BiLSTM(embedding_dim, hidden_dim // 2, num_layers=1)
postlayers["lstm_dropout_layer"] = nn.Dropout(dropout)
if self.attention:
postlayers["attention_layer"] = Transformer(
d_model=hidden_dim, n_heads=4, n_layers=1
)
postlayers["ff_layer"] = nn.Linear(hidden_dim, hidden_dim // 2)
postlayers["ff_activation"] = nn.ReLU()
postlayers["output_layer"] = nn.Linear(hidden_dim // 2, 1)
self.logger.info(f"prelayers: {prelayers.keys()}")
self.logger.info(f"postlayers: {postlayers.keys()}")
self.pre = nn.Sequential(prelayers)
self.post = nn.Sequential(postlayers)
def forward(self, x, word_seq_length):
x = self.pre(x)
x = self.lstm(x, word_seq_length)
output = []
for _x, l in zip(x.transpose(1, 0), word_seq_length):
output.append(self.post(_x[:l].unsqueeze(0))[0])
return pad_sequence(output, batch_first=True)
class BiLSTM(nn.Module):
def __init__(self, embedding_dim, lstm_hidden, num_layers):
super().__init__()
self.net = nn.LSTM(
input_size=embedding_dim,
hidden_size=lstm_hidden,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)
def forward(self, x, word_seq_length):
packed_words = pack_padded_sequence(x, word_seq_length, True, False)
lstm_out, hidden = self.net(packed_words)
lstm_out, _ = pad_packed_sequence(lstm_out)
return lstm_out

View file

@ -0,0 +1,128 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
class AttentionLayer(nn.Module):
def __init__(self):
super(AttentionLayer, self).__init__()
def forward(self, Q, K, V):
# Q: float32:[batch_size, n_queries, d_k]
# K: float32:[batch_size, n_keys, d_k]
# V: float32:[batch_size, n_keys, d_v]
dk = K.shape[-1]
dv = V.shape[-1]
KT = torch.transpose(K, -1, -2)
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
# weight_logits: float32[batch_size, n_queries, n_keys]
weights = F.softmax(weight_logits, dim=-1)
# weight: float32[batch_size, n_queries, n_keys]
return torch.bmm(weights, V)
# return float32[batch_size, n_queries, dv]
class MultiHeadedSelfAttentionLayer(nn.Module):
def __init__(self, d_model, n_heads):
super(MultiHeadedSelfAttentionLayer, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
print('{} {}'.format(d_model, n_heads))
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.d_v = self.d_k
self.attention_layer = AttentionLayer()
self.W_Qs = nn.ModuleList([
nn.Linear(d_model, self.d_k, bias=False)
for _ in range(n_heads)
])
self.W_Ks = nn.ModuleList([
nn.Linear(d_model, self.d_k, bias=False)
for _ in range(n_heads)
])
self.W_Vs = nn.ModuleList([
nn.Linear(d_model, self.d_v, bias=False)
for _ in range(n_heads)
])
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
# x:float32[batch_size, sequence_length, self.d_model]
head_outputs = []
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
Q = W_Q(x)
# Q float32:[batch_size, sequence_length, self.d_k]
K = W_K(x)
# Q float32:[batch_size, sequence_length, self.d_k]
V = W_V(x)
# Q float32:[batch_size, sequence_length, self.d_v]
head_output = self.attention_layer(Q, K, V)
# float32:[batch_size, sequence_length, self.d_v]
head_outputs.append(head_output)
concatenated = torch.cat(head_outputs, dim=-1)
# concatenated float32:[batch_size, sequence_length, self.d_model]
out = self.W_O(concatenated)
# out float32:[batch_size, sequence_length, self.d_model]
return out
class Feedforward(nn.Module):
def __init__(self, d_model):
super(Feedforward, self).__init__()
self.d_model = d_model
self.W1 = nn.Linear(d_model, d_model)
self.W2 = nn.Linear(d_model, d_model)
def forward(self, x):
# x: float32[batch_size, sequence_length, d_model]
return self.W2(torch.relu(self.W1(x)))
class Transformer(nn.Module):
def __init__(self, d_model, n_heads, n_layers):
super(Transformer, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.attention_layers = nn.ModuleList([
MultiHeadedSelfAttentionLayer(d_model, n_heads)
for _ in range(n_layers)
])
self.ffs = nn.ModuleList([
Feedforward(d_model)
for _ in range(n_layers)
])
def forward(self, x):
# x: float32[batch_size, sequence_length, self.d_model]
for attention_layer, ff in zip(self.attention_layers, self.ffs):
attention_out = attention_layer(x)
# attention_out: float32[batch_size, sequence_length, self.d_model]
x = F.layer_norm(x + attention_out, x.shape[2:])
ff_out = ff(x)
# ff_out: float32[batch_size, sequence_length, self.d_model]
x = F.layer_norm(x + ff_out, x.shape[2:])
return x

View file

@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2018, Tatsuya Aoki
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,31 @@
# Simple Model for Sentence Compression
3-layered BILSTM model for sentence compression, referred as Baseline in [Klerke et al., NAACL 2016](http://aclweb.org/anthology/N/N16/N16-1179.pdf).
## Requirements
### Framework
- python (<= 3.6)
- pytorch (<= 0.3.0)
### Packages
- torchtext
## How to run
```
./getdata
python main.py
```
To run the scripts with gpu, use this command `python main.py --gpu-id ID`, which ID is the integer from 0 to the number of gpus what you have.
## Reference
```
@InProceedings{klerke-goldberg-sogaard:2016:N16-1,
author = {Klerke, Sigrid and Goldberg, Yoav and S{\o}gaard, Anders},
title = {Improving sentence compression by learning to predict gaze},
booktitle = {Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
month = {June},
year = {2016},
address = {San Diego, California},
publisher = {Association for Computational Linguistics},
pages = {1528--1533},
url = {http://www.aclweb.org/anthology/N16-1179}
}
```

View file

@ -0,0 +1 @@
from .main import *

View file

@ -0,0 +1,95 @@
from torchtext import data
from const import Phase
def create_dataset(data: dict, batch_size: int, device: int):
train = Dataset(data[Phase.TRAIN]['tokens'],
data[Phase.TRAIN]['labels'],
vocab=None,
batch_size=batch_size,
device=device,
phase=Phase.TRAIN)
dev = Dataset(data[Phase.DEV]['tokens'],
data[Phase.DEV]['labels'],
vocab=train.vocab,
batch_size=batch_size,
device=device,
phase=Phase.DEV)
test = Dataset(data[Phase.TEST]['tokens'],
data[Phase.TEST]['labels'],
vocab=train.vocab,
batch_size=batch_size,
device=device,
phase=Phase.TEST)
return train, dev, test
class Dataset:
def __init__(self,
tokens: list,
label_list: list,
vocab: list,
batch_size: int,
device: int,
phase: Phase):
assert len(tokens) == len(label_list), \
'the number of sentences and the number of POS/head sequences \
should be the same length'
self.pad_token = '<PAD>'
# self.unk_token = '<UNK>'
self.tokens = tokens
self.label_list = label_list
self.sentence_id = [[i] for i in range(len(tokens))]
self.device = device
self.token_field = data.Field(use_vocab=True,
# unk_token=self.unk_token,
pad_token=self.pad_token,
batch_first=True)
self.label_field = data.Field(use_vocab=False, pad_token=-1, batch_first=True)
self.sentence_id_field = data.Field(use_vocab=False, batch_first=True)
self.dataset = self._create_dataset()
if vocab is None:
self.token_field.build_vocab(self.tokens)
self.vocab = self.token_field.vocab
else:
self.token_field.vocab = vocab
self.vocab = vocab
self.pad_index = self.token_field.vocab.stoi[self.pad_token]
self._set_batch_iter(batch_size, phase)
def get_raw_sentence(self, sentences):
return [[self.vocab.itos[idx] for idx in sentence]
for sentence in sentences]
def _create_dataset(self):
_fields = [('token', self.token_field),
('label', self.label_field),
('sentence_id', self.sentence_id_field)]
return data.Dataset(self._get_examples(_fields), _fields)
def _get_examples(self, fields: list):
ex = []
for sentence, label, sentence_id in zip(self.tokens, self.label_list, self.sentence_id):
ex.append(data.Example.fromlist([sentence, label, sentence_id], fields))
return ex
def _set_batch_iter(self, batch_size: int, phase: Phase):
def sort(data: data.Dataset) -> int:
return len(getattr(data, 'token'))
train = True if phase == Phase.TRAIN else False
self.batch_iter = data.BucketIterator(dataset=self.dataset,
batch_size=batch_size,
sort_key=sort,
train=train,
repeat=False,
device=self.device)

View file

@ -0,0 +1,8 @@
from enum import Enum, unique
@unique
class Phase(Enum):
TRAIN = 'train'
DEV = 'dev'
TEST = 'test'

View file

@ -0,0 +1,92 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
class Network(nn.Module):
def __init__(self,
embeddings,
hidden_size: int,
prior,
device: torch.device):
super(Network, self).__init__()
self.device = device
self.priors = torch.log(torch.tensor([prior, 1-prior])).to(device)
self.hidden_size = hidden_size
self.bilstm_layers = 3
self.bilstm_input_size = 300
self.bilstm_output_size = 2 * hidden_size
self.word_emb = nn.Embedding.from_pretrained(embeddings, freeze=False)
self.bilstm = nn.LSTM(self.bilstm_input_size,
self.hidden_size,
num_layers=self.bilstm_layers,
batch_first=True,
dropout=0.1, #ms best mod 0.1
bidirectional=True)
self.dropout = nn.Dropout(p=0.35)
if self.attention:
self.attention_size = self.bilstm_output_size * 2
self.u_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size)
self.w_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size)
self.v_a_inv = nn.Linear(self.bilstm_output_size, 1, bias=False)
self.linear_attn = nn.Linear(self.attention_size, self.bilstm_output_size)
self.linear = nn.Linear(self.bilstm_output_size, self.hidden_size)
self.pred = nn.Linear(self.hidden_size, 2)
self.softmax = nn.LogSoftmax(dim=1)
self.criterion = nn.NLLLoss(ignore_index=-1)
def forward(self, input_tokens, labels, fixations=None):
loss = 0.0
preds = []
atts = []
batch_size, seq_len = input_tokens.size()
self.init_hidden(batch_size, device=self.device)
x_i = self.word_emb(input_tokens)
x_i = self.dropout(x_i)
hidden, (self.h_n, self.c_n) = self.bilstm(x_i, (self.h_n, self.c_n))
_, _, hidden_size = hidden.size()
for i in range(seq_len):
nth_hidden = hidden[:, i, :]
if self.attention:
target = nth_hidden.expand(seq_len, batch_size, -1).transpose(0, 1)
mask = hidden.eq(target)[:, :, 0].unsqueeze(2)
attn_weight = self.attention(hidden, target, fixations, mask)
context_vector = torch.bmm(attn_weight.transpose(1, 2), hidden).squeeze(1)
nth_hidden = torch.tanh(self.linear_attn(torch.cat((nth_hidden, context_vector), -1)))
atts.append(attn_weight.detach().cpu())
logits = self.pred(self.linear(nth_hidden))
if not self.training:
logits = logits + self.priors
output = self.softmax(logits)
loss += self.criterion(output, labels[:, i])
_, topi = output.topk(k=1, dim=1)
pred = topi.squeeze(-1)
preds.append(pred)
preds = torch.stack(torch.cat(preds, dim=0).split(batch_size), dim=1)
if atts:
atts = torch.stack(torch.cat(atts, dim=0).split(batch_size), dim=1)
return loss, preds, atts
def attention(self, source, target, fixations=None, mask=None):
function_g = \
self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
if mask is not None:
function_g.masked_fill_(mask, -1e4)
if fixations is not None:
function_g = function_g*fixations
return nn.functional.softmax(function_g, dim=1)
def init_hidden(self, batch_size, device):
zeros = Variable(torch.zeros(2*self.bilstm_layers, batch_size, self.hidden_size))
self.h_n = zeros.to(device)
self.c_n = zeros.to(device)
return self.h_n, self.c_n

View file

@ -0,0 +1,183 @@
import torch
from torch import optim
import tqdm
from const import Phase
from batch import create_dataset
from models import Baseline
from sklearn.metrics import classification_report
def run(dataset_train,
dataset_dev,
dataset_test,
model_type,
word_embed_size,
hidden_size,
batch_size,
device,
n_epochs):
if model_type == 'base':
model = Baseline(vocab=dataset_train.vocab,
word_embed_size=word_embed_size,
hidden_size=hidden_size,
device=device,
inference=False)
else:
raise NotImplementedError
model = model.to(device)
optim_params = model.parameters()
optimizer = optim.Adam(optim_params, lr=10**-3)
print('start training')
for epoch in range(n_epochs):
train_loss, tokens, preds, golds = train(dataset_train,
model,
optimizer,
batch_size,
epoch,
Phase.TRAIN,
device)
dev_loss, tokens, preds, golds = train(dataset_dev,
model,
optimizer,
batch_size,
epoch,
Phase.DEV,
device)
logger = '\t'.join(['epoch {}'.format(epoch+1),
'TRAIN Loss: {:.9f}'.format(train_loss),
'DEV Loss: {:.9f}'.format(dev_loss)])
# print('\r'+logger, end='')
print(logger)
test_loss, tokens, preds, golds = train(dataset_test,
model,
optimizer,
batch_size,
epoch,
Phase.TEST,
device)
print('====', 'TEST', '=====')
print_scores(preds, golds)
output_results(tokens, preds, golds)
def train(dataset,
model,
optimizer,
batch_size,
n_epoch,
phase,
device):
total_loss = 0.0
tokens = []
preds = []
labels = []
if phase == Phase.TRAIN:
model.train()
else:
model.eval()
for batch in tqdm.tqdm(dataset.batch_iter):
token = getattr(batch, 'token')
label = getattr(batch, 'label')
raw_sentences = dataset.get_raw_sentence(token.data.detach().cpu().numpy())
loss, pred = \
model(token, raw_sentences, label, phase)
if phase == Phase.TRAIN:
optimizer.zero_grad()
torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=5)
loss.backward()
optimizer.step()
# remove PAD from input sentences/labels and results
mask = (token != dataset.pad_index)
length_tensor = mask.sum(1)
length_tensor = length_tensor.data.detach().cpu().numpy()
for index, n_tokens_in_the_sentence in enumerate(length_tensor):
if n_tokens_in_the_sentence > 0:
tokens.append(raw_sentences[index][:n_tokens_in_the_sentence])
_label = label[index][:n_tokens_in_the_sentence]
_pred = pred[index][:n_tokens_in_the_sentence]
_label = _label.data.detach().cpu().numpy()
_pred = _pred.data.detach().cpu().numpy()
labels.append(_label)
preds.append(_pred)
total_loss += torch.mean(loss).item()
return total_loss, tokens, preds, labels
def read_two_cols_data(fname, max_len=None):
data = {}
tokens = []
labels = []
token = []
label = []
with open(fname, mode='r') as f:
for line in f:
line = line.strip().lower().split()
if line:
try:
_token, _label = line
except ValueError:
raise
token.append(_token)
if _label == '0' or _label == '1':
label.append(int(_label))
else:
if _label == 'del':
label.append(1)
else:
label.append(0)
else:
if max_len is None or len(token) <= max_len:
tokens.append(token)
labels.append(label)
token = []
label = []
data['tokens'] = tokens
data['labels'] = labels
return data
def load(train_path, dev_path, test_path, batch_size, max_len, device):
train = read_two_cols_data(train_path, max_len)
dev = read_two_cols_data(dev_path)
test = read_two_cols_data(test_path)
data = {Phase.TRAIN: train, Phase.DEV: dev, Phase.TEST: test}
return create_dataset(data, batch_size=batch_size, device=device)
def print_scores(preds, golds):
_preds = [label for sublist in preds for label in sublist]
_golds = [label for sublist in golds for label in sublist]
target_names = ['not_del', 'del']
print(classification_report(_golds, _preds, target_names=target_names, digits=5))
def output_results(tokens, preds, golds, path='./result/sentcomp'):
with open(path+'.original.txt', mode='w') as w, \
open(path+'.gold.txt', mode='w') as w_gold, \
open(path+'.pred.txt', mode='w') as w_pred:
for _tokens, _golds, _preds in zip(tokens, golds, preds):
for token, gold, pred in zip(_tokens, _golds, _preds):
w.write(token + ' ')
if gold == 0:
w_gold.write(token + ' ')
# 0 -> keep, 1 -> delete
if pred == 0:
w_pred.write(token + ' ')
w.write('\n')
w_gold.write('\n')
w_pred.write('\n')

View file

@ -0,0 +1,218 @@
import json
import logging
import math
import os
import random
import re
import time
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from nltk.translate.bleu_score import sentence_bleu
import numpy as np
import torch
import torch.nn as nn
import config
plt.switch_backend("agg")
def load_glove(vocabulary):
logger = logging.getLogger(f"{__name__}.load_glove")
logger.info("loading embeddings")
try:
with open(f"glove.cache") as h:
cache = json.load(h)
except:
logger.info("cache doesn't exist")
cache = {}
cache[config.PAD] = [0] * 300
cache[config.SOS] = [0] * 300
cache[config.EOS] = [0] * 300
cache[config.UNK] = [0] * 300
cache[config.NOFIX] = [0] * 300
else:
logger.info("cache found")
cache_miss = False
if not set(vocabulary) <= set(cache):
cache_miss = True
logger.warn("cache miss, loading full embeddings")
data = {}
with open("glove.840B.300d.txt") as h:
for line in h:
word, *emb = line.strip().split()
try:
data[word] = [float(x) for x in emb]
except:
continue
logger.info("finished loading full embeddings")
for word in vocabulary:
try:
cache[word] = data[word]
except KeyError:
cache[word] = [0] * 300
logger.info("cache updated")
embeddings = []
for word in vocabulary:
embeddings.append(torch.tensor(cache[word], dtype=torch.float32))
embeddings = torch.stack(embeddings)
if cache_miss:
with open(f"glove.cache", "w") as h:
json.dump(cache, h)
logger.info("cache saved")
return embeddings
def tokenize(s):
s = s.lower().strip()
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
s = s.split(" ")
return s
def indices_from_sentence(word2index, sentence, unknown_threshold):
if unknown_threshold:
return [
word2index.get(
word if random.random() > unknown_threshold else config.UNK,
word2index[config.UNK],
)
for word in sentence
]
else:
return [
word2index.get(word, word2index[config.UNK]) for word in sentence
]
def tensor_from_sentence(word2index, sentence, unknown_threshold):
indices = indices_from_sentence(word2index, sentence, unknown_threshold)
return torch.tensor(indices, dtype=torch.long, device=config.DEV)
def tensors_from_pair(word2index, pair, shuffle, unknown_threshold):
tensors = [
tensor_from_sentence(word2index, pair[0], unknown_threshold),
tensor_from_sentence(word2index, pair[1], unknown_threshold),
]
if shuffle:
random.shuffle(tensors)
return tensors
def bleu(reference, hypothesis, n=4):
if n < 1:
return 0
weights = [1/n]*n
return sentence_bleu([reference], hypothesis, weights)
def pair_iter(pairs, word2index, shuffle=False, shuffle_pairs=False, unknown_threshold=0.00):
if shuffle:
pairs = pairs.copy()
random.shuffle(pairs)
for pair in pairs:
tensor1, tensor2 = tensors_from_pair(word2index, (pair[0], pair[1]), shuffle_pairs, unknown_threshold)
yield (tensor1,), (tensor2,)
def sent_iter(sents, word2index, batch_size, unknown_threshold=0.00):
for i in range(len(sents)//batch_size+1):
raw_sents = [x[0] for x in sents[i*batch_size:i*batch_size+batch_size]]
_sents = [tensor_from_sentence(word2index, sent, unknown_threshold) for sent, target in sents[i*batch_size:i*batch_size+batch_size]]
_targets = [torch.tensor(target, dtype=torch.long).to(config.DEV) for sent, target in sents[i*batch_size:i*batch_size+batch_size]]
if raw_sents and _sents and _targets:
yield(raw_sents, _sents, _targets)
def batch_iter(pairs, word2index, batch_size, shuffle=False, unknown_threshold=0.00):
for i in range(len(pairs) // batch_size):
batch = pairs[i : i + batch_size]
if len(batch) != batch_size:
continue
batch_tensors = [
tensors_from_pair(word2index, (pair[0], pair[1]), shuffle, unknown_threshold)
for pair in batch
]
tensors1, tensors2 = zip(*batch_tensors)
yield tensors1, tensors2
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return "%dm %ds" % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return "%s (- %s)" % (asMinutes(s), asMinutes(rs))
def showPlot(points):
plt.figure()
fig, ax = plt.subplots()
# this locator puts ticks at regular intervals
loc = ticker.MultipleLocator(base=0.2)
ax.yaxis.set_major_locator(loc)
plt.plot(points)
def showAttention(input_sentence, output_words, attentions):
# Set up figure with colorbar
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(attentions.numpy(), cmap="bone")
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([""] + input_sentence.split(" ") + ["<__EOS__>"], rotation=90)
ax.set_yticklabels([""] + output_words)
# Show label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
def evaluateAndShowAttention(input_sentence):
output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence)
print("input =", input_sentence)
print("output =", " ".join(output_words))
showAttention(input_sentence, output_words, attentions)
def save_model(model, word2index, path):
if not path.endswith(".tar"):
path += ".tar"
torch.save(
{"weights": model.state_dict(), "word2index": word2index},
path,
)
def load_model(path):
checkpoint = torch.load(path)
return checkpoint["weights"], checkpoint["word2index"]
def extend_vocabulary(word2index, langs):
for lang in langs:
for word in lang.word2index:
if word not in word2index:
word2index[word] = len(word2index)
return word2index

View file

@ -0,0 +1,495 @@
import logging
import os
import pathlib
import random
import re
import sys
import click
import sacrebleu
import torch
import torch.nn as nn
import tqdm
import config
from libs import corpora
from libs import utils
from libs.fixation_generation import Network as FixNN
from libs.sentence_compression import Network as ComNN
from sklearn.metrics import classification_report, precision_recall_fscore_support
cwd = os.path.dirname(__file__)
logger = logging.getLogger("main")
class Network(nn.Module):
def __init__(
self, word2index, embeddings, prior,
):
super().__init__()
self.logger = logging.getLogger(f"{__name__}")
self.word2index = word2index
self.index2word = {i: k for k, i in word2index.items()}
self.fix_gen = FixNN(
embedding_type="glove",
vocab_size=len(word2index),
embedding_dim=config.embedding_dim,
embeddings=embeddings,
dropout=config.fix_dropout,
hidden_dim=config.fix_hidden_dim,
)
self.com_nn = ComNN(
embeddings=embeddings, hidden_size=config.sem_hidden_dim, prior=prior, device=config.DEV
)
def forward(self, x, target, seq_lens):
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
target = nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=-1)
fixations = torch.sigmoid(self.fix_gen(x1, seq_lens))
# fixations = None
loss, pred, atts = self.com_nn(x1, target, fixations)
return loss, pred, atts, fixations
def load_corpus(corpus_name, splits):
if not splits:
return
logger.info("loading corpus")
if corpus_name == "google":
load_fn = corpora.load_google
corpus = {}
langs = []
if "train" in splits:
train_pairs, train_lang = load_fn("train", max_len=200)
corpus["train"] = train_pairs
langs.append(train_lang)
if "val" in splits:
val_pairs, val_lang = load_fn("val")
corpus["val"] = val_pairs
langs.append(val_lang)
if "test" in splits:
test_pairs, test_lang = load_fn("test")
corpus["test"] = test_pairs
langs.append(test_lang)
logger.info("creating word index")
lang = langs[0]
for _lang in langs[1:]:
lang += _lang
word2index = lang.word2index
index2word = {i: w for w, i in word2index.items()}
return corpus, word2index, index2word
def init_network(word2index, prior):
logger.info("loading embeddings")
vocabulary = sorted(word2index.keys())
embeddings = utils.load_glove(vocabulary)
logger.info("initializing model")
network = Network(word2index=word2index, embeddings=embeddings, prior=prior)
network.to(config.DEV)
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
return network
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
@click.option("-v", "--verbose", count=True)
@click.option("-d", "--debug", is_flag=True)
def main(verbose, debug):
if verbose == 0:
loglevel = logging.ERROR
elif verbose == 1:
loglevel = logging.WARN
elif verbose >= 2:
loglevel = logging.INFO
if debug:
loglevel = logging.DEBUG
logging.basicConfig(
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
datefmt="%d.%m. %H:%M:%S",
level=loglevel,
)
logger.debug("arguments: %s" % str(sys.argv))
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["google",])),
)
@click.option("-m", "--model_name", required=True)
@click.option("-w", "--fixation_weights", required=False)
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
@click.option("-d", "--debug", is_flag=True, default=False)
@click.option("-p", "--prior", type=float, default=.5)
def train(corpus_name, model_name, fixation_weights, freeze_fixations, debug, prior):
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
train_pairs = corpus["train"]
val_pairs = corpus["val"]
network = init_network(word2index, prior)
model_dir = os.path.join("models", model_name)
logger.debug("creating model dir %s" % model_dir)
pathlib.Path(model_dir).mkdir(parents=True)
if fixation_weights is not None:
logger.info("loading fixation prediction checkpoint")
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
else:
weights = checkpoint
# remove the embedding layer before loading
weights = {
k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")
}
network.fix_gen.load_state_dict(weights, strict=False)
if freeze_fixations:
logger.info("freezing fixation generation network")
for p in network.fix_gen.parameters():
p.requires_grad = False
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
best_val_loss = None
epoch = 1
batch_size = 20
while True:
train_batch_iter = utils.sent_iter(
sents=train_pairs, word2index=word2index, batch_size=batch_size
)
val_batch_iter = utils.sent_iter(
sents=val_pairs, word2index=word2index, batch_size=batch_size
)
total_train_loss = 0
total_val_loss = 0
network.train()
for i, batch in tqdm.tqdm(
enumerate(train_batch_iter, 1), total=len(train_pairs) // batch_size + 1
):
optimizer.zero_grad()
raw_sent, sent, target = batch
seq_lens = [len(x) for x in sent]
loss, prediction, attention, fixations = network(sent, target, seq_lens)
prediction = prediction.detach().cpu().numpy()
torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
loss.backward()
optimizer.step()
total_train_loss += loss.item()
avg_train_loss = total_train_loss / len(train_pairs)
val_sents = []
val_preds = []
val_targets = []
network.eval()
for i, batch in tqdm.tqdm(
enumerate(val_batch_iter), total=len(val_pairs) // batch_size + 1
):
raw_sent, sent, target = batch
seq_lens = [len(x) for x in sent]
loss, prediction, attention, fixations = network(sent, target, seq_lens)
prediction = prediction.detach().cpu().numpy()
for i, l in enumerate(seq_lens):
val_sents.append(raw_sent[i][:l])
val_preds.append(prediction[i][:l].tolist())
val_targets.append(target[i][:l].tolist())
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_pairs)
print(
f"epoch {epoch} train_loss {avg_train_loss:.4f} val_loss {avg_val_loss:.4f}"
)
print(
classification_report(
[x for y in val_targets for x in y],
[x for y in val_preds for x in y],
target_names=["not_del", "del"],
digits=5,
)
)
with open(f"models/{model_name}/val_original_{epoch}.txt", "w") as oh, open(
f"models/{model_name}/val_pred_{epoch}.txt", "w"
) as ph, open(f"models/{model_name}/val_gold_{epoch}.txt", "w") as gh:
for sent, preds, golds in zip(val_sents, val_preds, val_targets):
pred_compressed = [
word for word, delete in zip(sent, preds) if not delete
]
gold_compressed = [
word for word, delete in zip(sent, golds) if not delete
]
oh.write(" ".join(sent))
ph.write(" ".join(pred_compressed))
gh.write(" ".join(gold_compressed))
oh.write("\n")
ph.write("\n")
gh.write("\n")
if best_val_loss is None or avg_val_loss < best_val_loss:
delta = avg_val_loss - best_val_loss if best_val_loss is not None else 0.0
best_val_loss = avg_val_loss
print(
f"new best model epoch {epoch} val loss {avg_val_loss:.4f} ({delta:.4f})"
)
utils.save_model(
network, word2index, f"models/{model_name}/{model_name}_{epoch}"
)
epoch += 1
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["google",])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-p", "--prior", type=float, default=.5)
@click.option("-l", "--longest", is_flag=True)
@click.option("-s", "--shortest", is_flag=True)
@click.option("-d", "--detailed", is_flag=True)
def test(corpus_name, model_weights, prior, longest, shortest, detailed):
if longest and shortest:
print("longest and shortest are mutually exclusive", file=sys.stderr)
sys.exit()
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
test_pairs = corpus["test"]
model_name = os.path.basename(os.path.dirname(model_weights))
epoch = re.search("_(\d+).tar", model_weights).group(1)
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index, prior)
network.eval()
# remove the embedding layer before loading
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
# actually load the parameters
network.load_state_dict(weights, strict=False)
total_test_loss = 0
batch_size = 20
test_batch_iter = utils.sent_iter(
sents=test_pairs, word2index=word2index, batch_size=batch_size
)
test_sents = []
test_preds = []
test_targets = []
for i, batch in tqdm.tqdm(
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
):
raw_sent, sent, target = batch
seq_lens = [len(x) for x in sent]
loss, prediction, attention, fixations = network(sent, target, seq_lens)
prediction = prediction.detach().cpu().numpy()
for i, l in enumerate(
seq_lens
):
test_sents.append(raw_sent[i][:l])
test_preds.append(prediction[i][:l].tolist())
test_targets.append(target[i][:l].tolist())
total_test_loss += loss.item()
avg_test_loss = total_test_loss / len(test_pairs)
print(f"test_loss {avg_test_loss:.4f}")
if longest:
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
test_targets = list(filter(lambda x: len(x) > avg_len, test_targets))
elif shortest:
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
test_targets = list(filter(lambda x: len(x) <= avg_len, test_targets))
if detailed:
for test_sent, test_target, test_pred in zip(test_sents, test_targets, test_preds):
print(precision_recall_fscore_support(test_target, test_pred, average="weighted")[2], test_sent, test_target, test_pred)
else:
print(
classification_report(
[x for y in test_targets for x in y],
[x for y in test_preds for x in y],
target_names=["not_del", "del"],
digits=5,
)
)
with open(f"models/{model_name}/test_original_{epoch}.txt", "w") as oh, open(
f"models/{model_name}/test_pred_{epoch}.txt", "w"
) as ph, open(f"models/{model_name}/test_gold_{epoch}.txt", "w") as gh:
for sent, preds, golds in zip(test_sents, test_preds, test_targets):
pred_compressed = [word for word, delete in zip(sent, preds) if not delete]
gold_compressed = [word for word, delete in zip(sent, golds) if not delete]
oh.write(" ".join(sent))
ph.write(" ".join(pred_compressed))
gh.write(" ".join(gold_compressed))
oh.write("\n")
ph.write("\n")
gh.write("\n")
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["google",])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-p", "--prior", type=float, default=.5)
@click.option("-l", "--longest", is_flag=True)
@click.option("-s", "--shortest", is_flag=True)
def predict(corpus_name, model_weights, prior, longest, shortest):
if longest and shortest:
print("longest and shortest are mutually exclusive", file=sys.stderr)
sys.exit()
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
test_pairs = corpus["test"]
model_name = os.path.basename(os.path.dirname(model_weights))
epoch = re.search("_(\d+).tar", model_weights).group(1)
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index, prior)
network.eval()
# remove the embedding layer before loading
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
# actually load the parameters
network.load_state_dict(weights, strict=False)
total_test_loss = 0
batch_size = 20
test_batch_iter = utils.sent_iter(
sents=test_pairs, word2index=word2index, batch_size=batch_size
)
test_sents = []
test_preds = []
test_attentions = []
test_fixations = []
for i, batch in tqdm.tqdm(
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
):
raw_sent, sent, target = batch
seq_lens = [len(x) for x in sent]
loss, prediction, attention, fixations = network(sent, target, seq_lens)
prediction = prediction.detach().cpu().numpy()
attention = attention.detach().cpu().numpy()
if fixations is not None:
fixations = fixations.detach().cpu().numpy()
for i, l in enumerate(
seq_lens
):
test_sents.append(raw_sent[i][:l])
test_preds.append(prediction[i][:l].tolist())
test_attentions.append(attention[i][:l].tolist())
if fixations is not None:
test_fixations.append(fixations[i][:l].tolist())
else:
test_fixations.append([])
total_test_loss += loss.item()
avg_test_loss = total_test_loss / len(test_pairs)
if longest:
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
test_attentions = list(filter(lambda x: len(x) > avg_len, test_attentions))
test_fixations = list(filter(lambda x: len(x) > avg_len, test_fixations))
elif shortest:
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
test_attentions = list(filter(lambda x: len(x) <= avg_len, test_attentions))
test_fixations = list(filter(lambda x: len(x) <= avg_len, test_fixations))
print(f"sentence\tprediction\tattentions\tfixations")
for s, p, a, f in zip(test_sents, test_preds, test_attentions, test_fixations):
a = [x[:len(a)] for x in a]
print(f"{s}\t{p}\t{a}\t{f}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,23 @@
click==7.1.2
cycler==0.10.0
dataclasses==0.6
future==0.18.2
joblib==0.17.0
kiwisolver==1.3.1
matplotlib==3.3.3
nltk==3.5
numpy==1.19.4
Pillow==8.0.1
portalocker==2.0.0
pyparsing==2.4.7
python-dateutil==2.8.1
regex==2020.11.13
sacrebleu==1.4.6
scikit-learn==0.23.2
scipy==1.5.4
six==1.15.0
sklearn==0.0
threadpoolctl==2.1.0
torch==1.7.0
tqdm==4.54.1
typing-extensions==3.7.4.3

View file

@ -0,0 +1,43 @@
import ast
from statistics import mean, pstdev
import sys
import click
from scipy.stats import entropy
from matplotlib import pyplot as plt
def reader(path):
with open(path) as h:
for line in h:
line = line.strip()
try:
s, p, a, f = line.split("\t")
except:
print(f"skipping line: {line}", file=sys.stderr)
else:
try:
yield ast.literal_eval(s), ast.literal_eval(p), ast.literal_eval(a), ast.literal_eval(f)
except:
print(f"malformed line: {s}")
def get_stats(seq):
for s, p, a, f in seq:
print(s)
print(p)
print(len(s), len(p), len(a), len(f))
for x in a:
print(len(x))
print()
@click.command()
@click.argument("path")
def main(path):
get_stats(reader(path))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,34 @@
import ast
from math import log2
import os
from statistics import mean, pstdev
import sys
import click
import numpy as np
from scipy.special import kl_div
from matplotlib import pyplot as plt
def reader(path):
with open(path) as h:
for line in h:
line = line.strip()
yield line.split()
@click.command()
@click.argument("original")
@click.argument("compressed")
def main(original, compressed):
ratio = 0
total = 0
for o, c in zip(reader(original), reader(compressed)):
ratio += len(c)/len(o)
total += 1
print(f"cr: {ratio/total:.4f}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,88 @@
import ast
from math import log2
import os
from statistics import mean, pstdev
import sys
import click
import numpy as np
from scipy.special import kl_div
from matplotlib import pyplot as plt
def attention_reader(path):
with open(path) as h:
for line in h:
line = line.strip()
try:
s, p, a, f = line.split("\t")
except:
print(f"skipping line: {line}", file=sys.stderr)
else:
try:
yield [[x[0] for x in y] for y in ast.literal_eval(a)]
except:
print(f"skipping malformed line: {s}", file=sys.stderr)
def _kl_divergence(p, q):
p = np.asarray(p)
q = np.asarray(q)
p /= sum(p)
q /= sum(q)
return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
def kl_divergence(ps, qs):
kl = 0
count = 0
for p, q in zip(ps, qs):
print(p, q)
kl += _kl_divergence(p, q)
count += 1
return kl/count
def _js_divergence(p, q):
p = np.asarray(p)
q = np.asarray(q)
p /= sum(p)
q /= sum(q)
print(p, q)
m = 0.5 * (p + q)
return 0.5 * _kl_divergence(p, m) + 0.5 * _kl_divergence(q, m)
def js_divergence(ps, qs):
js = 0
count = 0
for p, q in zip(ps, qs):
js += _js_divergence(p, q)
count += 1
return js/count
def get_kl_div(seq1, seq2):
return [js_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
return [kl_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
@click.command()
@click.argument("ref")
@click.argument("path", nargs=-1)
def main(ref, path):
kls = []
labels = []
for p in path:
labels.append(os.path.basename(p))
kl = get_kl_div(attention_reader(ref), attention_reader(p))
print(mean(kl))
print(pstdev(kl))
kls.append(kl)
plt.boxplot(kls, labels=labels)
plt.show()
plt.clear()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,112 @@
'''
Calculates the KL divergence between specified columns in a file or across columns of different files
'''
from math import log2
import pandas as pd
import ast
import collections
from sklearn import metrics
import sys
def kl_divergence(p, q):
return sum(p[i] * log2(p[i] / q[i]) for i in range(len(p)))
def flatten(x):
if isinstance(x, collections.Iterable):
return [a for i in x for a in flatten(i)]
else:
return [x]
def get_data(file):
names = ["sentence", "prediction", "attentions", "fixations"]
df = pd.read_csv(file, sep='\t', names=names)
df = df[2:]
attentions = df.loc[:, "attentions"].tolist()
fixations = df.loc[:, "fixations"].tolist()
return attentions, fixations
def attention_attention(attentions1, attentions2):
divergence = []
for att1, att2 in zip(attentions1, attentions2):
current_att1 = ast.literal_eval(att1)
current_att2 = ast.literal_eval(att2)
lst_att1 = flatten(current_att1)
lst_att2 = flatten(current_att2)
try:
kl_pq = metrics.mutual_info_score(lst_att1, lst_att2)
divergence.append(kl_pq)
except:
divergence.append(None)
avg = sum(divergence) / len(divergence)
return avg
def fixation_fixation(fixation1, fixation2):
divergence = []
for fix1, fix2 in zip(fixation1, fixation2):
current_fixation1 = ast.literal_eval(fix1)
current_fixation2 = ast.literal_eval(fix2)
lst_fixation1 = flatten(current_fixation1)
lst_fixation2 = flatten(current_fixation2)
try:
kl_pq = metrics.mutual_info_score(lst_fixation1, lst_fixation2)
divergence.append(kl_pq)
except:
divergence.append(None)
avg = sum(divergence) / len(divergence)
return avg
def attention_fixation(attentions,fixations):
divergence = []
for attention, fixation in zip(attentions, fixations):
current_attention = ast.literal_eval(attention)
current_fixation = ast.literal_eval(fixation)
lst_attention = []
for t in current_attention:
attention_lst = flatten(t)
lst_attention.append(sum(attention_lst)/len(attention_lst))
lst_fixation = flatten(current_fixation)
try:
kl_pq = metrics.mutual_info_score(lst_attention, lst_fixation)
divergence.append(kl_pq)
except:
divergence.append(None)
avg = sum(divergence)/len(divergence)
return avg
def divergent_calculations(file1, file2=None, val1=None):
attentions, fixations = get_data(file1)
attentions2, fixations2 = get_data(file2)
if file2:
if val1 == "attention":
divergence = attention_attention(attentions, attentions2)
else:
divergence = fixation_fixation(fixations, fixations2)
else:
divergence = attention_fixation(attentions, fixations)
print ("DL Divergence: ", divergence)
divergent_calculations(sys.argv[1], sys.argv[2], sys.argv[3])

View file

@ -0,0 +1,64 @@
import ast
import os
import pathlib
import text_attention
import click
import matplotlib.pyplot as plt
plt.switch_backend("agg")
import matplotlib.ticker as ticker
import numpy as np
import tqdm
def plot_attention(input_sentence, output_words, attentions, path):
# Set up figure with colorbar
attentions = np.array(attentions)[:,:len(input_sentence)]
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(attentions, cmap="bone")
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
ax.set_yticklabels([""] + output_words)
# Show label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.savefig(f"{path}.pdf")
plt.close()
def parse(p):
with open(p) as h:
for line in h:
if not line or line.startswith("#"):
continue
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
try:
sentence = ast.literal_eval(_sentence)
prediction = ast.literal_eval(_prediction)
attention = ast.literal_eval(_attention)
except:
continue
yield sentence, prediction, attention
@click.command()
@click.argument("path", nargs=-1, required=True)
def main(path):
for p in tqdm.tqdm(path):
out_dir = os.path.splitext(p)[0]
if out_dir == path:
out_dir = f"{out_dir}_"
pathlib.Path(out_dir).mkdir(exist_ok=True)
for i, spa in enumerate(parse(p)):
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Date: 2019-03-29 16:10:23
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
# @Last Modified time: 2019-04-12 09:56:12
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
import numpy as np
latex_special_token = ["!@#$%^&*()"]
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
assert(len(text_list) == len(attention_list))
if rescale_value:
attention_list = rescale(attention_list)
word_num = len(text_list)
text_list = clean_word(text_list)
with open(latex_file,'w') as f:
f.write(r'''\documentclass[varwidth]{standalone}
\special{papersize=210mm,297mm}
\usepackage{color}
\usepackage{tcolorbox}
\usepackage{CJK}
\usepackage{adjustbox}
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
\begin{document}
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
for idx in range(word_num):
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
string += "\n}}}"
f.write(string+'\n')
f.write(r'''\end{CJK*}
\end{document}''')
def rescale(input_list):
the_array = np.asarray(input_list)
the_max = np.max(the_array)
the_min = np.min(the_array)
rescale = (the_array - the_min)/(the_max-the_min)*100
return rescale.tolist()
def clean_word(word_list):
new_word_list = []
for word in word_list:
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
if latex_sensitive in word:
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
new_word_list.append(word)
return new_word_list
if __name__ == '__main__':
## This is a demo:
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
sent = '''我 回忆 起 我 曾经 在 大学 年代 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar 我 想起 了 西游记 里 的 琵琶精 。
今年 下半年 合拍 西游记 即将 正式 开机 继续 扮演 美猴王 孙悟空 美猴王 艺术 形象 努力 创造 正能量 形象 开花 弘扬 中华 文化 希望 大家 多多 关注 '''
words = sent.split()
word_num = len(words)
attention = [(x+1.)/word_num*100 for x in range(word_num)]
import random
random.seed(42)
random.shuffle(attention)
color = 'red'
generate(words, attention, "sample.tex", color)