mtomnet/tbd/models/sl.py

112 lines
4.2 KiB
Python
Raw Permalink Normal View History

2025-01-10 15:39:20 +01:00
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetSL
class SLToMnet(nn.Module):
"""
Speaker-Listener ToMnet
"""
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
super(SLToMnet, self).__init__()
self.tom_weight = tom_weight
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetSL(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetSL(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
m1_logits = self.m1(out_1).mean(1)
m2_logits = self.m2(out_2).mean(1)
m12_logits = self.m12(out_1).mean(1)
m21_logits = self.m21(out_2).mean(1)
mc_logits = self.mc(out_1*out_2).mean(1)
m1_ranking = torch.log_softmax(m1_logits, dim=-1)
m2_ranking = torch.log_softmax(m2_logits, dim=-1)
m12_ranking = torch.log_softmax(m12_logits, dim=-1)
m21_ranking = torch.log_softmax(m21_logits, dim=-1)
mc_ranking = torch.log_softmax(mc_logits, dim=-1)
# NOTE: does this make sense?
m1 = m1_ranking + self.tom_weight * m2_ranking
m2 = m2_ranking + self.tom_weight * m1_ranking
m12 = m12_ranking + self.tom_weight * m21_ranking
m21 = m21_ranking + self.tom_weight * m12_ranking
mc = mc_ranking + self.tom_weight * mc_ranking
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
model = SLToMnet(hidden_dim=64, device='cpu', tom_weight=2.0, resnet=False, dropout=0.5)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(out[0].shape)