import torch import torch.nn as nn import torchvision.models as models from torch_geometric.nn.conv import GCNConv from .utils import left_bias, right_bias, build_ocr_graph from .base import CNN, MindNetLSTM from memory_efficient_attention_pytorch import Attention class ImplicitToMnet(nn.Module): """ Implicit ToM net. Supports any subset of modalities Possible aggregations: sum, mult, attn, concat """ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']): super(ImplicitToMnet, self).__init__() self.aggr = aggr # ---- Images ----# if resnet: resnet = models.resnet34(weights="IMAGENET1K_V1") self.cnn = nn.Sequential( *(list(resnet.children())[:-1]) ) for param in self.cnn.parameters(): param.requires_grad = False self.rgb_ff = nn.Linear(512, hidden_dim) else: self.cnn = CNN(hidden_dim) self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) # ---- OCR and bbox -----# self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device) self.ocr_gnn = GCNConv(-1, hidden_dim) self.bbox_ff = nn.Linear(108, hidden_dim) # ---- Others ----# self.act = nn.GELU() self.dropout = nn.Dropout(dropout) self.device = device # ---- Mind nets ----# self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods) self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods) self.ln_left = nn.LayerNorm(hidden_dim) self.ln_right = nn.LayerNorm(hidden_dim) if aggr == 'attn': self.attn_left = Attention( dim = hidden_dim, dim_head = hidden_dim // 4, heads = 4, memory_efficient = True, q_bucket_size = hidden_dim, k_bucket_size = hidden_dim) self.attn_right = Attention( dim = hidden_dim, dim_head = hidden_dim // 4, heads = 4, memory_efficient = True, q_bucket_size = hidden_dim, k_bucket_size = hidden_dim) self.left = nn.Linear(hidden_dim, 27) self.right = nn.Linear(hidden_dim, 27) self.left.bias.data = torch.tensor(left_bias).log() self.right.bias.data = torch.tensor(right_bias).log() def forward(self, images, poses, gazes, bboxes, ocr_tensor=None): assert images is not None assert poses is not None assert gazes is not None assert bboxes is not None batch_size, sequence_len, channels, height, width = images.shape bbox_feat = self.act(self.bbox_ff(bboxes)) rgb_feat = [] for i in range(sequence_len): images_i = images[:,i] img_i_feat = self.cnn(images_i) img_i_feat = img_i_feat.view(batch_size, -1) rgb_feat.append(img_i_feat) rgb_feat = torch.stack(rgb_feat, 1) rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat))) ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr))) ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1) out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat) out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat) if self.aggr == 'no_tom': return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right if self.aggr == 'attn': l = self.attn_left(x=out_left, context=cell_right) r = self.attn_right(x=out_right, context=cell_left) elif self.aggr == 'mult': l = out_left * cell_right r = out_right * cell_left elif self.aggr == 'sum': l = out_left + cell_right r = out_right + cell_left elif self.aggr == 'concat': l = torch.cat([out_left, cell_right], 1) r = torch.cat([out_right, cell_left], 1) else: raise ValueError l = self.act(l) l = self.ln_left(l) r = self.act(r) r = self.ln_right(r) if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn': left_beliefs = self.left(l) right_beliefs = self.right(r) if self.aggr == 'concat': left_beliefs = self.left(l)[:, :-1, :] right_beliefs = self.right(r)[:, :-1, :] return left_beliefs, right_beliefs, [out_left, cell_left, out_right, cell_right] + feats_left + feats_right if __name__ == "__main__": images = torch.ones(3, 22, 3, 128, 128) poses = torch.ones(3, 22, 2, 75) gazes = torch.ones(3, 22, 2, 3) bboxes = torch.ones(3, 22, 108) model = ImplicitToMnet(64, 'cpu', False, 0.5, aggr='attn') out = model(images, poses, gazes, bboxes, None) print(out[0].shape)