initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
234
models/backbones/moes_huggingface.py
Normal file
234
models/backbones/moes_huggingface.py
Normal file
|
@ -0,0 +1,234 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from timm.models.layers import DropPath
|
||||
import warnings
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
from .bert.xbert import BertLayer, BertAttention, BertIntermediate, BertOutput, BertConfig
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, config, expert_type):
|
||||
super(MoELayer, self).__init__()
|
||||
self.config = config
|
||||
self.expert_type = expert_type
|
||||
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
|
||||
|
||||
# Shared across all experts
|
||||
self.attention = BertAttention(self.bert_config)
|
||||
|
||||
# One for each expert
|
||||
if expert_type == 'modalities':
|
||||
# Spatial expert
|
||||
self.intermediate_spatial = BertIntermediate(self.bert_config)
|
||||
self.output_spatial = BertOutput(self.bert_config)
|
||||
|
||||
# Temporal expert
|
||||
self.intermediate_temporal = BertIntermediate(self.bert_config)
|
||||
self.output_temporal = BertOutput(self.bert_config)
|
||||
|
||||
# Vis Expert
|
||||
self.intermediate_vis = BertIntermediate(self.bert_config)
|
||||
self.output_vis = BertOutput(self.bert_config)
|
||||
|
||||
# Caption Expert
|
||||
self.intermediate_caption = BertIntermediate(self.bert_config)
|
||||
self.output_caption = BertOutput(self.bert_config)
|
||||
|
||||
if config.stage != 'stage_1':
|
||||
# History Expert
|
||||
self.intermediate_history = BertIntermediate(self.bert_config)
|
||||
self.output_history = BertOutput(self.bert_config)
|
||||
|
||||
# Fusion expert
|
||||
elif expert_type == 'fusion':
|
||||
self.intermediate_fusion = BertIntermediate(self.bert_config)
|
||||
self.output_fusion = BertOutput(self.bert_config)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for _, m in dict(self.named_modules()).items():
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.bias.data.zero_()
|
||||
m.weight.data.fill_(1.0)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
def get_extended_attention_mask(
|
||||
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
attention_mask (`torch.Tensor`):
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (`Tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
||||
"""
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
|
||||
if not (attention_mask.dim() == 2 and self.bert_config.is_decoder):
|
||||
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
||||
if device is not None:
|
||||
warnings.warn(
|
||||
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||
)
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
# if self.config.is_decoder:
|
||||
# extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
|
||||
# input_shape, attention_mask, device
|
||||
# )
|
||||
# else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
|
||||
return extended_attention_mask
|
||||
|
||||
|
||||
def forward(self, hidden_states, special_toks_indices, expert_flag, mask=None, only_text=False, output_attentions=False):
|
||||
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
# dtype = mask.dtype
|
||||
# device = mask.device
|
||||
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape, dtype=torch.float32)
|
||||
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
head_mask=None
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
# outputs = self_attention_outputs[1:]
|
||||
|
||||
len_init = attention_output.size(1)
|
||||
# bs, h_dim = x.size(0), x.size(-1)
|
||||
# device = x.device
|
||||
|
||||
|
||||
if expert_flag == 'modalities':
|
||||
if only_text:
|
||||
intermediate_output = self.intermediate_caption(attention_output)
|
||||
layer_output = self.output_caption(intermediate_output, attention_output)
|
||||
else:
|
||||
# split the input first into different parts/modalities
|
||||
unchanged = attention_output[:, :special_toks_indices['<vis>'], :]
|
||||
end_idx_spatial = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
|
||||
attention_spatial = attention_output[:, special_toks_indices['<vis>']:end_idx_spatial, :]
|
||||
|
||||
end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
|
||||
attention_caption = attention_output[:, special_toks_indices['<caption>']: end_idx_caption, :]
|
||||
|
||||
attention_temporal, attention_history = None, None
|
||||
|
||||
if '<temporal>' in special_toks_indices:
|
||||
end_idx_temporal = special_toks_indices['<caption>']
|
||||
attention_temporal = attention_output[:, special_toks_indices['<temporal>']:end_idx_temporal, :]
|
||||
|
||||
if '<history>' in special_toks_indices:
|
||||
end_idx_history = special_toks_indices['</s>'] + 1
|
||||
attention_history = attention_output[:, special_toks_indices['<history>']:end_idx_history, :]
|
||||
|
||||
# Expert activation
|
||||
# 1- Spatial
|
||||
intermediate_spatial = self.intermediate_spatial(attention_spatial)
|
||||
output_sapatial = self.output_spatial(intermediate_spatial, attention_spatial)
|
||||
|
||||
output_vis = output_sapatial
|
||||
|
||||
# 2- Temporal
|
||||
if attention_temporal is not None:
|
||||
intermediate_temporal = self.intermediate_temporal(attention_temporal)
|
||||
output_temporal = self.output_temporal(intermediate_temporal, attention_temporal)
|
||||
|
||||
attention_vis = torch.concat([output_sapatial, output_temporal], dim=1)
|
||||
intermediate_vis = self.intermediate_vis(attention_vis)
|
||||
output_vis = self.output_vis(intermediate_vis, attention_vis)
|
||||
|
||||
# 3- Caption
|
||||
intermediate_caption = self.intermediate_caption(attention_caption)
|
||||
output_caption = self.output_caption(intermediate_caption, attention_caption)
|
||||
|
||||
# 4- History
|
||||
if attention_history is not None:
|
||||
intermediate_history = self.intermediate_history(attention_history)
|
||||
output_history = self.output_history(intermediate_history, attention_history)
|
||||
|
||||
output_list = [unchanged, output_vis, output_caption]
|
||||
|
||||
if attention_history is not None:
|
||||
output_list.append(output_history)
|
||||
|
||||
# Concat the features back
|
||||
layer_output = torch.concat(output_list, dim=1)
|
||||
assert layer_output.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
|
||||
layer_output.size(1), len_init
|
||||
)
|
||||
|
||||
elif expert_flag == 'fusion':
|
||||
intermediate_output = self.intermediate_fusion(attention_output)
|
||||
layer_output = self.output_fusion(intermediate_output, attention_output)
|
||||
|
||||
return layer_output
|
||||
|
||||
|
||||
class MoEPooler(nn.Module):
|
||||
def __init__(self):
|
||||
super(MoEPooler, self).__init__()
|
||||
|
||||
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
|
||||
hidden_size = self.bert_config.hidden_size
|
||||
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for _, m in dict(self.named_modules()).items():
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.bias.data.zero_()
|
||||
m.weight.data.fill_(1.0)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, hidden_states, idx):
|
||||
pooled_states = hidden_states[:, idx]
|
||||
pooled_output = self.dense(pooled_states)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
Loading…
Add table
Add a link
Reference in a new issue