This commit is contained in:
Matteo Bortoletto 2025-01-10 15:39:20 +01:00
parent d4aaf7f4ad
commit 25b8b3f343
55 changed files with 7592 additions and 4 deletions

196
tbd/.gitignore vendored Normal file
View file

@ -0,0 +1,196 @@
experiments
wandb
predictions
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python,linux

16
tbd/README.md Normal file
View file

@ -0,0 +1,16 @@
# TBD
# Data
The original code can be found [here](https://github.com/LifengFan/Triadic-Belief-Dynamics). The dataset is not directly available but must be requested using the link to the Google form provided in the [README](https://github.com/LifengFan/Triadic-Belief-Dynamics?tab=readme-ov-file#dataset).
## Installing Dependencies
Run `conda env create -f environment.yml`.
## Train
`source run_train.sh`.
## Test
`source run_test.sh`. **Make sure to use the same random seed used for training**, otherwise the splits will be different and you will likely have a data leakage.
## Visualisations
The plots are made using `utils/fb_scores_err.py` (false belief analysis) and `utils/similarity.py` (PCA of latent representations).

100
tbd/environment.yml Normal file
View file

@ -0,0 +1,100 @@
name: tbd
channels:
- conda-forge
- defaults
- pytorch
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2023.01.10=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.4=h6a678d5_0
- openssl=1.1.1t=h7f8727e_0
- pip=23.0.1=py38h06a4308_0
- python=3.8.10=h12debd9_8
- readline=8.2=h5eee18b_0
- setuptools=66.0.0=py38h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.38.4=py38h06a4308_0
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- appdirs==1.4.4
- beautifulsoup4==4.12.2
- certifi==2023.5.7
- charset-normalizer==3.1.0
- click==8.1.3
- cmake==3.26.4
- contourpy==1.1.0
- cycler==0.11.0
- docker-pycreds==0.4.0
- einops==0.6.1
- filelock==3.12.0
- fonttools==4.40.0
- gdown==4.7.1
- gitdb==4.0.10
- gitpython==3.1.31
- idna==3.4
- importlib-resources==5.12.0
- jinja2==3.1.2
- joblib==1.3.1
- kiwisolver==1.4.4
- lit==16.0.6
- markupsafe==2.1.3
- matplotlib==3.7.1
- memory-efficient-attention-pytorch==0.1.2
- mpmath==1.3.0
- networkx==3.1
- numpy==1.24.4
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cufft-cu11==10.9.0.58
- nvidia-curand-cu11==10.2.10.91
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-nccl-cu11==2.14.3
- nvidia-nvtx-cu11==11.7.91
- opencv-python==4.8.0.74
- packaging==23.1
- pandas==2.0.3
- pathtools==0.1.2
- pillow==9.5.0
- protobuf==4.23.3
- psutil==5.9.5
- pyparsing==3.1.0
- pysocks==1.7.1
- python-dateutil==2.8.2
- pytz==2023.3
- pyyaml==6.0
- requests==2.30.0
- scikit-learn==1.3.0
- scipy==1.10.1
- seaborn==0.12.2
- sentry-sdk==1.27.0
- setproctitle==1.3.2
- six==1.16.0
- smmap==5.0.0
- soupsieve==2.4.1
- sympy==1.12
- threadpoolctl==3.1.0
- torch==2.0.1
- torch-geometric==2.3.1
- torchaudio==2.0.2
- torchsampler==0.1.2
- torchvision==0.15.2
- tqdm==4.65.0
- triton==2.0.0
- typing-extensions==4.7.0
- tzdata==2023.3
- urllib3==2.0.2
- wandb==0.15.5
- zipp==3.15.0
prefix: /opt/anaconda3/envs/tbd

156
tbd/models/base.py Normal file
View file

@ -0,0 +1,156 @@
import torch
import torch.nn as nn
from .utils import pose_edge_index
from torch_geometric.nn import GCNConv
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
def forward(self, x):
return self.net(x)
class CNN(nn.Module):
def __init__(self, hidden_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv3(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
return x
class MindNetLSTM(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTM, self).__init__()
self.mods = mods
if 'rgb_1' in mods:
self.img_emb = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
if 'gaze' in mods:
self.gaze_emb = nn.Linear(2, hidden_dim)
if 'pose' in mods:
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
feats = []
if 'rgb_3' in self.mods:
feats.append(rgb_3rd_pov_feats)
if 'rgb_1' in self.mods:
rgb_feat = []
for i in range(rgb_1st_pov.shape[1]):
images_i = rgb_1st_pov[:,i]
img_i_feat = self.img_emb(images_i)
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
feats.append(rgb_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats.mean(2))
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats
class MindNetSL(nn.Module):
"""
Basic MindNet for SL ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetSL, self).__init__()
self.mods = mods
if 'rgb_1' in mods:
self.img_emb = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
if 'gaze' in mods:
self.gaze_emb = nn.Linear(2, hidden_dim)
if 'pose' in mods:
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
feats = []
if 'rgb_3' in self.mods:
feats.append(rgb_3rd_pov_feats)
if 'rgb_1' in self.mods:
rgb_feat = []
for i in range(rgb_1st_pov.shape[1]):
images_i = rgb_1st_pov[:,i]
img_i_feat = self.img_emb(images_i)
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
feats.append(rgb_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats.mean(2))
lstm_inp = torch.cat(feats, 2)
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
return self.act(self.proj(lstm_out)), feats

157
tbd/models/common_mind.py Normal file
View file

@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class CommonMindToMnet(nn.Module):
"""
img: bs, 3, 128, 128
pose: bs, 26, 3
gaze: bs, 2 NOTE: only tracker has gaze
bbox: bs, 4
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']):
super(CommonMindToMnet, self).__init__()
self.aggr = aggr
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
#for param in self.cnn.parameters():
# param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_1 = nn.LayerNorm(hidden_dim)
self.ln_2 = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
if self.aggr == 'no_tom':
m1 = self.m1(out_1).mean(1)
m2 = self.m2(out_2).mean(1)
m12 = self.m12(out_1).mean(1)
m21 = self.m21(out_2).mean(1)
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h)
if self.aggr == 'attn':
p1 = self.attn_left(x=out_1, context=common_mind)
p2 = self.attn_right(x=out_2, context=common_mind)
elif self.aggr == 'mult':
p1 = out_1 * common_mind
p2 = out_2 * common_mind
elif self.aggr == 'sum':
p1 = out_1 + common_mind
p2 = out_2 + common_mind
elif self.aggr == 'concat':
p1 = torch.cat([out_1, common_mind], 1)
p2 = torch.cat([out_2, common_mind], 1)
else: raise ValueError
p1 = self.act(p1)
p1 = self.ln_1(p1)
p2 = self.act(p2)
p2 = self.ln_2(p2)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1)
if self.aggr == 'concat':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
mods = ['pose', 'bbox', 'rgb_3']
for agg in ['no_tom']:
model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(out[0].shape)

151
tbd/models/implicit.py Normal file
View file

@ -0,0 +1,151 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class ImplicitToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
super(ImplicitToMnet, self).__init__()
self.aggr = aggr
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
self.ln_1 = nn.LayerNorm(hidden_dim)
self.ln_2 = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
if self.aggr == 'no_tom':
m1 = self.m1(out_1).mean(1)
m2 = self.m2(out_2).mean(1)
m12 = self.m12(out_1).mean(1)
m21 = self.m21(out_2).mean(1)
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
if self.aggr == 'attn':
p1 = self.attn_left(x=out_1, context=cell_2)
p2 = self.attn_right(x=out_2, context=cell_1)
elif self.aggr == 'mult':
p1 = out_1 * cell_2
p2 = out_2 * cell_1
elif self.aggr == 'sum':
p1 = out_1 + cell_2
p2 = out_2 + cell_1
elif self.aggr == 'concat':
p1 = torch.cat([out_1, cell_2], 1)
p2 = torch.cat([out_2, cell_1], 1)
else: raise ValueError
p1 = self.act(p1)
p1 = self.ln_1(p1)
p2 = self.act(p2)
p2 = self.ln_2(p2)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1)
if self.aggr == 'concat':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(agg, out[0].shape)

112
tbd/models/sl.py Normal file
View file

