import json import logging import math import os import random import re import time import matplotlib.pyplot as plt import matplotlib.ticker as ticker from nltk.translate.bleu_score import sentence_bleu import numpy as np import torch import torch.nn as nn import config plt.switch_backend("agg") def load_glove(vocabulary): logger = logging.getLogger(f"{__name__}.load_glove") logger.info("loading embeddings") try: with open(f"glove.cache") as h: cache = json.load(h) except: logger.info("cache doesn't exist") cache = {} cache[config.PAD] = [0] * 300 cache[config.SOS] = [0] * 300 cache[config.EOS] = [0] * 300 cache[config.UNK] = [0] * 300 cache[config.NOFIX] = [0] * 300 else: logger.info("cache found") cache_miss = False if not set(vocabulary) <= set(cache): cache_miss = True logger.warn("cache miss, loading full embeddings") data = {} with open("glove.840B.300d.txt") as h: for line in h: word, *emb = line.strip().split() try: data[word] = [float(x) for x in emb] except: continue logger.info("finished loading full embeddings") for word in vocabulary: try: cache[word] = data[word] except KeyError: cache[word] = [0] * 300 logger.info("cache updated") embeddings = [] for word in vocabulary: embeddings.append(torch.tensor(cache[word], dtype=torch.float32)) embeddings = torch.stack(embeddings) if cache_miss: with open(f"glove.cache", "w") as h: json.dump(cache, h) logger.info("cache saved") return embeddings def tokenize(s): s = s.lower().strip() s = re.sub(r"([.!?])", r" \1", s) s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) s = s.split(" ") return s def indices_from_sentence(word2index, sentence, unknown_threshold): if unknown_threshold: return [ word2index.get( word if random.random() > unknown_threshold else config.UNK, word2index[config.UNK], ) for word in sentence ] else: return [ word2index.get(word, word2index[config.UNK]) for word in sentence ] def tensor_from_sentence(word2index, sentence, unknown_threshold): # indices = [config.SOS] indices = indices_from_sentence(word2index, sentence, unknown_threshold) indices.append(word2index[config.EOS]) return torch.tensor(indices, dtype=torch.long, device=config.DEV) def tensors_from_pair(word2index, pair, shuffle, unknown_threshold): tensors = [ tensor_from_sentence(word2index, pair[0], unknown_threshold), tensor_from_sentence(word2index, pair[1], unknown_threshold), ] if shuffle: random.shuffle(tensors) return tensors def bleu(reference, hypothesis, n=4): #not sure if this actually changes the n gram if n < 1: return 0 weights = [1/n]*n return sentence_bleu([reference], hypothesis, weights) def pair_iter(pairs, word2index, shuffle=False, shuffle_pairs=False, unknown_threshold=0.00): if shuffle: pairs = pairs.copy() random.shuffle(pairs) for pair in pairs: tensor1, tensor2 = tensors_from_pair(word2index, (pair[0], pair[1]), shuffle_pairs, unknown_threshold) yield (tensor1,), (tensor2,) def sent_iter(sents, word2index, unknown_threshold=0.00): for sent in sents: tensor = tensor_from_sentence(word2index, sent, unknown_threshold) yield (tensor,) def batch_iter(pairs, word2index, batch_size, shuffle=False, unknown_threshold=0.00): for i in range(len(pairs) // batch_size): batch = pairs[i : i + batch_size] if len(batch) != batch_size: continue batch_tensors = [ tensors_from_pair(word2index, (pair[0], pair[1]), shuffle, unknown_threshold) for pair in batch ] tensors1, tensors2 = zip(*batch_tensors) # targets = torch.tensor(targets, dtype=torch.long, device=config.DEV) # tensors1_lengths = [len(t) for t in tensors1] # tensors2_lengths = [len(t) for t in tensors2] # tensors1 = nn.utils.rnn.pack_sequence(tensors1, enforce_sorted=False) # tensors2 = nn.utils.rnn.pack_sequence(tensors2, enforce_sorted=False) yield tensors1, tensors2 def asMinutes(s): m = math.floor(s / 60) s -= m * 60 return "%dm %ds" % (m, s) def timeSince(since, percent): now = time.time() s = now - since es = s / (percent) rs = es - s return "%s (- %s)" % (asMinutes(s), asMinutes(rs)) def showPlot(points): plt.figure() fig, ax = plt.subplots() # this locator puts ticks at regular intervals loc = ticker.MultipleLocator(base=0.2) ax.yaxis.set_major_locator(loc) plt.plot(points) def showAttention(input_sentence, output_words, attentions): # Set up figure with colorbar fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(attentions.numpy(), cmap="bone") fig.colorbar(cax) # Set up axes ax.set_xticklabels([""] + input_sentence.split(" ") + ["<__EOS__>"], rotation=90) ax.set_yticklabels([""] + output_words) # Show label at every tick ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.show() def evaluateAndShowAttention(input_sentence): output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence) print("input =", input_sentence) print("output =", " ".join(output_words)) showAttention(input_sentence, output_words, attentions) def save_model(model, word2index, path): if not path.endswith(".tar"): path += ".tar" torch.save( {"weights": model.state_dict(), "word2index": word2index}, path, ) def load_model(path): checkpoint = torch.load(path) return checkpoint["weights"], checkpoint["word2index"] def extend_vocabulary(word2index, langs): for lang in langs: for word in lang.word2index: if word not in word2index: word2index[word] = len(word2index) return word2index