# -------------------------------------------------------- # mcan-vqa (Deep Modular Co-Attention Networks) # Licensed under The MIT License [see LICENSE for details] # Written by Yuhao Cui https://github.com/cuiyuhao1996 # -------------------------------------------------------- from core.model.net_utils import FC, MLP, LayerNorm from core.model.mca import SA, MCA_ED, VLC from core.model.dnc import DNC import torch.nn as nn import torch.nn.functional as F import torch # ------------------------------ # ---- Flatten the sequence ---- # ------------------------------ class AttFlat(nn.Module): def __init__(self, __C): super(AttFlat, self).__init__() self.__C = __C self.mlp = MLP( in_size=__C.HIDDEN_SIZE, mid_size=__C.FLAT_MLP_SIZE, out_size=__C.FLAT_GLIMPSES, dropout_r=__C.DROPOUT_R, use_relu=True ) self.linear_merge = nn.Linear( __C.HIDDEN_SIZE * __C.FLAT_GLIMPSES, __C.FLAT_OUT_SIZE ) def forward(self, x, x_mask): att = self.mlp(x) att = att.masked_fill( x_mask.squeeze(1).squeeze(1).unsqueeze(2), -1e9 ) att = F.softmax(att, dim=1) att_list = [] for i in range(self.__C.FLAT_GLIMPSES): att_list.append( torch.sum(att[:, :, i: i + 1] * x, dim=1) ) x_atted = torch.cat(att_list, dim=1) x_atted = self.linear_merge(x_atted) return x_atted class AttFlatMem(AttFlat): def __init__(self, __C): super(AttFlatMem, self).__init__(__C) self.__C = __C def forward(self, x_mem, x, x_mask): att = self.mlp(x_mem) att = att.masked_fill( x_mask.squeeze(1).squeeze(1).unsqueeze(2), float('-inf') ) att = F.softmax(att, dim=1) att_list = [] for i in range(self.__C.FLAT_GLIMPSES): att_list.append( torch.sum(att[:, :, i: i + 1] * x, dim=1) ) x_atted = torch.cat(att_list, dim=1) x_atted = self.linear_merge(x_atted) return x_atted # ------------------------- # ---- Main MCAN Model ---- # ------------------------- class Net1(nn.Module): def __init__(self, __C, pretrained_emb, token_size, answer_size): super(Net1, self).__init__() print('Training with Network type 1: VLCN') self.pretrained_path = __C.PRETRAINED_PATH self.embedding = nn.Embedding( num_embeddings=token_size, embedding_dim=__C.WORD_EMBED_SIZE ) # Loading the GloVe embedding weights if __C.USE_GLOVE: self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) self.lstm = nn.LSTM( input_size=__C.WORD_EMBED_SIZE, hidden_size=__C.HIDDEN_SIZE, num_layers=1, batch_first=True ) self.frame_feat_linear = nn.Linear( __C.FRAME_FEAT_SIZE, __C.HIDDEN_SIZE ) self.clip_feat_linear = nn.Linear( __C.CLIP_FEAT_SIZE, __C.HIDDEN_SIZE ) self.backbone = VLC(__C) self.attflat_lang = AttFlat(__C) self.attflat_frame = AttFlat(__C) self.attflat_clip = AttFlat(__C) self.dnc = DNC( __C.FLAT_OUT_SIZE, __C.FLAT_OUT_SIZE, rnn_type='lstm', num_layers=2, num_hidden_layers=2, bias=True, batch_first=True, dropout=0, bidirectional=True, nr_cells=__C.CELL_COUNT_DNC, read_heads=__C.N_READ_HEADS_DNC, cell_size=__C.WORD_LENGTH_DNC, nonlinearity='tanh', gpu_id=0, independent_linears=False, share_memory=False, debug=False, clip=20, ) self.proj_norm = LayerNorm(__C.FLAT_OUT_SIZE) self.proj_norm_dnc = LayerNorm(__C.FLAT_OUT_SIZE + __C.N_READ_HEADS_DNC * __C.WORD_LENGTH_DNC) self.linear_dnc = FC(__C.FLAT_OUT_SIZE + __C.N_READ_HEADS_DNC * __C.WORD_LENGTH_DNC, __C.FLAT_OUT_SIZE, dropout_r=0.2) self.proj = nn.Linear(__C.FLAT_OUT_SIZE, answer_size) def forward(self, frame_feat, clip_feat, ques_ix): # Make mask lang_feat_mask = self.make_mask(ques_ix.unsqueeze(2)) frame_feat_mask = self.make_mask(frame_feat) clip_feat_mask = self.make_mask(clip_feat) # Pre-process Language Feature lang_feat = self.embedding(ques_ix) lang_feat, _ = self.lstm(lang_feat) # Pre-process Video Feature frame_feat = self.frame_feat_linear(frame_feat) clip_feat = self.clip_feat_linear(clip_feat) # Backbone Framework lang_feat, frame_feat, clip_feat = self.backbone( lang_feat, frame_feat, clip_feat, lang_feat_mask, frame_feat_mask, clip_feat_mask ) lang_feat = self.attflat_lang( lang_feat, lang_feat_mask ) frame_feat = self.attflat_frame( frame_feat, frame_feat_mask ) clip_feat = self.attflat_clip( clip_feat, clip_feat_mask ) proj_feat_0 = lang_feat + frame_feat + clip_feat proj_feat_0 = self.proj_norm(proj_feat_0) proj_feat_1 = torch.stack([lang_feat, frame_feat, clip_feat], dim=1) proj_feat_1, (_, _, rv), _ = self.dnc(proj_feat_1, (None, None, None), reset_experience=True, pass_through_memory=True) proj_feat_1 = proj_feat_1.sum(1) proj_feat_1 = torch.cat([proj_feat_1, rv], dim=-1) proj_feat_1 = self.proj_norm_dnc(proj_feat_1) proj_feat_1 = self.linear_dnc(proj_feat_1) # proj_feat_1 = self.proj_norm(proj_feat_1) proj_feat = torch.sigmoid(self.proj(proj_feat_0 + proj_feat_1)) return proj_feat def load_pretrained_weights(self): pretrained_msvd = torch.load(self.pretrained_path)['state_dict'] for n_pretrained, p_pretrained in pretrained_msvd.items(): if 'dnc' in n_pretrained: self.state_dict()[n_pretrained].copy_(p_pretrained) print('Pre-trained dnc-weights successfully loaded!') # Masking def make_mask(self, feature): return (torch.sum( torch.abs(feature), dim=-1 ) == 0).unsqueeze(1).unsqueeze(2) class Net2(nn.Module): def __init__(self, __C, pretrained_emb, token_size, answer_size): super(Net2, self).__init__() print('Training with Network type 2: VLCN-FLF') self.embedding = nn.Embedding( num_embeddings=token_size, embedding_dim=__C.WORD_EMBED_SIZE ) # Loading the GloVe embedding weights if __C.USE_GLOVE: self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) self.lstm = nn.LSTM( input_size=__C.WORD_EMBED_SIZE, hidden_size=__C.HIDDEN_SIZE, num_layers=1, batch_first=True ) self.frame_feat_linear = nn.Linear( __C.FRAME_FEAT_SIZE, __C.HIDDEN_SIZE ) self.clip_feat_linear = nn.Linear( __C.CLIP_FEAT_SIZE, __C.HIDDEN_SIZE ) self.backbone = VLC(__C) self.attflat_lang = AttFlat(__C) self.attflat_frame = AttFlat(__C) self.attflat_clip = AttFlat(__C) self.proj_norm = LayerNorm(__C.FLAT_OUT_SIZE) self.proj = nn.Linear(__C.FLAT_OUT_SIZE, answer_size) def forward(self, frame_feat, clip_feat, ques_ix): # Make mask lang_feat_mask = self.make_mask(ques_ix.unsqueeze(2)) frame_feat_mask = self.make_mask(frame_feat) clip_feat_mask = self.make_mask(clip_feat) # Pre-process Language Feature lang_feat = self.embedding(ques_ix) lang_feat, _ = self.lstm(lang_feat) # Pre-process Video Feature frame_feat = self.frame_feat_linear(frame_feat) clip_feat = self.clip_feat_linear(clip_feat) # Backbone Framework lang_feat, frame_feat, clip_feat = self.backbone( lang_feat, frame_feat, clip_feat, lang_feat_mask, frame_feat_mask, clip_feat_mask ) lang_feat = self.attflat_lang( lang_feat, lang_feat_mask ) frame_feat = self.attflat_frame( frame_feat, frame_feat_mask ) clip_feat = self.attflat_clip( clip_feat, clip_feat_mask ) proj_feat = lang_feat + frame_feat + clip_feat proj_feat = self.proj_norm(proj_feat) proj_feat = torch.sigmoid(self.proj(proj_feat)) return proj_feat # Masking def make_mask(self, feature): return (torch.sum( torch.abs(feature), dim=-1 ) == 0).unsqueeze(1).unsqueeze(2) class Net3(nn.Module): def __init__(self, __C, pretrained_emb, token_size, answer_size): super(Net3, self).__init__() print('Training with Network type 3: VLCN+LSTM') self.embedding = nn.Embedding( num_embeddings=token_size, embedding_dim=__C.WORD_EMBED_SIZE ) # Loading the GloVe embedding weights if __C.USE_GLOVE: self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) self.lstm = nn.LSTM( input_size=__C.WORD_EMBED_SIZE, hidden_size=__C.HIDDEN_SIZE, num_layers=1, batch_first=True ) self.frame_feat_linear = nn.Linear( __C.FRAME_FEAT_SIZE, __C.HIDDEN_SIZE ) self.clip_feat_linear = nn.Linear( __C.CLIP_FEAT_SIZE, __C.HIDDEN_SIZE ) self.backbone = VLC(__C) self.attflat_lang = AttFlat(__C) self.attflat_frame = AttFlat(__C) self.attflat_clip = AttFlat(__C) self.lstm_fusion = nn.LSTM( input_size=__C.FLAT_OUT_SIZE, hidden_size=__C.FLAT_OUT_SIZE, num_layers=2, batch_first=True, bidirectional=True ) self.proj_norm = LayerNorm(__C.FLAT_OUT_SIZE) self.proj_feat_1 = nn.Linear(__C.FLAT_OUT_SIZE * 2, __C.FLAT_OUT_SIZE) self.proj_norm_lstm = LayerNorm(__C.FLAT_OUT_SIZE) self.proj = nn.Linear(__C.FLAT_OUT_SIZE, answer_size) def forward(self, frame_feat, clip_feat, ques_ix): # Make mask lang_feat_mask = self.make_mask(ques_ix.unsqueeze(2)) frame_feat_mask = self.make_mask(frame_feat) clip_feat_mask = self.make_mask(clip_feat) # Pre-process Language Feature lang_feat = self.embedding(ques_ix) lang_feat, _ = self.lstm(lang_feat) # Pre-process Video Feature frame_feat = self.frame_feat_linear(frame_feat) clip_feat = self.clip_feat_linear(clip_feat) # Backbone Framework lang_feat, frame_feat, clip_feat = self.backbone( lang_feat, frame_feat, clip_feat, lang_feat_mask, frame_feat_mask, clip_feat_mask ) lang_feat = self.attflat_lang( lang_feat, lang_feat_mask ) frame_feat = self.attflat_frame( frame_feat, frame_feat_mask ) clip_feat = self.attflat_clip( clip_feat, clip_feat_mask ) proj_feat_0 = lang_feat + frame_feat + clip_feat proj_feat_0 = self.proj_norm(proj_feat_0) proj_feat_1 = torch.stack([lang_feat, frame_feat, clip_feat], dim=1) proj_feat_1, _ = self.lstm_fusion(proj_feat_1) proj_feat_1 = proj_feat_1.sum(1) proj_feat_1 = self.proj_feat_1(proj_feat_1) proj_feat_1 = self.proj_norm_lstm(proj_feat_1) proj_feat = torch.sigmoid(self.proj(proj_feat_0 + proj_feat_1)) return proj_feat # Masking def make_mask(self, feature): return (torch.sum( torch.abs(feature), dim=-1 ) == 0).unsqueeze(1).unsqueeze(2) class Net4(nn.Module): def __init__(self, __C, pretrained_emb, token_size, answer_size): super(Net4, self).__init__() print('Training with Network type 4: MCAN') self.embedding = nn.Embedding( num_embeddings=token_size, embedding_dim=__C.WORD_EMBED_SIZE ) # Loading the GloVe embedding weights if __C.USE_GLOVE: self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) self.lstm = nn.LSTM( input_size=__C.WORD_EMBED_SIZE, hidden_size=__C.HIDDEN_SIZE, num_layers=1, batch_first=True ) self.frame_feat_linear = nn.Linear( __C.FRAME_FEAT_SIZE, __C.HIDDEN_SIZE ) self.clip_feat_linear = nn.Linear( __C.CLIP_FEAT_SIZE, __C.HIDDEN_SIZE ) self.backbone = MCA_ED(__C) self.attflat_lang = AttFlat(__C) self.attflat_vid = AttFlat(__C) self.proj_norm = LayerNorm(__C.FLAT_OUT_SIZE) self.proj = nn.Linear(__C.FLAT_OUT_SIZE, answer_size) def forward(self, frame_feat, clip_feat, ques_ix): # Make mask lang_feat_mask = self.make_mask(ques_ix.unsqueeze(2)) frame_feat_mask = self.make_mask(frame_feat) clip_feat_mask = self.make_mask(clip_feat) # Pre-process Language Feature lang_feat = self.embedding(ques_ix) lang_feat, _ = self.lstm(lang_feat) # Pre-process Video Feature frame_feat = self.frame_feat_linear(frame_feat) clip_feat = self.clip_feat_linear(clip_feat) # concat frame and clip features vid_feat = torch.cat([frame_feat, clip_feat], dim=1) vid_feat_mask = torch.cat([frame_feat_mask, clip_feat_mask], dim=-1) # Backbone Framework lang_feat, vid_feat = self.backbone( lang_feat, vid_feat, lang_feat_mask, vid_feat_mask, ) lang_feat = self.attflat_lang( lang_feat, lang_feat_mask ) vid_feat = self.attflat_vid( vid_feat, vid_feat_mask ) proj_feat = lang_feat + vid_feat proj_feat = self.proj_norm(proj_feat) proj_feat = torch.sigmoid(self.proj(proj_feat)) return proj_feat # Masking def make_mask(self, feature): return (torch.sum( torch.abs(feature), dim=-1 ) == 0).unsqueeze(1).unsqueeze(2)