@ -0,0 +1,112 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetSL
class SLToMnet(nn.Module):
"""
Speaker-Listener ToMnet
"""
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
super(SLToMnet, self).__init__()
self.tom_weight = tom_weight
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetSL(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetSL(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
m1_logits = self.m1(out_1).mean(1)
m2_logits = self.m2(out_2).mean(1)
m12_logits = self.m12(out_1).mean(1)
m21_logits = self.m21(out_2).mean(1)
mc_logits = self.mc(out_1*out_2).mean(1)
m1_ranking = torch.log_softmax(m1_logits, dim=-1)
m2_ranking = torch.log_softmax(m2_logits, dim=-1)
m12_ranking = torch.log_softmax(m12_logits, dim=-1)
m21_ranking = torch.log_softmax(m21_logits, dim=-1)
mc_ranking = torch.log_softmax(mc_logits, dim=-1)
# NOTE: does this make sense?
m1 = m1_ranking + self.tom_weight * m2_ranking
m2 = m2_ranking + self.tom_weight * m1_ranking
m12 = m12_ranking + self.tom_weight * m21_ranking
m21 = m21_ranking + self.tom_weight * m12_ranking
mc = mc_ranking + self.tom_weight * mc_ranking
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
model = SLToMnet(hidden_dim=64, device='cpu', tom_weight=2.0, resnet=False, dropout=0.5)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(out[0].shape)

112
tbd/models/tom_base.py Normal file
View file

@ -0,0 +1,112 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetLSTM
import numpy as np
class ImplicitToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
super(ImplicitToMnet, self).__init__()
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
if self.aggr == 'no_tom':
m1 = self.m1(out_1).mean(1)
m2 = self.m2(out_2).mean(1)
m12 = self.m12(out_1).mean(1)
m21 = self.m21(out_2).mean(1)
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
print(count_parameters(model))
breakpoint()
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(agg, out[0].shape)

7
tbd/models/utils.py Normal file
View file

@ -0,0 +1,7 @@
import torch
def pose_edge_index():
start = [15, 14, 13, 12, 19, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 3, 4, 5, 6, 8, 8, 4, 20, 21, 21, 22, 24, 22]
end = [14, 13, 12, 0, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 4, 20, 20, 21, 22, 24, 23, 25, 24]
return torch.tensor([start+end, end+start])

186
tbd/results/abl.json Normal file
View file

@ -0,0 +1,186 @@
{
"all": [
{
"m1": 0.3803337111550836,
"m2": 0.3900899763574355,
"m12": 0.4441281276628709,
"m21": 0.4818757648120031,
"mc": 0.4485177767702456
},
{
"m1": 0.5186066842992191,
"m2": 0.521895750052127,
"m12": 0.49294626980529677,
"m21": 0.4810118034501327,
"mc": 0.6097300398369058
},
{
"m1": 0.4965589148309122,
"m2": 0.5094894309980568,
"m12": 0.4615136302786905,
"m21": 0.4554005550423429,
"mc": 0.6258118710785031
}
],
"rgb_3_pose_gaze_bbox": [
{
"m1": 0.3776045061727805,
"m2": 0.3996776745150713,
"m12": 0.4762772810038159,
"m21": 0.48643178296718503,
"mc": 0.4575207273412474
},
{
"m1": 0.5176564423560418,
"m2": 0.5109344883698214,
"m12": 0.4630213122846928,
"m21": 0.4826608133674547,
"mc": 0.5979415365779003
},
{
"m1": 0.5114692300997931,
"m2": 0.5027048375802656,
"m12": 0.47527894405588544,
"m21": 0.45223985157847546,
"mc": 0.6054099305712209
}
],
"rgb_3_pose_gaze": [
{
"m1": 0.403207421026191,
"m2": 0.3833413122398237,
"m12": 0.4602455224198077,
"m21": 0.47181798537346287,
"mc": 0.4603675297898878
},
{
"m1": 0.49484810149311514,
"m2": 0.5060275976807422,
"m12": 0.4610412452830618,
"m21": 0.46869095956564044,
"mc": 0.6040674897817755
},
{
"m1": 0.5160598186177866,
"m2": 0.5309683014233921,
"m12": 0.47227245803060636,
"m21": 0.46953974307035984,
"mc": 0.6014771460423635
}
],
"rgb_3_pose": [
{
"m1": 0.4057149181928123,
"m2": 0.4002233785689204,
"m12": 0.46794813614607333,
"m21": 0.4690365183933033,
"mc": 0.4591530208921514
},
{
"m1": 0.5362792166212834,
"m2": 0.5290656046231254,
"m12": 0.4569419683345858,
"m21": 0.4530255281497826,
"mc": 0.4554252731371068
},
{
"m1": 0.49570625763169085,
"m2": 0.5146503967646507,
"m12": 0.4567936139893578,
"m21": 0.45918214877096325,
"mc": 0.5962397441246001
}
],
"rgb_3_gaze": [
{
"m1": 0.40135106828655215,
"m2": 0.38453470155825614,
"m12": 0.4989742833725901,
"m21": 0.47369273992079175,
"mc": 0.48430622854433986
},
{
"m1": 0.508038122818153,
"m2": 0.4875748099051746,
"m12": 0.46665443622698555,
"m21": 0.46635808547742913,
"mc": 0.47936993226840163
},
{
"m1": 0.49795853039610977,
"m2": 0.5028666890527814,
"m12": 0.44176709237564815,
"m21": 0.4483898274665582,
"mc": 0.5867527750929912
}
],
"rgb_3_bbox": [
{
"m1": 0.3951383898241492,
"m2": 0.3818794542844425,
"m12": 0.44108151735270384,
"m21": 0.46539754196523303,
"mc": 0.43982185797713114
},
{
"m1": 0.5093846655989521,
"m2": 0.4923439212866733,
"m12": 0.4598003475323884,
"m21": 0.47647640659290746,
"mc": 0.6349953712994137
},
{
"m1": 0.5325224862402295,
"m2": 0.5092319973570975,
"m12": 0.4435807136490263,
"m21": 0.4576911633624616,
"mc": 0.6282064277856357
}
],
"rgb_3_rgb_1": [
{
"m1": 0.39189391736691903,
"m2": 0.3739995635963588,
"m12": 0.4792392731637056,
"m21": 0.4592726043789752,
"mc": 0.4468645255652386
},
{
"m1": 0.4827892482357646,
"m2": 0.48042899735042716,
"m12": 0.45932653547051094,
"m21": 0.48430209616318126,
"mc": 0.4506104344435269
},
{
"m1": 0.4820247145474279,
"m2": 0.3667553358192628,
"m12": 0.44503028688537,
"m21": 0.45984906207471654,
"mc": 0.465120658971623
}
],
"rgb_3": [
{
"m1": 0.40725462165126114,
"m2": 0.38737351624656846,
"m12": 0.46230461548252094,
"m21": 0.4829312519709871,
"mc": 0.4492175856929955
},
{
"m1": 0.5286274183685061,
"m2": 0.5081429492163979,
"m12": 0.4610256989472217,
"m21": 0.4733487634477733,
"mc": 0.4655243312197501
},
{
"m1": 0.5217968210271873,
"m2": 0.5103780571157844,
"m12": 0.4431266771306429,
"m21": 0.48398542131284883,
"mc": 0.6122314353959392
}
]
}

232
tbd/results/all.json Normal file
View file

@ -0,0 +1,232 @@
{
"cm_concat": [
{
"m1": 0.38921744471949393,
"m2": 0.38557137008494935,
"m12": 0.44699534554593756,
"m21": 0.4747474437468054,
"mc": 0.4918107834016411
},
{
"m1": 0.5402415140026018,
"m2": 0.48833721513836786,
"m12": 0.4631512445419047,
"m21": 0.4740880083492652,
"mc": 0.6375070925808958
},
{
"m1": 0.5012543523713172,
"m2": 0.5068694866895836,
"m12": 0.4451537834591627,
"m21": 0.45215784721598673,
"mc": 0.6201022576104379
}
],
"cm_sum": [
{
"m1": 0.39403894801783246,
"m2": 0.38541918219411786,
"m12": 0.4600376974144952,
"m21": 0.471919704007463,
"mc": 0.43950812310207055
},
{
"m1": 0.48497621104052574,
"m2": 0.5295044689855949,
"m12": 0.4502949472343065,
"m21": 0.47823492553894387,
"mc": 0.6028290833617195
},
{
"m1": 0.503386104373653,
"m2": 0.49983127146477085,
"m12": 0.46782817568218116,
"m21": 0.45484578845116075,
"mc": 0.5905749126722909
}
],
"cm_mult": [
{
"m1": 0.39070820515470606,
"m2": 0.3996851353903932,
"m12": 0.4455704586852128,
"m21": 0.4713517869738811,
"mc": 0.4450907029478458
},
{
"m1": 0.5066540697731119,
"m2": 0.526507445454099,
"m12": 0.462643008560599,
"m21": 0.48263054309565334,
"mc": 0.6438566476782207
},
{
"m1": 0.48868811674304546,
"m2": 0.5074635877653536,
"m12": 0.44597405775819876,
"m21": 0.45445350963025877,
"mc": 0.5884265473527218
}
],
"cm_attn": [
{
"m1": 0.3949557687114269,
"m2": 0.3919385900921811,
"m12": 0.4850081112466773,
"m21": 0.4849575556679713,
"mc": 0.4516870089239762
},
{
"m1": 0.4925989821370256,
"m2": 0.49409170532242247,
"m12": 0.4664647278240569,
"m21": 0.46783863397462533,
"mc": 0.6398721139927354
},
{
"m1": 0.4945636568169018,
"m2": 0.5049812790749876,
"m12": 0.454359577718189,
"m21": 0.4712184012093268,
"mc": 0.5992735441011302
}
],
"no_tom": [
{
"m1": 0.2570551317,
"m2": 0.375350929686332,
"m12": 0.312451988649724,
"m21": 0.4631371031641,
"mc": 0.457486278214567
},
{
"m1": 0.233046800382043,
"m2": 0.522609755931958,
"m12": 0.326821758467328,
"m21": 0.474338898013257,
"mc": 0.604439456291308
},
{
"m1": 0.33774852598382,
"m2": 0.520943544364353,
"m12": 0.298617214416867,
"m21": 0.482175301427192,
"mc": 0.634948478570852
}
],
"sl": [
{
"m1": 0.365205706591741,
"m2": 0.255259363011619,
"m12": 0.421227579844245,
"m21": 0.376143327741882,
"mc": 0.45614515353718
},
{
"m1": 0.493046934143676,
"m2": 0.331798174804139,
"m12": 0.422821548330913,
"m21": 0.399768928780549,
"mc": 0.450957023549231
},
{
"m1": 0.466266787709392,
"m2": 0.350962671130227,
"m12": 0.431694150269919,
"m21": 0.378863431433258,
"mc": 0.470284405744656
}
],
"impl_concat": [
{
"m1": 0.38427302094644894,
"m2": 0.38673879043767634,
"m12": 0.45694337561663145,
"m21": 0.4737891562722213,
"mc": 0.4502976351448088
},
{
"m1": 0.49951068243751173,
"m2": 0.5084945752383908,
"m12": 0.4604721097809549,
"m21": 0.4826884970930907,
"mc": 0.6200443272625361
},
{
"m1": 0.5013244243339088,
"m2": 0.49476495726495723,
"m12": 0.4596701406290429,
"m21": 0.4554742441542813,
"mc": 0.5988949378402535
}
],
"impl_sum": [
{
"m1": 0.3803337111550836,
"m2": 0.3900899763574355,
"m12": 0.4441281276628709,
"m21": 0.4818757648120031,
"mc": 0.4485177767702456
},
{
"m1": 0.5186066842992191,
"m2": 0.521895750052127,
"m12": 0.49294626980529677,
"m21": 0.4810118034501327,
"mc": 0.6097300398369058
},
{
"m1": 0.4965589148309122,
"m2": 0.5094894309980568,
"m12": 0.4615136302786905,
"m21": 0.4554005550423429,
"mc": 0.6258118710785031
}
],
"impl_mult": [
{
"m1": 0.3789421413006731,
"m2": 0.3818053844554785,
"m12": 0.46402717346945177,
"m21": 0.4903726261039529,
"mc": 0.4461443806398687
},
{
"m1": 0.3789421413006731,
"m2": 0.3818053844554785,
"m12": 0.46402717346945177,
"m21": 0.4903726261039529,
"mc": 0.4461443806398687
},
{
"m1": 0.49338554196342077,
"m2": 0.5066817652688608,
"m12": 0.46253374461930613,
"m21": 0.47782311190445825,
"mc": 0.4581608719646799
}
],
"impl_attn": [
{
"m1": 0.37413691393147924,
"m2": 0.2546966838007244,
"m12": 0.429390512693598,
"m21": 0.292401773870023,
"mc": 0.45706325836224465
},
{
"m1": 0.513917904196177,
"m2": 0.25802580258025803,
"m12": 0.49272662664765543,
"m21": 0.27041556176385584,
"mc": 0.6041394755857196
},
{
"m1": 0.47720445038981674,
"m2": 0.25839328537170264,
"m12": 0.46505055463781547,
"m21": 0.260276985433943,
"mc": 0.6021811271770562
}
]
}

Binary file not shown.

87
tbd/results/fb_ttest.txt Normal file
View file

@ -0,0 +1,87 @@
========================================================= m1_m2_m12_m21
Model: Base -> yes
Model: DB -> yes
Model: CG$\oplus$ -> no
Model: CG$\otimes$ -> no
Model: CG$\odot$ -> no
Model: IC$\parallel$ -> no
Model: IC$\oplus$ -> no
Model: IC$\otimes$ -> no
Model: IC$\odot$ -> yes
========================================================= m1_m2
Model: Base -> yes
Model: DB -> yes
Model: CG$\oplus$ -> no
Model: CG$\otimes$ -> no
Model: CG$\odot$ -> no
Model: IC$\parallel$ -> no
Model: IC$\oplus$ -> no
Model: IC$\otimes$ -> no
Model: IC$\odot$ -> yes
========================================================= m12_m21
Model: Base -> yes
Model: DB -> yes
Model: CG$\oplus$ -> no
Model: CG$\otimes$ -> no
Model: CG$\odot$ -> no
Model: IC$\parallel$ -> no
Model: IC$\oplus$ -> no
Model: IC$\otimes$ -> no
Model: IC$\odot$ -> yes

View file

@ -0,0 +1,59 @@
mc =====================================================================
precision recall f1-score support
0 0.000 0.500 0.001 50
1 0.000 0.000 0.000 4
2 0.004 0.038 0.007 238
3 0.999 0.795 0.885 290788
accuracy 0.794 291080
macro avg 0.251 0.333 0.223 291080
weighted avg 0.998 0.794 0.884 291080
m1 =====================================================================
precision recall f1-score support
0 0.000 0.000 0.000 147
1 0.000 0.000 0.000 2
2 0.025 0.051 0.033 1714
3 0.994 0.988 0.991 289217
accuracy 0.982 291080
macro avg 0.255 0.260 0.256 291080
weighted avg 0.988 0.982 0.985 291080
m2 =====================================================================
precision recall f1-score support
0 0.001 0.013 0.001 151
2 0.031 0.084 0.045 2394
3 0.992 0.970 0.981 288535
accuracy 0.962 291080
macro avg 0.341 0.355 0.342 291080
weighted avg 0.983 0.962 0.972 291080
m12 =====================================================================
precision recall f1-score support
0 0.000 0.000 0.000 93
1 0.000 0.000 0.000 8
2 0.015 0.056 0.023 676
3 0.997 0.990 0.994 290303
accuracy 0.988 291080
macro avg 0.253 0.262 0.254 291080
weighted avg 0.995 0.988 0.991 291080
m21 =====================================================================
precision recall f1-score support
0 0.002 0.012 0.003 86
1 0.000 0.000 0.000 12
2 0.010 0.040 0.016 658
3 0.997 0.989 0.993 290324
accuracy 0.987 291080
macro avg 0.252 0.260 0.253 291080
weighted avg 0.995 0.987 0.991 291080

Binary file not shown.

12
tbd/run_test.sh Normal file
View file

@ -0,0 +1,12 @@
#!/bin/bash
python -m test \
--gpu_id 1 \
--seed 1 \
--non_blocking \
--pin_memory \
--model_type tom_cm \
--aggr no_tom \
--hidden_dim 64 \
--batch_size 64 \
--load_model_path /PATH/TO/model

16
tbd/run_train.sh Normal file
View file

@ -0,0 +1,16 @@
#!/bin/bash
python -m train \
--gpu_id 2 \
--seed 123 \
--logger \
--non_blocking \
--pin_memory \
--batch_size 64 \
--num_workers 16 \
--num_epoch 300 \
--lr 5e-4 \
--dropout 0.1 \
--model_type tom_cm \
--aggr no_tom \
--hidden_dim 64

568
tbd/tbd_dataloader.py Normal file
View file

@ -0,0 +1,568 @@
from __future__ import annotations
from typing import Optional, Union
import torch
import pickle
import torch
import time
import glob
import random
import os
import numpy as np
import pandas as pd
import cv2
from itertools import product
import csv
import torchvision.transforms as T
from utils.helpers import tracker_skeID, CLIPS_IDS_88, ALL_IDS, UNIQUE_OBJ_IDS
def collate_fn(batch):
# Unpack the batch into individual elements
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep = zip(*batch)
# Determine the maximum number of objects in any batch
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
# Pad the bounding box tensors
bboxes_pad = []
for bbox in bboxes:
pad_size = max_n_obj - bbox.shape[1]
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
padded_bbox = torch.cat((bbox, pad), dim=1)
bboxes_pad.append(padded_bbox)
# Stack the padded tensors into a batch tensor
bboxes_batch = torch.stack(bboxes_pad, dim=0)
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
img_tracker = torch.stack(img_tracker, dim=0)
img_battery = torch.stack(img_battery, dim=0)
pose1 = torch.stack(pose1, dim=0)
pose2 = torch.stack(pose2, dim=0)
gaze = torch.stack(gaze, dim=0)
labels = torch.tensor(labels, dtype=torch.long)
# Return the batched tensors
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep
def collate_fn_test(batch):
# Unpack the batch into individual elements
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep, false_beliefs = zip(*batch)
# Determine the maximum number of objects in any batch
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
# Pad the bounding box tensors
bboxes_pad = []
for bbox in bboxes:
pad_size = max_n_obj - bbox.shape[1]
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
padded_bbox = torch.cat((bbox, pad), dim=1)
bboxes_pad.append(padded_bbox)
# Stack the padded tensors into a batch tensor
bboxes_batch = torch.stack(bboxes_pad, dim=0)
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
img_tracker = torch.stack(img_tracker, dim=0)
img_battery = torch.stack(img_battery, dim=0)
pose1 = torch.stack(pose1, dim=0)
pose2 = torch.stack(pose2, dim=0)
gaze = torch.stack(gaze, dim=0)
labels = torch.tensor(labels, dtype=torch.long)
# Return the batched tensors
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep, false_beliefs
class TBDDataset(torch.utils.data.Dataset):
def __init__(
self,
path: str = "/scratch/bortoletto/data/tbd",
mode: str = "train",
tbd_data_path: str = "/scratch/bortoletto/data/tbd/mind_lstm_training_cnn_att/",
list_of_ids_to_consider: list = ALL_IDS,
use_preprocessed_img: bool = True,
resize_img: Optional[Union[tuple, int]] = (128,128),
):
"""TBD Dataset based on the 88 clip version of the TBD data.
Expects the following folder structure:
- path
- tracker_gt_smooth <- These are eye tracking from POV, 2D coordinates
- images/*/ <- These are the images by experiment id
- battery <- These are the images, 1st Person
- tracker <- These are the images, 1st other Person w/ eye fixation
- kinect <- These are the images, 3rd Person
- skeleton <- Pose estimation, 3D coordinates
- annotation <- These are the labels, i.e. [0,3] (see below)
Labels are strcutured as follows:
{
"O1": [ <- Object with id O1
{
"m1": {
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
"loc": null
},
"m2": {
"fluent": 3,
"loc": null
},
"m12": {
"fluent": 3,
"loc": null
},
"m21": {
"fluent": 3,
"loc": null
},
"mc": {
"fluent": 3,
"loc": null
},
"mg": {
"fluent": 3,
"loc": [
22,
9
]
}
}, ...
], ...
}
This corresponds to a strict subset of the raw dataset collected
by the TBD people in their paper "Learning Traidic Belief Dynamics
in Nonverbal Communication from Videos" (CVPR2021, Oral).
We keep small amounts of data in memory (everything <100MB).
Otherwise we read from disk on the fly. This dataset applies normalization.
Args:
path (str, optional): Where the folders lie.
Defaults to "/scratch/ruhdorfer/triadic_beleif_data_v2".
list_of_ids_to_consider (list, optional): List of ids to consider.
Defaults to ALL_IDS. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
resize_img (Optional[Union[tuple, int]], optional): Resize image to
this size if required. Defaults to None.
"""
print(f"Loading TBD Dataset in mode {mode}...")
self.mode = mode
start = time.time()
self.skeleton_3D_path = f"{path}/skeleton"
self.tracker_2D_path = f"{path}/tracker_gt_smooth"
self.bbox_csv_path = f"{path}/annotations_with_bbox.csv"
if use_preprocessed_img:
self.img_path = f"{path}/images_norm"
else:
self.img_path = f"{path}/images"
self.obj_ids_path = f"{path}/mind_lstm_training_cnn_att_shu.pkl"
self.label_map = list(product([0, 1, 2, 3], repeat=5))
clips = os.listdir(tbd_data_path)
data = []
labels = []
for clip in clips:
with open(tbd_data_path + clip, 'rb') as f:
vec_input, label_ = pickle.load(f, encoding='latin1')
data = data + vec_input
labels = labels + label_
c = list(zip(data, labels))
random.shuffle(c)
train_ratio = int(len(c) * 0.6)
validate_ratio = int(len(c) * 0.2)
data, label = zip(*c)
train_x, train_y = data[:train_ratio], label[:train_ratio]
validate_x, validate_y = data[train_ratio:train_ratio + validate_ratio], label[train_ratio:train_ratio + validate_ratio]
test_x, test_y = data[train_ratio + validate_ratio:], label[train_ratio + validate_ratio:]
self.mind_count = np.zeros(1024) # used for CE weights
if mode == "train":
self.data, self.labels = train_x, train_y
elif mode == "val":
self.data, self.labels = validate_x, validate_y
elif mode == "test":
self.data, self.labels = test_x, test_y
self.false_beliefs_path = f"{path}/store_mind_set"
# keep small amouts of data in memory
self.skeleton_3D = self.load_skeleton_3D(self.skeleton_3D_path, list_of_ids_to_consider)
self.tracker_2D = self.load_tracker_2D(self.tracker_2D_path, list_of_ids_to_consider)
self.bbox_df = pd.read_csv(self.bbox_csv_path, header=0)
self.obj_ids = self.load_obj_ids(self.obj_ids_path)
if not use_preprocessed_img:
normalisation_steps = [
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
if resize_img is not None:
normalisation_steps.insert(1, T.Resize(resize_img))
self.preprocess_img = T.Compose(normalisation_steps)
else:
self.preprocess_img = None
self.use_preprocessed_img = use_preprocessed_img
print(f"Done loading in {time.time() - start}s.")
def __len__(self):
return len(self.data)
def __getitem__(
self, idx: int
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict, str, int, str
]:
"""Given an index, return the corresponding experiment_id and timestep in the experiment.
Then picky the appropriate data and labels from these.
Args:
idx (int): _description_
Returns:
tuple: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict
Returns the following:
- img_kinect: torch.Tensor of shape (T, C, H, W) (Default is [T, 3, 720, 1280])
- img_tracker: torch.Tensor of shape (T, H, W, C)
- img_battery: torch.Tensor of shape (T, H, W, C)
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 1)
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 2)
- bbox: torch.Tensor of shape (T, num_obj, 5)
- tracker to skeleton ID: str (either skeleton 1 or 2)
- tracker_2D: torch.Tensor of shape (T, 2)
- labels: dict (see below)
"""
labels = self.label_map[self.labels[idx]]
experiment_id = self.data[idx][1][0].split('/')[6]
img_data_path = f"{self.img_path}/{experiment_id}"
frame_ids = [int(os.path.basename(self.data[idx][1][i]).split('_')[0]) for i in range(len(self.data[idx][1]))]
if self.use_preprocessed_img:
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.pt")))
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.pt")))
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.pt")))
kinect_img_paths = [kinect[id] for id in frame_ids]
tracker_img_paths = [tracker[id] for id in frame_ids]
battery_img_paths = [battery[id] for id in frame_ids]
else:
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.jpg")))
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.jpg")))
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.jpg")))
kinect_img_paths = [kinect[id] for id in frame_ids]
tracker_img_paths = [tracker[id] for id in frame_ids]
battery_img_paths = [battery[id] for id in frame_ids]
# load images
kinect_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in kinect_img_paths
]
kinect_imgs = torch.stack(kinect_imgs, axis=0)
tracker_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in tracker_img_paths
]
tracker_imgs = torch.stack(tracker_imgs, axis=0)
battery_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in battery_img_paths
]
battery_imgs = torch.stack(battery_imgs, axis=0)
# load object id to check for false beliefs - only for testing
if self.mode == "test": #or self.mode == "train":
if f"{experiment_id}.txt" in os.listdir(self.false_beliefs_path):
obj_id = self.obj_ids[experiment_id][frame_ids[-1]]
obj_id = next(x for x in obj_id if x is not None)
false_belief = next((line.strip().split(',')[2] for line in open(f"{self.false_beliefs_path}/{experiment_id}.txt") if line.startswith(str(frame_ids[-1]) + ',' + obj_id + ',')), "no")
#if experiment_id in ['test_boelter4_0', 'test_boelter4_7', 'test_boelter4_6', 'test_boelter4_8', 'test_boelter2_3',
# 'test_94342_20', 'test_94342_18', 'test_94342_11', 'test_94342_17', 'test_boelter3_8', 'test_94342_2',
# 'test_boelter2_17', 'test_boelter3_7', 'test_94342_4', 'test_boelter3_9', 'test_boelter_10',
# 'test_boelter2_6', 'test_boelter4_10', 'test_boelter4_2', 'test_boelter4_5', 'test_94342_24',
# 'test_94342_15', 'test_boelter3_5', 'test_94342_8', 'test2', 'test_boelter3_12']:
# print('here!')
# with open(os.path.join(f'results/hgm_test_fb.csv'), mode='a') as file:
# writer = csv.writer(file)
# writer.writerow([experiment_id, obj_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
else:
false_belief = "no"
#with open(os.path.join(f'results/test_fb.csv'), mode='a') as file:
# writer = csv.writer(file)
# writer.writerow([experiment_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
df = self.bbox_df[
(self.bbox_df.experiment_name == experiment_id)
#& (self.bbox_df.name == obj_id) # NOTE: load the bounding boxes for all objects
& (self.bbox_df.name != 'P1')
& (self.bbox_df.name != 'P2')
& (self.bbox_df.frame.isin(frame_ids))
]
bboxes = []
for f in frame_ids:
bbox = torch.tensor(df.loc[df['frame'] == f, ["x_min", "y_min", "x_max", "y_max"]].to_numpy(), dtype=torch.float32)
bbox[:, 0] = bbox[:, 0] / 1280.0
bbox[:, 1] = bbox[:, 1] / 720.0
bbox[:, 2] = bbox[:, 2] / 1280.0
bbox[:, 3] = bbox[:, 3] / 720.0
bboxes.append(bbox)
bboxes = torch.stack(bboxes) # NOTE: this will need a collate function bc not every video has the same number of objects
skele1 = self.skeleton_3D[experiment_id]["skele1"][frame_ids]
skele2 = self.skeleton_3D[experiment_id]["skele2"][frame_ids]
gaze = self.tracker_2D[experiment_id][frame_ids]
if self.mode == "test":
return (
kinect_imgs,
tracker_imgs,
battery_imgs,
skele1,
skele2,
bboxes,
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
gaze,
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
experiment_id,
frame_ids,
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
false_belief
)
else:
return (
kinect_imgs,
tracker_imgs,
battery_imgs,
skele1,
skele2,
bboxes,
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
gaze,
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
experiment_id,
frame_ids
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
)
def onehot(self, x, n=len(UNIQUE_OBJ_IDS)):
retval = torch.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
def load_obj_ids(self, path: str):
with open(path, "rb") as f:
ids = pickle.load(f)
return ids
def extract_labels(self):
"""TODO: Converts index label to [m1, m2, m12, m21, mc] format.
"""
return
def _flatten_mind_obj_timestep(self, mind_obj_dict: dict) -> list:
"""Flattens the mind object dict to a list. I.e. takes
{
"m1": {
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
"loc": null
},
"m2": {
"fluent": 3,
"loc": null
},
"m12": {
"fluent": 3,
"loc": null
},
"m21": {
"fluent": 3,
"loc": null
},
"mc": {
"fluent": 3,
"loc": null
},
"mg": {
"fluent": 3,
"loc": [
22,
9
]
}
}
and returns [3, 3, 3, 3, 3, 3]
Args:
mind_obj_dict (dict): Mind object dict as described in __init__.doctstring.
Returns:
list: List of mind object labels.
"""
return np.array([mind_obj["fluent"] for key, mind_obj in mind_obj_dict.items() if key != "mg"])
def load_skeleton_3D(self, path: str, list_of_ids_to_consider: list):
"""Load skeleton 3D data from disk.
- path
- * <- list of ids
- skele1.p <- 3D coord per id and timestep
- skele2.p <-
Args:
path (str): Where the skeleton 3D data lie.
list_of_ids_to_consider (list): List of ids to consider.
Defaults to None which means all ids. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
Returns:
dict: skeleton 3D data as described above in __init__.doctstring.
"""
skeleton_3D = {}
for experiment_id in list_of_ids_to_consider:
skeleton_3D[experiment_id] = {}
with open(f"{path}/{experiment_id}/skele1.p", "rb") as f:
skeleton_3D[experiment_id]["skele1"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
with open(f"{path}/{experiment_id}/skele2.p", "rb") as f:
skeleton_3D[experiment_id]["skele2"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
return skeleton_3D
def load_tracker_2D(self, path: str, list_of_ids_to_consider: list):
"""Load tracker 2D data from disk.
- path
- *.p <- 2D coord per id and timestep
Args:
path (str): Where the tracker 2D data lie.
list_of_ids_to_consider (list): List of ids to consider.
Defaults to None which means all ids. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
Returns:
dict: tracker 2D data.
"""
tracker_2D = {}
for experiment_id in list_of_ids_to_consider:
with open(f"{path}/{experiment_id}.p", "rb") as f:
tracker_2D[experiment_id] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
return tracker_2D
def load_bbox(self, path: str, list_of_ids_to_consider: list):
"""Load bbox data from disk.
- bbox_tensors.pickle <- bbox per experiment id one tensor
Args:
path (str): Where the bbox data lie.
list_of_ids_to_consider (list): List of ids to consider.
Returns:
dict: bbox data.
"""
with open(path, "rb") as f:
pickle_data = pickle.load(f)
for key in CLIPS_IDS_88:
if key not in list_of_ids_to_consider:
pickle_data.pop(key, None)
return pickle_data
if __name__ == '__main__':
# os.environ['PYTHONHASHSEED'] = str(42)
# torch.manual_seed(42)
# np.random.seed(42)
# random.seed(42)
data = TBDDataset(use_preprocessed_img=True, mode="test")
from tqdm import tqdm
for i in tqdm(range(data.__len__())):
data[i]
breakpoint()
from torch.utils.data import DataLoader
# Just for guessing time
data_0=data[0]
data_last=data[len(data)-1]
idx = np.random.randint(1, len(data)-1) # Something in between.
start = time.time()
(
kinect_imgs, # <- len x 720 x 1280 x 3 originally, likely smaller now
tracker_imgs,
battery_imgs,
skele1,
skele2,
bbox,
tracker_skeID_sample, # <- This is the tracker skeleton ID
tracker2d,
label,
experiment_id, # From here for debugging
timestep,
#obj_id, # <- This is the object ID as a one-hot
false_belief
) = data[idx]
end = time.time()
print(f"Time for one sample: {end-start}")
print('kinect:', kinect_imgs.shape)
print('tracker:', tracker_imgs.shape)
print('battery:', battery_imgs.shape)
print('skele1:', skele1.shape)
print('skele2:', skele2.shape)
print('gaze:', tracker2d.shape)
print('bbox:', bbox.shape)
print('label:', label)
#breakpoint()
dl = DataLoader(
data,
batch_size=4,
shuffle=False,
collate_fn=collate_fn
)
from tqdm import tqdm
for j, batch in tqdm(enumerate(dl)):
#print(j, end='\r')
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
#breakpoint()
#print(img_3rd_pov.shape)
#print(img_tracker.shape)
#print(img_battery.shape)
#print(pose1.shape, pose2.shape)
#print(bbox.shape)
#print(gaze.shape)
breakpoint()

196
tbd/test.py Normal file
View file

@ -0,0 +1,196 @@
import torch
import csv
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
import random
import os
import numpy as np
from tbd_dataloader import TBDDataset, collate_fn_test
from models.common_mind import CommonMindToMnet
from models.sl import SLToMnet
from models.implicit import ImplicitToMnet
from utils.helpers import compute_f1_scores
def test(args):
test_dataset = TBDDataset(
path=args.data_path,
mode="test",
use_preprocessed_img=True
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn_test
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
# model
if args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
model.device = device
model.eval()
if args.save_preds:
# Define the output file path
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(f'Saving predictions in {folder_path}.')
print('Testing...')
m1_pred_list = []
m2_pred_list = []
m12_pred_list = []
m21_pred_list = []
mc_pred_list = []
m1_label_list = []
m2_label_list = []
m12_label_list = []
m21_label_list = []
mc_label_list = []
with torch.no_grad():
for j, batch in tqdm(enumerate(test_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
m1_pred_list.append(m1_pred)
m2_pred_list.append(m2_pred)
m12_pred_list.append(m12_pred)
m21_pred_list.append(m21_pred)
mc_pred_list.append(mc_pred)
m1_label_list.append(m1_label)
m2_label_list.append(m2_label)
m12_label_list.append(m12_label)
m21_label_list.append(m21_label)
mc_label_list.append(mc_label)
if args.save_preds:
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
data = [(
i,
torch.argmax(m1_pred[i]).cpu().numpy(),
torch.argmax(m2_pred[i]).cpu().numpy(),
torch.argmax(m12_pred[i]).cpu().numpy(),
torch.argmax(m21_pred[i]).cpu().numpy(),
torch.argmax(mc_pred[i]).cpu().numpy(),
m1_label[i].cpu().numpy(),
m2_label[i].cpu().numpy(),
m12_label[i].cpu().numpy(),
m21_label[i].cpu().numpy(),
mc_label[i].cpu().numpy(),
false_belief[i]) for i in range(len(labels))
]
header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief']
with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header) # Write the header row
writer.writerows(data) # Write the data rows
#np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy())
test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores(
torch.cat(m1_pred_list),
torch.cat(m1_label_list),
torch.cat(m2_pred_list),
torch.cat(m2_label_list),
torch.cat(m12_pred_list),
torch.cat(m12_label_list),
torch.cat(m21_pred_list),
torch.cat(m21_label_list),
torch.cat(mc_pred_list),
torch.cat(mc_label_list)
)
print("Test m1 F1: {}".format(test_m1_f1))
print("Test m2 F1: {}".format(test_m2_f1))
print("Test m12 F1: {}".format(test_m12_f1))
print("Test m21 F1: {}".format(test_m21_f1))
print("Test mc F1: {}".format(test_mc_f1))
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}")
f.write(f"m1 f1: {test_m1_f1}")
f.write(f"m2 f1: {test_m2_f1}")
f.write(f"m12 f1: {test_m12_f1}")
f.write(f"m21 f1: {test_m21_f1}")
f.write(f"mc f1: {test_mc_f1}")
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--model_type', type=str)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--aggr', type=str, default='concat', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--test_frames', type=str, default=None)
parser.add_argument('--median', type=int, default=None)
parser.add_argument('--load_model_path', type=str)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--save_preds', action='store_true')
# Parse the command-line arguments
args = parser.parse_args()
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print('###########################################################################')
print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!')
print('###########################################################################')
test(args)

474
tbd/train.py Normal file
View file

@ -0,0 +1,474 @@
import torch
import os
import argparse
import numpy as np
import random
import datetime
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from tbd_dataloader import TBDDataset, collate_fn
from models.common_mind import CommonMindToMnet
from models.sl import SLToMnet
from models.implicit import ImplicitToMnet
from utils.helpers import count_parameters, compute_f1_scores
def main(args):
train_dataset = TBDDataset(
path=args.data_path,
mode="train",
use_preprocessed_img=True
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn
)
val_dataset = TBDDataset(
path=args.data_path,
mode="val",
use_preprocessed_img=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn
)
train_data = [data[1] for data in train_dataset.data]
val_data = [data[1] for data in val_dataset.data]
if args.logger:
wandb.config.update({"train_data": train_data})
wandb.config.update({"val_data": val_data})
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
# model
if args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
if args.resume_from_checkpoint is not None:
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# scheduler
if args.scheduler == None:
scheduler = None
else:
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
# loss function
if args.model_type == 'tom_sl':
ce_loss_m1 = nn.NLLLoss()
ce_loss_m2 = nn.NLLLoss()
ce_loss_m12 = nn.NLLLoss()
ce_loss_m21 = nn.NLLLoss()
ce_loss_mc = nn.NLLLoss()
else:
ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
stats = {
'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []},
'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}
}
max_val_f1 = 0
max_val_classification_epoch = None
counter = 0
print(f'Number of parameters: {count_parameters(model)}')
for i in range(args.num_epoch):
# training
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
epoch_train_loss_m1 = 0.0
epoch_train_loss_m2 = 0.0
epoch_train_loss_m12 = 0.0
epoch_train_loss_m21 = 0.0
epoch_train_loss_mc = 0.0
m1_train_batch_pred_list = []
m2_train_batch_pred_list = []
m12_train_batch_pred_list = []
m21_train_batch_pred_list = []
mc_train_batch_pred_list = []
m1_train_batch_label_list = []
m2_train_batch_label_list = []
m12_train_batch_label_list = []
m21_train_batch_label_list = []
mc_train_batch_label_list = []
model.train()
for j, batch in tqdm(enumerate(train_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
loss_m1 = ce_loss_m1(m1_pred, m1_label)
loss_m2 = ce_loss_m2(m2_pred, m2_label)
loss_m12 = ce_loss_m12(m12_pred, m12_label)
loss_m21 = ce_loss_m21(m21_pred, m21_label)
loss_mc = ce_loss_mc(mc_pred, mc_label)
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
epoch_train_loss_m1 += loss_m1.data.item()
epoch_train_loss_m2 += loss_m2.data.item()
epoch_train_loss_m12 += loss_m12.data.item()
epoch_train_loss_m21 += loss_m21.data.item()
epoch_train_loss_mc += loss_mc.data.item()
m1_train_batch_pred_list.append(m1_pred)
m2_train_batch_pred_list.append(m2_pred)
m12_train_batch_pred_list.append(m12_pred)
m21_train_batch_pred_list.append(m21_pred)
mc_train_batch_pred_list.append(mc_pred)
m1_train_batch_label_list.append(m1_label)
m2_train_batch_label_list.append(m2_label)
m12_train_batch_label_list.append(m12_label)
m21_train_batch_label_list.append(m21_label)
mc_train_batch_label_list.append(mc_label)
optimizer.zero_grad()
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
loss.backward()
optimizer.step()
if args.logger: wandb.log({
'batch_train_loss': loss.data.item(),
'lr': optimizer.param_groups[-1]['lr']
})
print("Epoch {}/{} batch {}/{} training done with loss={}".format(
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item())
)
if scheduler: scheduler.step()
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores(
torch.cat(m1_train_batch_pred_list),
torch.cat(m1_train_batch_label_list),
torch.cat(m2_train_batch_pred_list),
torch.cat(m2_train_batch_label_list),
torch.cat(m12_train_batch_pred_list),
torch.cat(m12_train_batch_label_list),
torch.cat(m21_train_batch_pred_list),
torch.cat(m21_train_batch_label_list),
torch.cat(mc_train_batch_pred_list),
torch.cat(mc_train_batch_label_list)
)
print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format(
i+1,
args.num_epoch,
epoch_train_loss_m1/len(train_dataloader),
epoch_train_loss_m2/len(train_dataloader),
epoch_train_loss_m12/len(train_dataloader),
epoch_train_loss_m21/len(train_dataloader),
epoch_train_loss_mc/len(train_dataloader),
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score
)
)
stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader))
stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader))
stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader))
stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader))
stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader))
stats['train']['m1_f1'].append(train_m1_f1_score)
stats['train']['m2_f1'].append(train_m2_f1_score)
stats['train']['m12_f1'].append(train_m12_f1_score)
stats['train']['m21_f1'].append(train_m21_f1_score)
stats['train']['mc_f1'].append(train_mc_f1_score)
if args.logger: wandb.log(
{
'train_m1_loss': epoch_train_loss_m1/len(train_dataloader),
'train_m2_loss': epoch_train_loss_m2/len(train_dataloader),
'train_m12_loss': epoch_train_loss_m12/len(train_dataloader),
'train_m21_loss': epoch_train_loss_m21/len(train_dataloader),
'train_mc_loss': epoch_train_loss_mc/len(train_dataloader),
'train_loss': epoch_train_loss_m1/len(train_dataloader) + \
epoch_train_loss_m2/len(train_dataloader) + \
epoch_train_loss_m12/len(train_dataloader) + \
epoch_train_loss_m21/len(train_dataloader) + \
epoch_train_loss_mc/len(train_dataloader),
'train_m1_f1_score': train_m1_f1_score,
'train_m2_f1_score': train_m2_f1_score,
'train_m12_f1_score': train_m12_f1_score,
'train_m21_f1_score': train_m21_f1_score,
'train_mc_f1_score': train_mc_f1_score
}
)
# validation
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
epoch_val_loss_m1 = 0.0
epoch_val_loss_m2 = 0.0
epoch_val_loss_m12 = 0.0
epoch_val_loss_m21 = 0.0
epoch_val_loss_mc = 0.0
m1_val_batch_pred_list = []
m2_val_batch_pred_list = []
m12_val_batch_pred_list = []
m21_val_batch_pred_list = []
mc_val_batch_pred_list = []
m1_val_batch_label_list = []
m2_val_batch_label_list = []
m12_val_batch_label_list = []
m21_val_batch_label_list = []
mc_val_batch_label_list = []
model.eval()
with torch.no_grad():
for j, batch in tqdm(enumerate(val_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
loss_m1 = ce_loss_m1(m1_pred, m1_label)
loss_m2 = ce_loss_m2(m2_pred, m2_label)
loss_m12 = ce_loss_m12(m12_pred, m12_label)
loss_m21 = ce_loss_m21(m21_pred, m21_label)
loss_mc = ce_loss_mc(mc_pred, mc_label)
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
epoch_val_loss_m1 += loss_m1.data.item()
epoch_val_loss_m2 += loss_m2.data.item()
epoch_val_loss_m12 += loss_m12.data.item()
epoch_val_loss_m21 += loss_m21.data.item()
epoch_val_loss_mc += loss_mc.data.item()
m1_val_batch_pred_list.append(m1_pred)
m2_val_batch_pred_list.append(m2_pred)
m12_val_batch_pred_list.append(m12_pred)
m21_val_batch_pred_list.append(m21_pred)
mc_val_batch_pred_list.append(mc_pred)
m1_val_batch_label_list.append(m1_label)
m2_val_batch_label_list.append(m2_label)
m12_val_batch_label_list.append(m12_label)
m21_val_batch_label_list.append(m21_label)
mc_val_batch_label_list.append(mc_label)
if args.logger: wandb.log({'batch_val_loss': loss.data.item()})
print("Epoch {}/{} batch {}/{} validation done with loss={}".format(
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item())
)
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores(
torch.cat(m1_val_batch_pred_list),
torch.cat(m1_val_batch_label_list),
torch.cat(m2_val_batch_pred_list),
torch.cat(m2_val_batch_label_list),
torch.cat(m12_val_batch_pred_list),
torch.cat(m12_val_batch_label_list),
torch.cat(m21_val_batch_pred_list),
torch.cat(m21_val_batch_label_list),
torch.cat(mc_val_batch_pred_list),
torch.cat(mc_val_batch_label_list)
)
print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format(
i+1,
args.num_epoch,
epoch_val_loss_m1/len(val_dataloader),
epoch_val_loss_m2/len(val_dataloader),
epoch_val_loss_m12/len(val_dataloader),
epoch_val_loss_m21/len(val_dataloader),
epoch_val_loss_mc/len(val_dataloader),
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score
)
)
stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader))
stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader))
stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader))
stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader))
stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader))
stats['val']['m1_f1'].append(val_m1_f1_score)
stats['val']['m2_f1'].append(val_m2_f1_score)
stats['val']['m12_f1'].append(val_m12_f1_score)
stats['val']['m21_f1'].append(val_m21_f1_score)
stats['val']['mc_f1'].append(val_mc_f1_score)
if args.logger: wandb.log(
{
'val_m1_loss': epoch_val_loss_m1/len(val_dataloader),
'val_m2_loss': epoch_val_loss_m2/len(val_dataloader),
'val_m12_loss': epoch_val_loss_m12/len(val_dataloader),
'val_m21_loss': epoch_val_loss_m21/len(val_dataloader),
'val_mc_loss': epoch_val_loss_mc/len(val_dataloader),
'val_loss': epoch_val_loss_m1/len(val_dataloader) + \
epoch_val_loss_m2/len(val_dataloader) + \
epoch_val_loss_m12/len(val_dataloader) + \
epoch_val_loss_m21/len(val_dataloader) + \
epoch_val_loss_mc/len(val_dataloader),
'val_m1_f1_score': val_m1_f1_score,
'val_m2_f1_score': val_m2_f1_score,
'val_m12_f1_score': val_m12_f1_score,
'val_m21_f1_score': val_m21_f1_score,
'val_mc_f1_score': val_mc_f1_score
}
)
# check for best stat/model using validation accuracy
if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1:
max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1]
max_val_classification_epoch = i+1
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
counter = 0
else:
counter += 1
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
if counter >= args.patience:
break
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
f.write('{}\n'.format(CFG))
f.write('{}\n'.format(train_data))
f.write('{}\n'.format(val_data))
f.write('{}\n'.format(stats))
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1))
f.close()
print(f'Results saved in {experiment_save_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--logger', action='store_true')
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
parser.add_argument('--use_mixup', action='store_true')
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--patience', type=int, default=99)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--num_epoch', type=int, default=300)
parser.add_argument('--lr', type=float, default=4e-4)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=0.005)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='concat', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
# Parse the command-line arguments
args = parser.parse_args()
if args.use_mixup and not args.mixup_alpha:
parser.error("--use_mixup requires --mixup_alpha")
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
# get experiment ID
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path, exist_ok=True)
experiment_save_path = os.path.join(args.save_path, experiment_id)
os.makedirs(experiment_save_path, exist_ok=True)
CFG = {
'use_ocr_custom_loss': 0,
'presaved': args.presaved,
'batch_size': args.batch_size,
'num_epoch': args.num_epoch,
'lr': args.lr,
'scheduler': args.scheduler,
'weight_decay': args.weight_decay,
'model_type': args.model_type,
'use_resnet': args.use_resnet,
'hidden_dim': args.hidden_dim,
'tom_weight': args.tom_weight,
'dropout': args.dropout,
'label_smoothing': args.label_smoothing,
'clip_grad_norm': args.clip_grad_norm,
'use_mixup': args.use_mixup,
'mixup_alpha': args.mixup_alpha,
'non_blocking_tensors': args.non_blocking,
'patience': args.patience,
'pin_memory': args.pin_memory,
'resume_from_checkpoint': args.resume_from_checkpoint,
'aggr': args.aggr,
'mods': args.mods,
'save_path': experiment_save_path ,
'seed': args.seed
}
print(CFG)
print(f'Saving results in {experiment_save_path}')
# set seed values
if args.logger:
wandb.init(project="tbd", config=CFG)
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
main(args)

