human-gaze-guided-neural-at.../joint_sentence_compression_.../utils/kl.py

89 lines
1.9 KiB
Python

import ast
from math import log2
import os
from statistics import mean, pstdev
import sys
import click
import numpy as np
from scipy.special import kl_div
from matplotlib import pyplot as plt
def attention_reader(path):
with open(path) as h:
for line in h:
line = line.strip()
try:
s, p, a, f = line.split("\t")
except:
print(f"skipping line: {line}", file=sys.stderr)
else:
try:
yield [[x[0] for x in y] for y in ast.literal_eval(a)]
except:
print(f"skipping malformed line: {s}", file=sys.stderr)
def _kl_divergence(p, q):
p = np.asarray(p)
q = np.asarray(q)
p /= sum(p)
q /= sum(q)
return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
def kl_divergence(ps, qs):
kl = 0
count = 0
for p, q in zip(ps, qs):
print(p, q)
kl += _kl_divergence(p, q)
count += 1
return kl/count
def _js_divergence(p, q):
p = np.asarray(p)
q = np.asarray(q)
p /= sum(p)
q /= sum(q)
print(p, q)
m = 0.5 * (p + q)
return 0.5 * _kl_divergence(p, m) + 0.5 * _kl_divergence(q, m)
def js_divergence(ps, qs):
js = 0
count = 0
for p, q in zip(ps, qs):
js += _js_divergence(p, q)
count += 1
return js/count
def get_kl_div(seq1, seq2):
return [js_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
return [kl_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
@click.command()
@click.argument("ref")
@click.argument("path", nargs=-1)
def main(ref, path):
kls = []
labels = []
for p in path:
labels.append(os.path.basename(p))
kl = get_kl_div(attention_reader(ref), attention_reader(p))
print(mean(kl))
print(pstdev(kl))
kls.append(kl)
plt.boxplot(kls, labels=labels)
plt.show()
plt.clear()
if __name__ == "__main__":
main()