import torch import torch.nn as nn from torch.autograd import Variable class Network(nn.Module): def __init__(self, embeddings, hidden_size: int, prior, device: torch.device): super(Network, self).__init__() self.device = device self.priors = torch.log(torch.tensor([prior, 1-prior])).to(device) self.hidden_size = hidden_size self.bilstm_layers = 3 self.bilstm_input_size = 300 self.bilstm_output_size = 2 * hidden_size self.word_emb = nn.Embedding.from_pretrained(embeddings, freeze=False) self.bilstm = nn.LSTM(self.bilstm_input_size, self.hidden_size, num_layers=self.bilstm_layers, batch_first=True, dropout=0.1, #ms best mod 0.1 bidirectional=True) self.dropout = nn.Dropout(p=0.35) if self.attention: self.attention_size = self.bilstm_output_size * 2 self.u_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size) self.w_a = nn.Linear(self.bilstm_output_size, self.bilstm_output_size) self.v_a_inv = nn.Linear(self.bilstm_output_size, 1, bias=False) self.linear_attn = nn.Linear(self.attention_size, self.bilstm_output_size) self.linear = nn.Linear(self.bilstm_output_size, self.hidden_size) self.pred = nn.Linear(self.hidden_size, 2) self.softmax = nn.LogSoftmax(dim=1) self.criterion = nn.NLLLoss(ignore_index=-1) def forward(self, input_tokens, labels, fixations=None): loss = 0.0 preds = [] atts = [] batch_size, seq_len = input_tokens.size() self.init_hidden(batch_size, device=self.device) x_i = self.word_emb(input_tokens) x_i = self.dropout(x_i) hidden, (self.h_n, self.c_n) = self.bilstm(x_i, (self.h_n, self.c_n)) _, _, hidden_size = hidden.size() for i in range(seq_len): nth_hidden = hidden[:, i, :] if self.attention: target = nth_hidden.expand(seq_len, batch_size, -1).transpose(0, 1) mask = hidden.eq(target)[:, :, 0].unsqueeze(2) attn_weight = self.attention(hidden, target, fixations, mask) context_vector = torch.bmm(attn_weight.transpose(1, 2), hidden).squeeze(1) nth_hidden = torch.tanh(self.linear_attn(torch.cat((nth_hidden, context_vector), -1))) atts.append(attn_weight.detach().cpu()) logits = self.pred(self.linear(nth_hidden)) if not self.training: logits = logits + self.priors output = self.softmax(logits) loss += self.criterion(output, labels[:, i]) _, topi = output.topk(k=1, dim=1) pred = topi.squeeze(-1) preds.append(pred) preds = torch.stack(torch.cat(preds, dim=0).split(batch_size), dim=1) if atts: atts = torch.stack(torch.cat(atts, dim=0).split(batch_size), dim=1) return loss, preds, atts def attention(self, source, target, fixations=None, mask=None): function_g = \ self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target))) if mask is not None: function_g.masked_fill_(mask, -1e4) if fixations is not None: function_g = function_g*fixations return nn.functional.softmax(function_g, dim=1) def init_hidden(self, batch_size, device): zeros = Variable(torch.zeros(2*self.bilstm_layers, batch_size, self.hidden_size)) self.h_n = zeros.to(device) self.c_n = zeros.to(device) return self.h_n, self.c_n