224
tbd/utils/fb_scores_err.py Normal file
View file

@ -0,0 +1,224 @@
import os
import csv
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import f1_score
import numpy as np
from tqdm import tqdm
ALPHA = 0.7
BAR_WIDTH = 0.27
sns.set_theme(style='whitegrid')
#sns.set_palette('mako')
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
model_to_subdir = {
"IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"],
"IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"],
"IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"],
"IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"],
"CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"],
"CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"],
"CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"],
"CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"],
"DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"],
"Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"]
}
def read_data_from_csv(subdirectory_path):
print(subdirectory_path)
data = []
csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')]
for csv_file in csv_files:
file_path = os.path.join(subdirectory_path, csv_file)
with open(file_path, 'r') as file:
reader = csv.reader(file)
header_skipped = False
for row in reader:
if not header_skipped:
header_skipped = True
continue
frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row
data.append({
'frame': int(frame),
'm1_pred': int(m1_pred),
'm2_pred': int(m2_pred),
'm12_pred': int(m12_pred),
'm21_pred': int(m21_pred),
'mc_pred': int(mc_pred),
'm1_label': int(m1_label),
'm2_label': int(m2_label),
'm12_label': int(m12_label),
'm21_label': int(m21_label),
'mc_label': int(mc_label),
'false_belief': false_belief,
})
return data
def compute_correct_false_belief(data, mind="all", folder=None):
total_false_belief = 0
correct_false_belief = 0
for item in data:
if 'false' in item['false_belief']:
false_belief_type = item['false_belief'].split('_')[0]
if mind == "all" or false_belief_type in mind:
total_false_belief += 1
if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]:
if folder is not None:
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
f.write(f"{str(1)}\n")
correct_false_belief += 1
else:
if folder is not None:
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
f.write(f"{str(0)}\n")
if total_false_belief == 0:
accuracy = 0.0
else:
accuracy = correct_false_belief / total_false_belief
return accuracy
def compute_macro_f1_score(data, mind="all"):
y_true = []
y_pred = []
for item in data:
if 'false' in item['false_belief']:
false_belief_type = item['false_belief'].split('_')[0]
if mind == "all" or false_belief_type in mind:
y_true.append(int(item[f"{false_belief_type}_label"]))
y_pred.append(int(item[f"{false_belief_type}_pred"]))
if not y_true or not y_pred:
macro_f1 = 0.0
else:
macro_f1 = f1_score(y_true, y_pred, average='macro')
return macro_f1
def delete_files_in_subfolders(folder_path, file_names_to_delete):
"""
Delete specified files in all subfolders of a given folder.
Parameters:
folder_path: The path to the folder containing subfolders.
file_names_to_delete: A list of file names to be deleted.
Returns:
None
"""
for root, _, _ in os.walk(folder_path):
for file_name in file_names_to_delete:
file_path = os.path.join(root, file_name)
if os.path.exists(file_path):
os.remove(file_path)
print(f"Deleted: {file_path}")
if __name__ == "__main__":
folder_path = "predictions"
files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"]
delete_files_in_subfolders(folder_path, files_to_delete)
metric = "Accuracy"
if metric == "Macro F1":
score_function = compute_macro_f1_score
elif metric == "Accuracy":
score_function = compute_correct_false_belief
else:
raise ValueError
models = [
'Base', 'DB',
'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$',
'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$'
]
parent_dir = 'predictions'
minds = categories = ['m1', 'm2', 'm12', 'm21']
score_m1_m2 = []
score_m12_m21 = []
score_all = []
std_m1_m2 = []
std_m12_m21 = []
std_all = []
for model in models:
model_scores_m1_m2 = []
model_scores_m12_m21 = []
model_scores_all = []
for s in range(3):
subdir_path = os.path.join(parent_dir, model_to_subdir[model][s])
data = read_data_from_csv(subdir_path)
model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s]))
model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s]))
model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s]))
score_m1_m2.append(np.mean(model_scores_m1_m2))
std_m1_m2.append(np.std(model_scores_m1_m2))
score_m12_m21.append(np.mean(model_scores_m12_m21))
std_m12_m21.append(np.std(model_scores_m12_m21))
score_all.append(np.mean(model_scores_all))
std_all.append(np.std(model_scores_all))
# Create a dataframe to use with sns.catplot
data = {
'Model': [m for m in models],
'FO_FB_mean': score_m1_m2,
'FO_FB_std': std_m1_m2,
'SO_FB_mean': score_m12_m21,
'SO_FB_std': std_m12_m21,
'Both_mean': score_all,
'Both_std': std_all
}
models = data['Model']
fo_fb_mean = data['FO_FB_mean']
fo_fb_std = data['FO_FB_std']
so_fb_mean = data['SO_FB_mean']
so_fb_std = data['SO_FB_std']
both_mean = data['Both_mean']
both_std = data['Both_std']
bar_width = BAR_WIDTH
x = np.arange(len(models))
plt.figure(figsize=(13, 3.5))
fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA)
so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA)
both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA)
def add_labels(bars, std_values):
cnt = 0
for bar, std in zip(bars, std_values):
height = bar.get_height()
offset = std + 0.01
if cnt == 0 or cnt == 1 or cnt == 9:
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10)
else:
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10)
cnt = cnt + 1
add_labels(fo_fb_bars, fo_fb_std)
add_labels(so_fb_bars, so_fb_std)
add_labels(both_bars, both_std)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.xlabel('MToMnet', fontsize=14)
plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14)
plt.xticks(x, models, rotation=0, fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3)
plt.tight_layout()
plt.savefig('results/false_belief_first_vs_second.pdf')

