''' Calculates the KL divergence between specified columns in a file or across columns of different files ''' from math import log2 import pandas as pd import ast import collections from sklearn import metrics import sys def kl_divergence(p, q): return sum(p[i] * log2(p[i] / q[i]) for i in range(len(p))) def flatten(x): if isinstance(x, collections.Iterable): return [a for i in x for a in flatten(i)] else: return [x] def get_data(file): names = ["sentence", "prediction", "attentions", "fixations"] df = pd.read_csv(file, sep='\t', names=names) df = df[2:] attentions = df.loc[:, "attentions"].tolist() fixations = df.loc[:, "fixations"].tolist() return attentions, fixations def attention_attention(attentions1, attentions2): divergence = [] for att1, att2 in zip(attentions1, attentions2): current_att1 = ast.literal_eval(att1) current_att2 = ast.literal_eval(att2) lst_att1 = flatten(current_att1) lst_att2 = flatten(current_att2) try: kl_pq = metrics.mutual_info_score(lst_att1, lst_att2) divergence.append(kl_pq) except: divergence.append(None) avg = sum(divergence) / len(divergence) return avg def fixation_fixation(fixation1, fixation2): divergence = [] for fix1, fix2 in zip(fixation1, fixation2): current_fixation1 = ast.literal_eval(fix1) current_fixation2 = ast.literal_eval(fix2) lst_fixation1 = flatten(current_fixation1) lst_fixation2 = flatten(current_fixation2) try: kl_pq = metrics.mutual_info_score(lst_fixation1, lst_fixation2) divergence.append(kl_pq) except: divergence.append(None) avg = sum(divergence) / len(divergence) return avg def attention_fixation(attentions,fixations): divergence = [] for attention, fixation in zip(attentions, fixations): current_attention = ast.literal_eval(attention) current_fixation = ast.literal_eval(fixation) lst_attention = [] for t in current_attention: attention_lst = flatten(t) lst_attention.append(sum(attention_lst)/len(attention_lst)) lst_fixation = flatten(current_fixation) try: kl_pq = metrics.mutual_info_score(lst_attention, lst_fixation) divergence.append(kl_pq) except: divergence.append(None) avg = sum(divergence)/len(divergence) return avg def divergent_calculations(file1, file2=None, val1=None): attentions, fixations = get_data(file1) attentions2, fixations2 = get_data(file2) if file2: if val1 == "attention": divergence = attention_attention(attentions, attentions2) else: divergence = fixation_fixation(fixations, fixations2) else: divergence = attention_fixation(attentions, fixations) print ("DL Divergence: ", divergence) divergent_calculations(sys.argv[1], sys.argv[2], sys.argv[3])