human-gaze-guided-neural-at.../joint_paraphrase_model/libs/fixation_generation/self_attention.py

132 lines
4.8 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
# Not a parameter
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
class AttentionLayer(nn.Module):
def __init__(self):
super(AttentionLayer, self).__init__()
def forward(self, Q, K, V):
# Q: float32:[batch_size, n_queries, d_k]
# K: float32:[batch_size, n_keys, d_k]
# V: float32:[batch_size, n_keys, d_v]
dk = K.shape[-1]
dv = V.shape[-1]
KT = torch.transpose(K, -1, -2)
weight_logits = torch.bmm(Q, KT) / math.sqrt(dk)
# weight_logits: float32[batch_size, n_queries, n_keys]
weights = F.softmax(weight_logits, dim=-1)
# weight: float32[batch_size, n_queries, n_keys]
return torch.bmm(weights, V)
# return float32[batch_size, n_queries, dv]
class MultiHeadedSelfAttentionLayer(nn.Module):
def __init__(self, d_model, n_heads):
super(MultiHeadedSelfAttentionLayer, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
print('{} {}'.format(d_model, n_heads))
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.d_v = self.d_k
self.attention_layer = AttentionLayer()
self.W_Qs = nn.ModuleList([
nn.Linear(d_model, self.d_k, bias=False)
for _ in range(n_heads)
])
self.W_Ks = nn.ModuleList([
nn.Linear(d_model, self.d_k, bias=False)
for _ in range(n_heads)
])
self.W_Vs = nn.ModuleList([
nn.Linear(d_model, self.d_v, bias=False)
for _ in range(n_heads)
])
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
# x:float32[batch_size, sequence_length, self.d_model]
head_outputs = []
for W_Q, W_K, W_V in zip(self.W_Qs, self.W_Ks, self.W_Vs):
Q = W_Q(x)
# Q float32:[batch_size, sequence_length, self.d_k]
K = W_K(x)
# Q float32:[batch_size, sequence_length, self.d_k]
V = W_V(x)
# Q float32:[batch_size, sequence_length, self.d_v]
head_output = self.attention_layer(Q, K, V)
# float32:[batch_size, sequence_length, self.d_v]
head_outputs.append(head_output)
concatenated = torch.cat(head_outputs, dim=-1)
# concatenated float32:[batch_size, sequence_length, self.d_model]
out = self.W_O(concatenated)
# out float32:[batch_size, sequence_length, self.d_model]
return out
class Feedforward(nn.Module):
def __init__(self, d_model):
super(Feedforward, self).__init__()
self.d_model = d_model
self.W1 = nn.Linear(d_model, d_model)
self.W2 = nn.Linear(d_model, d_model)
def forward(self, x):
# x: float32[batch_size, sequence_length, d_model]
return self.W2(torch.relu(self.W1(x)))
class Transformer(nn.Module):
def __init__(self, d_model, n_heads, n_layers):
super(Transformer, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.attention_layers = nn.ModuleList([
MultiHeadedSelfAttentionLayer(d_model, n_heads)
for _ in range(n_layers)
])
self.ffs = nn.ModuleList([
Feedforward(d_model)
for _ in range(n_layers)
])
def forward(self, x):
# x: float32[batch_size, sequence_length, self.d_model]
for attention_layer, ff in zip(self.attention_layers, self.ffs):
attention_out = attention_layer(x)
# attention_out: float32[batch_size, sequence_length, self.d_model]
x = F.layer_norm(x + attention_out, x.shape[2:])
ff_out = ff(x)
# ff_out: float32[batch_size, sequence_length, self.d_model]
x = F.layer_norm(x + ff_out, x.shape[2:])
return x