269 lines
10 KiB
Python
269 lines
10 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torchvision.models as models
|
||
|
from torch_geometric.nn.conv import GCNConv
|
||
|
from .utils import left_bias, right_bias, build_ocr_graph
|
||
|
from .base import CNN, MindNetLSTM, MindNetLSTMXL
|
||
|
from memory_efficient_attention_pytorch import Attention
|
||
|
|
||
|
|
||
|
class CommonMindToMnet(nn.Module):
|
||
|
"""
|
||
|
|
||
|
"""
|
||
|
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||
|
super(CommonMindToMnet, self).__init__()
|
||
|
|
||
|
self.aggr = aggr
|
||
|
|
||
|
# ---- Images ----#
|
||
|
if resnet:
|
||
|
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||
|
self.cnn = nn.Sequential(
|
||
|
*(list(resnet.children())[:-1])
|
||
|
)
|
||
|
#for param in self.cnn.parameters():
|
||
|
# param.requires_grad = False
|
||
|
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||
|
else:
|
||
|
self.cnn = CNN(hidden_dim)
|
||
|
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||
|
|
||
|
# ---- OCR and bbox -----#
|
||
|
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||
|
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||
|
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||
|
|
||
|
# ---- Others ----#
|
||
|
self.act = nn.GELU()
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.device = device
|
||
|
|
||
|
# ---- Mind nets ----#
|
||
|
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
|
||
|
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
|
||
|
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||
|
self.ln_left = nn.LayerNorm(hidden_dim)
|
||
|
self.ln_right = nn.LayerNorm(hidden_dim)
|
||
|
if aggr == 'attn':
|
||
|
self.attn_left = Attention(
|
||
|
dim = hidden_dim,
|
||
|
dim_head = hidden_dim // 4,
|
||
|
heads = 4,
|
||
|
memory_efficient = True,
|
||
|
q_bucket_size = hidden_dim,
|
||
|
k_bucket_size = hidden_dim)
|
||
|
self.attn_right = Attention(
|
||
|
dim = hidden_dim,
|
||
|
dim_head = hidden_dim // 4,
|
||
|
heads = 4,
|
||
|
memory_efficient = True,
|
||
|
q_bucket_size = hidden_dim,
|
||
|
k_bucket_size = hidden_dim)
|
||
|
self.left = nn.Linear(hidden_dim, 27)
|
||
|
self.right = nn.Linear(hidden_dim, 27)
|
||
|
self.left.bias.data = torch.tensor(left_bias).log()
|
||
|
self.right.bias.data = torch.tensor(right_bias).log()
|
||
|
|
||
|
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||
|
|
||
|
assert images is not None
|
||
|
assert poses is not None
|
||
|
assert gazes is not None
|
||
|
assert bboxes is not None
|
||
|
|
||
|
batch_size, sequence_len, channels, height, width = images.shape
|
||
|
|
||
|
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||
|
|
||
|
rgb_feat = []
|
||
|
for i in range(sequence_len):
|
||
|
images_i = images[:,i]
|
||
|
img_i_feat = self.cnn(images_i)
|
||
|
img_i_feat = img_i_feat.view(batch_size, -1)
|
||
|
rgb_feat.append(img_i_feat)
|
||
|
rgb_feat = torch.stack(rgb_feat, 1)
|
||
|
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||
|
|
||
|
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||
|
self.ocr_edge_index,
|
||
|
self.ocr_edge_attr)))
|
||
|
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||
|
|
||
|
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||
|
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||
|
|
||
|
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
|
||
|
|
||
|
if self.aggr == 'no_tom':
|
||
|
return self.left(out_left), self.right(out_right)
|
||
|
|
||
|
if self.aggr == 'attn':
|
||
|
l = self.attn_left(x=out_left, context=common_mind)
|
||
|
r = self.attn_right(x=out_right, context=common_mind)
|
||
|
elif self.aggr == 'mult':
|
||
|
l = out_left * common_mind
|
||
|
r = out_right * common_mind
|
||
|
elif self.aggr == 'sum':
|
||
|
l = out_left + common_mind
|
||
|
r = out_right + common_mind
|
||
|
elif self.aggr == 'concat':
|
||
|
l = torch.cat([out_left, common_mind], 1)
|
||
|
r = torch.cat([out_right, common_mind], 1)
|
||
|
else: raise ValueError
|
||
|
l = self.act(l)
|
||
|
l = self.ln_left(l)
|
||
|
r = self.act(r)
|
||
|
r = self.ln_right(r)
|
||
|
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||
|
left_beliefs = self.left(l)
|
||
|
right_beliefs = self.right(r)
|
||
|
if self.aggr == 'concat':
|
||
|
left_beliefs = self.left(l)[:, :-1, :]
|
||
|
right_beliefs = self.right(r)[:, :-1, :]
|
||
|
|
||
|
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
|
||
|
|
||
|
|
||
|
|
||
|
class CommonMindToMnetXL(nn.Module):
|
||
|
"""
|
||
|
XL model.
|
||
|
"""
|
||
|
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||
|
super(CommonMindToMnetXL, self).__init__()
|
||
|
|
||
|
self.aggr = aggr
|
||
|
|
||
|
# ---- Images ----#
|
||
|
if resnet:
|
||
|
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||
|
self.cnn = nn.Sequential(
|
||
|
*(list(resnet.children())[:-1])
|
||
|
)
|
||
|
for param in self.cnn.parameters():
|
||
|
param.requires_grad = False
|
||
|
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||
|
else:
|
||
|
self.cnn = CNN(hidden_dim)
|
||
|
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||
|
|
||
|
# ---- OCR and bbox -----#
|
||
|
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||
|
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||
|
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||
|
|
||
|
# ---- Others ----#
|
||
|
self.act = nn.GELU()
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.device = device
|
||
|
|
||
|
# ---- Mind nets ----#
|
||
|
self.mind_net_left = MindNetLSTMXL(hidden_dim, dropout, mods)
|
||
|
self.mind_net_right = MindNetLSTMXL(hidden_dim, dropout, mods)
|
||
|
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||
|
self.ln_left = nn.LayerNorm(hidden_dim)
|
||
|
self.ln_right = nn.LayerNorm(hidden_dim)
|
||
|
if aggr == 'attn':
|
||
|
self.attn_left = Attention(
|
||
|
dim = hidden_dim,
|
||
|
dim_head = hidden_dim // 4,
|
||
|
heads = 4,
|
||
|
memory_efficient = True,
|
||
|
q_bucket_size = hidden_dim,
|
||
|
k_bucket_size = hidden_dim)
|
||
|
self.attn_right = Attention(
|
||
|
dim = hidden_dim,
|
||
|
dim_head = hidden_dim // 4,
|
||
|
heads = 4,
|
||
|
memory_efficient = True,
|
||
|
q_bucket_size = hidden_dim,
|
||
|
k_bucket_size = hidden_dim)
|
||
|
self.left = nn.Sequential(
|
||
|
nn.Linear(hidden_dim, hidden_dim),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(hidden_dim, 27),
|
||
|
)
|
||
|
self.right = nn.Sequential(
|
||
|
nn.Linear(hidden_dim, hidden_dim),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(hidden_dim, 27)
|
||
|
)
|
||
|
self.left[-1].bias.data = torch.tensor(left_bias).log()
|
||
|
self.right[-1].bias.data = torch.tensor(right_bias).log()
|
||
|
|
||
|
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||
|
|
||
|
assert images is not None
|
||
|
assert poses is not None
|
||
|
assert gazes is not None
|
||
|
assert bboxes is not None
|
||
|
|
||
|
batch_size, sequence_len, channels, height, width = images.shape
|
||
|
|
||
|
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||
|
|
||
|
rgb_feat = []
|
||
|
for i in range(sequence_len):
|
||
|
images_i = images[:,i]
|
||
|
img_i_feat = self.cnn(images_i)
|
||
|
img_i_feat = img_i_feat.view(batch_size, -1)
|
||
|
rgb_feat.append(img_i_feat)
|
||
|
rgb_feat = torch.stack(rgb_feat, 1)
|
||
|
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||
|
|
||
|
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||
|
self.ocr_edge_index,
|
||
|
self.ocr_edge_attr)))
|
||
|
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||
|
|
||
|
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||
|
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||
|
|
||
|
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
|
||
|
|
||
|
if self.aggr == 'no_tom':
|
||
|
return self.left(out_left), self.right(out_right)
|
||
|
|
||
|
if self.aggr == 'attn':
|
||
|
l = self.attn_left(x=out_left, context=common_mind)
|
||
|
r = self.attn_right(x=out_right, context=common_mind)
|
||
|
elif self.aggr == 'mult':
|
||
|
l = out_left * common_mind
|
||
|
r = out_right * common_mind
|
||
|
elif self.aggr == 'sum':
|
||
|
l = out_left + common_mind
|
||
|
r = out_right + common_mind
|
||
|
elif self.aggr == 'concat':
|
||
|
l = torch.cat([out_left, common_mind], 1)
|
||
|
r = torch.cat([out_right, common_mind], 1)
|
||
|
else: raise ValueError
|
||
|
l = self.act(l)
|
||
|
l = self.ln_left(l)
|
||
|
r = self.act(r)
|
||
|
r = self.ln_right(r)
|
||
|
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||
|
left_beliefs = self.left(l)
|
||
|
right_beliefs = self.right(r)
|
||
|
if self.aggr == 'concat':
|
||
|
left_beliefs = self.left(l)[:, :-1, :]
|
||
|
right_beliefs = self.right(r)[:, :-1, :]
|
||
|
|
||
|
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
images = torch.ones(3, 22, 3, 128, 128)
|
||
|
poses = torch.ones(3, 22, 2, 75)
|
||
|
gazes = torch.ones(3, 22, 2, 3)
|
||
|
bboxes = torch.ones(3, 22, 108)
|
||
|
model = CommonMindToMnetXL(64, 'cpu', False, 0.5, aggr='attn')
|
||
|
out = model(images, poses, gazes, bboxes, None)
|
||
|
print(out[0].shape)
|