654 lines
23 KiB
Python
654 lines
23 KiB
Python
from functools import lru_cache
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from models.utils import allgather_wgrad
|
|
from utils.dist import get_rank, get_world_size
|
|
from utils.easydict import EasyDict
|
|
|
|
|
|
def get_sim(
|
|
x_proj: torch.Tensor,
|
|
y_proj: torch.Tensor,
|
|
temp=1.0,
|
|
):
|
|
"""calculate pair-wise similarity between two modalities x and y.
|
|
|
|
Args:
|
|
x_proj (torch.Tensor): The representation of modality x. Shape: [B,T,C] or [B,C].
|
|
y_proj (torch.Tensor): The representation of modality y. Shape: [B,C].
|
|
temp (torch.Tensor): The temperature. Shape: [].
|
|
|
|
Returns: The similarity between modality x and y. Shape: [B,B].
|
|
|
|
"""
|
|
x_proj = F.normalize(x_proj, dim=-1)
|
|
y_proj = F.normalize(y_proj, dim=-1)
|
|
assert x_proj.dim() in [2, 3]
|
|
assert y_proj.dim() == 2
|
|
if x_proj.dim() == 2:
|
|
sim_x2y = torch.einsum("md,nd->mn", x_proj, y_proj) / temp # (B,B)
|
|
else:
|
|
sim_x2y = torch.einsum("mld,nd->mln", x_proj, y_proj).mean(1) / temp # (B,B)
|
|
sim_y2x = sim_x2y.T
|
|
return sim_x2y, sim_y2x
|
|
|
|
|
|
class ContMatchLoss(nn.Module):
|
|
def __init__(self):
|
|
super(ContMatchLoss, self).__init__()
|
|
|
|
@torch.no_grad()
|
|
def get_mask(self, sim, idx=None, normalize=False):
|
|
"""
|
|
Args:
|
|
sim (torch.Tensor): The similarity between videos and texts. shape: (B, B).
|
|
idx (torch.Tensor): The index for each video. Shape: [B].
|
|
normalize (bool): If true, make row sum equal to 1
|
|
"""
|
|
if idx is not None:
|
|
idx = idx.view(-1, 1)
|
|
mask = torch.eq(idx, idx.T).to(sim.dtype)
|
|
if normalize:
|
|
mask = mask / mask.sum(1, keepdim=True)
|
|
else:
|
|
mask = torch.zeros_like(sim)
|
|
mask.fill_diagonal_(1)
|
|
return mask # `1` mark valid/matched location
|
|
|
|
@lru_cache(maxsize=16)
|
|
def get_gather_args(self):
|
|
"""obtain the args for all_gather
|
|
Returns: dict.
|
|
|
|
"""
|
|
return EasyDict({"world_size": get_world_size(), "rank": get_rank()})
|
|
|
|
|
|
class STC_STM_Loss(ContMatchLoss):
|
|
"""Contrastive and matching losses"""
|
|
|
|
def __init__(self):
|
|
super(STC_STM_Loss, self).__init__()
|
|
|
|
def stc_loss(
|
|
self,
|
|
temporal_proj: torch.Tensor,
|
|
spatial_proj: torch.Tensor,
|
|
idx: torch.Tensor,
|
|
temp=1.0,
|
|
all_gather=True
|
|
):
|
|
|
|
"""forward to calculate the loss
|
|
|
|
Args:
|
|
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
|
|
text_proj (torch.Tensor): The text representation. Shape: [B,C].
|
|
idx (torch.Tensor): The index for each example. Shape: [B,].
|
|
temp (torch.Tensor): The temperature. Shape: [].
|
|
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
|
|
|
|
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
|
|
|
|
"""
|
|
if all_gather:
|
|
gather_args = self.get_gather_args()
|
|
temporal_proj = allgather_wgrad(temporal_proj, gather_args)
|
|
spatial_proj = allgather_wgrad(spatial_proj, gather_args)
|
|
if idx is not None:
|
|
idx = allgather_wgrad(idx, gather_args)
|
|
|
|
sim_t2s, sim_s2t = get_sim(temporal_proj, spatial_proj, temp)
|
|
|
|
with torch.no_grad():
|
|
sim_t2s_targets = self.get_mask(sim_t2s, idx=idx, normalize=True)
|
|
sim_s2t_targets = sim_t2s_targets
|
|
|
|
loss_t2s = -torch.sum(F.log_softmax(sim_t2s, dim=1) * sim_t2s_targets, dim=1).mean()
|
|
loss_s2t = -torch.sum(F.log_softmax(sim_s2t, dim=1) * sim_s2t_targets, dim=1).mean()
|
|
|
|
loss_stc = (loss_t2s + loss_s2t) / 2
|
|
return loss_stc
|
|
|
|
def stm_loss(
|
|
self,
|
|
grounding_expert,
|
|
stm_head,
|
|
# temp,
|
|
spatial_embeds_orig,
|
|
temporal_embeds_orig,
|
|
temporal_proj,
|
|
spatial_proj,
|
|
idx,
|
|
generation=False,
|
|
temp=1.0
|
|
):
|
|
spatial_embeds = spatial_embeds_orig.clone()
|
|
temporal_embeds = temporal_embeds_orig.clone()
|
|
with torch.no_grad():
|
|
sim_s2t, sim_t2s = get_sim(temporal_proj, spatial_proj, temp)
|
|
spatial_atts = torch.ones(
|
|
spatial_embeds.size()[:-1], dtype=torch.long, device=spatial_embeds.device
|
|
)
|
|
temporal_atts = torch.ones(
|
|
temporal_embeds.size()[:-1], dtype=torch.long, device=temporal_embeds.device
|
|
)
|
|
weights_s2t = F.softmax(sim_s2t + 1e-4, dim=1) # (N, N)
|
|
weights_t2s = F.softmax(sim_t2s + 1e-4, dim=1)
|
|
|
|
mask = self.get_mask(sim_s2t, idx=idx).bool()
|
|
weights_s2t.masked_fill_(mask, 0)
|
|
weights_t2s.masked_fill_(mask, 0)
|
|
weights_s2t = torch.nan_to_num_(weights_s2t, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
|
weights_t2s = torch.nan_to_num_(weights_t2s, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
|
|
|
if generation:
|
|
with torch.no_grad():
|
|
output = grounding_expert(
|
|
encoder_embeds=temporal_embeds,
|
|
attention_mask=temporal_atts,
|
|
encoder_hidden_states=spatial_embeds,
|
|
encoder_attention_mask=spatial_atts,
|
|
return_dict=True,
|
|
)
|
|
pos_feats = output.last_hidden_state
|
|
return pos_feats
|
|
|
|
else:
|
|
# select a hard negatives within the batch
|
|
spatial_neg_indices = torch.multinomial(weights_s2t, 1).squeeze()
|
|
temporal_neg_indices = torch.multinomial(weights_t2s, 1).squeeze()
|
|
|
|
|
|
spatial_embeds_neg = spatial_embeds[spatial_neg_indices] # [B, L, c]
|
|
temporal_embeds_neg = temporal_embeds[temporal_neg_indices] # [B, L, d]
|
|
# temporal_atts_neg = temporal_atts[temporal_neg_indices]
|
|
|
|
# concat embeddings
|
|
spatial_embeds_all = torch.cat([spatial_embeds, spatial_embeds_neg, spatial_embeds], dim=0)
|
|
temporal_embeds_all = torch.cat([temporal_embeds, temporal_embeds, temporal_embeds_neg], dim=0)
|
|
spatial_atts_all = torch.cat([spatial_atts, spatial_atts, spatial_atts], dim=0)
|
|
temporal_atts_all = torch.cat([temporal_atts, temporal_atts, temporal_atts], dim=0)
|
|
|
|
output = grounding_expert(
|
|
inputs_embeds=temporal_embeds_all,
|
|
attention_mask=temporal_atts_all,
|
|
cross_embeds=spatial_embeds_all,
|
|
cross_attention_mask=spatial_atts_all,
|
|
)
|
|
|
|
stm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
|
|
|
stm_logits = stm_head(stm_embeds) # [3*B, 2]
|
|
|
|
bs = stm_logits.shape[0] // 3
|
|
stm_labels = stm_logits.new_ones(3 * bs, dtype=torch.long)
|
|
stm_labels[bs:] = 0
|
|
loss_stm = F.cross_entropy(stm_logits, stm_labels)
|
|
pos_feats = output.last_hidden_state[:bs]
|
|
|
|
return loss_stm, pos_feats
|
|
|
|
|
|
class VCC_VCM_Loss(ContMatchLoss):
|
|
"""Contrastive and matching losses"""
|
|
|
|
def __init__(self):
|
|
super(VCC_VCM_Loss, self).__init__()
|
|
|
|
def vcc_loss(
|
|
self,
|
|
vis_proj: torch.Tensor,
|
|
cap_proj: torch.Tensor,
|
|
idx: torch.Tensor,
|
|
temp=1.0,
|
|
all_gather=True
|
|
):
|
|
|
|
"""forward to calculate the loss
|
|
|
|
Args:
|
|
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
|
|
text_proj (torch.Tensor): The text representation. Shape: [B,C].
|
|
idx (torch.Tensor): The index for each example. Shape: [B,].
|
|
temp (torch.Tensor): The temperature. Shape: [].
|
|
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
|
|
|
|
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
|
|
|
|
"""
|
|
if all_gather:
|
|
gather_args = self.get_gather_args()
|
|
vis_proj = allgather_wgrad(vis_proj, gather_args)
|
|
cap_proj = allgather_wgrad(cap_proj, gather_args)
|
|
if idx is not None:
|
|
idx = allgather_wgrad(idx, gather_args)
|
|
|
|
sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp)
|
|
|
|
with torch.no_grad():
|
|
sim_v2c_targets = self.get_mask(sim_v2c, idx=idx, normalize=True)
|
|
sim_c2v_targets = sim_v2c_targets
|
|
|
|
loss_v2c = -torch.sum(F.log_softmax(sim_v2c, dim=1) * sim_v2c_targets, dim=1).mean()
|
|
loss_c2v = -torch.sum(F.log_softmax(sim_c2v, dim=1) * sim_c2v_targets, dim=1).mean()
|
|
|
|
loss_vcc = (loss_v2c + loss_c2v) / 2
|
|
return loss_vcc
|
|
|
|
def vcm_loss(
|
|
self,
|
|
grounding_expert,
|
|
vcm_head,
|
|
vis_embeds_orig,
|
|
cap_embeds_orig,
|
|
vis_proj,
|
|
cap_proj,
|
|
cap_atts,
|
|
idx,
|
|
generation=False,
|
|
temp=1.0
|
|
):
|
|
vis_embeds = vis_embeds_orig.clone()
|
|
cap_embeds = cap_embeds_orig.clone()
|
|
|
|
with torch.no_grad():
|
|
sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp)
|
|
vis_atts = torch.ones(
|
|
vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device
|
|
)
|
|
|
|
weights_v2c = F.softmax(sim_v2c + 1e-4, dim=1) # (N, N)
|
|
weights_c2v = F.softmax(sim_c2v + 1e-4, dim=1)
|
|
|
|
mask = self.get_mask(weights_v2c, idx=idx).bool()
|
|
weights_v2c.masked_fill_(mask, 0)
|
|
weights_c2v.masked_fill_(mask, 0)
|
|
weights_v2c = torch.nan_to_num_(weights_v2c, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
|
weights_c2v = torch.nan_to_num_(weights_c2v, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
|
|
|
if generation:
|
|
with torch.no_grad():
|
|
output = grounding_expert(
|
|
encoder_embeds=cap_embeds,
|
|
attention_mask=cap_atts,
|
|
encoder_hidden_states=vis_embeds,
|
|
encoder_attention_mask=vis_atts,
|
|
return_dict=True,
|
|
)
|
|
pos_feats = output.last_hidden_state
|
|
return pos_feats
|
|
|
|
else:
|
|
|
|
|
|
# select a hard negatives within the batch
|
|
vis_neg_indices = torch.multinomial(weights_v2c, 1).squeeze()
|
|
cap_neg_indices = torch.multinomial(weights_c2v, 1).squeeze()
|
|
|
|
|
|
vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c]
|
|
cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, d]
|
|
cap_atts_neg = cap_atts[cap_neg_indices]
|
|
|
|
# concat embeddings
|
|
vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0)
|
|
cap_embeds_all = torch.cat([cap_embeds, cap_embeds, cap_embeds_neg], dim=0)
|
|
vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0)
|
|
cap_atts_all = torch.cat([cap_atts, cap_atts, cap_atts_neg], dim=0)
|
|
|
|
output = grounding_expert(
|
|
inputs_embeds=cap_embeds_all,
|
|
attention_mask=cap_atts_all,
|
|
cross_embeds=vis_embeds_all,
|
|
cross_attention_mask=vis_atts_all,
|
|
)
|
|
|
|
vcm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
|
|
|
vcm_logits = vcm_head(vcm_embeds) # [3*B, 2]
|
|
|
|
bs = vcm_logits.shape[0] // 3
|
|
vcm_labels = vcm_logits.new_ones(3 * bs, dtype=torch.long)
|
|
vcm_labels[bs:] = 0
|
|
loss_vcm = F.cross_entropy(vcm_logits, vcm_labels)
|
|
pos_feats = output.last_hidden_state[:bs]
|
|
return loss_vcm, pos_feats
|
|
|
|
|
|
class VHC_VHM_Loss(ContMatchLoss):
|
|
"""Contrastive and matching losses"""
|
|
|
|
def __init__(self):
|
|
super(VHC_VHM_Loss, self).__init__()
|
|
|
|
def vhc_loss(
|
|
self,
|
|
vis_proj: torch.Tensor,
|
|
hist_proj: torch.Tensor,
|
|
idx: torch.Tensor,
|
|
temp=1.0,
|
|
all_gather=True
|
|
):
|
|
|
|
"""forward to calculate the loss
|
|
|
|
Args:
|
|
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
|
|
text_proj (torch.Tensor): The text representation. Shape: [B,C].
|
|
idx (torch.Tensor): The index for each example. Shape: [B,].
|
|
temp (torch.Tensor): The temperature. Shape: [].
|
|
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
|
|
|
|
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
|
|
|
|
"""
|
|
if all_gather:
|
|
gather_args = self.get_gather_args()
|
|
vis_proj = allgather_wgrad(vis_proj, gather_args)
|
|
hist_proj = allgather_wgrad(hist_proj, gather_args)
|
|
if idx is not None:
|
|
idx = allgather_wgrad(idx, gather_args)
|
|
|
|
sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp)
|
|
|
|
with torch.no_grad():
|
|
sim_v2h_targets = self.get_mask(sim_v2h, idx=idx, normalize=True)
|
|
sim_h2v_targets = sim_v2h_targets
|
|
|
|
loss_v2h = -torch.sum(F.log_softmax(sim_v2h, dim=1) * sim_v2h_targets, dim=1).mean()
|
|
loss_h2v = -torch.sum(F.log_softmax(sim_h2v, dim=1) * sim_h2v_targets, dim=1).mean()
|
|
|
|
loss_vhc = (loss_v2h + loss_h2v) / 2
|
|
return loss_vhc
|
|
|
|
def vhm_loss(
|
|
self,
|
|
grounding_expert,
|
|
vhm_head,
|
|
vis_embeds_orig,
|
|
hist_embeds_orig,
|
|
vis_proj,
|
|
hist_proj,
|
|
hist_atts,
|
|
idx,
|
|
generation=False,
|
|
temp=1.0,
|
|
):
|
|
vis_embeds = vis_embeds_orig.clone()
|
|
hist_embeds = hist_embeds_orig.clone()
|
|
with torch.no_grad():
|
|
sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp)
|
|
vis_atts = torch.ones(
|
|
vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device
|
|
)
|
|
|
|
weights_v2h = F.softmax(sim_v2h + 1e-4, dim=1) # (N, N)
|
|
weights_h2v = F.softmax(sim_h2v + 1e-4, dim=1)
|
|
|
|
mask = self.get_mask(weights_v2h, idx=idx).bool()
|
|
weights_v2h.masked_fill_(mask, 0)
|
|
weights_h2v.masked_fill_(mask, 0)
|
|
weights_v2h = torch.nan_to_num_(weights_v2h, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
|
weights_h2v = torch.nan_to_num_(weights_h2v, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
|
|
|
if generation:
|
|
with torch.no_grad():
|
|
output = grounding_expert(
|
|
encoder_embeds=hist_embeds,
|
|
attention_mask=hist_atts,
|
|
encoder_hidden_states=vis_embeds,
|
|
encoder_attention_mask=vis_atts,
|
|
return_dict=True,
|
|
# mode="fusion",
|
|
)
|
|
pos_feats = output.last_hidden_state
|
|
return pos_feats
|
|
|
|
else:
|
|
# select a hard negatives within the batch
|
|
vis_neg_indices = torch.multinomial(weights_v2h, 1).squeeze()
|
|
hist_neg_indices = torch.multinomial(weights_h2v, 1).squeeze()
|
|
|
|
vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c]
|
|
hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d]
|
|
hist_atts_neg = hist_atts[hist_neg_indices]
|
|
|
|
# concat embeddings
|
|
vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0)
|
|
hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0)
|
|
vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0)
|
|
hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0)
|
|
|
|
output = grounding_expert(
|
|
inputs_embeds=hist_embeds_all,
|
|
attention_mask=hist_atts_all,
|
|
cross_embeds=vis_embeds_all,
|
|
cross_attention_mask=vis_atts_all,
|
|
)
|
|
|
|
vhm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
|
|
|
vhm_logits = vhm_head(vhm_embeds) # [3*B, 2]
|
|
|
|
bs = vhm_logits.shape[0] // 3
|
|
vhm_labels = vhm_logits.new_ones(3 * bs, dtype=torch.long)
|
|
vhm_labels[bs:] = 0
|
|
loss_vhm = F.cross_entropy(vhm_logits, vhm_labels)
|
|
pos_feats = output.last_hidden_state[:bs]
|
|
|
|
return loss_vhm, pos_feats
|
|
|
|
|
|
class CHC_CHM_Loss(ContMatchLoss):
|
|
"""Contrastive and matching losses"""
|
|
|
|
def __init__(self):
|
|
super(CHC_CHM_Loss, self).__init__()
|
|
|
|
def chc_loss(
|
|
self,
|
|
cap_proj: torch.Tensor,
|
|
hist_proj: torch.Tensor,
|
|
idx: torch.Tensor,
|
|
temp=1.0,
|
|
all_gather=True
|
|
):
|
|
|
|
"""forward to calculate the loss
|
|
|
|
Args:
|
|
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
|
|
text_proj (torch.Tensor): The text representation. Shape: [B,C].
|
|
idx (torch.Tensor): The index for each example. Shape: [B,].
|
|
temp (torch.Tensor): The temperature. Shape: [].
|
|
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
|
|
|
|
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
|
|
|
|
"""
|
|
if all_gather:
|
|
gather_args = self.get_gather_args()
|
|
cap_proj = allgather_wgrad(cap_proj, gather_args)
|
|
hist_proj = allgather_wgrad(hist_proj, gather_args)
|
|
if idx is not None:
|
|
idx = allgather_wgrad(idx, gather_args)
|
|
|
|
sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp)
|
|
|
|
with torch.no_grad():
|
|
sim_c2h_targets = self.get_mask(sim_c2h, idx=idx, normalize=True)
|
|
sim_h2c_targets = sim_c2h_targets
|
|
|
|
loss_c2h = -torch.sum(F.log_softmax(sim_c2h, dim=1) * sim_c2h_targets, dim=1).mean()
|
|
loss_h2c = -torch.sum(F.log_softmax(sim_h2c, dim=1) * sim_h2c_targets, dim=1).mean()
|
|
|
|
loss_chc = (loss_c2h + loss_h2c) / 2
|
|
return loss_chc
|
|
|
|
def chm_loss(
|
|
self,
|
|
grounding_expert,
|
|
chm_head,
|
|
cap_embeds_orig,
|
|
hist_embeds_orig,
|
|
cap_proj,
|
|
hist_proj,
|
|
cap_atts,
|
|
hist_atts,
|
|
idx,
|
|
generation=False,
|
|
temp=1.0
|
|
):
|
|
cap_embeds = cap_embeds_orig.clone()
|
|
hist_embeds = hist_embeds_orig.clone()
|
|
with torch.no_grad():
|
|
sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp)
|
|
|
|
weights_c2h = F.softmax(sim_c2h + 1e-4, dim=1) # (N, N)
|
|
weights_h2c = F.softmax(sim_h2c + 1e-4, dim=1)
|
|
|
|
mask = self.get_mask(weights_c2h, idx=idx).bool()
|
|
weights_c2h.masked_fill_(mask, 0)
|
|
weights_h2c.masked_fill_(mask, 0)
|
|
weights_c2h = torch.nan_to_num_(weights_c2h, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
|
weights_h2c = torch.nan_to_num_(weights_h2c, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
|
|
|
if generation:
|
|
with torch.no_grad():
|
|
output = grounding_expert(
|
|
encoder_embeds=hist_embeds,
|
|
attention_mask=hist_atts,
|
|
encoder_hidden_states=cap_embeds,
|
|
encoder_attention_mask=cap_atts,
|
|
return_dict=True,
|
|
)
|
|
pos_feats = output.last_hidden_state
|
|
return pos_feats
|
|
else:
|
|
# select a hard negatives within the batch
|
|
cap_neg_indices = torch.multinomial(weights_c2h, 1).squeeze()
|
|
hist_neg_indices = torch.multinomial(weights_h2c, 1).squeeze()
|
|
|
|
cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, c]
|
|
cap_atts_neg = cap_atts[cap_neg_indices]
|
|
hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d]
|
|
hist_atts_neg = hist_atts[hist_neg_indices]
|
|
|
|
# concat embeddings
|
|
cap_embeds_all = torch.cat([cap_embeds, cap_embeds_neg, cap_embeds], dim=0)
|
|
hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0)
|
|
cap_atts_all = torch.cat([cap_atts, cap_atts_neg, cap_atts], dim=0)
|
|
hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0)
|
|
|
|
output = grounding_expert(
|
|
inputs_embeds=hist_embeds_all,
|
|
attention_mask=hist_atts_all,
|
|
cross_embeds=cap_embeds_all,
|
|
cross_attention_mask=cap_atts_all,
|
|
)
|
|
|
|
chm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
|
|
|
chm_logits = chm_head(chm_embeds) # [3*B, 2]
|
|
|
|
bs = chm_logits.shape[0] // 3
|
|
chm_labels = chm_logits.new_ones(3 * bs, dtype=torch.long)
|
|
chm_labels[bs:] = 0
|
|
loss_chm = F.cross_entropy(chm_logits, chm_labels)
|
|
pos_feats = output.last_hidden_state[:bs]
|
|
return loss_chm, pos_feats
|
|
|
|
|
|
class MLMLoss(nn.Module):
|
|
"""masked language modeling loss."""
|
|
|
|
def __init__(self, masking_prob, tokenizer):
|
|
super(MLMLoss, self).__init__()
|
|
self.tokenizer = tokenizer
|
|
self.masking_prob = masking_prob
|
|
|
|
def mlm_loss(
|
|
self,
|
|
text_encoder,
|
|
text,
|
|
text_embeds,
|
|
vision_embeds,
|
|
vision_atts,
|
|
):
|
|
input_ids = text.input_ids.clone()
|
|
labels = input_ids.clone()
|
|
probability_matrix = torch.full(labels.shape, self.masking_prob)
|
|
input_ids, labels = self.mask(
|
|
input_ids,
|
|
text_encoder.config.vocab_size,
|
|
input_ids.device,
|
|
targets=labels,
|
|
probability_matrix=probability_matrix,
|
|
)
|
|
|
|
# intermediate_mlm_output = text_encoder.bert(
|
|
# input_ids,
|
|
# attention_mask=text.attention_mask,
|
|
# encoder_hidden_states=vision_embeds,
|
|
# encoder_attention_mask=vision_atts,
|
|
# return_dict=True,
|
|
# # mode="text",
|
|
# )
|
|
|
|
# text_embeds = intermediate_mlm_output.last_hidden_state
|
|
|
|
mlm_output = text_encoder(
|
|
encoder_embeds=text_embeds,
|
|
attention_mask=text.attention_mask,
|
|
encoder_hidden_states=vision_embeds,
|
|
encoder_attention_mask=vision_atts,
|
|
return_dict=True,
|
|
labels=labels,
|
|
soft_labels=None,
|
|
# mode="fusion",
|
|
)
|
|
return mlm_output.loss
|
|
|
|
def mask(
|
|
self,
|
|
input_ids,
|
|
vocab_size,
|
|
device,
|
|
targets=None,
|
|
masked_indices=None,
|
|
probability_matrix=None,
|
|
):
|
|
if masked_indices is None:
|
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
|
|
|
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
|
|
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
|
|
|
|
if targets is not None:
|
|
# We only compute loss on masked tokens
|
|
targets[~masked_indices] = -100
|
|
|
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
|
indices_replaced = (
|
|
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
|
|
)
|
|
input_ids[indices_replaced] = self.tokenizer.mask_token_id
|
|
|
|
# 10% of the time, we replace masked input tokens with random word
|
|
indices_random = (
|
|
torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()
|
|
& masked_indices
|
|
& ~indices_replaced
|
|
)
|
|
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
|
|
input_ids[indices_random] = random_words[indices_random]
|
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
|
|
|
if targets is not None:
|
|
return input_ids, targets
|
|
else:
|
|
return input_ids
|