up
This commit is contained in:
parent
d4aaf7f4ad
commit
25b8b3f343
55 changed files with 7592 additions and 4 deletions
0
boss/models/__init__.py
Normal file
0
boss/models/__init__.py
Normal file
230
boss/models/base.py
Normal file
230
boss/models/base.py
Normal file
|
@ -0,0 +1,230 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .utils import pose_edge_index
|
||||
from torch_geometric.nn import Sequential, GCNConv
|
||||
from x_transformers import ContinuousTransformerWrapper, Decoder
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self, hidden_dim):
|
||||
super(CNN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv2(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv3(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
|
||||
return x
|
||||
|
||||
|
||||
class MindNetLSTM(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetLSTM, self).__init__()
|
||||
self.mods = mods
|
||||
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||||
feats = []
|
||||
if 'rgb' in self.mods:
|
||||
feats.append(rgb_feats)
|
||||
if 'ocr' in self.mods:
|
||||
feats.append(ocr_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats)
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||||
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||||
return self.act(self.proj(lstm_out)), c_n, feats
|
||||
|
||||
|
||||
class MindNetSL(nn.Module):
|
||||
"""
|
||||
Basic MindNet for SL ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetSL, self).__init__()
|
||||
self.mods = mods
|
||||
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*5,
|
||||
nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
|
||||
)
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||||
feats = []
|
||||
if 'rgb' in self.mods:
|
||||
feats.append(rgb_feats)
|
||||
if 'ocr' in self.mods:
|
||||
feats.append(ocr_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats)
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
|
||||
return self.act(self.proj(lstm_out)), feats
|
||||
|
||||
|
||||
class MindNetTF(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, Transformer on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetTF, self).__init__()
|
||||
self.mods = mods
|
||||
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.tf = ContinuousTransformerWrapper(
|
||||
dim_in=hidden_dim*len(mods),
|
||||
dim_out=hidden_dim,
|
||||
max_seq_len=747,
|
||||
attn_layers=Decoder(
|
||||
dim=512,
|
||||
depth=6,
|
||||
heads=8
|
||||
)
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||||
feats = []
|
||||
if 'rgb' in self.mods:
|
||||
feats.append(rgb_feats)
|
||||
if 'ocr' in self.mods:
|
||||
feats.append(ocr_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats)
|
||||
tf_inp = torch.cat(feats, 2)
|
||||
tf_out = self.tf(self.dropout(tf_inp))
|
||||
return tf_out, feats
|
||||
|
||||
|
||||
class MindNetLSTMXL(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetLSTMXL, self).__init__()
|
||||
self.mods = mods
|
||||
self.gaze_emb = nn.Sequential(
|
||||
nn.Linear(3, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, hidden_dim)
|
||||
)
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = Sequential('x, edge_index', [
|
||||
(GCNConv(3, hidden_dim), 'x, edge_index -> x'),
|
||||
nn.GELU(),
|
||||
(GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'),
|
||||
])
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
|
||||
)
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||||
feats = []
|
||||
if 'rgb' in self.mods:
|
||||
feats.append(rgb_feats)
|
||||
if 'ocr' in self.mods:
|
||||
feats.append(ocr_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats)
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||||
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||||
return self.act(self.proj(lstm_out)), c_n, feats
|
249
boss/models/resnet.py
Normal file
249
boss/models/resnet.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .utils import left_bias, right_bias
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, input_dim, device):
|
||||
super(ResNet, self).__init__()
|
||||
# Conv
|
||||
resnet = models.resnet34(pretrained=True)
|
||||
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
|
||||
# FFs
|
||||
self.left = nn.Linear(input_dim, 27)
|
||||
self.right = nn.Linear(input_dim, 27)
|
||||
# modality FFs
|
||||
self.pose_ff = nn.Linear(150, 150)
|
||||
self.gaze_ff = nn.Linear(6, 6)
|
||||
self.bbox_ff = nn.Linear(108, 108)
|
||||
self.ocr_ff = nn.Linear(729, 64)
|
||||
# others
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.device = device
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
left_beliefs = []
|
||||
right_beliefs = []
|
||||
image_feats = []
|
||||
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i].to(self.device)
|
||||
image_i_feat = self.resnet(images_i)
|
||||
image_i_feat = image_i_feat.view(batch_size, 512)
|
||||
if poses is not None:
|
||||
poses_i = poses[:,i].float()
|
||||
poses_i_feat = self.relu(self.pose_ff(poses_i))
|
||||
image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1)
|
||||
if gazes is not None:
|
||||
gazes_i = gazes[:,i].float()
|
||||
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
|
||||
image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1)
|
||||
if bboxes is not None:
|
||||
bboxes_i = bboxes[:,i].float()
|
||||
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
|
||||
image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1)
|
||||
if ocr_tensor is not None:
|
||||
ocr_tensor = ocr_tensor
|
||||
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
|
||||
image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1)
|
||||
image_feats.append(image_i_feat)
|
||||
|
||||
image_feats = torch.permute(torch.stack(image_feats), (1,0,2))
|
||||
left_beliefs = self.left(self.dropout(image_feats))
|
||||
right_beliefs = self.right(self.dropout(image_feats))
|
||||
|
||||
return left_beliefs, right_beliefs, None
|
||||
|
||||
|
||||
class ResNetGRU(nn.Module):
|
||||
def __init__(self, input_dim, device):
|
||||
super(ResNetGRU, self).__init__()
|
||||
resnet = models.resnet34(pretrained=True)
|
||||
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
|
||||
self.gru = nn.GRU(input_dim, 512, batch_first=True)
|
||||
for name, param in self.gru.named_parameters():
|
||||
if "weight" in name:
|
||||
nn.init.orthogonal_(param)
|
||||
elif "bias" in name:
|
||||
nn.init.constant_(param, 0)
|
||||
# FFs
|
||||
self.left = nn.Linear(512, 27)
|
||||
self.right = nn.Linear(512, 27)
|
||||
# modality FFs
|
||||
self.pose_ff = nn.Linear(150, 150)
|
||||
self.gaze_ff = nn.Linear(6, 6)
|
||||
self.bbox_ff = nn.Linear(108, 108)
|
||||
self.ocr_ff = nn.Linear(729, 64)
|
||||
# others
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.relu = nn.ReLU()
|
||||
self.device = device
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
left_beliefs = []
|
||||
right_beliefs = []
|
||||
rnn_inp = []
|
||||
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
rnn_i_feat = self.resnet(images_i)
|
||||
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
|
||||
if poses is not None:
|
||||
poses_i = poses[:,i].float()
|
||||
poses_i_feat = self.relu(self.pose_ff(poses_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
|
||||
if gazes is not None:
|
||||
gazes_i = gazes[:,i].float()
|
||||
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
|
||||
if bboxes is not None:
|
||||
bboxes_i = bboxes[:,i].float()
|
||||
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
|
||||
if ocr_tensor is not None:
|
||||
ocr_tensor = ocr_tensor
|
||||
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
|
||||
rnn_inp.append(rnn_i_feat)
|
||||
|
||||
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
|
||||
rnn_out, _ = self.gru(rnn_inp)
|
||||
left_beliefs = self.left(self.dropout(rnn_out))
|
||||
right_beliefs = self.right(self.dropout(rnn_out))
|
||||
|
||||
return left_beliefs, right_beliefs, None
|
||||
|
||||
|
||||
class ResNetConv1D(nn.Module):
|
||||
def __init__(self, input_dim, device):
|
||||
super(ResNetConv1D, self).__init__()
|
||||
resnet = models.resnet34(pretrained=True)
|
||||
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
|
||||
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4)
|
||||
# FFs
|
||||
self.left = nn.Linear(512, 27)
|
||||
self.right = nn.Linear(512, 27)
|
||||
# modality FFs
|
||||
self.pose_ff = nn.Linear(150, 150)
|
||||
self.gaze_ff = nn.Linear(6, 6)
|
||||
self.bbox_ff = nn.Linear(108, 108)
|
||||
self.ocr_ff = nn.Linear(729, 64)
|
||||
# others
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.device = device
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
left_beliefs = []
|
||||
right_beliefs = []
|
||||
conv1d_inp = []
|
||||
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
images_i_feat = self.resnet(images_i)
|
||||
images_i_feat = images_i_feat.view(batch_size, 512)
|
||||
if poses is not None:
|
||||
poses_i = poses[:,i].float()
|
||||
poses_i_feat = self.relu(self.pose_ff(poses_i))
|
||||
images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1)
|
||||
if gazes is not None:
|
||||
gazes_i = gazes[:,i].float()
|
||||
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
|
||||
images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1)
|
||||
if bboxes is not None:
|
||||
bboxes_i = bboxes[:,i].float()
|
||||
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
|
||||
images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1)
|
||||
if ocr_tensor is not None:
|
||||
ocr_tensor = ocr_tensor
|
||||
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
|
||||
images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1)
|
||||
conv1d_inp.append(images_i_feat)
|
||||
|
||||
conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0))
|
||||
conv1d_out = self.conv1d(conv1d_inp)
|
||||
conv1d_out = conv1d_out[:,:,:-4]
|
||||
conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1)))
|
||||
left_beliefs = self.left(self.dropout(conv1d_out))
|
||||
right_beliefs = self.right(self.dropout(conv1d_out))
|
||||
|
||||
return left_beliefs, right_beliefs, None
|
||||
|
||||
|
||||
class ResNetLSTM(nn.Module):
|
||||
def __init__(self, input_dim, device):
|
||||
super(ResNetLSTM, self).__init__()
|
||||
resnet = models.resnet34(pretrained=True)
|
||||
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
|
||||
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True)
|
||||
# FFs
|
||||
self.left = nn.Linear(512, 27)
|
||||
self.right = nn.Linear(512, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
# modality FFs
|
||||
self.pose_ff = nn.Linear(150, 150)
|
||||
self.gaze_ff = nn.Linear(6, 6)
|
||||
self.bbox_ff = nn.Linear(108, 108)
|
||||
self.ocr_ff = nn.Linear(729, 64)
|
||||
# others
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.device = device
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
left_beliefs = []
|
||||
right_beliefs = []
|
||||
rnn_inp = []
|
||||
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
rnn_i_feat = self.resnet(images_i)
|
||||
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
|
||||
if poses is not None:
|
||||
poses_i = poses[:,i].float()
|
||||
poses_i_feat = self.relu(self.pose_ff(poses_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
|
||||
if gazes is not None:
|
||||
gazes_i = gazes[:,i].float()
|
||||
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
|
||||
if bboxes is not None:
|
||||
bboxes_i = bboxes[:,i].float()
|
||||
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
|
||||
if ocr_tensor is not None:
|
||||
ocr_tensor = ocr_tensor
|
||||
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
|
||||
rnn_inp.append(rnn_i_feat)
|
||||
|
||||
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
|
||||
rnn_out, _ = self.lstm(rnn_inp)
|
||||
left_beliefs = self.left(self.dropout(rnn_out))
|
||||
right_beliefs = self.right(self.dropout(rnn_out))
|
||||
|
||||
return left_beliefs, right_beliefs, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 150)
|
||||
gazes = torch.ones(3, 22, 6)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = ResNet(32, 'cpu')
|
||||
print(model(images, poses, gazes, bboxes, None)[0].shape)
|
95
boss/models/single_mindnet.py
Normal file
95
boss/models/single_mindnet.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetLSTM
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class SingleMindNet(nn.Module):
|
||||
"""
|
||||
Base ToM net. Supports any subset of modalities
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(SingleMindNet, self).__init__()
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def count_parameters(model):
|
||||
import numpy as np
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = SingleMindNet(64, 'cpu', False, 0.5)
|
||||
print(count_parameters(model))
|
||||
breakpoint()
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
104
boss/models/tom_base.py
Normal file
104
boss/models/tom_base.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetLSTM
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseToMnet(nn.Module):
|
||||
"""
|
||||
Base ToM net. Supports any subset of modalities
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(BaseToMnet, self).__init__()
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = BaseToMnet(64, 'cpu', False, 0.5)
|
||||
print(count_parameters(model))
|
||||
breakpoint()
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
269
boss/models/tom_common_mind.py
Normal file
269
boss/models/tom_common_mind.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetLSTM, MindNetLSTMXL
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class CommonMindToMnet(nn.Module):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(CommonMindToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
#for param in self.cnn.parameters():
|
||||
# param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.ln_left = nn.LayerNorm(hidden_dim)
|
||||
self.ln_right = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
return self.left(out_left), self.right(out_right)
|
||||
|
||||
if self.aggr == 'attn':
|
||||
l = self.attn_left(x=out_left, context=common_mind)
|
||||
r = self.attn_right(x=out_right, context=common_mind)
|
||||
elif self.aggr == 'mult':
|
||||
l = out_left * common_mind
|
||||
r = out_right * common_mind
|
||||
elif self.aggr == 'sum':
|
||||
l = out_left + common_mind
|
||||
r = out_right + common_mind
|
||||
elif self.aggr == 'concat':
|
||||
l = torch.cat([out_left, common_mind], 1)
|
||||
r = torch.cat([out_right, common_mind], 1)
|
||||
else: raise ValueError
|
||||
l = self.act(l)
|
||||
l = self.ln_left(l)
|
||||
r = self.act(r)
|
||||
r = self.ln_right(r)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
left_beliefs = self.left(l)
|
||||
right_beliefs = self.right(r)
|
||||
if self.aggr == 'concat':
|
||||
left_beliefs = self.left(l)[:, :-1, :]
|
||||
right_beliefs = self.right(r)[:, :-1, :]
|
||||
|
||||
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
class CommonMindToMnetXL(nn.Module):
|
||||
"""
|
||||
XL model.
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(CommonMindToMnetXL, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetLSTMXL(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetLSTMXL(hidden_dim, dropout, mods)
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.ln_left = nn.LayerNorm(hidden_dim)
|
||||
self.ln_right = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.left = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, 27),
|
||||
)
|
||||
self.right = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, 27)
|
||||
)
|
||||
self.left[-1].bias.data = torch.tensor(left_bias).log()
|
||||
self.right[-1].bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
return self.left(out_left), self.right(out_right)
|
||||
|
||||
if self.aggr == 'attn':
|
||||
l = self.attn_left(x=out_left, context=common_mind)
|
||||
r = self.attn_right(x=out_right, context=common_mind)
|
||||
elif self.aggr == 'mult':
|
||||
l = out_left * common_mind
|
||||
r = out_right * common_mind
|
||||
elif self.aggr == 'sum':
|
||||
l = out_left + common_mind
|
||||
r = out_right + common_mind
|
||||
elif self.aggr == 'concat':
|
||||
l = torch.cat([out_left, common_mind], 1)
|
||||
r = torch.cat([out_right, common_mind], 1)
|
||||
else: raise ValueError
|
||||
l = self.act(l)
|
||||
l = self.ln_left(l)
|
||||
r = self.act(r)
|
||||
r = self.ln_right(r)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
left_beliefs = self.left(l)
|
||||
right_beliefs = self.right(r)
|
||||
if self.aggr == 'concat':
|
||||
left_beliefs = self.left(l)[:, :-1, :]
|
||||
right_beliefs = self.right(r)[:, :-1, :]
|
||||
|
||||
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = CommonMindToMnetXL(64, 'cpu', False, 0.5, aggr='attn')
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
144
boss/models/tom_implicit.py
Normal file
144
boss/models/tom_implicit.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetLSTM
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class ImplicitToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(ImplicitToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.ln_left = nn.LayerNorm(hidden_dim)
|
||||
self.ln_right = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
|
||||
|
||||
if self.aggr == 'attn':
|
||||
l = self.attn_left(x=out_left, context=cell_right)
|
||||
r = self.attn_right(x=out_right, context=cell_left)
|
||||
elif self.aggr == 'mult':
|
||||
l = out_left * cell_right
|
||||
r = out_right * cell_left
|
||||
elif self.aggr == 'sum':
|
||||
l = out_left + cell_right
|
||||
r = out_right + cell_left
|
||||
elif self.aggr == 'concat':
|
||||
l = torch.cat([out_left, cell_right], 1)
|
||||
r = torch.cat([out_right, cell_left], 1)
|
||||
else: raise ValueError
|
||||
l = self.act(l)
|
||||
l = self.ln_left(l)
|
||||
r = self.act(r)
|
||||
r = self.ln_right(r)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
left_beliefs = self.left(l)
|
||||
right_beliefs = self.right(r)
|
||||
if self.aggr == 'concat':
|
||||
left_beliefs = self.left(l)[:, :-1, :]
|
||||
right_beliefs = self.right(r)[:, :-1, :]
|
||||
|
||||
return left_beliefs, right_beliefs, [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = ImplicitToMnet(64, 'cpu', False, 0.5, aggr='attn')
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
98
boss/models/tom_sl.py
Normal file
98
boss/models/tom_sl.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph, pose_edge_index
|
||||
from .base import CNN, MindNetSL
|
||||
|
||||
|
||||
class SLToMnet(nn.Module):
|
||||
"""
|
||||
Speaker-Listener ToMnet
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(SLToMnet, self).__init__()
|
||||
|
||||
self.tom_weight = tom_weight
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetSL(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetSL(hidden_dim, dropout, mods)
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
left_logits = self.left(out_left)
|
||||
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
right_logits = self.left(out_right)
|
||||
|
||||
left_ranking = torch.log_softmax(left_logits, dim=-1)
|
||||
right_ranking = torch.log_softmax(right_logits, dim=-1)
|
||||
|
||||
right_beliefs = left_ranking + self.tom_weight * right_ranking
|
||||
left_beliefs = right_ranking + self.tom_weight * left_ranking
|
||||
|
||||
return left_beliefs, right_beliefs, [out_left, out_right] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = SLToMnet(64, 'cpu', 2.0, False, 0.5)
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
104
boss/models/tom_tf.py
Normal file
104
boss/models/tom_tf.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetTF
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class TFToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(TFToMnet, self).__init__()
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetTF(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetTF(hidden_dim, dropout, mods)
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
l = self.dropout(self.act(out_left))
|
||||
r = self.dropout(self.act(out_right))
|
||||
left_beliefs = self.left(l)
|
||||
right_beliefs = self.right(r)
|
||||
|
||||
return left_beliefs, right_beliefs, feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = TFToMnet(64, 'cpu', False, 0.5)
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
||||
|
95
boss/models/utils.py
Normal file
95
boss/models/utils.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import torch
|
||||
|
||||
left_bias = [0.10659290976303822,
|
||||
0.025158905348262015,
|
||||
0.02811095449589107,
|
||||
0.026342384511050237,
|
||||
0.025318475572458178,
|
||||
0.02283183957873461,
|
||||
0.021581872822531316,
|
||||
0.08062285577511237,
|
||||
0.03824366373234754,
|
||||
0.04853594319300018,
|
||||
0.09653998563867983,
|
||||
0.02961357410707162,
|
||||
0.02961357410707162,
|
||||
0.03172787957767081,
|
||||
0.029985904630196004,
|
||||
0.02897529321028696,
|
||||
0.06602218026116327,
|
||||
0.015345336560197867,
|
||||
0.026900880295736816,
|
||||
0.024879657455918726,
|
||||
0.028669450280577644,
|
||||
0.01936118720246802,
|
||||
0.02341693040078721,
|
||||
0.014707055663413206,
|
||||
0.027007260445200926,
|
||||
0.04146166325363687,
|
||||
0.04243238211749688]
|
||||
|
||||
right_bias = [0.13147256721895695,
|
||||
0.012433179968617855,
|
||||
0.01623627031195979,
|
||||
0.013683146724821148,
|
||||
0.015252253929416771,
|
||||
0.012579452674131008,
|
||||
0.03127576394244834,
|
||||
0.10325523257360177,
|
||||
0.041155820323927554,
|
||||
0.06563655221935587,
|
||||
0.12684503071726816,
|
||||
0.016156485199861705,
|
||||
0.0176989973670913,
|
||||
0.020238823435546928,
|
||||
0.01918831945958884,
|
||||
0.01791175766601952,
|
||||
0.08768383819579266,
|
||||
0.019002154198026647,
|
||||
0.029600276588388607,
|
||||
0.01578415467673732,
|
||||
0.0176989973670913,
|
||||
0.011834791627882237,
|
||||
0.014919815962341426,
|
||||
0.007552990611951809,
|
||||
0.029759846812584773,
|
||||
0.04981250498656951,
|
||||
0.05533097524002021]
|
||||
|
||||
def build_ocr_graph(device):
|
||||
ocr_graph = [
|
||||
[15, [10, 4], [17, 2]],
|
||||
[13, [16, 7], [18, 4]],
|
||||
[11, [16, 4], [7, 10]],
|
||||
[14, [10, 11], [7, 1]],
|
||||
[12, [10, 9], [16, 3]],
|
||||
[1, [7, 2], [9, 9], [10, 2]],
|
||||
[5, [8, 8], [6, 8]],
|
||||
[4, [9, 8], [7, 6]],
|
||||
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
|
||||
[2, [10, 1], [7, 7], [9, 3]],
|
||||
[19, [10, 2], [26, 6]],
|
||||
[20, [10, 7], [26, 5]],
|
||||
[22, [25, 4], [10, 8]],
|
||||
[23, [25, 15]],
|
||||
[21, [16, 5], [24, 8]]
|
||||
]
|
||||
edge_index = []
|
||||
edge_attr = []
|
||||
for i in range(len(ocr_graph)):
|
||||
for j in range(1, len(ocr_graph[i])):
|
||||
source_node = ocr_graph[i][0]
|
||||
target_node = ocr_graph[i][j][0]
|
||||
edge_index.append([source_node, target_node])
|
||||
edge_attr.append(ocr_graph[i][j][1])
|
||||
ocr_edge_index = torch.tensor(edge_index).t().long()
|
||||
ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1)
|
||||
x = torch.arange(0, 27)
|
||||
ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float)
|
||||
return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device)
|
||||
|
||||
def pose_edge_index():
|
||||
return torch.tensor(
|
||||
[[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20],
|
||||
[15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]],
|
||||
dtype=torch.long)
|
Loading…
Add table
Add a link
Reference in a new issue