mtomnet/tbd/models/base.py

156 lines
6 KiB
Python
Raw Permalink Normal View History

2025-01-10 15:39:20 +01:00
import torch
import torch.nn as nn
from .utils import pose_edge_index
from torch_geometric.nn import GCNConv
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
def forward(self, x):
return self.net(x)
class CNN(nn.Module):
def __init__(self, hidden_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv3(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
return x
class MindNetLSTM(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTM, self).__init__()
self.mods = mods
if 'rgb_1' in mods:
self.img_emb = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
if 'gaze' in mods:
self.gaze_emb = nn.Linear(2, hidden_dim)
if 'pose' in mods:
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
feats = []
if 'rgb_3' in self.mods:
feats.append(rgb_3rd_pov_feats)
if 'rgb_1' in self.mods:
rgb_feat = []
for i in range(rgb_1st_pov.shape[1]):
images_i = rgb_1st_pov[:,i]
img_i_feat = self.img_emb(images_i)
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
feats.append(rgb_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats.mean(2))
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats
class MindNetSL(nn.Module):
"""
Basic MindNet for SL ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetSL, self).__init__()
self.mods = mods
if 'rgb_1' in mods:
self.img_emb = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
if 'gaze' in mods:
self.gaze_emb = nn.Linear(2, hidden_dim)
if 'pose' in mods:
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
feats = []
if 'rgb_3' in self.mods:
feats.append(rgb_3rd_pov_feats)
if 'rgb_1' in self.mods:
rgb_feat = []
for i in range(rgb_1st_pov.shape[1]):
images_i = rgb_1st_pov[:,i]
img_i_feat = self.img_emb(images_i)
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
feats.append(rgb_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats.mean(2))
lstm_inp = torch.cat(feats, 2)
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
return self.act(self.proj(lstm_out)), feats