126 lines
3.8 KiB
Python
126 lines
3.8 KiB
Python
from collections import OrderedDict
|
|
import logging
|
|
import sys
|
|
|
|
from .self_attention import Transformer
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence, pad_sequence
|
|
|
|
|
|
def random_embedding(vocab_size, embedding_dim):
|
|
pretrain_emb = np.empty([vocab_size, embedding_dim])
|
|
scale = np.sqrt(3.0 / embedding_dim)
|
|
for index in range(vocab_size):
|
|
pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
|
|
return pretrain_emb
|
|
|
|
|
|
def neg_log_likelihood_loss(outputs, batch_label, batch_size, seq_len):
|
|
outputs = outputs.view(batch_size * seq_len, -1)
|
|
score = F.log_softmax(outputs, 1)
|
|
|
|
loss = nn.NLLLoss(ignore_index=0, size_average=False)(
|
|
score, batch_label.view(batch_size * seq_len)
|
|
)
|
|
loss = loss / batch_size
|
|
_, tag_seq = torch.max(score, 1)
|
|
tag_seq = tag_seq.view(batch_size, seq_len)
|
|
|
|
return loss, tag_seq
|
|
|
|
|
|
def mse_loss(outputs, batch_label, batch_size, seq_len, word_seq_length):
|
|
score = torch.sigmoid(outputs)
|
|
|
|
mask = torch.zeros_like(score)
|
|
for i, v in enumerate(word_seq_length):
|
|
mask[i, 0:v] = 1
|
|
|
|
score = score * mask
|
|
|
|
loss = nn.MSELoss(reduction="sum")(
|
|
score.view(batch_size, seq_len), batch_label.view(batch_size, seq_len)
|
|
)
|
|
|
|
loss = loss / batch_size
|
|
|
|
return loss, score.view(batch_size, seq_len)
|
|
|
|
|
|
class Network(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_type,
|
|
vocab_size,
|
|
embedding_dim,
|
|
dropout,
|
|
hidden_dim,
|
|
embeddings=None,
|
|
attention=True,
|
|
):
|
|
super().__init__()
|
|
self.logger = logging.getLogger(f"{__name__}")
|
|
self.attention = attention
|
|
prelayers = OrderedDict()
|
|
postlayers = OrderedDict()
|
|
|
|
if embedding_type in ("w2v", "glove"):
|
|
if embeddings is not None:
|
|
prelayers["embedding_layer"] = nn.Embedding.from_pretrained(embeddings, freeze=True)
|
|
else:
|
|
prelayers["embedding_layer"] = nn.Embedding(vocab_size, embedding_dim)
|
|
prelayers["embedding_dropout_layer"] = nn.Dropout(dropout)
|
|
embedding_dim = 300
|
|
elif embedding_type == "bert":
|
|
embedding_dim = 768
|
|
|
|
self.lstm = BiLSTM(embedding_dim, hidden_dim // 2, num_layers=1)
|
|
postlayers["lstm_dropout_layer"] = nn.Dropout(dropout)
|
|
|
|
if self.attention:
|
|
postlayers["attention_layer"] = Transformer(
|
|
d_model=hidden_dim, n_heads=4, n_layers=1
|
|
)
|
|
|
|
postlayers["ff_layer"] = nn.Linear(hidden_dim, hidden_dim // 2)
|
|
postlayers["ff_activation"] = nn.ReLU()
|
|
postlayers["output_layer"] = nn.Linear(hidden_dim // 2, 1)
|
|
|
|
self.logger.info(f"prelayers: {prelayers.keys()}")
|
|
self.logger.info(f"postlayers: {postlayers.keys()}")
|
|
|
|
self.pre = nn.Sequential(prelayers)
|
|
self.post = nn.Sequential(postlayers)
|
|
|
|
def forward(self, x, word_seq_length):
|
|
x = self.pre(x)
|
|
x = self.lstm(x, word_seq_length)
|
|
|
|
output = []
|
|
for _x, l in zip(x.transpose(1, 0), word_seq_length):
|
|
output.append(self.post(_x[:l].unsqueeze(0))[0])
|
|
|
|
return pad_sequence(output, batch_first=True)
|
|
|
|
|
|
class BiLSTM(nn.Module):
|
|
def __init__(self, embedding_dim, lstm_hidden, num_layers):
|
|
super().__init__()
|
|
self.net = nn.LSTM(
|
|
input_size=embedding_dim,
|
|
hidden_size=lstm_hidden,
|
|
num_layers=num_layers,
|
|
batch_first=True,
|
|
bidirectional=True,
|
|
)
|
|
|
|
def forward(self, x, word_seq_length):
|
|
packed_words = pack_padded_sequence(x, word_seq_length, True, False)
|
|
lstm_out, hidden = self.net(packed_words)
|
|
lstm_out, _ = pad_packed_sequence(lstm_out)
|
|
return lstm_out
|