mtomnet/boss/models/tom_base.py

104 lines
3.6 KiB
Python
Raw Normal View History

2025-01-10 15:39:20 +01:00
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
import numpy as np
class BaseToMnet(nn.Module):
"""
Base ToM net. Supports any subset of modalities
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(BaseToMnet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = BaseToMnet(64, 'cpu', False, 0.5)
print(count_parameters(model))
breakpoint()
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)