89 lines
1.9 KiB
Python
89 lines
1.9 KiB
Python
import ast
|
|
from math import log2
|
|
import os
|
|
from statistics import mean, pstdev
|
|
import sys
|
|
|
|
import click
|
|
import numpy as np
|
|
from scipy.special import kl_div
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
def attention_reader(path):
|
|
with open(path) as h:
|
|
for line in h:
|
|
line = line.strip()
|
|
try:
|
|
s, p, a, f = line.split("\t")
|
|
except:
|
|
print(f"skipping line: {line}", file=sys.stderr)
|
|
else:
|
|
try:
|
|
yield [[x[0] for x in y] for y in ast.literal_eval(a)]
|
|
except:
|
|
print(f"skipping malformed line: {s}", file=sys.stderr)
|
|
|
|
|
|
def _kl_divergence(p, q):
|
|
p = np.asarray(p)
|
|
q = np.asarray(q)
|
|
p /= sum(p)
|
|
q /= sum(q)
|
|
return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
|
|
|
|
def kl_divergence(ps, qs):
|
|
kl = 0
|
|
count = 0
|
|
for p, q in zip(ps, qs):
|
|
print(p, q)
|
|
kl += _kl_divergence(p, q)
|
|
count += 1
|
|
return kl/count
|
|
|
|
|
|
def _js_divergence(p, q):
|
|
p = np.asarray(p)
|
|
q = np.asarray(q)
|
|
p /= sum(p)
|
|
q /= sum(q)
|
|
print(p, q)
|
|
m = 0.5 * (p + q)
|
|
return 0.5 * _kl_divergence(p, m) + 0.5 * _kl_divergence(q, m)
|
|
|
|
def js_divergence(ps, qs):
|
|
js = 0
|
|
count = 0
|
|
for p, q in zip(ps, qs):
|
|
js += _js_divergence(p, q)
|
|
count += 1
|
|
return js/count
|
|
|
|
|
|
def get_kl_div(seq1, seq2):
|
|
return [js_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
|
|
return [kl_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
|
|
|
|
|
|
@click.command()
|
|
@click.argument("ref")
|
|
@click.argument("path", nargs=-1)
|
|
def main(ref, path):
|
|
kls = []
|
|
labels = []
|
|
for p in path:
|
|
labels.append(os.path.basename(p))
|
|
kl = get_kl_div(attention_reader(ref), attention_reader(p))
|
|
print(mean(kl))
|
|
print(pstdev(kl))
|
|
kls.append(kl)
|
|
|
|
plt.boxplot(kls, labels=labels)
|
|
plt.show()
|
|
plt.clear()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|