209 lines
5.5 KiB
Python
209 lines
5.5 KiB
Python
|
# --------------------------------------------------------
|
||
|
# mcan-vqa (Deep Modular Co-Attention Networks)
|
||
|
# Licensed under The MIT License [see LICENSE for details]
|
||
|
# Written by Yuhao Cui https://github.com/cuiyuhao1996
|
||
|
# --------------------------------------------------------
|
||
|
|
||
|
from core.model.net_utils import FC, MLP, LayerNorm
|
||
|
from core.model.dnc_improved import DNC, SharedMemDNC
|
||
|
from core.model.dnc_improved import FeedforwardController
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch, math
|
||
|
import time
|
||
|
|
||
|
|
||
|
# ------------------------------
|
||
|
# ---- Multi-Head Attention ----
|
||
|
# ------------------------------
|
||
|
|
||
|
class MHAtt(nn.Module):
|
||
|
def __init__(self, __C):
|
||
|
super(MHAtt, self).__init__()
|
||
|
self.__C = __C
|
||
|
|
||
|
self.linear_v = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE)
|
||
|
self.linear_k = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE)
|
||
|
self.linear_q = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE)
|
||
|
self.linear_merge = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE)
|
||
|
|
||
|
self.dropout = nn.Dropout(__C.DROPOUT_R)
|
||
|
|
||
|
def forward(self, v, k, q, mask):
|
||
|
n_batches = q.size(0)
|
||
|
|
||
|
v = self.linear_v(v).view(
|
||
|
n_batches,
|
||
|
-1,
|
||
|
self.__C.MULTI_HEAD,
|
||
|
self.__C.HIDDEN_SIZE_HEAD
|
||
|
).transpose(1, 2)
|
||
|
|
||
|
k = self.linear_k(k).view(
|
||
|
n_batches,
|
||
|
-1,
|
||
|
self.__C.MULTI_HEAD,
|
||
|
self.__C.HIDDEN_SIZE_HEAD
|
||
|
).transpose(1, 2)
|
||
|
|
||
|
q = self.linear_q(q).view(
|
||
|
n_batches,
|
||
|
-1,
|
||
|
self.__C.MULTI_HEAD,
|
||
|
self.__C.HIDDEN_SIZE_HEAD
|
||
|
).transpose(1, 2)
|
||
|
|
||
|
atted = self.att(v, k, q, mask)
|
||
|
atted = atted.transpose(1, 2).contiguous().view(
|
||
|
n_batches,
|
||
|
-1,
|
||
|
self.__C.HIDDEN_SIZE
|
||
|
)
|
||
|
|
||
|
atted = self.linear_merge(atted)
|
||
|
|
||
|
return atted
|
||
|
|
||
|
def att(self, value, key, query, mask):
|
||
|
d_k = query.size(-1)
|
||
|
|
||
|
scores = torch.matmul(
|
||
|
query, key.transpose(-2, -1)
|
||
|
) / math.sqrt(d_k)
|
||
|
|
||
|
if mask is not None:
|
||
|
scores = scores.masked_fill(mask, -1e9)
|
||
|
|
||
|
att_map = F.softmax(scores, dim=-1)
|
||
|
att_map = self.dropout(att_map)
|
||
|
|
||
|
return torch.matmul(att_map, value)
|
||
|
|
||
|
|
||
|
|
||
|
# ---------------------------
|
||
|
# ---- Feed Forward Nets ----
|
||
|
# ---------------------------
|
||
|
|
||
|
class FFN(nn.Module):
|
||
|
def __init__(self, __C):
|
||
|
super(FFN, self).__init__()
|
||
|
|
||
|
self.mlp = MLP(
|
||
|
in_size=__C.HIDDEN_SIZE,
|
||
|
mid_size=__C.FF_SIZE,
|
||
|
out_size=__C.HIDDEN_SIZE,
|
||
|
dropout_r=__C.DROPOUT_R,
|
||
|
use_relu=True
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.mlp(x)
|
||
|
|
||
|
|
||
|
# ------------------------
|
||
|
# ---- Self Attention ----
|
||
|
# ------------------------
|
||
|
|
||
|
class SA(nn.Module):
|
||
|
def __init__(self, __C):
|
||
|
super(SA, self).__init__()
|
||
|
self.mhatt = MHAtt(__C)
|
||
|
self.ffn = FFN(__C)
|
||
|
|
||
|
self.dropout1 = nn.Dropout(__C.DROPOUT_R)
|
||
|
self.norm1 = LayerNorm(__C.HIDDEN_SIZE)
|
||
|
|
||
|
self.dropout2 = nn.Dropout(__C.DROPOUT_R)
|
||
|
self.norm2 = LayerNorm(__C.HIDDEN_SIZE)
|
||
|
|
||
|
def forward(self, x, x_mask):
|
||
|
x = self.norm1(x + self.dropout1(
|
||
|
self.mhatt(x, x, x, x_mask)
|
||
|
))
|
||
|
|
||
|
x = self.norm2(x + self.dropout2(
|
||
|
self.ffn(x)
|
||
|
))
|
||
|
|
||
|
return x
|
||
|
|
||
|
# -------------------------------
|
||
|
# ---- Self Guided Attention ----
|
||
|
# -------------------------------
|
||
|
|
||
|
class SGA(nn.Module):
|
||
|
def __init__(self, __C):
|
||
|
super(SGA, self).__init__()
|
||
|
|
||
|
self.mhatt1 = MHAtt(__C)
|
||
|
self.mhatt2 = MHAtt(__C)
|
||
|
self.ffn = FFN(__C)
|
||
|
|
||
|
self.dropout1 = nn.Dropout(__C.DROPOUT_R)
|
||
|
self.norm1 = LayerNorm(__C.HIDDEN_SIZE)
|
||
|
|
||
|
self.dropout2 = nn.Dropout(__C.DROPOUT_R)
|
||
|
self.norm2 = LayerNorm(__C.HIDDEN_SIZE)
|
||
|
|
||
|
self.dropout3 = nn.Dropout(__C.DROPOUT_R)
|
||
|
self.norm3 = LayerNorm(__C.HIDDEN_SIZE)
|
||
|
|
||
|
def forward(self, x, y, x_mask, y_mask):
|
||
|
x = self.norm1(x + self.dropout1(
|
||
|
self.mhatt1(x, x, x, x_mask)
|
||
|
))
|
||
|
|
||
|
x = self.norm2(x + self.dropout2(
|
||
|
self.mhatt2(y, y, x, y_mask)
|
||
|
))
|
||
|
|
||
|
x = self.norm3(x + self.dropout3(
|
||
|
self.ffn(x)
|
||
|
))
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
# ------------------------------------------------
|
||
|
# ---- MAC Layers Cascaded by Encoder-Decoder ----
|
||
|
# ------------------------------------------------
|
||
|
|
||
|
class MCA_ED(nn.Module):
|
||
|
def __init__(self, __C):
|
||
|
super(MCA_ED, self).__init__()
|
||
|
|
||
|
self.enc_list = nn.ModuleList([SA(__C) for _ in range(__C.LAYER)])
|
||
|
self.dec_list = nn.ModuleList([SGA(__C) for _ in range(__C.LAYER)])
|
||
|
|
||
|
def forward(self, x, y, x_mask, y_mask):
|
||
|
# Get hidden vector
|
||
|
for enc in self.enc_list:
|
||
|
x = enc(x, x_mask)
|
||
|
|
||
|
for dec in self.dec_list:
|
||
|
y = dec(y, x, y_mask, x_mask)
|
||
|
return x, y
|
||
|
|
||
|
class VLC(nn.Module):
|
||
|
def __init__(self, __C):
|
||
|
super(VLC, self).__init__()
|
||
|
|
||
|
self.enc_list = nn.ModuleList([SA(__C) for _ in range(__C.LAYER)])
|
||
|
self.dec_lang_frames_list = nn.ModuleList([SGA(__C) for _ in range(__C.LAYER)])
|
||
|
self.dec_lang_clips_list = nn.ModuleList([SGA(__C) for _ in range(__C.LAYER)])
|
||
|
|
||
|
|
||
|
def forward(self, x, y, z, x_mask, y_mask, z_mask):
|
||
|
# Get hidden vector
|
||
|
for enc in self.enc_list:
|
||
|
x = enc(x, x_mask)
|
||
|
|
||
|
for dec in self.dec_lang_frames_list:
|
||
|
y = dec(y, x, y_mask, x_mask)
|
||
|
|
||
|
for dec in self.dec_lang_clips_list:
|
||
|
z = dec(z, x, z_mask, x_mask)
|
||
|
return x, y, z
|
||
|
|