Add NLP task models
This commit is contained in:
parent
d8beb17dfb
commit
69f6de0ace
46 changed files with 4976 additions and 0 deletions
43
joint_paraphrase_model/utils/long_sentence_split.py
Normal file
43
joint_paraphrase_model/utils/long_sentence_split.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def read(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
b, s, p, a, f = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line {line}", file=sys.stderr)
|
||||
continue
|
||||
else:
|
||||
yield b, s, p, a, f
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path")
|
||||
def main(path):
|
||||
data = list(read(path))
|
||||
avg_len = sum(len(x[1]) for x in data)/len(data)
|
||||
filtered_data = []
|
||||
filtered_data2 = []
|
||||
|
||||
fname, ext = os.path.splitext(path)
|
||||
ext = f".{ext}" if ext else ext
|
||||
with open(f"{fname}_long{ext}", "w") as lh, open(f"{fname}_short{ext}", "w") as sh:
|
||||
for x in data:
|
||||
if len(x[1]) > avg_len:
|
||||
lh.write("\t".join(x))
|
||||
lh.write("\n")
|
||||
else:
|
||||
sh.write("\t".join(x))
|
||||
sh.write("\n")
|
||||
|
||||
print(f"avg sentence length {avg_len}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
40
joint_paraphrase_model/utils/long_sentence_stats.py
Normal file
40
joint_paraphrase_model/utils/long_sentence_stats.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import sys
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def read(path):
|
||||
with open(path) as h:
|
||||
for line in h:
|
||||
line = line.strip()
|
||||
try:
|
||||
b, s, p, *_ = line.split("\t")
|
||||
except:
|
||||
print(f"skipping line {line}", file=sys.stderr)
|
||||
continue
|
||||
else:
|
||||
yield float(b), s, p
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path")
|
||||
def main(path):
|
||||
data = list(read(path))
|
||||
avg_len = sum(len(x[1]) for x in data)/len(data)
|
||||
filtered_data = []
|
||||
filtered_data2 = []
|
||||
for x in data:
|
||||
if len(x[1]) > avg_len:
|
||||
filtered_data.append(x)
|
||||
else:
|
||||
filtered_data2.append(x)
|
||||
print(f"avg sentence length {avg_len}")
|
||||
print(f"long sentences {len(filtered_data)}")
|
||||
print(f"short sentences {len(filtered_data2)}")
|
||||
print(f"total bleu {sum(x[0] for x in data)/len(data)}")
|
||||
print(f"longest bleu {sum(x[0] for x in filtered_data)/len(filtered_data)}")
|
||||
print(f"shortest bleu {sum(x[0] for x in filtered_data2)/len(filtered_data2)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
64
joint_paraphrase_model/utils/plot_attention.py
Normal file
64
joint_paraphrase_model/utils/plot_attention.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import ast
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import text_attention
|
||||
|
||||
import click
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("agg")
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
|
||||
def plot_attention(input_sentence, output_words, attentions, path):
|
||||
# Set up figure with colorbar
|
||||
attentions = np.array(attentions)[:,:len(input_sentence)]
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
cax = ax.matshow(attentions, cmap="bone")
|
||||
fig.colorbar(cax)
|
||||
|
||||
# Set up axes
|
||||
ax.set_xticklabels([""] + input_sentence + ["<__EOS__>"], rotation=90)
|
||||
ax.set_yticklabels([""] + output_words)
|
||||
|
||||
# Show label at every tick
|
||||
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
|
||||
|
||||
plt.savefig(f"{path}.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
def parse(p):
|
||||
with open(p) as h:
|
||||
for line in h:
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
_sentence, _prediction, _attention, _fixations = line.strip().split("\t")
|
||||
try:
|
||||
sentence = ast.literal_eval(_sentence)
|
||||
prediction = ast.literal_eval(_prediction)
|
||||
attention = ast.literal_eval(_attention)
|
||||
except:
|
||||
continue
|
||||
|
||||
yield sentence, prediction, attention
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("path", nargs=-1, required=True)
|
||||
def main(path):
|
||||
for p in tqdm.tqdm(path):
|
||||
out_dir = os.path.splitext(p)[0]
|
||||
if out_dir == path:
|
||||
out_dir = f"{out_dir}_"
|
||||
pathlib.Path(out_dir).mkdir(exist_ok=True)
|
||||
for i, spa in enumerate(parse(p)):
|
||||
plot_attention(*spa, path=os.path.join(out_dir, str(i)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
70
joint_paraphrase_model/utils/text_attention.py
Executable file
70
joint_paraphrase_model/utils/text_attention.py
Executable file
|
@ -0,0 +1,70 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author: Jie Yang
|
||||
# @Date: 2019-03-29 16:10:23
|
||||
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||
# @Last Modified time: 2019-04-12 09:56:12
|
||||
|
||||
|
||||
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
|
||||
import numpy as np
|
||||
|
||||
latex_special_token = ["!@#$%^&*()"]
|
||||
|
||||
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False):
|
||||
assert(len(text_list) == len(attention_list))
|
||||
if rescale_value:
|
||||
attention_list = rescale(attention_list)
|
||||
word_num = len(text_list)
|
||||
text_list = clean_word(text_list)
|
||||
with open(latex_file,'w') as f:
|
||||
f.write(r'''\documentclass[varwidth]{standalone}
|
||||
\special{papersize=210mm,297mm}
|
||||
\usepackage{color}
|
||||
\usepackage{tcolorbox}
|
||||
\usepackage{CJK}
|
||||
\usepackage{adjustbox}
|
||||
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
|
||||
\begin{document}
|
||||
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
|
||||
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
|
||||
for idx in range(word_num):
|
||||
string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
|
||||
string += "\n}}}"
|
||||
f.write(string+'\n')
|
||||
f.write(r'''\end{CJK*}
|
||||
\end{document}''')
|
||||
|
||||
def rescale(input_list):
|
||||
the_array = np.asarray(input_list)
|
||||
the_max = np.max(the_array)
|
||||
the_min = np.min(the_array)
|
||||
rescale = (the_array - the_min)/(the_max-the_min)*100
|
||||
return rescale.tolist()
|
||||
|
||||
|
||||
def clean_word(word_list):
|
||||
new_word_list = []
|
||||
for word in word_list:
|
||||
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
|
||||
if latex_sensitive in word:
|
||||
word = word.replace(latex_sensitive, '\\'+latex_sensitive)
|
||||
new_word_list.append(word)
|
||||
return new_word_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
## This is a demo:
|
||||
|
||||
sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".
|
||||
North Korea and the US have ratcheted up tensions in recent weeks and the movement of the strike group had raised the question of a pre-emptive strike by the US.
|
||||
On Wednesday, Mr Pence described the country as the "most dangerous and urgent threat to peace and security" in the Asia-Pacific.'''
|
||||
sent = '''我 回忆 起 我 曾经 在 大学 年代 , 我们 经常 喜欢 玩 “ Hawaii guitar ” 。 说起 Guitar , 我 想起 了 西游记 里 的 琵琶精 。
|
||||
今年 下半年 , 中 美 合拍 的 西游记 即将 正式 开机 , 我 继续 扮演 美猴王 孙悟空 , 我 会 用 美猴王 艺术 形象 努力 创造 一 个 正能量 的 形象 , 文 体 两 开花 , 弘扬 中华 文化 , 希望 大家 能 多多 关注 。'''
|
||||
words = sent.split()
|
||||
word_num = len(words)
|
||||
attention = [(x+1.)/word_num*100 for x in range(word_num)]
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(attention)
|
||||
color = 'red'
|
||||
generate(words, attention, "sample.tex", color)
|
Loading…
Add table
Add a link
Reference in a new issue