65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
|
import ast
|
||
|
import os
|
||
|
import pathlib
|
||
|
|
||
|
import text_attention
|
||
|
|
||
|
import click
|
||
|
import matplotlib.pyplot as plt
|
||
|
plt.switch_backend("agg")
|
||
|
import matplotlib.ticker as ticker
|
||
|
import numpy as np
|
||
|
import tqdm
|
||
|
|
||
|
|
||
|
def plot_attention(input_sentence, output_words, attentions, path):
|
||
|
# Set up figure with colorbar
|
||
|
attentions = np.array(attentions)[:,:len(input_sentence)]
|
||
|
fig = plt.figure()
|
||
|
ax = fig.add_subplot(111)
|
||
|
cax = ax.matshow(attentions, cmap="bone")
|
||
|
fig.colorbar(cax)
|
||
|
|
||
|
# Set up axes
|
||
|
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
|
||
|
ax.set_yticklabels([""] + output_words)
|
||
|
|
||
|
# Show label at every tick
|
||
|
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||
|
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||
|
|
||
|
plt.savefig(f"{path}.pdf")
|
||
|
plt.close()
|
||
|
|
||
|
|
||
|
def parse(p):
|
||
|
with open(p) as h:
|
||
|
for line in h:
|
||
|
if not line or line.startswith("#"):
|
||
|
continue
|
||
|
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
|
||
|
try:
|
||
|
sentence = ast.literal_eval(_sentence)
|
||
|
prediction = ast.literal_eval(_prediction)
|
||
|
attention = ast.literal_eval(_attention)
|
||
|
except:
|
||
|
continue
|
||
|
|
||
|
yield sentence, prediction, attention
|
||
|
|
||
|
|
||
|
@click.command()
|
||
|
@click.argument("path", nargs=-1, required=True)
|
||
|
def main(path):
|
||
|
for p in tqdm.tqdm(path):
|
||
|
out_dir = os.path.splitext(p)[0]
|
||
|
if out_dir == path:
|
||
|
out_dir = f"{out_dir}_"
|
||
|
pathlib.Path(out_dir).mkdir(exist_ok=True)
|
||
|
for i, spa in enumerate(parse(p)):
|
||
|
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|