Add NLP task models

This commit is contained in:
Ekta Sood 2020-12-08 21:10:52 +01:00
parent d8beb17dfb
commit 69f6de0ace
46 changed files with 4976 additions and 0 deletions

View file

@ -0,0 +1,43 @@
import os
import sys
import click
def read(path):
with open(path) as h:
for line in h:
line = line.strip()
try:
b, s, p, a, f = line.split("\t")
except:
print(f"skipping line {line}", file=sys.stderr)
continue
else:
yield b, s, p, a, f
@click.command()
@click.argument("path")
def main(path):
data = list(read(path))
avg_len = sum(len(x[1]) for x in data)/len(data)
filtered_data = []
filtered_data2 = []
fname, ext = os.path.splitext(path)
ext = f".{ext}" if ext else ext
with open(f"{fname}_long{ext}", "w") as lh, open(f"{fname}_short{ext}", "w") as sh:
for x in data:
if len(x[1]) > avg_len:
lh.write("\t".join(x))
lh.write("\n")
else:
sh.write("\t".join(x))
sh.write("\n")
print(f"avg sentence length {avg_len}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,40 @@
import sys
import click
def read(path):
with open(path) as h:
for line in h:
line = line.strip()
try:
b, s, p, *_ = line.split("\t")
except:
print(f"skipping line {line}", file=sys.stderr)
continue
else:
yield float(b), s, p
@click.command()
@click.argument("path")
def main(path):
data = list(read(path))
avg_len = sum(len(x[1]) for x in data)/len(data)
filtered_data = []
filtered_data2 = []
for x in data:
if len(x[1]) > avg_len:
filtered_data.append(x)
else:
filtered_data2.append(x)
print(f"avg sentence length {avg_len}")
print(f"long sentences {len(filtered_data)}")
print(f"short sentences {len(filtered_data2)}")
print(f"total bleu {sum(x[0] for x in data)/len(data)}")
print(f"longest bleu {sum(x[0] for x in filtered_data)/len(filtered_data)}")
print(f"shortest bleu {sum(x[0] for x in filtered_data2)/len(filtered_data2)}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,64 @@
import ast
import os
import pathlib
import text_attention
import click
import matplotlib.pyplot as plt
plt.switch_backend("agg")
import matplotlib.ticker as ticker
import numpy as np
import tqdm
def plot_attention(input_sentence, output_words, attentions, path):
# Set up figure with colorbar
attentions = np.array(attentions)[:,:len(input_sentence)]
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(attentions, cmap="bone")
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
ax.set_yticklabels([""] + output_words)
# Show label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.savefig(f"{path}.pdf")
plt.close()
def parse(p):
with open(p) as h:
for line in h:
if not line or line.startswith("#"):
continue
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
try:
sentence = ast.literal_eval(_sentence)
prediction = ast.literal_eval(_prediction)
attention = ast.literal_eval(_attention)
except:
continue
yield sentence, prediction, attention
@click.command()
@click.argument("path", nargs=-1, required=True)
def main(path):
for p in tqdm.tqdm(path):
out_dir = os.path.splitext(p)[0]
if out_dir == path:
out_dir = f"{out_dir}_"
pathlib.Path(out_dir).mkdir(exist_ok=True)
for i, spa in enumerate(parse(p)):
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Date: 2019-03-29 16:10:23
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
# @Last Modified time: 2019-04-12 09:56:12
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
import numpy as np
latex_special_token = ["!@#$%^&*()"]
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
assert(len(text_list) == len(attention_list))
if rescale_value:
attention_list = rescale(attention_list)
word_num = len(text_list)
text_list = clean_word(text_list)
with open(latex_file,'w') as f:
f.write(r'''\documentclass[varwidth]{standalone}
\special{papersize=210mm,297mm}
\usepackage{color}
\usepackage{tcolorbox}
\usepackage{CJK}
\usepackage{adjustbox}
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
\begin{document}
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
for idx in range(word_num):
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
string += "\n}}}"
f.write(string+'\n')
f.write(r'''\end{CJK*}
\end{document}''')
def rescale(input_list):
the_array = np.asarray(input_list)
the_max = np.max(the_array)
the_min = np.min(the_array)
rescale = (the_array - the_min)/(the_max-the_min)*100
return rescale.tolist()
def clean_word(word_list):
new_word_list = []
for word in word_list:
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
if latex_sensitive in word:
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
new_word_list.append(word)
return new_word_list
if __name__ == '__main__':
## This is a demo:
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
sent = '''我 回忆 起 我 曾经 在 大学 年代 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar 我 想起 了 西游记 里 的 琵琶精 。
今年 下半年 合拍 西游记 即将 正式 开机 继续 扮演 美猴王 孙悟空 美猴王 艺术 形象 努力 创造 正能量 形象 开花 弘扬 中华 文化 希望 大家 多多 关注 '''
words = sent.split()
word_num = len(words)
attention = [(x+1.)/word_num*100 for x in range(word_num)]
import random
random.seed(42)
random.shuffle(attention)
color = 'red'
generate(words, attention, "sample.tex", color)