human-gaze-guided-neural-at.../joint_sentence_compression_.../main.py

496 lines
16 KiB
Python

import logging
import os
import pathlib
import random
import re
import sys
import click
import sacrebleu
import torch
import torch.nn as nn
import tqdm
import config
from libs import corpora
from libs import utils
from libs.fixation_generation import Network as FixNN
from libs.sentence_compression import Network as ComNN
from sklearn.metrics import classification_report, precision_recall_fscore_support
cwd = os.path.dirname(__file__)
logger = logging.getLogger("main")
class Network(nn.Module):
def __init__(
self, word2index, embeddings, prior,
):
super().__init__()
self.logger = logging.getLogger(f"{__name__}")
self.word2index = word2index
self.index2word = {i: k for k, i in word2index.items()}
self.fix_gen = FixNN(
embedding_type="glove",
vocab_size=len(word2index),
embedding_dim=config.embedding_dim,
embeddings=embeddings,
dropout=config.fix_dropout,
hidden_dim=config.fix_hidden_dim,
)
self.com_nn = ComNN(
embeddings=embeddings, hidden_size=config.sem_hidden_dim, prior=prior, device=config.DEV
)
def forward(self, x, target, seq_lens):
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
target = nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=-1)
fixations = torch.sigmoid(self.fix_gen(x1, seq_lens))
# fixations = None
loss, pred, atts = self.com_nn(x1, target, fixations)
return loss, pred, atts, fixations
def load_corpus(corpus_name, splits):
if not splits:
return
logger.info("loading corpus")
if corpus_name == "google":
load_fn = corpora.load_google
corpus = {}
langs = []
if "train" in splits:
train_pairs, train_lang = load_fn("train", max_len=200)
corpus["train"] = train_pairs
langs.append(train_lang)
if "val" in splits:
val_pairs, val_lang = load_fn("val")
corpus["val"] = val_pairs
langs.append(val_lang)
if "test" in splits:
test_pairs, test_lang = load_fn("test")
corpus["test"] = test_pairs
langs.append(test_lang)
logger.info("creating word index")
lang = langs[0]
for _lang in langs[1:]:
lang += _lang
word2index = lang.word2index
index2word = {i: w for w, i in word2index.items()}
return corpus, word2index, index2word
def init_network(word2index, prior):
logger.info("loading embeddings")
vocabulary = sorted(word2index.keys())
embeddings = utils.load_glove(vocabulary)
logger.info("initializing model")
network = Network(word2index=word2index, embeddings=embeddings, prior=prior)
network.to(config.DEV)
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
return network
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
@click.option("-v", "--verbose", count=True)
@click.option("-d", "--debug", is_flag=True)
def main(verbose, debug):
if verbose == 0:
loglevel = logging.ERROR
elif verbose == 1:
loglevel = logging.WARN
elif verbose >= 2:
loglevel = logging.INFO
if debug:
loglevel = logging.DEBUG
logging.basicConfig(
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
datefmt="%d.%m. %H:%M:%S",
level=loglevel,
)
logger.debug("arguments: %s" % str(sys.argv))
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["google",])),
)
@click.option("-m", "--model_name", required=True)
@click.option("-w", "--fixation_weights", required=False)
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
@click.option("-d", "--debug", is_flag=True, default=False)
@click.option("-p", "--prior", type=float, default=.5)
def train(corpus_name, model_name, fixation_weights, freeze_fixations, debug, prior):
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
train_pairs = corpus["train"]
val_pairs = corpus["val"]
network = init_network(word2index, prior)
model_dir = os.path.join("models", model_name)
logger.debug("creating model dir %s" % model_dir)
pathlib.Path(model_dir).mkdir(parents=True)
if fixation_weights is not None:
logger.info("loading fixation prediction checkpoint")
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
else:
weights = checkpoint
# remove the embedding layer before loading
weights = {
k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")
}
network.fix_gen.load_state_dict(weights, strict=False)
if freeze_fixations:
logger.info("freezing fixation generation network")
for p in network.fix_gen.parameters():
p.requires_grad = False
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
best_val_loss = None
epoch = 1
batch_size = 20
while True:
train_batch_iter = utils.sent_iter(
sents=train_pairs, word2index=word2index, batch_size=batch_size
)
val_batch_iter = utils.sent_iter(
sents=val_pairs, word2index=word2index, batch_size=batch_size
)
total_train_loss = 0
total_val_loss = 0
network.train()
for i, batch in tqdm.tqdm(
enumerate(train_batch_iter, 1), total=len(train_pairs) // batch_size + 1
):
optimizer.zero_grad()
raw_sent, sent, target = batch
seq_lens = [len(x) for x in sent]
loss, prediction, attention, fixations = network(sent, target, seq_lens)
prediction = prediction.detach().cpu().numpy()
torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
loss.backward()
optimizer.step()
total_train_loss += loss.item()
avg_train_loss = total_train_loss / len(train_pairs)
val_sents = []
val_preds = []
val_targets = []
network.eval()
for i, batch in tqdm.tqdm(
enumerate(val_batch_iter), total=len(val_pairs) // batch_size + 1
):
raw_sent, sent, target = batch
seq_lens = [len(x) for x in sent]
loss, prediction, attention, fixations = network(sent, target, seq_lens)
prediction = prediction.detach().cpu().numpy()
for i, l in enumerate(seq_lens):
val_sents.append(raw_sent[i][:l])
val_preds.append(prediction[i][:l].tolist())
val_targets.append(target[i][:l].tolist())
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_pairs)
print(
f"epoch {epoch} train_loss {avg_train_loss:.4f} val_loss {avg_val_loss:.4f}"
)
print(
classification_report(
[x for y in val_targets for x in y],
[x for y in val_preds for x in y],
target_names=["not_del", "del"],
digits=5,
)
)
with open(f"models/{model_name}/val_original_{epoch}.txt", "w") as oh, open(
f"models/{model_name}/val_pred_{epoch}.txt", "w"
) as ph, open(f"models/{model_name}/val_gold_{epoch}.txt", "w") as gh:
for sent, preds, golds in zip(val_sents, val_preds, val_targets):
pred_compressed = [
word for word, delete in zip(sent, preds) if not delete
]
gold_compressed = [
word for word, delete in zip(sent, golds) if not delete
]
oh.write(" ".join(sent))
ph.write(" ".join(pred_compressed))
gh.write(" ".join(gold_compressed))
oh.write("\n")
ph.write("\n")
gh.write("\n")
if best_val_loss is None or avg_val_loss < best_val_loss:
delta = avg_val_loss - best_val_loss if best_val_loss is not None else 0.0
best_val_loss = avg_val_loss
print(
f"new best model epoch {epoch} val loss {avg_val_loss:.4f} ({delta:.4f})"
)
utils.save_model(
network, word2index, f"models/{model_name}/{model_name}_{epoch}"
)
epoch += 1
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["google",])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-p", "--prior", type=float, default=.5)
@click.option("-l", "--longest", is_flag=True)
@click.option("-s", "--shortest", is_flag=True)
@click.option("-d", "--detailed", is_flag=True)
def test(corpus_name, model_weights, prior, longest, shortest, detailed):
if longest and shortest:
print("longest and shortest are mutually exclusive", file=sys.stderr)
sys.exit()
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
test_pairs = corpus["test"]
model_name = os.path.basename(os.path.dirname(model_weights))
epoch = re.search("_(\d+).tar", model_weights).group(1)
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index, prior)
network.eval()
# remove the embedding layer before loading
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
# actually load the parameters
network.load_state_dict(weights, strict=False)
total_test_loss = 0
batch_size = 20
test_batch_iter = utils.sent_iter(
sents=test_pairs, word2index=word2index, batch_size=batch_size
)
test_sents = []
test_preds = []
test_targets = []
for i, batch in tqdm.tqdm(
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
):
raw_sent, sent, target = batch
seq_lens = [len(x) for x in sent]
loss, prediction, attention, fixations = network(sent, target, seq_lens)
prediction = prediction.detach().cpu().numpy()
for i, l in enumerate(
seq_lens
):
test_sents.append(raw_sent[i][:l])
test_preds.append(prediction[i][:l].tolist())
test_targets.append(target[i][:l].tolist())
total_test_loss += loss.item()
avg_test_loss = total_test_loss / len(test_pairs)
print(f"test_loss {avg_test_loss:.4f}")
if longest:
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
test_targets = list(filter(lambda x: len(x) > avg_len, test_targets))
elif shortest:
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
test_targets = list(filter(lambda x: len(x) <= avg_len, test_targets))
if detailed:
for test_sent, test_target, test_pred in zip(test_sents, test_targets, test_preds):
print(precision_recall_fscore_support(test_target, test_pred, average="weighted")[2], test_sent, test_target, test_pred)
else:
print(
classification_report(
[x for y in test_targets for x in y],
[x for y in test_preds for x in y],
target_names=["not_del", "del"],
digits=5,
)
)
with open(f"models/{model_name}/test_original_{epoch}.txt", "w") as oh, open(
f"models/{model_name}/test_pred_{epoch}.txt", "w"
) as ph, open(f"models/{model_name}/test_gold_{epoch}.txt", "w") as gh:
for sent, preds, golds in zip(test_sents, test_preds, test_targets):
pred_compressed = [word for word, delete in zip(sent, preds) if not delete]
gold_compressed = [word for word, delete in zip(sent, golds) if not delete]
oh.write(" ".join(sent))
ph.write(" ".join(pred_compressed))
gh.write(" ".join(gold_compressed))
oh.write("\n")
ph.write("\n")
gh.write("\n")
@main.command()
@click.option(
"-c",
"--corpus",
"corpus_name",
required=True,
type=click.Choice(sorted(["google",])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-p", "--prior", type=float, default=.5)
@click.option("-l", "--longest", is_flag=True)
@click.option("-s", "--shortest", is_flag=True)
def predict(corpus_name, model_weights, prior, longest, shortest):
if longest and shortest:
print("longest and shortest are mutually exclusive", file=sys.stderr)
sys.exit()
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
test_pairs = corpus["test"]
model_name = os.path.basename(os.path.dirname(model_weights))
epoch = re.search("_(\d+).tar", model_weights).group(1)
logger.info("loading model checkpoint")
checkpoint = torch.load(model_weights, map_location=config.DEV)
if "word2index" in checkpoint:
weights = checkpoint["weights"]
word2index = checkpoint["word2index"]
index2word = {i: w for w, i in word2index.items()}
else:
asdf
network = init_network(word2index, prior)
network.eval()
# remove the embedding layer before loading
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
# actually load the parameters
network.load_state_dict(weights, strict=False)
total_test_loss = 0
batch_size = 20
test_batch_iter = utils.sent_iter(
sents=test_pairs, word2index=word2index, batch_size=batch_size
)
test_sents = []
test_preds = []
test_attentions = []
test_fixations = []
for i, batch in tqdm.tqdm(
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
):
raw_sent, sent, target = batch
seq_lens = [len(x) for x in sent]
loss, prediction, attention, fixations = network(sent, target, seq_lens)
prediction = prediction.detach().cpu().numpy()
attention = attention.detach().cpu().numpy()
if fixations is not None:
fixations = fixations.detach().cpu().numpy()
for i, l in enumerate(
seq_lens
):
test_sents.append(raw_sent[i][:l])
test_preds.append(prediction[i][:l].tolist())
test_attentions.append(attention[i][:l].tolist())
if fixations is not None:
test_fixations.append(fixations[i][:l].tolist())
else:
test_fixations.append([])
total_test_loss += loss.item()
avg_test_loss = total_test_loss / len(test_pairs)
if longest:
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
test_attentions = list(filter(lambda x: len(x) > avg_len, test_attentions))
test_fixations = list(filter(lambda x: len(x) > avg_len, test_fixations))
elif shortest:
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
test_attentions = list(filter(lambda x: len(x) <= avg_len, test_attentions))
test_fixations = list(filter(lambda x: len(x) <= avg_len, test_fixations))
print(f"sentence\tprediction\tattentions\tfixations")
for s, p, a, f in zip(test_sents, test_preds, test_attentions, test_fixations):
a = [x[:len(a)] for x in a]
print(f"{s}\t{p}\t{a}\t{f}")
if __name__ == "__main__":
main()