795 lines
36 KiB
Python
795 lines
36 KiB
Python
import logging
|
||
import os
|
||
import pathlib
|
||
import random
|
||
import sys
|
||
|
||
import click
|
||
import sacrebleu
|
||
import torch
|
||
import torch.nn as nn
|
||
import tqdm
|
||
|
||
import config
|
||
from libs import corpora
|
||
from libs import utils
|
||
from libs.fixation_generation import Network as FixNN
|
||
from libs.paraphrase_generation import (
|
||
EncoderRNN as ParEncNN,
|
||
AttnDecoderRNN as ParDecNN,
|
||
)
|
||
|
||
|
||
cwd = os.path.dirname(__file__)
|
||
|
||
logger = logging.getLogger("main")
|
||
|
||
'''
|
||
#qqp_paw sentences:
|
||
|
||
debug_sentences = [
|
||
'What are the driving rules in Georgia versus Mississippi ?',
|
||
'I want to be a child psychologist , what qualification do i need to become one ? Are there good and reputed psychology Institute or Colleges in India ?',
|
||
'What is deep web and dark web and what are the contents of these sites ?',
|
||
'What is difference between North Indian Brahmins and South Indian Brahmins ?',
|
||
'Is carbon dioxide an ionic bond or a covalent bond ?',
|
||
'How do accounts receivable and accounts payable differ ?',
|
||
'Why did Wikipedia hide its audit history for ( superluminal ) successful speed experiments ?',
|
||
'What makes a person simple , or inversely , complicated ?',
|
||
'`` How do you say `` Miss you , too , `` in Spanish ? Are there multiple ways to say it ? ‘’',
|
||
'What is the difference between dominant trait and recessive trait ?’',
|
||
'`` What is the difference between `` seeing someone , `` `` dating someone , `` and `` having a girlfriend/boyfriend `` ? ‘’',
|
||
'How was the Empire State building built and designed ? How is it used ?',
|
||
'What is the sum of the square roots of first n natural number ?',
|
||
'Why is Roman Saini not so active on Quora now a days ?',
|
||
'If I have someone blocked on Instagram , and see their story , can they view I viewed it ?',
|
||
'Amongst the major IT companies of India which is the best ; Wipro , Capgemini , Infosys , TCS or is Oracle the best ?',
|
||
'How much mass does Saturn lose each year ? How much mass does it gain ?',
|
||
'What is a cheap healthy diet , I can keep the same and eat every day ?',
|
||
' What is it like to be beautiful ? Not just pretty or hot , but the kind of almost objective beauty that people are sometimes intimidated by ?',
|
||
'Could someone tell of a James Ronsey ( misspelled likely ) , writer and filmmaker , probably of the British Isles ?',
|
||
'How much pressure Is there around the core of Pluto ? is it enough to turn hydrogen/helium gas into a liquid or metallic state ?',
|
||
'How does quality of life in Vancouver compare to that in Melbourne Or Brisbane ?',
|
||
]
|
||
'''
|
||
'''
|
||
#wiki sentences:
|
||
|
||
debug_sentences = [
|
||
'They were there to enjoy us and they were there to pray for us .',
|
||
'Components of elastic potential systems store mechanical energy if they are deformed when forces are applied to the system .',
|
||
'Steam can also be used , and does not need to be pumped .',
|
||
'The solar approach to this requirement is the use of solar panels in a conventional-powered aircraft .',
|
||
'Daudkhali is a village in Barisal Division in the Pirojpur district in southwestern Bangladesh .',
|
||
'Briggs later met Briggs at the 1967 Monterey Pop Festival , where Ravi Shankar was also performing , with Eric Burdon and The Animals .',
|
||
'Brockton is approximately 25 miles northeast of Providence , Rhode Island , and 30 miles south of Boston .',
|
||
]
|
||
'''
|
||
|
||
#qqp sentences:
|
||
|
||
debug_sentences = [
|
||
'How do I get funding for my web based startup idea ?',
|
||
'What do intelligent people do to pass time ?',
|
||
'Which is the best SEO Company in Delhi ?',
|
||
'Why do you waer makeup ?',
|
||
'How do start chatting with a girl ?',
|
||
'What is the meaning of living life ?',
|
||
'Why do my armpits hurt ?',
|
||
'Why does eye color change with age ?',
|
||
'How do you find the standard deviation of a probability distribution ? What are some examples ?',
|
||
'How can I complete my 11 syllabus in one month ?',
|
||
'How do I concentrate better on my studies ?',
|
||
'Which is the best retirement plan in india ?',
|
||
'Should I tell my best friend I love her ?',
|
||
'Which is the best company for Appian Vagrant online job support ?',
|
||
'How can one do for good handwriting ?',
|
||
'What are remedies to get rid of belly fat ?',
|
||
'What is the best way to cook precooked turkey ?',
|
||
'What is the future of e-commerce in India ?',
|
||
'Why do my burps taste like rotten eggs ?',
|
||
'What is an example of chemical weathering ?',
|
||
'What are some of the advantages and disadvantages of cyber schooling ?',
|
||
'How can I increase traffic to my websites by Facebook ?',
|
||
'How do I increase my patience level in life ?',
|
||
'What are the best hospitals for treating cancer in India ?',
|
||
'Will Jio sim work in a 3G phone ? If yes , how ?',
|
||
]
|
||
|
||
debug_sentences = [s.split(" ") for s in debug_sentences]
|
||
|
||
|
||
class Network(nn.Module):
|
||
def __init__(
|
||
self,
|
||
word2index,
|
||
embeddings,
|
||
):
|
||
super().__init__()
|
||
self.logger = logging.getLogger(f"{__name__}")
|
||
self.word2index = word2index
|
||
self.index2word = {i: k for k, i in word2index.items()}
|
||
self.fix_gen = FixNN(
|
||
embedding_type="glove",
|
||
vocab_size=len(word2index),
|
||
embedding_dim=config.embedding_dim,
|
||
embeddings=embeddings,
|
||
dropout=config.fix_dropout,
|
||
hidden_dim=config.fix_hidden_dim,
|
||
)
|
||
self.par_enc = ParEncNN(
|
||
input_size=config.embedding_dim,
|
||
hidden_size=config.par_hidden_dim,
|
||
embeddings=embeddings,
|
||
)
|
||
self.par_dec = ParDecNN(
|
||
input_size=config.embedding_dim,
|
||
hidden_size=config.par_hidden_dim,
|
||
output_size=len(word2index),
|
||
embeddings=embeddings,
|
||
dropout_p=config.par_dropout,
|
||
max_length=config.max_length,
|
||
)
|
||
|
||
def forward(self, x, target=None, teacher_forcing_ratio=None):
|
||
teacher_forcing_ratio = teacher_forcing_ratio if teacher_forcing_ratio is not None else config.teacher_forcing_ratio
|
||
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
|
||
x2 = nn.utils.rnn.pad_sequence(x, batch_first=False)
|
||
fixations = torch.sigmoid(self.fix_gen(x1, [len(_x) for _x in x1]))
|
||
|
||
enc_hidden = self.par_enc.initHidden().to(config.DEV)
|
||
enc_outs = torch.zeros(config.max_length, config.par_hidden_dim, device=config.DEV)
|
||
|
||
for ei in range(len(x2)):
|
||
enc_out, enc_hidden = self.par_enc(x2[ei], enc_hidden)
|
||
enc_outs[ei] += enc_out[0, 0]
|
||
|
||
dec_in = torch.tensor([[self.word2index[config.SOS]]], device=config.DEV) # SOS
|
||
dec_hidden = enc_hidden
|
||
dec_outs = []
|
||
dec_words = []
|
||
dec_atts = torch.zeros(config.max_length, config.max_length)
|
||
|
||
if target is not None: # training
|
||
target = nn.utils.rnn.pad_sequence(target, batch_first=False)
|
||
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
||
|
||
if use_teacher_forcing:
|
||
for di in range(len(target)):
|
||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||
dec_in, dec_hidden, enc_outs, fixations
|
||
)
|
||
dec_outs.append(dec_out)
|
||
dec_atts[di] = dec_att.data
|
||
dec_input = target[di]
|
||
|
||
else:
|
||
for di in range(len(target)):
|
||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||
dec_in, dec_hidden, enc_outs, fixations
|
||
)
|
||
dec_outs.append(dec_out)
|
||
dec_atts[di] = dec_att.data
|
||
topv, topi = dec_out.data.topk(1)
|
||
dec_words.append(self.index2word[topi.item()])
|
||
|
||
dec_input = topi.squeeze().detach()
|
||
|
||
else: # prediction
|
||
for di in range(config.max_length):
|
||
dec_out, dec_hidden, dec_att = self.par_dec(
|
||
dec_in, dec_hidden, enc_outs, fixations
|
||
)
|
||
dec_outs.append(dec_out)
|
||
dec_atts[di] = dec_att.data
|
||
topv, topi = dec_out.data.topk(1)
|
||
if topi.item() == self.word2index[config.EOS]:
|
||
dec_words.append("<__EOS__>")
|
||
break
|
||
else:
|
||
dec_words.append(self.index2word[topi.item()])
|
||
|
||
dec_input = topi.squeeze().detach()
|
||
|
||
return dec_outs, dec_words, dec_atts[: di + 1], fixations
|
||
|
||
|
||
def load_corpus(corpus_name, splits):
|
||
if not splits:
|
||
return
|
||
|
||
logger.info("loading corpus")
|
||
if corpus_name == "msrpc":
|
||
load_fn = corpora.load_msrpc
|
||
elif corpus_name == "qqp":
|
||
load_fn = corpora.load_qqp
|
||
elif corpus_name == "wiki":
|
||
load_fn = corpora.load_wiki
|
||
elif corpus_name == "qqp_paws":
|
||
load_fn = corpora.load_qqp_paws
|
||
elif corpus_name == "qqp_kag":
|
||
load_fn = corpora.load_qqp_kag
|
||
elif corpus_name == "sentiment":
|
||
load_fn = corpora.load_sentiment
|
||
elif corpus_name == "stanford":
|
||
load_fn = corpora.load_stanford
|
||
elif corpus_name == "stanford_sent":
|
||
load_fn = corpora.load_stanford_sent
|
||
elif corpus_name == "tamil":
|
||
load_fn = corpora.load_tamil
|
||
elif corpus_name == "compression":
|
||
load_fn = corpora.load_compression
|
||
|
||
corpus = {}
|
||
langs = []
|
||
|
||
if "train" in splits:
|
||
train_pairs, train_lang = load_fn("train")
|
||
corpus["train"] = train_pairs
|
||
langs.append(train_lang)
|
||
if "val" in splits:
|
||
val_pairs, val_lang = load_fn("val")
|
||
corpus["val"] = val_pairs
|
||
langs.append(val_lang)
|
||
if "test" in splits:
|
||
test_pairs, test_lang = load_fn("test")
|
||
corpus["test"] = test_pairs
|
||
langs.append(test_lang)
|
||
|
||
logger.info("creating word index")
|
||
lang = langs[0]
|
||
for _lang in langs[1:]:
|
||
lang += _lang
|
||
word2index = lang.word2index
|
||
|
||
index2word = {i: w for w, i in word2index.items()}
|
||
|
||
return corpus, word2index, index2word
|
||
|
||
|
||
def init_network(word2index):
|
||
logger.info("loading embeddings")
|
||
vocabulary = sorted(word2index.keys())
|
||
embeddings = utils.load_glove(vocabulary)
|
||
|
||
logger.info("initializing model")
|
||
network = Network(
|
||
word2index=word2index,
|
||
embeddings=embeddings,
|
||
)
|
||
network.to(config.DEV)
|
||
|
||
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
|
||
|
||
return network
|
||
|
||
|
||
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
|
||
@click.option("-v", "--verbose", count=True)
|
||
@click.option("-d", "--debug", is_flag=True)
|
||
def main(verbose, debug):
|
||
if verbose == 0:
|
||
loglevel = logging.ERROR
|
||
elif verbose == 1:
|
||
loglevel = logging.WARN
|
||
elif verbose >= 2:
|
||
loglevel = logging.INFO
|
||
|
||
if debug:
|
||
loglevel = logging.DEBUG
|
||
|
||
logging.basicConfig(
|
||
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
|
||
datefmt="%d.%m. %H:%M:%S",
|
||
level=loglevel,
|
||
)
|
||
|
||
logger.debug("arguments: %s" % str(sys.argv))
|
||
|
||
|
||
@main.command()
|
||
@click.option(
|
||
"-c",
|
||
"--corpus",
|
||
"corpus_name",
|
||
required=True,
|
||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||
)
|
||
@click.option("-m", "--model_name", required=True)
|
||
@click.option("-w", "--fixation_weights", required=False)
|
||
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
|
||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||
def train(corpus_name, model_name, fixation_weights, freeze_fixations, bleu):
|
||
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
|
||
train_pairs = corpus["train"]
|
||
val_pairs = corpus["val"]
|
||
network = init_network(word2index)
|
||
|
||
model_dir = os.path.join("models", model_name)
|
||
logger.debug("creating model dir %s" % model_dir)
|
||
pathlib.Path(model_dir).mkdir(parents=True)
|
||
|
||
if fixation_weights is not None:
|
||
logger.info("loading fixation prediction checkpoint")
|
||
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
|
||
if "word2index" in checkpoint:
|
||
weights = checkpoint["weights"]
|
||
else:
|
||
weights = checkpoint
|
||
|
||
# remove the embedding layer before loading
|
||
weights = {k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")}
|
||
network.fix_gen.load_state_dict(weights, strict=False)
|
||
|
||
if freeze_fixations:
|
||
logger.info("freezing fixation generation network")
|
||
for p in network.fix_gen.parameters():
|
||
p.requires_grad = False
|
||
|
||
loss_fn = nn.CrossEntropyLoss()
|
||
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
|
||
#optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=1e-5)
|
||
|
||
best_val_loss = None
|
||
|
||
epoch = 1
|
||
while True:
|
||
train_batch_iter = utils.pair_iter(pairs=train_pairs, word2index=word2index, shuffle=True, shuffle_pairs=False)
|
||
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||
# test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||
|
||
running_train_loss = 0
|
||
total_train_loss = 0
|
||
total_val_loss = 0
|
||
|
||
if bleu == "sacrebleu":
|
||
running_train_bleu = 0
|
||
total_train_bleu = 0
|
||
total_val_bleu = 0
|
||
elif bleu == "nltk":
|
||
running_train_bleu_1 = 0
|
||
running_train_bleu_2 = 0
|
||
running_train_bleu_3 = 0
|
||
running_train_bleu_4 = 0
|
||
total_train_bleu_1 = 0
|
||
total_train_bleu_2 = 0
|
||
total_train_bleu_3 = 0
|
||
total_train_bleu_4 = 0
|
||
total_val_bleu_1 = 0
|
||
total_val_bleu_2 = 0
|
||
total_val_bleu_3 = 0
|
||
total_val_bleu_4 = 0
|
||
|
||
network.train()
|
||
for i, batch in enumerate(train_batch_iter, 1):
|
||
optimizer.zero_grad()
|
||
|
||
input, target = batch
|
||
prediction, words, attention, fixations = network(input, target)
|
||
|
||
loss = loss_fn(
|
||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||
)
|
||
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
running_train_loss += loss.item()
|
||
total_train_loss += loss.item()
|
||
|
||
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
|
||
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
|
||
|
||
if bleu == "sacrebleu":
|
||
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
|
||
running_train_bleu += bleu_score
|
||
total_train_bleu += bleu_score
|
||
elif bleu == "nltk":
|
||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||
running_train_bleu_1 += bleu_1_score
|
||
running_train_bleu_2 += bleu_2_score
|
||
running_train_bleu_3 += bleu_3_score
|
||
running_train_bleu_4 += bleu_4_score
|
||
total_train_bleu_1 += bleu_1_score
|
||
total_train_bleu_2 += bleu_2_score
|
||
total_train_bleu_3 += bleu_3_score
|
||
total_train_bleu_4 += bleu_4_score
|
||
# print(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist())
|
||
|
||
if i % 100 == 0:
|
||
if bleu == "sacrebleu":
|
||
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
|
||
elif bleu == "nltk":
|
||
print(f"step {i} avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
|
||
|
||
network.eval()
|
||
with open(os.path.join(model_dir, f"debug_{epoch}_{i}.out"), "w") as h:
|
||
if bleu == "sacrebleu":
|
||
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu {running_train_bleu/100:.2f}")
|
||
running_train_bleu = 0
|
||
elif bleu == "nltk":
|
||
h.write(f"# avg_train_loss {running_train_loss/100:.4f} avg_train_bleu_1 {running_train_bleu_1/100:.2f} avg_train_bleu_2 {running_train_bleu_2/100:.2f} avg_train_bleu_3 {running_train_bleu_3/100:.2f} avg_train_bleu_4 {running_train_bleu_4/100:.2f}")
|
||
running_train_bleu_1 = 0
|
||
running_train_bleu_2 = 0
|
||
running_train_bleu_3 = 0
|
||
running_train_bleu_4 = 0
|
||
|
||
running_train_loss = 0
|
||
|
||
h.write("\n")
|
||
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
|
||
h.write("\n")
|
||
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
|
||
prediction, words, attentions, fixations = network(input)
|
||
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
|
||
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
|
||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
|
||
h.write("\n")
|
||
|
||
network.train()
|
||
|
||
network.eval()
|
||
for i, batch in enumerate(val_batch_iter):
|
||
input, target = batch
|
||
prediction, words, attention, fixations = network(input, target, teacher_forcing_ratio=0)
|
||
loss = loss_fn(
|
||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||
)
|
||
|
||
_prediction = " ".join([index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()])
|
||
_target = " ".join([index2word[_x] for _x in target[0].tolist()])
|
||
if bleu == "sacrebleu":
|
||
bleu_score = sacrebleu.sentence_bleu(_prediction, _target).score
|
||
elif bleu == "nltk":
|
||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||
total_val_bleu_1 += bleu_1_score
|
||
total_val_bleu_2 += bleu_2_score
|
||
total_val_bleu_3 += bleu_3_score
|
||
total_val_bleu_4 += bleu_4_score
|
||
|
||
total_val_loss += loss.item()
|
||
|
||
avg_val_loss = total_val_loss/len(val_pairs)
|
||
|
||
if bleu == "sacrebleu":
|
||
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu {total_train_bleu/len(train_pairs):.2f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
|
||
elif bleu == "nltk":
|
||
print(f"epoch {epoch} avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||
|
||
with open(os.path.join(model_dir, f"debug_{epoch}_end.out"), "w") as h:
|
||
if bleu == "sacrebleu":
|
||
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs)} avg_val_loss {total_val_loss/len(val_pairs)} avg_train_bleu {total_train_bleu/len(train_pairs)} avg_val_bleu {total_val_bleu/len(val_pairs)}")
|
||
elif bleu == "nltk":
|
||
h.write(f"# avg_train_loss {total_train_loss/len(train_pairs):.4f} avg_val_loss {avg_val_loss:.4f} avg_train_bleu_1 {total_train_bleu_1/len(train_pairs):.2f} avg_train_bleu_2 {total_train_bleu_2/len(train_pairs):.2f} avg_train_bleu_3 {total_train_bleu_3/len(train_pairs):.2f} avg_train_bleu_4 {total_train_bleu_4/len(train_pairs):.2f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||
h.write("\n")
|
||
h.write("\t".join(["sentence", "prediction", "attention", "fixations"]))
|
||
h.write("\n")
|
||
for s, input in zip(debug_sentences, utils.sent_iter(debug_sentences, word2index=word2index)):
|
||
prediction, words, attentions, fixations = network(input)
|
||
prediction = torch.argmax(torch.stack(prediction).squeeze(1), -1).detach().cpu().tolist()
|
||
prediction = [index2word.get(x, "<__UNK__>") for x in prediction]
|
||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||
h.write(f"{s}\t{prediction}\t{attentions}\t{fixations}")
|
||
h.write("\n")
|
||
|
||
utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_{epoch}"))
|
||
|
||
if best_val_loss is None or avg_val_loss < best_val_loss:
|
||
if best_val_loss is not None:
|
||
logger.info(f"{avg_val_loss} < {best_val_loss} ({avg_val_loss-best_val_loss}): new best model from epoch {epoch}")
|
||
else:
|
||
logger.info(f"{avg_val_loss} < {best_val_loss}: new best model from epoch {epoch}")
|
||
|
||
best_val_loss = avg_val_loss
|
||
# save_model(model, word2index, model_name + "_epoch_" + str(epoch))
|
||
# utils.save_model(network, word2index, os.path.join(model_dir, f"{model_name}_best"))
|
||
|
||
epoch += 1
|
||
|
||
|
||
@main.command()
|
||
@click.option(
|
||
"-c",
|
||
"--corpus",
|
||
"corpus_name",
|
||
required=True,
|
||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||
)
|
||
@click.option("-w", "--model_weights", required=True)
|
||
@click.option("-s", "--sentence_statistics", is_flag=True)
|
||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||
def val(corpus_name, model_weights, sentence_statistics, bleu):
|
||
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
|
||
val_pairs = corpus["val"]
|
||
|
||
logger.info("loading model checkpoint")
|
||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||
if "word2index" in checkpoint:
|
||
weights = checkpoint["weights"]
|
||
word2index = checkpoint["word2index"]
|
||
index2word = {i: w for w, i in word2index.items()}
|
||
else:
|
||
asdf
|
||
|
||
network = init_network(word2index)
|
||
|
||
# remove the embedding layer before loading
|
||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||
# make a new output layer to match the weights from the checkpoint
|
||
# we cannot remove it like we did with the embedding layers because
|
||
# unlike those the output layer actually contains learned parameters
|
||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||
# actually load the parameters
|
||
network.load_state_dict(weights, strict=False)
|
||
|
||
loss_fn = nn.CrossEntropyLoss()
|
||
|
||
val_batch_iter = utils.pair_iter(pairs=val_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||
|
||
total_val_loss = 0
|
||
if bleu == "sacrebleu":
|
||
total_val_bleu = 0
|
||
elif bleu == "nltk":
|
||
total_val_bleu_1 = 0
|
||
total_val_bleu_2 = 0
|
||
total_val_bleu_3 = 0
|
||
total_val_bleu_4 = 0
|
||
|
||
network.eval()
|
||
for i, batch in enumerate(val_batch_iter, 1):
|
||
input, target = batch
|
||
prediction, words, attentions, fixations = network(input, target)
|
||
|
||
loss = loss_fn(
|
||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||
)
|
||
total_val_loss += loss.item()
|
||
|
||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
|
||
_target = [index2word[_x] for _x in target[0].tolist()]
|
||
if bleu == "sacrebleu":
|
||
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join(_target)).score
|
||
total_val_bleu += bleu_score
|
||
elif bleu == "nltk":
|
||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||
total_val_bleu_1 += bleu_1_score
|
||
total_val_bleu_2 += bleu_2_score
|
||
total_val_bleu_3 += bleu_3_score
|
||
total_val_bleu_4 += bleu_4_score
|
||
|
||
if sentence_statistics:
|
||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||
|
||
if bleu == "sacrebleu":
|
||
print(f"{bleu_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
|
||
elif bleu == "nltk":
|
||
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{_target}\t{attentions}\t{fixations}")
|
||
|
||
if bleu == "sacrebleu":
|
||
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu {total_val_bleu/len(val_pairs):.2f}")
|
||
elif bleu == "nltk":
|
||
print(f"avg_val_loss {total_val_loss/len(val_pairs):.4f} avg_val_bleu_1 {total_val_bleu_1/len(val_pairs):.2f} avg_val_bleu_2 {total_val_bleu_2/len(val_pairs):.2f} avg_val_bleu_3 {total_val_bleu_3/len(val_pairs):.2f} avg_val_bleu_4 {total_val_bleu_4/len(val_pairs):.2f}")
|
||
|
||
|
||
@main.command()
|
||
@click.option(
|
||
"-c",
|
||
"--corpus",
|
||
"corpus_name",
|
||
required=True,
|
||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||
)
|
||
@click.option("-w", "--model_weights", required=True)
|
||
@click.option("-s", "--sentence_statistics", is_flag=True)
|
||
@click.option("-b", "--bleu", type=click.Choice(["sacrebleu", "nltk"]), required=True)
|
||
def test(corpus_name, model_weights, sentence_statistics, bleu):
|
||
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
||
test_pairs = corpus["test"]
|
||
|
||
if model_weights is not None:
|
||
logger.info("loading model checkpoint")
|
||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||
if "word2index" in checkpoint:
|
||
weights = checkpoint["weights"]
|
||
word2index = checkpoint["word2index"]
|
||
index2word = {i: w for w, i in word2index.items()}
|
||
else:
|
||
asdf
|
||
|
||
network = init_network(word2index)
|
||
|
||
if model_weights is not None:
|
||
# remove the embedding layer before loading
|
||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||
# make a new output layer to match the weights from the checkpoint
|
||
# we cannot remove it like we did with the embedding layers because
|
||
# unlike those the output layer actually contains learned parameters
|
||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||
# actually load the parameters
|
||
network.load_state_dict(weights, strict=False)
|
||
|
||
loss_fn = nn.CrossEntropyLoss()
|
||
|
||
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||
|
||
total_test_loss = 0
|
||
if bleu == "sacrebleu":
|
||
total_test_bleu = 0
|
||
elif bleu == "nltk":
|
||
total_test_bleu_1 = 0
|
||
total_test_bleu_2 = 0
|
||
total_test_bleu_3 = 0
|
||
total_test_bleu_4 = 0
|
||
|
||
network.eval()
|
||
for i, batch in enumerate(test_batch_iter, 1):
|
||
input, target = batch
|
||
|
||
prediction, words, attentions, fixations = network(input, target)
|
||
|
||
loss = loss_fn(
|
||
torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], target[0]
|
||
)
|
||
total_test_loss += loss.item()
|
||
|
||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist()]
|
||
_target = [index2word[_x] for _x in target[0].tolist()]
|
||
if bleu == "sacrebleu":
|
||
bleu_score = sacrebleu.sentence_bleu(" ".join(_prediction), " ".join( _target)).score
|
||
total_test_bleu += bleu_score
|
||
elif bleu == "nltk":
|
||
bleu_1_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=1)
|
||
bleu_2_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=2)
|
||
bleu_3_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=3)
|
||
bleu_4_score = utils.bleu(target[0].tolist(), torch.argmax(torch.stack(prediction).squeeze(1)[: target[0].shape[0], :], -1).tolist(), n=4)
|
||
total_test_bleu_1 += bleu_1_score
|
||
total_test_bleu_2 += bleu_2_score
|
||
total_test_bleu_3 += bleu_3_score
|
||
total_test_bleu_4 += bleu_4_score
|
||
|
||
if sentence_statistics:
|
||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||
|
||
if bleu == "sacrebleu":
|
||
print(f"{bleu_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||
elif bleu == "nltk":
|
||
print(f"{bleu_1_score}\t{bleu_2_score}\t{bleu_3_score}\t{bleu_4_score}\t{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||
|
||
if bleu == "sacrebleu":
|
||
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu {total_test_bleu/len(test_pairs):.2f}")
|
||
elif bleu == "nltk":
|
||
print(f"avg_test_loss {total_test_loss/len(test_pairs):.4f} avg_test_bleu_1 {total_test_bleu_1/len(test_pairs):.2f} avg_test_bleu_2 {total_test_bleu_2/len(test_pairs):.2f} avg_test_bleu_3 {total_test_bleu_3/len(test_pairs):.2f} avg_test_bleu_4 {total_test_bleu_4/len(test_pairs):.2f}")
|
||
|
||
|
||
@main.command()
|
||
@click.option(
|
||
"-c",
|
||
"--corpus",
|
||
"corpus_name",
|
||
required=True,
|
||
type=click.Choice(sorted(["wiki", "qqp", "qqp_kag", "msrpc", "qqp_paws", "sentiment", "stanford", "stanford_sent", "tamil", "compression"])),
|
||
)
|
||
@click.option("-w", "--model_weights", required=False)
|
||
def predict(corpus_name, model_weights):
|
||
corpus, word2index, index2word = load_corpus(corpus_name, ["val"])
|
||
test_pairs = corpus["val"]
|
||
|
||
if model_weights is not None:
|
||
logger.info("loading model checkpoint")
|
||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||
if "word2index" in checkpoint:
|
||
weights = checkpoint["weights"]
|
||
word2index = checkpoint["word2index"]
|
||
index2word = {i: w for w, i in word2index.items()}
|
||
else:
|
||
asdf
|
||
|
||
network = init_network(word2index)
|
||
|
||
logger.info(f"vocab size {len(word2index)}")
|
||
|
||
if model_weights is not None:
|
||
# remove the embedding layer before loading
|
||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||
# make a new output layer to match the weights from the checkpoint
|
||
# we cannot remove it like we did with the embedding layers because
|
||
# unlike those the output layer actually contains learned parameters
|
||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||
# actually load the parameters
|
||
network.load_state_dict(weights, strict=False)
|
||
|
||
test_batch_iter = utils.pair_iter(pairs=test_pairs, word2index=word2index, shuffle=False, shuffle_pairs=False)
|
||
|
||
network.eval()
|
||
for i, batch in enumerate(test_batch_iter, 1):
|
||
input, target = batch
|
||
|
||
prediction, words, attentions, fixations = network(input, target)
|
||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
|
||
|
||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||
|
||
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||
|
||
|
||
@main.command()
|
||
@click.option("-w", "--model_weights", required=True)
|
||
@click.argument("path")
|
||
def predict_file(model_weights, path):
|
||
logger.info("loading sentences")
|
||
sentences = []
|
||
lang = corpora.Lang("pred")
|
||
with open(path) as h:
|
||
for line in h:
|
||
line = line.strip()
|
||
if line:
|
||
sentence = line.split(" ")
|
||
lang.add_sentence(sentence)
|
||
sentences.append(sentence)
|
||
word2index = lang.word2index
|
||
index2word = {i: w for w, i in word2index.items()}
|
||
|
||
logger.info(f"{len(sentences)} sentences loaded")
|
||
|
||
logger.info("loading model checkpoint")
|
||
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
||
if "word2index" in checkpoint:
|
||
weights = checkpoint["weights"]
|
||
word2index = checkpoint["word2index"]
|
||
index2word = {i: w for w, i in word2index.items()}
|
||
else:
|
||
asdf
|
||
|
||
network = init_network(word2index)
|
||
|
||
logger.info(f"vocab size {len(word2index)}")
|
||
|
||
# remove the embedding layer before loading
|
||
weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
||
# make a new output layer to match the weights from the checkpoint
|
||
# we cannot remove it like we did with the embedding layers because
|
||
# unlike those the output layer actually contains learned parameters
|
||
vocab_size, hidden_size = weights["par_dec.out.weight"].shape
|
||
network.par_dec.out = nn.Linear(hidden_size, vocab_size).to(config.DEV)
|
||
# actually load the parameters
|
||
network.load_state_dict(weights, strict=False)
|
||
|
||
debug_sentence_iter = utils.sent_iter(sentences, word2index=word2index)
|
||
|
||
network.eval()
|
||
for i, input in enumerate(debug_sentence_iter, 1):
|
||
|
||
prediction, words, attentions, fixations = network(input)
|
||
_prediction = [index2word[_x] for _x in torch.argmax(torch.stack(prediction).squeeze(1), -1).tolist()]
|
||
|
||
s = [index2word[x] for x in input[0].detach().cpu().tolist()]
|
||
attentions = attentions.detach().cpu().squeeze().tolist()
|
||
fixations = fixations.detach().cpu().squeeze().tolist()
|
||
|
||
print(f"{s}\t{_prediction}\t{attentions}\t{fixations}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|