V2Dial/models/v2dial.py
2025-07-10 07:31:58 +02:00

1417 lines
58 KiB
Python

import json
import glog as logging
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from .backbones.blip2 import Blip2Base, disabled_train
from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration
from .backbones.moes import MoELayer, Pooler
from .modules.temporal_modelling import SpatialAttention, TemporalAttention
from .common.dist_utils import concat_all_gather, all_gather_with_grad
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from transformers import BitsAndBytesConfig
from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_kbit_training,
set_peft_model_state_dict,
)
import time
import numpy as np
class V2DialAbstract(Blip2Base):
def __init__(self):
super(V2DialAbstract, self).__init__()
def shift_right(self, input_ids):
decoder_start_token_id = self.llm.config.decoder_start_token_id
pad_token_id = self.llm.config.pad_token_id
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
"See T5 docs for more information."
)
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def encode_vis(self, image, device, is_vid=True):
num_frames = image.size(1)
bs_pre_reshape = image.size(0)
if len(image.shape) > 4:
image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
# with self.maybe_autocast(): # inherited from Blip2Base
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
bs, pn, hs = image_embeds.shape
if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632)
image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
vis_embed = self.vit_proj(image_embeds) # project to LLM input size (200,64,5632) -> (200,64, d_hidden)
# reshape the video features
vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1))
# Perfrom spatial temporal attention
vis_embed_spatial = self.spatial_att(vis_embed)
vis_feat_len = vis_embed_spatial.size(1)
if not self.config.embed_from_llm:
vis_embed_spatial = vis_embed_spatial + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device))
vis_spatial_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device)
vis_embed_temporal, vis_temporal_mask = None, None
if is_vid:
vis_embed_temporal = self.temporal_att(vis_embed)
if not self.config.embed_from_llm:
vis_embed_temporal = vis_embed_temporal + self.token_type_embedding(torch.ones(bs_pre_reshape, vis_feat_len).long().to(device))
vis_temporal_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device)
return vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask
def tokenize_text(self, text, device, add_bos=False, add_eos=False, max_len=None):
if max_len:
text_tokenized = self.tokenizer(
text,
return_tensors='pt',
padding='max_length',
max_length=max_len,
truncation=True,
add_special_tokens=False,
return_special_tokens_mask=True
).to(device)
else:
text_tokenized = self.tokenizer(
text,
return_tensors='pt',
padding='longest',
add_special_tokens=False,
return_special_tokens_mask=True
).to(device)
text_ids = text_tokenized.input_ids
text_attention_mask = text_tokenized.attention_mask
if add_bos:
bos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.bos_token_id).to(device)
bos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device)
text_ids = torch.cat([bos_ids, text_ids], dim=1)
text_attention_mask = torch.cat([bos_att, text_attention_mask], dim=1)
if add_eos:
eos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.eos_token_id).to(device)
eos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device)
text_ids = torch.cat([text_ids, eos_ids], dim=1)
text_attention_mask = torch.cat([text_attention_mask, eos_att], dim=1)
return text_ids, text_attention_mask
def get_extended_attention_mask(self, attention_mask=None):
if attention_mask.dim() == 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask.unsqueeze(1)
else:
raise NotImplementedError
return extended_attention_mask
@staticmethod
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class V2DialBase(V2DialAbstract):
def __init__(self, config):
super(V2DialBase, self).__init__()
self.config = config
################## 1. Select Tokenizer -- We use BERT tokenizer ##################
bert_config = BertConfig.from_pretrained('bert-{}-uncased'.format(config.expert_size))
tokenizer = AutoTokenizer.from_pretrained('bert-{}-uncased'.format(config.expert_size))
text_embedding = BertEmbeddings(bert_config)
text_embedding.apply(self.init_weights)
token_type_embedding = nn.Embedding(3, bert_config.hidden_size) # Number of modality types (temp/spa/text)
token_type_embedding.apply(self.init_weights)
# Define the masking strategy
mlm_collactor = DataCollatorForLanguageModeling(
tokenizer, mlm=True, mlm_probability=config.masking_prob, return_tensors='pt')
################## 2. Select the backbone ViT ##################
logging.info('[INFO] Loading ViT in progress')
if config.freeze_vit:
# vit_precision = 'fp16' if config.fp16 else 'fp32'
logging.info(f'[INFO] ViT precision: {config.vit_precision}')
visual_encoder, ln_vision = self.init_vision_encoder(
config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, precision=config.vit_precision
)
for name, param in visual_encoder.named_parameters():
param.requires_grad = False
visual_encoder = visual_encoder.eval()
visual_encoder.train = disabled_train
for name, param in ln_vision.named_parameters():
param.requires_grad = False
ln_vision = ln_vision.eval()
ln_vision.train = disabled_train
logging.info('[INFO] ViT frozen')
else:
vit_precision = 'fp32'
visual_encoder, ln_vision = self.init_vision_encoder(
config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, vit_precision=vit_precision
)
logging.info('[INFO] ViT hot')
logging.info('[INFO] ViT successfully loaded')
################## 3. Define the ViT-Expert communication Interface ##################
self.system_prompt = False
self.vit_token_pooling = config.vit_token_pooling
if self.vit_token_pooling:
vit_proj = nn.Linear(
1408*4, bert_config.hidden_size
)
else:
vit_proj = nn.Linear(
1408, bert_config.hidden_size
)
vit_proj.apply(self.init_weights)
spatial_att = SpatialAttention(input_dim=bert_config.hidden_size)
temporal_att = TemporalAttention(input_dim=bert_config.hidden_size)
spatial_att.apply(self.init_weights)
temporal_att.apply(self.init_weights)
################## 4. Define the Expert layers ##################
moe_layers = []
for moe_layer_idx in range(config.num_moe_layers):
if moe_layer_idx < self.config.num_moe_modality_layers:
expert_flag = 'modalities'
else:
expert_flag = 'fusion'
moe_layer = MoELayer(
bert_config.hidden_size,
bert_config.num_attention_heads,
expert_flag,
use_sep_spatial_temp_experts=config.use_sep_spatial_temp_experts
)
moe_layer.apply(self.init_weights)
moe_layers.append(moe_layer)
logging.info(f'[INFO] {moe_layer_idx+1}/{config.num_moe_layers} MoE layers successfully loaded')
moe_layers = nn.ModuleList(moe_layers)
moe_norm = nn.LayerNorm(bert_config.hidden_size)
################## 5. Define the projection layers for contrastive learning ##################
temp_proj = nn.Linear(bert_config.hidden_size, config.joint_dim)
spatial_proj = nn.Linear(bert_config.hidden_size, config.joint_dim)
vision_proj = nn.Linear(bert_config.hidden_size, config.joint_dim)
cap_proj = nn.Linear(bert_config.hidden_size, config.joint_dim)
temp_proj.apply(self.init_weights)
spatial_proj.apply(self.init_weights)
vision_proj.apply(self.init_weights)
cap_proj.apply(self.init_weights)
################## 6. Define the pooler for matching loss ##################
pooler = Pooler(bert_config.hidden_size)
pooler.apply(self.init_weights)
################## 5. Attach the matching heads ##################
stm_head = nn.Linear(bert_config.hidden_size, 2)
vcm_head = nn.Linear(bert_config.hidden_size, 2)
lm_head = nn.Linear(bert_config.hidden_size, len(tokenizer))
stm_head.apply(self.init_weights)
vcm_head.apply(self.init_weights)
lm_head.apply(self.init_weights)
temp = nn.Parameter(0.07 * torch.ones([]))
# temp = 0.07
# Attach the components to self
self.tokenizer = tokenizer
self.mlm_collactor = mlm_collactor
self.text_embedding = text_embedding
self.token_type_embedding = token_type_embedding
self.visual_encoder = visual_encoder
self.ln_vision = ln_vision
self.vit_proj = vit_proj
self.moe_layers = moe_layers
self.moe_norm = moe_norm
self.spatial_att = spatial_att
self.temporal_att = temporal_att
self.temp_proj = temp_proj
self.spatial_proj = spatial_proj
self.vision_proj = vision_proj
self.cap_proj = cap_proj
self.pooler = pooler
self.stm_head = stm_head
self.vcm_head = vcm_head
self.lm_head = lm_head
self.temp = temp
@staticmethod
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def build_query_embeds(self, num_query_tokens, dim_query_tokens):
query_embeds = nn.Parameter(
torch.zeros(1, num_query_tokens, dim_query_tokens)
)
query_embeds.data.normal_(mean=0.0, std=0.02)
return query_embeds
def encode_caption(self, cap):
cap_output = self.cap_expert(
input_ids=cap.input_ids,
attention_mask=cap.attention_mask,
return_dict=True,
)
cap_embeds = cap_output.last_hidden_state
pooled_cap_embeds = cap_embeds[:, 0]
return cap_embeds, pooled_cap_embeds
def encode_queries(self, query_embeds, vis_embeds, vis_mode):
if vis_mode == 'spatial':
expert = self.spatial_expert
layer_norm = self.spatial_layernorm
elif vis_mode == 'temporal':
expert = self.temporal_expert
layer_norm = self.temporal_layernorm
else:
raise ValueError(f'[ERROR] {vis_mode} not implemented!')
attention_mask = torch.ones(
query_embeds.size()[:-1], dtype=torch.long).to(vis_embeds.device)
vis_attention_mask = torch.ones(
vis_embeds.size()[:-1], dtype=torch.long).to(vis_embeds.device)
if self.config['expert_layer_type'] == 'bert':
output_dict = expert(
encoder_embeds=query_embeds,
encoder_hidden_states=vis_embeds,
encoder_attention_mask=vis_attention_mask,
)
query_embeds = layer_norm(output_dict.last_hidden_state)
pooled_query_embeds = output_dict.pooler_output
elif self.config['expert_layer_type'] == 'bart':
output_dict = expert(
inputs_embeds=query_embeds,
attention_mask=attention_mask,
cross_embeds=vis_embeds,
cross_attention_mask=vis_attention_mask,
)
query_embeds = layer_norm(output_dict.last_hidden_state)
pooled_query_embeds = query_embeds[:, 0]
return query_embeds, pooled_query_embeds
def encode_vis_with_seq_spa_temp_att(self, image, device, is_vid=True):
num_frames = image.size(1)
bs_pre_reshape = image.size(0)
if len(image.shape) > 4:
image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
# with self.maybe_autocast(): # inherited from Blip2Base
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
bs, pn, hs = image_embeds.shape
if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632)
image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
vis_embed = self.vit_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
# reshape the video features
vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1))
size_orig = vis_embed.size()
# Perfrom spatial temporal attention
vis_embed = self.spatial_att(vis_embed)
if is_vid:
vis_embed = vis_embed.view(size_orig)
vis_embed = self.temporal_att(vis_embed)
vis_feat_len = vis_embed.size(1)
vis_embed = vis_embed + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device))
vis_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device)
return vis_embed, vis_mask
def tokenize_text(self, text, device, add_bos=False, add_eos=False, max_len=None):
if max_len:
text_tokenized = self.tokenizer(
text,
return_tensors='pt',
padding='max_length',
max_length=max_len,
truncation=True,
add_special_tokens=False,
return_special_tokens_mask=True
).to(device)
else:
text_tokenized = self.tokenizer(
text,
return_tensors='pt',
padding='longest',
add_special_tokens=False,
return_special_tokens_mask=True
).to(device)
text_ids = text_tokenized.input_ids
text_attention_mask = text_tokenized.attention_mask
if add_bos:
bos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.bos_token_id).to(device)
bos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device)
text_ids = torch.cat([bos_ids, text_ids], dim=1)
text_attention_mask = torch.cat([bos_att, text_attention_mask], dim=1)
if add_eos:
eos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.eos_token_id).to(device)
eos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device)
text_ids = torch.cat([text_ids, eos_ids], dim=1)
text_attention_mask = torch.cat([text_attention_mask, eos_att], dim=1)
return text_ids, text_attention_mask
def encode_text(self, text, max_len, device):
text_tokenized = self.tokenizer(
text,
return_tensors='pt',
padding='max_length',
max_length=max_len,
truncation=True,
add_special_tokens=False
).to(device)
text_ids = text_tokenized.input_ids
text_embeds = self.embed(text_ids)
text_attention_mask = text_tokenized.attention_mask
return text_embeds, text_ids, text_attention_mask
def construct_global_input(self, cap_ids, cap_attention_mask, vid_feat_len, media_type, device):
# for video: <s><vis><spatial>[spatial_featurres]<temporal>[temporal_features]<caption>[caption_features]</s>
# for image: <s><vis><spatial>[spatial_featurres]<caption>[caption_features]</s>
batch_size = cap_ids.size(0)
special_toks_indices = {
'<s>': 0,
'<vis>': 1,
'<spatial>': 2,
}
ids = [self.added_vocab['<s>']] + [self.added_vocab['<vis>']] + [self.added_vocab['<spatial>']]
ids += vid_feat_len * [self.added_vocab['<pad>']]
if media_type == 'webvid':
ids += [self.added_vocab['<temporal>']]
special_toks_indices['<temporal>'] = len(ids) - 1
ids += vid_feat_len * [self.added_vocab['<pad>']]
ids += [self.added_vocab['<caption>']]
special_toks_indices['<caption>'] = len(ids) - 1
ids += cap_ids.size(1) * [self.added_vocab['<pad>']]
ids += [self.added_vocab['</s>']]
special_toks_indices['</s>'] = len(ids) - 1
total_len = len(ids)
ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
ids[:, special_toks_indices['<caption>'] + 1: special_toks_indices['</s>']] = cap_ids
mask = torch.ones((batch_size, total_len), device=device)
mask[:, special_toks_indices['<caption>'] + 1: special_toks_indices['</s>']] = cap_attention_mask
return ids, mask, special_toks_indices
def compute_contrastive_loss(self, x, y_all, y, x_all):
sim_x2y = torch.mm(x, y_all.t()) # (bs, bs*ngpus)
sim_x2y = sim_x2y / self.temp
sim_y2x = torch.mm(y, x_all.t()) # (bs, bs*ngpus)
sim_y2x = sim_y2x / self.temp
rank = dist.get_rank() if self.config['distributed'] else 0
bs = x.size(0)
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
x.device
)
loss_contrastive = (
F.cross_entropy(sim_x2y, targets, label_smoothing=0.1)
+ F.cross_entropy(sim_y2x, targets, label_smoothing=0.1)
) / 2
return loss_contrastive, sim_x2y, sim_y2x
def get_extended_attention_mask(self, attention_mask=None):
if attention_mask.dim() == 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask.unsqueeze(1)
else:
raise NotImplementedError
return extended_attention_mask
def shared_forward(
self,
vis_spatial, vis_spatial_mask, vis_temporal, vis_temporal_mask,
cap_ids, cap_mask, is_vid, device):
# is_vid = media_type == 'webvid'
# batch_size = len(cap)
vis_feat_len = vis_spatial.size(1)
input_embeds = []
input_masks = []
input_embeds.append(vis_spatial)
input_masks.append(vis_spatial_mask)
if is_vid:
input_embeds.append(vis_temporal)
input_masks.append(vis_temporal_mask)
cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2))
cap_feat_len = cap_embeds.size(1)
input_embeds.append(cap_embeds)
input_masks.append(cap_mask)
input_embeds = torch.cat(input_embeds, dim=1)
input_masks = torch.cat(input_masks, dim=1)
# expand the mask
input_masks = self.get_extended_attention_mask(attention_mask=input_masks)
# MoEs feed-forward
for moe_layer_idx, moe_layer in enumerate(self.moe_layers):
if moe_layer_idx < self.config.num_moe_modality_layers:
expert_flag = 'modalities'
else:
expert_flag = 'fusion'
input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, is_vid=is_vid, mask=input_masks)
#TODO normalize the output () !!!!!!
input_embeds = self.moe_norm(input_embeds)
# return the features
spatial_feats = input_embeds[:, :vis_feat_len]
temporal_feats = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None
cap_feats = input_embeds[:, -cap_feat_len:]
cls_feats = self.pooler(cap_feats)
moe_outputs = {
'spatial_feats': spatial_feats,
'temporal_feats': temporal_feats,
'cap_feats': cap_feats,
'cls_feats': cls_feats,
}
return moe_outputs
def shared_forward_no_sep_spatial_temporal_experts(
self,
vis, vis_mask,
cap_ids, cap_mask, is_vid, device):
vis_feat_len = vis.size(1)
input_embeds = []
input_masks = []
input_embeds.append(vis)
input_masks.append(vis_mask)
cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2))
cap_feat_len = cap_embeds.size(1)
input_embeds.append(cap_embeds)
input_masks.append(cap_mask)
input_embeds = torch.cat(input_embeds, dim=1)
input_masks = torch.cat(input_masks, dim=1)
# expand the mask
input_masks = self.get_extended_attention_mask(attention_mask=input_masks)
# MoEs feed-forward
for moe_layer_idx, moe_layer in enumerate(self.moe_layers):
if moe_layer_idx < self.config.num_moe_modality_layers:
expert_flag = 'modalities'
else:
expert_flag = 'fusion'
input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, is_vid=is_vid, mask=input_masks)
input_embeds = self.moe_norm(input_embeds)
# return the features
vis_feats = input_embeds[:, :vis_feat_len]
cap_feats = input_embeds[:, -cap_feat_len:]
cls_feats = self.pooler(cap_feats)
moe_outputs = {
'vis_feats': vis_feats,
'cap_feats': cap_feats,
'cls_feats': cls_feats,
}
return moe_outputs
def vcm_iteration(self, vis, cap, neg_vis, is_vid, device):
# Prepare the vis data
# is_vid = media_type == 'webvid'
num_positive_samples = len(cap) // 2
num_negative_samples = len(cap) - num_positive_samples
vcm_labels = torch.cat([torch.ones(num_positive_samples), torch.zeros(num_negative_samples)]).to(device)
vcm_labels = vcm_labels[torch.randperm(vcm_labels.size(0))].long()
# now get the mixed vis data
vis_mixed = [p if vcm_labels[i] == 1 else n for i, (p, n) in enumerate(zip(vis, neg_vis))]
vis_mixed = torch.stack(vis_mixed, dim=0)
cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len)
if self.config.use_sep_spatial_temp_experts:
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis_mixed, device, is_vid=is_vid)
moe_outputs = self.shared_forward(
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device)
else:
vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts(
vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device)
vcm_logits = self.vcm_head(moe_outputs['cls_feats'])
loss_vcm = F.cross_entropy(vcm_logits, vcm_labels)
return loss_vcm
def stm_iteration(self, vis, cap, neg_vis, is_vid, device):
num_positive_samples = len(cap) // 2
num_negative_samples = len(cap) - num_positive_samples
stm_labels = torch.cat([torch.ones(num_positive_samples), torch.zeros(num_negative_samples)]).to(device)
stm_labels = stm_labels[torch.randperm(stm_labels.size(0))].long()
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid)
neg_vis_embed_spatial, _ , neg_vis_embed_temporal, _ = self.encode_vis(neg_vis, device, is_vid=is_vid)
# now get the mixed vis data
vis_embed_spatial_mixed = []
vis_embed_temporal_mixed = []
for i, (pos_spatial, pos_temporal, neg_spatial, neg_temporal) in enumerate(
zip(vis_embed_spatial, vis_embed_temporal, S, neg_vis_embed_temporal)):
if stm_labels[i] == 1:
vis_embed_spatial_mixed.append(pos_spatial)
vis_embed_temporal_mixed.append(pos_temporal)
else:
# 50% negative spatial / 50% negative temporal
if torch.rand(1).item() < 0.5:
vis_embed_spatial_mixed.append(pos_spatial)
vis_embed_temporal_mixed.append(neg_temporal)
else:
vis_embed_spatial_mixed.append(neg_spatial)
vis_embed_temporal_mixed.append(pos_temporal)
vis_embed_spatial_mixed = torch.stack(vis_embed_spatial_mixed, dim=0)
vis_embed_temporal_mixed = torch.stack(vis_embed_temporal_mixed, dim=0)
cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len)
moe_outputs = self.shared_forward(
vis_embed_spatial_mixed, vis_spatial_mask, vis_embed_temporal_mixed, vis_temporal_mask, cap_ids, cap_mask, is_vid, device)
stm_logits = self.vcm_head(moe_outputs['cls_feats'])
loss_stm = F.cross_entropy(stm_logits, stm_labels)
return loss_stm
def mlm_iteration(self, vis, cap, is_vid, device):
if self.config.use_sep_spatial_temp_experts:
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid)
else:
vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len)
cap_ids = cap_ids.tolist()
# NOTE We make sure to mask some tokens here to avoid nan loss later
mlm_output = self.mlm_collactor(cap_ids)
cap_ids = mlm_output['input_ids'].to(device)
labels_cap = mlm_output['labels'].to(device)
if self.config.use_sep_spatial_temp_experts:
moe_outputs = self.shared_forward(
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device)
else:
moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts(
vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device)
mlm_logits = self.lm_head(moe_outputs['cap_feats'])
loss_mlm = F.cross_entropy(mlm_logits.view(-1, mlm_logits.size(-1)), labels_cap.view(-1))
return loss_mlm
def vcc_iteration(self, vis, cap, is_vid, device):
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid)
cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len)
if self.config.use_sep_spatial_temp_experts:
moe_outputs = self.shared_forward(
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device)
vis_feats = moe_outputs['spatial_feats']
if is_vid:
vis_feats = torch.cat([moe_outputs['spatial_feats'], moe_outputs['temporal_feats']], dim=1)
else:
vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts(
vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device)
vis_feats = moe_outputs['vis_feats']
cap_feats = F.normalize(self.cap_proj(moe_outputs['cls_feats']), dim=-1)
vis_feats = F.normalize(self.vision_proj(vis_feats), dim=-1)
vis_feats_all = concat_all_gather(vis_feats)
cap_feats_all = concat_all_gather(cap_feats)
sim_v2c = torch.matmul(
vis_feats.unsqueeze(1), cap_feats_all.unsqueeze(-1)
).squeeze()
sim_v2c, _ = sim_v2c.max(-1)
sim_v2c = sim_v2c / self.temp
sim_c2v = torch.matmul(
cap_feats.unsqueeze(1).unsqueeze(1), vis_feats_all.permute(0, 2, 1)
).squeeze()
sim_c2v, _ = sim_c2v.max(-1)
sim_c2v = sim_c2v / self.temp
rank = dist.get_rank() if self.config['distributed'] else 0
bs = vis_feats.size(0)
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
device
)
loss_vcc = (
F.cross_entropy(sim_v2c, targets, label_smoothing=0.1)
+ F.cross_entropy(sim_c2v, targets, label_smoothing=0.1)
) / 2
return loss_vcc
def stc_iteration(self, vis, cap, is_vid, device):
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid)
cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len)
moe_outputs = self.shared_forward(
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device)
spatial_feats = F.normalize(self.spatial_proj(moe_outputs['spatial_feats']), dim=-1)
temporal_feats = F.normalize(self.temp_proj(moe_outputs['temporal_feats']), dim=-1)
spatial_feats_all = concat_all_gather(spatial_feats)
temporal_feats_all = concat_all_gather(temporal_feats)
sim_s2t = torch.matmul(
spatial_feats.unsqueeze(1), temporal_feats_all
)
sim_s2t, _ = sim_s2t.max(-1)
sim_s2t, _ = sim_s2t.max(-1)
sim_s2t = sim_s2t / self.temp
sim_t2s = torch.matmul(
temporal_feats.unsqueeze(1), spatial_feats_all
)
sim_t2s, _ = sim_t2s.max(-1)
sim_t2s, _ = sim_t2s.max(-1)
sim_t2s = sim_t2s / self.temp
rank = dist.get_rank() if self.config['distributed'] else 0
bs = vis.size(0)
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
device
)
loss_stc = (
F.cross_entropy(sim_s2t, targets, label_smoothing=0.1)
+ F.cross_entropy(sim_t2s, targets, label_smoothing=0.1)
) / 2
return loss_stc
def forward(self, vis, cap, neg_vis, media_type):
device = vis.device
is_vid = media_type == 'webvid'
loss_stc = torch.tensor(0).to(device)
loss_stm = torch.tensor(0).to(device)
loss_vcc = torch.tensor(0).to(device)
loss_vcm = torch.tensor(0).to(device)
loss_mlm = torch.tensor(0).to(device)
if self.config.loss_dict['vcm'] != 0:
loss_vcm = self.vcm_iteration(vis, cap, neg_vis, is_vid, device)
if self.config.loss_dict['vcc'] != 0:
loss_vcc = self.vcc_iteration(vis, cap, is_vid, device)
if self.config.loss_dict['stm'] != 0 and is_vid:
loss_stm = self.stm_iteration(vis, cap, neg_vis, is_vid, device)
if self.config.loss_dict['stc'] != 0 and is_vid:
loss_stc = self.stc_iteration(vis, cap, is_vid, device)
if self.config.loss_dict['mlm'] != 0:
loss_mlm = self.mlm_iteration(vis, cap, is_vid, device)
return dict(
loss_stc = loss_stc * self.config.loss_dict['stc'],
loss_stm = loss_stm * self.config.loss_dict['stm'],
loss_vcc = loss_vcc * self.config.loss_dict['vcc'],
loss_vcm = loss_vcm * self.config.loss_dict['vcm'],
loss_mlm = loss_mlm * self.config.loss_dict['mlm'],
)
def get_vis_enc_for_eval(self, vis, media_type):
# First get the visual features depending on the media type
vis_spatial_embed, vis_temporal_embed = self.encode_vis(vis, media_type)
# Expand the query tokens
spatial_query_embeds = self.spatial_query_embeds.expand(vis_spatial_embed.size(0), -1, -1)
# Run the spatial expert
spatial_query_embeds, pooled_spatial_query_embeds = self.encode_queries(
spatial_query_embeds, vis_spatial_embed, vis_mode='spatial')
temporal_query_embeds = self.spatial_query_embeds.expand(vis_temporal_embed.size(0), -1, -1)
temporal_query_embeds, pooled_temporal_query_embeds = self.encode_queries(
temporal_query_embeds, vis_temporal_embed, vis_mode='temporal')
vis_pooled = torch.cat((pooled_spatial_query_embeds, pooled_temporal_query_embeds), dim=1)
vis_embeds = torch.cat((spatial_query_embeds, temporal_query_embeds), dim=1)
return vis_embeds, vis_pooled
def get_expert_encoder(self, expert):
"""get text encoder, used for text and cross-modal encoding"""
encoder = None
if expert == 'cap':
encoder = self.cap_expert
if expert == 'spatial':
encoder = self.spatial_expert
if expert == 'temporal':
encoder = self.temporal_expert
if expert == 'sap_att_grounding':
encoder = self.spa_temp_grounding_expert
if expert == 'vis_cap_grounding':
encoder = self.vis_cap_grounding_expert
assert encoder is not None
return encoder.bert if hasattr(encoder, "bert") else encoder
class V2Dial(V2DialAbstract):
def __init__(self, config):
super(V2Dial, self).__init__()
self.config = config
################## 1. Select Tokenizer -- We use BERT tokenizer ##################
bert_config = BertConfig.from_pretrained('bert-{}-uncased'.format(config.expert_size))
tokenizer = AutoTokenizer.from_pretrained('bert-{}-uncased'.format(config.expert_size))
text_embedding = BertEmbeddings(bert_config)
text_embedding.apply(self.init_weights)
token_type_embedding = nn.Embedding(3, bert_config.hidden_size) # Number of modalities (temp/spa/cap/hist-ques-ans)
token_type_embedding.apply(self.init_weights)
################## 1. Select LLM -- We use BERT tokenizer ##################
if config.llm_family == 'flan_t5':
logging.info('[INFO] LLM: Flan T5')
llm_model = T5ForConditionalGeneration
else:
raise ValueError
llm_tokenizer = AutoTokenizer.from_pretrained(
config.llm_name,
use_fast=False,
token='your_token'
)
# set the padding token to eos token for llama
if config.llm_family == 'llama':
llm_tokenizer.pad_token = llm_tokenizer.eos_token
#________________________________ LLM Quantization ________________________________#
if config.llm_family in ['mistral', 'llama']:
dtype=None
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
else:
if config.fp16:
dtype = torch.float16
if config.llm_family == 'flan_t5':
dtype = torch.bfloat16
else:
dtype = torch.float32
quantization_config = None
# llm_model.generate()
llm = llm_model.from_pretrained(
config.llm_name,
token='your_token',
torch_dtype=dtype,
quantization_config=quantization_config
)
if config.llm_family == 'llama':
llm_embed = llm.model.embed_tokens
elif config.llm_family == 'flan_t5':
llm_embed = llm.shared
elif config.llm_family == 'mistral':
llm_embed = llm.model.embed_tokens
elif config.llm_family == 'bart':
llm_embed = llm.model.shared
else:
raise ValueError
# llm.resize_token_embeddings(len(self.tokenizer))
if quantization_config is not None:
# Gradient checkpointing is not compatible with DDP!!
llm = prepare_model_for_kbit_training(llm, use_gradient_checkpointing=True)
if config.freeze_llm:
for _, param in llm.named_parameters():
param.requires_grad = False
logging.info('[INFO] LLM frozen')
else:
if config.use_lora_llm:
# load the lora config
with open(config.lora_config, 'r') as f:
lora_config = json.load(f)
if config.llm_family in ['flan_t5']:
lora_config['target_modules'] = ['q', 'v']
lora_config = LoraConfig(**lora_config)
llm = get_peft_model(llm, lora_config)
logging.info('[INFO] LLM hot with lora')
else:
logging.info('[INFO] LLM hot')
logging.info('[INFO] LLM successfully loaded')
for _, param in llm_embed.named_parameters():
param.data = param.data.float()
param.requires_grad = True
llm_to_moe = nn.Linear(llm.config.hidden_size, bert_config.hidden_size)
llm_to_moe.apply(self.init_weights)
moe_to_llm = nn.Linear(bert_config.hidden_size, llm.config.hidden_size)
moe_to_llm.apply(self.init_weights)
################## 2. Select the backbone ViT ##################
logging.info('[INFO] Loading ViT in progress')
if config.freeze_vit:
# vit_precision = 'fp16' if config.fp16 else 'fp32'
logging.info(f'[INFO] ViT precision: {config.vit_precision}')
visual_encoder, ln_vision = self.init_vision_encoder(
config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, precision=config.vit_precision
)
for name, param in visual_encoder.named_parameters():
param.requires_grad = False
visual_encoder = visual_encoder.eval()
visual_encoder.train = disabled_train
for name, param in ln_vision.named_parameters():
param.requires_grad = False
ln_vision = ln_vision.eval()
ln_vision.train = disabled_train
logging.info('[INFO] ViT frozen')
else:
vit_precision = 'fp32'
visual_encoder, ln_vision = self.init_vision_encoder(
config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, vit_precision=vit_precision
)
logging.info('[INFO] ViT hot')
logging.info('[INFO] ViT successfully loaded')
################## 3. Define the ViT-Expert communication Interface ##################
self.system_prompt = False
self.vit_token_pooling = config.vit_token_pooling
if self.vit_token_pooling:
vit_proj = nn.Linear(
1408*4, bert_config.hidden_size
)
else:
vit_proj = nn.Linear(
1408, bert_config.hidden_size
)
vit_proj.apply(self.init_weights)
spatial_att = SpatialAttention(input_dim=bert_config.hidden_size)
temporal_att = TemporalAttention(input_dim=bert_config.hidden_size)
spatial_att.apply(self.init_weights)
temporal_att.apply(self.init_weights)
################## 4. Define the Expert layers ##################
moe_layers = None
moe_norm = None
if config.use_moes:
moe_layers = []
for moe_layer_idx in range(config.num_moe_layers):
if moe_layer_idx < self.config.num_moe_modality_layers:
expert_flag = 'modalities'
else:
expert_flag = 'fusion'
moe_layer = MoELayer(
bert_config.hidden_size,
bert_config.num_attention_heads,
expert_flag,
has_hist=True,
use_sep_spatial_temp_experts=config.use_sep_spatial_temp_experts
)
moe_layer.apply(self.init_weights)
moe_layers.append(moe_layer)
logging.info(f'[INFO] {moe_layer_idx+1}/{config.num_moe_layers} MoE layers successfully loaded')
moe_layers = nn.ModuleList(moe_layers)
moe_norm = nn.LayerNorm(bert_config.hidden_size)
temp = nn.Parameter(0.07 * torch.ones([]))
# temp = 0.07
# Attach the components to self
if self.config.embed_from_llm:
self.tokenizer = llm_tokenizer
self.text_embedding = llm_embed
else:
self.tokenizer = tokenizer
self.text_embedding = text_embedding
self.token_type_embedding = token_type_embedding
self.llm = llm
self.llm_to_moe = llm_to_moe
self.moe_to_llm = moe_to_llm
self.visual_encoder = visual_encoder
self.ln_vision = ln_vision
self.vit_proj = vit_proj
self.moe_layers = moe_layers
self.moe_norm = moe_norm
self.spatial_att = spatial_att
self.temporal_att = temporal_att
self.temp = temp
def construct_global_input(self, cap_ids, cap_attention_mask, hist_ids, hist_attention_mask, vid_feat_len, device):
# for video: <s><vis><spatial>[spatial_feats]<temporal>[temp_feats]<caption>[cap_feats]<history>[hist_feats]</s>
batch_size = cap_ids.size(0)
special_toks_indices = {
'<s>': 0,
'<vis>': 1,
'<spatial>': 2,
}
ids = [self.added_vocab['<s>']] + [self.added_vocab['<vis>']] + [self.added_vocab['<spatial>']]
ids += vid_feat_len * [self.added_vocab['<pad>']]
ids += [self.added_vocab['<temporal>']]
special_toks_indices['<temporal>'] = len(ids) - 1
ids += vid_feat_len * [self.added_vocab['<pad>']]
ids += [self.added_vocab['<caption>']]
special_toks_indices['<caption>'] = len(ids) - 1
ids += cap_ids.size(1) * [self.added_vocab['<pad>']]
ids += [self.added_vocab['<history>']]
special_toks_indices['<history>'] = len(ids) - 1
ids += hist_ids.size(1) * [self.added_vocab['<pad>']]
ids += [self.added_vocab['</s>']]
special_toks_indices['</s>'] = len(ids) - 1
total_len = len(ids)
ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
ids[:, special_toks_indices['<caption>'] + 1: special_toks_indices['<history>']] = cap_ids
ids[:, special_toks_indices['<history>'] + 1: special_toks_indices['</s>']] = hist_ids
mask = torch.ones((batch_size, total_len), device=device)
mask[:, special_toks_indices['<caption>'] + 1: special_toks_indices['<history>']] = cap_attention_mask
mask[:, special_toks_indices['<history>'] + 1: special_toks_indices['</s>']] = hist_attention_mask
return ids, mask, special_toks_indices
def pad_to_right_enc_dec(self, cap_embeds, cap_masks, hist_embeds, hist_masks, device):
"""
pushes all in-between pad tokens to the right
"""
res_embeds = []
res_mask = []
for cap_embed, cap_mask, hist_embed, hist_mask in zip(cap_embeds, cap_masks, hist_embeds, hist_masks):
len_cap = sum(cap_mask)
len_hist = sum(hist_mask)
batch_embed = torch.cat([cap_embed[:len_cap], hist_embed[:len_hist], cap_embed[len_cap:], hist_embed[len_hist:]], dim=0)
batch_mask = torch.zeros(batch_embed.size(0)).long().to(device)
batch_mask[:len_cap+len_hist] = 1
res_embeds.append(batch_embed)
res_mask.append(batch_mask)
res_embeds = torch.stack(res_embeds, dim=0)
res_mask = torch.stack(res_mask, dim=0)
return res_embeds, res_mask
def encode_vis_with_seq_spa_temp_att(self, image, device, is_vid=True):
num_frames = image.size(1)
bs_pre_reshape = image.size(0)
if len(image.shape) > 4:
image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
# with self.maybe_autocast(): # inherited from Blip2Base
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
bs, pn, hs = image_embeds.shape
if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632)
image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
vis_embed = self.vit_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
# reshape the video features
vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1))
size_orig = vis_embed.size()
# Perfrom spatial temporal attention
vis_embed = self.spatial_att(vis_embed)
if is_vid:
vis_embed = vis_embed.view(size_orig)
vis_embed = self.temporal_att(vis_embed)
vis_feat_len = vis_embed.size(1)
# vis_embed = vis_embed + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device))
vis_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device)
return vis_embed, vis_mask
def moe_forward_no_sep_spatial_temporal(
self,
vis, vis_mask,
cap_ids, cap_mask, hist_ids, hist_mask,
is_vid, device):
# is_vid = media_type == 'webvid'
# batch_size = len(cap)
vis_feat_len = vis.size(1)
input_embeds = []
input_masks = []
input_embeds.append(vis)
input_masks.append(vis_mask)
# if is_vid:
# input_embeds.append(vis_temporal)
# input_masks.append(vis_temporal_mask)
if self.config.embed_from_llm:
cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids))
else:
cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2))
cap_feat_len = cap_embeds.size(1)
input_embeds.append(cap_embeds)
input_masks.append(cap_mask)
if self.config.embed_from_llm:
hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids))
else:
hist_embeds = self.text_embedding(hist_ids) + self.token_type_embedding(torch.ones_like(hist_ids).long().fill_(2))
hist_feat_len = hist_embeds.size(1)
input_embeds.append(hist_embeds)
input_masks.append(hist_mask)
input_embeds = torch.cat(input_embeds, dim=1)
input_masks = torch.cat(input_masks, dim=1)
# expand the mask
input_masks = self.get_extended_attention_mask(attention_mask=input_masks)
# MoEs feed-forward
for moe_layer_idx, moe_layer in enumerate(self.moe_layers):
if moe_layer_idx < self.config.num_moe_modality_layers:
expert_flag = 'modalities'
else:
expert_flag = 'fusion'
input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len, is_vid=is_vid, mask=input_masks)
#TODO normalize the output () !!!!!!
input_embeds = self.moe_norm(input_embeds)
# return the features
vis_embeds = input_embeds[:, :vis_feat_len]
# temporal_embeds = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None
cap_embeds = input_embeds[:, -(cap_feat_len + hist_feat_len): -hist_feat_len]
hist_embeds = input_embeds[:, -hist_feat_len:]
# cls_feats = self.pooler(cap_feats)
moe_outputs = {
'vis_embeds': vis_embeds,
# 'temporal_embeds': temporal_embeds,
'cap_embeds': cap_embeds,
'hist_embeds': hist_embeds,
# 'cls_feats': cls_feats,
# 'last_hidden': input_embeds
}
return moe_outputs
def moe_forward(
self,
vis_spatial, vis_spatial_mask, vis_temporal, vis_temporal_mask,
cap_ids, cap_mask, hist_ids, hist_mask,
is_vid, device):
# is_vid = media_type == 'webvid'
# batch_size = len(cap)
vis_feat_len = vis_spatial.size(1)
input_embeds = []
input_masks = []
input_embeds.append(vis_spatial)
input_masks.append(vis_spatial_mask)
if is_vid:
input_embeds.append(vis_temporal)
input_masks.append(vis_temporal_mask)
if self.config.embed_from_llm:
cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids))
else:
cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2))
cap_feat_len = cap_embeds.size(1)
input_embeds.append(cap_embeds)
input_masks.append(cap_mask)
if self.config.embed_from_llm:
hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids))
else:
hist_embeds = self.text_embedding(hist_ids) + self.token_type_embedding(torch.ones_like(hist_ids).long().fill_(2))
hist_feat_len = hist_embeds.size(1)
input_embeds.append(hist_embeds)
input_masks.append(hist_mask)
input_embeds = torch.cat(input_embeds, dim=1)
input_masks = torch.cat(input_masks, dim=1)
# expand the mask
input_masks = self.get_extended_attention_mask(attention_mask=input_masks)
# MoEs feed-forward
for moe_layer_idx, moe_layer in enumerate(self.moe_layers):
if moe_layer_idx < self.config.num_moe_modality_layers:
expert_flag = 'modalities'
else:
expert_flag = 'fusion'
input_embeds = moe_layer(
input_embeds, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len,
is_vid=is_vid,
mask=input_masks,
expert_permutation=self.config.expert_permutation
)
#TODO normalize the output () !!!!!!
input_embeds = self.moe_norm(input_embeds)
# return the features
spatial_embeds = input_embeds[:, :vis_feat_len]
temporal_embeds = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None
cap_embeds = input_embeds[:, -(cap_feat_len + hist_feat_len): -hist_feat_len]
hist_embeds = input_embeds[:, -hist_feat_len:]
# cls_feats = self.pooler(cap_feats)
moe_outputs = {
'spatial_embeds': spatial_embeds,
'temporal_embeds': temporal_embeds,
'cap_embeds': cap_embeds,
'hist_embeds': hist_embeds,
# 'cls_feats': cls_feats,
# 'last_hidden': input_embeds
}
return moe_outputs
def forward(self, vis, cap, hist, ans, media_type):
device = vis.device
is_vid = media_type in ['webvid', 'champagne', 'avsd', 'nextqa']
loss_stc = torch.tensor(0)
loss_stm = torch.tensor(0)
loss_vhc = torch.tensor(0)
loss_vhm = torch.tensor(0)
loss_gen = torch.tensor(0)
# construct the global input tensor --> use place holder for vis features
cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=None)
hist_ids, hist_mask = self.tokenize_text(hist, device, max_len=None)
if self.config.use_moes:
# First get the visual features depending on the media type
if self.config.use_sep_spatial_temp_experts:
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid)
spatial_feat_len = vis_embed_spatial.size(1)
else:
vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
if self.config.use_sep_spatial_temp_experts:
moe_outputs = self.moe_forward(
vis_embed_spatial, vis_spatial_mask,
vis_embed_temporal, vis_temporal_mask,
cap_ids, cap_mask,
hist_ids, hist_mask,
is_vid, device
)
spatial_embeds = self.moe_to_llm(moe_outputs['spatial_embeds'])
temporal_embeds = self.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None
else:
moe_outputs = self.moe_forward_no_sep_spatial_temporal(
vis_embed, vis_mask,
cap_ids, cap_mask,
hist_ids, hist_mask,
is_vid, device
)
vis_embeds = self.moe_to_llm(moe_outputs['vis_embeds'])
# temporal_embeds = self.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None
cap_embeds = self.moe_to_llm(moe_outputs['cap_embeds'])
hist_embeds = self.moe_to_llm(moe_outputs['hist_embeds'])
else:
cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids))
hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids))
vis_embeds, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
ans = [a + self.tokenizer.eos_token for a in ans]
inputs_embeds, attention_mask = self.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device)
# now merge the multi-modal inputs
if self.config.use_moes:
if self.config.use_sep_spatial_temp_experts:
if is_vid:
inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1)
else:
inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1)
else:
inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
else:
inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
decoder_ids, decoder_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest
labels = decoder_ids.masked_fill(decoder_ids == self.tokenizer.pad_token_id, -100)
decoder_ids = self.shift_right(labels)
decoder_inputs_embeds = self.text_embedding(decoder_ids)
lm_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_attention_mask=decoder_mask,
labels=labels,
return_dict=True
)
loss_gen = lm_outputs.loss
return dict(
loss_stc = loss_stc * self.config.loss_dict['stc'],
loss_stm = loss_stm * self.config.loss_dict['stm'],
loss_vhc = loss_vhc * self.config.loss_dict['vhc'],
loss_vhm = loss_vhm * self.config.loss_dict['vhm'],
loss_gen = loss_gen * self.config.loss_dict['gen'],
)