import torch import torch.nn as nn import math class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor, shape [seq_len, batch_size, embedding_dim] """ x = x + self.pe[:x.size(0)] return self.dropout(x) class TransformerEncoder(nn.Module): def __init__( self, d_model, nhead, dim_feedforward, transformer_dropout, transformer_activation, num_encoder_layers, max_input_len, transformer_norm_input ): super().__init__() self.d_model = d_model self.num_layer = num_encoder_layers self.max_input_len = max_input_len # Creating Transformer Encoder Model encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=transformer_dropout, activation=transformer_activation ) encoder_norm = nn.LayerNorm(d_model) self.transformer = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) self.norm_input = None if transformer_norm_input: self.norm_input = nn.LayerNorm(d_model) def forward(self, padded_h_node, src_padding_mask): """ padded_h_node: n_b x B x h_d # 63, 257, 128 src_key_padding_mask: B x n_b # 257, 63 """ # (S, B, h_d), (B, S) if self.norm_input is not None: padded_h_node = self.norm_input(padded_h_node) transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) # (S, B, h_d) return transformer_out, src_padding_mask if __name__ == '__main__': model = TransformerEncoder( d_model=12, nhead=4, dim_feedforward=32, transformer_dropout=0.0, transformer_activation='gelu', num_encoder_layers=4, max_input_len=34, transformer_norm_input=0 ) print(model.norm_input)