mtomnet/boss/models/base.py

230 lines
8.6 KiB
Python
Raw Normal View History

2025-01-10 15:39:20 +01:00
import torch
import torch.nn as nn
from .utils import pose_edge_index
from torch_geometric.nn import Sequential, GCNConv
from x_transformers import ContinuousTransformerWrapper, Decoder
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
def forward(self, x):
return self.net(x)
class CNN(nn.Module):
def __init__(self, hidden_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv3(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
return x
class MindNetLSTM(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTM, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats
class MindNetSL(nn.Module):
"""
Basic MindNet for SL ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetSL, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*5,
nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
return self.act(self.proj(lstm_out)), feats
class MindNetTF(nn.Module):
"""
Basic MindNet for model-based ToM, Transformer on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetTF, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.tf = ContinuousTransformerWrapper(
dim_in=hidden_dim*len(mods),
dim_out=hidden_dim,
max_seq_len=747,
attn_layers=Decoder(
dim=512,
depth=6,
heads=8
)
)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
tf_inp = torch.cat(feats, 2)
tf_out = self.tf(self.dropout(tf_inp))
return tf_out, feats
class MindNetLSTMXL(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTMXL, self).__init__()
self.mods = mods
self.gaze_emb = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.pose_edge_index = pose_edge_index()
self.pose_emb = Sequential('x, edge_index', [
(GCNConv(3, hidden_dim), 'x, edge_index -> x'),
nn.GELU(),
(GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'),
])
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats