184 lines
6.2 KiB
Python
184 lines
6.2 KiB
Python
|
import torch
|
||
|
from torch import optim
|
||
|
import tqdm
|
||
|
|
||
|
from const import Phase
|
||
|
from batch import create_dataset
|
||
|
from models import Baseline
|
||
|
from sklearn.metrics import classification_report
|
||
|
|
||
|
|
||
|
def run(dataset_train,
|
||
|
dataset_dev,
|
||
|
dataset_test,
|
||
|
model_type,
|
||
|
word_embed_size,
|
||
|
hidden_size,
|
||
|
batch_size,
|
||
|
device,
|
||
|
n_epochs):
|
||
|
|
||
|
if model_type == 'base':
|
||
|
model = Baseline(vocab=dataset_train.vocab,
|
||
|
word_embed_size=word_embed_size,
|
||
|
hidden_size=hidden_size,
|
||
|
device=device,
|
||
|
inference=False)
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
model = model.to(device)
|
||
|
|
||
|
optim_params = model.parameters()
|
||
|
optimizer = optim.Adam(optim_params, lr=10**-3)
|
||
|
|
||
|
print('start training')
|
||
|
for epoch in range(n_epochs):
|
||
|
train_loss, tokens, preds, golds = train(dataset_train,
|
||
|
model,
|
||
|
optimizer,
|
||
|
batch_size,
|
||
|
epoch,
|
||
|
Phase.TRAIN,
|
||
|
device)
|
||
|
|
||
|
dev_loss, tokens, preds, golds = train(dataset_dev,
|
||
|
model,
|
||
|
optimizer,
|
||
|
batch_size,
|
||
|
epoch,
|
||
|
Phase.DEV,
|
||
|
device)
|
||
|
logger = '\t'.join(['epoch {}'.format(epoch+1),
|
||
|
'TRAIN Loss: {:.9f}'.format(train_loss),
|
||
|
'DEV Loss: {:.9f}'.format(dev_loss)])
|
||
|
# print('\r'+logger, end='')
|
||
|
print(logger)
|
||
|
test_loss, tokens, preds, golds = train(dataset_test,
|
||
|
model,
|
||
|
optimizer,
|
||
|
batch_size,
|
||
|
epoch,
|
||
|
Phase.TEST,
|
||
|
device)
|
||
|
print('====', 'TEST', '=====')
|
||
|
print_scores(preds, golds)
|
||
|
output_results(tokens, preds, golds)
|
||
|
|
||
|
|
||
|
def train(dataset,
|
||
|
model,
|
||
|
optimizer,
|
||
|
batch_size,
|
||
|
n_epoch,
|
||
|
phase,
|
||
|
device):
|
||
|
|
||
|
total_loss = 0.0
|
||
|
tokens = []
|
||
|
preds = []
|
||
|
labels = []
|
||
|
if phase == Phase.TRAIN:
|
||
|
model.train()
|
||
|
else:
|
||
|
model.eval()
|
||
|
|
||
|
for batch in tqdm.tqdm(dataset.batch_iter):
|
||
|
token = getattr(batch, 'token')
|
||
|
label = getattr(batch, 'label')
|
||
|
raw_sentences = dataset.get_raw_sentence(token.data.detach().cpu().numpy())
|
||
|
|
||
|
loss, pred = \
|
||
|
model(token, raw_sentences, label, phase)
|
||
|
|
||
|
if phase == Phase.TRAIN:
|
||
|
optimizer.zero_grad()
|
||
|
torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=5)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
# remove PAD from input sentences/labels and results
|
||
|
mask = (token != dataset.pad_index)
|
||
|
length_tensor = mask.sum(1)
|
||
|
length_tensor = length_tensor.data.detach().cpu().numpy()
|
||
|
|
||
|
for index, n_tokens_in_the_sentence in enumerate(length_tensor):
|
||
|
if n_tokens_in_the_sentence > 0:
|
||
|
tokens.append(raw_sentences[index][:n_tokens_in_the_sentence])
|
||
|
_label = label[index][:n_tokens_in_the_sentence]
|
||
|
_pred = pred[index][:n_tokens_in_the_sentence]
|
||
|
_label = _label.data.detach().cpu().numpy()
|
||
|
_pred = _pred.data.detach().cpu().numpy()
|
||
|
labels.append(_label)
|
||
|
preds.append(_pred)
|
||
|
|
||
|
total_loss += torch.mean(loss).item()
|
||
|
|
||
|
return total_loss, tokens, preds, labels
|
||
|
|
||
|
|
||
|
def read_two_cols_data(fname, max_len=None):
|
||
|
data = {}
|
||
|
tokens = []
|
||
|
labels = []
|
||
|
token = []
|
||
|
label = []
|
||
|
with open(fname, mode='r') as f:
|
||
|
for line in f:
|
||
|
line = line.strip().lower().split()
|
||
|
if line:
|
||
|
try:
|
||
|
_token, _label = line
|
||
|
except ValueError:
|
||
|
raise
|
||
|
token.append(_token)
|
||
|
if _label == '0' or _label == '1':
|
||
|
label.append(int(_label))
|
||
|
else:
|
||
|
if _label == 'del':
|
||
|
label.append(1)
|
||
|
else:
|
||
|
label.append(0)
|
||
|
else:
|
||
|
if max_len is None or len(token) <= max_len:
|
||
|
tokens.append(token)
|
||
|
labels.append(label)
|
||
|
token = []
|
||
|
label = []
|
||
|
|
||
|
data['tokens'] = tokens
|
||
|
data['labels'] = labels
|
||
|
return data
|
||
|
|
||
|
|
||
|
def load(train_path, dev_path, test_path, batch_size, max_len, device):
|
||
|
train = read_two_cols_data(train_path, max_len)
|
||
|
dev = read_two_cols_data(dev_path)
|
||
|
test = read_two_cols_data(test_path)
|
||
|
data = {Phase.TRAIN: train, Phase.DEV: dev, Phase.TEST: test}
|
||
|
return create_dataset(data, batch_size=batch_size, device=device)
|
||
|
|
||
|
|
||
|
def print_scores(preds, golds):
|
||
|
_preds = [label for sublist in preds for label in sublist]
|
||
|
_golds = [label for sublist in golds for label in sublist]
|
||
|
target_names = ['not_del', 'del']
|
||
|
print(classification_report(_golds, _preds, target_names=target_names, digits=5))
|
||
|
|
||
|
|
||
|
def output_results(tokens, preds, golds, path='./result/sentcomp'):
|
||
|
with open(path+'.original.txt', mode='w') as w, \
|
||
|
open(path+'.gold.txt', mode='w') as w_gold, \
|
||
|
open(path+'.pred.txt', mode='w') as w_pred:
|
||
|
|
||
|
for _tokens, _golds, _preds in zip(tokens, golds, preds):
|
||
|
for token, gold, pred in zip(_tokens, _golds, _preds):
|
||
|
w.write(token + ' ')
|
||
|
if gold == 0:
|
||
|
w_gold.write(token + ' ')
|
||
|
# 0 -> keep, 1 -> delete
|
||
|
if pred == 0:
|
||
|
w_pred.write(token + ' ')
|
||
|
w.write('\n')
|
||
|
w_gold.write('\n')
|
||
|
w_pred.write('\n')
|