mtomnet/tbd/models/common_mind.py
2025-01-10 15:39:20 +01:00

157 lines
No EOL
6 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class CommonMindToMnet(nn.Module):
"""
img: bs, 3, 128, 128
pose: bs, 26, 3
gaze: bs, 2 NOTE: only tracker has gaze
bbox: bs, 4
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']):
super(CommonMindToMnet, self).__init__()
self.aggr = aggr
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
#for param in self.cnn.parameters():
# param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_1 = nn.LayerNorm(hidden_dim)
self.ln_2 = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
if self.aggr == 'no_tom':
m1 = self.m1(out_1).mean(1)
m2 = self.m2(out_2).mean(1)
m12 = self.m12(out_1).mean(1)
m21 = self.m21(out_2).mean(1)
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h)
if self.aggr == 'attn':
p1 = self.attn_left(x=out_1, context=common_mind)
p2 = self.attn_right(x=out_2, context=common_mind)
elif self.aggr == 'mult':
p1 = out_1 * common_mind
p2 = out_2 * common_mind
elif self.aggr == 'sum':
p1 = out_1 + common_mind
p2 = out_2 + common_mind
elif self.aggr == 'concat':
p1 = torch.cat([out_1, common_mind], 1)
p2 = torch.cat([out_2, common_mind], 1)
else: raise ValueError
p1 = self.act(p1)
p1 = self.ln_1(p1)
p2 = self.act(p2)
p2 = self.ln_2(p2)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1)
if self.aggr == 'concat':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
mods = ['pose', 'bbox', 'rgb_3']
for agg in ['no_tom']:
model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(out[0].shape)