210
tbd/utils/helpers.py Normal file

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,37 @@
import glob
import cv2
import torchvision.transforms as T
import torch
import os
from tqdm import tqdm
PATH_IN = "/scratch/bortoletto/data/tbd/images"
PATH_OUT = "/scratch/bortoletto/data/tbd/images_norm"
normalisation_steps = [
T.ToTensor(),
T.Resize((128,128)),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
preprocess_img = T.Compose(normalisation_steps)
def main():
print(f"{PATH_IN}/*/*/*.jpg")
all_img = glob.glob(f"{PATH_IN}/*/*/*.jpg")
print(len(all_img))
for img_path in tqdm(all_img):
new_img = preprocess_img(cv2.imread(img_path)).numpy()
img_path_split = img_path.split("/")
os.makedirs(f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}", exist_ok=True)
out_img = f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}/{img_path_split[-1][:-4]}.pt"
torch.save(new_img, out_img)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,106 @@
import pandas as pd
import os
import glob
import pickle
DATASET_LOCATION = "YOUR_PATH_HERE"
def reframe_annotation():
annotation_path = f'{DATASET_LOCATION}/retrieve_annotation/all/'
save_path = f'{DATASET_LOCATION}/reformat_annotation/'
if not os.path.exists(save_path):
os.makedirs(save_path)
tasks = glob.glob(annotation_path + '*.txt')
id_map = pd.read_csv('id_map.csv')
for task in tasks:
if not task.split('/')[-1].split('_')[2] == '1.txt':
continue
with open(task, 'r') as f:
lines = f.readlines()
task_id = int(task.split('/')[-1].split('_')[1]) + 1
clip = id_map.loc[id_map['ID'] == task_id].folder
print(task_id, len(clip))
if len(clip) == 0:
continue
with open(save_path + clip.item() + '.txt', 'w') as f:
for line in lines:
words = line.split()
f.write(words[0] + ',' + words[1] + ',' + words[2] + ',' + words[3] + ',' + words[4] + ',' + words[5] +
',' + words[6] + ',' + words[7] + ',' + words[8] + ',' + words[9] + ',' + ' '.join(words[10:]) + '\n')
f.close()
def get_grid_location(obj_frame):
x_min = obj_frame['x_min']#.item()
y_min = obj_frame['y_min']#.item()
x_max = obj_frame['x_max']#.item()
y_max = obj_frame['y_max']#.item()
gridLW = 1280 / 25.
gridLH = 720 / 15.
center_x, center_y = (x_min + x_max)/2, (y_min + y_max)/2
X, Y = int(center_x / gridLW), int(center_y / gridLH)
return X, Y
def regenerate_annotation():
annotation_path = f'{DATASET_LOCATION}/reformat_annotation/'
save_path=f'{DATASET_LOCATION}/regenerate_annotation/'
if not os.path.exists(save_path):
os.makedirs(save_path)
tasks = glob.glob(annotation_path + '*.txt')
for task in tasks:
print(task)
annt = pd.read_csv(task, sep=",", header=None)
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name", "label"]
obj_records = {}
for index, obj_frame in annt.iterrows():
if obj_frame['name'].startswith('P'):
continue
else:
assert obj_frame['name'].startswith('O')
obj_name = obj_frame['name']
# 0: enter 1: disappear 2: update 3: unchange
frame_id = obj_frame['frame']
curr_loc = get_grid_location(obj_frame)
mind_dict = {'m1': {'fluent': 3, 'loc': None}, 'm2': {'fluent': 3, 'loc': None},
'm12': {'fluent': 3, 'loc': None},
'm21': {'fluent': 3, 'loc': None}, 'mc': {'fluent': 3, 'loc': None},
'mg': {'fluent': 3, 'loc': curr_loc}}
mind_dict['mg']['loc'] = curr_loc
if not type(obj_frame['label']) == float:
mind_labels = obj_frame['label'].split()
for mind_label in mind_labels:
if mind_label == 'in_m1' or mind_label == 'in_m2' or mind_label == 'in_m12' \
or mind_label == 'in_m21' or mind_label == 'in_mc' or mind_label == '"in_m1"' or mind_label == '"in_m2"'\
or mind_label == '"in_m12"' or mind_label == '"in_m21"' or mind_label == '"in_mc"':
mind_name = mind_label.split('_')[1].split('"')[0]
mind_dict[mind_name]['loc'] = curr_loc
else:
mind_name = mind_label.split('_')[0].split('"')
if len(mind_name) > 1:
mind_name = mind_name[1]
else:
mind_name = mind_name[0]
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
mind_dict[mind_name]['loc'] = last_loc
for mind_name in mind_dict.keys():
if frame_id > 0:
curr_loc = mind_dict[mind_name]['loc']
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
if last_loc is None and curr_loc is not None:
mind_dict[mind_name]['fluent'] = 0
elif last_loc is not None and curr_loc is None:
mind_dict[mind_name]['fluent'] = 1
elif not last_loc == curr_loc:
mind_dict[mind_name]['fluent'] = 2
if obj_name not in obj_records:
obj_records[obj_name] = [mind_dict]
else:
obj_records[obj_name].append(mind_dict)
with open(save_path + task.split('/')[-1].split('.')[0] + '.p', 'wb') as f:
pickle.dump(obj_records, f)
if __name__ == '__main__':
reframe_annotation()
regenerate_annotation()

