import torch import torch.nn as nn from .utils import pose_edge_index from torch_geometric.nn import Sequential, GCNConv from x_transformers import ContinuousTransformerWrapper, Decoder class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class FeedForward(nn.Module): def __init__(self, dim): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim)) def forward(self, x): return self.net(x) class CNN(nn.Module): def __init__(self, hidden_dim): super(CNN, self).__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = self.pool(x) x = self.conv2(x) x = nn.functional.relu(x) x = self.pool(x) x = self.conv3(x) x = nn.functional.relu(x) x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling return x class MindNetLSTM(nn.Module): """ Basic MindNet for model-based ToM, just LSTM on input concatenation """ def __init__(self, hidden_dim, dropout, mods): super(MindNetLSTM, self).__init__() self.mods = mods self.gaze_emb = nn.Linear(3, hidden_dim) self.pose_edge_index = pose_edge_index() self.pose_emb = GCNConv(3, hidden_dim) self.LSTM = PreNorm( hidden_dim*len(mods), nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True)) self.proj = nn.Linear(hidden_dim*2, hidden_dim) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats): feats = [] if 'rgb' in self.mods: feats.append(rgb_feats) if 'ocr' in self.mods: feats.append(ocr_feats) if 'pose' in self.mods: bs, seq_len = pose.size(0), pose.size(1) self.pose_edge_index = self.pose_edge_index.to(pose.device) pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index) pose_emb = self.dropout(self.act(pose_emb)) pose_emb = torch.mean(pose_emb, dim=1) hd = pose_emb.size(-1) feats.append(pose_emb.view(bs, seq_len, hd)) if 'gaze' in self.mods: gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) feats.append(gaze_feats) if 'bbox' in self.mods: feats.append(bbox_feats) lstm_inp = torch.cat(feats, 2) lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp)) c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2) return self.act(self.proj(lstm_out)), c_n, feats class MindNetSL(nn.Module): """ Basic MindNet for SL ToM, just LSTM on input concatenation """ def __init__(self, hidden_dim, dropout, mods): super(MindNetSL, self).__init__() self.mods = mods self.gaze_emb = nn.Linear(3, hidden_dim) self.pose_edge_index = pose_edge_index() self.pose_emb = GCNConv(3, hidden_dim) self.LSTM = PreNorm( hidden_dim*5, nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True) ) self.proj = nn.Linear(hidden_dim*2, hidden_dim) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats): feats = [] if 'rgb' in self.mods: feats.append(rgb_feats) if 'ocr' in self.mods: feats.append(ocr_feats) if 'pose' in self.mods: bs, seq_len = pose.size(0), pose.size(1) self.pose_edge_index = self.pose_edge_index.to(pose.device) pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index) pose_emb = self.dropout(self.act(pose_emb)) pose_emb = torch.mean(pose_emb, dim=1) hd = pose_emb.size(-1) feats.append(pose_emb.view(bs, seq_len, hd)) if 'gaze' in self.mods: gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) feats.append(gaze_feats) if 'bbox' in self.mods: feats.append(bbox_feats) lstm_inp = torch.cat(feats, 2) lstm_out, _ = self.LSTM(self.dropout(lstm_inp)) return self.act(self.proj(lstm_out)), feats class MindNetTF(nn.Module): """ Basic MindNet for model-based ToM, Transformer on input concatenation """ def __init__(self, hidden_dim, dropout, mods): super(MindNetTF, self).__init__() self.mods = mods self.gaze_emb = nn.Linear(3, hidden_dim) self.pose_edge_index = pose_edge_index() self.pose_emb = GCNConv(3, hidden_dim) self.tf = ContinuousTransformerWrapper( dim_in=hidden_dim*len(mods), dim_out=hidden_dim, max_seq_len=747, attn_layers=Decoder( dim=512, depth=6, heads=8 ) ) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats): feats = [] if 'rgb' in self.mods: feats.append(rgb_feats) if 'ocr' in self.mods: feats.append(ocr_feats) if 'pose' in self.mods: bs, seq_len = pose.size(0), pose.size(1) self.pose_edge_index = self.pose_edge_index.to(pose.device) pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index) pose_emb = self.dropout(self.act(pose_emb)) pose_emb = torch.mean(pose_emb, dim=1) hd = pose_emb.size(-1) feats.append(pose_emb.view(bs, seq_len, hd)) if 'gaze' in self.mods: gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) feats.append(gaze_feats) if 'bbox' in self.mods: feats.append(bbox_feats) tf_inp = torch.cat(feats, 2) tf_out = self.tf(self.dropout(tf_inp)) return tf_out, feats class MindNetLSTMXL(nn.Module): """ Basic MindNet for model-based ToM, just LSTM on input concatenation """ def __init__(self, hidden_dim, dropout, mods): super(MindNetLSTMXL, self).__init__() self.mods = mods self.gaze_emb = nn.Sequential( nn.Linear(3, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim) ) self.pose_edge_index = pose_edge_index() self.pose_emb = Sequential('x, edge_index', [ (GCNConv(3, hidden_dim), 'x, edge_index -> x'), nn.GELU(), (GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'), ]) self.LSTM = PreNorm( hidden_dim*len(mods), nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True) ) self.proj = nn.Linear(hidden_dim*2, hidden_dim) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats): feats = [] if 'rgb' in self.mods: feats.append(rgb_feats) if 'ocr' in self.mods: feats.append(ocr_feats) if 'pose' in self.mods: bs, seq_len = pose.size(0), pose.size(1) self.pose_edge_index = self.pose_edge_index.to(pose.device) pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index) pose_emb = self.dropout(self.act(pose_emb)) pose_emb = torch.mean(pose_emb, dim=1) hd = pose_emb.size(-1) feats.append(pose_emb.view(bs, seq_len, hd)) if 'gaze' in self.mods: gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) feats.append(gaze_feats) if 'bbox' in self.mods: feats.append(bbox_feats) lstm_inp = torch.cat(feats, 2) lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp)) c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2) return self.act(self.proj(lstm_out)), c_n, feats