156 lines
No EOL
6 KiB
Python
156 lines
No EOL
6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .utils import pose_edge_index
|
|
from torch_geometric.nn import GCNConv
|
|
|
|
|
|
class PreNorm(nn.Module):
|
|
def __init__(self, dim, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
self.norm = nn.LayerNorm(dim)
|
|
def forward(self, x, **kwargs):
|
|
x = self.norm(x)
|
|
return self.fn(x, **kwargs)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(dim, dim),
|
|
nn.GELU(),
|
|
nn.Linear(dim, dim))
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class CNN(nn.Module):
|
|
def __init__(self, hidden_dim):
|
|
super(CNN, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
|
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = nn.functional.relu(x)
|
|
x = self.pool(x)
|
|
x = self.conv2(x)
|
|
x = nn.functional.relu(x)
|
|
x = self.pool(x)
|
|
x = self.conv3(x)
|
|
x = nn.functional.relu(x)
|
|
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
|
|
return x
|
|
|
|
|
|
class MindNetLSTM(nn.Module):
|
|
"""
|
|
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
|
"""
|
|
def __init__(self, hidden_dim, dropout, mods):
|
|
super(MindNetLSTM, self).__init__()
|
|
self.mods = mods
|
|
if 'rgb_1' in mods:
|
|
self.img_emb = CNN(hidden_dim)
|
|
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
|
if 'gaze' in mods:
|
|
self.gaze_emb = nn.Linear(2, hidden_dim)
|
|
if 'pose' in mods:
|
|
self.pose_edge_index = pose_edge_index()
|
|
self.pose_emb = GCNConv(3, hidden_dim)
|
|
self.LSTM = PreNorm(
|
|
hidden_dim*len(mods),
|
|
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
|
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.act = nn.GELU()
|
|
|
|
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
|
|
feats = []
|
|
if 'rgb_3' in self.mods:
|
|
feats.append(rgb_3rd_pov_feats)
|
|
if 'rgb_1' in self.mods:
|
|
rgb_feat = []
|
|
for i in range(rgb_1st_pov.shape[1]):
|
|
images_i = rgb_1st_pov[:,i]
|
|
img_i_feat = self.img_emb(images_i)
|
|
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
|
|
rgb_feat.append(img_i_feat)
|
|
rgb_feat = torch.stack(rgb_feat, 1)
|
|
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
|
feats.append(rgb_feats)
|
|
if 'pose' in self.mods:
|
|
bs, seq_len = pose.size(0), pose.size(1)
|
|
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
|
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
|
|
pose_emb = self.dropout(self.act(pose_emb))
|
|
pose_emb = torch.mean(pose_emb, dim=1)
|
|
hd = pose_emb.size(-1)
|
|
feats.append(pose_emb.view(bs, seq_len, hd))
|
|
if 'gaze' in self.mods:
|
|
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
|
feats.append(gaze_feats)
|
|
if 'bbox' in self.mods:
|
|
feats.append(bbox_feats.mean(2))
|
|
lstm_inp = torch.cat(feats, 2)
|
|
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
|
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
|
return self.act(self.proj(lstm_out)), c_n, feats
|
|
|
|
|
|
class MindNetSL(nn.Module):
|
|
"""
|
|
Basic MindNet for SL ToM, just LSTM on input concatenation
|
|
"""
|
|
def __init__(self, hidden_dim, dropout, mods):
|
|
super(MindNetSL, self).__init__()
|
|
self.mods = mods
|
|
if 'rgb_1' in mods:
|
|
self.img_emb = CNN(hidden_dim)
|
|
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
|
if 'gaze' in mods:
|
|
self.gaze_emb = nn.Linear(2, hidden_dim)
|
|
if 'pose' in mods:
|
|
self.pose_edge_index = pose_edge_index()
|
|
self.pose_emb = GCNConv(3, hidden_dim)
|
|
self.LSTM = PreNorm(
|
|
hidden_dim*len(mods),
|
|
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
|
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.act = nn.GELU()
|
|
|
|
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
|
|
feats = []
|
|
if 'rgb_3' in self.mods:
|
|
feats.append(rgb_3rd_pov_feats)
|
|
if 'rgb_1' in self.mods:
|
|
rgb_feat = []
|
|
for i in range(rgb_1st_pov.shape[1]):
|
|
images_i = rgb_1st_pov[:,i]
|
|
img_i_feat = self.img_emb(images_i)
|
|
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
|
|
rgb_feat.append(img_i_feat)
|
|
rgb_feat = torch.stack(rgb_feat, 1)
|
|
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
|
feats.append(rgb_feats)
|
|
if 'pose' in self.mods:
|
|
bs, seq_len = pose.size(0), pose.size(1)
|
|
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
|
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
|
|
pose_emb = self.dropout(self.act(pose_emb))
|
|
pose_emb = torch.mean(pose_emb, dim=1)
|
|
hd = pose_emb.size(-1)
|
|
feats.append(pose_emb.view(bs, seq_len, hd))
|
|
if 'gaze' in self.mods:
|
|
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
|
feats.append(gaze_feats)
|
|
if 'bbox' in self.mods:
|
|
feats.append(bbox_feats.mean(2))
|
|
lstm_inp = torch.cat(feats, 2)
|
|
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
|
|
return self.act(self.proj(lstm_out)), feats |