human-gaze-guided-neural-at.../joint_sentence_compression_.../utils/kl_divergence.py

113 lines
3.1 KiB
Python

'''
Calculates the KL divergence between specified columns in a file or across columns of different files
'''
from math import log2
import pandas as pd
import ast
import collections
from sklearn import metrics
import sys
def kl_divergence(p, q):
return sum(p[i] * log2(p[i] / q[i]) for i in range(len(p)))
def flatten(x):
if isinstance(x, collections.Iterable):
return [a for i in x for a in flatten(i)]
else:
return [x]
def get_data(file):
names = ["sentence", "prediction", "attentions", "fixations"]
df = pd.read_csv(file, sep='\t', names=names)
df = df[2:]
attentions = df.loc[:, "attentions"].tolist()
fixations = df.loc[:, "fixations"].tolist()
return attentions, fixations
def attention_attention(attentions1, attentions2):
divergence = []
for att1, att2 in zip(attentions1, attentions2):
current_att1 = ast.literal_eval(att1)
current_att2 = ast.literal_eval(att2)
lst_att1 = flatten(current_att1)
lst_att2 = flatten(current_att2)
try:
kl_pq = metrics.mutual_info_score(lst_att1, lst_att2)
divergence.append(kl_pq)
except:
divergence.append(None)
avg = sum(divergence) / len(divergence)
return avg
def fixation_fixation(fixation1, fixation2):
divergence = []
for fix1, fix2 in zip(fixation1, fixation2):
current_fixation1 = ast.literal_eval(fix1)
current_fixation2 = ast.literal_eval(fix2)
lst_fixation1 = flatten(current_fixation1)
lst_fixation2 = flatten(current_fixation2)
try:
kl_pq = metrics.mutual_info_score(lst_fixation1, lst_fixation2)
divergence.append(kl_pq)
except:
divergence.append(None)
avg = sum(divergence) / len(divergence)
return avg
def attention_fixation(attentions,fixations):
divergence = []
for attention, fixation in zip(attentions, fixations):
current_attention = ast.literal_eval(attention)
current_fixation = ast.literal_eval(fixation)
lst_attention = []
for t in current_attention:
attention_lst = flatten(t)
lst_attention.append(sum(attention_lst)/len(attention_lst))
lst_fixation = flatten(current_fixation)
try:
kl_pq = metrics.mutual_info_score(lst_attention, lst_fixation)
divergence.append(kl_pq)
except:
divergence.append(None)
avg = sum(divergence)/len(divergence)
return avg
def divergent_calculations(file1, file2=None, val1=None):
attentions, fixations = get_data(file1)
attentions2, fixations2 = get_data(file2)
if file2:
if val1 == "attention":
divergence = attention_attention(attentions, attentions2)
else:
divergence = fixation_fixation(fixations, fixations2)
else:
divergence = attention_fixation(attentions, fixations)
print ("DL Divergence: ", divergence)
divergent_calculations(sys.argv[1], sys.argv[2], sys.argv[3])