230 lines
8.6 KiB
Python
230 lines
8.6 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from .utils import pose_edge_index
|
||
|
from torch_geometric.nn import Sequential, GCNConv
|
||
|
from x_transformers import ContinuousTransformerWrapper, Decoder
|
||
|
|
||
|
|
||
|
class PreNorm(nn.Module):
|
||
|
def __init__(self, dim, fn):
|
||
|
super().__init__()
|
||
|
self.fn = fn
|
||
|
self.norm = nn.LayerNorm(dim)
|
||
|
def forward(self, x, **kwargs):
|
||
|
x = self.norm(x)
|
||
|
return self.fn(x, **kwargs)
|
||
|
|
||
|
|
||
|
class FeedForward(nn.Module):
|
||
|
def __init__(self, dim):
|
||
|
super().__init__()
|
||
|
self.net = nn.Sequential(
|
||
|
nn.Linear(dim, dim),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(dim, dim))
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.net(x)
|
||
|
|
||
|
|
||
|
class CNN(nn.Module):
|
||
|
def __init__(self, hidden_dim):
|
||
|
super(CNN, self).__init__()
|
||
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
||
|
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.conv1(x)
|
||
|
x = nn.functional.relu(x)
|
||
|
x = self.pool(x)
|
||
|
x = self.conv2(x)
|
||
|
x = nn.functional.relu(x)
|
||
|
x = self.pool(x)
|
||
|
x = self.conv3(x)
|
||
|
x = nn.functional.relu(x)
|
||
|
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
|
||
|
return x
|
||
|
|
||
|
|
||
|
class MindNetLSTM(nn.Module):
|
||
|
"""
|
||
|
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||
|
"""
|
||
|
def __init__(self, hidden_dim, dropout, mods):
|
||
|
super(MindNetLSTM, self).__init__()
|
||
|
self.mods = mods
|
||
|
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||
|
self.pose_edge_index = pose_edge_index()
|
||
|
self.pose_emb = GCNConv(3, hidden_dim)
|
||
|
self.LSTM = PreNorm(
|
||
|
hidden_dim*len(mods),
|
||
|
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||
|
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.act = nn.GELU()
|
||
|
|
||
|
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||
|
feats = []
|
||
|
if 'rgb' in self.mods:
|
||
|
feats.append(rgb_feats)
|
||
|
if 'ocr' in self.mods:
|
||
|
feats.append(ocr_feats)
|
||
|
if 'pose' in self.mods:
|
||
|
bs, seq_len = pose.size(0), pose.size(1)
|
||
|
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||
|
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||
|
pose_emb = self.dropout(self.act(pose_emb))
|
||
|
pose_emb = torch.mean(pose_emb, dim=1)
|
||
|
hd = pose_emb.size(-1)
|
||
|
feats.append(pose_emb.view(bs, seq_len, hd))
|
||
|
if 'gaze' in self.mods:
|
||
|
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||
|
feats.append(gaze_feats)
|
||
|
if 'bbox' in self.mods:
|
||
|
feats.append(bbox_feats)
|
||
|
lstm_inp = torch.cat(feats, 2)
|
||
|
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||
|
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||
|
return self.act(self.proj(lstm_out)), c_n, feats
|
||
|
|
||
|
|
||
|
class MindNetSL(nn.Module):
|
||
|
"""
|
||
|
Basic MindNet for SL ToM, just LSTM on input concatenation
|
||
|
"""
|
||
|
def __init__(self, hidden_dim, dropout, mods):
|
||
|
super(MindNetSL, self).__init__()
|
||
|
self.mods = mods
|
||
|
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||
|
self.pose_edge_index = pose_edge_index()
|
||
|
self.pose_emb = GCNConv(3, hidden_dim)
|
||
|
self.LSTM = PreNorm(
|
||
|
hidden_dim*5,
|
||
|
nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
|
||
|
)
|
||
|
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.act = nn.GELU()
|
||
|
|
||
|
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||
|
feats = []
|
||
|
if 'rgb' in self.mods:
|
||
|
feats.append(rgb_feats)
|
||
|
if 'ocr' in self.mods:
|
||
|
feats.append(ocr_feats)
|
||
|
if 'pose' in self.mods:
|
||
|
bs, seq_len = pose.size(0), pose.size(1)
|
||
|
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||
|
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||
|
pose_emb = self.dropout(self.act(pose_emb))
|
||
|
pose_emb = torch.mean(pose_emb, dim=1)
|
||
|
hd = pose_emb.size(-1)
|
||
|
feats.append(pose_emb.view(bs, seq_len, hd))
|
||
|
if 'gaze' in self.mods:
|
||
|
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||
|
feats.append(gaze_feats)
|
||
|
if 'bbox' in self.mods:
|
||
|
feats.append(bbox_feats)
|
||
|
lstm_inp = torch.cat(feats, 2)
|
||
|
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
|
||
|
return self.act(self.proj(lstm_out)), feats
|
||
|
|
||
|
|
||
|
class MindNetTF(nn.Module):
|
||
|
"""
|
||
|
Basic MindNet for model-based ToM, Transformer on input concatenation
|
||
|
"""
|
||
|
def __init__(self, hidden_dim, dropout, mods):
|
||
|
super(MindNetTF, self).__init__()
|
||
|
self.mods = mods
|
||
|
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||
|
self.pose_edge_index = pose_edge_index()
|
||
|
self.pose_emb = GCNConv(3, hidden_dim)
|
||
|
self.tf = ContinuousTransformerWrapper(
|
||
|
dim_in=hidden_dim*len(mods),
|
||
|
dim_out=hidden_dim,
|
||
|
max_seq_len=747,
|
||
|
attn_layers=Decoder(
|
||
|
dim=512,
|
||
|
depth=6,
|
||
|
heads=8
|
||
|
)
|
||
|
)
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.act = nn.GELU()
|
||
|
|
||
|
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||
|
feats = []
|
||
|
if 'rgb' in self.mods:
|
||
|
feats.append(rgb_feats)
|
||
|
if 'ocr' in self.mods:
|
||
|
feats.append(ocr_feats)
|
||
|
if 'pose' in self.mods:
|
||
|
bs, seq_len = pose.size(0), pose.size(1)
|
||
|
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||
|
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||
|
pose_emb = self.dropout(self.act(pose_emb))
|
||
|
pose_emb = torch.mean(pose_emb, dim=1)
|
||
|
hd = pose_emb.size(-1)
|
||
|
feats.append(pose_emb.view(bs, seq_len, hd))
|
||
|
if 'gaze' in self.mods:
|
||
|
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||
|
feats.append(gaze_feats)
|
||
|
if 'bbox' in self.mods:
|
||
|
feats.append(bbox_feats)
|
||
|
tf_inp = torch.cat(feats, 2)
|
||
|
tf_out = self.tf(self.dropout(tf_inp))
|
||
|
return tf_out, feats
|
||
|
|
||
|
|
||
|
class MindNetLSTMXL(nn.Module):
|
||
|
"""
|
||
|
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||
|
"""
|
||
|
def __init__(self, hidden_dim, dropout, mods):
|
||
|
super(MindNetLSTMXL, self).__init__()
|
||
|
self.mods = mods
|
||
|
self.gaze_emb = nn.Sequential(
|
||
|
nn.Linear(3, hidden_dim),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(hidden_dim, hidden_dim)
|
||
|
)
|
||
|
self.pose_edge_index = pose_edge_index()
|
||
|
self.pose_emb = Sequential('x, edge_index', [
|
||
|
(GCNConv(3, hidden_dim), 'x, edge_index -> x'),
|
||
|
nn.GELU(),
|
||
|
(GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'),
|
||
|
])
|
||
|
self.LSTM = PreNorm(
|
||
|
hidden_dim*len(mods),
|
||
|
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
|
||
|
)
|
||
|
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.act = nn.GELU()
|
||
|
|
||
|
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||
|
feats = []
|
||
|
if 'rgb' in self.mods:
|
||
|
feats.append(rgb_feats)
|
||
|
if 'ocr' in self.mods:
|
||
|
feats.append(ocr_feats)
|
||
|
if 'pose' in self.mods:
|
||
|
bs, seq_len = pose.size(0), pose.size(1)
|
||
|
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||
|
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||
|
pose_emb = self.dropout(self.act(pose_emb))
|
||
|
pose_emb = torch.mean(pose_emb, dim=1)
|
||
|
hd = pose_emb.size(-1)
|
||
|
feats.append(pose_emb.view(bs, seq_len, hd))
|
||
|
if 'gaze' in self.mods:
|
||
|
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||
|
feats.append(gaze_feats)
|
||
|
if 'bbox' in self.mods:
|
||
|
feats.append(bbox_feats)
|
||
|
lstm_inp = torch.cat(feats, 2)
|
||
|
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||
|
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||
|
return self.act(self.proj(lstm_out)), c_n, feats
|