vlcn/code/core/model/mca.py

209 lines
5.5 KiB
Python

# --------------------------------------------------------
# mcan-vqa (Deep Modular Co-Attention Networks)
# Licensed under The MIT License [see LICENSE for details]
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# --------------------------------------------------------
from core.model.net_utils import FC, MLP, LayerNorm
from core.model.dnc_improved import DNC, SharedMemDNC
from core.model.dnc_improved import FeedforwardController
import torch.nn as nn
import torch.nn.functional as F
import torch, math
import time
# ------------------------------
# ---- Multi-Head Attention ----
# ------------------------------
class MHAtt(nn.Module):
def __init__(self, __C):
super(MHAtt, self).__init__()
self.__C = __C
self.linear_v = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE)
self.linear_k = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE)
self.linear_q = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE)
self.linear_merge = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE)
self.dropout = nn.Dropout(__C.DROPOUT_R)
def forward(self, v, k, q, mask):
n_batches = q.size(0)
v = self.linear_v(v).view(
n_batches,
-1,
self.__C.MULTI_HEAD,
self.__C.HIDDEN_SIZE_HEAD
).transpose(1, 2)
k = self.linear_k(k).view(
n_batches,
-1,
self.__C.MULTI_HEAD,
self.__C.HIDDEN_SIZE_HEAD
).transpose(1, 2)
q = self.linear_q(q).view(
n_batches,
-1,
self.__C.MULTI_HEAD,
self.__C.HIDDEN_SIZE_HEAD
).transpose(1, 2)
atted = self.att(v, k, q, mask)
atted = atted.transpose(1, 2).contiguous().view(
n_batches,
-1,
self.__C.HIDDEN_SIZE
)
atted = self.linear_merge(atted)
return atted
def att(self, value, key, query, mask):
d_k = query.size(-1)
scores = torch.matmul(
query, key.transpose(-2, -1)
) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask, -1e9)
att_map = F.softmax(scores, dim=-1)
att_map = self.dropout(att_map)
return torch.matmul(att_map, value)
# ---------------------------
# ---- Feed Forward Nets ----
# ---------------------------
class FFN(nn.Module):
def __init__(self, __C):
super(FFN, self).__init__()
self.mlp = MLP(
in_size=__C.HIDDEN_SIZE,
mid_size=__C.FF_SIZE,
out_size=__C.HIDDEN_SIZE,
dropout_r=__C.DROPOUT_R,
use_relu=True
)
def forward(self, x):
return self.mlp(x)
# ------------------------
# ---- Self Attention ----
# ------------------------
class SA(nn.Module):
def __init__(self, __C):
super(SA, self).__init__()
self.mhatt = MHAtt(__C)
self.ffn = FFN(__C)
self.dropout1 = nn.Dropout(__C.DROPOUT_R)
self.norm1 = LayerNorm(__C.HIDDEN_SIZE)
self.dropout2 = nn.Dropout(__C.DROPOUT_R)
self.norm2 = LayerNorm(__C.HIDDEN_SIZE)
def forward(self, x, x_mask):
x = self.norm1(x + self.dropout1(
self.mhatt(x, x, x, x_mask)
))
x = self.norm2(x + self.dropout2(
self.ffn(x)
))
return x
# -------------------------------
# ---- Self Guided Attention ----
# -------------------------------
class SGA(nn.Module):
def __init__(self, __C):
super(SGA, self).__init__()
self.mhatt1 = MHAtt(__C)
self.mhatt2 = MHAtt(__C)
self.ffn = FFN(__C)
self.dropout1 = nn.Dropout(__C.DROPOUT_R)
self.norm1 = LayerNorm(__C.HIDDEN_SIZE)
self.dropout2 = nn.Dropout(__C.DROPOUT_R)
self.norm2 = LayerNorm(__C.HIDDEN_SIZE)
self.dropout3 = nn.Dropout(__C.DROPOUT_R)
self.norm3 = LayerNorm(__C.HIDDEN_SIZE)
def forward(self, x, y, x_mask, y_mask):
x = self.norm1(x + self.dropout1(
self.mhatt1(x, x, x, x_mask)
))
x = self.norm2(x + self.dropout2(
self.mhatt2(y, y, x, y_mask)
))
x = self.norm3(x + self.dropout3(
self.ffn(x)
))
return x
# ------------------------------------------------
# ---- MAC Layers Cascaded by Encoder-Decoder ----
# ------------------------------------------------
class MCA_ED(nn.Module):
def __init__(self, __C):
super(MCA_ED, self).__init__()
self.enc_list = nn.ModuleList([SA(__C) for _ in range(__C.LAYER)])
self.dec_list = nn.ModuleList([SGA(__C) for _ in range(__C.LAYER)])
def forward(self, x, y, x_mask, y_mask):
# Get hidden vector
for enc in self.enc_list:
x = enc(x, x_mask)
for dec in self.dec_list:
y = dec(y, x, y_mask, x_mask)
return x, y
class VLC(nn.Module):
def __init__(self, __C):
super(VLC, self).__init__()
self.enc_list = nn.ModuleList([SA(__C) for _ in range(__C.LAYER)])
self.dec_lang_frames_list = nn.ModuleList([SGA(__C) for _ in range(__C.LAYER)])
self.dec_lang_clips_list = nn.ModuleList([SGA(__C) for _ in range(__C.LAYER)])
def forward(self, x, y, z, x_mask, y_mask, z_mask):
# Get hidden vector
for enc in self.enc_list:
x = enc(x, x_mask)
for dec in self.dec_lang_frames_list:
y = dec(y, x, y_mask, x_mask)
for dec in self.dec_lang_clips_list:
z = dec(z, x, z_mask, x_mask)
return x, y, z