import torch import torch.nn as nn import torchvision.models as models from .base import CNN, MindNetLSTM from memory_efficient_attention_pytorch import Attention class CommonMindToMnet(nn.Module): """ img: bs, 3, 128, 128 pose: bs, 26, 3 gaze: bs, 2 NOTE: only tracker has gaze bbox: bs, 4 """ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']): super(CommonMindToMnet, self).__init__() self.aggr = aggr self.mods = mods # ---- 3rd POV Images, object and bbox ----# if resnet: resnet = models.resnet34(weights="IMAGENET1K_V1") self.cnn = nn.Sequential( *(list(resnet.children())[:-1]) ) #for param in self.cnn.parameters(): # param.requires_grad = False self.rgb_ff = nn.Linear(512, hidden_dim) else: self.cnn = CNN(hidden_dim) self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) self.bbox_ff = nn.Linear(4, hidden_dim) # ---- Others ----# self.act = nn.GELU() self.dropout = nn.Dropout(dropout) self.device = device # ---- Mind nets ----# self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods) self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze']) if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim) self.ln_1 = nn.LayerNorm(hidden_dim) self.ln_2 = nn.LayerNorm(hidden_dim) if aggr == 'attn': self.attn_left = Attention( dim = hidden_dim, dim_head = hidden_dim // 4, heads = 4, memory_efficient = True, q_bucket_size = hidden_dim, k_bucket_size = hidden_dim) self.attn_right = Attention( dim = hidden_dim, dim_head = hidden_dim // 4, heads = 4, memory_efficient = True, q_bucket_size = hidden_dim, k_bucket_size = hidden_dim) self.m1 = nn.Linear(hidden_dim, 4) self.m2 = nn.Linear(hidden_dim, 4) self.m12 = nn.Linear(hidden_dim, 4) self.m21 = nn.Linear(hidden_dim, 4) self.mc = nn.Linear(hidden_dim, 4) def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze): batch_size, sequence_len, channels, height, width = img_3rd_pov.shape if 'bbox' in self.mods: bbox_feat = self.dropout(self.act(self.bbox_ff(bbox))) else: bbox_feat = None if 'rgb_3' in self.mods: rgb_feat = [] for i in range(sequence_len): images_i = img_3rd_pov[:,i] img_i_feat = self.cnn(images_i) img_i_feat = img_i_feat.view(batch_size, -1) rgb_feat.append(img_i_feat) rgb_feat = torch.stack(rgb_feat, 1) rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat))) else: rgb_feat_3rd_pov = None if tracker_id == 'skele1': out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze) out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None) else: out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze) out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None) if self.aggr == 'no_tom': m1 = self.m1(out_1).mean(1) m2 = self.m2(out_2).mean(1) m12 = self.m12(out_1).mean(1) m21 = self.m21(out_2).mean(1) mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2 return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2 common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h) if self.aggr == 'attn': p1 = self.attn_left(x=out_1, context=common_mind) p2 = self.attn_right(x=out_2, context=common_mind) elif self.aggr == 'mult': p1 = out_1 * common_mind p2 = out_2 * common_mind elif self.aggr == 'sum': p1 = out_1 + common_mind p2 = out_2 + common_mind elif self.aggr == 'concat': p1 = torch.cat([out_1, common_mind], 1) p2 = torch.cat([out_2, common_mind], 1) else: raise ValueError p1 = self.act(p1) p1 = self.ln_1(p1) p2 = self.act(p2) p2 = self.ln_2(p2) if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn': m1 = self.m1(p1).mean(1) m2 = self.m2(p2).mean(1) m12 = self.m12(p1).mean(1) m21 = self.m21(p2).mean(1) mc = self.mc(p1*p2).mean(1) if self.aggr == 'concat': m1 = self.m1(p1).mean(1) m2 = self.m2(p2).mean(1) m12 = self.m12(p1).mean(1) m21 = self.m21(p2).mean(1) mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2 return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2 if __name__ == "__main__": img_3rd_pov = torch.ones(3, 5, 3, 128, 128) img_tracker = torch.ones(3, 5, 3, 128, 128) img_battery = torch.ones(3, 5, 3, 128, 128) pose1 = torch.ones(3, 5, 26, 3) pose2 = torch.ones(3, 5, 26, 3) bbox = torch.ones(3, 5, 13, 4) tracker_id = 'skele1' gaze = torch.ones(3, 5, 2) mods = ['pose', 'bbox', 'rgb_3'] for agg in ['no_tom']: model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods) out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) print(out[0].shape)