75
tbd/utils/similarity.py Normal file
View file

@ -0,0 +1,75 @@
import os
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import seaborn as sns
FOLDER_PATH = 'PATH_TO_FOLDER'
print(FOLDER_PATH)
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
COLORS = sns.color_palette()
sns.set_theme(style='white')
out_left_main_mods_full_test = []
out_right_main_mods_full_test = []
cell_left_main_mods_full_test = []
cell_right_main_mods_full_test = []
cm_left_main_mods_full_test = []
cm_right_main_mods_full_test = []
for i in range(len([filename for filename in os.listdir(FOLDER_PATH) if filename.endswith('.pt')])):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 13: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
elif len(data) == 12: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
elif len(data) == 11: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
else: raise ValueError("Data should have 13 (impl), others are not implemented")
# ====== PCA for left and right embeddings ====== #
out_left_pca = out_left[0].reshape(-1, 64)
out_right_pca = out_right[0].reshape(-1, 64)
out_left_and_right = np.concatenate((out_left_pca, out_right_pca), axis=0)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
pca_result_left = pca_result[:out_left_pca.shape[0]]
pca_result_right = pca_result[out_right_pca.shape[0]:]
plt.figure(figsize=(6.8,6))
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
plt.xlabel('Principal Component 1', fontsize=32)
plt.ylabel('Principal Component 2', fontsize=32)
plt.xticks(fontsize=24)
plt.xticks([-0.4, -0.2, 0.0, 0.2, 0.4])
plt.yticks(fontsize=24)
plt.legend(fontsize=32)
plt.tight_layout()
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
plt.close()

