This commit is contained in:
Matteo Bortoletto 2025-01-10 15:39:20 +01:00
parent d4aaf7f4ad
commit 25b8b3f343
55 changed files with 7592 additions and 4 deletions

0
boss/models/__init__.py Normal file
View file

230
boss/models/base.py Normal file
View file

@ -0,0 +1,230 @@
import torch
import torch.nn as nn
from .utils import pose_edge_index
from torch_geometric.nn import Sequential, GCNConv
from x_transformers import ContinuousTransformerWrapper, Decoder
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
def forward(self, x):
return self.net(x)
class CNN(nn.Module):
def __init__(self, hidden_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv3(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
return x
class MindNetLSTM(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTM, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats
class MindNetSL(nn.Module):
"""
Basic MindNet for SL ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetSL, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*5,
nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
return self.act(self.proj(lstm_out)), feats
class MindNetTF(nn.Module):
"""
Basic MindNet for model-based ToM, Transformer on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetTF, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.tf = ContinuousTransformerWrapper(
dim_in=hidden_dim*len(mods),
dim_out=hidden_dim,
max_seq_len=747,
attn_layers=Decoder(
dim=512,
depth=6,
heads=8
)
)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
tf_inp = torch.cat(feats, 2)
tf_out = self.tf(self.dropout(tf_inp))
return tf_out, feats
class MindNetLSTMXL(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTMXL, self).__init__()
self.mods = mods
self.gaze_emb = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.pose_edge_index = pose_edge_index()
self.pose_emb = Sequential('x, edge_index', [
(GCNConv(3, hidden_dim), 'x, edge_index -> x'),
nn.GELU(),
(GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'),
])
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats

249
boss/models/resnet.py Normal file
View file

@ -0,0 +1,249 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .utils import left_bias, right_bias
class ResNet(nn.Module):
def __init__(self, input_dim, device):
super(ResNet, self).__init__()
# Conv
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
# FFs
self.left = nn.Linear(input_dim, 27)
self.right = nn.Linear(input_dim, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
image_feats = []
for i in range(sequence_len):
images_i = images[:,i].to(self.device)
image_i_feat = self.resnet(images_i)
image_i_feat = image_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1)
image_feats.append(image_i_feat)
image_feats = torch.permute(torch.stack(image_feats), (1,0,2))
left_beliefs = self.left(self.dropout(image_feats))
right_beliefs = self.right(self.dropout(image_feats))
return left_beliefs, right_beliefs, None
class ResNetGRU(nn.Module):
def __init__(self, input_dim, device):
super(ResNetGRU, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.gru = nn.GRU(input_dim, 512, batch_first=True)
for name, param in self.gru.named_parameters():
if "weight" in name:
nn.init.orthogonal_(param)
elif "bias" in name:
nn.init.constant_(param, 0)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.dropout = nn.Dropout(p=0.2)
self.relu = nn.ReLU()
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
rnn_inp = []
for i in range(sequence_len):
images_i = images[:,i]
rnn_i_feat = self.resnet(images_i)
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
rnn_inp.append(rnn_i_feat)
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
rnn_out, _ = self.gru(rnn_inp)
left_beliefs = self.left(self.dropout(rnn_out))
right_beliefs = self.right(self.dropout(rnn_out))
return left_beliefs, right_beliefs, None
class ResNetConv1D(nn.Module):
def __init__(self, input_dim, device):
super(ResNetConv1D, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
conv1d_inp = []
for i in range(sequence_len):
images_i = images[:,i]
images_i_feat = self.resnet(images_i)
images_i_feat = images_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1)
conv1d_inp.append(images_i_feat)
conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0))
conv1d_out = self.conv1d(conv1d_inp)
conv1d_out = conv1d_out[:,:,:-4]
conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1)))
left_beliefs = self.left(self.dropout(conv1d_out))
right_beliefs = self.right(self.dropout(conv1d_out))
return left_beliefs, right_beliefs, None
class ResNetLSTM(nn.Module):
def __init__(self, input_dim, device):
super(ResNetLSTM, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
rnn_inp = []
for i in range(sequence_len):
images_i = images[:,i]
rnn_i_feat = self.resnet(images_i)
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
rnn_inp.append(rnn_i_feat)
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
rnn_out, _ = self.lstm(rnn_inp)
left_beliefs = self.left(self.dropout(rnn_out))
right_beliefs = self.right(self.dropout(rnn_out))
return left_beliefs, right_beliefs, None
if __name__ == '__main__':
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 150)
gazes = torch.ones(3, 22, 6)
bboxes = torch.ones(3, 22, 108)
model = ResNet(32, 'cpu')
print(model(images, poses, gazes, bboxes, None)[0].shape)

View file

