import logging import os import pathlib import random import re import sys import click import sacrebleu import torch import torch.nn as nn import tqdm import config from libs import corpora from libs import utils from libs.fixation_generation import Network as FixNN from libs.sentence_compression import Network as ComNN from sklearn.metrics import classification_report, precision_recall_fscore_support cwd = os.path.dirname(__file__) logger = logging.getLogger("main") class Network(nn.Module): def __init__( self, word2index, embeddings, prior, ): super().__init__() self.logger = logging.getLogger(f"{__name__}") self.word2index = word2index self.index2word = {i: k for k, i in word2index.items()} self.fix_gen = FixNN( embedding_type="glove", vocab_size=len(word2index), embedding_dim=config.embedding_dim, embeddings=embeddings, dropout=config.fix_dropout, hidden_dim=config.fix_hidden_dim, ) self.com_nn = ComNN( embeddings=embeddings, hidden_size=config.sem_hidden_dim, prior=prior, device=config.DEV ) def forward(self, x, target, seq_lens): x1 = nn.utils.rnn.pad_sequence(x, batch_first=True) target = nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=-1) fixations = torch.sigmoid(self.fix_gen(x1, seq_lens)) # fixations = None loss, pred, atts = self.com_nn(x1, target, fixations) return loss, pred, atts, fixations def load_corpus(corpus_name, splits): if not splits: return logger.info("loading corpus") if corpus_name == "google": load_fn = corpora.load_google corpus = {} langs = [] if "train" in splits: train_pairs, train_lang = load_fn("train", max_len=200) corpus["train"] = train_pairs langs.append(train_lang) if "val" in splits: val_pairs, val_lang = load_fn("val") corpus["val"] = val_pairs langs.append(val_lang) if "test" in splits: test_pairs, test_lang = load_fn("test") corpus["test"] = test_pairs langs.append(test_lang) logger.info("creating word index") lang = langs[0] for _lang in langs[1:]: lang += _lang word2index = lang.word2index index2word = {i: w for w, i in word2index.items()} return corpus, word2index, index2word def init_network(word2index, prior): logger.info("loading embeddings") vocabulary = sorted(word2index.keys()) embeddings = utils.load_glove(vocabulary) logger.info("initializing model") network = Network(word2index=word2index, embeddings=embeddings, prior=prior) network.to(config.DEV) print(f"#parameters: {sum(p.numel() for p in network.parameters())}") return network @click.group(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-v", "--verbose", count=True) @click.option("-d", "--debug", is_flag=True) def main(verbose, debug): if verbose == 0: loglevel = logging.ERROR elif verbose == 1: loglevel = logging.WARN elif verbose >= 2: loglevel = logging.INFO if debug: loglevel = logging.DEBUG logging.basicConfig( format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s", datefmt="%d.%m. %H:%M:%S", level=loglevel, ) logger.debug("arguments: %s" % str(sys.argv)) @main.command() @click.option( "-c", "--corpus", "corpus_name", required=True, type=click.Choice(sorted(["google",])), ) @click.option("-m", "--model_name", required=True) @click.option("-w", "--fixation_weights", required=False) @click.option("-f", "--freeze_fixations", is_flag=True, default=False) @click.option("-d", "--debug", is_flag=True, default=False) @click.option("-p", "--prior", type=float, default=.5) def train(corpus_name, model_name, fixation_weights, freeze_fixations, debug, prior): corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"]) train_pairs = corpus["train"] val_pairs = corpus["val"] network = init_network(word2index, prior) model_dir = os.path.join("models", model_name) logger.debug("creating model dir %s" % model_dir) pathlib.Path(model_dir).mkdir(parents=True) if fixation_weights is not None: logger.info("loading fixation prediction checkpoint") checkpoint = torch.load(fixation_weights, map_location=config.DEV) if "word2index" in checkpoint: weights = checkpoint["weights"] else: weights = checkpoint # remove the embedding layer before loading weights = { k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer") } network.fix_gen.load_state_dict(weights, strict=False) if freeze_fixations: logger.info("freezing fixation generation network") for p in network.fix_gen.parameters(): p.requires_grad = False optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate) best_val_loss = None epoch = 1 batch_size = 20 while True: train_batch_iter = utils.sent_iter( sents=train_pairs, word2index=word2index, batch_size=batch_size ) val_batch_iter = utils.sent_iter( sents=val_pairs, word2index=word2index, batch_size=batch_size ) total_train_loss = 0 total_val_loss = 0 network.train() for i, batch in tqdm.tqdm( enumerate(train_batch_iter, 1), total=len(train_pairs) // batch_size + 1 ): optimizer.zero_grad() raw_sent, sent, target = batch seq_lens = [len(x) for x in sent] loss, prediction, attention, fixations = network(sent, target, seq_lens) prediction = prediction.detach().cpu().numpy() torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5) loss.backward() optimizer.step() total_train_loss += loss.item() avg_train_loss = total_train_loss / len(train_pairs) val_sents = [] val_preds = [] val_targets = [] network.eval() for i, batch in tqdm.tqdm( enumerate(val_batch_iter), total=len(val_pairs) // batch_size + 1 ): raw_sent, sent, target = batch seq_lens = [len(x) for x in sent] loss, prediction, attention, fixations = network(sent, target, seq_lens) prediction = prediction.detach().cpu().numpy() for i, l in enumerate(seq_lens): val_sents.append(raw_sent[i][:l]) val_preds.append(prediction[i][:l].tolist()) val_targets.append(target[i][:l].tolist()) total_val_loss += loss.item() avg_val_loss = total_val_loss / len(val_pairs) print( f"epoch {epoch} train_loss {avg_train_loss:.4f} val_loss {avg_val_loss:.4f}" ) print( classification_report( [x for y in val_targets for x in y], [x for y in val_preds for x in y], target_names=["not_del", "del"], digits=5, ) ) with open(f"models/{model_name}/val_original_{epoch}.txt", "w") as oh, open( f"models/{model_name}/val_pred_{epoch}.txt", "w" ) as ph, open(f"models/{model_name}/val_gold_{epoch}.txt", "w") as gh: for sent, preds, golds in zip(val_sents, val_preds, val_targets): pred_compressed = [ word for word, delete in zip(sent, preds) if not delete ] gold_compressed = [ word for word, delete in zip(sent, golds) if not delete ] oh.write(" ".join(sent)) ph.write(" ".join(pred_compressed)) gh.write(" ".join(gold_compressed)) oh.write("\n") ph.write("\n") gh.write("\n") if best_val_loss is None or avg_val_loss < best_val_loss: delta = avg_val_loss - best_val_loss if best_val_loss is not None else 0.0 best_val_loss = avg_val_loss print( f"new best model epoch {epoch} val loss {avg_val_loss:.4f} ({delta:.4f})" ) utils.save_model( network, word2index, f"models/{model_name}/{model_name}_{epoch}" ) epoch += 1 @main.command() @click.option( "-c", "--corpus", "corpus_name", required=True, type=click.Choice(sorted(["google",])), ) @click.option("-w", "--model_weights", required=True) @click.option("-p", "--prior", type=float, default=.5) @click.option("-l", "--longest", is_flag=True) @click.option("-s", "--shortest", is_flag=True) @click.option("-d", "--detailed", is_flag=True) def test(corpus_name, model_weights, prior, longest, shortest, detailed): if longest and shortest: print("longest and shortest are mutually exclusive", file=sys.stderr) sys.exit() corpus, word2index, index2word = load_corpus(corpus_name, ["test"]) test_pairs = corpus["test"] model_name = os.path.basename(os.path.dirname(model_weights)) epoch = re.search("_(\d+).tar", model_weights).group(1) logger.info("loading model checkpoint") checkpoint = torch.load(model_weights, map_location=config.DEV) if "word2index" in checkpoint: weights = checkpoint["weights"] word2index = checkpoint["word2index"] index2word = {i: w for w, i in word2index.items()} else: asdf network = init_network(word2index, prior) network.eval() # remove the embedding layer before loading # weights = {k: v for k, v in weights.items() if not "embedding" in k} # actually load the parameters network.load_state_dict(weights, strict=False) total_test_loss = 0 batch_size = 20 test_batch_iter = utils.sent_iter( sents=test_pairs, word2index=word2index, batch_size=batch_size ) test_sents = [] test_preds = [] test_targets = [] for i, batch in tqdm.tqdm( enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1 ): raw_sent, sent, target = batch seq_lens = [len(x) for x in sent] loss, prediction, attention, fixations = network(sent, target, seq_lens) prediction = prediction.detach().cpu().numpy() for i, l in enumerate( seq_lens ): test_sents.append(raw_sent[i][:l]) test_preds.append(prediction[i][:l].tolist()) test_targets.append(target[i][:l].tolist()) total_test_loss += loss.item() avg_test_loss = total_test_loss / len(test_pairs) print(f"test_loss {avg_test_loss:.4f}") if longest: avg_len = sum(len(s) for s in test_sents)/len(test_sents) test_sents = list(filter(lambda x: len(x) > avg_len, test_sents)) test_preds = list(filter(lambda x: len(x) > avg_len, test_preds)) test_targets = list(filter(lambda x: len(x) > avg_len, test_targets)) elif shortest: avg_len = sum(len(s) for s in test_sents)/len(test_sents) test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents)) test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds)) test_targets = list(filter(lambda x: len(x) <= avg_len, test_targets)) if detailed: for test_sent, test_target, test_pred in zip(test_sents, test_targets, test_preds): print(precision_recall_fscore_support(test_target, test_pred, average="weighted")[2], test_sent, test_target, test_pred) else: print( classification_report( [x for y in test_targets for x in y], [x for y in test_preds for x in y], target_names=["not_del", "del"], digits=5, ) ) with open(f"models/{model_name}/test_original_{epoch}.txt", "w") as oh, open( f"models/{model_name}/test_pred_{epoch}.txt", "w" ) as ph, open(f"models/{model_name}/test_gold_{epoch}.txt", "w") as gh: for sent, preds, golds in zip(test_sents, test_preds, test_targets): pred_compressed = [word for word, delete in zip(sent, preds) if not delete] gold_compressed = [word for word, delete in zip(sent, golds) if not delete] oh.write(" ".join(sent)) ph.write(" ".join(pred_compressed)) gh.write(" ".join(gold_compressed)) oh.write("\n") ph.write("\n") gh.write("\n") @main.command() @click.option( "-c", "--corpus", "corpus_name", required=True, type=click.Choice(sorted(["google",])), ) @click.option("-w", "--model_weights", required=True) @click.option("-p", "--prior", type=float, default=.5) @click.option("-l", "--longest", is_flag=True) @click.option("-s", "--shortest", is_flag=True) def predict(corpus_name, model_weights, prior, longest, shortest): if longest and shortest: print("longest and shortest are mutually exclusive", file=sys.stderr) sys.exit() corpus, word2index, index2word = load_corpus(corpus_name, ["test"]) test_pairs = corpus["test"] model_name = os.path.basename(os.path.dirname(model_weights)) epoch = re.search("_(\d+).tar", model_weights).group(1) logger.info("loading model checkpoint") checkpoint = torch.load(model_weights, map_location=config.DEV) if "word2index" in checkpoint: weights = checkpoint["weights"] word2index = checkpoint["word2index"] index2word = {i: w for w, i in word2index.items()} else: asdf network = init_network(word2index, prior) network.eval() # remove the embedding layer before loading # weights = {k: v for k, v in weights.items() if not "embedding" in k} # actually load the parameters network.load_state_dict(weights, strict=False) total_test_loss = 0 batch_size = 20 test_batch_iter = utils.sent_iter( sents=test_pairs, word2index=word2index, batch_size=batch_size ) test_sents = [] test_preds = [] test_attentions = [] test_fixations = [] for i, batch in tqdm.tqdm( enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1 ): raw_sent, sent, target = batch seq_lens = [len(x) for x in sent] loss, prediction, attention, fixations = network(sent, target, seq_lens) prediction = prediction.detach().cpu().numpy() attention = attention.detach().cpu().numpy() if fixations is not None: fixations = fixations.detach().cpu().numpy() for i, l in enumerate( seq_lens ): test_sents.append(raw_sent[i][:l]) test_preds.append(prediction[i][:l].tolist()) test_attentions.append(attention[i][:l].tolist()) if fixations is not None: test_fixations.append(fixations[i][:l].tolist()) else: test_fixations.append([]) total_test_loss += loss.item() avg_test_loss = total_test_loss / len(test_pairs) if longest: avg_len = sum(len(s) for s in test_sents)/len(test_sents) test_sents = list(filter(lambda x: len(x) > avg_len, test_sents)) test_preds = list(filter(lambda x: len(x) > avg_len, test_preds)) test_attentions = list(filter(lambda x: len(x) > avg_len, test_attentions)) test_fixations = list(filter(lambda x: len(x) > avg_len, test_fixations)) elif shortest: avg_len = sum(len(s) for s in test_sents)/len(test_sents) test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents)) test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds)) test_attentions = list(filter(lambda x: len(x) <= avg_len, test_attentions)) test_fixations = list(filter(lambda x: len(x) <= avg_len, test_fixations)) print(f"sentence\tprediction\tattentions\tfixations") for s, p, a, f in zip(test_sents, test_preds, test_attentions, test_fixations): a = [x[:len(a)] for x in a] print(f"{s}\t{p}\t{a}\t{f}") if __name__ == "__main__": main()