151 lines
5.8 KiB
Python
151 lines
5.8 KiB
Python
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torchvision.models as models
|
|||
|
from .base import CNN, MindNetLSTM
|
|||
|
from memory_efficient_attention_pytorch import Attention
|
|||
|
|
|||
|
|
|||
|
class ImplicitToMnet(nn.Module):
|
|||
|
"""
|
|||
|
Implicit ToM net. Supports any subset of modalities
|
|||
|
Possible aggregations: sum, mult, attn, concat
|
|||
|
"""
|
|||
|
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
|||
|
super(ImplicitToMnet, self).__init__()
|
|||
|
|
|||
|
self.aggr = aggr
|
|||
|
self.mods = mods
|
|||
|
|
|||
|
# ---- 3rd POV Images, object and bbox ----#
|
|||
|
if resnet:
|
|||
|
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
|||
|
self.cnn = nn.Sequential(
|
|||
|
*(list(resnet.children())[:-1])
|
|||
|
)
|
|||
|
for param in self.cnn.parameters():
|
|||
|
param.requires_grad = False
|
|||
|
self.rgb_ff = nn.Linear(512, hidden_dim)
|
|||
|
else:
|
|||
|
self.cnn = CNN(hidden_dim)
|
|||
|
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
|||
|
self.bbox_ff = nn.Linear(4, hidden_dim)
|
|||
|
|
|||
|
# ---- Others ----#
|
|||
|
self.act = nn.GELU()
|
|||
|
self.dropout = nn.Dropout(dropout)
|
|||
|
self.device = device
|
|||
|
|
|||
|
# ---- Mind nets ----#
|
|||
|
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
|||
|
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
|||
|
self.ln_1 = nn.LayerNorm(hidden_dim)
|
|||
|
self.ln_2 = nn.LayerNorm(hidden_dim)
|
|||
|
if aggr == 'attn':
|
|||
|
self.attn_left = Attention(
|
|||
|
dim = hidden_dim,
|
|||
|
dim_head = hidden_dim // 4,
|
|||
|
heads = 4,
|
|||
|
memory_efficient = True,
|
|||
|
q_bucket_size = hidden_dim,
|
|||
|
k_bucket_size = hidden_dim)
|
|||
|
self.attn_right = Attention(
|
|||
|
dim = hidden_dim,
|
|||
|
dim_head = hidden_dim // 4,
|
|||
|
heads = 4,
|
|||
|
memory_efficient = True,
|
|||
|
q_bucket_size = hidden_dim,
|
|||
|
k_bucket_size = hidden_dim)
|
|||
|
self.m1 = nn.Linear(hidden_dim, 4)
|
|||
|
self.m2 = nn.Linear(hidden_dim, 4)
|
|||
|
self.m12 = nn.Linear(hidden_dim, 4)
|
|||
|
self.m21 = nn.Linear(hidden_dim, 4)
|
|||
|
self.mc = nn.Linear(hidden_dim, 4)
|
|||
|
|
|||
|
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
|||
|
|
|||
|
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
|||
|
|
|||
|
if 'bbox' in self.mods:
|
|||
|
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
|||
|
else:
|
|||
|
bbox_feat = None
|
|||
|
|
|||
|
if 'rgb_3' in self.mods:
|
|||
|
rgb_feat = []
|
|||
|
for i in range(sequence_len):
|
|||
|
images_i = img_3rd_pov[:,i]
|
|||
|
img_i_feat = self.cnn(images_i)
|
|||
|
img_i_feat = img_i_feat.view(batch_size, -1)
|
|||
|
rgb_feat.append(img_i_feat)
|
|||
|
rgb_feat = torch.stack(rgb_feat, 1)
|
|||
|
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
|||
|
else:
|
|||
|
rgb_feat_3rd_pov = None
|
|||
|
|
|||
|
if tracker_id == 'skele1':
|
|||
|
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
|||
|
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
|||
|
else:
|
|||
|
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
|||
|
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
|||
|
|
|||
|
if self.aggr == 'no_tom':
|
|||
|
m1 = self.m1(out_1).mean(1)
|
|||
|
m2 = self.m2(out_2).mean(1)
|
|||
|
m12 = self.m12(out_1).mean(1)
|
|||
|
m21 = self.m21(out_2).mean(1)
|
|||
|
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
|||
|
|
|||
|
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
|||
|
|
|||
|
if self.aggr == 'attn':
|
|||
|
p1 = self.attn_left(x=out_1, context=cell_2)
|
|||
|
p2 = self.attn_right(x=out_2, context=cell_1)
|
|||
|
elif self.aggr == 'mult':
|
|||
|
p1 = out_1 * cell_2
|
|||
|
p2 = out_2 * cell_1
|
|||
|
elif self.aggr == 'sum':
|
|||
|
p1 = out_1 + cell_2
|
|||
|
p2 = out_2 + cell_1
|
|||
|
elif self.aggr == 'concat':
|
|||
|
p1 = torch.cat([out_1, cell_2], 1)
|
|||
|
p2 = torch.cat([out_2, cell_1], 1)
|
|||
|
else: raise ValueError
|
|||
|
p1 = self.act(p1)
|
|||
|
p1 = self.ln_1(p1)
|
|||
|
p2 = self.act(p2)
|
|||
|
p2 = self.ln_2(p2)
|
|||
|
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
|||
|
m1 = self.m1(p1).mean(1)
|
|||
|
m2 = self.m2(p2).mean(1)
|
|||
|
m12 = self.m12(p1).mean(1)
|
|||
|
m21 = self.m21(p2).mean(1)
|
|||
|
mc = self.mc(p1*p2).mean(1)
|
|||
|
if self.aggr == 'concat':
|
|||
|
m1 = self.m1(p1).mean(1)
|
|||
|
m2 = self.m2(p2).mean(1)
|
|||
|
m12 = self.m12(p1).mean(1)
|
|||
|
m21 = self.m21(p2).mean(1)
|
|||
|
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
|
|||
|
|
|||
|
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
|
|||
|
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
|||
|
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
|||
|
img_battery = torch.ones(3, 5, 3, 128, 128)
|
|||
|
pose1 = torch.ones(3, 5, 26, 3)
|
|||
|
pose2 = torch.ones(3, 5, 26, 3)
|
|||
|
bbox = torch.ones(3, 5, 13, 4)
|
|||
|
tracker_id = 'skele1'
|
|||
|
gaze = torch.ones(3, 5, 2)
|
|||
|
|
|||
|
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
|
|||
|
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg)
|
|||
|
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
|||
|
|
|||
|
print(agg, out[0].shape)
|