85 lines
3.2 KiB
Python
85 lines
3.2 KiB
Python
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
|
|
from typing import Optional, Any, Union, Callable
|
|
|
|
from torch import nn
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn import functional as F
|
|
from torch.nn.modules import Module
|
|
from torch.nn import MultiheadAttention
|
|
#from nn.container import ModuleList
|
|
#from ..init import xavier_uniform_
|
|
from torch.nn import Dropout
|
|
from torch.nn import Linear
|
|
from torch.nn import LayerNorm
|
|
|
|
|
|
class CustomTransformerEncoderLayer(Module):
|
|
|
|
__constants__ = ['batch_first', 'norm_first']
|
|
|
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
|
device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
|
**factory_kwargs)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
|
self.dropout = Dropout(dropout)
|
|
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
|
|
|
self.norm_first = norm_first
|
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
|
self.dropout1 = Dropout(dropout)
|
|
self.dropout2 = Dropout(dropout)
|
|
|
|
def __setstate__(self, state):
|
|
if 'activation' not in state:
|
|
state['activation'] = F.relu
|
|
super().__setstate__(state)
|
|
|
|
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
|
r"""Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
src: the sequence to the encoder layer (required).
|
|
src_mask: the mask for the src sequence (optional).
|
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
|
|
Shape:
|
|
see the docs in Transformer class.
|
|
"""
|
|
|
|
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
|
|
|
x = src
|
|
if self.norm_first:
|
|
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
|
x = x + self._ff_block(self.norm2(x))
|
|
else:
|
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
|
x = self.norm2(x + self._ff_block(x))
|
|
|
|
return x
|
|
|
|
|
|
# self-attention block
|
|
def _sa_block(self, x: Tensor,
|
|
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
|
|
x = self.self_attn(x, x, x,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=True)[0]
|
|
return self.dropout1(x)
|
|
|
|
# feed forward block
|
|
def _ff_block(self, x: Tensor) -> Tensor:
|
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
return self.dropout2(x)
|
|
|