View file

@ -0,0 +1,96 @@
import os
import pandas as pd
import pickle
from tqdm import tqdm
def check_append(obj_name, m1, mind_name, obj_frame, flags, label):
if label:
if not obj_name in m1:
m1[obj_name] = []
m1[obj_name].append(
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
flags[mind_name] = 1
elif not flags[mind_name]:
m1[obj_name].append(
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
flags[mind_name] = 1
else: # false belief
if obj_name in m1:
if flags[mind_name]:
m1[obj_name].append(
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 1])
flags[mind_name] = 0
return flags, m1
def store_mind_set(clip, annotation_path, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
annt = pd.read_csv(annotation_path + clip, sep=",", header=None)
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name",
"label"]
obj_names = annt.name.unique()
m1, m2, m12, m21, mc = {}, {}, {}, {}, {}
flags = {'m1':0, 'm2':0, 'm12':0, 'm21':0, 'mc':0}
for obj_name in obj_names:
if obj_name == 'P1' or obj_name == 'P2':
continue
obj_frames = annt.loc[annt.name == obj_name]
for index, obj_frame in obj_frames.iterrows():
if type(obj_frame.label) == float:
continue
labels = obj_frame.label.split()
for label in labels:
if label == 'in_m1' or label == '"in_m1"':
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 1)
elif label == 'in_m2' or label == '"in_m2"':
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 1)
elif label == 'in_m12'or label == '"in_m12"':
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 1)
elif label == 'in_m21' or label == '"in_m21"':
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 1)
elif label == 'in_mc'or label == '"in_mc"':
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 1)
elif label == 'm1_false' or label == '"m1_false"':
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 0)
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
false_belief = 'm1_false'
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
elif label == 'm2_false' or label == '"m2_false"':
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 0)
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
false_belief = 'm2_false'
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
elif label == 'm12_false' or label == '"m12_false"':
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
false_belief = 'm12_false'
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
elif label == 'm21_false' or label == '"m21_false"':
flags, m21 = check_append(obj_name, m2, 'm21', obj_frame, flags, 0)
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
false_belief = 'm21_false'
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
# print('m1', m1)
# print('m2', m2)
# print('m12', m12)
# print('m21', m21)
# print('mc', mc)
#with open(save_path + clip.split('.')[0] + '.p', 'wb') as f:
# pickle.dump([m1, m2, m12, m21, mc], f)
if __name__ == "__main__":
annotation_path = '/scratch/bortoletto/data/tbd/reformat_annotation/'
save_path = '/scratch/bortoletto/data/tbd/store_mind_set/'
for clip in tqdm(os.listdir(annotation_path), desc="Processing videos", unit="item"):
store_mind_set(clip, annotation_path, save_path)

