up
This commit is contained in:
parent
d4aaf7f4ad
commit
25b8b3f343
55 changed files with 7592 additions and 4 deletions
196
tbd/.gitignore
vendored
Normal file
196
tbd/.gitignore
vendored
Normal file
|
@ -0,0 +1,196 @@
|
|||
experiments
|
||||
wandb
|
||||
predictions
|
||||
|
||||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
|
||||
|
||||
### Linux ###
|
||||
*~
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python,linux
|
16
tbd/README.md
Normal file
16
tbd/README.md
Normal file
|
@ -0,0 +1,16 @@
|
|||
# TBD
|
||||
|
||||
# Data
|
||||
The original code can be found [here](https://github.com/LifengFan/Triadic-Belief-Dynamics). The dataset is not directly available but must be requested using the link to the Google form provided in the [README](https://github.com/LifengFan/Triadic-Belief-Dynamics?tab=readme-ov-file#dataset).
|
||||
|
||||
## Installing Dependencies
|
||||
Run `conda env create -f environment.yml`.
|
||||
|
||||
## Train
|
||||
`source run_train.sh`.
|
||||
|
||||
## Test
|
||||
`source run_test.sh`. **Make sure to use the same random seed used for training**, otherwise the splits will be different and you will likely have a data leakage.
|
||||
|
||||
## Visualisations
|
||||
The plots are made using `utils/fb_scores_err.py` (false belief analysis) and `utils/similarity.py` (PCA of latent representations).
|
100
tbd/environment.yml
Normal file
100
tbd/environment.yml
Normal file
|
@ -0,0 +1,100 @@
|
|||
name: tbd
|
||||
channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
- pytorch
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- ca-certificates=2023.01.10=h06a4308_0
|
||||
- ld_impl_linux-64=2.38=h1181459_1
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- ncurses=6.4=h6a678d5_0
|
||||
- openssl=1.1.1t=h7f8727e_0
|
||||
- pip=23.0.1=py38h06a4308_0
|
||||
- python=3.8.10=h12debd9_8
|
||||
- readline=8.2=h5eee18b_0
|
||||
- setuptools=66.0.0=py38h06a4308_0
|
||||
- sqlite=3.41.2=h5eee18b_0
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- wheel=0.38.4=py38h06a4308_0
|
||||
- xz=5.4.2=h5eee18b_0
|
||||
- zlib=1.2.13=h5eee18b_0
|
||||
- pip:
|
||||
- appdirs==1.4.4
|
||||
- beautifulsoup4==4.12.2
|
||||
- certifi==2023.5.7
|
||||
- charset-normalizer==3.1.0
|
||||
- click==8.1.3
|
||||
- cmake==3.26.4
|
||||
- contourpy==1.1.0
|
||||
- cycler==0.11.0
|
||||
- docker-pycreds==0.4.0
|
||||
- einops==0.6.1
|
||||
- filelock==3.12.0
|
||||
- fonttools==4.40.0
|
||||
- gdown==4.7.1
|
||||
- gitdb==4.0.10
|
||||
- gitpython==3.1.31
|
||||
- idna==3.4
|
||||
- importlib-resources==5.12.0
|
||||
- jinja2==3.1.2
|
||||
- joblib==1.3.1
|
||||
- kiwisolver==1.4.4
|
||||
- lit==16.0.6
|
||||
- markupsafe==2.1.3
|
||||
- matplotlib==3.7.1
|
||||
- memory-efficient-attention-pytorch==0.1.2
|
||||
- mpmath==1.3.0
|
||||
- networkx==3.1
|
||||
- numpy==1.24.4
|
||||
- nvidia-cublas-cu11==11.10.3.66
|
||||
- nvidia-cuda-cupti-cu11==11.7.101
|
||||
- nvidia-cuda-nvrtc-cu11==11.7.99
|
||||
- nvidia-cuda-runtime-cu11==11.7.99
|
||||
- nvidia-cudnn-cu11==8.5.0.96
|
||||
- nvidia-cufft-cu11==10.9.0.58
|
||||
- nvidia-curand-cu11==10.2.10.91
|
||||
- nvidia-cusolver-cu11==11.4.0.1
|
||||
- nvidia-cusparse-cu11==11.7.4.91
|
||||
- nvidia-nccl-cu11==2.14.3
|
||||
- nvidia-nvtx-cu11==11.7.91
|
||||
- opencv-python==4.8.0.74
|
||||
- packaging==23.1
|
||||
- pandas==2.0.3
|
||||
- pathtools==0.1.2
|
||||
- pillow==9.5.0
|
||||
- protobuf==4.23.3
|
||||
- psutil==5.9.5
|
||||
- pyparsing==3.1.0
|
||||
- pysocks==1.7.1
|
||||
- python-dateutil==2.8.2
|
||||
- pytz==2023.3
|
||||
- pyyaml==6.0
|
||||
- requests==2.30.0
|
||||
- scikit-learn==1.3.0
|
||||
- scipy==1.10.1
|
||||
- seaborn==0.12.2
|
||||
- sentry-sdk==1.27.0
|
||||
- setproctitle==1.3.2
|
||||
- six==1.16.0
|
||||
- smmap==5.0.0
|
||||
- soupsieve==2.4.1
|
||||
- sympy==1.12
|
||||
- threadpoolctl==3.1.0
|
||||
- torch==2.0.1
|
||||
- torch-geometric==2.3.1
|
||||
- torchaudio==2.0.2
|
||||
- torchsampler==0.1.2
|
||||
- torchvision==0.15.2
|
||||
- tqdm==4.65.0
|
||||
- triton==2.0.0
|
||||
- typing-extensions==4.7.0
|
||||
- tzdata==2023.3
|
||||
- urllib3==2.0.2
|
||||
- wandb==0.15.5
|
||||
- zipp==3.15.0
|
||||
prefix: /opt/anaconda3/envs/tbd
|
156
tbd/models/base.py
Normal file
156
tbd/models/base.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .utils import pose_edge_index
|
||||
from torch_geometric.nn import GCNConv
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self, hidden_dim):
|
||||
super(CNN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv2(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv3(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
|
||||
return x
|
||||
|
||||
|
||||
class MindNetLSTM(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetLSTM, self).__init__()
|
||||
self.mods = mods
|
||||
if 'rgb_1' in mods:
|
||||
self.img_emb = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
if 'gaze' in mods:
|
||||
self.gaze_emb = nn.Linear(2, hidden_dim)
|
||||
if 'pose' in mods:
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
|
||||
feats = []
|
||||
if 'rgb_3' in self.mods:
|
||||
feats.append(rgb_3rd_pov_feats)
|
||||
if 'rgb_1' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(rgb_1st_pov.shape[1]):
|
||||
images_i = rgb_1st_pov[:,i]
|
||||
img_i_feat = self.img_emb(images_i)
|
||||
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
feats.append(rgb_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats.mean(2))
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||||
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||||
return self.act(self.proj(lstm_out)), c_n, feats
|
||||
|
||||
|
||||
class MindNetSL(nn.Module):
|
||||
"""
|
||||
Basic MindNet for SL ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetSL, self).__init__()
|
||||
self.mods = mods
|
||||
if 'rgb_1' in mods:
|
||||
self.img_emb = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
if 'gaze' in mods:
|
||||
self.gaze_emb = nn.Linear(2, hidden_dim)
|
||||
if 'pose' in mods:
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
|
||||
feats = []
|
||||
if 'rgb_3' in self.mods:
|
||||
feats.append(rgb_3rd_pov_feats)
|
||||
if 'rgb_1' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(rgb_1st_pov.shape[1]):
|
||||
images_i = rgb_1st_pov[:,i]
|
||||
img_i_feat = self.img_emb(images_i)
|
||||
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
feats.append(rgb_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats.mean(2))
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
|
||||
return self.act(self.proj(lstm_out)), feats
|
157
tbd/models/common_mind.py
Normal file
157
tbd/models/common_mind.py
Normal file
|
@ -0,0 +1,157 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class CommonMindToMnet(nn.Module):
|
||||
"""
|
||||
img: bs, 3, 128, 128
|
||||
pose: bs, 26, 3
|
||||
gaze: bs, 2 NOTE: only tracker has gaze
|
||||
bbox: bs, 4
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']):
|
||||
super(CommonMindToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
#for param in self.cnn.parameters():
|
||||
# param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.ln_1 = nn.LayerNorm(hidden_dim)
|
||||
self.ln_2 = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
|
||||
|
||||
common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h)
|
||||
|
||||
if self.aggr == 'attn':
|
||||
p1 = self.attn_left(x=out_1, context=common_mind)
|
||||
p2 = self.attn_right(x=out_2, context=common_mind)
|
||||
elif self.aggr == 'mult':
|
||||
p1 = out_1 * common_mind
|
||||
p2 = out_2 * common_mind
|
||||
elif self.aggr == 'sum':
|
||||
p1 = out_1 + common_mind
|
||||
p2 = out_2 + common_mind
|
||||
elif self.aggr == 'concat':
|
||||
p1 = torch.cat([out_1, common_mind], 1)
|
||||
p2 = torch.cat([out_2, common_mind], 1)
|
||||
else: raise ValueError
|
||||
p1 = self.act(p1)
|
||||
p1 = self.ln_1(p1)
|
||||
p2 = self.act(p2)
|
||||
p2 = self.ln_2(p2)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1)
|
||||
if self.aggr == 'concat':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
mods = ['pose', 'bbox', 'rgb_3']
|
||||
|
||||
for agg in ['no_tom']:
|
||||
model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(out[0].shape)
|
151
tbd/models/implicit.py
Normal file
151
tbd/models/implicit.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class ImplicitToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(ImplicitToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
self.ln_1 = nn.LayerNorm(hidden_dim)
|
||||
self.ln_2 = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
if self.aggr == 'attn':
|
||||
p1 = self.attn_left(x=out_1, context=cell_2)
|
||||
p2 = self.attn_right(x=out_2, context=cell_1)
|
||||
elif self.aggr == 'mult':
|
||||
p1 = out_1 * cell_2
|
||||
p2 = out_2 * cell_1
|
||||
elif self.aggr == 'sum':
|
||||
p1 = out_1 + cell_2
|
||||
p2 = out_2 + cell_1
|
||||
elif self.aggr == 'concat':
|
||||
p1 = torch.cat([out_1, cell_2], 1)
|
||||
p2 = torch.cat([out_2, cell_1], 1)
|
||||
else: raise ValueError
|
||||
p1 = self.act(p1)
|
||||
p1 = self.ln_1(p1)
|
||||
p2 = self.act(p2)
|
||||
p2 = self.ln_2(p2)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1)
|
||||
if self.aggr == 'concat':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(agg, out[0].shape)
|
112
tbd/models/sl.py
Normal file
112
tbd/models/sl.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetSL
|
||||
|
||||
|
||||
class SLToMnet(nn.Module):
|
||||
"""
|
||||
Speaker-Listener ToMnet
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(SLToMnet, self).__init__()
|
||||
|
||||
self.tom_weight = tom_weight
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetSL(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetSL(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
m1_logits = self.m1(out_1).mean(1)
|
||||
m2_logits = self.m2(out_2).mean(1)
|
||||
m12_logits = self.m12(out_1).mean(1)
|
||||
m21_logits = self.m21(out_2).mean(1)
|
||||
mc_logits = self.mc(out_1*out_2).mean(1)
|
||||
|
||||
m1_ranking = torch.log_softmax(m1_logits, dim=-1)
|
||||
m2_ranking = torch.log_softmax(m2_logits, dim=-1)
|
||||
m12_ranking = torch.log_softmax(m12_logits, dim=-1)
|
||||
m21_ranking = torch.log_softmax(m21_logits, dim=-1)
|
||||
mc_ranking = torch.log_softmax(mc_logits, dim=-1)
|
||||
|
||||
# NOTE: does this make sense?
|
||||
m1 = m1_ranking + self.tom_weight * m2_ranking
|
||||
m2 = m2_ranking + self.tom_weight * m1_ranking
|
||||
m12 = m12_ranking + self.tom_weight * m21_ranking
|
||||
m21 = m21_ranking + self.tom_weight * m12_ranking
|
||||
mc = mc_ranking + self.tom_weight * mc_ranking
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
model = SLToMnet(hidden_dim=64, device='cpu', tom_weight=2.0, resnet=False, dropout=0.5)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(out[0].shape)
|
112
tbd/models/tom_base.py
Normal file
112
tbd/models/tom_base.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ImplicitToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(ImplicitToMnet, self).__init__()
|
||||
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
|
||||
print(count_parameters(model))
|
||||
breakpoint()
|
||||
|
||||
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(agg, out[0].shape)
|
7
tbd/models/utils.py
Normal file
7
tbd/models/utils.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
import torch
|
||||
|
||||
|
||||
def pose_edge_index():
|
||||
start = [15, 14, 13, 12, 19, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 3, 4, 5, 6, 8, 8, 4, 20, 21, 21, 22, 24, 22]
|
||||
end = [14, 13, 12, 0, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 4, 20, 20, 21, 22, 24, 23, 25, 24]
|
||||
return torch.tensor([start+end, end+start])
|
186
tbd/results/abl.json
Normal file
186
tbd/results/abl.json
Normal file
|
@ -0,0 +1,186 @@
|
|||
{
|
||||
"all": [
|
||||
{
|
||||
"m1": 0.3803337111550836,
|
||||
"m2": 0.3900899763574355,
|
||||
"m12": 0.4441281276628709,
|
||||
"m21": 0.4818757648120031,
|
||||
"mc": 0.4485177767702456
|
||||
},
|
||||
{
|
||||
"m1": 0.5186066842992191,
|
||||
"m2": 0.521895750052127,
|
||||
"m12": 0.49294626980529677,
|
||||
"m21": 0.4810118034501327,
|
||||
"mc": 0.6097300398369058
|
||||
},
|
||||
{
|
||||
"m1": 0.4965589148309122,
|
||||
"m2": 0.5094894309980568,
|
||||
"m12": 0.4615136302786905,
|
||||
"m21": 0.4554005550423429,
|
||||
"mc": 0.6258118710785031
|
||||
}
|
||||
],
|
||||
"rgb_3_pose_gaze_bbox": [
|
||||
{
|
||||
"m1": 0.3776045061727805,
|
||||
"m2": 0.3996776745150713,
|
||||
"m12": 0.4762772810038159,
|
||||
"m21": 0.48643178296718503,
|
||||
"mc": 0.4575207273412474
|
||||
},
|
||||
{
|
||||
"m1": 0.5176564423560418,
|
||||
"m2": 0.5109344883698214,
|
||||
"m12": 0.4630213122846928,
|
||||
"m21": 0.4826608133674547,
|
||||
"mc": 0.5979415365779003
|
||||
},
|
||||
{
|
||||
"m1": 0.5114692300997931,
|
||||
"m2": 0.5027048375802656,
|
||||
"m12": 0.47527894405588544,
|
||||
"m21": 0.45223985157847546,
|
||||
"mc": 0.6054099305712209
|
||||
}
|
||||
],
|
||||
"rgb_3_pose_gaze": [
|
||||
{
|
||||
"m1": 0.403207421026191,
|
||||
"m2": 0.3833413122398237,
|
||||
"m12": 0.4602455224198077,
|
||||
"m21": 0.47181798537346287,
|
||||
"mc": 0.4603675297898878
|
||||
},
|
||||
{
|
||||
"m1": 0.49484810149311514,
|
||||
"m2": 0.5060275976807422,
|
||||
"m12": 0.4610412452830618,
|
||||
"m21": 0.46869095956564044,
|
||||
"mc": 0.6040674897817755
|
||||
},
|
||||
{
|
||||
"m1": 0.5160598186177866,
|
||||
"m2": 0.5309683014233921,
|
||||
"m12": 0.47227245803060636,
|
||||
"m21": 0.46953974307035984,
|
||||
"mc": 0.6014771460423635
|
||||
}
|
||||
],
|
||||
"rgb_3_pose": [
|
||||
{
|
||||
"m1": 0.4057149181928123,
|
||||
"m2": 0.4002233785689204,
|
||||
"m12": 0.46794813614607333,
|
||||
"m21": 0.4690365183933033,
|
||||
"mc": 0.4591530208921514
|
||||
},
|
||||
{
|
||||
"m1": 0.5362792166212834,
|
||||
"m2": 0.5290656046231254,
|
||||
"m12": 0.4569419683345858,
|
||||
"m21": 0.4530255281497826,
|
||||
"mc": 0.4554252731371068
|
||||
},
|
||||
{
|
||||
"m1": 0.49570625763169085,
|
||||
"m2": 0.5146503967646507,
|
||||
"m12": 0.4567936139893578,
|
||||
"m21": 0.45918214877096325,
|
||||
"mc": 0.5962397441246001
|
||||
}
|
||||
],
|
||||
"rgb_3_gaze": [
|
||||
{
|
||||
"m1": 0.40135106828655215,
|
||||
"m2": 0.38453470155825614,
|
||||
"m12": 0.4989742833725901,
|
||||
"m21": 0.47369273992079175,
|
||||
"mc": 0.48430622854433986
|
||||
},
|
||||
{
|
||||
"m1": 0.508038122818153,
|
||||
"m2": 0.4875748099051746,
|
||||
"m12": 0.46665443622698555,
|
||||
"m21": 0.46635808547742913,
|
||||
"mc": 0.47936993226840163
|
||||
},
|
||||
{
|
||||
"m1": 0.49795853039610977,
|
||||
"m2": 0.5028666890527814,
|
||||
"m12": 0.44176709237564815,
|
||||
"m21": 0.4483898274665582,
|
||||
"mc": 0.5867527750929912
|
||||
}
|
||||
],
|
||||
"rgb_3_bbox": [
|
||||
{
|
||||
"m1": 0.3951383898241492,
|
||||
"m2": 0.3818794542844425,
|
||||
"m12": 0.44108151735270384,
|
||||
"m21": 0.46539754196523303,
|
||||
"mc": 0.43982185797713114
|
||||
},
|
||||
{
|
||||
"m1": 0.5093846655989521,
|
||||
"m2": 0.4923439212866733,
|
||||
"m12": 0.4598003475323884,
|
||||
"m21": 0.47647640659290746,
|
||||
"mc": 0.6349953712994137
|
||||
},
|
||||
{
|
||||
"m1": 0.5325224862402295,
|
||||
"m2": 0.5092319973570975,
|
||||
"m12": 0.4435807136490263,
|
||||
"m21": 0.4576911633624616,
|
||||
"mc": 0.6282064277856357
|
||||
}
|
||||
],
|
||||
"rgb_3_rgb_1": [
|
||||
{
|
||||
"m1": 0.39189391736691903,
|
||||
"m2": 0.3739995635963588,
|
||||
"m12": 0.4792392731637056,
|
||||
"m21": 0.4592726043789752,
|
||||
"mc": 0.4468645255652386
|
||||
},
|
||||
{
|
||||
"m1": 0.4827892482357646,
|
||||
"m2": 0.48042899735042716,
|
||||
"m12": 0.45932653547051094,
|
||||
"m21": 0.48430209616318126,
|
||||
"mc": 0.4506104344435269
|
||||
},
|
||||
{
|
||||
"m1": 0.4820247145474279,
|
||||
"m2": 0.3667553358192628,
|
||||
"m12": 0.44503028688537,
|
||||
"m21": 0.45984906207471654,
|
||||
"mc": 0.465120658971623
|
||||
}
|
||||
],
|
||||
"rgb_3": [
|
||||
{
|
||||
"m1": 0.40725462165126114,
|
||||
"m2": 0.38737351624656846,
|
||||
"m12": 0.46230461548252094,
|
||||
"m21": 0.4829312519709871,
|
||||
"mc": 0.4492175856929955
|
||||
},
|
||||
{
|
||||
"m1": 0.5286274183685061,
|
||||
"m2": 0.5081429492163979,
|
||||
"m12": 0.4610256989472217,
|
||||
"m21": 0.4733487634477733,
|
||||
"mc": 0.4655243312197501
|
||||
},
|
||||
{
|
||||
"m1": 0.5217968210271873,
|
||||
"m2": 0.5103780571157844,
|
||||
"m12": 0.4431266771306429,
|
||||
"m21": 0.48398542131284883,
|
||||
"mc": 0.6122314353959392
|
||||
}
|
||||
]
|
||||
}
|
232
tbd/results/all.json
Normal file
232
tbd/results/all.json
Normal file
|
@ -0,0 +1,232 @@
|
|||
{
|
||||
"cm_concat": [
|
||||
{
|
||||
"m1": 0.38921744471949393,
|
||||
"m2": 0.38557137008494935,
|
||||
"m12": 0.44699534554593756,
|
||||
"m21": 0.4747474437468054,
|
||||
"mc": 0.4918107834016411
|
||||
},
|
||||
{
|
||||
"m1": 0.5402415140026018,
|
||||
"m2": 0.48833721513836786,
|
||||
"m12": 0.4631512445419047,
|
||||
"m21": 0.4740880083492652,
|
||||
"mc": 0.6375070925808958
|
||||
},
|
||||
{
|
||||
"m1": 0.5012543523713172,
|
||||
"m2": 0.5068694866895836,
|
||||
"m12": 0.4451537834591627,
|
||||
"m21": 0.45215784721598673,
|
||||
"mc": 0.6201022576104379
|
||||
}
|
||||
],
|
||||
"cm_sum": [
|
||||
{
|
||||
"m1": 0.39403894801783246,
|
||||
"m2": 0.38541918219411786,
|
||||
"m12": 0.4600376974144952,
|
||||
"m21": 0.471919704007463,
|
||||
"mc": 0.43950812310207055
|
||||
},
|
||||
{
|
||||
"m1": 0.48497621104052574,
|
||||
"m2": 0.5295044689855949,
|
||||
"m12": 0.4502949472343065,
|
||||
"m21": 0.47823492553894387,
|
||||
"mc": 0.6028290833617195
|
||||
},
|
||||
{
|
||||
"m1": 0.503386104373653,
|
||||
"m2": 0.49983127146477085,
|
||||
"m12": 0.46782817568218116,
|
||||
"m21": 0.45484578845116075,
|
||||
"mc": 0.5905749126722909
|
||||
}
|
||||
],
|
||||
"cm_mult": [
|
||||
{
|
||||
"m1": 0.39070820515470606,
|
||||
"m2": 0.3996851353903932,
|
||||
"m12": 0.4455704586852128,
|
||||
"m21": 0.4713517869738811,
|
||||
"mc": 0.4450907029478458
|
||||
},
|
||||
{
|
||||
"m1": 0.5066540697731119,
|
||||
"m2": 0.526507445454099,
|
||||
"m12": 0.462643008560599,
|
||||
"m21": 0.48263054309565334,
|
||||
"mc": 0.6438566476782207
|
||||
},
|
||||
{
|
||||
"m1": 0.48868811674304546,
|
||||
"m2": 0.5074635877653536,
|
||||
"m12": 0.44597405775819876,
|
||||
"m21": 0.45445350963025877,
|
||||
"mc": 0.5884265473527218
|
||||
}
|
||||
],
|
||||
"cm_attn": [
|
||||
{
|
||||
"m1": 0.3949557687114269,
|
||||
"m2": 0.3919385900921811,
|
||||
"m12": 0.4850081112466773,
|
||||
"m21": 0.4849575556679713,
|
||||
"mc": 0.4516870089239762
|
||||
},
|
||||
{
|
||||
"m1": 0.4925989821370256,
|
||||
"m2": 0.49409170532242247,
|
||||
"m12": 0.4664647278240569,
|
||||
"m21": 0.46783863397462533,
|
||||
"mc": 0.6398721139927354
|
||||
},
|
||||
{
|
||||
"m1": 0.4945636568169018,
|
||||
"m2": 0.5049812790749876,
|
||||
"m12": 0.454359577718189,
|
||||
"m21": 0.4712184012093268,
|
||||
"mc": 0.5992735441011302
|
||||
}
|
||||
],
|
||||
"no_tom": [
|
||||
{
|
||||
"m1": 0.2570551317,
|
||||
"m2": 0.375350929686332,
|
||||
"m12": 0.312451988649724,
|
||||
"m21": 0.4631371031641,
|
||||
"mc": 0.457486278214567
|
||||
},
|
||||
{
|
||||
"m1": 0.233046800382043,
|
||||
"m2": 0.522609755931958,
|
||||
"m12": 0.326821758467328,
|
||||
"m21": 0.474338898013257,
|
||||
"mc": 0.604439456291308
|
||||
},
|
||||
{
|
||||
"m1": 0.33774852598382,
|
||||
"m2": 0.520943544364353,
|
||||
"m12": 0.298617214416867,
|
||||
"m21": 0.482175301427192,
|
||||
"mc": 0.634948478570852
|
||||
}
|
||||
],
|
||||
"sl": [
|
||||
{
|
||||
"m1": 0.365205706591741,
|
||||
"m2": 0.255259363011619,
|
||||
"m12": 0.421227579844245,
|
||||
"m21": 0.376143327741882,
|
||||
"mc": 0.45614515353718
|
||||
},
|
||||
{
|
||||
"m1": 0.493046934143676,
|
||||
"m2": 0.331798174804139,
|
||||
"m12": 0.422821548330913,
|
||||
"m21": 0.399768928780549,
|
||||
"mc": 0.450957023549231
|
||||
},
|
||||
{
|
||||
"m1": 0.466266787709392,
|
||||
"m2": 0.350962671130227,
|
||||
"m12": 0.431694150269919,
|
||||
"m21": 0.378863431433258,
|
||||
"mc": 0.470284405744656
|
||||
}
|
||||
],
|
||||
"impl_concat": [
|
||||
{
|
||||
"m1": 0.38427302094644894,
|
||||
"m2": 0.38673879043767634,
|
||||
"m12": 0.45694337561663145,
|
||||
"m21": 0.4737891562722213,
|
||||
"mc": 0.4502976351448088
|
||||
},
|
||||
{
|
||||
"m1": 0.49951068243751173,
|
||||
"m2": 0.5084945752383908,
|
||||
"m12": 0.4604721097809549,
|
||||
"m21": 0.4826884970930907,
|
||||
"mc": 0.6200443272625361
|
||||
},
|
||||
{
|
||||
"m1": 0.5013244243339088,
|
||||
"m2": 0.49476495726495723,
|
||||
"m12": 0.4596701406290429,
|
||||
"m21": 0.4554742441542813,
|
||||
"mc": 0.5988949378402535
|
||||
}
|
||||
],
|
||||
"impl_sum": [
|
||||
{
|
||||
"m1": 0.3803337111550836,
|
||||
"m2": 0.3900899763574355,
|
||||
"m12": 0.4441281276628709,
|
||||
"m21": 0.4818757648120031,
|
||||
"mc": 0.4485177767702456
|
||||
},
|
||||
{
|
||||
"m1": 0.5186066842992191,
|
||||
"m2": 0.521895750052127,
|
||||
"m12": 0.49294626980529677,
|
||||
"m21": 0.4810118034501327,
|
||||
"mc": 0.6097300398369058
|
||||
},
|
||||
{
|
||||
"m1": 0.4965589148309122,
|
||||
"m2": 0.5094894309980568,
|
||||
"m12": 0.4615136302786905,
|
||||
"m21": 0.4554005550423429,
|
||||
"mc": 0.6258118710785031
|
||||
}
|
||||
],
|
||||
"impl_mult": [
|
||||
{
|
||||
"m1": 0.3789421413006731,
|
||||
"m2": 0.3818053844554785,
|
||||
"m12": 0.46402717346945177,
|
||||
"m21": 0.4903726261039529,
|
||||
"mc": 0.4461443806398687
|
||||
},
|
||||
{
|
||||
"m1": 0.3789421413006731,
|
||||
"m2": 0.3818053844554785,
|
||||
"m12": 0.46402717346945177,
|
||||
"m21": 0.4903726261039529,
|
||||
"mc": 0.4461443806398687
|
||||
},
|
||||
{
|
||||
"m1": 0.49338554196342077,
|
||||
"m2": 0.5066817652688608,
|
||||
"m12": 0.46253374461930613,
|
||||
"m21": 0.47782311190445825,
|
||||
"mc": 0.4581608719646799
|
||||
}
|
||||
],
|
||||
"impl_attn": [
|
||||
{
|
||||
"m1": 0.37413691393147924,
|
||||
"m2": 0.2546966838007244,
|
||||
"m12": 0.429390512693598,
|
||||
"m21": 0.292401773870023,
|
||||
"mc": 0.45706325836224465
|
||||
},
|
||||
{
|
||||
"m1": 0.513917904196177,
|
||||
"m2": 0.25802580258025803,
|
||||
"m12": 0.49272662664765543,
|
||||
"m21": 0.27041556176385584,
|
||||
"mc": 0.6041394755857196
|
||||
},
|
||||
{
|
||||
"m1": 0.47720445038981674,
|
||||
"m2": 0.25839328537170264,
|
||||
"m12": 0.46505055463781547,
|
||||
"m21": 0.260276985433943,
|
||||
"mc": 0.6021811271770562
|
||||
}
|
||||
]
|
||||
}
|
BIN
tbd/results/false_belief_first_vs_second.pdf
Normal file
BIN
tbd/results/false_belief_first_vs_second.pdf
Normal file
Binary file not shown.
87
tbd/results/fb_ttest.txt
Normal file
87
tbd/results/fb_ttest.txt
Normal file
|
@ -0,0 +1,87 @@
|
|||
|
||||
========================================================= m1_m2_m12_m21
|
||||
|
||||
|
||||
Model: Base -> yes
|
||||
|
||||
|
||||
Model: DB -> yes
|
||||
|
||||
|
||||
Model: CG$\oplus$ -> no
|
||||
|
||||
|
||||
Model: CG$\otimes$ -> no
|
||||
|
||||
|
||||
Model: CG$\odot$ -> no
|
||||
|
||||
|
||||
Model: IC$\parallel$ -> no
|
||||
|
||||
|
||||
Model: IC$\oplus$ -> no
|
||||
|
||||
|
||||
Model: IC$\otimes$ -> no
|
||||
|
||||
|
||||
Model: IC$\odot$ -> yes
|
||||
|
||||
========================================================= m1_m2
|
||||
|
||||
|
||||
Model: Base -> yes
|
||||
|
||||
|
||||
Model: DB -> yes
|
||||
|
||||
|
||||
Model: CG$\oplus$ -> no
|
||||
|
||||
|
||||
Model: CG$\otimes$ -> no
|
||||
|
||||
|
||||
Model: CG$\odot$ -> no
|
||||
|
||||
|
||||
Model: IC$\parallel$ -> no
|
||||
|
||||
|
||||
Model: IC$\oplus$ -> no
|
||||
|
||||
|
||||
Model: IC$\otimes$ -> no
|
||||
|
||||
|
||||
Model: IC$\odot$ -> yes
|
||||
|
||||
========================================================= m12_m21
|
||||
|
||||
|
||||
Model: Base -> yes
|
||||
|
||||
|
||||
Model: DB -> yes
|
||||
|
||||
|
||||
Model: CG$\oplus$ -> no
|
||||
|
||||
|
||||
Model: CG$\otimes$ -> no
|
||||
|
||||
|
||||
Model: CG$\odot$ -> no
|
||||
|
||||
|
||||
Model: IC$\parallel$ -> no
|
||||
|
||||
|
||||
Model: IC$\oplus$ -> no
|
||||
|
||||
|
||||
Model: IC$\otimes$ -> no
|
||||
|
||||
|
||||
Model: IC$\odot$ -> yes
|
59
tbd/results/hgm_scores.txt
Normal file
59
tbd/results/hgm_scores.txt
Normal file
|
@ -0,0 +1,59 @@
|
|||
mc =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.000 0.500 0.001 50
|
||||
1 0.000 0.000 0.000 4
|
||||
2 0.004 0.038 0.007 238
|
||||
3 0.999 0.795 0.885 290788
|
||||
|
||||
accuracy 0.794 291080
|
||||
macro avg 0.251 0.333 0.223 291080
|
||||
weighted avg 0.998 0.794 0.884 291080
|
||||
|
||||
m1 =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.000 0.000 0.000 147
|
||||
1 0.000 0.000 0.000 2
|
||||
2 0.025 0.051 0.033 1714
|
||||
3 0.994 0.988 0.991 289217
|
||||
|
||||
accuracy 0.982 291080
|
||||
macro avg 0.255 0.260 0.256 291080
|
||||
weighted avg 0.988 0.982 0.985 291080
|
||||
|
||||
m2 =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.001 0.013 0.001 151
|
||||
2 0.031 0.084 0.045 2394
|
||||
3 0.992 0.970 0.981 288535
|
||||
|
||||
accuracy 0.962 291080
|
||||
macro avg 0.341 0.355 0.342 291080
|
||||
weighted avg 0.983 0.962 0.972 291080
|
||||
|
||||
m12 =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.000 0.000 0.000 93
|
||||
1 0.000 0.000 0.000 8
|
||||
2 0.015 0.056 0.023 676
|
||||
3 0.997 0.990 0.994 290303
|
||||
|
||||
accuracy 0.988 291080
|
||||
macro avg 0.253 0.262 0.254 291080
|
||||
weighted avg 0.995 0.988 0.991 291080
|
||||
|
||||
m21 =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.002 0.012 0.003 86
|
||||
1 0.000 0.000 0.000 12
|
||||
2 0.010 0.040 0.016 658
|
||||
3 0.997 0.989 0.993 290324
|
||||
|
||||
accuracy 0.987 291080
|
||||
macro avg 0.252 0.260 0.253 291080
|
||||
weighted avg 0.995 0.987 0.991 291080
|
||||
|
BIN
tbd/results/tbd_abl_avg_only.pdf
Normal file
BIN
tbd/results/tbd_abl_avg_only.pdf
Normal file
Binary file not shown.
12
tbd/run_test.sh
Normal file
12
tbd/run_test.sh
Normal file
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
|
||||
python -m test \
|
||||
--gpu_id 1 \
|
||||
--seed 1 \
|
||||
--non_blocking \
|
||||
--pin_memory \
|
||||
--model_type tom_cm \
|
||||
--aggr no_tom \
|
||||
--hidden_dim 64 \
|
||||
--batch_size 64 \
|
||||
--load_model_path /PATH/TO/model
|
16
tbd/run_train.sh
Normal file
16
tbd/run_train.sh
Normal file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
|
||||
python -m train \
|
||||
--gpu_id 2 \
|
||||
--seed 123 \
|
||||
--logger \
|
||||
--non_blocking \
|
||||
--pin_memory \
|
||||
--batch_size 64 \
|
||||
--num_workers 16 \
|
||||
--num_epoch 300 \
|
||||
--lr 5e-4 \
|
||||
--dropout 0.1 \
|
||||
--model_type tom_cm \
|
||||
--aggr no_tom \
|
||||
--hidden_dim 64
|
568
tbd/tbd_dataloader.py
Normal file
568
tbd/tbd_dataloader.py
Normal file
|
@ -0,0 +1,568 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import pickle
|
||||
import torch
|
||||
import time
|
||||
import glob
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import cv2
|
||||
from itertools import product
|
||||
import csv
|
||||
import torchvision.transforms as T
|
||||
|
||||
from utils.helpers import tracker_skeID, CLIPS_IDS_88, ALL_IDS, UNIQUE_OBJ_IDS
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
# Unpack the batch into individual elements
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep = zip(*batch)
|
||||
|
||||
# Determine the maximum number of objects in any batch
|
||||
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
|
||||
|
||||
# Pad the bounding box tensors
|
||||
bboxes_pad = []
|
||||
for bbox in bboxes:
|
||||
pad_size = max_n_obj - bbox.shape[1]
|
||||
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
|
||||
padded_bbox = torch.cat((bbox, pad), dim=1)
|
||||
bboxes_pad.append(padded_bbox)
|
||||
|
||||
# Stack the padded tensors into a batch tensor
|
||||
bboxes_batch = torch.stack(bboxes_pad, dim=0)
|
||||
|
||||
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
|
||||
img_tracker = torch.stack(img_tracker, dim=0)
|
||||
img_battery = torch.stack(img_battery, dim=0)
|
||||
pose1 = torch.stack(pose1, dim=0)
|
||||
pose2 = torch.stack(pose2, dim=0)
|
||||
gaze = torch.stack(gaze, dim=0)
|
||||
labels = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
# Return the batched tensors
|
||||
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep
|
||||
|
||||
def collate_fn_test(batch):
|
||||
# Unpack the batch into individual elements
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep, false_beliefs = zip(*batch)
|
||||
|
||||
# Determine the maximum number of objects in any batch
|
||||
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
|
||||
|
||||
# Pad the bounding box tensors
|
||||
bboxes_pad = []
|
||||
for bbox in bboxes:
|
||||
pad_size = max_n_obj - bbox.shape[1]
|
||||
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
|
||||
padded_bbox = torch.cat((bbox, pad), dim=1)
|
||||
bboxes_pad.append(padded_bbox)
|
||||
|
||||
# Stack the padded tensors into a batch tensor
|
||||
bboxes_batch = torch.stack(bboxes_pad, dim=0)
|
||||
|
||||
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
|
||||
img_tracker = torch.stack(img_tracker, dim=0)
|
||||
img_battery = torch.stack(img_battery, dim=0)
|
||||
pose1 = torch.stack(pose1, dim=0)
|
||||
pose2 = torch.stack(pose2, dim=0)
|
||||
gaze = torch.stack(gaze, dim=0)
|
||||
labels = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
# Return the batched tensors
|
||||
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep, false_beliefs
|
||||
|
||||
|
||||
class TBDDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str = "/scratch/bortoletto/data/tbd",
|
||||
mode: str = "train",
|
||||
tbd_data_path: str = "/scratch/bortoletto/data/tbd/mind_lstm_training_cnn_att/",
|
||||
list_of_ids_to_consider: list = ALL_IDS,
|
||||
use_preprocessed_img: bool = True,
|
||||
resize_img: Optional[Union[tuple, int]] = (128,128),
|
||||
):
|
||||
"""TBD Dataset based on the 88 clip version of the TBD data.
|
||||
|
||||
Expects the following folder structure:
|
||||
- path
|
||||
- tracker_gt_smooth <- These are eye tracking from POV, 2D coordinates
|
||||
- images/*/ <- These are the images by experiment id
|
||||
- battery <- These are the images, 1st Person
|
||||
- tracker <- These are the images, 1st other Person w/ eye fixation
|
||||
- kinect <- These are the images, 3rd Person
|
||||
- skeleton <- Pose estimation, 3D coordinates
|
||||
- annotation <- These are the labels, i.e. [0,3] (see below)
|
||||
|
||||
|
||||
Labels are strcutured as follows:
|
||||
{
|
||||
"O1": [ <- Object with id O1
|
||||
{
|
||||
"m1": {
|
||||
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
|
||||
"loc": null
|
||||
},
|
||||
"m2": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"m12": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"m21": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"mc": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"mg": {
|
||||
"fluent": 3,
|
||||
"loc": [
|
||||
22,
|
||||
9
|
||||
]
|
||||
}
|
||||
}, ...
|
||||
], ...
|
||||
}
|
||||
|
||||
This corresponds to a strict subset of the raw dataset collected
|
||||
by the TBD people in their paper "Learning Traidic Belief Dynamics
|
||||
in Nonverbal Communication from Videos" (CVPR2021, Oral).
|
||||
|
||||
We keep small amounts of data in memory (everything <100MB).
|
||||
Otherwise we read from disk on the fly. This dataset applies normalization.
|
||||
|
||||
Args:
|
||||
path (str, optional): Where the folders lie.
|
||||
Defaults to "/scratch/ruhdorfer/triadic_beleif_data_v2".
|
||||
list_of_ids_to_consider (list, optional): List of ids to consider.
|
||||
Defaults to ALL_IDS. Otherwise specify a list,
|
||||
e.g. ["test_94342_23", "test_boelter_21", ...].
|
||||
resize_img (Optional[Union[tuple, int]], optional): Resize image to
|
||||
this size if required. Defaults to None.
|
||||
"""
|
||||
print(f"Loading TBD Dataset in mode {mode}...")
|
||||
|
||||
self.mode = mode
|
||||
|
||||
start = time.time()
|
||||
|
||||
self.skeleton_3D_path = f"{path}/skeleton"
|
||||
self.tracker_2D_path = f"{path}/tracker_gt_smooth"
|
||||
self.bbox_csv_path = f"{path}/annotations_with_bbox.csv"
|
||||
if use_preprocessed_img:
|
||||
self.img_path = f"{path}/images_norm"
|
||||
else:
|
||||
self.img_path = f"{path}/images"
|
||||
self.obj_ids_path = f"{path}/mind_lstm_training_cnn_att_shu.pkl"
|
||||
|
||||
self.label_map = list(product([0, 1, 2, 3], repeat=5))
|
||||
|
||||
clips = os.listdir(tbd_data_path)
|
||||
data = []
|
||||
labels = []
|
||||
for clip in clips:
|
||||
with open(tbd_data_path + clip, 'rb') as f:
|
||||
vec_input, label_ = pickle.load(f, encoding='latin1')
|
||||
data = data + vec_input
|
||||
labels = labels + label_
|
||||
c = list(zip(data, labels))
|
||||
random.shuffle(c)
|
||||
train_ratio = int(len(c) * 0.6)
|
||||
validate_ratio = int(len(c) * 0.2)
|
||||
data, label = zip(*c)
|
||||
train_x, train_y = data[:train_ratio], label[:train_ratio]
|
||||
validate_x, validate_y = data[train_ratio:train_ratio + validate_ratio], label[train_ratio:train_ratio + validate_ratio]
|
||||
test_x, test_y = data[train_ratio + validate_ratio:], label[train_ratio + validate_ratio:]
|
||||
self.mind_count = np.zeros(1024) # used for CE weights
|
||||
|
||||
if mode == "train":
|
||||
self.data, self.labels = train_x, train_y
|
||||
elif mode == "val":
|
||||
self.data, self.labels = validate_x, validate_y
|
||||
elif mode == "test":
|
||||
self.data, self.labels = test_x, test_y
|
||||
|
||||
self.false_beliefs_path = f"{path}/store_mind_set"
|
||||
|
||||
# keep small amouts of data in memory
|
||||
self.skeleton_3D = self.load_skeleton_3D(self.skeleton_3D_path, list_of_ids_to_consider)
|
||||
self.tracker_2D = self.load_tracker_2D(self.tracker_2D_path, list_of_ids_to_consider)
|
||||
self.bbox_df = pd.read_csv(self.bbox_csv_path, header=0)
|
||||
self.obj_ids = self.load_obj_ids(self.obj_ids_path)
|
||||
|
||||
if not use_preprocessed_img:
|
||||
normalisation_steps = [
|
||||
T.ToTensor(),
|
||||
T.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
]
|
||||
if resize_img is not None:
|
||||
normalisation_steps.insert(1, T.Resize(resize_img))
|
||||
self.preprocess_img = T.Compose(normalisation_steps)
|
||||
else:
|
||||
self.preprocess_img = None
|
||||
|
||||
self.use_preprocessed_img = use_preprocessed_img
|
||||
print(f"Done loading in {time.time() - start}s.")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(
|
||||
self, idx: int
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict, str, int, str
|
||||
]:
|
||||
"""Given an index, return the corresponding experiment_id and timestep in the experiment.
|
||||
Then picky the appropriate data and labels from these.
|
||||
|
||||
Args:
|
||||
idx (int): _description_
|
||||
|
||||
Returns:
|
||||
tuple: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict
|
||||
Returns the following:
|
||||
- img_kinect: torch.Tensor of shape (T, C, H, W) (Default is [T, 3, 720, 1280])
|
||||
- img_tracker: torch.Tensor of shape (T, H, W, C)
|
||||
- img_battery: torch.Tensor of shape (T, H, W, C)
|
||||
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 1)
|
||||
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 2)
|
||||
- bbox: torch.Tensor of shape (T, num_obj, 5)
|
||||
- tracker to skeleton ID: str (either skeleton 1 or 2)
|
||||
- tracker_2D: torch.Tensor of shape (T, 2)
|
||||
- labels: dict (see below)
|
||||
"""
|
||||
labels = self.label_map[self.labels[idx]]
|
||||
experiment_id = self.data[idx][1][0].split('/')[6]
|
||||
img_data_path = f"{self.img_path}/{experiment_id}"
|
||||
frame_ids = [int(os.path.basename(self.data[idx][1][i]).split('_')[0]) for i in range(len(self.data[idx][1]))]
|
||||
|
||||
if self.use_preprocessed_img:
|
||||
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.pt")))
|
||||
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.pt")))
|
||||
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.pt")))
|
||||
kinect_img_paths = [kinect[id] for id in frame_ids]
|
||||
tracker_img_paths = [tracker[id] for id in frame_ids]
|
||||
battery_img_paths = [battery[id] for id in frame_ids]
|
||||
else:
|
||||
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.jpg")))
|
||||
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.jpg")))
|
||||
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.jpg")))
|
||||
kinect_img_paths = [kinect[id] for id in frame_ids]
|
||||
tracker_img_paths = [tracker[id] for id in frame_ids]
|
||||
battery_img_paths = [battery[id] for id in frame_ids]
|
||||
|
||||
# load images
|
||||
kinect_imgs = [
|
||||
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
|
||||
for img_path in kinect_img_paths
|
||||
]
|
||||
kinect_imgs = torch.stack(kinect_imgs, axis=0)
|
||||
|
||||
tracker_imgs = [
|
||||
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
|
||||
for img_path in tracker_img_paths
|
||||
]
|
||||
tracker_imgs = torch.stack(tracker_imgs, axis=0)
|
||||
|
||||
battery_imgs = [
|
||||
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
|
||||
for img_path in battery_img_paths
|
||||
]
|
||||
battery_imgs = torch.stack(battery_imgs, axis=0)
|
||||
|
||||
# load object id to check for false beliefs - only for testing
|
||||
if self.mode == "test": #or self.mode == "train":
|
||||
if f"{experiment_id}.txt" in os.listdir(self.false_beliefs_path):
|
||||
obj_id = self.obj_ids[experiment_id][frame_ids[-1]]
|
||||
obj_id = next(x for x in obj_id if x is not None)
|
||||
false_belief = next((line.strip().split(',')[2] for line in open(f"{self.false_beliefs_path}/{experiment_id}.txt") if line.startswith(str(frame_ids[-1]) + ',' + obj_id + ',')), "no")
|
||||
#if experiment_id in ['test_boelter4_0', 'test_boelter4_7', 'test_boelter4_6', 'test_boelter4_8', 'test_boelter2_3',
|
||||
# 'test_94342_20', 'test_94342_18', 'test_94342_11', 'test_94342_17', 'test_boelter3_8', 'test_94342_2',
|
||||
# 'test_boelter2_17', 'test_boelter3_7', 'test_94342_4', 'test_boelter3_9', 'test_boelter_10',
|
||||
# 'test_boelter2_6', 'test_boelter4_10', 'test_boelter4_2', 'test_boelter4_5', 'test_94342_24',
|
||||
# 'test_94342_15', 'test_boelter3_5', 'test_94342_8', 'test2', 'test_boelter3_12']:
|
||||
# print('here!')
|
||||
# with open(os.path.join(f'results/hgm_test_fb.csv'), mode='a') as file:
|
||||
# writer = csv.writer(file)
|
||||
# writer.writerow([experiment_id, obj_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
|
||||
else:
|
||||
false_belief = "no"
|
||||
#with open(os.path.join(f'results/test_fb.csv'), mode='a') as file:
|
||||
# writer = csv.writer(file)
|
||||
# writer.writerow([experiment_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
|
||||
|
||||
df = self.bbox_df[
|
||||
(self.bbox_df.experiment_name == experiment_id)
|
||||
#& (self.bbox_df.name == obj_id) # NOTE: load the bounding boxes for all objects
|
||||
& (self.bbox_df.name != 'P1')
|
||||
& (self.bbox_df.name != 'P2')
|
||||
& (self.bbox_df.frame.isin(frame_ids))
|
||||
]
|
||||
|
||||
bboxes = []
|
||||
for f in frame_ids:
|
||||
bbox = torch.tensor(df.loc[df['frame'] == f, ["x_min", "y_min", "x_max", "y_max"]].to_numpy(), dtype=torch.float32)
|
||||
bbox[:, 0] = bbox[:, 0] / 1280.0
|
||||
bbox[:, 1] = bbox[:, 1] / 720.0
|
||||
bbox[:, 2] = bbox[:, 2] / 1280.0
|
||||
bbox[:, 3] = bbox[:, 3] / 720.0
|
||||
bboxes.append(bbox)
|
||||
bboxes = torch.stack(bboxes) # NOTE: this will need a collate function bc not every video has the same number of objects
|
||||
|
||||
skele1 = self.skeleton_3D[experiment_id]["skele1"][frame_ids]
|
||||
skele2 = self.skeleton_3D[experiment_id]["skele2"][frame_ids]
|
||||
|
||||
gaze = self.tracker_2D[experiment_id][frame_ids]
|
||||
|
||||
if self.mode == "test":
|
||||
return (
|
||||
kinect_imgs,
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bboxes,
|
||||
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
|
||||
gaze,
|
||||
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
|
||||
experiment_id,
|
||||
frame_ids,
|
||||
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
|
||||
false_belief
|
||||
)
|
||||
else:
|
||||
return (
|
||||
kinect_imgs,
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bboxes,
|
||||
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
|
||||
gaze,
|
||||
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
|
||||
experiment_id,
|
||||
frame_ids
|
||||
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
|
||||
)
|
||||
|
||||
def onehot(self, x, n=len(UNIQUE_OBJ_IDS)):
|
||||
retval = torch.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
def load_obj_ids(self, path: str):
|
||||
with open(path, "rb") as f:
|
||||
ids = pickle.load(f)
|
||||
return ids
|
||||
|
||||
def extract_labels(self):
|
||||
"""TODO: Converts index label to [m1, m2, m12, m21, mc] format.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
def _flatten_mind_obj_timestep(self, mind_obj_dict: dict) -> list:
|
||||
"""Flattens the mind object dict to a list. I.e. takes
|
||||
|
||||
{
|
||||
"m1": {
|
||||
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
|
||||
"loc": null
|
||||
},
|
||||
"m2": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"m12": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"m21": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"mc": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"mg": {
|
||||
"fluent": 3,
|
||||
"loc": [
|
||||
22,
|
||||
9
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
and returns [3, 3, 3, 3, 3, 3]
|
||||
|
||||
Args:
|
||||
mind_obj_dict (dict): Mind object dict as described in __init__.doctstring.
|
||||
|
||||
Returns:
|
||||
list: List of mind object labels.
|
||||
"""
|
||||
return np.array([mind_obj["fluent"] for key, mind_obj in mind_obj_dict.items() if key != "mg"])
|
||||
|
||||
def load_skeleton_3D(self, path: str, list_of_ids_to_consider: list):
|
||||
"""Load skeleton 3D data from disk.
|
||||
|
||||
- path
|
||||
- * <- list of ids
|
||||
- skele1.p <- 3D coord per id and timestep
|
||||
- skele2.p <-
|
||||
|
||||
Args:
|
||||
path (str): Where the skeleton 3D data lie.
|
||||
list_of_ids_to_consider (list): List of ids to consider.
|
||||
Defaults to None which means all ids. Otherwise specify a list,
|
||||
e.g. ["test_94342_23", "test_boelter_21", ...].
|
||||
|
||||
Returns:
|
||||
dict: skeleton 3D data as described above in __init__.doctstring.
|
||||
"""
|
||||
skeleton_3D = {}
|
||||
for experiment_id in list_of_ids_to_consider:
|
||||
skeleton_3D[experiment_id] = {}
|
||||
with open(f"{path}/{experiment_id}/skele1.p", "rb") as f:
|
||||
skeleton_3D[experiment_id]["skele1"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
|
||||
with open(f"{path}/{experiment_id}/skele2.p", "rb") as f:
|
||||
skeleton_3D[experiment_id]["skele2"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
|
||||
return skeleton_3D
|
||||
|
||||
def load_tracker_2D(self, path: str, list_of_ids_to_consider: list):
|
||||
"""Load tracker 2D data from disk.
|
||||
|
||||
- path
|
||||
- *.p <- 2D coord per id and timestep
|
||||
|
||||
Args:
|
||||
path (str): Where the tracker 2D data lie.
|
||||
list_of_ids_to_consider (list): List of ids to consider.
|
||||
Defaults to None which means all ids. Otherwise specify a list,
|
||||
e.g. ["test_94342_23", "test_boelter_21", ...].
|
||||
|
||||
Returns:
|
||||
dict: tracker 2D data.
|
||||
"""
|
||||
tracker_2D = {}
|
||||
for experiment_id in list_of_ids_to_consider:
|
||||
with open(f"{path}/{experiment_id}.p", "rb") as f:
|
||||
tracker_2D[experiment_id] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
|
||||
return tracker_2D
|
||||
|
||||
def load_bbox(self, path: str, list_of_ids_to_consider: list):
|
||||
"""Load bbox data from disk.
|
||||
|
||||
- bbox_tensors.pickle <- bbox per experiment id one tensor
|
||||
|
||||
Args:
|
||||
path (str): Where the bbox data lie.
|
||||
list_of_ids_to_consider (list): List of ids to consider.
|
||||
|
||||
Returns:
|
||||
dict: bbox data.
|
||||
"""
|
||||
with open(path, "rb") as f:
|
||||
pickle_data = pickle.load(f)
|
||||
for key in CLIPS_IDS_88:
|
||||
if key not in list_of_ids_to_consider:
|
||||
pickle_data.pop(key, None)
|
||||
return pickle_data
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# os.environ['PYTHONHASHSEED'] = str(42)
|
||||
# torch.manual_seed(42)
|
||||
# np.random.seed(42)
|
||||
# random.seed(42)
|
||||
|
||||
data = TBDDataset(use_preprocessed_img=True, mode="test")
|
||||
|
||||
from tqdm import tqdm
|
||||
for i in tqdm(range(data.__len__())):
|
||||
data[i]
|
||||
|
||||
breakpoint()
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Just for guessing time
|
||||
data_0=data[0]
|
||||
data_last=data[len(data)-1]
|
||||
idx = np.random.randint(1, len(data)-1) # Something in between.
|
||||
start = time.time()
|
||||
(
|
||||
kinect_imgs, # <- len x 720 x 1280 x 3 originally, likely smaller now
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bbox,
|
||||
tracker_skeID_sample, # <- This is the tracker skeleton ID
|
||||
tracker2d,
|
||||
label,
|
||||
experiment_id, # From here for debugging
|
||||
timestep,
|
||||
#obj_id, # <- This is the object ID as a one-hot
|
||||
false_belief
|
||||
) = data[idx]
|
||||
end = time.time()
|
||||
print(f"Time for one sample: {end-start}")
|
||||
|
||||
print('kinect:', kinect_imgs.shape)
|
||||
print('tracker:', tracker_imgs.shape)
|
||||
print('battery:', battery_imgs.shape)
|
||||
print('skele1:', skele1.shape)
|
||||
print('skele2:', skele2.shape)
|
||||
print('gaze:', tracker2d.shape)
|
||||
print('bbox:', bbox.shape)
|
||||
print('label:', label)
|
||||
|
||||
#breakpoint()
|
||||
|
||||
dl = DataLoader(
|
||||
data,
|
||||
batch_size=4,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
for j, batch in tqdm(enumerate(dl)):
|
||||
#print(j, end='\r')
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
|
||||
#breakpoint()
|
||||
#print(img_3rd_pov.shape)
|
||||
#print(img_tracker.shape)
|
||||
#print(img_battery.shape)
|
||||
#print(pose1.shape, pose2.shape)
|
||||
#print(bbox.shape)
|
||||
#print(gaze.shape)
|
||||
|
||||
|
||||
breakpoint()
|
196
tbd/test.py
Normal file
196
tbd/test.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
import torch
|
||||
import csv
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tbd_dataloader import TBDDataset, collate_fn_test
|
||||
from models.common_mind import CommonMindToMnet
|
||||
from models.sl import SLToMnet
|
||||
from models.implicit import ImplicitToMnet
|
||||
from utils.helpers import compute_f1_scores
|
||||
|
||||
|
||||
def test(args):
|
||||
|
||||
test_dataset = TBDDataset(
|
||||
path=args.data_path,
|
||||
mode="test",
|
||||
use_preprocessed_img=True
|
||||
)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
collate_fn=collate_fn_test
|
||||
)
|
||||
|
||||
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# model
|
||||
if args.model_type == 'tom_cm':
|
||||
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_sl':
|
||||
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
elif args.model_type == 'tom_impl':
|
||||
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
else: raise NotImplementedError
|
||||
|
||||
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
|
||||
model.device = device
|
||||
|
||||
model.eval()
|
||||
|
||||
if args.save_preds:
|
||||
# Define the output file path
|
||||
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}'
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path)
|
||||
print(f'Saving predictions in {folder_path}.')
|
||||
|
||||
print('Testing...')
|
||||
m1_pred_list = []
|
||||
m2_pred_list = []
|
||||
m12_pred_list = []
|
||||
m21_pred_list = []
|
||||
mc_pred_list = []
|
||||
m1_label_list = []
|
||||
m2_label_list = []
|
||||
m12_label_list = []
|
||||
m21_label_list = []
|
||||
mc_label_list = []
|
||||
with torch.no_grad():
|
||||
for j, batch in tqdm(enumerate(test_dataloader)):
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch
|
||||
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
||||
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
||||
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
||||
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
||||
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
||||
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
||||
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
||||
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
m1_pred = m1_pred.reshape(-1, 4)
|
||||
m2_pred = m2_pred.reshape(-1, 4)
|
||||
m12_pred = m12_pred.reshape(-1, 4)
|
||||
m21_pred = m21_pred.reshape(-1, 4)
|
||||
mc_pred = mc_pred.reshape(-1, 4)
|
||||
m1_label = labels[:, 0].reshape(-1).to(device)
|
||||
m2_label = labels[:, 1].reshape(-1).to(device)
|
||||
m12_label = labels[:, 2].reshape(-1).to(device)
|
||||
m21_label = labels[:, 3].reshape(-1).to(device)
|
||||
mc_label = labels[:, 4].reshape(-1).to(device)
|
||||
|
||||
m1_pred_list.append(m1_pred)
|
||||
m2_pred_list.append(m2_pred)
|
||||
m12_pred_list.append(m12_pred)
|
||||
m21_pred_list.append(m21_pred)
|
||||
mc_pred_list.append(mc_pred)
|
||||
m1_label_list.append(m1_label)
|
||||
m2_label_list.append(m2_label)
|
||||
m12_label_list.append(m12_label)
|
||||
m21_label_list.append(m21_label)
|
||||
mc_label_list.append(mc_label)
|
||||
|
||||
if args.save_preds:
|
||||
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
|
||||
data = [(
|
||||
i,
|
||||
torch.argmax(m1_pred[i]).cpu().numpy(),
|
||||
torch.argmax(m2_pred[i]).cpu().numpy(),
|
||||
torch.argmax(m12_pred[i]).cpu().numpy(),
|
||||
torch.argmax(m21_pred[i]).cpu().numpy(),
|
||||
torch.argmax(mc_pred[i]).cpu().numpy(),
|
||||
m1_label[i].cpu().numpy(),
|
||||
m2_label[i].cpu().numpy(),
|
||||
m12_label[i].cpu().numpy(),
|
||||
m21_label[i].cpu().numpy(),
|
||||
mc_label[i].cpu().numpy(),
|
||||
false_belief[i]) for i in range(len(labels))
|
||||
]
|
||||
header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief']
|
||||
with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow(header) # Write the header row
|
||||
writer.writerows(data) # Write the data rows
|
||||
|
||||
#np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy())
|
||||
test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores(
|
||||
torch.cat(m1_pred_list),
|
||||
torch.cat(m1_label_list),
|
||||
torch.cat(m2_pred_list),
|
||||
torch.cat(m2_label_list),
|
||||
torch.cat(m12_pred_list),
|
||||
torch.cat(m12_label_list),
|
||||
torch.cat(m21_pred_list),
|
||||
torch.cat(m21_label_list),
|
||||
torch.cat(mc_pred_list),
|
||||
torch.cat(mc_label_list)
|
||||
)
|
||||
|
||||
print("Test m1 F1: {}".format(test_m1_f1))
|
||||
print("Test m2 F1: {}".format(test_m2_f1))
|
||||
print("Test m12 F1: {}".format(test_m12_f1))
|
||||
print("Test m21 F1: {}".format(test_m21_f1))
|
||||
print("Test mc F1: {}".format(test_mc_f1))
|
||||
|
||||
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
|
||||
f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}")
|
||||
f.write(f"m1 f1: {test_m1_f1}")
|
||||
f.write(f"m2 f1: {test_m2_f1}")
|
||||
f.write(f"m12 f1: {test_m12_f1}")
|
||||
f.write(f"m21 f1: {test_m21_f1}")
|
||||
f.write(f"mc f1: {test_mc_f1}")
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Define the command-line arguments
|
||||
parser.add_argument('--gpu_id', type=int)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--presaved', type=int, default=128)
|
||||
parser.add_argument('--non_blocking', action='store_true')
|
||||
parser.add_argument('--num_workers', type=int, default=16)
|
||||
parser.add_argument('--pin_memory', action='store_true')
|
||||
parser.add_argument('--model_type', type=str)
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument('--aggr', type=str, default='concat', required=False)
|
||||
parser.add_argument('--use_resnet', action='store_true')
|
||||
parser.add_argument('--hidden_dim', type=int, default=64)
|
||||
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
||||
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
|
||||
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
|
||||
parser.add_argument('--save_path', type=str, default='experiments/')
|
||||
parser.add_argument('--test_frames', type=str, default=None)
|
||||
parser.add_argument('--median', type=int, default=None)
|
||||
parser.add_argument('--load_model_path', type=str)
|
||||
parser.add_argument('--dropout', type=float, default=0.0)
|
||||
parser.add_argument('--save_preds', action='store_true')
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
||||
if not args.aggr:
|
||||
parser.error("The choosen --model_type requires --aggr")
|
||||
if args.model_type == 'tom_sl' and not args.tom_weight:
|
||||
parser.error("The choosen --model_type requires --tom_weight")
|
||||
|
||||
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
print('###########################################################################')
|
||||
print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!')
|
||||
print('###########################################################################')
|
||||
|
||||
test(args)
|
474
tbd/train.py
Normal file
474
tbd/train.py
Normal file
|
@ -0,0 +1,474 @@
|
|||
import torch
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import random
|
||||
import datetime
|
||||
import wandb
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from tbd_dataloader import TBDDataset, collate_fn
|
||||
from models.common_mind import CommonMindToMnet
|
||||
from models.sl import SLToMnet
|
||||
from models.implicit import ImplicitToMnet
|
||||
from utils.helpers import count_parameters, compute_f1_scores
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
train_dataset = TBDDataset(
|
||||
path=args.data_path,
|
||||
mode="train",
|
||||
use_preprocessed_img=True
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
val_dataset = TBDDataset(
|
||||
path=args.data_path,
|
||||
mode="val",
|
||||
use_preprocessed_img=True
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
train_data = [data[1] for data in train_dataset.data]
|
||||
val_data = [data[1] for data in val_dataset.data]
|
||||
if args.logger:
|
||||
wandb.config.update({"train_data": train_data})
|
||||
wandb.config.update({"val_data": val_data})
|
||||
|
||||
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# model
|
||||
if args.model_type == 'tom_cm':
|
||||
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_sl':
|
||||
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
elif args.model_type == 'tom_impl':
|
||||
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
else: raise NotImplementedError
|
||||
if args.resume_from_checkpoint is not None:
|
||||
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
# scheduler
|
||||
if args.scheduler == None:
|
||||
scheduler = None
|
||||
else:
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
|
||||
# loss function
|
||||
if args.model_type == 'tom_sl':
|
||||
ce_loss_m1 = nn.NLLLoss()
|
||||
ce_loss_m2 = nn.NLLLoss()
|
||||
ce_loss_m12 = nn.NLLLoss()
|
||||
ce_loss_m21 = nn.NLLLoss()
|
||||
ce_loss_mc = nn.NLLLoss()
|
||||
else:
|
||||
ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
|
||||
stats = {
|
||||
'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []},
|
||||
'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}
|
||||
}
|
||||
max_val_f1 = 0
|
||||
max_val_classification_epoch = None
|
||||
counter = 0
|
||||
|
||||
print(f'Number of parameters: {count_parameters(model)}')
|
||||
|
||||
for i in range(args.num_epoch):
|
||||
# training
|
||||
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
|
||||
epoch_train_loss_m1 = 0.0
|
||||
epoch_train_loss_m2 = 0.0
|
||||
epoch_train_loss_m12 = 0.0
|
||||
epoch_train_loss_m21 = 0.0
|
||||
epoch_train_loss_mc = 0.0
|
||||
m1_train_batch_pred_list = []
|
||||
m2_train_batch_pred_list = []
|
||||
m12_train_batch_pred_list = []
|
||||
m21_train_batch_pred_list = []
|
||||
mc_train_batch_pred_list = []
|
||||
m1_train_batch_label_list = []
|
||||
m2_train_batch_label_list = []
|
||||
m12_train_batch_label_list = []
|
||||
m21_train_batch_label_list = []
|
||||
mc_train_batch_label_list = []
|
||||
model.train()
|
||||
for j, batch in tqdm(enumerate(train_dataloader)):
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
|
||||
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
||||
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
||||
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
||||
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
||||
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
||||
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
||||
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
||||
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
m1_pred = m1_pred.reshape(-1, 4)
|
||||
m2_pred = m2_pred.reshape(-1, 4)
|
||||
m12_pred = m12_pred.reshape(-1, 4)
|
||||
m21_pred = m21_pred.reshape(-1, 4)
|
||||
mc_pred = mc_pred.reshape(-1, 4)
|
||||
m1_label = labels[:, 0].reshape(-1).to(device)
|
||||
m2_label = labels[:, 1].reshape(-1).to(device)
|
||||
m12_label = labels[:, 2].reshape(-1).to(device)
|
||||
m21_label = labels[:, 3].reshape(-1).to(device)
|
||||
mc_label = labels[:, 4].reshape(-1).to(device)
|
||||
|
||||
loss_m1 = ce_loss_m1(m1_pred, m1_label)
|
||||
loss_m2 = ce_loss_m2(m2_pred, m2_label)
|
||||
loss_m12 = ce_loss_m12(m12_pred, m12_label)
|
||||
loss_m21 = ce_loss_m21(m21_pred, m21_label)
|
||||
loss_mc = ce_loss_mc(mc_pred, mc_label)
|
||||
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
|
||||
|
||||
epoch_train_loss_m1 += loss_m1.data.item()
|
||||
epoch_train_loss_m2 += loss_m2.data.item()
|
||||
epoch_train_loss_m12 += loss_m12.data.item()
|
||||
epoch_train_loss_m21 += loss_m21.data.item()
|
||||
epoch_train_loss_mc += loss_mc.data.item()
|
||||
|
||||
m1_train_batch_pred_list.append(m1_pred)
|
||||
m2_train_batch_pred_list.append(m2_pred)
|
||||
m12_train_batch_pred_list.append(m12_pred)
|
||||
m21_train_batch_pred_list.append(m21_pred)
|
||||
mc_train_batch_pred_list.append(mc_pred)
|
||||
m1_train_batch_label_list.append(m1_label)
|
||||
m2_train_batch_label_list.append(m2_label)
|
||||
m12_train_batch_label_list.append(m12_label)
|
||||
m21_train_batch_label_list.append(m21_label)
|
||||
mc_train_batch_label_list.append(mc_label)
|
||||
|
||||
optimizer.zero_grad()
|
||||
if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if args.logger: wandb.log({
|
||||
'batch_train_loss': loss.data.item(),
|
||||
'lr': optimizer.param_groups[-1]['lr']
|
||||
})
|
||||
|
||||
print("Epoch {}/{} batch {}/{} training done with loss={}".format(
|
||||
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item())
|
||||
)
|
||||
|
||||
if scheduler: scheduler.step()
|
||||
|
||||
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores(
|
||||
torch.cat(m1_train_batch_pred_list),
|
||||
torch.cat(m1_train_batch_label_list),
|
||||
torch.cat(m2_train_batch_pred_list),
|
||||
torch.cat(m2_train_batch_label_list),
|
||||
torch.cat(m12_train_batch_pred_list),
|
||||
torch.cat(m12_train_batch_label_list),
|
||||
torch.cat(m21_train_batch_pred_list),
|
||||
torch.cat(m21_train_batch_label_list),
|
||||
torch.cat(mc_train_batch_pred_list),
|
||||
torch.cat(mc_train_batch_label_list)
|
||||
)
|
||||
|
||||
print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format(
|
||||
i+1,
|
||||
args.num_epoch,
|
||||
epoch_train_loss_m1/len(train_dataloader),
|
||||
epoch_train_loss_m2/len(train_dataloader),
|
||||
epoch_train_loss_m12/len(train_dataloader),
|
||||
epoch_train_loss_m21/len(train_dataloader),
|
||||
epoch_train_loss_mc/len(train_dataloader),
|
||||
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score
|
||||
)
|
||||
)
|
||||
stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader))
|
||||
stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader))
|
||||
stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader))
|
||||
stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader))
|
||||
stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader))
|
||||
stats['train']['m1_f1'].append(train_m1_f1_score)
|
||||
stats['train']['m2_f1'].append(train_m2_f1_score)
|
||||
stats['train']['m12_f1'].append(train_m12_f1_score)
|
||||
stats['train']['m21_f1'].append(train_m21_f1_score)
|
||||
stats['train']['mc_f1'].append(train_mc_f1_score)
|
||||
|
||||
if args.logger: wandb.log(
|
||||
{
|
||||
'train_m1_loss': epoch_train_loss_m1/len(train_dataloader),
|
||||
'train_m2_loss': epoch_train_loss_m2/len(train_dataloader),
|
||||
'train_m12_loss': epoch_train_loss_m12/len(train_dataloader),
|
||||
'train_m21_loss': epoch_train_loss_m21/len(train_dataloader),
|
||||
'train_mc_loss': epoch_train_loss_mc/len(train_dataloader),
|
||||
'train_loss': epoch_train_loss_m1/len(train_dataloader) + \
|
||||
epoch_train_loss_m2/len(train_dataloader) + \
|
||||
epoch_train_loss_m12/len(train_dataloader) + \
|
||||
epoch_train_loss_m21/len(train_dataloader) + \
|
||||
epoch_train_loss_mc/len(train_dataloader),
|
||||
'train_m1_f1_score': train_m1_f1_score,
|
||||
'train_m2_f1_score': train_m2_f1_score,
|
||||
'train_m12_f1_score': train_m12_f1_score,
|
||||
'train_m21_f1_score': train_m21_f1_score,
|
||||
'train_mc_f1_score': train_mc_f1_score
|
||||
}
|
||||
)
|
||||
|
||||
# validation
|
||||
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
|
||||
epoch_val_loss_m1 = 0.0
|
||||
epoch_val_loss_m2 = 0.0
|
||||
epoch_val_loss_m12 = 0.0
|
||||
epoch_val_loss_m21 = 0.0
|
||||
epoch_val_loss_mc = 0.0
|
||||
m1_val_batch_pred_list = []
|
||||
m2_val_batch_pred_list = []
|
||||
m12_val_batch_pred_list = []
|
||||
m21_val_batch_pred_list = []
|
||||
mc_val_batch_pred_list = []
|
||||
m1_val_batch_label_list = []
|
||||
m2_val_batch_label_list = []
|
||||
m12_val_batch_label_list = []
|
||||
m21_val_batch_label_list = []
|
||||
mc_val_batch_label_list = []
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for j, batch in tqdm(enumerate(val_dataloader)):
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
|
||||
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
||||
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
||||
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
||||
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
||||
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
||||
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
||||
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
||||
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
m1_pred = m1_pred.reshape(-1, 4)
|
||||
m2_pred = m2_pred.reshape(-1, 4)
|
||||
m12_pred = m12_pred.reshape(-1, 4)
|
||||
m21_pred = m21_pred.reshape(-1, 4)
|
||||
mc_pred = mc_pred.reshape(-1, 4)
|
||||
m1_label = labels[:, 0].reshape(-1).to(device)
|
||||
m2_label = labels[:, 1].reshape(-1).to(device)
|
||||
m12_label = labels[:, 2].reshape(-1).to(device)
|
||||
m21_label = labels[:, 3].reshape(-1).to(device)
|
||||
mc_label = labels[:, 4].reshape(-1).to(device)
|
||||
|
||||
loss_m1 = ce_loss_m1(m1_pred, m1_label)
|
||||
loss_m2 = ce_loss_m2(m2_pred, m2_label)
|
||||
loss_m12 = ce_loss_m12(m12_pred, m12_label)
|
||||
loss_m21 = ce_loss_m21(m21_pred, m21_label)
|
||||
loss_mc = ce_loss_mc(mc_pred, mc_label)
|
||||
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
|
||||
|
||||
epoch_val_loss_m1 += loss_m1.data.item()
|
||||
epoch_val_loss_m2 += loss_m2.data.item()
|
||||
epoch_val_loss_m12 += loss_m12.data.item()
|
||||
epoch_val_loss_m21 += loss_m21.data.item()
|
||||
epoch_val_loss_mc += loss_mc.data.item()
|
||||
|
||||
m1_val_batch_pred_list.append(m1_pred)
|
||||
m2_val_batch_pred_list.append(m2_pred)
|
||||
m12_val_batch_pred_list.append(m12_pred)
|
||||
m21_val_batch_pred_list.append(m21_pred)
|
||||
mc_val_batch_pred_list.append(mc_pred)
|
||||
m1_val_batch_label_list.append(m1_label)
|
||||
m2_val_batch_label_list.append(m2_label)
|
||||
m12_val_batch_label_list.append(m12_label)
|
||||
m21_val_batch_label_list.append(m21_label)
|
||||
mc_val_batch_label_list.append(mc_label)
|
||||
|
||||
if args.logger: wandb.log({'batch_val_loss': loss.data.item()})
|
||||
print("Epoch {}/{} batch {}/{} validation done with loss={}".format(
|
||||
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item())
|
||||
)
|
||||
|
||||
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores(
|
||||
torch.cat(m1_val_batch_pred_list),
|
||||
torch.cat(m1_val_batch_label_list),
|
||||
torch.cat(m2_val_batch_pred_list),
|
||||
torch.cat(m2_val_batch_label_list),
|
||||
torch.cat(m12_val_batch_pred_list),
|
||||
torch.cat(m12_val_batch_label_list),
|
||||
torch.cat(m21_val_batch_pred_list),
|
||||
torch.cat(m21_val_batch_label_list),
|
||||
torch.cat(mc_val_batch_pred_list),
|
||||
torch.cat(mc_val_batch_label_list)
|
||||
)
|
||||
|
||||
print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format(
|
||||
i+1,
|
||||
args.num_epoch,
|
||||
epoch_val_loss_m1/len(val_dataloader),
|
||||
epoch_val_loss_m2/len(val_dataloader),
|
||||
epoch_val_loss_m12/len(val_dataloader),
|
||||
epoch_val_loss_m21/len(val_dataloader),
|
||||
epoch_val_loss_mc/len(val_dataloader),
|
||||
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score
|
||||
)
|
||||
)
|
||||
|
||||
stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader))
|
||||
stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader))
|
||||
stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader))
|
||||
stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader))
|
||||
stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader))
|
||||
stats['val']['m1_f1'].append(val_m1_f1_score)
|
||||
stats['val']['m2_f1'].append(val_m2_f1_score)
|
||||
stats['val']['m12_f1'].append(val_m12_f1_score)
|
||||
stats['val']['m21_f1'].append(val_m21_f1_score)
|
||||
stats['val']['mc_f1'].append(val_mc_f1_score)
|
||||
|
||||
if args.logger: wandb.log(
|
||||
{
|
||||
'val_m1_loss': epoch_val_loss_m1/len(val_dataloader),
|
||||
'val_m2_loss': epoch_val_loss_m2/len(val_dataloader),
|
||||
'val_m12_loss': epoch_val_loss_m12/len(val_dataloader),
|
||||
'val_m21_loss': epoch_val_loss_m21/len(val_dataloader),
|
||||
'val_mc_loss': epoch_val_loss_mc/len(val_dataloader),
|
||||
'val_loss': epoch_val_loss_m1/len(val_dataloader) + \
|
||||
epoch_val_loss_m2/len(val_dataloader) + \
|
||||
epoch_val_loss_m12/len(val_dataloader) + \
|
||||
epoch_val_loss_m21/len(val_dataloader) + \
|
||||
epoch_val_loss_mc/len(val_dataloader),
|
||||
'val_m1_f1_score': val_m1_f1_score,
|
||||
'val_m2_f1_score': val_m2_f1_score,
|
||||
'val_m12_f1_score': val_m12_f1_score,
|
||||
'val_m21_f1_score': val_m21_f1_score,
|
||||
'val_mc_f1_score': val_mc_f1_score
|
||||
}
|
||||
)
|
||||
|
||||
# check for best stat/model using validation accuracy
|
||||
if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1:
|
||||
max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1]
|
||||
max_val_classification_epoch = i+1
|
||||
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
|
||||
counter = 0
|
||||
else:
|
||||
counter += 1
|
||||
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
|
||||
if counter >= args.patience:
|
||||
break
|
||||
|
||||
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
|
||||
f.write('{}\n'.format(CFG))
|
||||
f.write('{}\n'.format(train_data))
|
||||
f.write('{}\n'.format(val_data))
|
||||
f.write('{}\n'.format(stats))
|
||||
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1))
|
||||
f.close()
|
||||
|
||||
print(f'Results saved in {experiment_save_path}')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Define the command-line arguments
|
||||
parser.add_argument('--gpu_id', type=int)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--logger', action='store_true')
|
||||
parser.add_argument('--presaved', type=int, default=128)
|
||||
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
|
||||
parser.add_argument('--use_mixup', action='store_true')
|
||||
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
|
||||
parser.add_argument('--non_blocking', action='store_true')
|
||||
parser.add_argument('--patience', type=int, default=99)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--num_workers', type=int, default=8)
|
||||
parser.add_argument('--pin_memory', action='store_true')
|
||||
parser.add_argument('--num_epoch', type=int, default=300)
|
||||
parser.add_argument('--lr', type=float, default=4e-4)
|
||||
parser.add_argument('--scheduler', type=str, default=None)
|
||||
parser.add_argument('--dropout', type=float, default=0.1)
|
||||
parser.add_argument('--weight_decay', type=float, default=0.005)
|
||||
parser.add_argument('--label_smoothing', type=float, default=0.1)
|
||||
parser.add_argument('--model_type', type=str)
|
||||
parser.add_argument('--aggr', type=str, default='concat', required=False)
|
||||
parser.add_argument('--use_resnet', action='store_true')
|
||||
parser.add_argument('--hidden_dim', type=int, default=64)
|
||||
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
||||
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
|
||||
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
|
||||
parser.add_argument('--save_path', type=str, default='experiments/')
|
||||
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.use_mixup and not args.mixup_alpha:
|
||||
parser.error("--use_mixup requires --mixup_alpha")
|
||||
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
||||
if not args.aggr:
|
||||
parser.error("The choosen --model_type requires --aggr")
|
||||
if args.model_type == 'tom_sl' and not args.tom_weight:
|
||||
parser.error("The choosen --model_type requires --tom_weight")
|
||||
|
||||
# get experiment ID
|
||||
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path, exist_ok=True)
|
||||
experiment_save_path = os.path.join(args.save_path, experiment_id)
|
||||
os.makedirs(experiment_save_path, exist_ok=True)
|
||||
|
||||
CFG = {
|
||||
'use_ocr_custom_loss': 0,
|
||||
'presaved': args.presaved,
|
||||
'batch_size': args.batch_size,
|
||||
'num_epoch': args.num_epoch,
|
||||
'lr': args.lr,
|
||||
'scheduler': args.scheduler,
|
||||
'weight_decay': args.weight_decay,
|
||||
'model_type': args.model_type,
|
||||
'use_resnet': args.use_resnet,
|
||||
'hidden_dim': args.hidden_dim,
|
||||
'tom_weight': args.tom_weight,
|
||||
'dropout': args.dropout,
|
||||
'label_smoothing': args.label_smoothing,
|
||||
'clip_grad_norm': args.clip_grad_norm,
|
||||
'use_mixup': args.use_mixup,
|
||||
'mixup_alpha': args.mixup_alpha,
|
||||
'non_blocking_tensors': args.non_blocking,
|
||||
'patience': args.patience,
|
||||
'pin_memory': args.pin_memory,
|
||||
'resume_from_checkpoint': args.resume_from_checkpoint,
|
||||
'aggr': args.aggr,
|
||||
'mods': args.mods,
|
||||
'save_path': experiment_save_path ,
|
||||
'seed': args.seed
|
||||
}
|
||||
|
||||
print(CFG)
|
||||
print(f'Saving results in {experiment_save_path}')
|
||||
|
||||
# set seed values
|
||||
if args.logger:
|
||||
wandb.init(project="tbd", config=CFG)
|
||||
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
main(args)
|
224
tbd/utils/fb_scores_err.py
Normal file
224
tbd/utils/fb_scores_err.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import os
|
||||
import csv
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from sklearn.metrics import f1_score
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
ALPHA = 0.7
|
||||
BAR_WIDTH = 0.27
|
||||
sns.set_theme(style='whitegrid')
|
||||
#sns.set_palette('mako')
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
model_to_subdir = {
|
||||
"IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"],
|
||||
"IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"],
|
||||
"IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"],
|
||||
"IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"],
|
||||
"CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"],
|
||||
"CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"],
|
||||
"CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"],
|
||||
"CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"],
|
||||
"DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"],
|
||||
"Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"]
|
||||
}
|
||||
|
||||
def read_data_from_csv(subdirectory_path):
|
||||
print(subdirectory_path)
|
||||
data = []
|
||||
csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')]
|
||||
for csv_file in csv_files:
|
||||
file_path = os.path.join(subdirectory_path, csv_file)
|
||||
with open(file_path, 'r') as file:
|
||||
reader = csv.reader(file)
|
||||
header_skipped = False
|
||||
for row in reader:
|
||||
if not header_skipped:
|
||||
header_skipped = True
|
||||
continue
|
||||
frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row
|
||||
data.append({
|
||||
'frame': int(frame),
|
||||
'm1_pred': int(m1_pred),
|
||||
'm2_pred': int(m2_pred),
|
||||
'm12_pred': int(m12_pred),
|
||||
'm21_pred': int(m21_pred),
|
||||
'mc_pred': int(mc_pred),
|
||||
'm1_label': int(m1_label),
|
||||
'm2_label': int(m2_label),
|
||||
'm12_label': int(m12_label),
|
||||
'm21_label': int(m21_label),
|
||||
'mc_label': int(mc_label),
|
||||
'false_belief': false_belief,
|
||||
})
|
||||
return data
|
||||
|
||||
def compute_correct_false_belief(data, mind="all", folder=None):
|
||||
total_false_belief = 0
|
||||
correct_false_belief = 0
|
||||
for item in data:
|
||||
if 'false' in item['false_belief']:
|
||||
false_belief_type = item['false_belief'].split('_')[0]
|
||||
if mind == "all" or false_belief_type in mind:
|
||||
total_false_belief += 1
|
||||
if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]:
|
||||
if folder is not None:
|
||||
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
|
||||
f.write(f"{str(1)}\n")
|
||||
correct_false_belief += 1
|
||||
else:
|
||||
if folder is not None:
|
||||
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
|
||||
f.write(f"{str(0)}\n")
|
||||
if total_false_belief == 0:
|
||||
accuracy = 0.0
|
||||
else:
|
||||
accuracy = correct_false_belief / total_false_belief
|
||||
return accuracy
|
||||
|
||||
def compute_macro_f1_score(data, mind="all"):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for item in data:
|
||||
if 'false' in item['false_belief']:
|
||||
false_belief_type = item['false_belief'].split('_')[0]
|
||||
if mind == "all" or false_belief_type in mind:
|
||||
y_true.append(int(item[f"{false_belief_type}_label"]))
|
||||
y_pred.append(int(item[f"{false_belief_type}_pred"]))
|
||||
|
||||
if not y_true or not y_pred:
|
||||
macro_f1 = 0.0
|
||||
else:
|
||||
macro_f1 = f1_score(y_true, y_pred, average='macro')
|
||||
|
||||
return macro_f1
|
||||
|
||||
def delete_files_in_subfolders(folder_path, file_names_to_delete):
|
||||
"""
|
||||
Delete specified files in all subfolders of a given folder.
|
||||
|
||||
Parameters:
|
||||
folder_path: The path to the folder containing subfolders.
|
||||
file_names_to_delete: A list of file names to be deleted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for root, _, _ in os.walk(folder_path):
|
||||
for file_name in file_names_to_delete:
|
||||
file_path = os.path.join(root, file_name)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
print(f"Deleted: {file_path}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
folder_path = "predictions"
|
||||
files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"]
|
||||
delete_files_in_subfolders(folder_path, files_to_delete)
|
||||
|
||||
metric = "Accuracy"
|
||||
if metric == "Macro F1":
|
||||
score_function = compute_macro_f1_score
|
||||
elif metric == "Accuracy":
|
||||
score_function = compute_correct_false_belief
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
models = [
|
||||
'Base', 'DB',
|
||||
'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$',
|
||||
'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$'
|
||||
]
|
||||
|
||||
parent_dir = 'predictions'
|
||||
minds = categories = ['m1', 'm2', 'm12', 'm21']
|
||||
score_m1_m2 = []
|
||||
score_m12_m21 = []
|
||||
score_all = []
|
||||
std_m1_m2 = []
|
||||
std_m12_m21 = []
|
||||
std_all = []
|
||||
|
||||
for model in models:
|
||||
model_scores_m1_m2 = []
|
||||
model_scores_m12_m21 = []
|
||||
model_scores_all = []
|
||||
for s in range(3):
|
||||
subdir_path = os.path.join(parent_dir, model_to_subdir[model][s])
|
||||
data = read_data_from_csv(subdir_path)
|
||||
model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s]))
|
||||
model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s]))
|
||||
model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s]))
|
||||
score_m1_m2.append(np.mean(model_scores_m1_m2))
|
||||
std_m1_m2.append(np.std(model_scores_m1_m2))
|
||||
score_m12_m21.append(np.mean(model_scores_m12_m21))
|
||||
std_m12_m21.append(np.std(model_scores_m12_m21))
|
||||
score_all.append(np.mean(model_scores_all))
|
||||
std_all.append(np.std(model_scores_all))
|
||||
|
||||
# Create a dataframe to use with sns.catplot
|
||||
data = {
|
||||
'Model': [m for m in models],
|
||||
'FO_FB_mean': score_m1_m2,
|
||||
'FO_FB_std': std_m1_m2,
|
||||
'SO_FB_mean': score_m12_m21,
|
||||
'SO_FB_std': std_m12_m21,
|
||||
'Both_mean': score_all,
|
||||
'Both_std': std_all
|
||||
}
|
||||
|
||||
models = data['Model']
|
||||
fo_fb_mean = data['FO_FB_mean']
|
||||
fo_fb_std = data['FO_FB_std']
|
||||
so_fb_mean = data['SO_FB_mean']
|
||||
so_fb_std = data['SO_FB_std']
|
||||
both_mean = data['Both_mean']
|
||||
both_std = data['Both_std']
|
||||
|
||||
bar_width = BAR_WIDTH
|
||||
x = np.arange(len(models))
|
||||
|
||||
plt.figure(figsize=(13, 3.5))
|
||||
fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA)
|
||||
so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA)
|
||||
both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA)
|
||||
|
||||
def add_labels(bars, std_values):
|
||||
cnt = 0
|
||||
for bar, std in zip(bars, std_values):
|
||||
height = bar.get_height()
|
||||
offset = std + 0.01
|
||||
if cnt == 0 or cnt == 1 or cnt == 9:
|
||||
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10)
|
||||
else:
|
||||
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10)
|
||||
cnt = cnt + 1
|
||||
|
||||
add_labels(fo_fb_bars, fo_fb_std)
|
||||
add_labels(so_fb_bars, so_fb_std)
|
||||
add_labels(both_bars, both_std)
|
||||
|
||||
plt.gca().spines['top'].set_visible(False)
|
||||
plt.gca().spines['right'].set_visible(False)
|
||||
plt.xlabel('MToMnet', fontsize=14)
|
||||
plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14)
|
||||
plt.xticks(x, models, rotation=0, fontsize=14)
|
||||
plt.yticks(fontsize=14)
|
||||
plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3)
|
||||
plt.tight_layout()
|
||||
plt.savefig('results/false_belief_first_vs_second.pdf')
|
210
tbd/utils/helpers.py
Normal file
210
tbd/utils/helpers.py
Normal file
File diff suppressed because one or more lines are too long
37
tbd/utils/preprocess_img.py
Normal file
37
tbd/utils/preprocess_img.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
import glob
|
||||
|
||||
import cv2
|
||||
|
||||
import torchvision.transforms as T
|
||||
import torch
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
PATH_IN = "/scratch/bortoletto/data/tbd/images"
|
||||
PATH_OUT = "/scratch/bortoletto/data/tbd/images_norm"
|
||||
|
||||
normalisation_steps = [
|
||||
T.ToTensor(),
|
||||
T.Resize((128,128)),
|
||||
T.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
]
|
||||
|
||||
preprocess_img = T.Compose(normalisation_steps)
|
||||
|
||||
def main():
|
||||
print(f"{PATH_IN}/*/*/*.jpg")
|
||||
all_img = glob.glob(f"{PATH_IN}/*/*/*.jpg")
|
||||
print(len(all_img))
|
||||
for img_path in tqdm(all_img):
|
||||
new_img = preprocess_img(cv2.imread(img_path)).numpy()
|
||||
img_path_split = img_path.split("/")
|
||||
os.makedirs(f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}", exist_ok=True)
|
||||
out_img = f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}/{img_path_split[-1][:-4]}.pt"
|
||||
torch.save(new_img, out_img)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
106
tbd/utils/reformat_labels_ours.py
Normal file
106
tbd/utils/reformat_labels_ours.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
import pandas as pd
|
||||
import os
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
DATASET_LOCATION = "YOUR_PATH_HERE"
|
||||
|
||||
def reframe_annotation():
|
||||
annotation_path = f'{DATASET_LOCATION}/retrieve_annotation/all/'
|
||||
save_path = f'{DATASET_LOCATION}/reformat_annotation/'
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
tasks = glob.glob(annotation_path + '*.txt')
|
||||
id_map = pd.read_csv('id_map.csv')
|
||||
for task in tasks:
|
||||
if not task.split('/')[-1].split('_')[2] == '1.txt':
|
||||
continue
|
||||
with open(task, 'r') as f:
|
||||
lines = f.readlines()
|
||||
task_id = int(task.split('/')[-1].split('_')[1]) + 1
|
||||
clip = id_map.loc[id_map['ID'] == task_id].folder
|
||||
print(task_id, len(clip))
|
||||
if len(clip) == 0:
|
||||
continue
|
||||
with open(save_path + clip.item() + '.txt', 'w') as f:
|
||||
for line in lines:
|
||||
words = line.split()
|
||||
f.write(words[0] + ',' + words[1] + ',' + words[2] + ',' + words[3] + ',' + words[4] + ',' + words[5] +
|
||||
',' + words[6] + ',' + words[7] + ',' + words[8] + ',' + words[9] + ',' + ' '.join(words[10:]) + '\n')
|
||||
f.close()
|
||||
|
||||
def get_grid_location(obj_frame):
|
||||
x_min = obj_frame['x_min']#.item()
|
||||
y_min = obj_frame['y_min']#.item()
|
||||
x_max = obj_frame['x_max']#.item()
|
||||
y_max = obj_frame['y_max']#.item()
|
||||
gridLW = 1280 / 25.
|
||||
gridLH = 720 / 15.
|
||||
center_x, center_y = (x_min + x_max)/2, (y_min + y_max)/2
|
||||
X, Y = int(center_x / gridLW), int(center_y / gridLH)
|
||||
return X, Y
|
||||
|
||||
def regenerate_annotation():
|
||||
annotation_path = f'{DATASET_LOCATION}/reformat_annotation/'
|
||||
save_path=f'{DATASET_LOCATION}/regenerate_annotation/'
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
tasks = glob.glob(annotation_path + '*.txt')
|
||||
for task in tasks:
|
||||
print(task)
|
||||
annt = pd.read_csv(task, sep=",", header=None)
|
||||
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name", "label"]
|
||||
obj_records = {}
|
||||
for index, obj_frame in annt.iterrows():
|
||||
if obj_frame['name'].startswith('P'):
|
||||
continue
|
||||
else:
|
||||
assert obj_frame['name'].startswith('O')
|
||||
obj_name = obj_frame['name']
|
||||
# 0: enter 1: disappear 2: update 3: unchange
|
||||
frame_id = obj_frame['frame']
|
||||
curr_loc = get_grid_location(obj_frame)
|
||||
mind_dict = {'m1': {'fluent': 3, 'loc': None}, 'm2': {'fluent': 3, 'loc': None},
|
||||
'm12': {'fluent': 3, 'loc': None},
|
||||
'm21': {'fluent': 3, 'loc': None}, 'mc': {'fluent': 3, 'loc': None},
|
||||
'mg': {'fluent': 3, 'loc': curr_loc}}
|
||||
mind_dict['mg']['loc'] = curr_loc
|
||||
if not type(obj_frame['label']) == float:
|
||||
mind_labels = obj_frame['label'].split()
|
||||
for mind_label in mind_labels:
|
||||
if mind_label == 'in_m1' or mind_label == 'in_m2' or mind_label == 'in_m12' \
|
||||
or mind_label == 'in_m21' or mind_label == 'in_mc' or mind_label == '"in_m1"' or mind_label == '"in_m2"'\
|
||||
or mind_label == '"in_m12"' or mind_label == '"in_m21"' or mind_label == '"in_mc"':
|
||||
mind_name = mind_label.split('_')[1].split('"')[0]
|
||||
mind_dict[mind_name]['loc'] = curr_loc
|
||||
else:
|
||||
mind_name = mind_label.split('_')[0].split('"')
|
||||
if len(mind_name) > 1:
|
||||
mind_name = mind_name[1]
|
||||
else:
|
||||
mind_name = mind_name[0]
|
||||
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
|
||||
mind_dict[mind_name]['loc'] = last_loc
|
||||
|
||||
for mind_name in mind_dict.keys():
|
||||
if frame_id > 0:
|
||||
curr_loc = mind_dict[mind_name]['loc']
|
||||
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
|
||||
if last_loc is None and curr_loc is not None:
|
||||
mind_dict[mind_name]['fluent'] = 0
|
||||
elif last_loc is not None and curr_loc is None:
|
||||
mind_dict[mind_name]['fluent'] = 1
|
||||
elif not last_loc == curr_loc:
|
||||
mind_dict[mind_name]['fluent'] = 2
|
||||
if obj_name not in obj_records:
|
||||
obj_records[obj_name] = [mind_dict]
|
||||
else:
|
||||
obj_records[obj_name].append(mind_dict)
|
||||
|
||||
with open(save_path + task.split('/')[-1].split('.')[0] + '.p', 'wb') as f:
|
||||
pickle.dump(obj_records, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reframe_annotation()
|
||||
regenerate_annotation()
|
75
tbd/utils/similarity.py
Normal file
75
tbd/utils/similarity.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.decomposition import PCA
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
FOLDER_PATH = 'PATH_TO_FOLDER'
|
||||
|
||||
print(FOLDER_PATH)
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
COLORS = sns.color_palette()
|
||||
|
||||
sns.set_theme(style='white')
|
||||
|
||||
out_left_main_mods_full_test = []
|
||||
out_right_main_mods_full_test = []
|
||||
cell_left_main_mods_full_test = []
|
||||
cell_right_main_mods_full_test = []
|
||||
cm_left_main_mods_full_test = []
|
||||
cm_right_main_mods_full_test = []
|
||||
|
||||
for i in range(len([filename for filename in os.listdir(FOLDER_PATH) if filename.endswith('.pt')])):
|
||||
|
||||
print(f'Computing analysis for test video {i}...', end='\r')
|
||||
|
||||
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
|
||||
data = torch.load(emb_file)
|
||||
if len(data) == 13: # implicit
|
||||
model = 'impl'
|
||||
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||||
elif len(data) == 12: # common mind
|
||||
model = 'cm'
|
||||
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||||
elif len(data) == 11: # speaker-listener
|
||||
model = 'sl'
|
||||
out_left, out_right, feats = data[0], data[1], data[2:]
|
||||
else: raise ValueError("Data should have 13 (impl), others are not implemented")
|
||||
|
||||
# ====== PCA for left and right embeddings ====== #
|
||||
|
||||
out_left_pca = out_left[0].reshape(-1, 64)
|
||||
out_right_pca = out_right[0].reshape(-1, 64)
|
||||
out_left_and_right = np.concatenate((out_left_pca, out_right_pca), axis=0)
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
pca_result = pca.fit_transform(out_left_and_right)
|
||||
|
||||
# Separate the PCA results for each tensor
|
||||
pca_result_left = pca_result[:out_left_pca.shape[0]]
|
||||
pca_result_right = pca_result[out_right_pca.shape[0]:]
|
||||
|
||||
plt.figure(figsize=(6.8,6))
|
||||
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
|
||||
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
|
||||
plt.xlabel('Principal Component 1', fontsize=32)
|
||||
plt.ylabel('Principal Component 2', fontsize=32)
|
||||
plt.xticks(fontsize=24)
|
||||
plt.xticks([-0.4, -0.2, 0.0, 0.2, 0.4])
|
||||
plt.yticks(fontsize=24)
|
||||
plt.legend(fontsize=32)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
|
||||
plt.close()
|
96
tbd/utils/store_mind_set.py
Normal file
96
tbd/utils/store_mind_set.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def check_append(obj_name, m1, mind_name, obj_frame, flags, label):
|
||||
if label:
|
||||
if not obj_name in m1:
|
||||
m1[obj_name] = []
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
|
||||
flags[mind_name] = 1
|
||||
elif not flags[mind_name]:
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
|
||||
flags[mind_name] = 1
|
||||
else: # false belief
|
||||
if obj_name in m1:
|
||||
if flags[mind_name]:
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 1])
|
||||
flags[mind_name] = 0
|
||||
return flags, m1
|
||||
|
||||
|
||||
def store_mind_set(clip, annotation_path, save_path):
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
annt = pd.read_csv(annotation_path + clip, sep=",", header=None)
|
||||
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name",
|
||||
"label"]
|
||||
obj_names = annt.name.unique()
|
||||
m1, m2, m12, m21, mc = {}, {}, {}, {}, {}
|
||||
flags = {'m1':0, 'm2':0, 'm12':0, 'm21':0, 'mc':0}
|
||||
for obj_name in obj_names:
|
||||
if obj_name == 'P1' or obj_name == 'P2':
|
||||
continue
|
||||
obj_frames = annt.loc[annt.name == obj_name]
|
||||
for index, obj_frame in obj_frames.iterrows():
|
||||
if type(obj_frame.label) == float:
|
||||
continue
|
||||
labels = obj_frame.label.split()
|
||||
for label in labels:
|
||||
if label == 'in_m1' or label == '"in_m1"':
|
||||
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 1)
|
||||
elif label == 'in_m2' or label == '"in_m2"':
|
||||
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 1)
|
||||
elif label == 'in_m12'or label == '"in_m12"':
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 1)
|
||||
elif label == 'in_m21' or label == '"in_m21"':
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 1)
|
||||
elif label == 'in_mc'or label == '"in_mc"':
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 1)
|
||||
elif label == 'm1_false' or label == '"m1_false"':
|
||||
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 0)
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
|
||||
false_belief = 'm1_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm2_false' or label == '"m2_false"':
|
||||
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 0)
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
|
||||
false_belief = 'm2_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm12_false' or label == '"m12_false"':
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
|
||||
false_belief = 'm12_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm21_false' or label == '"m21_false"':
|
||||
flags, m21 = check_append(obj_name, m2, 'm21', obj_frame, flags, 0)
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
|
||||
false_belief = 'm21_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
# print('m1', m1)
|
||||
# print('m2', m2)
|
||||
# print('m12', m12)
|
||||
# print('m21', m21)
|
||||
# print('mc', mc)
|
||||
#with open(save_path + clip.split('.')[0] + '.p', 'wb') as f:
|
||||
# pickle.dump([m1, m2, m12, m21, mc], f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
annotation_path = '/scratch/bortoletto/data/tbd/reformat_annotation/'
|
||||
save_path = '/scratch/bortoletto/data/tbd/store_mind_set/'
|
||||
|
||||
for clip in tqdm(os.listdir(annotation_path), desc="Processing videos", unit="item"):
|
||||
store_mind_set(clip, annotation_path, save_path)
|
95
tbd/utils/visualize_bbox.py
Normal file
95
tbd/utils/visualize_bbox.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import time
|
||||
from tbd_dataloader import TBDv2Dataset
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
|
||||
def point2screen(points):
|
||||
K = [607.13232421875, 0.0, 638.6468505859375, 0.0, 607.1067504882812, 367.1607360839844, 0.0, 0.0, 1.0]
|
||||
K = np.reshape(np.array(K), [3, 3])
|
||||
rot_points = np.array(points) + np.array([0, 0.2, 0])
|
||||
rot_points = rot_points
|
||||
points_camera = rot_points.reshape(3, 1)
|
||||
|
||||
project_matrix = np.array(K).reshape(3, 3)
|
||||
points_prj = project_matrix.dot(points_camera)
|
||||
points_prj = points_prj.transpose()
|
||||
if not points_prj[:, 2][0] == 0.0:
|
||||
points_prj[:, 0] = points_prj[:, 0] / points_prj[:, 2]
|
||||
points_prj[:, 1] = points_prj[:, 1] / points_prj[:, 2]
|
||||
points_screen = points_prj[:, :2]
|
||||
assert points_screen.shape == (1, 2)
|
||||
points_screen = points_screen.reshape(-1)
|
||||
return points_screen
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = TBDv2Dataset(number_frames_to_sample=1, resize_img=None)
|
||||
index = np.random.randint(0, len(data))
|
||||
start = time.time()
|
||||
(
|
||||
kinect_imgs, # <- len x 720 x 1280 x 3
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bbox,
|
||||
tracker_skeID_sample, # <- This is the tracker skeleton ID
|
||||
tracker2d,
|
||||
label,
|
||||
experiment_id, # From here for debugging
|
||||
timestep,
|
||||
obj_id, # <- This is the object ID as a string
|
||||
) = data[index]
|
||||
end = time.time()
|
||||
print(f"Time for one sample: {end-start}")
|
||||
|
||||
img = kinect_imgs[-1]
|
||||
bbox = bbox[-1]
|
||||
print(label.shape)
|
||||
|
||||
print(skele1.shape)
|
||||
print(skele2.shape)
|
||||
|
||||
skele1 = skele1[-1, :,:]
|
||||
skele2 = skele2[-1, :,:]
|
||||
|
||||
print(skele1.shape)
|
||||
|
||||
|
||||
|
||||
# reshape img from c, h, w to h, w, c
|
||||
img = img.permute(1, 2, 0)
|
||||
|
||||
fig, ax = plt.subplots(1)
|
||||
ax.imshow(img)
|
||||
print(bbox[0], bbox[1], bbox[2], bbox[3]) # t(top left x, top left y, width, height)
|
||||
top_left_x, top_left_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
|
||||
|
||||
|
||||
|
||||
for i in range(26):
|
||||
print(skele1[i,0], skele1[i,1])
|
||||
print(skele1[i,:].shape)
|
||||
print(point2screen(skele1[i,:]))
|
||||
x, y = point2screen(skele1[i,:])[0], point2screen(skele1[i,:])[1]
|
||||
ax.text(x, y, f"{i}", fontsize=5, color='w')
|
||||
|
||||
wedge = patches.Wedge((x,y), 10, 0, 360, width=10, color='b')
|
||||
ax.add_patch(wedge)
|
||||
|
||||
for i in range(26):
|
||||
x, y = point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]
|
||||
ax.text(x, y, f"{i}", fontsize=5, color='w')
|
||||
wedge = patches.Wedge((point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]), 10, 0, 360, width=10, color='r')
|
||||
ax.add_patch(wedge)
|
||||
|
||||
# Create a Rectangle patch
|
||||
# rect = patches.Rectangle((top_left_x, top_left_y-height), width, height, linewidth=1, edgecolor='r', facecolor='none')
|
||||
# ax.add_patch(rect)
|
||||
# rect = patches.Rectangle((x_min, y_max), x_max-x_min, y_max-y_min, linewidth=1, edgecolor='g', facecolor='none')
|
||||
# ax.add_patch(rect)
|
||||
fig.savefig(f"bbox_{obj_id}_{index}_{experiment_id}.png")
|
Loading…
Add table
Add a link
Reference in a new issue