131 lines
4.8 KiB
Python
131 lines
4.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import numpy as np
|
|
|
|
import math
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, d_hid, n_position=200):
|
|
super(PositionalEncoding, self).__init__()
|
|
|
|
# Not a parameter
|
|
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
|
|
|
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
|
''' Sinusoid position encoding table '''
|
|
# TODO: make it with torch instead of numpy
|
|
|
|
def get_position_angle_vec(position):
|
|
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
|
|
|
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
|
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
|
|
|
def forward(self, x):
|
|
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
|
|
|
|
|
class AttentionLayer(nn.Module):
|
|
def __init__(self):
|
|
super(AttentionLayer, self).__init__()
|
|
|
|
def forward(self, Q, K, V):
|
|
# Q: float32:[batch_size, n_queries, d_k]
|
|
# K: float32:[batch_size, n_keys, d_k]
|
|
# V: float32:[batch_size, n_keys, d_v]
|
|
dk = K.shape[-1]
|
|
dv = V.shape[-1]
|
|
KT = torch.transpose(K, -1, -2)
|
|
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
|
|
# weight_logits: float32[batch_size, n_queries, n_keys]
|
|
weights = F.softmax(weight_logits, dim=-1)
|
|
# weight: float32[batch_size, n_queries, n_keys]
|
|
return torch.bmm(weights, V)
|
|
# return float32[batch_size, n_queries, dv]
|
|
|
|
|
|
class MultiHeadedSelfAttentionLayer(nn.Module):
|
|
def __init__(self, d_model, n_heads):
|
|
super(MultiHeadedSelfAttentionLayer, self).__init__()
|
|
self.d_model = d_model
|
|
self.n_heads = n_heads
|
|
print('{} {}'.format(d_model, n_heads))
|
|
assert d_model % n_heads == 0
|
|
self.d_k = d_model // n_heads
|
|
self.d_v = self.d_k
|
|
self.attention_layer = AttentionLayer()
|
|
self.W_Qs = nn.ModuleList([
|
|
nn.Linear(d_model, self.d_k, bias=False)
|
|
for _ in range(n_heads)
|
|
])
|
|
self.W_Ks = nn.ModuleList([
|
|
nn.Linear(d_model, self.d_k, bias=False)
|
|
for _ in range(n_heads)
|
|
])
|
|
self.W_Vs = nn.ModuleList([
|
|
nn.Linear(d_model, self.d_v, bias=False)
|
|
for _ in range(n_heads)
|
|
])
|
|
self.W_O = nn.Linear(d_model, d_model, bias=False)
|
|
|
|
def forward(self, x):
|
|
# x:float32[batch_size, sequence_length, self.d_model]
|
|
head_outputs = []
|
|
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
|
|
Q = W_Q(x)
|
|
# Q float32:[batch_size, sequence_length, self.d_k]
|
|
K = W_K(x)
|
|
# Q float32:[batch_size, sequence_length, self.d_k]
|
|
V = W_V(x)
|
|
# Q float32:[batch_size, sequence_length, self.d_v]
|
|
head_output = self.attention_layer(Q, K, V)
|
|
# float32:[batch_size, sequence_length, self.d_v]
|
|
head_outputs.append(head_output)
|
|
concatenated = torch.cat(head_outputs, dim=-1)
|
|
# concatenated float32:[batch_size, sequence_length, self.d_model]
|
|
out = self.W_O(concatenated)
|
|
# out float32:[batch_size, sequence_length, self.d_model]
|
|
return out
|
|
|
|
class Feedforward(nn.Module):
|
|
def __init__(self, d_model):
|
|
super(Feedforward, self).__init__()
|
|
self.d_model = d_model
|
|
self.W1 = nn.Linear(d_model, d_model)
|
|
self.W2 = nn.Linear(d_model, d_model)
|
|
|
|
def forward(self, x):
|
|
# x: float32[batch_size, sequence_length, d_model]
|
|
return self.W2(torch.relu(self.W1(x)))
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(self, d_model, n_heads, n_layers):
|
|
super(Transformer, self).__init__()
|
|
self.d_model = d_model
|
|
self.n_heads = n_heads
|
|
self.n_layers = n_layers
|
|
self.attention_layers = nn.ModuleList([
|
|
MultiHeadedSelfAttentionLayer(d_model, n_heads)
|
|
for _ in range(n_layers)
|
|
])
|
|
self.ffs = nn.ModuleList([
|
|
Feedforward(d_model)
|
|
for _ in range(n_layers)
|
|
])
|
|
|
|
def forward(self, x):
|
|
# x: float32[batch_size, sequence_length, self.d_model]
|
|
for attention_layer, ff in zip(self.attention_layers, self.ffs):
|
|
attention_out = attention_layer(x)
|
|
# attention_out: float32[batch_size, sequence_length, self.d_model]
|
|
x = F.layer_norm(x + attention_out, x.shape[2:])
|
|
ff_out = ff(x)
|
|
# ff_out: float32[batch_size, sequence_length, self.d_model]
|
|
x = F.layer_norm(x + ff_out, x.shape[2:])
|
|
return x
|