import ast import os import pathlib import text_attention import click import matplotlib.pyplot as plt plt.switch_backend("agg") import matplotlib.ticker as ticker import numpy as np import tqdm def plot_attention(input_sentence, output_words, attentions, path): # Set up figure with colorbar attentions = np.array(attentions)[:,:len(input_sentence)] fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(attentions, cmap="bone") fig.colorbar(cax) # Set up axes ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90) ax.set_yticklabels([""] + output_words) # Show label at every tick ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.savefig(f"{path}.pdf") plt.close() def parse(p): with open(p) as h: for line in h: if not line or line.startswith("#"): continue _sentence, _prediction, _attention, _fixations = line.strip().split("\t") try: sentence = ast.literal_eval(_sentence) prediction = ast.literal_eval(_prediction) attention = ast.literal_eval(_attention) except: continue yield sentence, prediction, attention @click.command() @click.argument("path", nargs=-1, required=True) def main(path): for p in tqdm.tqdm(path): out_dir = os.path.splitext(p)[0] if out_dir == path: out_dir = f"{out_dir}_" pathlib.Path(out_dir).mkdir(exist_ok=True) for i, spa in enumerate(parse(p)): plot_attention(*spa, path=os.path.join(out_dir, str(i))) if __name__ == "__main__": main()