import torch import torch.nn as nn from transformers.models.llama.modeling_llama import LlamaRMSNorm from timm.models.layers import DropPath import warnings from torch import Tensor from typing import Optional, Tuple from .bert.xbert import BertLayer, BertAttention, BertIntermediate, BertOutput, BertConfig class MoELayer(nn.Module): def __init__(self, config, expert_type): super(MoELayer, self).__init__() self.config = config self.expert_type = expert_type self.bert_config = BertConfig.from_pretrained('bert-large-uncased') # Shared across all experts self.attention = BertAttention(self.bert_config) # One for each expert if expert_type == 'modalities': # Spatial expert self.intermediate_spatial = BertIntermediate(self.bert_config) self.output_spatial = BertOutput(self.bert_config) # Temporal expert self.intermediate_temporal = BertIntermediate(self.bert_config) self.output_temporal = BertOutput(self.bert_config) # Vis Expert self.intermediate_vis = BertIntermediate(self.bert_config) self.output_vis = BertOutput(self.bert_config) # Caption Expert self.intermediate_caption = BertIntermediate(self.bert_config) self.output_caption = BertOutput(self.bert_config) if config.stage != 'stage_1': # History Expert self.intermediate_history = BertIntermediate(self.bert_config) self.output_history = BertOutput(self.bert_config) # Fusion expert elif expert_type == 'fusion': self.intermediate_fusion = BertIntermediate(self.bert_config) self.output_fusion = BertOutput(self.bert_config) else: raise ValueError self._init_weights() def _init_weights(self): for _, m in dict(self.named_modules()).items(): if isinstance(m, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range) elif isinstance(m, nn.LayerNorm): m.bias.data.zero_() m.weight.data.fill_(1.0) if isinstance(m, nn.Linear) and m.bias is not None: m.bias.data.zero_() def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (`Tuple[int]`): The shape of the input to the model. Returns: `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. """ if dtype is None: dtype = self.dtype if not (attention_mask.dim() == 2 and self.bert_config.is_decoder): # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` if device is not None: warnings.warn( "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] # if self.config.is_decoder: # extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( # input_shape, attention_mask, device # ) # else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min return extended_attention_mask def forward(self, hidden_states, special_toks_indices, expert_flag, mask=None, only_text=False, output_attentions=False): input_shape = hidden_states.size()[:-1] # dtype = mask.dtype # device = mask.device extended_attention_mask = self.get_extended_attention_mask(mask, input_shape, dtype=torch.float32) self_attention_outputs = self.attention( hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions, head_mask=None ) attention_output = self_attention_outputs[0] # outputs = self_attention_outputs[1:] len_init = attention_output.size(1) # bs, h_dim = x.size(0), x.size(-1) # device = x.device if expert_flag == 'modalities': if only_text: intermediate_output = self.intermediate_caption(attention_output) layer_output = self.output_caption(intermediate_output, attention_output) else: # split the input first into different parts/modalities unchanged = attention_output[:, :special_toks_indices[''], :] end_idx_spatial = special_toks_indices.get('', special_toks_indices['']) attention_spatial = attention_output[:, special_toks_indices['']:end_idx_spatial, :] end_idx_caption = special_toks_indices.get('', special_toks_indices[''] + 1) attention_caption = attention_output[:, special_toks_indices['']: end_idx_caption, :] attention_temporal, attention_history = None, None if '' in special_toks_indices: end_idx_temporal = special_toks_indices[''] attention_temporal = attention_output[:, special_toks_indices['']:end_idx_temporal, :] if '' in special_toks_indices: end_idx_history = special_toks_indices[''] + 1 attention_history = attention_output[:, special_toks_indices['']:end_idx_history, :] # Expert activation # 1- Spatial intermediate_spatial = self.intermediate_spatial(attention_spatial) output_sapatial = self.output_spatial(intermediate_spatial, attention_spatial) output_vis = output_sapatial # 2- Temporal if attention_temporal is not None: intermediate_temporal = self.intermediate_temporal(attention_temporal) output_temporal = self.output_temporal(intermediate_temporal, attention_temporal) attention_vis = torch.concat([output_sapatial, output_temporal], dim=1) intermediate_vis = self.intermediate_vis(attention_vis) output_vis = self.output_vis(intermediate_vis, attention_vis) # 3- Caption intermediate_caption = self.intermediate_caption(attention_caption) output_caption = self.output_caption(intermediate_caption, attention_caption) # 4- History if attention_history is not None: intermediate_history = self.intermediate_history(attention_history) output_history = self.output_history(intermediate_history, attention_history) output_list = [unchanged, output_vis, output_caption] if attention_history is not None: output_list.append(output_history) # Concat the features back layer_output = torch.concat(output_list, dim=1) assert layer_output.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format( layer_output.size(1), len_init ) elif expert_flag == 'fusion': intermediate_output = self.intermediate_fusion(attention_output) layer_output = self.output_fusion(intermediate_output, attention_output) return layer_output class MoEPooler(nn.Module): def __init__(self): super(MoEPooler, self).__init__() self.bert_config = BertConfig.from_pretrained('bert-large-uncased') hidden_size = self.bert_config.hidden_size self.dense = nn.Linear(hidden_size, hidden_size) self.activation = nn.Tanh() self._init_weights() def _init_weights(self): for _, m in dict(self.named_modules()).items(): if isinstance(m, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range) elif isinstance(m, nn.LayerNorm): m.bias.data.zero_() m.weight.data.fill_(1.0) if isinstance(m, nn.Linear) and m.bias is not None: m.bias.data.zero_() def forward(self, hidden_states, idx): pooled_states = hidden_states[:, idx] pooled_output = self.dense(pooled_states) pooled_output = self.activation(pooled_output) return pooled_output