import json import re import glog as logging import random import os import torch from torch.cuda.amp import autocast as autocast import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist # from minigpt4.common.registry import registry from .backbones.blip2 import Blip2Base, disabled_train from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration from transformers.models.bart.modeling_bart import BartForConditionalGeneration # from .backbones.encoder_decoder.xflan_t5 import T5ForConditionalGeneration from .backbones.modeling_mistral import MistralForCausalLM from .backbones.modeling_llama_v2 import LlamaForCausalLM from .backbones.moes import MoELayer, Pooler # from .backbones.moes_huggingface import MoEPooler # from .backbones.moes_huggingface import MoELayer, MoEPooler from .modules.temporal_modelling import SpatialAttention, TemporalAttention from .common.dist_utils import concat_all_gather, all_gather_with_grad from .utils import MLM from utils.dist import is_main_process # from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model # minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model # from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub from transformers import AutoTokenizer, DataCollatorForLanguageModeling from transformers import BitsAndBytesConfig from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings from peft import ( LoraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_kbit_training, set_peft_model_state_dict, ) import time import numpy as np # from minigpt4.models import policies class V2DialAbstract(Blip2Base): def __init__(self): super(V2DialAbstract, self).__init__() def shift_right(self, input_ids): decoder_start_token_id = self.llm.config.decoder_start_token_id pad_token_id = self.llm.config.pad_token_id if decoder_start_token_id is None: raise ValueError( "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " "See T5 docs for more information." ) shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids def encode_vis(self, image, device, is_vid=True): num_frames = image.size(1) bs_pre_reshape = image.size(0) if len(image.shape) > 4: image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) # with self.maybe_autocast(): # inherited from Blip2Base image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) bs, pn, hs = image_embeds.shape if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) vis_embed = self.vit_proj(image_embeds) # project to LLM input size (200,64,5632) -> (200,64, d_hidden) # reshape the video features vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) # Perfrom spatial temporal attention vis_embed_spatial = self.spatial_att(vis_embed) vis_feat_len = vis_embed_spatial.size(1) if not self.config.embed_from_llm: vis_embed_spatial = vis_embed_spatial + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) vis_spatial_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) vis_embed_temporal, vis_temporal_mask = None, None if is_vid: vis_embed_temporal = self.temporal_att(vis_embed) if not self.config.embed_from_llm: vis_embed_temporal = vis_embed_temporal + self.token_type_embedding(torch.ones(bs_pre_reshape, vis_feat_len).long().to(device)) vis_temporal_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) return vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask def tokenize_text(self, text, device, add_bos=False, add_eos=False, max_len=None): if max_len: text_tokenized = self.tokenizer( text, return_tensors='pt', padding='max_length', max_length=max_len, truncation=True, add_special_tokens=False, return_special_tokens_mask=True ).to(device) else: text_tokenized = self.tokenizer( text, return_tensors='pt', padding='longest', add_special_tokens=False, return_special_tokens_mask=True ).to(device) text_ids = text_tokenized.input_ids text_attention_mask = text_tokenized.attention_mask if add_bos: bos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.bos_token_id).to(device) bos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device) text_ids = torch.cat([bos_ids, text_ids], dim=1) text_attention_mask = torch.cat([bos_att, text_attention_mask], dim=1) if add_eos: eos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.eos_token_id).to(device) eos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device) text_ids = torch.cat([text_ids, eos_ids], dim=1) text_attention_mask = torch.cat([text_attention_mask, eos_att], dim=1) return text_ids, text_attention_mask def get_extended_attention_mask(self, attention_mask=None): if attention_mask.dim() == 2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) elif attention_mask.dim() == 3: extended_attention_mask = attention_mask.unsqueeze(1) else: raise NotImplementedError return extended_attention_mask @staticmethod def init_weights(module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class V2DialBase(V2DialAbstract): def __init__(self, config): super(V2DialBase, self).__init__() self.config = config ################## 1. Select Tokenizer -- We use BERT tokenizer ################## bert_config = BertConfig.from_pretrained('bert-{}-uncased'.format(config.expert_size)) tokenizer = AutoTokenizer.from_pretrained('bert-{}-uncased'.format(config.expert_size)) text_embedding = BertEmbeddings(bert_config) text_embedding.apply(self.init_weights) token_type_embedding = nn.Embedding(3, bert_config.hidden_size) # Number of modality types (temp/spa/text) token_type_embedding.apply(self.init_weights) # Define the masking strategy mlm_collactor = DataCollatorForLanguageModeling( tokenizer, mlm=True, mlm_probability=config.masking_prob, return_tensors='pt') ################## 2. Select the backbone ViT ################## logging.info('[INFO] Loading ViT in progress') if config.freeze_vit: # vit_precision = 'fp16' if config.fp16 else 'fp32' logging.info(f'[INFO] ViT precision: {config.vit_precision}') visual_encoder, ln_vision = self.init_vision_encoder( config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, precision=config.vit_precision ) for name, param in visual_encoder.named_parameters(): param.requires_grad = False visual_encoder = visual_encoder.eval() visual_encoder.train = disabled_train for name, param in ln_vision.named_parameters(): param.requires_grad = False ln_vision = ln_vision.eval() ln_vision.train = disabled_train logging.info('[INFO] ViT frozen') else: vit_precision = 'fp32' visual_encoder, ln_vision = self.init_vision_encoder( config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, vit_precision=vit_precision ) logging.info('[INFO] ViT hot') logging.info('[INFO] ViT successfully loaded') ################## 3. Define the ViT-Expert communication Interface ################## self.system_prompt = False self.vit_token_pooling = config.vit_token_pooling if self.vit_token_pooling: vit_proj = nn.Linear( 1408*4, bert_config.hidden_size ) else: vit_proj = nn.Linear( 1408, bert_config.hidden_size ) vit_proj.apply(self.init_weights) spatial_att = SpatialAttention(input_dim=bert_config.hidden_size) temporal_att = TemporalAttention(input_dim=bert_config.hidden_size) spatial_att.apply(self.init_weights) temporal_att.apply(self.init_weights) ################## 4. Define the Expert layers ################## moe_layers = [] for moe_layer_idx in range(config.num_moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' moe_layer = MoELayer( bert_config.hidden_size, bert_config.num_attention_heads, expert_flag, use_sep_spatial_temp_experts=config.use_sep_spatial_temp_experts ) moe_layer.apply(self.init_weights) moe_layers.append(moe_layer) logging.info(f'[INFO] {moe_layer_idx+1}/{config.num_moe_layers} MoE layers successfully loaded') moe_layers = nn.ModuleList(moe_layers) moe_norm = nn.LayerNorm(bert_config.hidden_size) ################## 5. Define the projection layers for contrastive learning ################## temp_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) spatial_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) vision_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) cap_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) temp_proj.apply(self.init_weights) spatial_proj.apply(self.init_weights) vision_proj.apply(self.init_weights) cap_proj.apply(self.init_weights) ################## 6. Define the pooler for matching loss ################## pooler = Pooler(bert_config.hidden_size) pooler.apply(self.init_weights) ################## 5. Attach the matching heads ################## stm_head = nn.Linear(bert_config.hidden_size, 2) vcm_head = nn.Linear(bert_config.hidden_size, 2) lm_head = nn.Linear(bert_config.hidden_size, len(tokenizer)) stm_head.apply(self.init_weights) vcm_head.apply(self.init_weights) lm_head.apply(self.init_weights) temp = nn.Parameter(0.07 * torch.ones([])) # temp = 0.07 # Attach the components to self self.tokenizer = tokenizer self.mlm_collactor = mlm_collactor self.text_embedding = text_embedding self.token_type_embedding = token_type_embedding self.visual_encoder = visual_encoder self.ln_vision = ln_vision self.vit_proj = vit_proj self.moe_layers = moe_layers self.moe_norm = moe_norm self.spatial_att = spatial_att self.temporal_att = temporal_att self.temp_proj = temp_proj self.spatial_proj = spatial_proj self.vision_proj = vision_proj self.cap_proj = cap_proj self.pooler = pooler self.stm_head = stm_head self.vcm_head = vcm_head self.lm_head = lm_head self.temp = temp @staticmethod def init_weights(module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def build_query_embeds(self, num_query_tokens, dim_query_tokens): query_embeds = nn.Parameter( torch.zeros(1, num_query_tokens, dim_query_tokens) ) query_embeds.data.normal_(mean=0.0, std=0.02) return query_embeds def encode_caption(self, cap): cap_output = self.cap_expert( input_ids=cap.input_ids, attention_mask=cap.attention_mask, return_dict=True, ) cap_embeds = cap_output.last_hidden_state pooled_cap_embeds = cap_embeds[:, 0] return cap_embeds, pooled_cap_embeds def encode_vis_old(self, vis, media_type): # if media_type == 'webvid': # bs, num_frames, c, h, w = vis.size() # # reshape # vis = vis.view(bs * num_frames, c, h, w) vis_embed = self.beit(vis).last_hidden_state # vis_embed = self.beit_layernorm(vis_output.last_hidden_state) # remove cls token embedding vis_embed = vis_embed[:, :, 1:, :] vis_embed = self.beit_lin(vis_embed) # perform spatial attention vis_spatial_embed = self.spatial_att(vis_embed) vis_temp_embed = self.tempotal_att(vis_embed) if media_type in ['webvid', 'msrvtt', 'champagne', 'avsd'] else None return vis_spatial_embed, vis_temp_embed def encode_queries(self, query_embeds, vis_embeds, vis_mode): if vis_mode == 'spatial': expert = self.spatial_expert layer_norm = self.spatial_layernorm elif vis_mode == 'temporal': expert = self.temporal_expert layer_norm = self.temporal_layernorm else: raise ValueError(f'[ERROR] {vis_mode} not implemented!') attention_mask = torch.ones( query_embeds.size()[:-1], dtype=torch.long).to(vis_embeds.device) vis_attention_mask = torch.ones( vis_embeds.size()[:-1], dtype=torch.long).to(vis_embeds.device) if self.config['expert_layer_type'] == 'bert': output_dict = expert( encoder_embeds=query_embeds, encoder_hidden_states=vis_embeds, encoder_attention_mask=vis_attention_mask, ) query_embeds = layer_norm(output_dict.last_hidden_state) pooled_query_embeds = output_dict.pooler_output elif self.config['expert_layer_type'] == 'bart': output_dict = expert( inputs_embeds=query_embeds, attention_mask=attention_mask, cross_embeds=vis_embeds, cross_attention_mask=vis_attention_mask, ) query_embeds = layer_norm(output_dict.last_hidden_state) pooled_query_embeds = query_embeds[:, 0] return query_embeds, pooled_query_embeds # def encode_vis(self, image, device, is_vid=True): # num_frames = image.size(1) # bs_pre_reshape = image.size(0) # if len(image.shape) > 4: # image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) # # with self.maybe_autocast(): # inherited from Blip2Base # image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) # image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) # bs, pn, hs = image_embeds.shape # if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) # image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) # vis_embed = self.vit_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096) # # reshape the video features # vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) # # Perfrom spatial temporal attention # vis_embed_spatial = self.spatial_att(vis_embed) # vis_feat_len = vis_embed_spatial.size(1) # vis_embed_spatial = vis_embed_spatial + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) # vis_spatial_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) # vis_embed_temporal, vis_temporal_mask = None, None # if is_vid: # vis_embed_temporal = self.temporal_att(vis_embed) + self.token_type_embedding(torch.ones(bs_pre_reshape, vis_feat_len).long().to(device)) # vis_temporal_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) # return vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask def encode_vis_with_seq_spa_temp_att(self, image, device, is_vid=True): num_frames = image.size(1) bs_pre_reshape = image.size(0) if len(image.shape) > 4: image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) # with self.maybe_autocast(): # inherited from Blip2Base image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) bs, pn, hs = image_embeds.shape if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) vis_embed = self.vit_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096) # reshape the video features vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) size_orig = vis_embed.size() # Perfrom spatial temporal attention vis_embed = self.spatial_att(vis_embed) if is_vid: vis_embed = vis_embed.view(size_orig) vis_embed = self.temporal_att(vis_embed) vis_feat_len = vis_embed.size(1) vis_embed = vis_embed + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) vis_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) return vis_embed, vis_mask def tokenize_text(self, text, device, add_bos=False, add_eos=False, max_len=None): if max_len: text_tokenized = self.tokenizer( text, return_tensors='pt', padding='max_length', max_length=max_len, truncation=True, add_special_tokens=False, return_special_tokens_mask=True ).to(device) else: text_tokenized = self.tokenizer( text, return_tensors='pt', padding='longest', add_special_tokens=False, return_special_tokens_mask=True ).to(device) text_ids = text_tokenized.input_ids text_attention_mask = text_tokenized.attention_mask if add_bos: bos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.bos_token_id).to(device) bos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device) text_ids = torch.cat([bos_ids, text_ids], dim=1) text_attention_mask = torch.cat([bos_att, text_attention_mask], dim=1) if add_eos: eos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.eos_token_id).to(device) eos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device) text_ids = torch.cat([text_ids, eos_ids], dim=1) text_attention_mask = torch.cat([text_attention_mask, eos_att], dim=1) return text_ids, text_attention_mask def encode_text(self, text, max_len, device): text_tokenized = self.tokenizer( text, return_tensors='pt', padding='max_length', max_length=max_len, truncation=True, add_special_tokens=False ).to(device) text_ids = text_tokenized.input_ids text_embeds = self.embed(text_ids) text_attention_mask = text_tokenized.attention_mask return text_embeds, text_ids, text_attention_mask def encode_spatial_toks(self, batch_size, device): # ['', '', '', '', ''] special_toks_ids = self.tokenizer( '', return_tensors='pt', padding='longest', truncation=True, add_special_tokens=False ).to(device) special_toks_embeds = self.embed(special_toks_ids.input_ids) special_toks_embeds = special_toks_embeds.repeat(batch_size, 1, 1) return special_toks_embeds def construt_input_embeds_stage_1(self, vis_embed, cap_embed, special_toks_embeds, cap_attention_mask, media_type, device): batch_size = vis_embed.size(0) embed_dim = vis_embed.size(-1) vis_embed = vis_embed.view(batch_size, -1, embed_dim) input_embeds = [] input_attention_mask = [] special_toks_indices = { '': 0, '': 1, '': 2, } # special_toks_embeds = # for video: [spatial_featurres][temporal_featurres][caption_features] # for image: [spatial_featurres][caption_features] input_embeds.append(special_toks_embeds[:, 0:3, :]) # input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) input_embeds.append(vis_embed.clone()) # [spatial_features] input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) if media_type == 'webvid': # here we copy the original vis_embeds twice and will apply spatial and temporal attention later input_embeds.append(special_toks_embeds[:, 3:4, :]) # input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) special_toks_indices[''] = special_toks_indices[''] + input_embeds[-2].size(1) + 1 input_embeds.append(vis_embed.clone()) # [temporal_features] input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) input_embeds.append(special_toks_embeds[:, 4:5, :]) # input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) if media_type == 'webvid': special_toks_indices[''] = special_toks_indices[''] + input_embeds[-2].size(1) + 1 elif media_type == 'cc3m': special_toks_indices[''] = special_toks_indices[''] + input_embeds[-2].size(1) + 1 input_embeds.append(cap_embed) # [caption_features] input_attention_mask.append(cap_attention_mask) input_embeds.append(special_toks_embeds[:, 6:7, :]) # input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) special_toks_indices[''] = special_toks_indices[''] + input_embeds[-2].size(1) + 1 input_embeds = torch.cat(input_embeds, dim=1) input_attention_mask = torch.cat(input_attention_mask, dim=1) assert input_embeds.size()[:-1] == input_attention_mask.size() return input_embeds, input_attention_mask, special_toks_indices def construct_global_input(self, cap_ids, cap_attention_mask, vid_feat_len, media_type, device): # for video: [spatial_featurres][temporal_features][caption_features] # for image: [spatial_featurres][caption_features] batch_size = cap_ids.size(0) special_toks_indices = { '': 0, '': 1, '': 2, } ids = [self.added_vocab['']] + [self.added_vocab['']] + [self.added_vocab['']] ids += vid_feat_len * [self.added_vocab['']] if media_type == 'webvid': ids += [self.added_vocab['']] special_toks_indices[''] = len(ids) - 1 ids += vid_feat_len * [self.added_vocab['']] ids += [self.added_vocab['']] special_toks_indices[''] = len(ids) - 1 ids += cap_ids.size(1) * [self.added_vocab['']] ids += [self.added_vocab['']] special_toks_indices[''] = len(ids) - 1 total_len = len(ids) ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) ids[:, special_toks_indices[''] + 1: special_toks_indices['']] = cap_ids mask = torch.ones((batch_size, total_len), device=device) mask[:, special_toks_indices[''] + 1: special_toks_indices['']] = cap_attention_mask return ids, mask, special_toks_indices def compute_contrastive_loss(self, x, y_all, y, x_all): sim_x2y = torch.mm(x, y_all.t()) # (bs, bs*ngpus) sim_x2y = sim_x2y / self.temp sim_y2x = torch.mm(y, x_all.t()) # (bs, bs*ngpus) sim_y2x = sim_y2x / self.temp rank = dist.get_rank() if self.config['distributed'] else 0 bs = x.size(0) targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( x.device ) loss_contrastive = ( F.cross_entropy(sim_x2y, targets, label_smoothing=0.1) + F.cross_entropy(sim_y2x, targets, label_smoothing=0.1) ) / 2 return loss_contrastive, sim_x2y, sim_y2x def get_extended_attention_mask(self, attention_mask=None): if attention_mask.dim() == 2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) elif attention_mask.dim() == 3: extended_attention_mask = attention_mask.unsqueeze(1) else: raise NotImplementedError return extended_attention_mask def shared_forward( self, vis_spatial, vis_spatial_mask, vis_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device): # is_vid = media_type == 'webvid' # batch_size = len(cap) vis_feat_len = vis_spatial.size(1) input_embeds = [] input_masks = [] input_embeds.append(vis_spatial) input_masks.append(vis_spatial_mask) if is_vid: input_embeds.append(vis_temporal) input_masks.append(vis_temporal_mask) cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2)) cap_feat_len = cap_embeds.size(1) input_embeds.append(cap_embeds) input_masks.append(cap_mask) input_embeds = torch.cat(input_embeds, dim=1) input_masks = torch.cat(input_masks, dim=1) # expand the mask input_masks = self.get_extended_attention_mask(attention_mask=input_masks) # MoEs feed-forward for moe_layer_idx, moe_layer in enumerate(self.moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, is_vid=is_vid, mask=input_masks) #TODO normalize the output () !!!!!! input_embeds = self.moe_norm(input_embeds) # return the features spatial_feats = input_embeds[:, :vis_feat_len] temporal_feats = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None cap_feats = input_embeds[:, -cap_feat_len:] cls_feats = self.pooler(cap_feats) moe_outputs = { 'spatial_feats': spatial_feats, 'temporal_feats': temporal_feats, 'cap_feats': cap_feats, 'cls_feats': cls_feats, } return moe_outputs def shared_forward_no_sep_spatial_temporal_experts( self, vis, vis_mask, cap_ids, cap_mask, is_vid, device): # is_vid = media_type == 'webvid' # batch_size = len(cap) vis_feat_len = vis.size(1) input_embeds = [] input_masks = [] input_embeds.append(vis) input_masks.append(vis_mask) # if is_vid: # input_embeds.append(vis_temporal) # input_masks.append(vis_temporal_mask) cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2)) cap_feat_len = cap_embeds.size(1) input_embeds.append(cap_embeds) input_masks.append(cap_mask) input_embeds = torch.cat(input_embeds, dim=1) input_masks = torch.cat(input_masks, dim=1) # expand the mask input_masks = self.get_extended_attention_mask(attention_mask=input_masks) # MoEs feed-forward for moe_layer_idx, moe_layer in enumerate(self.moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, is_vid=is_vid, mask=input_masks) #TODO normalize the output () !!!!!! input_embeds = self.moe_norm(input_embeds) # return the features vis_feats = input_embeds[:, :vis_feat_len] cap_feats = input_embeds[:, -cap_feat_len:] cls_feats = self.pooler(cap_feats) moe_outputs = { 'vis_feats': vis_feats, 'cap_feats': cap_feats, 'cls_feats': cls_feats, } return moe_outputs def vcm_iteration(self, vis, cap, neg_vis, is_vid, device): # Prepare the vis data # is_vid = media_type == 'webvid' num_positive_samples = len(cap) // 2 num_negative_samples = len(cap) - num_positive_samples vcm_labels = torch.cat([torch.ones(num_positive_samples), torch.zeros(num_negative_samples)]).to(device) vcm_labels = vcm_labels[torch.randperm(vcm_labels.size(0))].long() # now get the mixed vis data vis_mixed = [p if vcm_labels[i] == 1 else n for i, (p, n) in enumerate(zip(vis, neg_vis))] vis_mixed = torch.stack(vis_mixed, dim=0) cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) if self.config.use_sep_spatial_temp_experts: vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis_mixed, device, is_vid=is_vid) moe_outputs = self.shared_forward( vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) else: vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts( vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device) vcm_logits = self.vcm_head(moe_outputs['cls_feats']) loss_vcm = F.cross_entropy(vcm_logits, vcm_labels) return loss_vcm def stm_iteration(self, vis, cap, neg_vis, is_vid, device): num_positive_samples = len(cap) // 2 num_negative_samples = len(cap) - num_positive_samples stm_labels = torch.cat([torch.ones(num_positive_samples), torch.zeros(num_negative_samples)]).to(device) stm_labels = stm_labels[torch.randperm(stm_labels.size(0))].long() vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) neg_vis_embed_spatial, _ , neg_vis_embed_temporal, _ = self.encode_vis(neg_vis, device, is_vid=is_vid) # now get the mixed vis data vis_embed_spatial_mixed = [] vis_embed_temporal_mixed = [] for i, (pos_spatial, pos_temporal, neg_spatial, neg_temporal) in enumerate( zip(vis_embed_spatial, vis_embed_temporal, S, neg_vis_embed_temporal)): if stm_labels[i] == 1: vis_embed_spatial_mixed.append(pos_spatial) vis_embed_temporal_mixed.append(pos_temporal) else: # 50% negative spatial / 50% negative temporal if torch.rand(1).item() < 0.5: vis_embed_spatial_mixed.append(pos_spatial) vis_embed_temporal_mixed.append(neg_temporal) else: vis_embed_spatial_mixed.append(neg_spatial) vis_embed_temporal_mixed.append(pos_temporal) vis_embed_spatial_mixed = torch.stack(vis_embed_spatial_mixed, dim=0) vis_embed_temporal_mixed = torch.stack(vis_embed_temporal_mixed, dim=0) cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) moe_outputs = self.shared_forward( vis_embed_spatial_mixed, vis_spatial_mask, vis_embed_temporal_mixed, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) stm_logits = self.vcm_head(moe_outputs['cls_feats']) loss_stm = F.cross_entropy(stm_logits, stm_labels) return loss_stm def mlm_iteration(self, vis, cap, is_vid, device): if self.config.use_sep_spatial_temp_experts: vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) else: vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) cap_ids = cap_ids.tolist() # NOTE We make sure to mask some tokens here to avoid nan loss later mlm_output = self.mlm_collactor(cap_ids) cap_ids = mlm_output['input_ids'].to(device) labels_cap = mlm_output['labels'].to(device) if self.config.use_sep_spatial_temp_experts: moe_outputs = self.shared_forward( vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) else: moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts( vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device) mlm_logits = self.lm_head(moe_outputs['cap_feats']) loss_mlm = F.cross_entropy(mlm_logits.view(-1, mlm_logits.size(-1)), labels_cap.view(-1)) return loss_mlm def vcc_iteration(self, vis, cap, is_vid, device): vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) if self.config.use_sep_spatial_temp_experts: moe_outputs = self.shared_forward( vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) vis_feats = moe_outputs['spatial_feats'] if is_vid: vis_feats = torch.cat([moe_outputs['spatial_feats'], moe_outputs['temporal_feats']], dim=1) else: vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts( vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device) vis_feats = moe_outputs['vis_feats'] cap_feats = F.normalize(self.cap_proj(moe_outputs['cls_feats']), dim=-1) vis_feats = F.normalize(self.vision_proj(vis_feats), dim=-1) vis_feats_all = concat_all_gather(vis_feats) cap_feats_all = concat_all_gather(cap_feats) sim_v2c = torch.matmul( vis_feats.unsqueeze(1), cap_feats_all.unsqueeze(-1) ).squeeze() sim_v2c, _ = sim_v2c.max(-1) sim_v2c = sim_v2c / self.temp sim_c2v = torch.matmul( cap_feats.unsqueeze(1).unsqueeze(1), vis_feats_all.permute(0, 2, 1) ).squeeze() sim_c2v, _ = sim_c2v.max(-1) sim_c2v = sim_c2v / self.temp rank = dist.get_rank() if self.config['distributed'] else 0 bs = vis_feats.size(0) targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( device ) loss_vcc = ( F.cross_entropy(sim_v2c, targets, label_smoothing=0.1) + F.cross_entropy(sim_c2v, targets, label_smoothing=0.1) ) / 2 return loss_vcc def stc_iteration(self, vis, cap, is_vid, device): vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) moe_outputs = self.shared_forward( vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) spatial_feats = F.normalize(self.spatial_proj(moe_outputs['spatial_feats']), dim=-1) temporal_feats = F.normalize(self.temp_proj(moe_outputs['temporal_feats']), dim=-1) spatial_feats_all = concat_all_gather(spatial_feats) temporal_feats_all = concat_all_gather(temporal_feats) sim_s2t = torch.matmul( spatial_feats.unsqueeze(1), temporal_feats_all ) sim_s2t, _ = sim_s2t.max(-1) sim_s2t, _ = sim_s2t.max(-1) sim_s2t = sim_s2t / self.temp sim_t2s = torch.matmul( temporal_feats.unsqueeze(1), spatial_feats_all ) sim_t2s, _ = sim_t2s.max(-1) sim_t2s, _ = sim_t2s.max(-1) sim_t2s = sim_t2s / self.temp rank = dist.get_rank() if self.config['distributed'] else 0 bs = vis.size(0) targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( device ) loss_stc = ( F.cross_entropy(sim_s2t, targets, label_smoothing=0.1) + F.cross_entropy(sim_t2s, targets, label_smoothing=0.1) ) / 2 return loss_stc def forward(self, vis, cap, neg_vis, media_type): device = vis.device is_vid = media_type == 'webvid' loss_stc = torch.tensor(0).to(device) loss_stm = torch.tensor(0).to(device) loss_vcc = torch.tensor(0).to(device) loss_vcm = torch.tensor(0).to(device) loss_mlm = torch.tensor(0).to(device) if self.config.loss_dict['vcm'] != 0: loss_vcm = self.vcm_iteration(vis, cap, neg_vis, is_vid, device) if self.config.loss_dict['vcc'] != 0: loss_vcc = self.vcc_iteration(vis, cap, is_vid, device) if self.config.loss_dict['stm'] != 0 and is_vid: loss_stm = self.stm_iteration(vis, cap, neg_vis, is_vid, device) if self.config.loss_dict['stc'] != 0 and is_vid: loss_stc = self.stc_iteration(vis, cap, is_vid, device) if self.config.loss_dict['mlm'] != 0: loss_mlm = self.mlm_iteration(vis, cap, is_vid, device) return dict( loss_stc = loss_stc * self.config.loss_dict['stc'], loss_stm = loss_stm * self.config.loss_dict['stm'], loss_vcc = loss_vcc * self.config.loss_dict['vcc'], loss_vcm = loss_vcm * self.config.loss_dict['vcm'], loss_mlm = loss_mlm * self.config.loss_dict['mlm'], ) def forward__(self, vis, cap, neg_vis, media_type): device = vis.device self.vcm_matching(vis, cap, neg_vis, media_type, device) self.shared_forward(vis, cap, media_type, device) # First init all losses to zeros loss_stc = torch.tensor(0).to(device) loss_stm = torch.tensor(0).to(device) loss_vcc = torch.tensor(0).to(device) loss_vcm = torch.tensor(0).to(device) loss_mlm = torch.tensor(0).to(device) batch_size = len(cap) # First get the visual features depending on the media type vis_embed = self.encode_vis(vis) neg_vis_embed = self.encode_vis(neg_vis) embed_dim = vis_embed.size(-1) num_frames = vis.size(1) # reshape the video features vis_embed = vis_embed.view(batch_size, num_frames, -1, embed_dim) neg_vis_embed = neg_vis_embed.view(batch_size, num_frames, -1, embed_dim) # Perfrom spatial temporal attention and reshape vis_embed_spatial = self.spatial_att(vis_embed) # vis_embed_spatial = vis_embed_spatial.view(batch_size, -1, embed_dim) neg_vis_embed_spatial = self.spatial_att(neg_vis_embed) # neg_vis_embed_spatial = neg_vis_embed_spatial.view(batch_size, -1, embed_dim) if media_type == 'webvid': vis_embed_temporal = self.temporal_att(vis_embed) # vis_embed_temporal = vis_embed_temporal.view(batch_size, -1, embed_dim) neg_vis_embed_temporal = self.temporal_att(neg_vis_embed) # neg_vis_embed_temporal = neg_vis_embed_temporal.view(batch_size, -1, embed_dim) spatial_feat_len = vis_embed_spatial.size(1) # construct the global input tensor --> use place holder for vis features cap_ids, cap_attention_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) input_ids, input_mask, special_toks_indices = self.construct_global_input(cap_ids, cap_attention_mask, spatial_feat_len, media_type, device) input_embeds = self.embed(input_ids) if media_type == 'webvid': input_embeds[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = vis_embed_spatial input_embeds[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = vis_embed_temporal elif media_type == 'cc3m': input_embeds[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = vis_embed_spatial # LLM --> MoEs input_embeds = self.moe_llm_bottleneck(input_embeds) input_embeds_orig = input_embeds.clone() neg_vis_embed_spatial = self.moe_llm_bottleneck(neg_vis_embed_spatial) if media_type == 'webvid': neg_vis_embed_temporal = self.moe_llm_bottleneck(neg_vis_embed_temporal) for moe_layer_idx, moe_layer in enumerate(self.moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' input_embeds = moe_layer(input_embeds, special_toks_indices, expert_flag, mask=input_mask) #TODO normalize the output () !!!!!! #-------------------- Contrastive losses --------------------# cap_proj_feats = F.normalize(self.cap_proj(input_embeds[:, special_toks_indices[''], :]), dim=-1) # (bs*gpus, H) vis_proj_feats = F.normalize(self.vision_proj(input_embeds[:, special_toks_indices[''], :]), dim=-1) # (bs*gpus, H) if media_type == 'webvid': spatial_proj_feats = F.normalize(self.spatial_proj(input_embeds[:, special_toks_indices[''], :]), dim=-1) # (bs*gpus, H) temp_proj_feats = F.normalize(self.temp_proj(input_embeds[:, special_toks_indices[''], :]), dim=-1) # (bs*gpus, H) if self.config.loss_dict['vcc'] != 0: vis_proj_feats_all = concat_all_gather(vis_proj_feats) # (bs*gpus, H) cap_proj_feats_all = concat_all_gather(cap_proj_feats) # (bs*gpus, H) loss_vcc, _, _ = self.compute_contrastive_loss(vis_proj_feats, cap_proj_feats_all, cap_proj_feats, vis_proj_feats_all) # 1- Spatial-Temporal if media_type == 'webvid': if self.config.loss_dict['stc'] != 0: spatial_proj_feats_all = concat_all_gather(spatial_proj_feats) # (bs*gpus, H) temp_proj_feats_all = concat_all_gather(temp_proj_feats) # (bs*gpus, H) loss_stc, _, _ = self.compute_contrastive_loss(temp_proj_feats, spatial_proj_feats_all, spatial_proj_feats, temp_proj_feats_all) #-------------------- Matching losses --------------------# if self.config.loss_dict['vcm'] != 0: # Negative caption with positive visual neg_cap_ids, neg_cap_attention_mask, = self.tokenize_text(neg_cap, device, max_len=self.config.max_cap_len) neg_cap_embed = self.moe_llm_bottleneck(self.embed(neg_cap_ids)) input_embeds_neg_cap = input_embeds_orig.clone().detach() input_embeds_neg_cap[:, special_toks_indices[''] + 1:special_toks_indices['']] = neg_cap_embed input_mask_neg_cap = input_mask.clone().detach() input_mask_neg_cap[:, special_toks_indices[''] + 1:special_toks_indices['']] = neg_cap_attention_mask # Negative visual with positive caption input_embeds_neg_vis = input_embeds_orig.clone().detach() input_mask_neg_vis = input_mask.clone().detach() # neg_vis_embed = self.encode_vis(neg_vis) # # reshape video features # neg_vis_embed = neg_vis_embed.reshape(batch_size, num_frames, -1, embed_dim) # # Perfrom spatial temporal attention and reshape # neg_vis_embed_spatial = self.spatial_att(neg_vis_embed) # neg_vis_embed_spatial = neg_vis_embed_spatial.reshape(batch_size, -1, embed_dim) if media_type == 'webvid': # neg_vis_embed_temporal = self.temporal_att(neg_vis_embed) # neg_vis_embed_temporal = neg_vis_embed_temporal.reshape(batch_size, -1, embed_dim) input_embeds_neg_vis[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_spatial input_embeds_neg_vis[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_temporal elif media_type == 'cc3m': # neg_vis_embed_spatial = self.moe_llm_bottleneck(neg_vis_embed_spatial) input_embeds_neg_vis[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_spatial # Construct the input of VCM final_input_embeds_vcm = torch.cat([input_embeds_orig, input_embeds_neg_cap, input_embeds_neg_vis], dim=0) final_input_mask_vcm = torch.cat([input_mask, input_mask_neg_cap, input_mask_neg_vis], dim=0) for moe_layer_idx, moe_layer in enumerate(self.moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' final_input_embeds_vcm = moe_layer(final_input_embeds_vcm, special_toks_indices, expert_flag, mask=final_input_mask_vcm) pooled_caption = self.caption_pooler(final_input_embeds_vcm, special_toks_indices['']) pooled_vis = self.vis_pooler(final_input_embeds_vcm, special_toks_indices['']) vcm_feats = torch.mul(pooled_caption, pooled_vis) vcm_logits = self.vcm_head(vcm_feats) vcm_labels = torch.cat( [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], dim=0, ).to(device) # random permutation of the logits and labels --> make the task not trivial to learn # perm_idx = torch.randperm(vcm_logits.size(0), device=device) # perm_idx_extended = perm_idx.unsqueeze(-1).repeat(1, vcm_logits.size(-1)) # # Shuffle # vcm_logits = vcm_logits.scatter(0, perm_idx_extended, vcm_logits) # vcm_labels = vcm_labels.scatter(0, perm_idx, vcm_labels) # class_weight = torch.FloatTensor([1.0, 1.0/3]).to(device) loss_vcm = F.cross_entropy(vcm_logits, vcm_labels) # , weight=class_weight) if media_type == 'webvid': if self.config.loss_dict['stm'] != 0: # Negative spatial with positive temporal input_embeds_neg_spatial = input_embeds_orig.clone().detach() input_mask_neg_spatial = input_mask.clone().detach() input_embeds_neg_spatial[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_spatial # Positive spatial with negative temporal input_embeds_neg_temporal = input_embeds_orig.clone().detach() input_mask_neg_temporal = input_mask.clone().detach() input_embeds_neg_temporal[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_temporal # Construct the input of STM final_input_embeds_stm = torch.cat([input_embeds_orig, input_embeds_neg_spatial, input_embeds_neg_temporal], dim=0) final_input_mask_stm = torch.cat([input_mask, input_mask_neg_spatial, input_mask_neg_temporal], dim=0) for moe_layer_idx, moe_layer in enumerate(self.moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' final_input_embeds_stm = moe_layer(final_input_embeds_stm, special_toks_indices, expert_flag, mask=final_input_mask_stm) pooled_spatial = self.spatial_pooler(final_input_embeds_stm, special_toks_indices['']) pooled_temporal = self.temporal_pooler(final_input_embeds_stm, special_toks_indices['']) stm_feats = torch.mul(pooled_spatial, pooled_temporal) stm_logits = self.stm_head(stm_feats) stm_labels = torch.cat( [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], dim=0, ).to(device) # random permutation of the logits and labels --> make the task not trivial to learn # perm_idx = torch.randperm(stm_logits.size(0), device=device) # perm_idx_extended = perm_idx.unsqueeze(-1).repeat(1, stm_logits.size(-1)) # # Shuffle # stm_logits = stm_logits.scatter(0, perm_idx_extended, stm_logits) # stm_labels = stm_labels.scatter(0, perm_idx, stm_labels) # class_weight = torch.FloatTensor([1.0, 1.0/3]).to(device) loss_stm = F.cross_entropy(stm_logits, stm_labels) # , weight=class_weight) if self.config.loss_dict['mlm'] != 0: masked_cap_ids, labels = self.mlm(cap_ids.clone()) masked_cap_embeds = self.moe_llm_bottleneck(self.embed(masked_cap_ids)) # inject the masked embeddings instead of the original ones # input_embeds_mlm[:, special_toks_indices['']+1 : special_toks_indices[''], :] = masked_cap_embeds for moe_layer_idx, moe_layer in enumerate(self.moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' masked_cap_embeds = moe_layer(masked_cap_embeds, special_toks_indices, expert_flag, mask=cap_attention_mask, only_text=True) # extract the caption last hidden states # masked_cap_embeds_last = input_embeds_mlm[:, special_toks_indices['']+1 : special_toks_indices[''], :] lm_logits = self.lm_head(masked_cap_embeds) loss_mlm = F.cross_entropy( lm_logits.view(-1, len(self.tokenizer)), labels.view(-1), ignore_index=self.mlm.padding_token ) return dict( loss_stc = loss_stc * self.config.loss_dict['stc'], loss_stm = loss_stm * self.config.loss_dict['stm'], loss_vcc = loss_vcc * self.config.loss_dict['vcc'], loss_vcm = loss_vcm * self.config.loss_dict['vcm'], loss_mlm = loss_mlm * self.config.loss_dict['mlm'], ) def get_vis_enc_for_eval(self, vis, media_type): # First get the visual features depending on the media type vis_spatial_embed, vis_temporal_embed = self.encode_vis(vis, media_type) # Expand the query tokens spatial_query_embeds = self.spatial_query_embeds.expand(vis_spatial_embed.size(0), -1, -1) # Run the spatial expert spatial_query_embeds, pooled_spatial_query_embeds = self.encode_queries( spatial_query_embeds, vis_spatial_embed, vis_mode='spatial') temporal_query_embeds = self.spatial_query_embeds.expand(vis_temporal_embed.size(0), -1, -1) temporal_query_embeds, pooled_temporal_query_embeds = self.encode_queries( temporal_query_embeds, vis_temporal_embed, vis_mode='temporal') vis_pooled = torch.cat((pooled_spatial_query_embeds, pooled_temporal_query_embeds), dim=1) vis_embeds = torch.cat((spatial_query_embeds, temporal_query_embeds), dim=1) return vis_embeds, vis_pooled def get_expert_encoder(self, expert): """get text encoder, used for text and cross-modal encoding""" encoder = None if expert == 'cap': encoder = self.cap_expert if expert == 'spatial': encoder = self.spatial_expert if expert == 'temporal': encoder = self.temporal_expert if expert == 'sap_att_grounding': encoder = self.spa_temp_grounding_expert if expert == 'vis_cap_grounding': encoder = self.vis_cap_grounding_expert assert encoder is not None return encoder.bert if hasattr(encoder, "bert") else encoder class V2Dial(V2DialAbstract): def __init__(self, config): super(V2Dial, self).__init__() self.config = config ################## 1. Select Tokenizer -- We use BERT tokenizer ################## bert_config = BertConfig.from_pretrained('bert-{}-uncased'.format(config.expert_size)) tokenizer = AutoTokenizer.from_pretrained('bert-{}-uncased'.format(config.expert_size)) text_embedding = BertEmbeddings(bert_config) text_embedding.apply(self.init_weights) token_type_embedding = nn.Embedding(3, bert_config.hidden_size) # Number of modalities (temp/spa/cap/hist-ques-ans) token_type_embedding.apply(self.init_weights) ################## 1. Select LLM -- We use BERT tokenizer ################## if config.llm_family == 'llama': logging.info('[INFO] LLM: LLAMA v2') llm_model = LlamaForCausalLM elif config.llm_family == 'mistral': logging.info('[INFO] LLM: Mistral') llm_model = MistralForCausalLM elif config.llm_family == 'flan_t5': logging.info('[INFO] LLM: Flan T5') llm_model = T5ForConditionalGeneration elif config.llm_family == 'bart': logging.info('[INFO] LLM: BART') llm_model = BartForConditionalGeneration else: raise ValueError llm_tokenizer = AutoTokenizer.from_pretrained( config.llm_name, use_fast=False, token='your_token' ) # set the padding token to eos token for llama if config.llm_family == 'llama': llm_tokenizer.pad_token = llm_tokenizer.eos_token #________________________________ LLM Quantization ________________________________# if config.llm_family in ['mistral', 'llama']: dtype=None quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) else: if config.fp16: dtype = torch.float16 if config.llm_family == 'flan_t5': dtype = torch.bfloat16 else: dtype = torch.float32 quantization_config = None # llm_model.generate() llm = llm_model.from_pretrained( config.llm_name, token='your_token', torch_dtype=dtype, quantization_config=quantization_config ) if config.llm_family == 'llama': llm_embed = llm.model.embed_tokens elif config.llm_family == 'flan_t5': llm_embed = llm.shared elif config.llm_family == 'mistral': llm_embed = llm.model.embed_tokens elif config.llm_family == 'bart': llm_embed = llm.model.shared else: raise ValueError # llm.resize_token_embeddings(len(self.tokenizer)) if quantization_config is not None: # Gradient checkpointing is not compatible with DDP!! llm = prepare_model_for_kbit_training(llm, use_gradient_checkpointing=True) if config.freeze_llm: for _, param in llm.named_parameters(): param.requires_grad = False logging.info('[INFO] LLM frozen') else: if config.use_lora_llm: # load the lora config with open(config.lora_config, 'r') as f: lora_config = json.load(f) if config.llm_family in ['llama', 'mistral']: lora_config['target_modules'] = ['q_proj', 'v_proj'] elif config.llm_family in ['flan_t5']: lora_config['target_modules'] = ['q', 'v'] lora_config = LoraConfig(**lora_config) llm = get_peft_model(llm, lora_config) logging.info('[INFO] LLM hot with lora') else: logging.info('[INFO] LLM hot') logging.info('[INFO] LLM successfully loaded') for _, param in llm_embed.named_parameters(): param.data = param.data.float() param.requires_grad = True llm_to_moe = nn.Linear(llm.config.hidden_size, bert_config.hidden_size) llm_to_moe.apply(self.init_weights) moe_to_llm = nn.Linear(bert_config.hidden_size, llm.config.hidden_size) moe_to_llm.apply(self.init_weights) ################## 2. Select the backbone ViT ################## logging.info('[INFO] Loading ViT in progress') if config.freeze_vit: # vit_precision = 'fp16' if config.fp16 else 'fp32' logging.info(f'[INFO] ViT precision: {config.vit_precision}') visual_encoder, ln_vision = self.init_vision_encoder( config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, precision=config.vit_precision ) for name, param in visual_encoder.named_parameters(): param.requires_grad = False visual_encoder = visual_encoder.eval() visual_encoder.train = disabled_train for name, param in ln_vision.named_parameters(): param.requires_grad = False ln_vision = ln_vision.eval() ln_vision.train = disabled_train logging.info('[INFO] ViT frozen') else: vit_precision = 'fp32' visual_encoder, ln_vision = self.init_vision_encoder( config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, vit_precision=vit_precision ) logging.info('[INFO] ViT hot') logging.info('[INFO] ViT successfully loaded') ################## 3. Define the ViT-Expert communication Interface ################## self.system_prompt = False self.vit_token_pooling = config.vit_token_pooling if self.vit_token_pooling: vit_proj = nn.Linear( 1408*4, bert_config.hidden_size ) else: vit_proj = nn.Linear( 1408, bert_config.hidden_size ) vit_proj.apply(self.init_weights) spatial_att = SpatialAttention(input_dim=bert_config.hidden_size) temporal_att = TemporalAttention(input_dim=bert_config.hidden_size) spatial_att.apply(self.init_weights) temporal_att.apply(self.init_weights) ################## 4. Define the Expert layers ################## moe_layers = None moe_norm = None if config.use_moes: moe_layers = [] for moe_layer_idx in range(config.num_moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' moe_layer = MoELayer( bert_config.hidden_size, bert_config.num_attention_heads, expert_flag, has_hist=True, use_sep_spatial_temp_experts=config.use_sep_spatial_temp_experts ) moe_layer.apply(self.init_weights) moe_layers.append(moe_layer) logging.info(f'[INFO] {moe_layer_idx+1}/{config.num_moe_layers} MoE layers successfully loaded') moe_layers = nn.ModuleList(moe_layers) moe_norm = nn.LayerNorm(bert_config.hidden_size) ################## 5. Define the projection layers for contrastive learning ################## # temp_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) # spatial_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) # vision_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) # cap_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) # temp_proj.apply(self.init_weights) # spatial_proj.apply(self.init_weights) # vision_proj.apply(self.init_weights) # cap_proj.apply(self.init_weights) ################## 6. Define the pooler for matching loss ################## # pooler = Pooler(bert_config.hidden_size) # pooler.apply(self.init_weights) ################## 5. Attach the matching heads ################## # stm_head = nn.Linear(bert_config.hidden_size, 2) # vcm_head = nn.Linear(bert_config.hidden_size, 2) # lm_head = nn.Linear(bert_config.hidden_size, len(tokenizer)) # stm_head.apply(self.init_weights) # vcm_head.apply(self.init_weights) # lm_head.apply(self.init_weights) temp = nn.Parameter(0.07 * torch.ones([])) # temp = 0.07 # Attach the components to self if self.config.embed_from_llm: self.tokenizer = llm_tokenizer self.text_embedding = llm_embed else: self.tokenizer = tokenizer self.text_embedding = text_embedding self.token_type_embedding = token_type_embedding self.llm = llm self.llm_to_moe = llm_to_moe self.moe_to_llm = moe_to_llm self.visual_encoder = visual_encoder self.ln_vision = ln_vision self.vit_proj = vit_proj self.moe_layers = moe_layers self.moe_norm = moe_norm self.spatial_att = spatial_att self.temporal_att = temporal_att # self.temp_proj = temp_proj # self.spatial_proj = spatial_proj # self.vision_proj = vision_proj # self.cap_proj = cap_proj # self.pooler = pooler # self.stm_head = stm_head # self.vcm_head = vcm_head # self.lm_head = lm_head self.temp = temp def construct_global_input(self, cap_ids, cap_attention_mask, hist_ids, hist_attention_mask, vid_feat_len, device): # for video: [spatial_feats][temp_feats][cap_feats][hist_feats] batch_size = cap_ids.size(0) special_toks_indices = { '': 0, '': 1, '': 2, } ids = [self.added_vocab['']] + [self.added_vocab['']] + [self.added_vocab['']] ids += vid_feat_len * [self.added_vocab['']] ids += [self.added_vocab['']] special_toks_indices[''] = len(ids) - 1 ids += vid_feat_len * [self.added_vocab['']] ids += [self.added_vocab['']] special_toks_indices[''] = len(ids) - 1 ids += cap_ids.size(1) * [self.added_vocab['']] ids += [self.added_vocab['']] special_toks_indices[''] = len(ids) - 1 ids += hist_ids.size(1) * [self.added_vocab['']] ids += [self.added_vocab['']] special_toks_indices[''] = len(ids) - 1 total_len = len(ids) ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) ids[:, special_toks_indices[''] + 1: special_toks_indices['']] = cap_ids ids[:, special_toks_indices[''] + 1: special_toks_indices['']] = hist_ids mask = torch.ones((batch_size, total_len), device=device) mask[:, special_toks_indices[''] + 1: special_toks_indices['']] = cap_attention_mask mask[:, special_toks_indices[''] + 1: special_toks_indices['']] = hist_attention_mask return ids, mask, special_toks_indices def construct_reg_labels(self, regress_ids, start_regress_idx, full_embeds, device): full_labels = torch.LongTensor(full_embeds.size(0), full_embeds.size(1)).fill_(-100).to(device) for i in range(regress_ids.size(0)): full_labels[i, start_regress_idx[i]: start_regress_idx[i] + regress_ids[i].size(-1)] = regress_ids[i] # Add to the labels -- just before the response starts full_labels[i, start_regress_idx[i] - 1] = self.tokenizer.eos_token_id # labels = regress_ids.masked_fill( # regress_ids == self.tokenizer.pad_token_id, -100 # ).to(device) # eos_from_cond = torch.LongTensor(labels.size(0), 1).fill_(self.tokenizer.eos_token_id).to(device) # labels = torch.concat([eos_from_cond, labels], dim=1) # full_labels = torch.LongTensor(labels.size(0), full_len).fill_(-100).to(device) # full_labels[:, len_cond-1:] = labels return full_labels def rearrange_llm_input_decoder_only(self, input_embeds, output_emebds, input_mask, cap_mask, hist_mask, output_mask, spatial_feat_len): ''' Push all pads to the right ''' # full_embeds = [...][...][...][pad][...][pad][ans ...][pad] # ------------> [...][...][...][...][ans ...][-----pad-----] init_len = input_embeds.size(1) + output_emebds.size(1) # First, we compute the initial offset of the visual features offset = 3 + spatial_feat_len + 1 + spatial_feat_len # --> input_embeds[offset] = h_ offset_embeds = input_embeds[:, :offset, :] offset_mask = input_mask[:, :offset] rest_input_embdes = input_embeds[:, offset:, :] rest_input_mask = input_mask[:, offset:] start_output_idx = [] full_embeds = [] full_masks = [] for i in range(input_embeds.size(0)): output_emebd_i = output_emebds[i] output_mask_i = output_mask[i] cap_mask_i = cap_mask[i] len_cap_i = cap_mask_i.sum() end_cap_i = len_cap_i + 1 # +1 for the token cap_embdes_i_to_keep = rest_input_embdes[i, :end_cap_i, :] cap_mask_i_to_keep = rest_input_mask[i, :end_cap_i,] cap_embeds_i_to_push = rest_input_embdes[i, end_cap_i:cap_mask_i.size(-1) + 1, :] # +1 for the token cap_mask_i_to_push = rest_input_mask[i, end_cap_i: cap_mask_i.size(-1) + 1] # +1 for the token hist_mask_i = hist_mask[i] len_hist_i = hist_mask_i.sum() start_hist_i = cap_mask_i.size(-1) + 1 end_hist_i = start_hist_i + len_hist_i + 1 # +1 for token # fianl token to keep is which is the last in input_embdes/rest_input_embdes final_tok_embedding_i = rest_input_embdes[i, -1, :].unsqueeze(0) final_tok_mask_i = rest_input_mask[i, -1].unsqueeze(0) hist_embdes_i_to_keep = rest_input_embdes[i, start_hist_i:end_hist_i, :] hist_mask_i_to_keep = rest_input_mask[i, start_hist_i:end_hist_i] # these two do not consider the last token --> we don't need to extra remove it from them hist_embdes_i_to_push = rest_input_embdes[i, end_hist_i: cap_mask_i.size(-1) + 1 + hist_mask_i.size(-1) + 1, :] hist_mask_i_to_push = rest_input_mask[i, end_hist_i: cap_mask_i.size(-1) + 1 + hist_mask_i.size(-1) + 1] full_embed_i = torch.cat( [cap_embdes_i_to_keep, hist_embdes_i_to_keep, final_tok_embedding_i, output_emebd_i, cap_embeds_i_to_push, hist_embdes_i_to_push], dim=0 ) full_mask_i = torch.cat( [cap_mask_i_to_keep, hist_mask_i_to_keep, final_tok_mask_i, output_mask_i, cap_mask_i_to_push, hist_mask_i_to_push], dim=0 ) start_output_idx.append(offset + cap_embdes_i_to_keep.size(0) + hist_embdes_i_to_keep.size(0) + 1 - 1) full_embeds.append(full_embed_i) full_masks.append(full_mask_i) # Now stack to get the batch full_embeds = torch.stack(full_embeds, dim=0) full_masks = torch.stack(full_masks, dim=0) # Add the offset visual features full_embeds = torch.cat([offset_embeds, full_embeds], dim=1) full_masks = torch.cat([offset_mask, full_masks], dim=1) final_len = full_embeds.size(1) # Sanity check assert init_len == final_len, 'The reconstructed embeds have length ({}) which is not the same as the length of initial embeds ({})'.format( final_len, init_len ) return full_embeds, full_masks, start_output_idx def pad_to_right_enc_dec(self, cap_embeds, cap_masks, hist_embeds, hist_masks, device): """ pushes all in-between pad tokens to the right """ res_embeds = [] res_mask = [] for cap_embed, cap_mask, hist_embed, hist_mask in zip(cap_embeds, cap_masks, hist_embeds, hist_masks): len_cap = sum(cap_mask) len_hist = sum(hist_mask) batch_embed = torch.cat([cap_embed[:len_cap], hist_embed[:len_hist], cap_embed[len_cap:], hist_embed[len_hist:]], dim=0) batch_mask = torch.zeros(batch_embed.size(0)).long().to(device) batch_mask[:len_cap+len_hist] = 1 res_embeds.append(batch_embed) res_mask.append(batch_mask) res_embeds = torch.stack(res_embeds, dim=0) res_mask = torch.stack(res_mask, dim=0) return res_embeds, res_mask def pad_to_right_dec_only(self, cap_embeds, cap_masks, hist_embeds, hist_masks, regress_embeds, regress_masks, device): """ pushes all in-between pad tokens to the right """ res_embeds = [] res_mask = [] regress_limits_txt_input = [] for cap_embed, cap_mask, hist_embed, hist_mask, regress_emebd, regress_mask in zip( cap_embeds, cap_masks, hist_embeds, hist_masks, regress_embeds, regress_masks): len_cap = sum(cap_mask) len_hist = sum(hist_mask) len_ans = sum(regress_mask) regress_limits_txt_input.append((len_cap+len_hist, len_cap+len_hist+len_ans)) batch_embed = torch.cat([cap_embed[:len_cap], hist_embed[:len_hist], regress_emebd, cap_embed[len_cap:], hist_embed[len_hist:]], dim=0) batch_mask = torch.zeros(batch_embed.size(0)).long().to(device) batch_mask[:len_cap+len_hist+len_ans] = 1 res_embeds.append(batch_embed) res_mask.append(batch_mask) res_embeds = torch.stack(res_embeds, dim=0) res_mask = torch.stack(res_mask, dim=0) return res_embeds, res_mask, regress_limits_txt_input def pad_to_right_dec_only_gen_mode(self, cap_embeds, cap_masks, hist_embeds, hist_masks, device): """ pushes all in-between pad tokens to the right """ res_embeds = [] res_mask = [] for cap_embed, cap_mask, hist_embed, hist_mask in zip(cap_embeds, cap_masks, hist_embeds, hist_masks): len_cap = sum(cap_mask) len_hist = sum(hist_mask) batch_embed = torch.cat([cap_embed[:len_cap], hist_embed[:len_hist], cap_embed[len_cap:], hist_embed[len_hist:]], dim=0) batch_mask = torch.zeros(batch_embed.size(0)).long().to(device) batch_mask[:len_cap+len_hist] = 1 res_embeds.append(batch_embed) res_mask.append(batch_mask) res_embeds = torch.stack(res_embeds, dim=0) res_mask = torch.stack(res_mask, dim=0) return res_embeds, res_mask def encode_vis_with_seq_spa_temp_att(self, image, device, is_vid=True): num_frames = image.size(1) bs_pre_reshape = image.size(0) if len(image.shape) > 4: image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) # with self.maybe_autocast(): # inherited from Blip2Base image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) bs, pn, hs = image_embeds.shape if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) vis_embed = self.vit_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096) # reshape the video features vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) size_orig = vis_embed.size() # Perfrom spatial temporal attention vis_embed = self.spatial_att(vis_embed) if is_vid: vis_embed = vis_embed.view(size_orig) vis_embed = self.temporal_att(vis_embed) vis_feat_len = vis_embed.size(1) # vis_embed = vis_embed + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) vis_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) return vis_embed, vis_mask def moe_forward_no_sep_spatial_temporal( self, vis, vis_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device): # is_vid = media_type == 'webvid' # batch_size = len(cap) vis_feat_len = vis.size(1) input_embeds = [] input_masks = [] input_embeds.append(vis) input_masks.append(vis_mask) # if is_vid: # input_embeds.append(vis_temporal) # input_masks.append(vis_temporal_mask) if self.config.embed_from_llm: cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids)) else: cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2)) cap_feat_len = cap_embeds.size(1) input_embeds.append(cap_embeds) input_masks.append(cap_mask) if self.config.embed_from_llm: hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids)) else: hist_embeds = self.text_embedding(hist_ids) + self.token_type_embedding(torch.ones_like(hist_ids).long().fill_(2)) hist_feat_len = hist_embeds.size(1) input_embeds.append(hist_embeds) input_masks.append(hist_mask) input_embeds = torch.cat(input_embeds, dim=1) input_masks = torch.cat(input_masks, dim=1) # expand the mask input_masks = self.get_extended_attention_mask(attention_mask=input_masks) # MoEs feed-forward for moe_layer_idx, moe_layer in enumerate(self.moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len, is_vid=is_vid, mask=input_masks) #TODO normalize the output () !!!!!! input_embeds = self.moe_norm(input_embeds) # return the features vis_embeds = input_embeds[:, :vis_feat_len] # temporal_embeds = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None cap_embeds = input_embeds[:, -(cap_feat_len + hist_feat_len): -hist_feat_len] hist_embeds = input_embeds[:, -hist_feat_len:] # cls_feats = self.pooler(cap_feats) moe_outputs = { 'vis_embeds': vis_embeds, # 'temporal_embeds': temporal_embeds, 'cap_embeds': cap_embeds, 'hist_embeds': hist_embeds, # 'cls_feats': cls_feats, # 'last_hidden': input_embeds } return moe_outputs def moe_forward( self, vis_spatial, vis_spatial_mask, vis_temporal, vis_temporal_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device): # is_vid = media_type == 'webvid' # batch_size = len(cap) vis_feat_len = vis_spatial.size(1) input_embeds = [] input_masks = [] input_embeds.append(vis_spatial) input_masks.append(vis_spatial_mask) if is_vid: input_embeds.append(vis_temporal) input_masks.append(vis_temporal_mask) if self.config.embed_from_llm: cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids)) else: cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2)) cap_feat_len = cap_embeds.size(1) input_embeds.append(cap_embeds) input_masks.append(cap_mask) if self.config.embed_from_llm: hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids)) else: hist_embeds = self.text_embedding(hist_ids) + self.token_type_embedding(torch.ones_like(hist_ids).long().fill_(2)) hist_feat_len = hist_embeds.size(1) input_embeds.append(hist_embeds) input_masks.append(hist_mask) input_embeds = torch.cat(input_embeds, dim=1) input_masks = torch.cat(input_masks, dim=1) # expand the mask input_masks = self.get_extended_attention_mask(attention_mask=input_masks) # MoEs feed-forward for moe_layer_idx, moe_layer in enumerate(self.moe_layers): if moe_layer_idx < self.config.num_moe_modality_layers: expert_flag = 'modalities' else: expert_flag = 'fusion' input_embeds = moe_layer( input_embeds, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len, is_vid=is_vid, mask=input_masks, expert_permutation=self.config.expert_permutation ) #TODO normalize the output () !!!!!! input_embeds = self.moe_norm(input_embeds) # return the features spatial_embeds = input_embeds[:, :vis_feat_len] temporal_embeds = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None cap_embeds = input_embeds[:, -(cap_feat_len + hist_feat_len): -hist_feat_len] hist_embeds = input_embeds[:, -hist_feat_len:] # cls_feats = self.pooler(cap_feats) moe_outputs = { 'spatial_embeds': spatial_embeds, 'temporal_embeds': temporal_embeds, 'cap_embeds': cap_embeds, 'hist_embeds': hist_embeds, # 'cls_feats': cls_feats, # 'last_hidden': input_embeds } return moe_outputs def forward(self, vis, cap, hist, ans, media_type): device = vis.device is_vid = media_type in ['webvid', 'champagne', 'avsd', 'nextqa'] loss_stc = torch.tensor(0) loss_stm = torch.tensor(0) loss_vhc = torch.tensor(0) loss_vhm = torch.tensor(0) loss_gen = torch.tensor(0) # construct the global input tensor --> use place holder for vis features cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=None) hist_ids, hist_mask = self.tokenize_text(hist, device, max_len=None) if self.config.use_moes: # First get the visual features depending on the media type if self.config.use_sep_spatial_temp_experts: vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) spatial_feat_len = vis_embed_spatial.size(1) else: vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) if self.config.use_sep_spatial_temp_experts: moe_outputs = self.moe_forward( vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device ) spatial_embeds = self.moe_to_llm(moe_outputs['spatial_embeds']) temporal_embeds = self.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None # cap_embeds = self.moe_to_llm(moe_outputs['cap_embeds']) # hist_embeds = self.moe_to_llm(moe_outputs['hist_embeds']) else: moe_outputs = self.moe_forward_no_sep_spatial_temporal( vis_embed, vis_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device ) vis_embeds = self.moe_to_llm(moe_outputs['vis_embeds']) # temporal_embeds = self.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None cap_embeds = self.moe_to_llm(moe_outputs['cap_embeds']) hist_embeds = self.moe_to_llm(moe_outputs['hist_embeds']) else: cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids)) hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids)) vis_embeds, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) ans = [a + self.tokenizer.eos_token for a in ans] if self.config.llm_family in ['llama', 'mistral']: bos = torch.ones_like(cap_ids[:, :1]) * self.tokenizer.bos_token_id bos_embeds = self.text_embedding(bos) bos_mask = cap_mask[:, :1] # add corresponding eos regress_ids, regress_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest regress_embeds = self.text_embedding(regress_ids) inputs_embeds, attention_mask, regress_limits_txt_input = self.pad_to_right_dec_only(cap_embeds, cap_mask, hist_embeds, hist_mask, regress_embeds, regress_mask, device) if is_vid: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) labels = torch.zeros(inputs_embeds.size()[:-1]).fill_(-100).long().to(device) for i in range(labels.size(0)): start_regress = regress_limits_txt_input[i][0] + 1 + spatial_feat_len + spatial_feat_len * int(is_vid) # offset (bos + spatial + temporal) end_regress = regress_limits_txt_input[i][1] + 1 + spatial_feat_len + spatial_feat_len * int(is_vid) # offset (bos + spatial + temporal) labels[i, start_regress:end_regress] = regress_ids[i, :regress_mask[i].sum()] # get causal attention mask # Compute the regression embeds # Now we need to right-pad the input to LLM (at least for llama) to avoid nan loss values # This means, all pad tokens have to be placed to the right # full_embeds = [...][...][...][pad][...][pad][ans ...][pad] # ------------> [...][...][...][...][ans ...][-----pad-----] # full_embeds, full_masks, start_output_idx = self.rearrange_llm_input_dec_only(cond_embeds, regress_embeds, cond_mask, cap_mask, hist_mask, regress_mask, spatial_feat_len) # labels = self.construct_reg_labels(regress_ids, start_output_idx, full_embeds, device) lm_outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, return_dict=True ) loss_gen = lm_outputs.loss # Encoder Decoder else: inputs_embeds, attention_mask = self.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) # now merge the multi-modal inputs if self.config.use_moes: if self.config.use_sep_spatial_temp_experts: if is_vid: inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_mask, attention_mask], dim=1) decoder_ids, decoder_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest labels = decoder_ids.masked_fill(decoder_ids == self.tokenizer.pad_token_id, -100) decoder_ids = self.shift_right(labels) decoder_inputs_embeds = self.text_embedding(decoder_ids) lm_outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds, decoder_attention_mask=decoder_mask, labels=labels, return_dict=True ) loss_gen = lm_outputs.loss return dict( loss_stc = loss_stc * self.config.loss_dict['stc'], loss_stm = loss_stm * self.config.loss_dict['stm'], loss_vhc = loss_vhc * self.config.loss_dict['vhc'], loss_vhm = loss_vhm * self.config.loss_dict['vhm'], loss_gen = loss_gen * self.config.loss_dict['gen'], ) class V2DialNoMoes(V2Dial): def __init__(self, config): super(V2DialNoMoes, self).__init__(config) def encode_vis(self, image, device, is_vid=True): num_frames = image.size(1) bs_pre_reshape = image.size(0) if len(image.shape) > 4: image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) # with self.maybe_autocast(): # inherited from Blip2Base image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) bs, pn, hs = image_embeds.shape if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) vis_embed = self.vit_proj(image_embeds) # project to LLM input size (200,64,5632) -> (200,64, d_hidden) # reshape the video features vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) # Perfrom spatial temporal attention if is_vid: vis_embed = self.temporal_att(vis_embed) if not self.config.embed_from_llm: vis_embed_temporal = vis_embed_temporal + self.token_type_embedding(torch.ones(bs_pre_reshape, vis_feat_len).long().to(device)) # vis_temporal_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) vis_embed = self.spatial_att(vis_embed) vis_feat_len = vis_embed_spatial.size(1) if not self.config.embed_from_llm: vis_embed_spatial = vis_embed_spatial + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) vis_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) return vis_embed, vis_mask def forward(self, vis, cap, hist, ans, media_type): device = vis.device is_vid = media_type in ['webvid', 'champagne', 'avsd', 'nextqa'] loss_stc = torch.tensor(0) loss_stm = torch.tensor(0) loss_vhc = torch.tensor(0) loss_vhm = torch.tensor(0) loss_gen = torch.tensor(0) # First get the visual features depending on the media type vis_embed, vis_mask = self.encode_vis(vis, device, is_vid=is_vid) # spatial_feat_len = vis_embed_spatial.size(1) # construct the global input tensor --> use place holder for vis features # text = (c + h for c,h in zip(cap, hist)) # cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=None) # hist_ids, hist_mask = self.tokenize_text(hist, device, max_len=None) # text_ids, text_mask = self.tokenize_text(text, device, max_len=None) text_embeds = self.text_embedding(text_ids) # moe_outputs = self.moe_forward( # vis_embed_spatial, vis_spatial_mask, # vis_embed_temporal, vis_temporal_mask, # cap_ids, cap_mask, # hist_ids, hist_mask, # is_vid, device # ) # spatial_embeds = self.moe_to_llm(moe_outputs['spatial_embeds']) # temporal_embeds = self.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None # cap_embeds = self.moe_to_llm(moe_outputs['cap_embeds']) # hist_embeds = self.moe_to_llm(moe_outputs['hist_embeds']) ans = [a + self.tokenizer.eos_token for a in ans] if self.config.llm_family in ['llama', 'mistral']: bos = torch.ones_like(cap_ids[:, :1]) * self.tokenizer.bos_token_id bos_embeds = self.text_embedding(bos) bos_mask = cap_mask[:, :1] # add corresponding eos regress_ids, regress_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest regress_embeds = self.text_embedding(regress_ids) inputs_embeds, attention_mask, regress_limits_txt_input = self.pad_to_right_dec_only(cap_embeds, cap_mask, hist_embeds, hist_mask, regress_embeds, regress_mask, device) if is_vid: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) labels = torch.zeros(inputs_embeds.size()[:-1]).fill_(-100).long().to(device) for i in range(labels.size(0)): start_regress = regress_limits_txt_input[i][0] + 1 + spatial_feat_len + spatial_feat_len * int(is_vid) # offset (bos + spatial + temporal) end_regress = regress_limits_txt_input[i][1] + 1 + spatial_feat_len + spatial_feat_len * int(is_vid) # offset (bos + spatial + temporal) labels[i, start_regress:end_regress] = regress_ids[i, :regress_mask[i].sum()] # get causal attention mask # Compute the regression embeds # Now we need to right-pad the input to LLM (at least for llama) to avoid nan loss values # This means, all pad tokens have to be placed to the right # full_embeds = [...][...][...][pad][...][pad][ans ...][pad] # ------------> [...][...][...][...][ans ...][-----pad-----] # full_embeds, full_masks, start_output_idx = self.rearrange_llm_input_dec_only(cond_embeds, regress_embeds, cond_mask, cap_mask, hist_mask, regress_mask, spatial_feat_len) # labels = self.construct_reg_labels(regress_ids, start_output_idx, full_embeds, device) lm_outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, return_dict=True ) loss_gen = lm_outputs.loss # Encoder Decoder else: # inputs_embeds, attention_mask = self.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) # now merge the multi-modal inputs # if is_vid: # inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) # attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) # else: inputs_embeds = torch.cat([vis_embed, text_embeds], dim=1) attention_mask = torch.cat([vis_mask, text_mask], dim=1) decoder_ids, decoder_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest labels = decoder_ids.masked_fill(decoder_ids == self.tokenizer.pad_token_id, -100) decoder_ids = self.shift_right(labels) decoder_inputs_embeds = self.text_embedding(decoder_ids) lm_outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds, decoder_attention_mask=decoder_mask, labels=labels, return_dict=True ) loss_gen = lm_outputs.loss return dict( loss_stc = loss_stc * self.config.loss_dict['stc'], loss_stm = loss_stm * self.config.loss_dict['stm'], loss_vhc = loss_vhc * self.config.loss_dict['vhc'], loss_vhm = loss_vhm * self.config.loss_dict['vhm'], loss_gen = loss_gen * self.config.loss_dict['gen'], )