human-gaze-guided-neural-at.../joint_paraphrase_model/main.py

795 lines
36 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import pathlib
import random
import sys
import click
import sacrebleu
import torch
import torch.nn as nn
import tqdm
import config
from libs import corpora
from libs import utils
from libs.fixation_generation import Network as FixNN
from libs.paraphrase_generation import (
EncoderRNN as ParEncNN,
AttnDecoderRNN as ParDecNN,
)
cwd = os.path.dirname(__file__)
logger = logging.getLogger("main")
'''
#qqp_paw sentences:
debug_sentences = [
'What are the driving rules in Georgia versus Mississippi ?',
'I want to be a child psychologist , what qualification do i need to become one ? Are there good and reputed psychology Institute or Colleges in India ?',
'What is deep web and dark web and what are the contents of these sites ?',
'What is difference between North Indian Brahmins and South Indian Brahmins ?',
'Is carbon dioxide an ionic bond or a covalent bond ?',
'How do accounts receivable and accounts payable differ ?',
'Why did Wikipedia hide its audit history for ( superluminal ) successful speed experiments ?',
'What makes a person simple , or inversely , complicated ?',
'`` How do you say `` Miss you , too , `` in Spanish ? Are there multiple ways to say it ? ',
'What is the difference between dominant trait and recessive trait ?',
'`` What is the difference between `` seeing someone , `` `` dating someone , `` and `` having a girlfriend/boyfriend `` ? ',
'How was the Empire State building built and designed ? How is it used ?',
'What is the sum of the square roots of first n natural number ?',
'Why is Roman Saini not so active on Quora now a days ?',
'If I have someone blocked on Instagram , and see their story , can they view I viewed it ?',
'Amongst the major IT companies of India which is the best ; Wipro , Capgemini , Infosys , TCS or is Oracle the best ?',
'How much mass does Saturn lose each year ? How much mass does it gain ?',
'What is a cheap healthy diet , I can keep the same and eat every day ?',
' What is it like to be beautiful ? Not just pretty or hot , but the kind of almost objective beauty that people are sometimes intimidated by ?',
'Could someone tell of a James Ronsey ( misspelled likely ) , writer and filmmaker , probably of the British Isles ?',
'How much pressure Is there around the core of Pluto ? is it enough to turn hydrogen/helium gas into a liquid or metallic state ?',
'How does quality of life in Vancouver compare to that in Melbourne Or Brisbane ?',
]
'''
'''
#wiki sentences:
debug_sentences = [
'They were there to enjoy us and they were there to pray for us .',
'Components of elastic potential systems store mechanical energy if they are deformed when forces are applied to the system .',
'Steam can also be used , and does not need to be pumped .',
'The solar approach to this requirement is the use of solar panels in a conventional-powered aircraft .',
'Daudkhali is a village in Barisal Division in the Pirojpur district in southwestern Bangladesh .',
'Briggs later met Briggs at the 1967 Monterey Pop Festival , where Ravi Shankar was also performing , with Eric Burdon and The Animals .',
'Brockton is approximately 25 miles northeast of Providence , Rhode Island , and 30 miles south of Boston .',
]
'''
#qqp sentences:
debug_sentences = [
'How do I get funding for my web based startup idea ?',
'What do intelligent people do to pass time ?',
'Which is the best SEO Company in Delhi ?',
'Why do you waer makeup ?',
'How do start chatting with a girl ?',
'What is the meaning of living life ?',
'Why do my armpits hurt ?',
'Why does eye color change with age ?',
'How do you find the standard deviation of a probability distribution ? What are some examples ?',
'How can I complete my 11 syllabus in one month ?',
'How do I concentrate better on my studies ?',
'Which is the best retirement plan in india ?',
'Should I tell my best friend I love her ?',
'Which is the best company for Appian Vagrant online job support ?',
'How can one do for good handwriting ?',
'What are remedies to get rid of belly fat ?',
'What is the best way to cook precooked turkey ?',
'What is the future of e-commerce in India ?',
'Why do my burps taste like rotten eggs ?',
'What is an example of chemical weathering ?',
'What are some of the advantages and disadvantages of cyber schooling ?',
'How can I increase traffic to my websites by Facebook ?',
'How do I increase my patience level in life ?',
'What are the best hospitals for treating cancer in India ?',
'Will Jio sim work in a 3G phone ? If yes , how ?',
]
debug_sentences = [s.split(" ") for s in debug_sentences]
class Network(nn.Module):
def __init__(
self,
word2index,
embeddings,
):
super().__init__()
self.logger = logging.getLogger(f"{__name__}")
self.word2index = word2index
self.index2word = {i: k for k, i in word2index.items()}
self.fix_gen = FixNN(
embedding_type="glove",
vocab_size=len(word2index),
embedding_dim=config.embedding_dim,
embeddings=embeddings,
dropout=config.fix_dropout,
hidden_dim=config.fix_hidden_dim,
)
self.par_enc = ParEncNN(
input_size=config.embedding_dim,
hidden_size=config.par_hidden_dim,
embeddings=embeddings,
)
self.par_dec = ParDecNN(
input_size=config.embedding_dim,
hidden_size=config.par_hidden_dim,
output_size=len(word2index),
embeddings=embeddings,
dropout_p=config.par_dropout,
max_length=config.max_length,
)
def forward(self, x, target=None, teacher_forcing_ratio=None):
teacher_forcing_ratio = teacher_forcing_ratio if teacher_forcing_ratio is not None else config.teacher_forcing_ratio
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
x2 = nn.utils.rnn.pad_sequence(x, batch_first=False)
fixations = torch.sigmoid(self.fix_gen(x1, [len(_x) for _x in x1]))
enc_hidden = self.par_enc.initHidden().to(config.DEV)
enc_outs = torch.zeros(config.max_length, config.par_hidden_dim, device=config.DEV)
for ei in range(len(x2)):
enc_out, enc_hidden = self.par_enc(x2[ei], enc_hidden)
enc_outs[ei] += enc_out[0, 0]
dec_in = torch.tensor([[self.word2index[config.SOS]]], device=config.DEV) # SOS
dec_hidden = enc_hidden
dec_outs = []
dec_words = []
dec_atts = torch.zeros(config.max_length, config.max_length)
if target is not None: # training
target = nn.utils.rnn.pad_sequence(target, batch_first=False)
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
if use_teacher_forcing:
for di in range(len(target)):
dec_out, dec_hidden, dec_att = self.par_dec(
dec_in, dec_hidden, enc_outs, fixations
)
dec_outs.append(dec_out)
dec_atts[di] = dec_att.data
dec_input = target[di]
else:
for di in range(len(target)):
dec_out, dec_hidden, dec_att = self.par_dec(
dec_in, dec_hidden, enc_outs, fixations
)
dec_outs.append(dec_out)
dec_atts[di] = dec_att.data
topv, topi = dec_out.data.topk(1)
dec_words.append(self.index2word[topi.item()])
dec_input = topi.squeeze().detach()
else: # prediction
for di in range(config.max_length):
dec_out, dec_hidden, dec_att = self.par_dec(
dec_in, dec_hidden, enc_outs, fixations
)
dec_outs.append(dec_out)
dec_atts[di] = dec_att.data
topv, topi = dec_out.data.topk(1)
if topi.item() == self.word2index[config.EOS]:
dec_words.append("<__EOS__>")
break
else:
dec_words.append(self.index2word[topi.item()])
dec_input = topi.squeeze().detach()
return dec_outs, dec_words, dec_atts[: di + 1], fixations
def load_corpus(corpus_name, splits):
if not splits:
return
logger.info("loading corpus")
if corpus_name == "msrpc":
load_fn = corpora.load_msrpc
elif corpus_name == "qqp":
load_fn = corpora.load_qqp
elif corpus_name == "wiki":
load_fn = corpora.load_wiki
elif corpus_name == "qqp_paws":
load_fn = corpora.load_qqp_paws
elif corpus_name == "qqp_kag":
load_fn = corpora.load_qqp_kag
elif corpus_name == "sentiment":
load_fn = corpora.load_sentiment
elif corpus_name == "stanford":
load_fn = corpora.load_stanford
elif corpus_name == "stanford_sent":
load_fn = corpora.load_stanford_sent
elif corpus_name == "tamil":
load_fn = corpora.load_tamil
elif corpus_name == "compression":
load_fn = corpora.load_compression
corpus = {}
langs = []
if "train" in splits:
train_pairs, train_lang = load_fn("train")
corpus["train"] = train_pairs
langs.append(train_lang)
if "val" in splits:
val_pairs, val_lang = load_fn("val")
corpus["val"] = val_pairs
langs.append(val_lang)
if "test" in splits:
test_pairs, test_lang = load_fn("test")
corpus["test"] = test_pairs
langs.append(test_lang)
logger.info("creating word index")
lang = langs[0]
for _lang in langs[1:]:
lang += _lang
word2index = lang.word2index
index2word = {i: w for w, i in word2index.items()}
return corpus, word2index, index2word
def init_network(word2index):
logger.info("loading embeddings")
vocabulary = sorted(word2index.keys())
embeddings = utils.load_glove(vocabulary)
logger.info("initializing model")
network = Network(
word2index=word2index,
embeddings=embeddings,
)
network.to(config.DEV)
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
return network
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
@click.option("-v", "--verbose", count=True)
@click.option("-d", "--debug", is_flag=True)
def main(verbose, debug):
if verbose == 0:
loglevel = logging.ERROR
elif verbose == 1:
loglevel = logging.WARN
elif verbose >= 2:
loglevel = logging.INFO
if debug:
loglevel = logging.DEBUG
logging.basicConfig(
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
datefmt="%d.%m. %H:%M:%S",
level=loglevel,
)
logger.debug("arguments: %s" % str(sys.argv))
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
)
@click.option("-m", "--model_name", required=True)
@click.option("-w", "--fixation_weights", required=False)
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
def train(corpus_name, model_name, fixation_weights, freeze_fixations, bleu):
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
train_pairs = corpus["train"]
val_pairs = corpus["val"]
network = init_network(word2index)
model_dir = os.path.join("models", model_name)
logger.debug("creating model dir %s" % model_dir)
pathlib.Path(model_dir).mkdir(parents=True)
if fixation_weights is not None:
logger.info("loading fixation prediction checkpoint")
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
else:
weights = checkpoint
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")}
network.fix_gen.load_state_dict(weights, strict=False)
if freeze_fixations:
logger.info("freezing fixation generation network")
for p in network.fix_gen.parameters():
p.requires_grad = False
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
#optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=1e-5)
best_val_loss = None
epoch = 1
while True:
train_batch_iter = utils.pair_iter(pairs=train_pairs, word2index=word2index, shuffle=True, shuffle_pairs=False)
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
# test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
running_train_loss = 0
total_train_loss = 0
total_val_loss = 0
if bleu == "sacrebleu":
running_train_bleu = 0
total_train_bleu = 0
total_val_bleu = 0
elif bleu == "nltk":
running_train_bleu_1 = 0
running_train_bleu_2 = 0
running_train_bleu_3 = 0
running_train_bleu_4 = 0
total_train_bleu_1 = 0
total_train_bleu_2 = 0
total_train_bleu_3 = 0
total_train_bleu_4 = 0
total_val_bleu_1 = 0
total_val_bleu_2 = 0
total_val_bleu_3 = 0
total_val_bleu_4 = 0
network.train()
for i, batch in enumerate(train_batch_iter, 1):
optimizer.zero_grad()
input, target = batch
prediction, words, attention, fixations = network(input, target)
loss = loss_fn(
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
)
loss.backward()
optimizer.step()
running_train_loss += loss.item()
total_train_loss += loss.item()
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
if bleu == "sacrebleu":
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
running_train_bleu += bleu_score
total_train_bleu += bleu_score
elif bleu == "nltk":
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
running_train_bleu_1 += bleu_1_score
running_train_bleu_2 += bleu_2_score
running_train_bleu_3 += bleu_3_score
running_train_bleu_4 += bleu_4_score
total_train_bleu_1 += bleu_1_score
total_train_bleu_2 += bleu_2_score
total_train_bleu_3 += bleu_3_score
total_train_bleu_4 += bleu_4_score
# print(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist())
if i % 100 == 0:
if bleu == "sacrebleu":
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
elif bleu == "nltk":
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
network.eval()
with open(os.path.join(model_dir, f"debug_{epoch}_{i}.out"), "w") as h:
if bleu == "sacrebleu":
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
running_train_bleu = 0
elif bleu == "nltk":
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
running_train_bleu_1 = 0
running_train_bleu_2 = 0
running_train_bleu_3 = 0
running_train_bleu_4 = 0
running_train_loss = 0
h.write("\n")
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
h.write("\n")
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
prediction, words, attentions, fixations = network(input)
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
h.write("\n")
network.train()
network.eval()
for i, batch in enumerate(val_batch_iter):
input, target = batch
prediction, words, attention, fixations = network(input, target, teacher_forcing_ratio=0)
loss = loss_fn(
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
)
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
if bleu == "sacrebleu":
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
elif bleu == "nltk":
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
total_val_bleu_1 += bleu_1_score
total_val_bleu_2 += bleu_2_score
total_val_bleu_3 += bleu_3_score
total_val_bleu_4 += bleu_4_score
total_val_loss += loss.item()
avg_val_loss = total_val_loss/len(val_pairs)
if bleu == "sacrebleu":
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu {total_train_bleu/len(train_pairs):.2f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
elif bleu == "nltk":
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
with open(os.path.join(model_dir, f"debug_{epoch}_end.out"), "w") as h:
if bleu == "sacrebleu":
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs)} avg_val_loss {total_val_loss/len(val_pairs)} avg_train_bleu {total_train_bleu/len(train_pairs)} avg_val_bleu {total_val_bleu/len(val_pairs)}")
elif bleu == "nltk":
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
h.write("\n")
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
h.write("\n")
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
prediction, words, attentions, fixations = network(input)
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
h.write("\n")
utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_{epoch}"))
if best_val_loss is None or avg_val_loss < best_val_loss:
if best_val_loss is not None:
logger.info(f"{avg_val_loss} < {best_val_loss} ({avg_val_loss-best_val_loss}): new best model from epoch {epoch}")
else:
logger.info(f"{avg_val_loss} < {best_val_loss}: new best model from epoch {epoch}")
best_val_loss = avg_val_loss
# save_model(model, word2index, model_name + "_epoch_" + str(epoch))
# utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_best"))
epoch += 1
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-s", "--sentence_statistics", is_flag=True)
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
def val(corpus_name, model_weights, sentence_statistics, bleu):
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
val_pairs = corpus["val"]
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index)
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not "embedding" in k}
# make a new output layer to match the weights from the checkpoint
# we cannot remove it like we did with the embedding layers because
# unlike those the output layer actually contains learned parameters
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
# actually load the parameters
network.load_state_dict(weights, strict=False)
loss_fn = nn.CrossEntropyLoss()
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
total_val_loss = 0
if bleu == "sacrebleu":
total_val_bleu = 0
elif bleu == "nltk":
total_val_bleu_1 = 0
total_val_bleu_2 = 0
total_val_bleu_3 = 0
total_val_bleu_4 = 0
network.eval()
for i, batch in enumerate(val_batch_iter, 1):
input, target = batch
prediction, words, attentions, fixations = network(input, target)
loss = loss_fn(
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
)
total_val_loss += loss.item()
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
_target = [index2word[_x] for _x in target[0].tolist()]
if bleu == "sacrebleu":
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join(_target)).score
total_val_bleu += bleu_score
elif bleu == "nltk":
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
total_val_bleu_1 += bleu_1_score
total_val_bleu_2 += bleu_2_score
total_val_bleu_3 += bleu_3_score
total_val_bleu_4 += bleu_4_score
if sentence_statistics:
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
if bleu == "sacrebleu":
print(f"{bleu_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
elif bleu == "nltk":
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
if bleu == "sacrebleu":
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
elif bleu == "nltk":
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-s", "--sentence_statistics", is_flag=True)
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
def test(corpus_name, model_weights, sentence_statistics, bleu):
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
test_pairs = corpus["test"]
if model_weights is not None:
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index)
if model_weights is not None:
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not "embedding" in k}
# make a new output layer to match the weights from the checkpoint
# we cannot remove it like we did with the embedding layers because
# unlike those the output layer actually contains learned parameters
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
# actually load the parameters
network.load_state_dict(weights, strict=False)
loss_fn = nn.CrossEntropyLoss()
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
total_test_loss = 0
if bleu == "sacrebleu":
total_test_bleu = 0
elif bleu == "nltk":
total_test_bleu_1 = 0
total_test_bleu_2 = 0
total_test_bleu_3 = 0
total_test_bleu_4 = 0
network.eval()
for i, batch in enumerate(test_batch_iter, 1):
input, target = batch
prediction, words, attentions, fixations = network(input, target)
loss = loss_fn(
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
)
total_test_loss += loss.item()
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
_target = [index2word[_x] for _x in target[0].tolist()]
if bleu == "sacrebleu":
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join( _target)).score
total_test_bleu += bleu_score
elif bleu == "nltk":
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
total_test_bleu_1 += bleu_1_score
total_test_bleu_2 += bleu_2_score
total_test_bleu_3 += bleu_3_score
total_test_bleu_4 += bleu_4_score
if sentence_statistics:
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
if bleu == "sacrebleu":
print(f"{bleu_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
elif bleu == "nltk":
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
if bleu == "sacrebleu":
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu {total_test_bleu/len(test_pairs):.2f}")
elif bleu == "nltk":
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu_1 {total_test_bleu_1/len(test_pairs):.2f} avg_test_bleu_2 {total_test_bleu_2/len(test_pairs):.2f} avg_test_bleu_3 {total_test_bleu_3/len(test_pairs):.2f} avg_test_bleu_4 {total_test_bleu_4/len(test_pairs):.2f}")
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
)
@click.option("-w", "--model_weights", required=False)
def predict(corpus_name, model_weights):
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
test_pairs = corpus["val"]
if model_weights is not None:
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index)
logger.info(f"vocab size {len(word2index)}")
if model_weights is not None:
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not "embedding" in k}
# make a new output layer to match the weights from the checkpoint
# we cannot remove it like we did with the embedding layers because
# unlike those the output layer actually contains learned parameters
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
# actually load the parameters
network.load_state_dict(weights, strict=False)
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
network.eval()
for i, batch in enumerate(test_batch_iter, 1):
input, target = batch
prediction, words, attentions, fixations = network(input, target)
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
@main.command()
@click.option("-w", "--model_weights", required=True)
@click.argument("path")
def predict_file(model_weights, path):
logger.info("loading sentences")
sentences = []
lang = corpora.Lang("pred")
with open(path) as h:
for line in h:
line = line.strip()
if line:
sentence = line.split(" ")
lang.add_sentence(sentence)
sentences.append(sentence)
word2index = lang.word2index
index2word = {i: w for w, i in word2index.items()}
logger.info(f"{len(sentences)} sentences loaded")
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index)
logger.info(f"vocab size {len(word2index)}")
# remove the embedding layer before loading
weights = {k: v for k, v in weights.items() if not "embedding" in k}
# make a new output layer to match the weights from the checkpoint
# we cannot remove it like we did with the embedding layers because
# unlike those the output layer actually contains learned parameters
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
# actually load the parameters
network.load_state_dict(weights, strict=False)
debug_sentence_iter = utils.sent_iter(sentences, word2index=word2index)
network.eval()
for i, input in enumerate(debug_sentence_iter, 1):
prediction, words, attentions, fixations = network(input)
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
attentions = attentions.detach().cpu().squeeze().tolist()
fixations = fixations.detach().cpu().squeeze().tolist()
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
if __name__ == "__main__":
main()