496 lines
16 KiB
Python
496 lines
16 KiB
Python
import logging
|
|
import os
|
|
import pathlib
|
|
import random
|
|
import re
|
|
import sys
|
|
|
|
import click
|
|
import sacrebleu
|
|
import torch
|
|
import torch.nn as nn
|
|
import tqdm
|
|
|
|
import config
|
|
from libs import corpora
|
|
from libs import utils
|
|
from libs.fixation_generation import Network as FixNN
|
|
from libs.sentence_compression import Network as ComNN
|
|
from sklearn.metrics import classification_report, precision_recall_fscore_support
|
|
|
|
|
|
cwd = os.path.dirname(__file__)
|
|
|
|
logger = logging.getLogger("main")
|
|
|
|
|
|
class Network(nn.Module):
|
|
def __init__(
|
|
self, word2index, embeddings, prior,
|
|
):
|
|
super().__init__()
|
|
self.logger = logging.getLogger(f"{__name__}")
|
|
self.word2index = word2index
|
|
self.index2word = {i: k for k, i in word2index.items()}
|
|
self.fix_gen = FixNN(
|
|
embedding_type="glove",
|
|
vocab_size=len(word2index),
|
|
embedding_dim=config.embedding_dim,
|
|
embeddings=embeddings,
|
|
dropout=config.fix_dropout,
|
|
hidden_dim=config.fix_hidden_dim,
|
|
)
|
|
self.com_nn = ComNN(
|
|
embeddings=embeddings, hidden_size=config.sem_hidden_dim, prior=prior, device=config.DEV
|
|
)
|
|
|
|
def forward(self, x, target, seq_lens):
|
|
x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
|
|
target = nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=-1)
|
|
|
|
fixations = torch.sigmoid(self.fix_gen(x1, seq_lens))
|
|
# fixations = None
|
|
|
|
loss, pred, atts = self.com_nn(x1, target, fixations)
|
|
return loss, pred, atts, fixations
|
|
|
|
|
|
def load_corpus(corpus_name, splits):
|
|
if not splits:
|
|
return
|
|
|
|
logger.info("loading corpus")
|
|
if corpus_name == "google":
|
|
load_fn = corpora.load_google
|
|
|
|
corpus = {}
|
|
langs = []
|
|
|
|
if "train" in splits:
|
|
train_pairs, train_lang = load_fn("train", max_len=200)
|
|
corpus["train"] = train_pairs
|
|
langs.append(train_lang)
|
|
if "val" in splits:
|
|
val_pairs, val_lang = load_fn("val")
|
|
corpus["val"] = val_pairs
|
|
langs.append(val_lang)
|
|
if "test" in splits:
|
|
test_pairs, test_lang = load_fn("test")
|
|
corpus["test"] = test_pairs
|
|
langs.append(test_lang)
|
|
|
|
logger.info("creating word index")
|
|
lang = langs[0]
|
|
for _lang in langs[1:]:
|
|
lang += _lang
|
|
word2index = lang.word2index
|
|
|
|
index2word = {i: w for w, i in word2index.items()}
|
|
|
|
return corpus, word2index, index2word
|
|
|
|
|
|
def init_network(word2index, prior):
|
|
logger.info("loading embeddings")
|
|
vocabulary = sorted(word2index.keys())
|
|
embeddings = utils.load_glove(vocabulary)
|
|
|
|
logger.info("initializing model")
|
|
network = Network(word2index=word2index, embeddings=embeddings, prior=prior)
|
|
network.to(config.DEV)
|
|
|
|
print(f"#parameters: {sum(p.numel() for p in network.parameters())}")
|
|
|
|
return network
|
|
|
|
|
|
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
|
|
@click.option("-v", "--verbose", count=True)
|
|
@click.option("-d", "--debug", is_flag=True)
|
|
def main(verbose, debug):
|
|
if verbose == 0:
|
|
loglevel = logging.ERROR
|
|
elif verbose == 1:
|
|
loglevel = logging.WARN
|
|
elif verbose >= 2:
|
|
loglevel = logging.INFO
|
|
|
|
if debug:
|
|
loglevel = logging.DEBUG
|
|
|
|
logging.basicConfig(
|
|
format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
|
|
datefmt="%d.%m. %H:%M:%S",
|
|
level=loglevel,
|
|
)
|
|
|
|
logger.debug("arguments: %s" % str(sys.argv))
|
|
|
|
|
|
@main.command()
|
|
@click.option(
|
|
"-c",
|
|
"--corpus",
|
|
"corpus_name",
|
|
required=True,
|
|
type=click.Choice(sorted(["google",])),
|
|
)
|
|
@click.option("-m", "--model_name", required=True)
|
|
@click.option("-w", "--fixation_weights", required=False)
|
|
@click.option("-f", "--freeze_fixations", is_flag=True, default=False)
|
|
@click.option("-d", "--debug", is_flag=True, default=False)
|
|
@click.option("-p", "--prior", type=float, default=.5)
|
|
def train(corpus_name, model_name, fixation_weights, freeze_fixations, debug, prior):
|
|
corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
|
|
train_pairs = corpus["train"]
|
|
val_pairs = corpus["val"]
|
|
network = init_network(word2index, prior)
|
|
|
|
model_dir = os.path.join("models", model_name)
|
|
logger.debug("creating model dir %s" % model_dir)
|
|
pathlib.Path(model_dir).mkdir(parents=True)
|
|
|
|
if fixation_weights is not None:
|
|
logger.info("loading fixation prediction checkpoint")
|
|
checkpoint = torch.load(fixation_weights, map_location=config.DEV)
|
|
if "word2index" in checkpoint:
|
|
weights = checkpoint["weights"]
|
|
else:
|
|
weights = checkpoint
|
|
|
|
# remove the embedding layer before loading
|
|
weights = {
|
|
k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")
|
|
}
|
|
network.fix_gen.load_state_dict(weights, strict=False)
|
|
|
|
if freeze_fixations:
|
|
logger.info("freezing fixation generation network")
|
|
for p in network.fix_gen.parameters():
|
|
p.requires_grad = False
|
|
|
|
optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)
|
|
|
|
best_val_loss = None
|
|
|
|
epoch = 1
|
|
batch_size = 20
|
|
|
|
while True:
|
|
train_batch_iter = utils.sent_iter(
|
|
sents=train_pairs, word2index=word2index, batch_size=batch_size
|
|
)
|
|
val_batch_iter = utils.sent_iter(
|
|
sents=val_pairs, word2index=word2index, batch_size=batch_size
|
|
)
|
|
|
|
total_train_loss = 0
|
|
total_val_loss = 0
|
|
|
|
network.train()
|
|
for i, batch in tqdm.tqdm(
|
|
enumerate(train_batch_iter, 1), total=len(train_pairs) // batch_size + 1
|
|
):
|
|
optimizer.zero_grad()
|
|
|
|
raw_sent, sent, target = batch
|
|
seq_lens = [len(x) for x in sent]
|
|
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
|
|
|
prediction = prediction.detach().cpu().numpy()
|
|
|
|
torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_train_loss += loss.item()
|
|
|
|
avg_train_loss = total_train_loss / len(train_pairs)
|
|
|
|
val_sents = []
|
|
val_preds = []
|
|
val_targets = []
|
|
|
|
network.eval()
|
|
for i, batch in tqdm.tqdm(
|
|
enumerate(val_batch_iter), total=len(val_pairs) // batch_size + 1
|
|
):
|
|
raw_sent, sent, target = batch
|
|
seq_lens = [len(x) for x in sent]
|
|
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
|
|
|
prediction = prediction.detach().cpu().numpy()
|
|
|
|
for i, l in enumerate(seq_lens):
|
|
val_sents.append(raw_sent[i][:l])
|
|
val_preds.append(prediction[i][:l].tolist())
|
|
val_targets.append(target[i][:l].tolist())
|
|
|
|
total_val_loss += loss.item()
|
|
|
|
avg_val_loss = total_val_loss / len(val_pairs)
|
|
|
|
print(
|
|
f"epoch {epoch} train_loss {avg_train_loss:.4f} val_loss {avg_val_loss:.4f}"
|
|
)
|
|
|
|
print(
|
|
classification_report(
|
|
[x for y in val_targets for x in y],
|
|
[x for y in val_preds for x in y],
|
|
target_names=["not_del", "del"],
|
|
digits=5,
|
|
)
|
|
)
|
|
|
|
with open(f"models/{model_name}/val_original_{epoch}.txt", "w") as oh, open(
|
|
f"models/{model_name}/val_pred_{epoch}.txt", "w"
|
|
) as ph, open(f"models/{model_name}/val_gold_{epoch}.txt", "w") as gh:
|
|
for sent, preds, golds in zip(val_sents, val_preds, val_targets):
|
|
pred_compressed = [
|
|
word for word, delete in zip(sent, preds) if not delete
|
|
]
|
|
gold_compressed = [
|
|
word for word, delete in zip(sent, golds) if not delete
|
|
]
|
|
|
|
oh.write(" ".join(sent))
|
|
ph.write(" ".join(pred_compressed))
|
|
gh.write(" ".join(gold_compressed))
|
|
|
|
oh.write("\n")
|
|
ph.write("\n")
|
|
gh.write("\n")
|
|
|
|
if best_val_loss is None or avg_val_loss < best_val_loss:
|
|
delta = avg_val_loss - best_val_loss if best_val_loss is not None else 0.0
|
|
best_val_loss = avg_val_loss
|
|
print(
|
|
f"new best model epoch {epoch} val loss {avg_val_loss:.4f} ({delta:.4f})"
|
|
)
|
|
|
|
utils.save_model(
|
|
network, word2index, f"models/{model_name}/{model_name}_{epoch}"
|
|
)
|
|
|
|
epoch += 1
|
|
|
|
|
|
@main.command()
|
|
@click.option(
|
|
"-c",
|
|
"--corpus",
|
|
"corpus_name",
|
|
required=True,
|
|
type=click.Choice(sorted(["google",])),
|
|
)
|
|
@click.option("-w", "--model_weights", required=True)
|
|
@click.option("-p", "--prior", type=float, default=.5)
|
|
@click.option("-l", "--longest", is_flag=True)
|
|
@click.option("-s", "--shortest", is_flag=True)
|
|
@click.option("-d", "--detailed", is_flag=True)
|
|
def test(corpus_name, model_weights, prior, longest, shortest, detailed):
|
|
if longest and shortest:
|
|
print("longest and shortest are mutually exclusive", file=sys.stderr)
|
|
sys.exit()
|
|
|
|
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
|
test_pairs = corpus["test"]
|
|
|
|
model_name = os.path.basename(os.path.dirname(model_weights))
|
|
epoch = re.search("_(\d+).tar", model_weights).group(1)
|
|
|
|
logger.info("loading model checkpoint")
|
|
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
|
if "word2index" in checkpoint:
|
|
weights = checkpoint["weights"]
|
|
word2index = checkpoint["word2index"]
|
|
index2word = {i: w for w, i in word2index.items()}
|
|
else:
|
|
asdf
|
|
|
|
network = init_network(word2index, prior)
|
|
network.eval()
|
|
|
|
# remove the embedding layer before loading
|
|
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
|
# actually load the parameters
|
|
network.load_state_dict(weights, strict=False)
|
|
|
|
total_test_loss = 0
|
|
|
|
batch_size = 20
|
|
|
|
test_batch_iter = utils.sent_iter(
|
|
sents=test_pairs, word2index=word2index, batch_size=batch_size
|
|
)
|
|
|
|
test_sents = []
|
|
test_preds = []
|
|
test_targets = []
|
|
|
|
for i, batch in tqdm.tqdm(
|
|
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
|
|
):
|
|
raw_sent, sent, target = batch
|
|
seq_lens = [len(x) for x in sent]
|
|
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
|
|
|
prediction = prediction.detach().cpu().numpy()
|
|
|
|
for i, l in enumerate(
|
|
seq_lens
|
|
):
|
|
test_sents.append(raw_sent[i][:l])
|
|
test_preds.append(prediction[i][:l].tolist())
|
|
test_targets.append(target[i][:l].tolist())
|
|
|
|
total_test_loss += loss.item()
|
|
|
|
avg_test_loss = total_test_loss / len(test_pairs)
|
|
|
|
print(f"test_loss {avg_test_loss:.4f}")
|
|
|
|
if longest:
|
|
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
|
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
|
|
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
|
|
test_targets = list(filter(lambda x: len(x) > avg_len, test_targets))
|
|
elif shortest:
|
|
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
|
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
|
|
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
|
|
test_targets = list(filter(lambda x: len(x) <= avg_len, test_targets))
|
|
|
|
if detailed:
|
|
for test_sent, test_target, test_pred in zip(test_sents, test_targets, test_preds):
|
|
print(precision_recall_fscore_support(test_target, test_pred, average="weighted")[2], test_sent, test_target, test_pred)
|
|
else:
|
|
print(
|
|
classification_report(
|
|
[x for y in test_targets for x in y],
|
|
[x for y in test_preds for x in y],
|
|
target_names=["not_del", "del"],
|
|
digits=5,
|
|
)
|
|
)
|
|
|
|
with open(f"models/{model_name}/test_original_{epoch}.txt", "w") as oh, open(
|
|
f"models/{model_name}/test_pred_{epoch}.txt", "w"
|
|
) as ph, open(f"models/{model_name}/test_gold_{epoch}.txt", "w") as gh:
|
|
for sent, preds, golds in zip(test_sents, test_preds, test_targets):
|
|
pred_compressed = [word for word, delete in zip(sent, preds) if not delete]
|
|
gold_compressed = [word for word, delete in zip(sent, golds) if not delete]
|
|
|
|
oh.write(" ".join(sent))
|
|
ph.write(" ".join(pred_compressed))
|
|
gh.write(" ".join(gold_compressed))
|
|
|
|
oh.write("\n")
|
|
ph.write("\n")
|
|
gh.write("\n")
|
|
|
|
|
|
@main.command()
|
|
@click.option(
|
|
"-c",
|
|
"--corpus",
|
|
"corpus_name",
|
|
required=True,
|
|
type=click.Choice(sorted(["google",])),
|
|
)
|
|
@click.option("-w", "--model_weights", required=True)
|
|
@click.option("-p", "--prior", type=float, default=.5)
|
|
@click.option("-l", "--longest", is_flag=True)
|
|
@click.option("-s", "--shortest", is_flag=True)
|
|
def predict(corpus_name, model_weights, prior, longest, shortest):
|
|
if longest and shortest:
|
|
print("longest and shortest are mutually exclusive", file=sys.stderr)
|
|
sys.exit()
|
|
|
|
corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
|
|
test_pairs = corpus["test"]
|
|
|
|
model_name = os.path.basename(os.path.dirname(model_weights))
|
|
epoch = re.search("_(\d+).tar", model_weights).group(1)
|
|
|
|
logger.info("loading model checkpoint")
|
|
checkpoint = torch.load(model_weights, map_location=config.DEV)
|
|
if "word2index" in checkpoint:
|
|
weights = checkpoint["weights"]
|
|
word2index = checkpoint["word2index"]
|
|
index2word = {i: w for w, i in word2index.items()}
|
|
else:
|
|
asdf
|
|
|
|
network = init_network(word2index, prior)
|
|
network.eval()
|
|
|
|
# remove the embedding layer before loading
|
|
# weights = {k: v for k, v in weights.items() if not "embedding" in k}
|
|
# actually load the parameters
|
|
network.load_state_dict(weights, strict=False)
|
|
|
|
total_test_loss = 0
|
|
|
|
batch_size = 20
|
|
|
|
test_batch_iter = utils.sent_iter(
|
|
sents=test_pairs, word2index=word2index, batch_size=batch_size
|
|
)
|
|
|
|
test_sents = []
|
|
test_preds = []
|
|
test_attentions = []
|
|
test_fixations = []
|
|
|
|
for i, batch in tqdm.tqdm(
|
|
enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
|
|
):
|
|
raw_sent, sent, target = batch
|
|
seq_lens = [len(x) for x in sent]
|
|
loss, prediction, attention, fixations = network(sent, target, seq_lens)
|
|
|
|
prediction = prediction.detach().cpu().numpy()
|
|
attention = attention.detach().cpu().numpy()
|
|
if fixations is not None:
|
|
fixations = fixations.detach().cpu().numpy()
|
|
|
|
for i, l in enumerate(
|
|
seq_lens
|
|
):
|
|
test_sents.append(raw_sent[i][:l])
|
|
test_preds.append(prediction[i][:l].tolist())
|
|
test_attentions.append(attention[i][:l].tolist())
|
|
if fixations is not None:
|
|
test_fixations.append(fixations[i][:l].tolist())
|
|
else:
|
|
test_fixations.append([])
|
|
|
|
total_test_loss += loss.item()
|
|
|
|
avg_test_loss = total_test_loss / len(test_pairs)
|
|
|
|
if longest:
|
|
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
|
test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
|
|
test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
|
|
test_attentions = list(filter(lambda x: len(x) > avg_len, test_attentions))
|
|
test_fixations = list(filter(lambda x: len(x) > avg_len, test_fixations))
|
|
elif shortest:
|
|
avg_len = sum(len(s) for s in test_sents)/len(test_sents)
|
|
test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
|
|
test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
|
|
test_attentions = list(filter(lambda x: len(x) <= avg_len, test_attentions))
|
|
test_fixations = list(filter(lambda x: len(x) <= avg_len, test_fixations))
|
|
|
|
print(f"sentence\tprediction\tattentions\tfixations")
|
|
for s, p, a, f in zip(test_sents, test_preds, test_attentions, test_fixations):
|
|
a = [x[:len(a)] for x in a]
|
|
print(f"{s}\t{p}\t{a}\t{f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|