@ -0,0 +1,95 @@
import torch
import torch.nn as nn
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
import torchvision.models as models
class SingleMindNet(nn.Module):
"""
Base ToM net. Supports any subset of modalities
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(SingleMindNet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net = MindNetLSTM(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if __name__ == "__main__":
def count_parameters(model):
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = SingleMindNet(64, 'cpu', False, 0.5)
print(count_parameters(model))
breakpoint()
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

104
boss/models/tom_base.py Normal file
View file

@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
import numpy as np
class BaseToMnet(nn.Module):
"""
Base ToM net. Supports any subset of modalities
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(BaseToMnet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = BaseToMnet(64, 'cpu', False, 0.5)
print(count_parameters(model))
breakpoint()
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

View file

@ -0,0 +1,269 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM, MindNetLSTMXL
from memory_efficient_attention_pytorch import Attention
class CommonMindToMnet(nn.Module):
"""
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(CommonMindToMnet, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
#for param in self.cnn.parameters():
# param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right)
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=common_mind)
r = self.attn_right(x=out_right, context=common_mind)
elif self.aggr == 'mult':
l = out_left * common_mind
r = out_right * common_mind
elif self.aggr == 'sum':
l = out_left + common_mind
r = out_right + common_mind
elif self.aggr == 'concat':
l = torch.cat([out_left, common_mind], 1)
r = torch.cat([out_right, common_mind], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
class CommonMindToMnetXL(nn.Module):
"""
XL model.
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(CommonMindToMnetXL, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTMXL(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTMXL(hidden_dim, dropout, mods)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 27),
)
self.right = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 27)
)
self.left[-1].bias.data = torch.tensor(left_bias).log()
self.right[-1].bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right)
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=common_mind)
r = self.attn_right(x=out_right, context=common_mind)
elif self.aggr == 'mult':
l = out_left * common_mind
r = out_right * common_mind
elif self.aggr == 'sum':
l = out_left + common_mind
r = out_right + common_mind
elif self.aggr == 'concat':
l = torch.cat([out_left, common_mind], 1)
r = torch.cat([out_right, common_mind], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = CommonMindToMnetXL(64, 'cpu', False, 0.5, aggr='attn')
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

144
boss/models/tom_implicit.py Normal file
View file

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class ImplicitToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(ImplicitToMnet, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=cell_right)
r = self.attn_right(x=out_right, context=cell_left)
elif self.aggr == 'mult':
l = out_left * cell_right
r = out_right * cell_left
elif self.aggr == 'sum':
l = out_left + cell_right
r = out_right + cell_left
elif self.aggr == 'concat':
l = torch.cat([out_left, cell_right], 1)
r = torch.cat([out_right, cell_left], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = ImplicitToMnet(64, 'cpu', False, 0.5, aggr='attn')
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

98
boss/models/tom_sl.py Normal file
View file

@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph, pose_edge_index
from .base import CNN, MindNetSL
class SLToMnet(nn.Module):
"""
Speaker-Listener ToMnet
"""
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(SLToMnet, self).__init__()
self.tom_weight = tom_weight
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetSL(hidden_dim, dropout, mods)
self.mind_net_right = MindNetSL(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
left_logits = self.left(out_left)
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
right_logits = self.left(out_right)
left_ranking = torch.log_softmax(left_logits, dim=-1)
right_ranking = torch.log_softmax(right_logits, dim=-1)
right_beliefs = left_ranking + self.tom_weight * right_ranking
left_beliefs = right_ranking + self.tom_weight * left_ranking
return left_beliefs, right_beliefs, [out_left, out_right] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = SLToMnet(64, 'cpu', 2.0, False, 0.5)
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

104
boss/models/tom_tf.py Normal file
View file

@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetTF
from memory_efficient_attention_pytorch import Attention
class TFToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(TFToMnet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetTF(hidden_dim, dropout, mods)
self.mind_net_right = MindNetTF(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
l = self.dropout(self.act(out_left))
r = self.dropout(self.act(out_right))
left_beliefs = self.left(l)
right_beliefs = self.right(r)
return left_beliefs, right_beliefs, feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = TFToMnet(64, 'cpu', False, 0.5)
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

95
boss/models/utils.py Normal file
View file

@ -0,0 +1,95 @@
import torch
left_bias = [0.10659290976303822,
0.025158905348262015,
0.02811095449589107,
0.026342384511050237,
0.025318475572458178,
0.02283183957873461,
0.021581872822531316,
0.08062285577511237,
0.03824366373234754,
0.04853594319300018,
0.09653998563867983,
0.02961357410707162,
0.02961357410707162,
0.03172787957767081,
0.029985904630196004,
0.02897529321028696,
0.06602218026116327,
0.015345336560197867,
0.026900880295736816,
0.024879657455918726,
0.028669450280577644,
0.01936118720246802,
0.02341693040078721,
0.014707055663413206,
0.027007260445200926,
0.04146166325363687,
0.04243238211749688]
right_bias = [0.13147256721895695,
0.012433179968617855,
0.01623627031195979,
0.013683146724821148,
0.015252253929416771,
0.012579452674131008,
0.03127576394244834,
0.10325523257360177,
0.041155820323927554,
0.06563655221935587,
0.12684503071726816,
0.016156485199861705,
0.0176989973670913,
0.020238823435546928,
0.01918831945958884,
0.01791175766601952,
0.08768383819579266,
0.019002154198026647,
0.029600276588388607,
0.01578415467673732,
0.0176989973670913,
0.011834791627882237,
0.014919815962341426,
0.007552990611951809,
0.029759846812584773,
0.04981250498656951,
0.05533097524002021]
def build_ocr_graph(device):
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
edge_index = []
edge_attr = []
for i in range(len(ocr_graph)):
for j in range(1, len(ocr_graph[i])):
source_node = ocr_graph[i][0]
target_node = ocr_graph[i][j][0]
edge_index.append([source_node, target_node])
edge_attr.append(ocr_graph[i][j][1])
ocr_edge_index = torch.tensor(edge_index).t().long()
ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1)
x = torch.arange(0, 27)
ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float)
return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device)
def pose_edge_index():
return torch.tensor(
[[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20],
[15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]],
dtype=torch.long)