mtomnet/tbd/models/implicit.py
2025-01-10 15:39:20 +01:00

151 lines
No EOL
5.8 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class ImplicitToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
super(ImplicitToMnet, self).__init__()
self.aggr = aggr
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
self.ln_1 = nn.LayerNorm(hidden_dim)
self.ln_2 = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
if self.aggr == 'no_tom':
m1 = self.m1(out_1).mean(1)
m2 = self.m2(out_2).mean(1)
m12 = self.m12(out_1).mean(1)
m21 = self.m21(out_2).mean(1)
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
if self.aggr == 'attn':
p1 = self.attn_left(x=out_1, context=cell_2)
p2 = self.attn_right(x=out_2, context=cell_1)
elif self.aggr == 'mult':
p1 = out_1 * cell_2
p2 = out_2 * cell_1
elif self.aggr == 'sum':
p1 = out_1 + cell_2
p2 = out_2 + cell_1
elif self.aggr == 'concat':
p1 = torch.cat([out_1, cell_2], 1)
p2 = torch.cat([out_2, cell_1], 1)
else: raise ValueError
p1 = self.act(p1)
p1 = self.ln_1(p1)
p2 = self.act(p2)
p2 = self.ln_2(p2)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1)
if self.aggr == 'concat':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(agg, out[0].shape)