human-gaze-guided-neural-at.../joint_sentence_compression_model/libs/sentence_compression/main.py
2020-12-08 21:10:52 +01:00

92 lines
3.7 KiB
Python

import torch
import torch.nn as nn
from torch.autograd import Variable
class Network(nn.Module):
def __init__(self,
embeddings,
hidden_size: int,
prior,
device: torch.device):
super(Network, self).__init__()
self.device = device
self.priors = torch.log(torch.tensor([prior, 1-prior])).to(device)
self.hidden_size = hidden_size
self.bilstm_layers = 3
self.bilstm_input_size = 300
self.bilstm_output_size = 2 * hidden_size
self.word_emb = nn.Embedding.from_pretrained(embeddings, freeze=False)
self.bilstm = nn.LSTM(self.bilstm_input_size,
self.hidden_size,
num_layers=self.bilstm_layers,
batch_first=True,
dropout=0.1, #ms best mod 0.1
bidirectional=True)
self.dropout = nn.Dropout(p=0.35)
if self.attention:
self.attention_size = self.bilstm_output_size * 2
self.u_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size)
self.w_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size)
self.v_a_inv = nn.Linear(self.bilstm_output_size, 1, bias=False)
self.linear_attn = nn.Linear(self.attention_size, self.bilstm_output_size)
self.linear = nn.Linear(self.bilstm_output_size, self.hidden_size)
self.pred = nn.Linear(self.hidden_size, 2)
self.softmax = nn.LogSoftmax(dim=1)
self.criterion = nn.NLLLoss(ignore_index=-1)
def forward(self, input_tokens, labels, fixations=None):
loss = 0.0
preds = []
atts = []
batch_size, seq_len = input_tokens.size()
self.init_hidden(batch_size, device=self.device)
x_i = self.word_emb(input_tokens)
x_i = self.dropout(x_i)
hidden, (self.h_n, self.c_n) = self.bilstm(x_i, (self.h_n, self.c_n))
_, _, hidden_size = hidden.size()
for i in range(seq_len):
nth_hidden = hidden[:, i, :]
if self.attention:
target = nth_hidden.expand(seq_len, batch_size, -1).transpose(0, 1)
mask = hidden.eq(target)[:, :, 0].unsqueeze(2)
attn_weight = self.attention(hidden, target, fixations, mask)
context_vector = torch.bmm(attn_weight.transpose(1, 2), hidden).squeeze(1)
nth_hidden = torch.tanh(self.linear_attn(torch.cat((nth_hidden, context_vector), -1)))
atts.append(attn_weight.detach().cpu())
logits = self.pred(self.linear(nth_hidden))
if not self.training:
logits = logits + self.priors
output = self.softmax(logits)
loss += self.criterion(output, labels[:, i])
_, topi = output.topk(k=1, dim=1)
pred = topi.squeeze(-1)
preds.append(pred)
preds = torch.stack(torch.cat(preds, dim=0).split(batch_size), dim=1)
if atts:
atts = torch.stack(torch.cat(atts, dim=0).split(batch_size), dim=1)
return loss, preds, atts
def attention(self, source, target, fixations=None, mask=None):
function_g = \
self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
if mask is not None:
function_g.masked_fill_(mask, -1e4)
if fixations is not None:
function_g = function_g*fixations
return nn.functional.softmax(function_g, dim=1)
def init_hidden(self, batch_size, device):
zeros = Variable(torch.zeros(2*self.bilstm_layers, batch_size, self.hidden_size))
self.h_n = zeros.to(device)
self.c_n = zeros.to(device)
return self.h_n, self.c_n