import ast from math import log2 import os from statistics import mean, pstdev import sys import click import numpy as np from scipy.special import kl_div from matplotlib import pyplot as plt def attention_reader(path): with open(path) as h: for line in h: line = line.strip() try: s, p, a, f = line.split("\t") except: print(f"skipping line: {line}", file=sys.stderr) else: try: yield [[x[0] for x in y] for y in ast.literal_eval(a)] except: print(f"skipping malformed line: {s}", file=sys.stderr) def _kl_divergence(p, q): p = np.asarray(p) q = np.asarray(q) p /= sum(p) q /= sum(q) return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p))) def kl_divergence(ps, qs): kl = 0 count = 0 for p, q in zip(ps, qs): print(p, q) kl += _kl_divergence(p, q) count += 1 return kl/count def _js_divergence(p, q): p = np.asarray(p) q = np.asarray(q) p /= sum(p) q /= sum(q) print(p, q) m = 0.5 * (p + q) return 0.5 * _kl_divergence(p, m) + 0.5 * _kl_divergence(q, m) def js_divergence(ps, qs): js = 0 count = 0 for p, q in zip(ps, qs): js += _js_divergence(p, q) count += 1 return js/count def get_kl_div(seq1, seq2): return [js_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)] return [kl_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)] @click.command() @click.argument("ref") @click.argument("path", nargs=-1) def main(ref, path): kls = [] labels = [] for p in path: labels.append(os.path.basename(p)) kl = get_kl_div(attention_reader(ref), attention_reader(p)) print(mean(kl)) print(pstdev(kl)) kls.append(kl) plt.boxplot(kls, labels=labels) plt.show() plt.clear() if __name__ == "__main__": main()