import torch import torch.nn as nn from .utils import pose_edge_index from torch_geometric.nn import GCNConv class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class FeedForward(nn.Module): def __init__(self, dim): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim)) def forward(self, x): return self.net(x) class CNN(nn.Module): def __init__(self, hidden_dim): super(CNN, self).__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = self.pool(x) x = self.conv2(x) x = nn.functional.relu(x) x = self.pool(x) x = self.conv3(x) x = nn.functional.relu(x) x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling return x class MindNetLSTM(nn.Module): """ Basic MindNet for model-based ToM, just LSTM on input concatenation """ def __init__(self, hidden_dim, dropout, mods): super(MindNetLSTM, self).__init__() self.mods = mods if 'rgb_1' in mods: self.img_emb = CNN(hidden_dim) self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) if 'gaze' in mods: self.gaze_emb = nn.Linear(2, hidden_dim) if 'pose' in mods: self.pose_edge_index = pose_edge_index() self.pose_emb = GCNConv(3, hidden_dim) self.LSTM = PreNorm( hidden_dim*len(mods), nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True)) self.proj = nn.Linear(hidden_dim*2, hidden_dim) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze): feats = [] if 'rgb_3' in self.mods: feats.append(rgb_3rd_pov_feats) if 'rgb_1' in self.mods: rgb_feat = [] for i in range(rgb_1st_pov.shape[1]): images_i = rgb_1st_pov[:,i] img_i_feat = self.img_emb(images_i) img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1) rgb_feat.append(img_i_feat) rgb_feat = torch.stack(rgb_feat, 1) rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat))) feats.append(rgb_feats) if 'pose' in self.mods: bs, seq_len = pose.size(0), pose.size(1) self.pose_edge_index = self.pose_edge_index.to(pose.device) pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index) pose_emb = self.dropout(self.act(pose_emb)) pose_emb = torch.mean(pose_emb, dim=1) hd = pose_emb.size(-1) feats.append(pose_emb.view(bs, seq_len, hd)) if 'gaze' in self.mods: gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) feats.append(gaze_feats) if 'bbox' in self.mods: feats.append(bbox_feats.mean(2)) lstm_inp = torch.cat(feats, 2) lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp)) c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2) return self.act(self.proj(lstm_out)), c_n, feats class MindNetSL(nn.Module): """ Basic MindNet for SL ToM, just LSTM on input concatenation """ def __init__(self, hidden_dim, dropout, mods): super(MindNetSL, self).__init__() self.mods = mods if 'rgb_1' in mods: self.img_emb = CNN(hidden_dim) self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) if 'gaze' in mods: self.gaze_emb = nn.Linear(2, hidden_dim) if 'pose' in mods: self.pose_edge_index = pose_edge_index() self.pose_emb = GCNConv(3, hidden_dim) self.LSTM = PreNorm( hidden_dim*len(mods), nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True)) self.proj = nn.Linear(hidden_dim*2, hidden_dim) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze): feats = [] if 'rgb_3' in self.mods: feats.append(rgb_3rd_pov_feats) if 'rgb_1' in self.mods: rgb_feat = [] for i in range(rgb_1st_pov.shape[1]): images_i = rgb_1st_pov[:,i] img_i_feat = self.img_emb(images_i) img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1) rgb_feat.append(img_i_feat) rgb_feat = torch.stack(rgb_feat, 1) rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat))) feats.append(rgb_feats) if 'pose' in self.mods: bs, seq_len = pose.size(0), pose.size(1) self.pose_edge_index = self.pose_edge_index.to(pose.device) pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index) pose_emb = self.dropout(self.act(pose_emb)) pose_emb = torch.mean(pose_emb, dim=1) hd = pose_emb.size(-1) feats.append(pose_emb.view(bs, seq_len, hd)) if 'gaze' in self.mods: gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) feats.append(gaze_feats) if 'bbox' in self.mods: feats.append(bbox_feats.mean(2)) lstm_inp = torch.cat(feats, 2) lstm_out, _ = self.LSTM(self.dropout(lstm_inp)) return self.act(self.proj(lstm_out)), feats