from functools import lru_cache import torch import torch.nn.functional as F from torch import nn from models.utils import allgather_wgrad from utils.dist import get_rank, get_world_size from utils.easydict import EasyDict def get_sim( x_proj: torch.Tensor, y_proj: torch.Tensor, temp=1.0, ): """calculate pair-wise similarity between two modalities x and y. Args: x_proj (torch.Tensor): The representation of modality x. Shape: [B,T,C] or [B,C]. y_proj (torch.Tensor): The representation of modality y. Shape: [B,C]. temp (torch.Tensor): The temperature. Shape: []. Returns: The similarity between modality x and y. Shape: [B,B]. """ x_proj = F.normalize(x_proj, dim=-1) y_proj = F.normalize(y_proj, dim=-1) assert x_proj.dim() in [2, 3] assert y_proj.dim() == 2 if x_proj.dim() == 2: sim_x2y = torch.einsum("md,nd->mn", x_proj, y_proj) / temp # (B,B) else: sim_x2y = torch.einsum("mld,nd->mln", x_proj, y_proj).mean(1) / temp # (B,B) sim_y2x = sim_x2y.T return sim_x2y, sim_y2x class ContMatchLoss(nn.Module): def __init__(self): super(ContMatchLoss, self).__init__() @torch.no_grad() def get_mask(self, sim, idx=None, normalize=False): """ Args: sim (torch.Tensor): The similarity between videos and texts. shape: (B, B). idx (torch.Tensor): The index for each video. Shape: [B]. normalize (bool): If true, make row sum equal to 1 """ if idx is not None: idx = idx.view(-1, 1) mask = torch.eq(idx, idx.T).to(sim.dtype) if normalize: mask = mask / mask.sum(1, keepdim=True) else: mask = torch.zeros_like(sim) mask.fill_diagonal_(1) return mask # `1` mark valid/matched location @lru_cache(maxsize=16) def get_gather_args(self): """obtain the args for all_gather Returns: dict. """ return EasyDict({"world_size": get_world_size(), "rank": get_rank()}) class STC_STM_Loss(ContMatchLoss): """Contrastive and matching losses""" def __init__(self): super(STC_STM_Loss, self).__init__() def stc_loss( self, temporal_proj: torch.Tensor, spatial_proj: torch.Tensor, idx: torch.Tensor, temp=1.0, all_gather=True ): """forward to calculate the loss Args: vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. text_proj (torch.Tensor): The text representation. Shape: [B,C]. idx (torch.Tensor): The index for each example. Shape: [B,]. temp (torch.Tensor): The temperature. Shape: []. all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. """ if all_gather: gather_args = self.get_gather_args() temporal_proj = allgather_wgrad(temporal_proj, gather_args) spatial_proj = allgather_wgrad(spatial_proj, gather_args) if idx is not None: idx = allgather_wgrad(idx, gather_args) sim_t2s, sim_s2t = get_sim(temporal_proj, spatial_proj, temp) with torch.no_grad(): sim_t2s_targets = self.get_mask(sim_t2s, idx=idx, normalize=True) sim_s2t_targets = sim_t2s_targets loss_t2s = -torch.sum(F.log_softmax(sim_t2s, dim=1) * sim_t2s_targets, dim=1).mean() loss_s2t = -torch.sum(F.log_softmax(sim_s2t, dim=1) * sim_s2t_targets, dim=1).mean() loss_stc = (loss_t2s + loss_s2t) / 2 return loss_stc def stm_loss( self, grounding_expert, stm_head, # temp, spatial_embeds_orig, temporal_embeds_orig, temporal_proj, spatial_proj, idx, generation=False, temp=1.0 ): spatial_embeds = spatial_embeds_orig.clone() temporal_embeds = temporal_embeds_orig.clone() with torch.no_grad(): sim_s2t, sim_t2s = get_sim(temporal_proj, spatial_proj, temp) spatial_atts = torch.ones( spatial_embeds.size()[:-1], dtype=torch.long, device=spatial_embeds.device ) temporal_atts = torch.ones( temporal_embeds.size()[:-1], dtype=torch.long, device=temporal_embeds.device ) weights_s2t = F.softmax(sim_s2t + 1e-4, dim=1) # (N, N) weights_t2s = F.softmax(sim_t2s + 1e-4, dim=1) mask = self.get_mask(sim_s2t, idx=idx).bool() weights_s2t.masked_fill_(mask, 0) weights_t2s.masked_fill_(mask, 0) weights_s2t = torch.nan_to_num_(weights_s2t, nan=1e-2, posinf=1e-2, neginf=1e-2) weights_t2s = torch.nan_to_num_(weights_t2s, nan=1e-2, posinf=1e-2, neginf=1e-2) if generation: with torch.no_grad(): output = grounding_expert( encoder_embeds=temporal_embeds, attention_mask=temporal_atts, encoder_hidden_states=spatial_embeds, encoder_attention_mask=spatial_atts, return_dict=True, ) pos_feats = output.last_hidden_state return pos_feats else: # select a hard negatives within the batch spatial_neg_indices = torch.multinomial(weights_s2t, 1).squeeze() temporal_neg_indices = torch.multinomial(weights_t2s, 1).squeeze() spatial_embeds_neg = spatial_embeds[spatial_neg_indices] # [B, L, c] temporal_embeds_neg = temporal_embeds[temporal_neg_indices] # [B, L, d] # temporal_atts_neg = temporal_atts[temporal_neg_indices] # concat embeddings spatial_embeds_all = torch.cat([spatial_embeds, spatial_embeds_neg, spatial_embeds], dim=0) temporal_embeds_all = torch.cat([temporal_embeds, temporal_embeds, temporal_embeds_neg], dim=0) spatial_atts_all = torch.cat([spatial_atts, spatial_atts, spatial_atts], dim=0) temporal_atts_all = torch.cat([temporal_atts, temporal_atts, temporal_atts], dim=0) output = grounding_expert( inputs_embeds=temporal_embeds_all, attention_mask=temporal_atts_all, cross_embeds=spatial_embeds_all, cross_attention_mask=spatial_atts_all, ) stm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) stm_logits = stm_head(stm_embeds) # [3*B, 2] bs = stm_logits.shape[0] // 3 stm_labels = stm_logits.new_ones(3 * bs, dtype=torch.long) stm_labels[bs:] = 0 loss_stm = F.cross_entropy(stm_logits, stm_labels) pos_feats = output.last_hidden_state[:bs] return loss_stm, pos_feats class VCC_VCM_Loss(ContMatchLoss): """Contrastive and matching losses""" def __init__(self): super(VCC_VCM_Loss, self).__init__() def vcc_loss( self, vis_proj: torch.Tensor, cap_proj: torch.Tensor, idx: torch.Tensor, temp=1.0, all_gather=True ): """forward to calculate the loss Args: vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. text_proj (torch.Tensor): The text representation. Shape: [B,C]. idx (torch.Tensor): The index for each example. Shape: [B,]. temp (torch.Tensor): The temperature. Shape: []. all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. """ if all_gather: gather_args = self.get_gather_args() vis_proj = allgather_wgrad(vis_proj, gather_args) cap_proj = allgather_wgrad(cap_proj, gather_args) if idx is not None: idx = allgather_wgrad(idx, gather_args) sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp) with torch.no_grad(): sim_v2c_targets = self.get_mask(sim_v2c, idx=idx, normalize=True) sim_c2v_targets = sim_v2c_targets loss_v2c = -torch.sum(F.log_softmax(sim_v2c, dim=1) * sim_v2c_targets, dim=1).mean() loss_c2v = -torch.sum(F.log_softmax(sim_c2v, dim=1) * sim_c2v_targets, dim=1).mean() loss_vcc = (loss_v2c + loss_c2v) / 2 return loss_vcc def vcm_loss( self, grounding_expert, vcm_head, vis_embeds_orig, cap_embeds_orig, vis_proj, cap_proj, cap_atts, idx, generation=False, temp=1.0 ): vis_embeds = vis_embeds_orig.clone() cap_embeds = cap_embeds_orig.clone() with torch.no_grad(): sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp) vis_atts = torch.ones( vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device ) weights_v2c = F.softmax(sim_v2c + 1e-4, dim=1) # (N, N) weights_c2v = F.softmax(sim_c2v + 1e-4, dim=1) mask = self.get_mask(weights_v2c, idx=idx).bool() weights_v2c.masked_fill_(mask, 0) weights_c2v.masked_fill_(mask, 0) weights_v2c = torch.nan_to_num_(weights_v2c, nan=1e-2, posinf=1e-2, neginf=1e-2) weights_c2v = torch.nan_to_num_(weights_c2v, nan=1e-2, posinf=1e-2, neginf=1e-2) if generation: with torch.no_grad(): output = grounding_expert( encoder_embeds=cap_embeds, attention_mask=cap_atts, encoder_hidden_states=vis_embeds, encoder_attention_mask=vis_atts, return_dict=True, ) pos_feats = output.last_hidden_state return pos_feats else: # select a hard negatives within the batch vis_neg_indices = torch.multinomial(weights_v2c, 1).squeeze() cap_neg_indices = torch.multinomial(weights_c2v, 1).squeeze() vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c] cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, d] cap_atts_neg = cap_atts[cap_neg_indices] # concat embeddings vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0) cap_embeds_all = torch.cat([cap_embeds, cap_embeds, cap_embeds_neg], dim=0) vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0) cap_atts_all = torch.cat([cap_atts, cap_atts, cap_atts_neg], dim=0) output = grounding_expert( inputs_embeds=cap_embeds_all, attention_mask=cap_atts_all, cross_embeds=vis_embeds_all, cross_attention_mask=vis_atts_all, ) vcm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) vcm_logits = vcm_head(vcm_embeds) # [3*B, 2] bs = vcm_logits.shape[0] // 3 vcm_labels = vcm_logits.new_ones(3 * bs, dtype=torch.long) vcm_labels[bs:] = 0 loss_vcm = F.cross_entropy(vcm_logits, vcm_labels) pos_feats = output.last_hidden_state[:bs] return loss_vcm, pos_feats class VHC_VHM_Loss(ContMatchLoss): """Contrastive and matching losses""" def __init__(self): super(VHC_VHM_Loss, self).__init__() def vhc_loss( self, vis_proj: torch.Tensor, hist_proj: torch.Tensor, idx: torch.Tensor, temp=1.0, all_gather=True ): """forward to calculate the loss Args: vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. text_proj (torch.Tensor): The text representation. Shape: [B,C]. idx (torch.Tensor): The index for each example. Shape: [B,]. temp (torch.Tensor): The temperature. Shape: []. all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. """ if all_gather: gather_args = self.get_gather_args() vis_proj = allgather_wgrad(vis_proj, gather_args) hist_proj = allgather_wgrad(hist_proj, gather_args) if idx is not None: idx = allgather_wgrad(idx, gather_args) sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp) with torch.no_grad(): sim_v2h_targets = self.get_mask(sim_v2h, idx=idx, normalize=True) sim_h2v_targets = sim_v2h_targets loss_v2h = -torch.sum(F.log_softmax(sim_v2h, dim=1) * sim_v2h_targets, dim=1).mean() loss_h2v = -torch.sum(F.log_softmax(sim_h2v, dim=1) * sim_h2v_targets, dim=1).mean() loss_vhc = (loss_v2h + loss_h2v) / 2 return loss_vhc def vhm_loss( self, grounding_expert, vhm_head, vis_embeds_orig, hist_embeds_orig, vis_proj, hist_proj, hist_atts, idx, generation=False, temp=1.0, ): vis_embeds = vis_embeds_orig.clone() hist_embeds = hist_embeds_orig.clone() with torch.no_grad(): sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp) vis_atts = torch.ones( vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device ) weights_v2h = F.softmax(sim_v2h + 1e-4, dim=1) # (N, N) weights_h2v = F.softmax(sim_h2v + 1e-4, dim=1) mask = self.get_mask(weights_v2h, idx=idx).bool() weights_v2h.masked_fill_(mask, 0) weights_h2v.masked_fill_(mask, 0) weights_v2h = torch.nan_to_num_(weights_v2h, nan=1e-2, posinf=1e-2, neginf=1e-2) weights_h2v = torch.nan_to_num_(weights_h2v, nan=1e-2, posinf=1e-2, neginf=1e-2) if generation: with torch.no_grad(): output = grounding_expert( encoder_embeds=hist_embeds, attention_mask=hist_atts, encoder_hidden_states=vis_embeds, encoder_attention_mask=vis_atts, return_dict=True, # mode="fusion", ) pos_feats = output.last_hidden_state return pos_feats else: # select a hard negatives within the batch vis_neg_indices = torch.multinomial(weights_v2h, 1).squeeze() hist_neg_indices = torch.multinomial(weights_h2v, 1).squeeze() vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c] hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d] hist_atts_neg = hist_atts[hist_neg_indices] # concat embeddings vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0) hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0) vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0) hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0) output = grounding_expert( inputs_embeds=hist_embeds_all, attention_mask=hist_atts_all, cross_embeds=vis_embeds_all, cross_attention_mask=vis_atts_all, ) vhm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) vhm_logits = vhm_head(vhm_embeds) # [3*B, 2] bs = vhm_logits.shape[0] // 3 vhm_labels = vhm_logits.new_ones(3 * bs, dtype=torch.long) vhm_labels[bs:] = 0 loss_vhm = F.cross_entropy(vhm_logits, vhm_labels) pos_feats = output.last_hidden_state[:bs] return loss_vhm, pos_feats class CHC_CHM_Loss(ContMatchLoss): """Contrastive and matching losses""" def __init__(self): super(CHC_CHM_Loss, self).__init__() def chc_loss( self, cap_proj: torch.Tensor, hist_proj: torch.Tensor, idx: torch.Tensor, temp=1.0, all_gather=True ): """forward to calculate the loss Args: vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. text_proj (torch.Tensor): The text representation. Shape: [B,C]. idx (torch.Tensor): The index for each example. Shape: [B,]. temp (torch.Tensor): The temperature. Shape: []. all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. """ if all_gather: gather_args = self.get_gather_args() cap_proj = allgather_wgrad(cap_proj, gather_args) hist_proj = allgather_wgrad(hist_proj, gather_args) if idx is not None: idx = allgather_wgrad(idx, gather_args) sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp) with torch.no_grad(): sim_c2h_targets = self.get_mask(sim_c2h, idx=idx, normalize=True) sim_h2c_targets = sim_c2h_targets loss_c2h = -torch.sum(F.log_softmax(sim_c2h, dim=1) * sim_c2h_targets, dim=1).mean() loss_h2c = -torch.sum(F.log_softmax(sim_h2c, dim=1) * sim_h2c_targets, dim=1).mean() loss_chc = (loss_c2h + loss_h2c) / 2 return loss_chc def chm_loss( self, grounding_expert, chm_head, cap_embeds_orig, hist_embeds_orig, cap_proj, hist_proj, cap_atts, hist_atts, idx, generation=False, temp=1.0 ): cap_embeds = cap_embeds_orig.clone() hist_embeds = hist_embeds_orig.clone() with torch.no_grad(): sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp) weights_c2h = F.softmax(sim_c2h + 1e-4, dim=1) # (N, N) weights_h2c = F.softmax(sim_h2c + 1e-4, dim=1) mask = self.get_mask(weights_c2h, idx=idx).bool() weights_c2h.masked_fill_(mask, 0) weights_h2c.masked_fill_(mask, 0) weights_c2h = torch.nan_to_num_(weights_c2h, nan=1e-2, posinf=1e-2, neginf=1e-2) weights_h2c = torch.nan_to_num_(weights_h2c, nan=1e-2, posinf=1e-2, neginf=1e-2) if generation: with torch.no_grad(): output = grounding_expert( encoder_embeds=hist_embeds, attention_mask=hist_atts, encoder_hidden_states=cap_embeds, encoder_attention_mask=cap_atts, return_dict=True, ) pos_feats = output.last_hidden_state return pos_feats else: # select a hard negatives within the batch cap_neg_indices = torch.multinomial(weights_c2h, 1).squeeze() hist_neg_indices = torch.multinomial(weights_h2c, 1).squeeze() cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, c] cap_atts_neg = cap_atts[cap_neg_indices] hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d] hist_atts_neg = hist_atts[hist_neg_indices] # concat embeddings cap_embeds_all = torch.cat([cap_embeds, cap_embeds_neg, cap_embeds], dim=0) hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0) cap_atts_all = torch.cat([cap_atts, cap_atts_neg, cap_atts], dim=0) hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0) output = grounding_expert( inputs_embeds=hist_embeds_all, attention_mask=hist_atts_all, cross_embeds=cap_embeds_all, cross_attention_mask=cap_atts_all, ) chm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) chm_logits = chm_head(chm_embeds) # [3*B, 2] bs = chm_logits.shape[0] // 3 chm_labels = chm_logits.new_ones(3 * bs, dtype=torch.long) chm_labels[bs:] = 0 loss_chm = F.cross_entropy(chm_logits, chm_labels) pos_feats = output.last_hidden_state[:bs] return loss_chm, pos_feats class MLMLoss(nn.Module): """masked language modeling loss.""" def __init__(self, masking_prob, tokenizer): super(MLMLoss, self).__init__() self.tokenizer = tokenizer self.masking_prob = masking_prob def mlm_loss( self, text_encoder, text, text_embeds, vision_embeds, vision_atts, ): input_ids = text.input_ids.clone() labels = input_ids.clone() probability_matrix = torch.full(labels.shape, self.masking_prob) input_ids, labels = self.mask( input_ids, text_encoder.config.vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix, ) # intermediate_mlm_output = text_encoder.bert( # input_ids, # attention_mask=text.attention_mask, # encoder_hidden_states=vision_embeds, # encoder_attention_mask=vision_atts, # return_dict=True, # # mode="text", # ) # text_embeds = intermediate_mlm_output.last_hidden_state mlm_output = text_encoder( encoder_embeds=text_embeds, attention_mask=text.attention_mask, encoder_hidden_states=vision_embeds, encoder_attention_mask=vision_atts, return_dict=True, labels=labels, soft_labels=None, # mode="fusion", ) return mlm_output.loss def mask( self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None, ): if masked_indices is None: masked_indices = torch.bernoulli(probability_matrix).bool() masked_indices[input_ids == self.tokenizer.pad_token_id] = False masked_indices[input_ids == self.tokenizer.cls_token_id] = False if targets is not None: # We only compute loss on masked tokens targets[~masked_indices] = -100 # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = ( torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices ) input_ids[indices_replaced] = self.tokenizer.mask_token_id # 10% of the time, we replace masked input tokens with random word indices_random = ( torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced ) random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) input_ids[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged if targets is not None: return input_ids, targets else: return input_ids