up
This commit is contained in:
parent
d4aaf7f4ad
commit
25b8b3f343
55 changed files with 7592 additions and 4 deletions
156
tbd/models/base.py
Normal file
156
tbd/models/base.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .utils import pose_edge_index
|
||||
from torch_geometric.nn import GCNConv
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self, hidden_dim):
|
||||
super(CNN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv2(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv3(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
|
||||
return x
|
||||
|
||||
|
||||
class MindNetLSTM(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetLSTM, self).__init__()
|
||||
self.mods = mods
|
||||
if 'rgb_1' in mods:
|
||||
self.img_emb = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
if 'gaze' in mods:
|
||||
self.gaze_emb = nn.Linear(2, hidden_dim)
|
||||
if 'pose' in mods:
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
|
||||
feats = []
|
||||
if 'rgb_3' in self.mods:
|
||||
feats.append(rgb_3rd_pov_feats)
|
||||
if 'rgb_1' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(rgb_1st_pov.shape[1]):
|
||||
images_i = rgb_1st_pov[:,i]
|
||||
img_i_feat = self.img_emb(images_i)
|
||||
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
feats.append(rgb_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats.mean(2))
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||||
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||||
return self.act(self.proj(lstm_out)), c_n, feats
|
||||
|
||||
|
||||
class MindNetSL(nn.Module):
|
||||
"""
|
||||
Basic MindNet for SL ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetSL, self).__init__()
|
||||
self.mods = mods
|
||||
if 'rgb_1' in mods:
|
||||
self.img_emb = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
if 'gaze' in mods:
|
||||
self.gaze_emb = nn.Linear(2, hidden_dim)
|
||||
if 'pose' in mods:
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
|
||||
feats = []
|
||||
if 'rgb_3' in self.mods:
|
||||
feats.append(rgb_3rd_pov_feats)
|
||||
if 'rgb_1' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(rgb_1st_pov.shape[1]):
|
||||
images_i = rgb_1st_pov[:,i]
|
||||
img_i_feat = self.img_emb(images_i)
|
||||
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
feats.append(rgb_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats.mean(2))
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
|
||||
return self.act(self.proj(lstm_out)), feats
|
157
tbd/models/common_mind.py
Normal file
157
tbd/models/common_mind.py
Normal file
|
@ -0,0 +1,157 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class CommonMindToMnet(nn.Module):
|
||||
"""
|
||||
img: bs, 3, 128, 128
|
||||
pose: bs, 26, 3
|
||||
gaze: bs, 2 NOTE: only tracker has gaze
|
||||
bbox: bs, 4
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']):
|
||||
super(CommonMindToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
#for param in self.cnn.parameters():
|
||||
# param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.ln_1 = nn.LayerNorm(hidden_dim)
|
||||
self.ln_2 = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
|
||||
|
||||
common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h)
|
||||
|
||||
if self.aggr == 'attn':
|
||||
p1 = self.attn_left(x=out_1, context=common_mind)
|
||||
p2 = self.attn_right(x=out_2, context=common_mind)
|
||||
elif self.aggr == 'mult':
|
||||
p1 = out_1 * common_mind
|
||||
p2 = out_2 * common_mind
|
||||
elif self.aggr == 'sum':
|
||||
p1 = out_1 + common_mind
|
||||
p2 = out_2 + common_mind
|
||||
elif self.aggr == 'concat':
|
||||
p1 = torch.cat([out_1, common_mind], 1)
|
||||
p2 = torch.cat([out_2, common_mind], 1)
|
||||
else: raise ValueError
|
||||
p1 = self.act(p1)
|
||||
p1 = self.ln_1(p1)
|
||||
p2 = self.act(p2)
|
||||
p2 = self.ln_2(p2)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1)
|
||||
if self.aggr == 'concat':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
mods = ['pose', 'bbox', 'rgb_3']
|
||||
|
||||
for agg in ['no_tom']:
|
||||
model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(out[0].shape)
|
151
tbd/models/implicit.py
Normal file
151
tbd/models/implicit.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class ImplicitToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(ImplicitToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
self.ln_1 = nn.LayerNorm(hidden_dim)
|
||||
self.ln_2 = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
if self.aggr == 'attn':
|
||||
p1 = self.attn_left(x=out_1, context=cell_2)
|
||||
p2 = self.attn_right(x=out_2, context=cell_1)
|
||||
elif self.aggr == 'mult':
|
||||
p1 = out_1 * cell_2
|
||||
p2 = out_2 * cell_1
|
||||
elif self.aggr == 'sum':
|
||||
p1 = out_1 + cell_2
|
||||
p2 = out_2 + cell_1
|
||||
elif self.aggr == 'concat':
|
||||
p1 = torch.cat([out_1, cell_2], 1)
|
||||
p2 = torch.cat([out_2, cell_1], 1)
|
||||
else: raise ValueError
|
||||
p1 = self.act(p1)
|
||||
p1 = self.ln_1(p1)
|
||||
p2 = self.act(p2)
|
||||
p2 = self.ln_2(p2)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1)
|
||||
if self.aggr == 'concat':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(agg, out[0].shape)
|
112
tbd/models/sl.py
Normal file
112
tbd/models/sl.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetSL
|
||||
|
||||
|
||||
class SLToMnet(nn.Module):
|
||||
"""
|
||||
Speaker-Listener ToMnet
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(SLToMnet, self).__init__()
|
||||
|
||||
self.tom_weight = tom_weight
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetSL(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetSL(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
m1_logits = self.m1(out_1).mean(1)
|
||||
m2_logits = self.m2(out_2).mean(1)
|
||||
m12_logits = self.m12(out_1).mean(1)
|
||||
m21_logits = self.m21(out_2).mean(1)
|
||||
mc_logits = self.mc(out_1*out_2).mean(1)
|
||||
|
||||
m1_ranking = torch.log_softmax(m1_logits, dim=-1)
|
||||
m2_ranking = torch.log_softmax(m2_logits, dim=-1)
|
||||
m12_ranking = torch.log_softmax(m12_logits, dim=-1)
|
||||
m21_ranking = torch.log_softmax(m21_logits, dim=-1)
|
||||
mc_ranking = torch.log_softmax(mc_logits, dim=-1)
|
||||
|
||||
# NOTE: does this make sense?
|
||||
m1 = m1_ranking + self.tom_weight * m2_ranking
|
||||
m2 = m2_ranking + self.tom_weight * m1_ranking
|
||||
m12 = m12_ranking + self.tom_weight * m21_ranking
|
||||
m21 = m21_ranking + self.tom_weight * m12_ranking
|
||||
mc = mc_ranking + self.tom_weight * mc_ranking
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
model = SLToMnet(hidden_dim=64, device='cpu', tom_weight=2.0, resnet=False, dropout=0.5)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(out[0].shape)
|
112
tbd/models/tom_base.py
Normal file
112
tbd/models/tom_base.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ImplicitToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(ImplicitToMnet, self).__init__()
|
||||
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
|
||||
print(count_parameters(model))
|
||||
breakpoint()
|
||||
|
||||
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(agg, out[0].shape)
|
7
tbd/models/utils.py
Normal file
7
tbd/models/utils.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
import torch
|
||||
|
||||
|
||||
def pose_edge_index():
|
||||
start = [15, 14, 13, 12, 19, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 3, 4, 5, 6, 8, 8, 4, 20, 21, 21, 22, 24, 22]
|
||||
end = [14, 13, 12, 0, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 4, 20, 20, 21, 22, 24, 23, 25, 24]
|
||||
return torch.tensor([start+end, end+start])
|
Loading…
Add table
Add a link
Reference in a new issue