View file

@ -0,0 +1,95 @@
import time
from tbd_dataloader import TBDv2Dataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def point2screen(points):
K = [607.13232421875, 0.0, 638.6468505859375, 0.0, 607.1067504882812, 367.1607360839844, 0.0, 0.0, 1.0]
K = np.reshape(np.array(K), [3, 3])
rot_points = np.array(points) + np.array([0, 0.2, 0])
rot_points = rot_points
points_camera = rot_points.reshape(3, 1)
project_matrix = np.array(K).reshape(3, 3)
points_prj = project_matrix.dot(points_camera)
points_prj = points_prj.transpose()
if not points_prj[:, 2][0] == 0.0:
points_prj[:, 0] = points_prj[:, 0] / points_prj[:, 2]
points_prj[:, 1] = points_prj[:, 1] / points_prj[:, 2]
points_screen = points_prj[:, :2]
assert points_screen.shape == (1, 2)
points_screen = points_screen.reshape(-1)
return points_screen
if __name__ == '__main__':
data = TBDv2Dataset(number_frames_to_sample=1, resize_img=None)
index = np.random.randint(0, len(data))
start = time.time()
(
kinect_imgs, # <- len x 720 x 1280 x 3
tracker_imgs,
battery_imgs,
skele1,
skele2,
bbox,
tracker_skeID_sample, # <- This is the tracker skeleton ID
tracker2d,
label,
experiment_id, # From here for debugging
timestep,
obj_id, # <- This is the object ID as a string
) = data[index]
end = time.time()
print(f"Time for one sample: {end-start}")
img = kinect_imgs[-1]
bbox = bbox[-1]
print(label.shape)
print(skele1.shape)
print(skele2.shape)
skele1 = skele1[-1, :,:]
skele2 = skele2[-1, :,:]
print(skele1.shape)
# reshape img from c, h, w to h, w, c
img = img.permute(1, 2, 0)
fig, ax = plt.subplots(1)
ax.imshow(img)
print(bbox[0], bbox[1], bbox[2], bbox[3]) # t(top left x, top left y, width, height)
top_left_x, top_left_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3]
x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
for i in range(26):
print(skele1[i,0], skele1[i,1])
print(skele1[i,:].shape)
print(point2screen(skele1[i,:]))
x, y = point2screen(skele1[i,:])[0], point2screen(skele1[i,:])[1]
ax.text(x, y, f"{i}", fontsize=5, color='w')
wedge = patches.Wedge((x,y), 10, 0, 360, width=10, color='b')
ax.add_patch(wedge)
for i in range(26):
x, y = point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]
ax.text(x, y, f"{i}", fontsize=5, color='w')
wedge = patches.Wedge((point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]), 10, 0, 360, width=10, color='r')
ax.add_patch(wedge)
# Create a Rectangle patch
# rect = patches.Rectangle((top_left_x, top_left_y-height), width, height, linewidth=1, edgecolor='r', facecolor='none')
# ax.add_patch(rect)
# rect = patches.Rectangle((x_min, y_max), x_max-x_min, y_max-y_min, linewidth=1, edgecolor='g', facecolor='none')
# ax.add_patch(rect)
fig.savefig(f"bbox_{obj_id}_{index}_{experiment_id}.png")