Add NLP task models
This commit is contained in:
parent
d8beb17dfb
commit
69f6de0ace
46 changed files with 4976 additions and 0 deletions
43
joint_sentence_compression_model/utils/check_stats.py
Normal file
43
joint_sentence_compression_model/utils/check_stats.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import ast
|
||||
from statistics import mean, pstdev
|
||||
import sys
|
||||
|
||||
import click
|
||||
from scipy.stats import entropy
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def reader(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
s, p, a, f = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line: {line}", file=sys.stderr)
|
||||
else:
|
||||
try:
|
||||
yield ast.literal_eval(s), ast.literal_eval(p), ast.literal_eval(a), ast.literal_eval(f)
|
||||
except:
|
||||
print(f"malformed line: {s}")
|
||||
|
||||
|
||||
def get_stats(seq):
|
||||
for s, p, a, f in seq:
|
||||
print(s)
|
||||
print(p)
|
||||
print(len(s), len(p), len(a), len(f))
|
||||
for x in a:
|
||||
print(len(x))
|
||||
print()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path")
|
||||
def main(path):
|
||||
get_stats(reader(path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
34
joint_sentence_compression_model/utils/cr.py
Normal file
34
joint_sentence_compression_model/utils/cr.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import ast
|
||||
from math import log2
|
||||
import os
|
||||
from statistics import mean, pstdev
|
||||
import sys
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
from scipy.special import kl_div
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def reader(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
yield line.split()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("original")
|
||||
@click.argument("compressed")
|
||||
def main(original, compressed):
|
||||
ratio = 0
|
||||
total = 0
|
||||
for o, c in zip(reader(original), reader(compressed)):
|
||||
ratio += len(c)/len(o)
|
||||
total += 1
|
||||
print(f"cr: {ratio/total:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
88
joint_sentence_compression_model/utils/kl.py
Normal file
88
joint_sentence_compression_model/utils/kl.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import ast
|
||||
from math import log2
|
||||
import os
|
||||
from statistics import mean, pstdev
|
||||
import sys
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
from scipy.special import kl_div
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def attention_reader(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
s, p, a, f = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line: {line}", file=sys.stderr)
|
||||
else:
|
||||
try:
|
||||
yield [[x[0] for x in y] for y in ast.literal_eval(a)]
|
||||
except:
|
||||
print(f"skipping malformed line: {s}", file=sys.stderr)
|
||||
|
||||
|
||||
def _kl_divergence(p, q):
|
||||
p = np.asarray(p)
|
||||
q = np.asarray(q)
|
||||
p /= sum(p)
|
||||
q /= sum(q)
|
||||
return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
|
||||
|
||||
def kl_divergence(ps, qs):
|
||||
kl = 0
|
||||
count = 0
|
||||
for p, q in zip(ps, qs):
|
||||
print(p, q)
|
||||
kl += _kl_divergence(p, q)
|
||||
count += 1
|
||||
return kl/count
|
||||
|
||||
|
||||
def _js_divergence(p, q):
|
||||
p = np.asarray(p)
|
||||
q = np.asarray(q)
|
||||
p /= sum(p)
|
||||
q /= sum(q)
|
||||
print(p, q)
|
||||
m = 0.5 * (p + q)
|
||||
return 0.5 * _kl_divergence(p, m) + 0.5 * _kl_divergence(q, m)
|
||||
|
||||
def js_divergence(ps, qs):
|
||||
js = 0
|
||||
count = 0
|
||||
for p, q in zip(ps, qs):
|
||||
js += _js_divergence(p, q)
|
||||
count += 1
|
||||
return js/count
|
||||
|
||||
|
||||
def get_kl_div(seq1, seq2):
|
||||
return [js_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
|
||||
return [kl_divergence(x1, x2) for x1, x2 in zip(seq1, seq2)]
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("ref")
|
||||
@click.argument("path", nargs=-1)
|
||||
def main(ref, path):
|
||||
kls = []
|
||||
labels = []
|
||||
for p in path:
|
||||
labels.append(os.path.basename(p))
|
||||
kl = get_kl_div(attention_reader(ref), attention_reader(p))
|
||||
print(mean(kl))
|
||||
print(pstdev(kl))
|
||||
kls.append(kl)
|
||||
|
||||
plt.boxplot(kls, labels=labels)
|
||||
plt.show()
|
||||
plt.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
112
joint_sentence_compression_model/utils/kl_divergence.py
Normal file
112
joint_sentence_compression_model/utils/kl_divergence.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
'''
|
||||
Calculates the KL divergence between specified columns in a file or across columns of different files
|
||||
'''
|
||||
|
||||
from math import log2
|
||||
import pandas as pd
|
||||
import ast
|
||||
import collections
|
||||
from sklearn import metrics
|
||||
import sys
|
||||
|
||||
|
||||
def kl_divergence(p, q):
|
||||
return sum(p[i] * log2(p[i] / q[i]) for i in range(len(p)))
|
||||
|
||||
|
||||
def flatten(x):
|
||||
if isinstance(x, collections.Iterable):
|
||||
return [a for i in x for a in flatten(i)]
|
||||
else:
|
||||
return [x]
|
||||
|
||||
|
||||
def get_data(file):
|
||||
names = ["sentence", "prediction", "attentions", "fixations"]
|
||||
df = pd.read_csv(file, sep='\t', names=names)
|
||||
df = df[2:]
|
||||
attentions = df.loc[:, "attentions"].tolist()
|
||||
fixations = df.loc[:, "fixations"].tolist()
|
||||
|
||||
return attentions, fixations
|
||||
|
||||
def attention_attention(attentions1, attentions2):
|
||||
divergence = []
|
||||
|
||||
for att1, att2 in zip(attentions1, attentions2):
|
||||
current_att1 = ast.literal_eval(att1)
|
||||
current_att2 = ast.literal_eval(att2)
|
||||
|
||||
lst_att1 = flatten(current_att1)
|
||||
lst_att2 = flatten(current_att2)
|
||||
|
||||
try:
|
||||
kl_pq = metrics.mutual_info_score(lst_att1, lst_att2)
|
||||
divergence.append(kl_pq)
|
||||
except:
|
||||
divergence.append(None)
|
||||
|
||||
avg = sum(divergence) / len(divergence)
|
||||
return avg
|
||||
|
||||
def fixation_fixation(fixation1, fixation2):
|
||||
divergence = []
|
||||
|
||||
for fix1, fix2 in zip(fixation1, fixation2):
|
||||
|
||||
current_fixation1 = ast.literal_eval(fix1)
|
||||
current_fixation2 = ast.literal_eval(fix2)
|
||||
|
||||
lst_fixation1 = flatten(current_fixation1)
|
||||
lst_fixation2 = flatten(current_fixation2)
|
||||
|
||||
try:
|
||||
kl_pq = metrics.mutual_info_score(lst_fixation1, lst_fixation2)
|
||||
divergence.append(kl_pq)
|
||||
except:
|
||||
divergence.append(None)
|
||||
|
||||
avg = sum(divergence) / len(divergence)
|
||||
return avg
|
||||
|
||||
def attention_fixation(attentions,fixations):
|
||||
divergence = []
|
||||
|
||||
for attention, fixation in zip(attentions, fixations):
|
||||
current_attention = ast.literal_eval(attention)
|
||||
current_fixation = ast.literal_eval(fixation)
|
||||
|
||||
lst_attention = []
|
||||
|
||||
for t in current_attention:
|
||||
attention_lst = flatten(t)
|
||||
lst_attention.append(sum(attention_lst)/len(attention_lst))
|
||||
|
||||
lst_fixation = flatten(current_fixation)
|
||||
|
||||
try:
|
||||
kl_pq = metrics.mutual_info_score(lst_attention, lst_fixation)
|
||||
divergence.append(kl_pq)
|
||||
except:
|
||||
divergence.append(None)
|
||||
|
||||
avg = sum(divergence)/len(divergence)
|
||||
return avg
|
||||
|
||||
|
||||
def divergent_calculations(file1, file2=None, val1=None):
|
||||
attentions, fixations = get_data(file1)
|
||||
attentions2, fixations2 = get_data(file2)
|
||||
|
||||
if file2:
|
||||
if val1 == "attention":
|
||||
divergence = attention_attention(attentions, attentions2)
|
||||
else:
|
||||
divergence = fixation_fixation(fixations, fixations2)
|
||||
|
||||
else:
|
||||
divergence = attention_fixation(attentions, fixations)
|
||||
|
||||
print ("DL Divergence: ", divergence)
|
||||
|
||||
divergent_calculations(sys.argv[1], sys.argv[2], sys.argv[3])
|
64
joint_sentence_compression_model/utils/plot_attention.py
Normal file
64
joint_sentence_compression_model/utils/plot_attention.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import ast
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import text_attention
|
||||
|
||||
import click
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("agg")
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
|
||||
def plot_attention(input_sentence, output_words, attentions, path):
|
||||
# Set up figure with colorbar
|
||||
attentions = np.array(attentions)[:,:len(input_sentence)]
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
cax = ax.matshow(attentions, cmap="bone")
|
||||
fig.colorbar(cax)
|
||||
|
||||
# Set up axes
|
||||
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
|
||||
ax.set_yticklabels([""] + output_words)
|
||||
|
||||
# Show label at every tick
|
||||
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
|
||||
plt.savefig(f"{path}.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
def parse(p):
|
||||
with open(p) as h:
|
||||
for line in h:
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
|
||||
try:
|
||||
sentence = ast.literal_eval(_sentence)
|
||||
prediction = ast.literal_eval(_prediction)
|
||||
attention = ast.literal_eval(_attention)
|
||||
except:
|
||||
continue
|
||||
|
||||
yield sentence, prediction, attention
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path", nargs=-1, required=True)
|
||||
def main(path):
|
||||
for p in tqdm.tqdm(path):
|
||||
out_dir = os.path.splitext(p)[0]
|
||||
if out_dir == path:
|
||||
out_dir = f"{out_dir}_"
|
||||
pathlib.Path(out_dir).mkdir(exist_ok=True)
|
||||
for i, spa in enumerate(parse(p)):
|
||||
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
70
joint_sentence_compression_model/utils/text_attention.py
Executable file
70
joint_sentence_compression_model/utils/text_attention.py
Executable file
|
@ -0,0 +1,70 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author: Jie Yang
|
||||
# @Date: 2019-03-29 16:10:23
|
||||
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||
# @Last Modified time: 2019-04-12 09:56:12
|
||||
|
||||
|
||||
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
|
||||
import numpy as np
|
||||
|
||||
latex_special_token = ["!@#$%^&*()"]
|
||||
|
||||
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
|
||||
assert(len(text_list) == len(attention_list))
|
||||
if rescale_value:
|
||||
attention_list = rescale(attention_list)
|
||||
word_num = len(text_list)
|
||||
text_list = clean_word(text_list)
|
||||
with open(latex_file,'w') as f:
|
||||
f.write(r'''\documentclass[varwidth]{standalone}
|
||||
\special{papersize=210mm,297mm}
|
||||
\usepackage{color}
|
||||
\usepackage{tcolorbox}
|
||||
\usepackage{CJK}
|
||||
\usepackage{adjustbox}
|
||||
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
||||
\begin{document}
|
||||
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
|
||||
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
|
||||
for idx in range(word_num):
|
||||
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
|
||||
string += "\n}}}"
|
||||
f.write(string+'\n')
|
||||
f.write(r'''\end{CJK*}
|
||||
\end{document}''')
|
||||
|
||||
def rescale(input_list):
|
||||
the_array = np.asarray(input_list)
|
||||
the_max = np.max(the_array)
|
||||
the_min = np.min(the_array)
|
||||
rescale = (the_array - the_min)/(the_max-the_min)*100
|
||||
return rescale.tolist()
|
||||
|
||||
|
||||
def clean_word(word_list):
|
||||
new_word_list = []
|
||||
for word in word_list:
|
||||
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
|
||||
if latex_sensitive in word:
|
||||
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
|
||||
new_word_list.append(word)
|
||||
return new_word_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
## This is a demo:
|
||||
|
||||
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
|
||||
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
|
||||
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
|
||||
sent = '''我 回忆 起 我 曾经 在 大学 年代 , 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar , 我 想起 了 西游记 里 的 琵琶精 。
|
||||
今年 下半年 , 中 美 合拍 的 西游记 即将 正式 开机 , 我 继续 扮演 美猴王 孙悟空 , 我 会 用 美猴王 艺术 形象 努力 创造 一 个 正能量 的 形象 , 文 体 两 开花 , 弘扬 中华 文化 , 希望 大家 能 多多 关注 。'''
|
||||
words = sent.split()
|
||||
word_num = len(words)
|
||||
attention = [(x+1.)/word_num*100 for x in range(word_num)]
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(attention)
|
||||
color = 'red'
|
||||
generate(words, attention, "sample.tex", color)
|
Loading…
Add table
Add a link
Reference in a new issue