113 lines
3.1 KiB
Python
113 lines
3.1 KiB
Python
'''
|
|
Calculates the KL divergence between specified columns in a file or across columns of different files
|
|
'''
|
|
|
|
from math import log2
|
|
import pandas as pd
|
|
import ast
|
|
import collections
|
|
from sklearn import metrics
|
|
import sys
|
|
|
|
|
|
def kl_divergence(p, q):
|
|
return sum(p[i] * log2(p[i] / q[i]) for i in range(len(p)))
|
|
|
|
|
|
def flatten(x):
|
|
if isinstance(x, collections.Iterable):
|
|
return [a for i in x for a in flatten(i)]
|
|
else:
|
|
return [x]
|
|
|
|
|
|
def get_data(file):
|
|
names = ["sentence", "prediction", "attentions", "fixations"]
|
|
df = pd.read_csv(file, sep='\t', names=names)
|
|
df = df[2:]
|
|
attentions = df.loc[:, "attentions"].tolist()
|
|
fixations = df.loc[:, "fixations"].tolist()
|
|
|
|
return attentions, fixations
|
|
|
|
def attention_attention(attentions1, attentions2):
|
|
divergence = []
|
|
|
|
for att1, att2 in zip(attentions1, attentions2):
|
|
current_att1 = ast.literal_eval(att1)
|
|
current_att2 = ast.literal_eval(att2)
|
|
|
|
lst_att1 = flatten(current_att1)
|
|
lst_att2 = flatten(current_att2)
|
|
|
|
try:
|
|
kl_pq = metrics.mutual_info_score(lst_att1, lst_att2)
|
|
divergence.append(kl_pq)
|
|
except:
|
|
divergence.append(None)
|
|
|
|
avg = sum(divergence) / len(divergence)
|
|
return avg
|
|
|
|
def fixation_fixation(fixation1, fixation2):
|
|
divergence = []
|
|
|
|
for fix1, fix2 in zip(fixation1, fixation2):
|
|
|
|
current_fixation1 = ast.literal_eval(fix1)
|
|
current_fixation2 = ast.literal_eval(fix2)
|
|
|
|
lst_fixation1 = flatten(current_fixation1)
|
|
lst_fixation2 = flatten(current_fixation2)
|
|
|
|
try:
|
|
kl_pq = metrics.mutual_info_score(lst_fixation1, lst_fixation2)
|
|
divergence.append(kl_pq)
|
|
except:
|
|
divergence.append(None)
|
|
|
|
avg = sum(divergence) / len(divergence)
|
|
return avg
|
|
|
|
def attention_fixation(attentions,fixations):
|
|
divergence = []
|
|
|
|
for attention, fixation in zip(attentions, fixations):
|
|
current_attention = ast.literal_eval(attention)
|
|
current_fixation = ast.literal_eval(fixation)
|
|
|
|
lst_attention = []
|
|
|
|
for t in current_attention:
|
|
attention_lst = flatten(t)
|
|
lst_attention.append(sum(attention_lst)/len(attention_lst))
|
|
|
|
lst_fixation = flatten(current_fixation)
|
|
|
|
try:
|
|
kl_pq = metrics.mutual_info_score(lst_attention, lst_fixation)
|
|
divergence.append(kl_pq)
|
|
except:
|
|
divergence.append(None)
|
|
|
|
avg = sum(divergence)/len(divergence)
|
|
return avg
|
|
|
|
|
|
def divergent_calculations(file1, file2=None, val1=None):
|
|
attentions, fixations = get_data(file1)
|
|
attentions2, fixations2 = get_data(file2)
|
|
|
|
if file2:
|
|
if val1 == "attention":
|
|
divergence = attention_attention(attentions, attentions2)
|
|
else:
|
|
divergence = fixation_fixation(fixations, fixations2)
|
|
|
|
else:
|
|
divergence = attention_fixation(attentions, fixations)
|
|
|
|
print ("DL Divergence: ", divergence)
|
|
|
|
divergent_calculations(sys.argv[1], sys.argv[2], sys.argv[3])
|