This commit is contained in:
Matteo Bortoletto 2025-01-10 15:39:20 +01:00
parent d4aaf7f4ad
commit 25b8b3f343
55 changed files with 7592 additions and 4 deletions

View file

@ -22,11 +22,17 @@ This is a temporary reference that will be updated after the proceedings are pub
}
```
<br>
<br>
<br>
# Code
Under construction
This repository has the following structure:
```
mtomnet
├── boss
└── tbd
```
We have one subfolder for dataset, containing the code to run the corresponding experiments. Inside each subfolder we provide a README with further instructions.
[1]: https://mattbortoletto.github.io/

206
boss/.gitignore vendored Normal file
View file

@ -0,0 +1,206 @@
experiments/
predictions_*
predictions/
tmp/
### WANDB ###
wandb/*
wandb
wandb-debug.log
logs
summary*
run*
### SYSTEM ###
*/__pycache__/*
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# End of https://www.toptal.com/developers/gitignore/api/python,linux

16
boss/README.md Normal file
View file

@ -0,0 +1,16 @@
# BOSS
## Installing Dependencies
Run `conda env create -f environment.yml`.
## New bounding box annotations
We re-extracted bounding box annotations using Yolo-v8. The new data is in `new_bbox/`. The rest of the data can be found [here](https://drive.google.com/drive/folders/1b8FdpyoWx9gUps-BX6qbE9Kea3C2Uyua).
## Train
`source run_train.sh`.
## Test
`source run_test.sh` (specify the path to the model).
## Resources
The original project page for the BOSS dataset can be found [here](https://sites.google.com/view/bossbelief/). Our code is based on the original implementation.

500
boss/dataloader.py Normal file
View file

@ -0,0 +1,500 @@
from torch.utils.data import Dataset
import os
from torchvision import transforms
from natsort import natsorted
import numpy as np
import pickle
import torch
import cv2
import ast
from einops import rearrange
from utils import spherical2cartesial
class Data(Dataset):
def __init__(
self,
frame_path,
label_path,
pose_path=None,
gaze_path=None,
bbox_path=None,
ocr_graph_path=None,
presaved=128,
sizes = (128,128),
spatial_patch_size: int = 16,
temporal_patch_size: int = 4,
img_channels: int = 3,
patch_data: bool = False,
flatten_dim: int = 1
):
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
self.size_1, self.size_2 = sizes
self.patch_data = patch_data
if self.patch_data:
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
if presaved == 224:
self.presaved = 'presaved'
elif presaved == 128:
self.presaved = 'presaved128'
else: raise ValueError
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
seq_lens = []
for frame_dir in natsorted(frame_dirs):
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
seq_lens.append(len(frame_paths))
self.data['frame_paths'].append(frame_paths)
print('Labels...')
with open(label_path, 'rb') as fp:
labels = pickle.load(fp)
left_labels, right_labels = labels
for i in episode_ids:
episode_id = int(i)
left_labels_i = left_labels[episode_id]
right_labels_i = right_labels[episode_id]
self.data['labels'].append([left_labels_i, right_labels_i])
fp.close()
print('Pose...')
self.pose_path = pose_path
if pose_path is not None:
for i in episode_ids:
pose_i_dir = os.path.join(pose_path, i)
poses = []
for j in natsorted(os.listdir(pose_i_dir)):
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
else:
poses.append(fs.getNode("pose_0").mat())
poses = np.array(poses)
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
#self.data['poses'].append(torch.tensor(np.array(poses)))
print('Gaze...')
"""
Gaze information is represented by a 3D gaze vector based on the observing cameras Cartesian eye coordinate system
"""
self.gaze_path = gaze_path
if gaze_path is not None:
for i in episode_ids:
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
with open(gaze_i_txt, 'r') as fp:
gaze_i = fp.readlines()
fp.close()
gaze_i = ast.literal_eval(gaze_i[0])
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
for j in range(len(gaze_i)):
if len(gaze_i[j]) >= 2:
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
elif len(gaze_i[j]) == 1:
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
else:
continue
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
print('Bbox...')
self.bbox_path = bbox_path
if bbox_path is not None:
self.objects = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
"""
# NOTE: old bbox
for i in episode_ids:
bbox_i_dir = os.path.join(bbox_path, i)
with open(bbox_i_dir, 'rb') as fp:
bboxes_i = pickle.load(fp)
len_i = len(bboxes_i)
fp.close()
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
for j in range(len(bboxes_i)):
items_i_j, bboxes_i_j = bboxes_i[j]
for k in range(len(items_i_j)):
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
bboxes_i_j[k][0] / 1920, # * self.size_1,
bboxes_i_j[k][1] / 1088, # * self.size_2,
bboxes_i_j[k][2] / 1920, # * self.size_1,
bboxes_i_j[k][3] / 1088, # * self.size_2
]) # [x_min, y_min, x_max, y_max]
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
"""
# NOTE: new bbox
for i in episode_ids:
bbox_dir = os.path.join(bbox_path, i)
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
bbox_content = fp.readlines()
fp.close()
for bbox_content_line in bbox_content:
bbox_content_values = bbox_content_line.split()
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
print('OCR...\n')
self.ocr_graph_path = ocr_graph_path
if ocr_graph_path is not None:
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
ocr_tensor = torch.zeros((27, 27))
for ocr in ocr_graph:
obj = ocr[0]
contexts = ocr[1:]
total_context_count = sum([i[1] for i in contexts])
for context in contexts:
ocr_tensor[obj, context[0]] = context[1] / total_context_count
ocr_tensor = torch.flatten(ocr_tensor)
for i in episode_ids:
self.data['ocr_graph'].append(ocr_tensor)
self.frame_path = frame_path
self.transform = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.data['frame_paths'])
def __getitem__(self, idx):
frames_path = self.data['frame_paths'][idx][0].split('/')
frames_path.insert(5, self.presaved)
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
with open(frames_path, 'rb') as f:
images = pickle.load(f)
images = torch.stack(images)
if self.patch_data:
images = rearrange(
images,
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
p1=self.spatial_patch_size,
p2=self.spatial_patch_size,
p3=self.temporal_patch_size
)
if self.pose_path is not None:
pose = self.data['poses'][idx]
if self.patch_data:
pose = rearrange(
pose,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
pose = None
if self.gaze_path is not None:
gaze = self.data['gazes'][idx]
if self.patch_data:
gaze = rearrange(
gaze,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
gaze = None
if self.bbox_path is not None:
bbox = self.data['bboxes'][idx]
if self.patch_data:
bbox = rearrange(
bbox,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
bbox = None
if self.ocr_graph_path is not None:
ocr_graphs = self.data['ocr_graph'][idx]
else:
ocr_graphs = None
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
class DataTest(Dataset):
def __init__(
self,
frame_path,
label_path,
pose_path=None,
gaze_path=None,
bbox_path=None,
ocr_graph_path=None,
presaved=128,
frame_ids=None,
median=None,
sizes = (128,128),
spatial_patch_size: int = 16,
temporal_patch_size: int = 4,
img_channels: int = 3,
patch_data: bool = False,
flatten_dim: int = 1
):
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
self.size_1, self.size_2 = sizes
self.patch_data = patch_data
if self.patch_data:
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
if presaved == 224:
self.presaved = 'presaved'
elif presaved == 128:
self.presaved = 'presaved128'
else: raise ValueError
if frame_ids is not None:
test_ids = []
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
for frame_dir in episode_ids:
if int(frame_dir) in frame_ids:
test_ids.append(str(frame_dir))
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
self.data['frame_paths'].append(frame_paths)
elif median is not None:
test_ids = []
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
for frame_dir in episode_ids:
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
seq_len = len(frame_paths)
if (median[1] and seq_len >= median[0]) or (not median[1] and seq_len < median[0]):
self.data['frame_paths'].append(frame_paths)
test_ids.append(int(frame_dir))
else:
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
test_ids = episode_ids.copy()
for frame_dir in episode_ids:
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
self.data['frame_paths'].append(frame_paths)
print('Labels...')
with open(label_path, 'rb') as fp:
labels = pickle.load(fp)
left_labels, right_labels = labels
for i in test_ids:
episode_id = int(i)
left_labels_i = left_labels[episode_id]
right_labels_i = right_labels[episode_id]
self.data['labels'].append([left_labels_i, right_labels_i])
fp.close()
print('Pose...')
self.pose_path = pose_path
if pose_path is not None:
for i in test_ids:
pose_i_dir = os.path.join(pose_path, i)
poses = []
for j in natsorted(os.listdir(pose_i_dir)):
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
else:
poses.append(fs.getNode("pose_0").mat())
poses = np.array(poses)
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
print('Gaze...')
self.gaze_path = gaze_path
if gaze_path is not None:
for i in test_ids:
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
with open(gaze_i_txt, 'r') as fp:
gaze_i = fp.readlines()
fp.close()
gaze_i = ast.literal_eval(gaze_i[0])
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
for j in range(len(gaze_i)):
if len(gaze_i[j]) >= 2:
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
elif len(gaze_i[j]) == 1:
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
else:
continue
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
print('Bbox...')
self.bbox_path = bbox_path
if bbox_path is not None:
self.objects = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
""" NOTE: old bbox
for i in test_ids:
bbox_i_dir = os.path.join(bbox_path, i)
with open(bbox_i_dir, 'rb') as fp:
bboxes_i = pickle.load(fp)
len_i = len(bboxes_i)
fp.close()
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
for j in range(len(bboxes_i)):
items_i_j, bboxes_i_j = bboxes_i[j]
for k in range(len(items_i_j)):
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
bboxes_i_j[k][0] / 1920, # * self.size_1,
bboxes_i_j[k][1] / 1088, # * self.size_2,
bboxes_i_j[k][2] / 1920, # * self.size_1,
bboxes_i_j[k][3] / 1088, # * self.size_2
]) # [x_min, y_min, x_max, y_max]
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
"""
for i in test_ids:
bbox_dir = os.path.join(bbox_path, i)
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
bbox_content = fp.readlines()
fp.close()
for bbox_content_line in bbox_content:
bbox_content_values = bbox_content_line.split()
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
print('OCR...\n')
self.ocr_graph_path = ocr_graph_path
if ocr_graph_path is not None:
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
ocr_tensor = torch.zeros((27, 27))
for ocr in ocr_graph:
obj = ocr[0]
contexts = ocr[1:]
total_context_count = sum([i[1] for i in contexts])
for context in contexts:
ocr_tensor[obj, context[0]] = context[1] / total_context_count
ocr_tensor = torch.flatten(ocr_tensor)
for i in test_ids:
self.data['ocr_graph'].append(ocr_tensor)
self.frame_path = frame_path
self.transform = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize((self.size_1, self.size_2)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.data['frame_paths'])
def __getitem__(self, idx):
frames_path = self.data['frame_paths'][idx][0].split('/')
frames_path.insert(5, self.presaved)
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
with open(frames_path, 'rb') as f:
images = pickle.load(f)
images = torch.stack(images)
if self.patch_data:
images = rearrange(
images,
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
p1=self.spatial_patch_size,
p2=self.spatial_patch_size,
p3=self.temporal_patch_size
)
if self.pose_path is not None:
pose = self.data['poses'][idx]
if self.patch_data:
pose = rearrange(
pose,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
pose = None
if self.gaze_path is not None:
gaze = self.data['gazes'][idx]
if self.patch_data:
gaze = rearrange(
gaze,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
gaze = None
if self.bbox_path is not None:
bbox = self.data['bboxes'][idx]
if self.patch_data:
bbox = rearrange(
bbox,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
bbox = None
if self.ocr_graph_path is not None:
ocr_graphs = self.data['ocr_graph'][idx]
else:
ocr_graphs = None
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs

240
boss/environment.yml Normal file
View file

@ -0,0 +1,240 @@
name: boss
channels:
- pyg
- anaconda
- conda-forge
- defaults
- pytorch
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- alembic=1.8.0=pyhd8ed1ab_0
- appdirs=1.4.4=pyh9f0ad1d_0
- asttokens=2.2.1=pyhd8ed1ab_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
- blas=1.0=mkl
- blosc=1.21.3=h6a678d5_0
- bottle=0.12.23=pyhd8ed1ab_0
- bottleneck=1.3.5=py38h7deecbd_0
- brotli=1.0.9=h5eee18b_7
- brotli-bin=1.0.9=h5eee18b_7
- brotlipy=0.7.0=py38h27cfd23_1003
- brunsli=0.1=h2531618_0
- bzip2=1.0.8=h7b6447c_0
- c-ares=1.18.1=h7f8727e_0
- ca-certificates=2023.01.10=h06a4308_0
- certifi=2023.5.7=py38h06a4308_0
- cffi=1.15.1=py38h74dc2b5_0
- cfitsio=3.470=h5893167_7
- charls=2.2.0=h2531618_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- click=8.1.3=unix_pyhd8ed1ab_2
- cloudpickle=2.0.0=pyhd3eb1b0_0
- cmaes=0.9.1=pyhd8ed1ab_0
- colorlog=6.7.0=py38h578d9bd_1
- contourpy=1.0.5=py38hdb19cb5_0
- cryptography=39.0.1=py38h9ce1e76_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- cycler=0.11.0=pyhd3eb1b0_0
- cytoolz=0.12.0=py38h5eee18b_0
- dask-core=2022.7.0=py38h06a4308_0
- dbus=1.13.18=hb2f20db_0
- debugpy=1.5.1=py38h295c915_0
- decorator=5.1.1=pyhd8ed1ab_0
- docker-pycreds=0.4.0=py_0
- entrypoints=0.4=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- expat=2.4.9=h6a678d5_0
- ffmpeg=4.3=hf484d3e_0
- fftw=3.3.9=h27cfd23_1
- flit-core=3.6.0=pyhd3eb1b0_0
- fontconfig=2.14.1=h52c9d5c_1
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.12.1=h4a9f257_0
- fsspec=2022.11.0=py38h06a4308_0
- giflib=5.2.1=h5eee18b_1
- gitdb=4.0.10=pyhd8ed1ab_0
- gitpython=3.1.31=pyhd8ed1ab_0
- glib=2.69.1=h4ff587b_1
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- greenlet=2.0.1=py38h6a678d5_0
- gst-plugins-base=1.14.1=h6a678d5_1
- gstreamer=1.14.1=h5eee18b_1
- icu=58.2=he6710b0_3
- idna=3.4=py38h06a4308_0
- imagecodecs=2021.8.26=py38hf0132c2_1
- imageio=2.19.3=py38h06a4308_0
- importlib-metadata=6.1.0=pyha770c72_0
- importlib_resources=5.2.0=pyhd3eb1b0_1
- intel-openmp=2021.4.0=h06a4308_3561
- ipykernel=6.15.0=pyh210e3f2_0
- ipython=8.11.0=pyh41d4057_0
- jedi=0.18.2=pyhd8ed1ab_0
- jinja2=3.1.2=py38h06a4308_0
- joblib=1.1.1=py38h06a4308_0
- jpeg=9e=h7f8727e_0
- jupyter_client=7.0.6=pyhd8ed1ab_0
- jupyter_core=4.12.0=py38h578d9bd_0
- jxrlib=1.1=h7b6447c_2
- kiwisolver=1.4.4=py38h6a678d5_0
- krb5=1.19.4=h568e23c_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libaec=1.0.4=he6710b0_1
- libbrotlicommon=1.0.9=h5eee18b_7
- libbrotlidec=1.0.9=h5eee18b_7
- libbrotlienc=1.0.9=h5eee18b_7
- libclang=10.0.1=default_hb85057a_2
- libcurl=7.87.0=h91b91d3_0
- libdeflate=1.8=h7f8727e_5
- libedit=3.1.20221030=h5eee18b_0
- libev=4.33=h7f8727e_1
- libevent=2.1.12=h8f2d780_0
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h7f8727e_2
- libidn2=2.3.2=h7f8727e_0
- libllvm10=10.0.1=hbcb73fb_5
- libnghttp2=1.46.0=hce63b2e_0
- libpng=1.6.37=hbc83047_0
- libpq=12.9=h16c4e8d_3
- libprotobuf=3.20.3=he621ea3_0
- libsodium=1.0.18=h36c2ea0_1
- libssh2=1.10.0=h8f2d780_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.5.0=h6a678d5_1
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libwebp=1.2.4=h11a3e52_0
- libwebp-base=1.2.4=h5eee18b_0
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.0.1=hfa300c1_0
- libxml2=2.9.14=h74e7548_0
- libxslt=1.1.35=h4e12654_0
- libzopfli=1.0.3=he6710b0_0
- locket=1.0.0=py38h06a4308_0
- lz4-c=1.9.4=h6a678d5_0
- mako=1.2.4=pyhd8ed1ab_0
- markupsafe=2.1.1=py38h7f8727e_0
- matplotlib=3.7.0=py38h06a4308_0
- matplotlib-base=3.7.0=py38h417a72b_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.1=py38hd3c417c_0
- mkl_random=1.2.2=py38h51133e4_0
- munkres=1.1.4=py_0
- natsort=7.1.1=pyhd3eb1b0_0
- ncurses=6.3=h5eee18b_3
- nest-asyncio=1.5.6=pyhd8ed1ab_0
- nettle=3.7.3=hbbd107a_1
- networkx=2.8.4=py38h06a4308_0
- nspr=4.33=h295c915_0
- nss=3.74=h0370c37_0
- numexpr=2.8.4=py38he184ba9_0
- numpy=1.23.5=py38h14f4228_0
- numpy-base=1.23.5=py38h31eccc5_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.4.0=h3ad879b_0
- openssl=1.1.1t=h7f8727e_0
- optuna=3.1.0=pyhd8ed1ab_0
- optuna-dashboard=0.9.0=pyhd8ed1ab_0
- packaging=22.0=py38h06a4308_0
- pandas=1.5.3=py38h417a72b_0
- parso=0.8.3=pyhd8ed1ab_0
- partd=1.2.0=pyhd3eb1b0_1
- pathtools=0.1.2=py_1
- pcre=8.45=h295c915_0
- pexpect=4.8.0=pyh1a96a4e_2
- pickleshare=0.7.5=py_1003
- pillow=9.4.0=py38h6a678d5_0
- pip=22.3.1=py38h06a4308_0
- ply=3.11=py38_0
- prompt-toolkit=3.0.38=pyha770c72_0
- prompt_toolkit=3.0.38=hd8ed1ab_0
- protobuf=3.20.3=py38h6a678d5_0
- psutil=5.9.0=py38h5eee18b_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pycparser=2.21=pyhd3eb1b0_0
- pyg=2.2.0=py38_torch_1.12.0_cu113
- pygments=2.14.0=pyhd8ed1ab_0
- pyopenssl=23.0.0=py38h06a4308_0
- pyparsing=3.0.9=py38h06a4308_0
- pyqt=5.15.7=py38h6a678d5_1
- pyqt5-sip=12.11.0=py38h6a678d5_1
- pysocks=1.7.1=py38h06a4308_0
- python=3.8.13=haa1d7c7_1
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python_abi=3.8=2_cp38
- pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0
- pytorch-cluster=1.6.0=py38_torch_1.12.0_cu113
- pytorch-mutex=1.0=cuda
- pytorch-scatter=2.1.0=py38_torch_1.12.0_cu113
- pytorch-sparse=0.6.16=py38_torch_1.12.0_cu113
- pytz=2022.7=py38h06a4308_0
- pywavelets=1.4.1=py38h5eee18b_0
- pyyaml=6.0=py38h5eee18b_1
- pyzmq=19.0.2=py38ha71036d_2
- qt-main=5.15.2=h327a75a_7
- qt-webengine=5.15.9=hd2b0992_4
- qtwebkit=5.212=h4eab89a_4
- readline=8.2=h5eee18b_0
- requests=2.28.1=py38h06a4308_1
- scikit-image=0.19.3=py38h6a678d5_1
- scikit-learn=1.2.1=py38h6a678d5_0
- scipy=1.9.3=py38h14f4228_0
- sentry-sdk=1.16.0=pyhd8ed1ab_0
- setproctitle=1.2.2=py38h0a891b7_2
- setuptools=65.6.3=py38h06a4308_0
- sip=6.6.2=py38h6a678d5_0
- six=1.16.0=pyhd3eb1b0_1
- smmap=3.0.5=pyh44b312d_0
- snappy=1.1.9=h295c915_0
- sqlalchemy=1.4.39=py38h5eee18b_0
- sqlite=3.40.1=h5082296_0
- stack_data=0.6.2=pyhd8ed1ab_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tifffile=2021.7.2=pyhd3eb1b0_2
- tk=8.6.12=h1ccaba5_0
- toml=0.10.2=pyhd3eb1b0_0
- toolz=0.12.0=py38h06a4308_0
- torchaudio=0.12.0=py38_cu113
- torchvision=0.13.0=py38_cu113
- tornado=6.1=py38h0a891b7_3
- tqdm=4.64.1=py38h06a4308_0
- traitlets=5.9.0=pyhd8ed1ab_0
- typing_extensions=4.4.0=py38h06a4308_0
- tzdata=2022g=h04d1e81_0
- urllib3=1.26.14=py38h06a4308_0
- wandb=0.13.10=pyhd8ed1ab_0
- wcwidth=0.2.6=pyhd8ed1ab_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.10=h5eee18b_1
- yaml=0.2.5=h7b6447c_0
- zeromq=4.3.4=h9c3ff4c_1
- zfp=0.5.5=h295c915_6
- zipp=3.11.0=py38h06a4308_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.2=ha4553b6_0
- pip:
- einops==0.6.1
- lion-pytorch==0.1.2
- llvmlite==0.40.0
- memory-efficient-attention-pytorch==0.1.2
- numba==0.57.0
- opencv-python==4.7.0.72
- pynndescent==0.5.10
- seaborn==0.12.2
- umap-learn==0.5.3
- x-transformers==1.14.0
prefix: /opt/anaconda3/envs/boss

0
boss/models/__init__.py Normal file
View file

230
boss/models/base.py Normal file
View file

@ -0,0 +1,230 @@
import torch
import torch.nn as nn
from .utils import pose_edge_index
from torch_geometric.nn import Sequential, GCNConv
from x_transformers import ContinuousTransformerWrapper, Decoder
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
def forward(self, x):
return self.net(x)
class CNN(nn.Module):
def __init__(self, hidden_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv3(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
return x
class MindNetLSTM(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTM, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats
class MindNetSL(nn.Module):
"""
Basic MindNet for SL ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetSL, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*5,
nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
return self.act(self.proj(lstm_out)), feats
class MindNetTF(nn.Module):
"""
Basic MindNet for model-based ToM, Transformer on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetTF, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.tf = ContinuousTransformerWrapper(
dim_in=hidden_dim*len(mods),
dim_out=hidden_dim,
max_seq_len=747,
attn_layers=Decoder(
dim=512,
depth=6,
heads=8
)
)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
tf_inp = torch.cat(feats, 2)
tf_out = self.tf(self.dropout(tf_inp))
return tf_out, feats
class MindNetLSTMXL(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTMXL, self).__init__()
self.mods = mods
self.gaze_emb = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.pose_edge_index = pose_edge_index()
self.pose_emb = Sequential('x, edge_index', [
(GCNConv(3, hidden_dim), 'x, edge_index -> x'),
nn.GELU(),
(GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'),
])
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats

249
boss/models/resnet.py Normal file
View file

@ -0,0 +1,249 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .utils import left_bias, right_bias
class ResNet(nn.Module):
def __init__(self, input_dim, device):
super(ResNet, self).__init__()
# Conv
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
# FFs
self.left = nn.Linear(input_dim, 27)
self.right = nn.Linear(input_dim, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
image_feats = []
for i in range(sequence_len):
images_i = images[:,i].to(self.device)
image_i_feat = self.resnet(images_i)
image_i_feat = image_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1)
image_feats.append(image_i_feat)
image_feats = torch.permute(torch.stack(image_feats), (1,0,2))
left_beliefs = self.left(self.dropout(image_feats))
right_beliefs = self.right(self.dropout(image_feats))
return left_beliefs, right_beliefs, None
class ResNetGRU(nn.Module):
def __init__(self, input_dim, device):
super(ResNetGRU, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.gru = nn.GRU(input_dim, 512, batch_first=True)
for name, param in self.gru.named_parameters():
if "weight" in name:
nn.init.orthogonal_(param)
elif "bias" in name:
nn.init.constant_(param, 0)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.dropout = nn.Dropout(p=0.2)
self.relu = nn.ReLU()
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
rnn_inp = []
for i in range(sequence_len):
images_i = images[:,i]
rnn_i_feat = self.resnet(images_i)
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
rnn_inp.append(rnn_i_feat)
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
rnn_out, _ = self.gru(rnn_inp)
left_beliefs = self.left(self.dropout(rnn_out))
right_beliefs = self.right(self.dropout(rnn_out))
return left_beliefs, right_beliefs, None
class ResNetConv1D(nn.Module):
def __init__(self, input_dim, device):
super(ResNetConv1D, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
conv1d_inp = []
for i in range(sequence_len):
images_i = images[:,i]
images_i_feat = self.resnet(images_i)
images_i_feat = images_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1)
conv1d_inp.append(images_i_feat)
conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0))
conv1d_out = self.conv1d(conv1d_inp)
conv1d_out = conv1d_out[:,:,:-4]
conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1)))
left_beliefs = self.left(self.dropout(conv1d_out))
right_beliefs = self.right(self.dropout(conv1d_out))
return left_beliefs, right_beliefs, None
class ResNetLSTM(nn.Module):
def __init__(self, input_dim, device):
super(ResNetLSTM, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
rnn_inp = []
for i in range(sequence_len):
images_i = images[:,i]
rnn_i_feat = self.resnet(images_i)
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
rnn_inp.append(rnn_i_feat)
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
rnn_out, _ = self.lstm(rnn_inp)
left_beliefs = self.left(self.dropout(rnn_out))
right_beliefs = self.right(self.dropout(rnn_out))
return left_beliefs, right_beliefs, None
if __name__ == '__main__':
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 150)
gazes = torch.ones(3, 22, 6)
bboxes = torch.ones(3, 22, 108)
model = ResNet(32, 'cpu')
print(model(images, poses, gazes, bboxes, None)[0].shape)

View file

@ -0,0 +1,95 @@
import torch
import torch.nn as nn
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
import torchvision.models as models
class SingleMindNet(nn.Module):
"""
Base ToM net. Supports any subset of modalities
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(SingleMindNet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net = MindNetLSTM(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if __name__ == "__main__":
def count_parameters(model):
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = SingleMindNet(64, 'cpu', False, 0.5)
print(count_parameters(model))
breakpoint()
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

104
boss/models/tom_base.py Normal file
View file

@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
import numpy as np
class BaseToMnet(nn.Module):
"""
Base ToM net. Supports any subset of modalities
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(BaseToMnet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = BaseToMnet(64, 'cpu', False, 0.5)
print(count_parameters(model))
breakpoint()
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

View file

@ -0,0 +1,269 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM, MindNetLSTMXL
from memory_efficient_attention_pytorch import Attention
class CommonMindToMnet(nn.Module):
"""
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(CommonMindToMnet, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
#for param in self.cnn.parameters():
# param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right)
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=common_mind)
r = self.attn_right(x=out_right, context=common_mind)
elif self.aggr == 'mult':
l = out_left * common_mind
r = out_right * common_mind
elif self.aggr == 'sum':
l = out_left + common_mind
r = out_right + common_mind
elif self.aggr == 'concat':
l = torch.cat([out_left, common_mind], 1)
r = torch.cat([out_right, common_mind], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
class CommonMindToMnetXL(nn.Module):
"""
XL model.
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(CommonMindToMnetXL, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTMXL(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTMXL(hidden_dim, dropout, mods)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 27),
)
self.right = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 27)
)
self.left[-1].bias.data = torch.tensor(left_bias).log()
self.right[-1].bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right)
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=common_mind)
r = self.attn_right(x=out_right, context=common_mind)
elif self.aggr == 'mult':
l = out_left * common_mind
r = out_right * common_mind
elif self.aggr == 'sum':
l = out_left + common_mind
r = out_right + common_mind
elif self.aggr == 'concat':
l = torch.cat([out_left, common_mind], 1)
r = torch.cat([out_right, common_mind], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = CommonMindToMnetXL(64, 'cpu', False, 0.5, aggr='attn')
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

144
boss/models/tom_implicit.py Normal file
View file

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class ImplicitToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(ImplicitToMnet, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=cell_right)
r = self.attn_right(x=out_right, context=cell_left)
elif self.aggr == 'mult':
l = out_left * cell_right
r = out_right * cell_left
elif self.aggr == 'sum':
l = out_left + cell_right
r = out_right + cell_left
elif self.aggr == 'concat':
l = torch.cat([out_left, cell_right], 1)
r = torch.cat([out_right, cell_left], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = ImplicitToMnet(64, 'cpu', False, 0.5, aggr='attn')
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

98
boss/models/tom_sl.py Normal file
View file

@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph, pose_edge_index
from .base import CNN, MindNetSL
class SLToMnet(nn.Module):
"""
Speaker-Listener ToMnet
"""
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(SLToMnet, self).__init__()
self.tom_weight = tom_weight
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetSL(hidden_dim, dropout, mods)
self.mind_net_right = MindNetSL(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
left_logits = self.left(out_left)
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
right_logits = self.left(out_right)
left_ranking = torch.log_softmax(left_logits, dim=-1)
right_ranking = torch.log_softmax(right_logits, dim=-1)
right_beliefs = left_ranking + self.tom_weight * right_ranking
left_beliefs = right_ranking + self.tom_weight * left_ranking
return left_beliefs, right_beliefs, [out_left, out_right] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = SLToMnet(64, 'cpu', 2.0, False, 0.5)
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

104
boss/models/tom_tf.py Normal file
View file

@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetTF
from memory_efficient_attention_pytorch import Attention
class TFToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(TFToMnet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetTF(hidden_dim, dropout, mods)
self.mind_net_right = MindNetTF(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
l = self.dropout(self.act(out_left))
r = self.dropout(self.act(out_right))
left_beliefs = self.left(l)
right_beliefs = self.right(r)
return left_beliefs, right_beliefs, feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = TFToMnet(64, 'cpu', False, 0.5)
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

95
boss/models/utils.py Normal file
View file

@ -0,0 +1,95 @@
import torch
left_bias = [0.10659290976303822,
0.025158905348262015,
0.02811095449589107,
0.026342384511050237,
0.025318475572458178,
0.02283183957873461,
0.021581872822531316,
0.08062285577511237,
0.03824366373234754,
0.04853594319300018,
0.09653998563867983,
0.02961357410707162,
0.02961357410707162,
0.03172787957767081,
0.029985904630196004,
0.02897529321028696,
0.06602218026116327,
0.015345336560197867,
0.026900880295736816,
0.024879657455918726,
0.028669450280577644,
0.01936118720246802,
0.02341693040078721,
0.014707055663413206,
0.027007260445200926,
0.04146166325363687,
0.04243238211749688]
right_bias = [0.13147256721895695,
0.012433179968617855,
0.01623627031195979,
0.013683146724821148,
0.015252253929416771,
0.012579452674131008,
0.03127576394244834,
0.10325523257360177,
0.041155820323927554,
0.06563655221935587,
0.12684503071726816,
0.016156485199861705,
0.0176989973670913,
0.020238823435546928,
0.01918831945958884,
0.01791175766601952,
0.08768383819579266,
0.019002154198026647,
0.029600276588388607,
0.01578415467673732,
0.0176989973670913,
0.011834791627882237,
0.014919815962341426,
0.007552990611951809,
0.029759846812584773,
0.04981250498656951,
0.05533097524002021]
def build_ocr_graph(device):
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
edge_index = []
edge_attr = []
for i in range(len(ocr_graph)):
for j in range(1, len(ocr_graph[i])):
source_node = ocr_graph[i][0]
target_node = ocr_graph[i][j][0]
edge_index.append([source_node, target_node])
edge_attr.append(ocr_graph[i][j][1])
ocr_edge_index = torch.tensor(edge_index).t().long()
ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1)
x = torch.arange(0, 27)
ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float)
return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device)
def pose_edge_index():
return torch.tensor(
[[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20],
[15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]],
dtype=torch.long)

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
boss/outfile Normal file

Binary file not shown.

View file

@ -0,0 +1,82 @@
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style='whitegrid')
COLORS = sns.color_palette()
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
abl_old_bbox_mean = [0.539406718, 0.5348262324, 0.529845863]
abl_old_bbox_std = [0.03639921819, 0.01519544901, 0.01718265794]
filename = 'results/abl_cm_concat_old_vs_new_bbox'
#def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
means = []
stds = []
for key, values in abl_tom_cm_concat.items():
mean = np.mean(values)
std = np.std(values)
means.append(mean)
stds.append(std)
fig, ax = plt.subplots(figsize=(10,4))
x = np.arange(len(abl_tom_cm_concat))
width = 0.6
rects1 = ax.bar(
x, means, width, label='New bbox', yerr=stds,
capsize=5,
color=[MTOM_COLORS['CG']]*8,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
rects2 = ax.bar(
[0, 2, 5], abl_old_bbox_mean, width, label='Original bbox', yerr=abl_old_bbox_std,
capsize=5,
color=[COLORS[9]]*8,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
ax.set_ylabel('Accuracy', fontsize=18)
xticklabels = list(abl_tom_cm_concat.keys())
ax.set_xticks(np.arange(len(xticklabels)))
ax.set_xticklabels(xticklabels, rotation=0, fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
# Add value labels above each bar, avoiding overlapping with error bars
for rect, std in zip(rects1, stds):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
for rect, std in zip(rects2, abl_old_bbox_std):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.legend(fontsize=14)
ax.grid(axis='x')
plt.tight_layout()
plt.savefig(f'{filename}.pdf', bbox_inches='tight')

82
boss/plots/pca.py Normal file
View file

@ -0,0 +1,82 @@
import torch
import os
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-22_12-00-38_train_None" # no_tom seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-39-41_train_None" # impl mult seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-47-07_train_None" # impl sum seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_15-58-44_train_None" # impl attn seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-58-04_train_None" # impl concat seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-40-01_train_None" # cm mult seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-45-55_train_None" # cm sum seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-50-42_train_None" # cm attn seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-57-15_train_None" # cm concat seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-37-50_train_None" # db seed 1
print(FOLDER_PATH)
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
sns.set_theme(style='white')
for i in range(60):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 14: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
out_left = out_left.squeeze(0)
cell_left = cell_left.squeeze(0)
out_right = out_right.squeeze(0)
cell_right = cell_right.squeeze(0)
elif len(data) == 13: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
common_mind = common_mind.squeeze(0)
elif len(data) == 12: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
# ====== PCA for left and right embeddings ====== #
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
pca_result_left = pca_result[:out_left.shape[0]]
pca_result_right = pca_result[out_right.shape[0]:]
plt.figure(figsize=(7,6))
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
plt.xlabel('Principal Component 1', fontsize=30)
plt.ylabel('Principal Component 2', fontsize=30)
plt.grid(False)
plt.legend(fontsize=30)
plt.tight_layout()
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
plt.close()

View file

@ -0,0 +1,42 @@
{
"all": [
0.7402937327,
0.7171004799,
0.7305511124
],
"rgb\npose\ngaze": [
0.4736803839,
0.4658281227,
0.4948378653
],
"rgb\nocr\nbbox": [
0.7364403083,
0.7407663225,
0.7321506471
],
"rgb\ngaze": [
0.5013814163,
0.418859968,
0.4644103534
],
"rgb\npose": [
0.4545586738,
0.4584484514,
0.4525229024
],
"rgb\nbbox": [
0.716591537,
0.7334593573,
0.758324851
],
"rgb\nocr": [
0.4581939799,
0.456449033,
0.4577577432
],
"rgb": [
0.4896393776,
0.4907299695,
0.4583757452
]
}

Binary file not shown.

72
boss/results/all.json Normal file
View file

@ -0,0 +1,72 @@
{
"no_tom": [
0.6504653192,
0.6404682274,
0.6666787844
],
"sl": [
0.6584993456,
0.6608259415,
0.6584629926
],
"cm mult": [
0.6739130435,
0.6872182638,
0.6591173477
],
"cm sum": [
0.5820852116,
0.6401774029,
0.6001163298
],
"cm attn": [
0.3240511851,
0.2688672386,
0.3028573506
],
"cm concat": [
0.7402937327,
0.7171004799,
0.7305511124
],
"impl mult": [
0.7119746983,
0.7019776065,
0.7244437982
],
"impl sum": [
0.6542460375,
0.6642431293,
0.6678420823
],
"impl attn": [
0.3311763851,
0.3125636179,
0.3112549077
],
"impl concat": [
0.6957975862,
0.7227352043,
0.6971062964
],
"resnet": [
0.467173186,
0.4267485822,
0.4195870292
],
"gru": [
0.5724152974,
0.525628908,
0.4825141777
],
"lstm": [
0.6210193398,
0.5547477098,
0.6572633416
],
"conv1d": [
0.4478333576,
0.4114802966,
0.3853060928
]
}

BIN
boss/results/all.pdf Normal file

Binary file not shown.

181
boss/test.py Normal file
View file

@ -0,0 +1,181 @@
import argparse
import torch
from tqdm import tqdm
import csv
import os
from torch.utils.data import DataLoader
from dataloader import DataTest
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
from models.tom_implicit import ImplicitToMnet
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
from models.tom_sl import SLToMnet
from models.single_mindnet import SingleMindNet
from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim
def test(args):
if args.test_frames == 'friends':
test_frame_ids = friends
elif args.test_frames == 'strangers':
test_frame_ids = strangers
elif args.test_frames == 'once':
test_frame_ids = tried_once
elif args.test_frames == 'twice_thrice':
test_frame_ids = tried_twice + tried_thrice
elif args.test_frames is None:
test_frame_ids = None
else: raise NameError
if args.median is not None:
median = (240, False)
else:
median = None
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
flatten_dim = 2
else:
flatten_dim = 1
# load datasets
test_dataset = DataTest(
args.test_frame_path,
args.label_path,
args.test_pose_path,
args.test_gaze_path,
args.test_bbox_path,
args.ocr_graph_path,
args.presaved,
test_frame_ids,
median,
flatten_dim=flatten_dim
)
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
assert args.load_model_path is not None
if args.model_type == 'resnet':
inp_dim = get_input_dim(args.mods)
model = ResNet(inp_dim, device).to(device)
elif args.model_type == 'gru':
inp_dim = get_input_dim(args.mods)
model = ResNetGRU(inp_dim, device).to(device)
elif args.model_type == 'lstm':
inp_dim = get_input_dim(args.mods)
model = ResNetLSTM(inp_dim, device).to(device)
elif args.model_type == 'conv1d':
inp_dim = get_input_dim(args.mods)
model = ResNetConv1D(inp_dim, device).to(device)
elif args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_single':
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_cm_xl':
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
model.device = device
model.eval()
if args.save_preds:
# Define the output file path
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(f'Saving predictions in {folder_path}.')
print('Testing...')
num_correct = 0
cnt = 0
with torch.no_grad():
for j, batch in tqdm(enumerate(test_dataloader)):
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=True)
if poses is not None: poses = poses.to(device, non_blocking=True)
if gazes is not None: gazes = gazes.to(device, non_blocking=True)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True)
pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
cnt += batch_num_pred
num_correct += batch_num_correct
if args.save_preds:
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
data = [(
i,
torch.argmax(pred_left_labels[i]).cpu().numpy(),
torch.argmax(pred_right_labels[i]).cpu().numpy(),
labels[:, 0][i].cpu().numpy(),
labels[:, 1][i].cpu().numpy()) for i in range(len(labels))
]
header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label']
with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header) # Write the header row
writer.writerows(data) # Write the data rows
test_acc = num_correct / cnt
print("Test accuracy: {}".format(num_correct / cnt))
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
f.write(str(test_acc))
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='mult', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame')
parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose')
parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze')
parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels')
parser.add_argument('--ocr_graph_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='outfile')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--test_frames', type=str, default=None)
parser.add_argument('--median', type=int, default=None)
parser.add_argument('--load_model_path', type=str)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--save_preds', action='store_true')
# Parse the command-line arguments
args = parser.parse_args()
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
test(args)

324
boss/train.py Normal file
View file

@ -0,0 +1,324 @@
import torch
import os
import argparse
import numpy as np
import random
import datetime
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from dataloader import Data
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
from models.tom_implicit import ImplicitToMnet
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
from models.tom_sl import SLToMnet
from models.tom_tf import TFToMnet
from models.single_mindnet import SingleMindNet
from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim
def train(args):
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
flatten_dim = 2
else:
flatten_dim = 1
train_dataset = Data(
args.train_frame_path,
args.label_path,
args.train_pose_path,
args.train_gaze_path,
args.train_bbox_path,
args.ocr_graph_path,
presaved=args.presaved,
flatten_dim=flatten_dim
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
val_dataset = Data(
args.val_frame_path,
args.label_path,
args.val_pose_path,
args.val_gaze_path,
args.val_bbox_path,
args.ocr_graph_path,
presaved=args.presaved,
flatten_dim=flatten_dim
)
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
if args.model_type == 'resnet':
inp_dim = get_input_dim(args.mods)
model = ResNet(inp_dim, device).to(device)
elif args.model_type == 'gru':
inp_dim = get_input_dim(args.mods)
model = ResNetGRU(inp_dim, device).to(device)
elif args.model_type == 'lstm':
inp_dim = get_input_dim(args.mods)
model = ResNetLSTM(inp_dim, device).to(device)
elif args.model_type == 'conv1d':
inp_dim = get_input_dim(args.mods)
model = ResNetConv1D(inp_dim, device).to(device)
elif args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_cm_xl':
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_single':
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_tf':
model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
else: raise NotImplementedError
if args.resume_from_checkpoint is not None:
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.scheduler == None:
scheduler = None
else:
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1)
else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device)
stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}}
max_val_classification_acc = 0
max_val_classification_epoch = None
counter = 0
print(f'Number of parameters: {count_parameters(model)}')
for i in range(args.num_epoch):
# training
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
temp_train_classification_loss = []
epoch_num_correct = 0
epoch_cnt = 0
model.train()
for j, batch in tqdm(enumerate(train_dataloader)):
if args.use_mixup:
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27)
else:
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
if args.use_mixup:
labels = torch.reshape(labels, (-1, 54)).to(device)
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:])
else:
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1])
epoch_cnt += batch_num_pred
epoch_num_correct += batch_num_correct
temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2)
optimizer.zero_grad()
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
loss.backward()
optimizer.step()
if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']})
print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format(
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc)
)
if scheduler: scheduler.step()
print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format(
i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
)
stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt)
stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt)
# validation
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
temp_val_classification_loss = []
epoch_num_correct = 0
epoch_cnt = 0
model.eval()
with torch.no_grad():
for j, batch in tqdm(enumerate(val_dataloader)):
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
epoch_cnt += batch_num_pred
epoch_num_correct += batch_num_correct
loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1])
temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2)
if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()})
print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format(
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc)
)
print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format(
i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
)
cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt
cls_acc = epoch_num_correct / epoch_cnt
stats['val']['cls_loss'].append(cls_loss)
stats['val']['cls_acc'].append(cls_acc)
if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i})
# check for best stat/model using validation accuracy
if stats['val']['cls_acc'][-1] >= max_val_classification_acc:
max_val_classification_acc = stats['val']['cls_acc'][-1]
max_val_classification_epoch = i+1
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
counter = 0
else:
counter += 1
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
if counter >= args.patience:
break
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
f.write('{}\n'.format(CFG))
f.write('{}\n'.format(stats))
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc))
f.close()
print(f'Results saved in {experiment_save_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--logger', action='store_true')
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
parser.add_argument('--use_mixup', action='store_true')
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--patience', type=int, default=99)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--num_epoch', type=int, default=300)
parser.add_argument('--lr', type=float, default=4e-4)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=0.005)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='mult', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame')
parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose')
parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze')
parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels')
parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame')
parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose')
parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze')
parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels')
parser.add_argument('--ocr_graph_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='outfile')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
# Parse the command-line arguments
args = parser.parse_args()
if args.use_mixup and not args.mixup_alpha:
parser.error("--use_mixup requires --mixup_alpha")
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
# get experiment ID
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path, exist_ok=True)
experiment_save_path = os.path.join(args.save_path, experiment_id)
os.makedirs(experiment_save_path, exist_ok=True)
CFG = {
'use_ocr_custom_loss': 0,
'presaved': args.presaved,
'batch_size': args.batch_size,
'num_epoch': args.num_epoch,
'lr': args.lr,
'scheduler': args.scheduler,
'weight_decay': args.weight_decay,
'model_type': args.model_type,
'use_resnet': args.use_resnet,
'hidden_dim': args.hidden_dim,
'tom_weight': args.tom_weight,
'dropout': args.dropout,
'label_smoothing': args.label_smoothing,
'clip_grad_norm': args.clip_grad_norm,
'use_mixup': args.use_mixup,
'mixup_alpha': args.mixup_alpha,
'non_blocking_tensors': args.non_blocking,
'patience': args.patience,
'pin_memory': args.pin_memory,
'resume_from_checkpoint': args.resume_from_checkpoint,
'aggr': args.aggr,
'mods': args.mods,
'save_path': experiment_save_path ,
'seed': args.seed
}
print(CFG)
print(f'Saving results in {experiment_save_path}')
# set seed values
if args.logger:
wandb.init(project="boss", config=CFG)
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
train(args)

769
boss/utils.py Normal file
View file

@ -0,0 +1,769 @@
import os
import torch
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
#from umap import UMAP
import json
import seaborn as sns
import pandas as pd
from natsort import natsorted
import argparse
import seaborn as sns
COLORS_DM = {
"red": (242/256, 165/256, 179/256),
"blue": (195/256, 219/256, 252/256),
"green": (156/256, 228/256, 213/256),
"yellow": (250/256, 236/256, 144/256),
"violet": (207/256, 187/256, 244/256),
"orange": (244/256, 188/256, 154/256)
}
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
COLORS = sns.color_palette()
OBJECTS = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
tried_once = [2, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 32, 38, 40, 41, 42, 44, 47, 50, 51, 52, 53, 55,
56, 57, 61, 63, 65, 67, 68, 70, 71, 72, 73, 74, 76, 80, 81, 83, 85, 87,
88, 89, 90, 92, 93, 96, 97, 99, 101, 102, 105, 106, 108, 110, 111, 112,
113, 114, 116, 118, 121, 123, 125, 131, 132, 134, 135, 140, 142, 143,
145, 146, 148, 149, 151, 152, 154, 155, 156, 157, 160, 161, 162, 165,
169, 170, 171, 173, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185,
186, 187, 190, 191, 194, 196, 203, 204, 206, 207, 208, 209, 210, 211,
213, 214, 216, 218, 219, 220, 222, 225, 227, 228, 229, 232, 233, 235,
236, 237, 238, 239, 242, 243, 246, 247, 249, 251, 252, 254, 255, 256,
257, 260, 261, 262, 263, 265, 266, 268, 270, 272, 275, 277, 278, 279,
281, 282, 287, 290, 296, 298]
tried_twice = [1, 3, 11, 13, 14, 16, 17, 31, 33, 35, 36, 37, 39,
43, 45, 49, 54, 58, 59, 60, 62, 64, 66, 69, 75, 77, 79, 82, 84, 86,
91, 94, 95, 98, 100, 103, 104, 107, 109, 115, 117, 119, 120, 122,
124, 126, 127, 128, 129, 130, 133, 136, 137, 138, 139, 141, 144,
147, 150, 153, 158, 159, 164, 166, 167, 168, 172, 174, 177, 188,
189, 192, 193, 195, 197, 198, 199, 200, 201, 202, 205, 212, 215,
217, 221, 223, 224, 226, 230, 231, 234, 240, 241, 244, 245, 248,
250, 253, 258, 259, 264, 267, 269, 271, 273, 274, 276, 280, 283,
284, 285, 286, 288, 291, 292, 293, 294, 295, 297, 299]
tried_thrice = [34, 46, 48, 78, 163, 289]
friends = [i for i in range(150, 300)]
strangers = [i for i in range(0, 150)]
def get_input_dim(modalities):
dimensions = {
'rgb': 512,
'pose': 150,
'gaze': 6,
'ocr': 64,
'bbox': 108
}
sum_of_dimensions = sum(dimensions[modality] for modality in modalities)
return sum_of_dimensions
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
def onehot(label, n_classes):
if len(label.size()) == 3:
batch_size, seq_len, _ = label.size()
label_1_2d = label[:,:,0].contiguous().view(batch_size*seq_len, 1)
label_2_2d = label[:,:,1].contiguous().view(batch_size*seq_len, 1)
#onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1)
#onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1)
#onehot_2d = torch.cat((onehot_1_2d, onehot_2_2d), dim=1)
#onehot_3d = onehot_2d.view(batch_size, seq_len, 2*n_classes)
onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1).view(-1, seq_len, n_classes)
onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1).view(-1, seq_len, n_classes)
onehot_3d = torch.cat((onehot_1_2d, onehot_2_2d), dim=2)
return onehot_3d
else:
return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
def mixup(data, alpha, n_classes):
lam = torch.FloatTensor([np.random.beta(alpha, alpha)])
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = data
indices = torch.randperm(labels.size(0))
# labels
labels2 = labels[indices]
labels = onehot(labels, n_classes)
labels2 = onehot(labels2, n_classes)
labels = labels * lam + labels2 * (1 - lam)
# frames
frames2 = frames[indices]
frames = frames * lam + frames2 * (1 - lam)
# poses
poses2 = poses[indices]
poses = poses * lam + poses2 * (1 - lam)
# gazes
gazes2 = gazes[indices]
gazes = gazes * lam + gazes2 * (1 - lam)
# bboxes
bboxes2 = bboxes[indices]
bboxes = bboxes * lam + bboxes2 * (1 - lam)
return frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths
def get_classification_accuracy(pred_left_labels, pred_right_labels, labels, sequence_lengths):
max_len = max(sequence_lengths)
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
labels = torch.reshape(labels, (-1, max_len, 2))
left_correct = torch.argmax(pred_left_labels, 2) == labels[:,:,0]
right_correct = torch.argmax(pred_right_labels, 2) == labels[:,:,1]
num_pred = sum(sequence_lengths) * 2
num_correct = 0
for i in range(len(sequence_lengths)):
size = sequence_lengths[i]
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
acc = num_correct / num_pred
return acc, num_correct, num_pred
def get_classification_accuracy_mixup(pred_left_labels, pred_right_labels, labels, sequence_lengths):
max_len = max(sequence_lengths)
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
labels = torch.reshape(labels, (-1, max_len, 54))
left_correct = torch.argmax(pred_left_labels, 2) == torch.argmax(labels[:,:,:27], 2)
right_correct = torch.argmax(pred_right_labels, 2) == torch.argmax(labels[:,:,27:], 2)
num_pred = sum(sequence_lengths) * 2
num_correct = 0
for i in range(len(sequence_lengths)):
size = sequence_lengths[i]
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
acc = num_correct / num_pred
return acc, num_correct, num_pred
def pad_collate(batch):
(aa, bb, cc, dd, ee, ff) = zip(*batch)
seq_lens = [len(a) for a in aa]
aa_pad = pad_sequence(aa, batch_first=True, padding_value=0)
bb_pad = pad_sequence(bb, batch_first=True, padding_value=-1)
if cc[0] is not None:
cc_pad = pad_sequence(cc, batch_first=True, padding_value=0)
else:
cc_pad = None
if dd[0] is not None:
dd_pad = pad_sequence(dd, batch_first=True, padding_value=0)
else:
dd_pad = None
if ee[0] is not None:
ee_pad = pad_sequence(ee, batch_first=True, padding_value=0)
else:
ee_pad = None
if ff[0] is not None:
ff_pad = pad_sequence(ff, batch_first=True, padding_value=0)
else:
ff_pad = None
return aa_pad, bb_pad, cc_pad, dd_pad, ee_pad, ff_pad, seq_lens
def ocr_loss(pL, lL, pR, lR, ocr, loss, eta=10.0, mode='abs'):
"""
Custom loss based on the negative OCRL matrix.
Input:
pL: tensor of shape [batch_size, num_classes] representing left belief predictions
lL: tensor of shape [batch_size] representing the left belief labels
pR: tensor of shape [batch_size, num_classes] representing right belief predictions
lR: tensor of shape [batch_size] representing the right belief labels
ocr: negative OCR matrix (i.e. 1-OCR)
loss: loss function
eta: hyperparameter for the interaction term
Output:
Final loss resulting from weighting left and right loss with the respective
OCR coefficients and summing them.
"""
bs = pL.shape[0]
if len([*lR[0].size()]) > 0: # for mixup
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
g = torch.tensor([ocr[torch.argmax(lL[i]), torch.argmax(lR[i])] for i in range(bs)], device=pL.device)
else:
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
g = torch.tensor([ocr[lL[i], lR[i]] for i in range(bs)], device=pL.device)
left_loss = torch.mean(loss(pL, lL))
right_loss = torch.mean(loss(pR, lR))
if mode == 'abs':
interaction_loss = torch.mean(torch.abs(g - p))
elif mode == 'mse':
interaction_loss = torch.mean(torch.pow(g - p, 2))
else: raise NotImplementedError
eta = (left_loss + right_loss) / interaction_loss
interaction_loss = interaction_loss * eta
print(f"Left: {left_loss} --- Right: {right_loss} --- Interaction: {interaction_loss}")
return left_loss + right_loss + interaction_loss
def spherical2cartesial(x):
"""
From https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing#scrollTo=78QhNw4MYSYp
"""
output = torch.zeros(x.size(0),3)
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
output[:,1] = torch.sin(x[:,1])
return output
def presave():
from PIL import Image
from torchvision import transforms
import os
from natsort import natsorted
import pickle as pkl
preprocess = transforms.Compose([
transforms.Resize((128, 128)),
#transforms.Resize(256),
#transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
frame_path = '/scratch/bortoletto/data/boss/test/frame'
frame_dirs = os.listdir(frame_path)
frame_paths = []
for frame_dir in natsorted(frame_dirs):
paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
frame_paths.append(paths)
save_folder = '/scratch/bortoletto/data/boss/presaved128/'+frame_path.split('/')[-2]+'/'+frame_path.split('/')[-1]
if not os.path.exists(save_folder):
os.makedirs(save_folder)
print(save_folder, 'created')
for video in natsorted(frame_paths):
print(video[0].split('/')[-2], end='\r')
images = [preprocess(Image.open(i)) for i in video]
strings = video[0].split('/')
with open(save_folder+'/'+strings[7]+'.pkl', 'wb') as f:
pkl.dump(images, f)
def compute_cosine_similarity(tensor1, tensor2):
return F.cosine_similarity(tensor1, tensor2, dim=-1)
def find_most_similar_embedding(rgb, pose, gaze, ocr, bbox, repr):
gaze_similarity = compute_cosine_similarity(gaze, repr)
pose_similarity = compute_cosine_similarity(pose, repr)
ocr_similarity = compute_cosine_similarity(ocr, repr)
bbox_similarity = compute_cosine_similarity(bbox, repr)
rgb_similarity = compute_cosine_similarity(rgb, repr)
similarities = torch.stack([gaze_similarity, pose_similarity, ocr_similarity, bbox_similarity, rgb_similarity])
max_index = torch.argmax(similarities, dim=0)
main_modality = []
main_modality_name = []
for idx in max_index:
if idx == 0:
main_modality.append(gaze)
main_modality_name.append('gaze')
elif idx == 1:
main_modality.append(pose)
main_modality_name.append('pose')
elif idx == 2:
main_modality.append(ocr)
main_modality_name.append('ocr')
elif idx == 3:
main_modality.append(bbox)
main_modality_name.append('bbox')
else:
main_modality.append(rgb)
main_modality_name.append('rgb')
return main_modality, main_modality_name
def plot_similarity_histogram(values, filename, labels):
def count_elements(list_of_strings, target_list):
counts = []
for element in list_of_strings:
count = target_list.count(element)
counts.append(count)
return counts
fig, ax = plt.subplots(figsize=(12,4))
colors = ["red", "blue", "green", "violet"]
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
colors = [MTOM_COLORS['MN1'], MTOM_COLORS['MN2'], MTOM_COLORS['CG'], MTOM_COLORS['CG']]
alphas = [0.6, 0.6, 0.6, 0.6]
edgecolors = ['black', 'black', MTOM_COLORS['MN1'], MTOM_COLORS['MN2']]
linewidths = [1.0, 1.0, 5.0, 5.0]
if isinstance(values[0], list):
num_lists = len(values)
bar_width = 0.8 / num_lists
for i, val in enumerate(values):
unique_strings = ['rgb', 'pose', 'gaze', 'bbox', 'ocr']
counts = count_elements(unique_strings, val)
x = np.arange(len(unique_strings))
x_shifted = x + (i - (len(labels) - 1) / 2) * bar_width #x - (bar_width * num_lists / 2) + (bar_width * i)
ax.bar(x_shifted, counts, width=bar_width, label=f'{labels[i]}', color=colors[i], edgecolor=edgecolors[i], linewidth=linewidths[i], alpha=alphas[i])
ax.set_xlabel('Modality', fontsize=18)
ax.set_ylabel('Counts', fontsize=18)
ax.set_xticks(np.arange(len(unique_strings)))
ax.set_xticklabels(unique_strings, fontsize=18)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
ax.legend(fontsize=18)
ax.grid(axis='y')
plt.savefig(filename, bbox_inches='tight')
else:
unique_strings, counts = np.unique(values, return_counts=True)
ax.bar(unique_strings, counts)
ax.set_xlabel('Modality')
ax.set_ylabel('Counts')
plt.savefig(filename, bbox_inches='tight')
def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
means = []
stds = []
for key, values in data.items():
mean = np.mean(values)
std = np.std(values)
means.append(mean)
stds.append(std)
fig, ax = plt.subplots(figsize=size)
x = np.arange(len(data))
width = 0.6
rects1 = ax.bar(
x, means, width, label='Mean', yerr=stds,
capsize=5,
color='teal' if colors is None else colors,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
ax.set_ylabel('Accuracy', fontsize=18)
#ax.set_title('Mean and Standard Deviation of Results', fontsize=14)
if filename == 'results/all':
xticklabels = [
'Base',
'DB',
'CG$\otimes$', 'CG$\oplus$', 'CG$\odot$', 'CG$\parallel$',
'IC$\otimes$', 'IC$\oplus$', 'IC$\odot$', 'IC$\parallel$',
'CNN', 'CNN+GRU', 'CNN+LSTM', 'CNN+Conv1D'
]
else:
xticklabels = list(data.keys())
ax.set_xticks(np.arange(len(xticklabels)))
ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=16) #data.keys(), rotation=rotation, fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
#ax.grid(axis='y')
#ax.set_axisbelow(True)
#ax.legend(loc='upper right')
# Add value labels above each bar, avoiding overlapping with error bars
for rect, std in zip(rects1, stds):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(axis='x')
plt.tight_layout()
plt.savefig(f'{filename}.pdf', bbox_inches='tight')
def plot_confusion_matrices(left_cm, right_cm, labels, title, annot=True):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
left_xticklabels = [OBJECTS[i] for i in labels[0]]
left_yticklabels = [OBJECTS[i] for i in labels[1]]
right_xticklabels = [OBJECTS[i] for i in labels[2]]
right_yticklabels = [OBJECTS[i] for i in labels[3]]
sns.heatmap(
left_cm,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False,
xticklabels=left_xticklabels,
yticklabels=left_yticklabels,
ax=ax1
)
ax1.set_xlabel('Predicted')
ax1.set_ylabel('True')
ax1.set_title('Left Confusion Matrix')
sns.heatmap(
right_cm,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False, #True if annot is False else False,
xticklabels=right_xticklabels,
yticklabels=right_yticklabels,
ax=ax2
)
ax2.set_xlabel('Predicted')
ax2.set_ylabel('True')
ax2.set_title('Right Confusion Matrix')
#plt.suptitle(title)
plt.tight_layout()
plt.savefig(title + '.pdf')
plt.close()
def plot_confusion_matrix(confusion_matrix, labels, title, annot=True):
plt.figure(figsize=(6, 6))
xticklabels = [OBJECTS[i] for i in labels[0]]
yticklabels = [OBJECTS[i] for i in labels[1]]
sns.heatmap(
confusion_matrix,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False,
xticklabels=xticklabels,
yticklabels=yticklabels
)
plt.xlabel('Predicted')
plt.ylabel('True')
#plt.title(title)
plt.tight_layout()
plt.savefig(title + '.pdf')
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, choices=['confusion', 'similarity', 'scores', 'friends_vs_strangers'])
parser.add_argument('--folder_path', type=str)
args = parser.parse_args()
if args.task == 'similarity':
sns.set_theme(style='white')
else:
sns.set_theme(style='whitegrid')
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
if args.task == 'similarity':
out_left_main_mods_full_test = []
out_right_main_mods_full_test = []
cell_left_main_mods_full_test = []
cell_right_main_mods_full_test = []
cm_left_main_mods_full_test = []
cm_right_main_mods_full_test = []
for i in range(60):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(args.folder_path, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 14: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
out_left = out_left.squeeze(0)
cell_left = cell_left.squeeze(0)
out_right = out_right.squeeze(0)
cell_right = cell_right.squeeze(0)
elif len(data) == 13: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
common_mind = common_mind.squeeze(0)
elif len(data) == 12: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
# ====== PCA for left and right embeddings ====== #
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
pca_result_left = pca_result[:out_left.shape[0]]
pca_result_right = pca_result[out_right.shape[0]:]
plt.figure(figsize=(7,6))
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='MindNet$_1$', color=MTOM_COLORS['MN1'], s=100)
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='MindNet$_2$', color=MTOM_COLORS['MN2'], s=100)
plt.xlabel('Principal Component 1', fontsize=30)
plt.ylabel('Principal Component 2', fontsize=30)
plt.grid(False)
plt.legend(fontsize=30)
plt.tight_layout()
plt.savefig(f'{args.folder_path}/{i}_pca.pdf')
plt.close()
# ====== Feature similarity ====== #
if len(feats) == 10:
left_rgb, left_ocr, left_pose, left_gaze, left_bbox, right_rgb, right_ocr, right_pose, right_gaze, right_bbox = feats
left_rgb = left_rgb.squeeze(0)
left_ocr = left_ocr.squeeze(0)
left_pose = left_pose.squeeze(0)
left_gaze = left_gaze.squeeze(0)
left_bbox = left_bbox.squeeze(0)
right_rgb = right_rgb.squeeze(0)
right_ocr = right_ocr.squeeze(0)
right_pose = right_pose.squeeze(0)
right_gaze = right_gaze.squeeze(0)
right_bbox = right_bbox.squeeze(0)
else: raise NotImplementedError("Ablated versions are not supported yet.")
# out: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
# cell: [1, 1, dim] --------- squeeze ---> [1, dim]
# cm: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
# feat: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
_, out_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, out_left)
_, out_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, out_right)
out_left_main_mods_full_test += out_left_main_mods
out_right_main_mods_full_test += out_right_main_mods
if model == 'impl':
_, cell_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, cell_left)
_, cell_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, cell_right)
cell_left_main_mods_full_test += cell_left_main_mods
cell_right_main_mods_full_test += cell_right_main_mods
if model == 'cm':
_, cm_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, common_mind)
_, cm_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, common_mind)
cm_left_main_mods_full_test += cm_left_main_mods
cm_right_main_mods_full_test += cm_right_main_mods
if model == 'impl':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test,
cell_left_main_mods_full_test,
cell_right_main_mods_full_test],
f'{args.folder_path}/boss_similartiy_impl_hist_all.pdf',
[r'$h_1$', r'$h_2$', r'$c_1$', r'$c_2$']
)
elif model == 'cm':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test,
cm_left_main_mods_full_test,
cm_right_main_mods_full_test],
f'{args.folder_path}/boss_similarity_cm_hist_all.pdf',
[r'$h_1$', r'$h_2$', r'$cg$ w/ 1', r'$cg$ w/ 2']
)
elif model == 'sl':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test],
f'{args.folder_path}/boss_similarity_sl_hist_all.py',
[r'$h_1$', r'$h_2$']
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'scores':
all = json.load(open('results/all.json'))
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
plot_scores_histogram(
all,
filename='results/all',
size=(10,4),
rotation=45,
colors=[COLORS[7]] + [MTOM_COLORS['DB']] + [MTOM_COLORS['CG']]*4 + [MTOM_COLORS['IC']]*4 + [COLORS[4]]*4
)
plot_scores_histogram(
abl_tom_cm_concat,
size=(10,4),
filename='results/abl_cm_concat',
colors=[MTOM_COLORS['CG']]*8
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'confusion':
# List to store confusion matrices
all_df = []
left_matrices = []
right_matrices = []
# Iterate over the CSV files
for filename in natsorted(os.listdir(args.folder_path)):
if filename.endswith('.csv'):
file_path = os.path.join(args.folder_path, filename)
print(f'Processing {file_path}...')
# Read the CSV file
df = pd.read_csv(file_path)
# Extract the left and right labels and predictions
left_labels = df['left_label']
right_labels = df['right_label']
left_preds = df['left_pred']
right_preds = df['right_pred']
# Calculate the confusion matrices for left and right
left_cm = pd.crosstab(left_labels, left_preds)
right_cm = pd.crosstab(right_labels, right_preds)
# Append the confusion matrices to the list
left_matrices.append(left_cm)
right_matrices.append(right_cm)
all_df.append(df)
# Plot and save the confusion matrices for left and right
for i, cm in enumerate(zip(left_matrices, right_matrices)):
print(f'Computing confusion matrices for video {i}...', end='\r')
plot_confusion_matrices(
cm[0],
cm[1],
labels=[cm[0].columns, cm[0].index, cm[1].columns, cm[1].index],
title=f'{args.folder_path}/{i}_cm'
)
merged_df = pd.concat(all_df).reset_index(drop=True)
merged_left_cm = pd.crosstab(merged_df['left_label'], merged_df['left_pred'])
merged_right_cm = pd.crosstab(merged_df['right_label'], merged_df['right_pred'])
plot_confusion_matrices(
merged_left_cm,
merged_right_cm,
labels=[merged_left_cm.columns, merged_left_cm.index, merged_right_cm.columns, merged_right_cm.index],
title=f'{args.folder_path}/all_lr',
annot=False
)
merged_preds = pd.concat([merged_df['left_pred'], merged_df['right_pred']]).reset_index(drop=True)
merged_labels = pd.concat([merged_df['left_label'], merged_df['right_label']]).reset_index(drop=True)
merged_all_cm = pd.crosstab(merged_labels, merged_preds)
plot_confusion_matrix(
merged_all_cm,
labels=[merged_all_cm.columns, merged_all_cm.index],
title=f'{args.folder_path}/all_merged',
annot=False
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'friends_vs_strangers':
# friends ids: 30-59
# stranger ids: 0-29
out_left_list = []
out_right_list = []
cell_left_list = []
cell_right_list = []
common_mind_list = []
for i in range(60):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(args.folder_path, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 14: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
out_left = out_left.squeeze(0)
cell_left = cell_left.squeeze(0)
out_right = out_right.squeeze(0)
cell_right = cell_right.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
cell_left_list.append(cell_left)
cell_right_list.append(cell_right)
elif len(data) == 13: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
common_mind = common_mind.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
common_mind_list.append(common_mind)
elif len(data) == 12: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
# ====== PCA for left and right embeddings ====== #
print('\rComputing PCA...')
strangers_nframes = sum([out_left_list[i].shape[0] for i in range(30)])
left = torch.cat(out_left_list, 0)
right = torch.cat(out_right_list, 0)
out_left_and_right = torch.cat([left, right], axis=0).numpy()
#out_left_and_right = StandardScaler().fit_transform(out_left_and_right)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
#pca_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(out_left_and_right)
#pca_result = UMAP().fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
#pca_result_left = pca_result[:left.shape[0]]
#pca_result_right = pca_result[right.shape[0]:]
pca_result_strangers = pca_result[:strangers_nframes]
pca_result_friends = pca_result[strangers_nframes:]
#plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='Left')
#plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='Right')
plt.scatter(pca_result_friends[:, 0], pca_result_friends[:, 1], label='Friends')
plt.scatter(pca_result_strangers[:, 0], pca_result_strangers[:, 1], label='Strangers')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.legend()
plt.savefig(f'{args.folder_path}/friends_vs_strangers_pca.pdf')
plt.close()
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
else:
raise NameError

196
tbd/.gitignore vendored Normal file
View file

@ -0,0 +1,196 @@
experiments
wandb
predictions
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python,linux

16
tbd/README.md Normal file
View file

@ -0,0 +1,16 @@
# TBD
# Data
The original code can be found [here](https://github.com/LifengFan/Triadic-Belief-Dynamics). The dataset is not directly available but must be requested using the link to the Google form provided in the [README](https://github.com/LifengFan/Triadic-Belief-Dynamics?tab=readme-ov-file#dataset).
## Installing Dependencies
Run `conda env create -f environment.yml`.
## Train
`source run_train.sh`.
## Test
`source run_test.sh`. **Make sure to use the same random seed used for training**, otherwise the splits will be different and you will likely have a data leakage.
## Visualisations
The plots are made using `utils/fb_scores_err.py` (false belief analysis) and `utils/similarity.py` (PCA of latent representations).

100
tbd/environment.yml Normal file
View file

@ -0,0 +1,100 @@
name: tbd
channels:
- conda-forge
- defaults
- pytorch
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2023.01.10=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.4=h6a678d5_0
- openssl=1.1.1t=h7f8727e_0
- pip=23.0.1=py38h06a4308_0
- python=3.8.10=h12debd9_8
- readline=8.2=h5eee18b_0
- setuptools=66.0.0=py38h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.38.4=py38h06a4308_0
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- appdirs==1.4.4
- beautifulsoup4==4.12.2
- certifi==2023.5.7
- charset-normalizer==3.1.0
- click==8.1.3
- cmake==3.26.4
- contourpy==1.1.0
- cycler==0.11.0
- docker-pycreds==0.4.0
- einops==0.6.1
- filelock==3.12.0
- fonttools==4.40.0
- gdown==4.7.1
- gitdb==4.0.10
- gitpython==3.1.31
- idna==3.4
- importlib-resources==5.12.0
- jinja2==3.1.2
- joblib==1.3.1
- kiwisolver==1.4.4
- lit==16.0.6
- markupsafe==2.1.3
- matplotlib==3.7.1
- memory-efficient-attention-pytorch==0.1.2
- mpmath==1.3.0
- networkx==3.1
- numpy==1.24.4
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cufft-cu11==10.9.0.58
- nvidia-curand-cu11==10.2.10.91
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-nccl-cu11==2.14.3
- nvidia-nvtx-cu11==11.7.91
- opencv-python==4.8.0.74
- packaging==23.1
- pandas==2.0.3
- pathtools==0.1.2
- pillow==9.5.0
- protobuf==4.23.3
- psutil==5.9.5
- pyparsing==3.1.0
- pysocks==1.7.1
- python-dateutil==2.8.2
- pytz==2023.3
- pyyaml==6.0
- requests==2.30.0
- scikit-learn==1.3.0
- scipy==1.10.1
- seaborn==0.12.2
- sentry-sdk==1.27.0
- setproctitle==1.3.2
- six==1.16.0
- smmap==5.0.0
- soupsieve==2.4.1
- sympy==1.12
- threadpoolctl==3.1.0
- torch==2.0.1
- torch-geometric==2.3.1
- torchaudio==2.0.2
- torchsampler==0.1.2
- torchvision==0.15.2
- tqdm==4.65.0
- triton==2.0.0
- typing-extensions==4.7.0
- tzdata==2023.3
- urllib3==2.0.2
- wandb==0.15.5
- zipp==3.15.0
prefix: /opt/anaconda3/envs/tbd

156
tbd/models/base.py Normal file
View file

@ -0,0 +1,156 @@
import torch
import torch.nn as nn
from .utils import pose_edge_index
from torch_geometric.nn import GCNConv
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
def forward(self, x):
return self.net(x)
class CNN(nn.Module):
def __init__(self, hidden_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv3(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
return x
class MindNetLSTM(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTM, self).__init__()
self.mods = mods
if 'rgb_1' in mods:
self.img_emb = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
if 'gaze' in mods:
self.gaze_emb = nn.Linear(2, hidden_dim)
if 'pose' in mods:
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
feats = []
if 'rgb_3' in self.mods:
feats.append(rgb_3rd_pov_feats)
if 'rgb_1' in self.mods:
rgb_feat = []
for i in range(rgb_1st_pov.shape[1]):
images_i = rgb_1st_pov[:,i]
img_i_feat = self.img_emb(images_i)
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
feats.append(rgb_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats.mean(2))
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats
class MindNetSL(nn.Module):
"""
Basic MindNet for SL ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetSL, self).__init__()
self.mods = mods
if 'rgb_1' in mods:
self.img_emb = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
if 'gaze' in mods:
self.gaze_emb = nn.Linear(2, hidden_dim)
if 'pose' in mods:
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
feats = []
if 'rgb_3' in self.mods:
feats.append(rgb_3rd_pov_feats)
if 'rgb_1' in self.mods:
rgb_feat = []
for i in range(rgb_1st_pov.shape[1]):
images_i = rgb_1st_pov[:,i]
img_i_feat = self.img_emb(images_i)
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
feats.append(rgb_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats.mean(2))
lstm_inp = torch.cat(feats, 2)
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
return self.act(self.proj(lstm_out)), feats

157
tbd/models/common_mind.py Normal file
View file

@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class CommonMindToMnet(nn.Module):
"""
img: bs, 3, 128, 128
pose: bs, 26, 3
gaze: bs, 2 NOTE: only tracker has gaze
bbox: bs, 4
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']):
super(CommonMindToMnet, self).__init__()
self.aggr = aggr
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
#for param in self.cnn.parameters():
# param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_1 = nn.LayerNorm(hidden_dim)
self.ln_2 = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
if self.aggr == 'no_tom':
m1 = self.m1(out_1).mean(1)
m2 = self.m2(out_2).mean(1)
m12 = self.m12(out_1).mean(1)
m21 = self.m21(out_2).mean(1)
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h)
if self.aggr == 'attn':
p1 = self.attn_left(x=out_1, context=common_mind)
p2 = self.attn_right(x=out_2, context=common_mind)
elif self.aggr == 'mult':
p1 = out_1 * common_mind
p2 = out_2 * common_mind
elif self.aggr == 'sum':
p1 = out_1 + common_mind
p2 = out_2 + common_mind
elif self.aggr == 'concat':
p1 = torch.cat([out_1, common_mind], 1)
p2 = torch.cat([out_2, common_mind], 1)
else: raise ValueError
p1 = self.act(p1)
p1 = self.ln_1(p1)
p2 = self.act(p2)
p2 = self.ln_2(p2)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1)
if self.aggr == 'concat':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
mods = ['pose', 'bbox', 'rgb_3']
for agg in ['no_tom']:
model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(out[0].shape)

151
tbd/models/implicit.py Normal file
View file

@ -0,0 +1,151 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class ImplicitToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
super(ImplicitToMnet, self).__init__()
self.aggr = aggr
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
self.ln_1 = nn.LayerNorm(hidden_dim)
self.ln_2 = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
if self.aggr == 'no_tom':
m1 = self.m1(out_1).mean(1)
m2 = self.m2(out_2).mean(1)
m12 = self.m12(out_1).mean(1)
m21 = self.m21(out_2).mean(1)
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
if self.aggr == 'attn':
p1 = self.attn_left(x=out_1, context=cell_2)
p2 = self.attn_right(x=out_2, context=cell_1)
elif self.aggr == 'mult':
p1 = out_1 * cell_2
p2 = out_2 * cell_1
elif self.aggr == 'sum':
p1 = out_1 + cell_2
p2 = out_2 + cell_1
elif self.aggr == 'concat':
p1 = torch.cat([out_1, cell_2], 1)
p2 = torch.cat([out_2, cell_1], 1)
else: raise ValueError
p1 = self.act(p1)
p1 = self.ln_1(p1)
p2 = self.act(p2)
p2 = self.ln_2(p2)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1)
if self.aggr == 'concat':
m1 = self.m1(p1).mean(1)
m2 = self.m2(p2).mean(1)
m12 = self.m12(p1).mean(1)
m21 = self.m21(p2).mean(1)
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(agg, out[0].shape)

112
tbd/models/sl.py Normal file
View file

@ -0,0 +1,112 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetSL
class SLToMnet(nn.Module):
"""
Speaker-Listener ToMnet
"""
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
super(SLToMnet, self).__init__()
self.tom_weight = tom_weight
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetSL(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetSL(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
m1_logits = self.m1(out_1).mean(1)
m2_logits = self.m2(out_2).mean(1)
m12_logits = self.m12(out_1).mean(1)
m21_logits = self.m21(out_2).mean(1)
mc_logits = self.mc(out_1*out_2).mean(1)
m1_ranking = torch.log_softmax(m1_logits, dim=-1)
m2_ranking = torch.log_softmax(m2_logits, dim=-1)
m12_ranking = torch.log_softmax(m12_logits, dim=-1)
m21_ranking = torch.log_softmax(m21_logits, dim=-1)
mc_ranking = torch.log_softmax(mc_logits, dim=-1)
# NOTE: does this make sense?
m1 = m1_ranking + self.tom_weight * m2_ranking
m2 = m2_ranking + self.tom_weight * m1_ranking
m12 = m12_ranking + self.tom_weight * m21_ranking
m21 = m21_ranking + self.tom_weight * m12_ranking
mc = mc_ranking + self.tom_weight * mc_ranking
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
model = SLToMnet(hidden_dim=64, device='cpu', tom_weight=2.0, resnet=False, dropout=0.5)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(out[0].shape)

112
tbd/models/tom_base.py Normal file
View file

@ -0,0 +1,112 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .base import CNN, MindNetLSTM
import numpy as np
class ImplicitToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
super(ImplicitToMnet, self).__init__()
self.mods = mods
# ---- 3rd POV Images, object and bbox ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
self.bbox_ff = nn.Linear(4, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
self.m1 = nn.Linear(hidden_dim, 4)
self.m2 = nn.Linear(hidden_dim, 4)
self.m12 = nn.Linear(hidden_dim, 4)
self.m21 = nn.Linear(hidden_dim, 4)
self.mc = nn.Linear(hidden_dim, 4)
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
if 'bbox' in self.mods:
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
else:
bbox_feat = None
if 'rgb_3' in self.mods:
rgb_feat = []
for i in range(sequence_len):
images_i = img_3rd_pov[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
else:
rgb_feat_3rd_pov = None
if tracker_id == 'skele1':
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
else:
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
if self.aggr == 'no_tom':
m1 = self.m1(out_1).mean(1)
m2 = self.m2(out_2).mean(1)
m12 = self.m12(out_1).mean(1)
m21 = self.m21(out_2).mean(1)
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
if __name__ == "__main__":
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
img_tracker = torch.ones(3, 5, 3, 128, 128)
img_battery = torch.ones(3, 5, 3, 128, 128)
pose1 = torch.ones(3, 5, 26, 3)
pose2 = torch.ones(3, 5, 26, 3)
bbox = torch.ones(3, 5, 13, 4)
tracker_id = 'skele1'
gaze = torch.ones(3, 5, 2)
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
print(count_parameters(model))
breakpoint()
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
print(agg, out[0].shape)

7
tbd/models/utils.py Normal file
View file

@ -0,0 +1,7 @@
import torch
def pose_edge_index():
start = [15, 14, 13, 12, 19, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 3, 4, 5, 6, 8, 8, 4, 20, 21, 21, 22, 24, 22]
end = [14, 13, 12, 0, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 4, 20, 20, 21, 22, 24, 23, 25, 24]
return torch.tensor([start+end, end+start])

186
tbd/results/abl.json Normal file
View file

@ -0,0 +1,186 @@
{
"all": [
{
"m1": 0.3803337111550836,
"m2": 0.3900899763574355,
"m12": 0.4441281276628709,
"m21": 0.4818757648120031,
"mc": 0.4485177767702456
},
{
"m1": 0.5186066842992191,
"m2": 0.521895750052127,
"m12": 0.49294626980529677,
"m21": 0.4810118034501327,
"mc": 0.6097300398369058
},
{
"m1": 0.4965589148309122,
"m2": 0.5094894309980568,
"m12": 0.4615136302786905,
"m21": 0.4554005550423429,
"mc": 0.6258118710785031
}
],
"rgb_3_pose_gaze_bbox": [
{
"m1": 0.3776045061727805,
"m2": 0.3996776745150713,
"m12": 0.4762772810038159,
"m21": 0.48643178296718503,
"mc": 0.4575207273412474
},
{
"m1": 0.5176564423560418,
"m2": 0.5109344883698214,
"m12": 0.4630213122846928,
"m21": 0.4826608133674547,
"mc": 0.5979415365779003
},
{
"m1": 0.5114692300997931,
"m2": 0.5027048375802656,
"m12": 0.47527894405588544,
"m21": 0.45223985157847546,
"mc": 0.6054099305712209
}
],
"rgb_3_pose_gaze": [
{
"m1": 0.403207421026191,
"m2": 0.3833413122398237,
"m12": 0.4602455224198077,
"m21": 0.47181798537346287,
"mc": 0.4603675297898878
},
{
"m1": 0.49484810149311514,
"m2": 0.5060275976807422,
"m12": 0.4610412452830618,
"m21": 0.46869095956564044,
"mc": 0.6040674897817755
},
{
"m1": 0.5160598186177866,
"m2": 0.5309683014233921,
"m12": 0.47227245803060636,
"m21": 0.46953974307035984,
"mc": 0.6014771460423635
}
],
"rgb_3_pose": [
{
"m1": 0.4057149181928123,
"m2": 0.4002233785689204,
"m12": 0.46794813614607333,
"m21": 0.4690365183933033,
"mc": 0.4591530208921514
},
{
"m1": 0.5362792166212834,
"m2": 0.5290656046231254,
"m12": 0.4569419683345858,
"m21": 0.4530255281497826,
"mc": 0.4554252731371068
},
{
"m1": 0.49570625763169085,
"m2": 0.5146503967646507,
"m12": 0.4567936139893578,
"m21": 0.45918214877096325,
"mc": 0.5962397441246001
}
],
"rgb_3_gaze": [
{
"m1": 0.40135106828655215,
"m2": 0.38453470155825614,
"m12": 0.4989742833725901,
"m21": 0.47369273992079175,
"mc": 0.48430622854433986
},
{
"m1": 0.508038122818153,
"m2": 0.4875748099051746,
"m12": 0.46665443622698555,
"m21": 0.46635808547742913,
"mc": 0.47936993226840163
},
{
"m1": 0.49795853039610977,
"m2": 0.5028666890527814,
"m12": 0.44176709237564815,
"m21": 0.4483898274665582,
"mc": 0.5867527750929912
}
],
"rgb_3_bbox": [
{
"m1": 0.3951383898241492,
"m2": 0.3818794542844425,
"m12": 0.44108151735270384,
"m21": 0.46539754196523303,
"mc": 0.43982185797713114
},
{
"m1": 0.5093846655989521,
"m2": 0.4923439212866733,
"m12": 0.4598003475323884,
"m21": 0.47647640659290746,
"mc": 0.6349953712994137
},
{
"m1": 0.5325224862402295,
"m2": 0.5092319973570975,
"m12": 0.4435807136490263,
"m21": 0.4576911633624616,
"mc": 0.6282064277856357
}
],
"rgb_3_rgb_1": [
{
"m1": 0.39189391736691903,
"m2": 0.3739995635963588,
"m12": 0.4792392731637056,
"m21": 0.4592726043789752,
"mc": 0.4468645255652386
},
{
"m1": 0.4827892482357646,
"m2": 0.48042899735042716,
"m12": 0.45932653547051094,
"m21": 0.48430209616318126,
"mc": 0.4506104344435269
},
{
"m1": 0.4820247145474279,
"m2": 0.3667553358192628,
"m12": 0.44503028688537,
"m21": 0.45984906207471654,
"mc": 0.465120658971623
}
],
"rgb_3": [
{
"m1": 0.40725462165126114,
"m2": 0.38737351624656846,
"m12": 0.46230461548252094,
"m21": 0.4829312519709871,
"mc": 0.4492175856929955
},
{
"m1": 0.5286274183685061,
"m2": 0.5081429492163979,
"m12": 0.4610256989472217,
"m21": 0.4733487634477733,
"mc": 0.4655243312197501
},
{
"m1": 0.5217968210271873,
"m2": 0.5103780571157844,
"m12": 0.4431266771306429,
"m21": 0.48398542131284883,
"mc": 0.6122314353959392
}
]
}

232
tbd/results/all.json Normal file
View file

@ -0,0 +1,232 @@
{
"cm_concat": [
{
"m1": 0.38921744471949393,
"m2": 0.38557137008494935,
"m12": 0.44699534554593756,
"m21": 0.4747474437468054,
"mc": 0.4918107834016411
},
{
"m1": 0.5402415140026018,
"m2": 0.48833721513836786,
"m12": 0.4631512445419047,
"m21": 0.4740880083492652,
"mc": 0.6375070925808958
},
{
"m1": 0.5012543523713172,
"m2": 0.5068694866895836,
"m12": 0.4451537834591627,
"m21": 0.45215784721598673,
"mc": 0.6201022576104379
}
],
"cm_sum": [
{
"m1": 0.39403894801783246,
"m2": 0.38541918219411786,
"m12": 0.4600376974144952,
"m21": 0.471919704007463,
"mc": 0.43950812310207055
},
{
"m1": 0.48497621104052574,
"m2": 0.5295044689855949,
"m12": 0.4502949472343065,
"m21": 0.47823492553894387,
"mc": 0.6028290833617195
},
{
"m1": 0.503386104373653,
"m2": 0.49983127146477085,
"m12": 0.46782817568218116,
"m21": 0.45484578845116075,
"mc": 0.5905749126722909
}
],
"cm_mult": [
{
"m1": 0.39070820515470606,
"m2": 0.3996851353903932,
"m12": 0.4455704586852128,
"m21": 0.4713517869738811,
"mc": 0.4450907029478458
},
{
"m1": 0.5066540697731119,
"m2": 0.526507445454099,
"m12": 0.462643008560599,
"m21": 0.48263054309565334,
"mc": 0.6438566476782207
},
{
"m1": 0.48868811674304546,
"m2": 0.5074635877653536,
"m12": 0.44597405775819876,
"m21": 0.45445350963025877,
"mc": 0.5884265473527218
}
],
"cm_attn": [
{
"m1": 0.3949557687114269,
"m2": 0.3919385900921811,
"m12": 0.4850081112466773,
"m21": 0.4849575556679713,
"mc": 0.4516870089239762
},
{
"m1": 0.4925989821370256,
"m2": 0.49409170532242247,
"m12": 0.4664647278240569,
"m21": 0.46783863397462533,
"mc": 0.6398721139927354
},
{
"m1": 0.4945636568169018,
"m2": 0.5049812790749876,
"m12": 0.454359577718189,
"m21": 0.4712184012093268,
"mc": 0.5992735441011302
}
],
"no_tom": [
{
"m1": 0.2570551317,
"m2": 0.375350929686332,
"m12": 0.312451988649724,
"m21": 0.4631371031641,
"mc": 0.457486278214567
},
{
"m1": 0.233046800382043,
"m2": 0.522609755931958,
"m12": 0.326821758467328,
"m21": 0.474338898013257,
"mc": 0.604439456291308
},
{
"m1": 0.33774852598382,
"m2": 0.520943544364353,
"m12": 0.298617214416867,
"m21": 0.482175301427192,
"mc": 0.634948478570852
}
],
"sl": [
{
"m1": 0.365205706591741,
"m2": 0.255259363011619,
"m12": 0.421227579844245,
"m21": 0.376143327741882,
"mc": 0.45614515353718
},
{
"m1": 0.493046934143676,
"m2": 0.331798174804139,
"m12": 0.422821548330913,
"m21": 0.399768928780549,
"mc": 0.450957023549231
},
{
"m1": 0.466266787709392,
"m2": 0.350962671130227,
"m12": 0.431694150269919,
"m21": 0.378863431433258,
"mc": 0.470284405744656
}
],
"impl_concat": [
{
"m1": 0.38427302094644894,
"m2": 0.38673879043767634,
"m12": 0.45694337561663145,
"m21": 0.4737891562722213,
"mc": 0.4502976351448088
},
{
"m1": 0.49951068243751173,
"m2": 0.5084945752383908,
"m12": 0.4604721097809549,
"m21": 0.4826884970930907,
"mc": 0.6200443272625361
},
{
"m1": 0.5013244243339088,
"m2": 0.49476495726495723,
"m12": 0.4596701406290429,
"m21": 0.4554742441542813,
"mc": 0.5988949378402535
}
],
"impl_sum": [
{
"m1": 0.3803337111550836,
"m2": 0.3900899763574355,
"m12": 0.4441281276628709,
"m21": 0.4818757648120031,
"mc": 0.4485177767702456
},
{
"m1": 0.5186066842992191,
"m2": 0.521895750052127,
"m12": 0.49294626980529677,
"m21": 0.4810118034501327,
"mc": 0.6097300398369058
},
{
"m1": 0.4965589148309122,
"m2": 0.5094894309980568,
"m12": 0.4615136302786905,
"m21": 0.4554005550423429,
"mc": 0.6258118710785031
}
],
"impl_mult": [
{
"m1": 0.3789421413006731,
"m2": 0.3818053844554785,
"m12": 0.46402717346945177,
"m21": 0.4903726261039529,
"mc": 0.4461443806398687
},
{
"m1": 0.3789421413006731,
"m2": 0.3818053844554785,
"m12": 0.46402717346945177,
"m21": 0.4903726261039529,
"mc": 0.4461443806398687
},
{
"m1": 0.49338554196342077,
"m2": 0.5066817652688608,
"m12": 0.46253374461930613,
"m21": 0.47782311190445825,
"mc": 0.4581608719646799
}
],
"impl_attn": [
{
"m1": 0.37413691393147924,
"m2": 0.2546966838007244,
"m12": 0.429390512693598,
"m21": 0.292401773870023,
"mc": 0.45706325836224465
},
{
"m1": 0.513917904196177,
"m2": 0.25802580258025803,
"m12": 0.49272662664765543,
"m21": 0.27041556176385584,
"mc": 0.6041394755857196
},
{
"m1": 0.47720445038981674,
"m2": 0.25839328537170264,
"m12": 0.46505055463781547,
"m21": 0.260276985433943,
"mc": 0.6021811271770562
}
]
}

Binary file not shown.

87
tbd/results/fb_ttest.txt Normal file
View file

@ -0,0 +1,87 @@
========================================================= m1_m2_m12_m21
Model: Base -> yes
Model: DB -> yes
Model: CG$\oplus$ -> no
Model: CG$\otimes$ -> no
Model: CG$\odot$ -> no
Model: IC$\parallel$ -> no
Model: IC$\oplus$ -> no
Model: IC$\otimes$ -> no
Model: IC$\odot$ -> yes
========================================================= m1_m2
Model: Base -> yes
Model: DB -> yes
Model: CG$\oplus$ -> no
Model: CG$\otimes$ -> no
Model: CG$\odot$ -> no
Model: IC$\parallel$ -> no
Model: IC$\oplus$ -> no
Model: IC$\otimes$ -> no
Model: IC$\odot$ -> yes
========================================================= m12_m21
Model: Base -> yes
Model: DB -> yes
Model: CG$\oplus$ -> no
Model: CG$\otimes$ -> no
Model: CG$\odot$ -> no
Model: IC$\parallel$ -> no
Model: IC$\oplus$ -> no
Model: IC$\otimes$ -> no
Model: IC$\odot$ -> yes

View file

@ -0,0 +1,59 @@
mc =====================================================================
precision recall f1-score support
0 0.000 0.500 0.001 50
1 0.000 0.000 0.000 4
2 0.004 0.038 0.007 238
3 0.999 0.795 0.885 290788
accuracy 0.794 291080
macro avg 0.251 0.333 0.223 291080
weighted avg 0.998 0.794 0.884 291080
m1 =====================================================================
precision recall f1-score support
0 0.000 0.000 0.000 147
1 0.000 0.000 0.000 2
2 0.025 0.051 0.033 1714
3 0.994 0.988 0.991 289217
accuracy 0.982 291080
macro avg 0.255 0.260 0.256 291080
weighted avg 0.988 0.982 0.985 291080
m2 =====================================================================
precision recall f1-score support
0 0.001 0.013 0.001 151
2 0.031 0.084 0.045 2394
3 0.992 0.970 0.981 288535
accuracy 0.962 291080
macro avg 0.341 0.355 0.342 291080
weighted avg 0.983 0.962 0.972 291080
m12 =====================================================================
precision recall f1-score support
0 0.000 0.000 0.000 93
1 0.000 0.000 0.000 8
2 0.015 0.056 0.023 676
3 0.997 0.990 0.994 290303
accuracy 0.988 291080
macro avg 0.253 0.262 0.254 291080
weighted avg 0.995 0.988 0.991 291080
m21 =====================================================================
precision recall f1-score support
0 0.002 0.012 0.003 86
1 0.000 0.000 0.000 12
2 0.010 0.040 0.016 658
3 0.997 0.989 0.993 290324
accuracy 0.987 291080
macro avg 0.252 0.260 0.253 291080
weighted avg 0.995 0.987 0.991 291080

Binary file not shown.

12
tbd/run_test.sh Normal file
View file

@ -0,0 +1,12 @@
#!/bin/bash
python -m test \
--gpu_id 1 \
--seed 1 \
--non_blocking \
--pin_memory \
--model_type tom_cm \
--aggr no_tom \
--hidden_dim 64 \
--batch_size 64 \
--load_model_path /PATH/TO/model

16
tbd/run_train.sh Normal file
View file

@ -0,0 +1,16 @@
#!/bin/bash
python -m train \
--gpu_id 2 \
--seed 123 \
--logger \
--non_blocking \
--pin_memory \
--batch_size 64 \
--num_workers 16 \
--num_epoch 300 \
--lr 5e-4 \
--dropout 0.1 \
--model_type tom_cm \
--aggr no_tom \
--hidden_dim 64

568
tbd/tbd_dataloader.py Normal file
View file

@ -0,0 +1,568 @@
from __future__ import annotations
from typing import Optional, Union
import torch
import pickle
import torch
import time
import glob
import random
import os
import numpy as np
import pandas as pd
import cv2
from itertools import product
import csv
import torchvision.transforms as T
from utils.helpers import tracker_skeID, CLIPS_IDS_88, ALL_IDS, UNIQUE_OBJ_IDS
def collate_fn(batch):
# Unpack the batch into individual elements
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep = zip(*batch)
# Determine the maximum number of objects in any batch
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
# Pad the bounding box tensors
bboxes_pad = []
for bbox in bboxes:
pad_size = max_n_obj - bbox.shape[1]
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
padded_bbox = torch.cat((bbox, pad), dim=1)
bboxes_pad.append(padded_bbox)
# Stack the padded tensors into a batch tensor
bboxes_batch = torch.stack(bboxes_pad, dim=0)
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
img_tracker = torch.stack(img_tracker, dim=0)
img_battery = torch.stack(img_battery, dim=0)
pose1 = torch.stack(pose1, dim=0)
pose2 = torch.stack(pose2, dim=0)
gaze = torch.stack(gaze, dim=0)
labels = torch.tensor(labels, dtype=torch.long)
# Return the batched tensors
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep
def collate_fn_test(batch):
# Unpack the batch into individual elements
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep, false_beliefs = zip(*batch)
# Determine the maximum number of objects in any batch
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
# Pad the bounding box tensors
bboxes_pad = []
for bbox in bboxes:
pad_size = max_n_obj - bbox.shape[1]
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
padded_bbox = torch.cat((bbox, pad), dim=1)
bboxes_pad.append(padded_bbox)
# Stack the padded tensors into a batch tensor
bboxes_batch = torch.stack(bboxes_pad, dim=0)
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
img_tracker = torch.stack(img_tracker, dim=0)
img_battery = torch.stack(img_battery, dim=0)
pose1 = torch.stack(pose1, dim=0)
pose2 = torch.stack(pose2, dim=0)
gaze = torch.stack(gaze, dim=0)
labels = torch.tensor(labels, dtype=torch.long)
# Return the batched tensors
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep, false_beliefs
class TBDDataset(torch.utils.data.Dataset):
def __init__(
self,
path: str = "/scratch/bortoletto/data/tbd",
mode: str = "train",
tbd_data_path: str = "/scratch/bortoletto/data/tbd/mind_lstm_training_cnn_att/",
list_of_ids_to_consider: list = ALL_IDS,
use_preprocessed_img: bool = True,
resize_img: Optional[Union[tuple, int]] = (128,128),
):
"""TBD Dataset based on the 88 clip version of the TBD data.
Expects the following folder structure:
- path
- tracker_gt_smooth <- These are eye tracking from POV, 2D coordinates
- images/*/ <- These are the images by experiment id
- battery <- These are the images, 1st Person
- tracker <- These are the images, 1st other Person w/ eye fixation
- kinect <- These are the images, 3rd Person
- skeleton <- Pose estimation, 3D coordinates
- annotation <- These are the labels, i.e. [0,3] (see below)
Labels are strcutured as follows:
{
"O1": [ <- Object with id O1
{
"m1": {
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
"loc": null
},
"m2": {
"fluent": 3,
"loc": null
},
"m12": {
"fluent": 3,
"loc": null
},
"m21": {
"fluent": 3,
"loc": null
},
"mc": {
"fluent": 3,
"loc": null
},
"mg": {
"fluent": 3,
"loc": [
22,
9
]
}
}, ...
], ...
}
This corresponds to a strict subset of the raw dataset collected
by the TBD people in their paper "Learning Traidic Belief Dynamics
in Nonverbal Communication from Videos" (CVPR2021, Oral).
We keep small amounts of data in memory (everything <100MB).
Otherwise we read from disk on the fly. This dataset applies normalization.
Args:
path (str, optional): Where the folders lie.
Defaults to "/scratch/ruhdorfer/triadic_beleif_data_v2".
list_of_ids_to_consider (list, optional): List of ids to consider.
Defaults to ALL_IDS. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
resize_img (Optional[Union[tuple, int]], optional): Resize image to
this size if required. Defaults to None.
"""
print(f"Loading TBD Dataset in mode {mode}...")
self.mode = mode
start = time.time()
self.skeleton_3D_path = f"{path}/skeleton"
self.tracker_2D_path = f"{path}/tracker_gt_smooth"
self.bbox_csv_path = f"{path}/annotations_with_bbox.csv"
if use_preprocessed_img:
self.img_path = f"{path}/images_norm"
else:
self.img_path = f"{path}/images"
self.obj_ids_path = f"{path}/mind_lstm_training_cnn_att_shu.pkl"
self.label_map = list(product([0, 1, 2, 3], repeat=5))
clips = os.listdir(tbd_data_path)
data = []
labels = []
for clip in clips:
with open(tbd_data_path + clip, 'rb') as f:
vec_input, label_ = pickle.load(f, encoding='latin1')
data = data + vec_input
labels = labels + label_
c = list(zip(data, labels))
random.shuffle(c)
train_ratio = int(len(c) * 0.6)
validate_ratio = int(len(c) * 0.2)
data, label = zip(*c)
train_x, train_y = data[:train_ratio], label[:train_ratio]
validate_x, validate_y = data[train_ratio:train_ratio + validate_ratio], label[train_ratio:train_ratio + validate_ratio]
test_x, test_y = data[train_ratio + validate_ratio:], label[train_ratio + validate_ratio:]
self.mind_count = np.zeros(1024) # used for CE weights
if mode == "train":
self.data, self.labels = train_x, train_y
elif mode == "val":
self.data, self.labels = validate_x, validate_y
elif mode == "test":
self.data, self.labels = test_x, test_y
self.false_beliefs_path = f"{path}/store_mind_set"
# keep small amouts of data in memory
self.skeleton_3D = self.load_skeleton_3D(self.skeleton_3D_path, list_of_ids_to_consider)
self.tracker_2D = self.load_tracker_2D(self.tracker_2D_path, list_of_ids_to_consider)
self.bbox_df = pd.read_csv(self.bbox_csv_path, header=0)
self.obj_ids = self.load_obj_ids(self.obj_ids_path)
if not use_preprocessed_img:
normalisation_steps = [
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
if resize_img is not None:
normalisation_steps.insert(1, T.Resize(resize_img))
self.preprocess_img = T.Compose(normalisation_steps)
else:
self.preprocess_img = None
self.use_preprocessed_img = use_preprocessed_img
print(f"Done loading in {time.time() - start}s.")
def __len__(self):
return len(self.data)
def __getitem__(
self, idx: int
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict, str, int, str
]:
"""Given an index, return the corresponding experiment_id and timestep in the experiment.
Then picky the appropriate data and labels from these.
Args:
idx (int): _description_
Returns:
tuple: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict
Returns the following:
- img_kinect: torch.Tensor of shape (T, C, H, W) (Default is [T, 3, 720, 1280])
- img_tracker: torch.Tensor of shape (T, H, W, C)
- img_battery: torch.Tensor of shape (T, H, W, C)
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 1)
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 2)
- bbox: torch.Tensor of shape (T, num_obj, 5)
- tracker to skeleton ID: str (either skeleton 1 or 2)
- tracker_2D: torch.Tensor of shape (T, 2)
- labels: dict (see below)
"""
labels = self.label_map[self.labels[idx]]
experiment_id = self.data[idx][1][0].split('/')[6]
img_data_path = f"{self.img_path}/{experiment_id}"
frame_ids = [int(os.path.basename(self.data[idx][1][i]).split('_')[0]) for i in range(len(self.data[idx][1]))]
if self.use_preprocessed_img:
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.pt")))
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.pt")))
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.pt")))
kinect_img_paths = [kinect[id] for id in frame_ids]
tracker_img_paths = [tracker[id] for id in frame_ids]
battery_img_paths = [battery[id] for id in frame_ids]
else:
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.jpg")))
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.jpg")))
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.jpg")))
kinect_img_paths = [kinect[id] for id in frame_ids]
tracker_img_paths = [tracker[id] for id in frame_ids]
battery_img_paths = [battery[id] for id in frame_ids]
# load images
kinect_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in kinect_img_paths
]
kinect_imgs = torch.stack(kinect_imgs, axis=0)
tracker_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in tracker_img_paths
]
tracker_imgs = torch.stack(tracker_imgs, axis=0)
battery_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in battery_img_paths
]
battery_imgs = torch.stack(battery_imgs, axis=0)
# load object id to check for false beliefs - only for testing
if self.mode == "test": #or self.mode == "train":
if f"{experiment_id}.txt" in os.listdir(self.false_beliefs_path):
obj_id = self.obj_ids[experiment_id][frame_ids[-1]]
obj_id = next(x for x in obj_id if x is not None)
false_belief = next((line.strip().split(',')[2] for line in open(f"{self.false_beliefs_path}/{experiment_id}.txt") if line.startswith(str(frame_ids[-1]) + ',' + obj_id + ',')), "no")
#if experiment_id in ['test_boelter4_0', 'test_boelter4_7', 'test_boelter4_6', 'test_boelter4_8', 'test_boelter2_3',
# 'test_94342_20', 'test_94342_18', 'test_94342_11', 'test_94342_17', 'test_boelter3_8', 'test_94342_2',
# 'test_boelter2_17', 'test_boelter3_7', 'test_94342_4', 'test_boelter3_9', 'test_boelter_10',
# 'test_boelter2_6', 'test_boelter4_10', 'test_boelter4_2', 'test_boelter4_5', 'test_94342_24',
# 'test_94342_15', 'test_boelter3_5', 'test_94342_8', 'test2', 'test_boelter3_12']:
# print('here!')
# with open(os.path.join(f'results/hgm_test_fb.csv'), mode='a') as file:
# writer = csv.writer(file)
# writer.writerow([experiment_id, obj_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
else:
false_belief = "no"
#with open(os.path.join(f'results/test_fb.csv'), mode='a') as file:
# writer = csv.writer(file)
# writer.writerow([experiment_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
df = self.bbox_df[
(self.bbox_df.experiment_name == experiment_id)
#& (self.bbox_df.name == obj_id) # NOTE: load the bounding boxes for all objects
& (self.bbox_df.name != 'P1')
& (self.bbox_df.name != 'P2')
& (self.bbox_df.frame.isin(frame_ids))
]
bboxes = []
for f in frame_ids:
bbox = torch.tensor(df.loc[df['frame'] == f, ["x_min", "y_min", "x_max", "y_max"]].to_numpy(), dtype=torch.float32)
bbox[:, 0] = bbox[:, 0] / 1280.0
bbox[:, 1] = bbox[:, 1] / 720.0
bbox[:, 2] = bbox[:, 2] / 1280.0
bbox[:, 3] = bbox[:, 3] / 720.0
bboxes.append(bbox)
bboxes = torch.stack(bboxes) # NOTE: this will need a collate function bc not every video has the same number of objects
skele1 = self.skeleton_3D[experiment_id]["skele1"][frame_ids]
skele2 = self.skeleton_3D[experiment_id]["skele2"][frame_ids]
gaze = self.tracker_2D[experiment_id][frame_ids]
if self.mode == "test":
return (
kinect_imgs,
tracker_imgs,
battery_imgs,
skele1,
skele2,
bboxes,
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
gaze,
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
experiment_id,
frame_ids,
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
false_belief
)
else:
return (
kinect_imgs,
tracker_imgs,
battery_imgs,
skele1,
skele2,
bboxes,
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
gaze,
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
experiment_id,
frame_ids
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
)
def onehot(self, x, n=len(UNIQUE_OBJ_IDS)):
retval = torch.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
def load_obj_ids(self, path: str):
with open(path, "rb") as f:
ids = pickle.load(f)
return ids
def extract_labels(self):
"""TODO: Converts index label to [m1, m2, m12, m21, mc] format.
"""
return
def _flatten_mind_obj_timestep(self, mind_obj_dict: dict) -> list:
"""Flattens the mind object dict to a list. I.e. takes
{
"m1": {
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
"loc": null
},
"m2": {
"fluent": 3,
"loc": null
},
"m12": {
"fluent": 3,
"loc": null
},
"m21": {
"fluent": 3,
"loc": null
},
"mc": {
"fluent": 3,
"loc": null
},
"mg": {
"fluent": 3,
"loc": [
22,
9
]
}
}
and returns [3, 3, 3, 3, 3, 3]
Args:
mind_obj_dict (dict): Mind object dict as described in __init__.doctstring.
Returns:
list: List of mind object labels.
"""
return np.array([mind_obj["fluent"] for key, mind_obj in mind_obj_dict.items() if key != "mg"])
def load_skeleton_3D(self, path: str, list_of_ids_to_consider: list):
"""Load skeleton 3D data from disk.
- path
- * <- list of ids
- skele1.p <- 3D coord per id and timestep
- skele2.p <-
Args:
path (str): Where the skeleton 3D data lie.
list_of_ids_to_consider (list): List of ids to consider.
Defaults to None which means all ids. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
Returns:
dict: skeleton 3D data as described above in __init__.doctstring.
"""
skeleton_3D = {}
for experiment_id in list_of_ids_to_consider:
skeleton_3D[experiment_id] = {}
with open(f"{path}/{experiment_id}/skele1.p", "rb") as f:
skeleton_3D[experiment_id]["skele1"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
with open(f"{path}/{experiment_id}/skele2.p", "rb") as f:
skeleton_3D[experiment_id]["skele2"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
return skeleton_3D
def load_tracker_2D(self, path: str, list_of_ids_to_consider: list):
"""Load tracker 2D data from disk.
- path
- *.p <- 2D coord per id and timestep
Args:
path (str): Where the tracker 2D data lie.
list_of_ids_to_consider (list): List of ids to consider.
Defaults to None which means all ids. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
Returns:
dict: tracker 2D data.
"""
tracker_2D = {}
for experiment_id in list_of_ids_to_consider:
with open(f"{path}/{experiment_id}.p", "rb") as f:
tracker_2D[experiment_id] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
return tracker_2D
def load_bbox(self, path: str, list_of_ids_to_consider: list):
"""Load bbox data from disk.
- bbox_tensors.pickle <- bbox per experiment id one tensor
Args:
path (str): Where the bbox data lie.
list_of_ids_to_consider (list): List of ids to consider.
Returns:
dict: bbox data.
"""
with open(path, "rb") as f:
pickle_data = pickle.load(f)
for key in CLIPS_IDS_88:
if key not in list_of_ids_to_consider:
pickle_data.pop(key, None)
return pickle_data
if __name__ == '__main__':
# os.environ['PYTHONHASHSEED'] = str(42)
# torch.manual_seed(42)
# np.random.seed(42)
# random.seed(42)
data = TBDDataset(use_preprocessed_img=True, mode="test")
from tqdm import tqdm
for i in tqdm(range(data.__len__())):
data[i]
breakpoint()
from torch.utils.data import DataLoader
# Just for guessing time
data_0=data[0]
data_last=data[len(data)-1]
idx = np.random.randint(1, len(data)-1) # Something in between.
start = time.time()
(
kinect_imgs, # <- len x 720 x 1280 x 3 originally, likely smaller now
tracker_imgs,
battery_imgs,
skele1,
skele2,
bbox,
tracker_skeID_sample, # <- This is the tracker skeleton ID
tracker2d,
label,
experiment_id, # From here for debugging
timestep,
#obj_id, # <- This is the object ID as a one-hot
false_belief
) = data[idx]
end = time.time()
print(f"Time for one sample: {end-start}")
print('kinect:', kinect_imgs.shape)
print('tracker:', tracker_imgs.shape)
print('battery:', battery_imgs.shape)
print('skele1:', skele1.shape)
print('skele2:', skele2.shape)
print('gaze:', tracker2d.shape)
print('bbox:', bbox.shape)
print('label:', label)
#breakpoint()
dl = DataLoader(
data,
batch_size=4,
shuffle=False,
collate_fn=collate_fn
)
from tqdm import tqdm
for j, batch in tqdm(enumerate(dl)):
#print(j, end='\r')
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
#breakpoint()
#print(img_3rd_pov.shape)
#print(img_tracker.shape)
#print(img_battery.shape)
#print(pose1.shape, pose2.shape)
#print(bbox.shape)
#print(gaze.shape)
breakpoint()

196
tbd/test.py Normal file
View file

@ -0,0 +1,196 @@
import torch
import csv
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
import random
import os
import numpy as np
from tbd_dataloader import TBDDataset, collate_fn_test
from models.common_mind import CommonMindToMnet
from models.sl import SLToMnet
from models.implicit import ImplicitToMnet
from utils.helpers import compute_f1_scores
def test(args):
test_dataset = TBDDataset(
path=args.data_path,
mode="test",
use_preprocessed_img=True
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn_test
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
# model
if args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
model.device = device
model.eval()
if args.save_preds:
# Define the output file path
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(f'Saving predictions in {folder_path}.')
print('Testing...')
m1_pred_list = []
m2_pred_list = []
m12_pred_list = []
m21_pred_list = []
mc_pred_list = []
m1_label_list = []
m2_label_list = []
m12_label_list = []
m21_label_list = []
mc_label_list = []
with torch.no_grad():
for j, batch in tqdm(enumerate(test_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
m1_pred_list.append(m1_pred)
m2_pred_list.append(m2_pred)
m12_pred_list.append(m12_pred)
m21_pred_list.append(m21_pred)
mc_pred_list.append(mc_pred)
m1_label_list.append(m1_label)
m2_label_list.append(m2_label)
m12_label_list.append(m12_label)
m21_label_list.append(m21_label)
mc_label_list.append(mc_label)
if args.save_preds:
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
data = [(
i,
torch.argmax(m1_pred[i]).cpu().numpy(),
torch.argmax(m2_pred[i]).cpu().numpy(),
torch.argmax(m12_pred[i]).cpu().numpy(),
torch.argmax(m21_pred[i]).cpu().numpy(),
torch.argmax(mc_pred[i]).cpu().numpy(),
m1_label[i].cpu().numpy(),
m2_label[i].cpu().numpy(),
m12_label[i].cpu().numpy(),
m21_label[i].cpu().numpy(),
mc_label[i].cpu().numpy(),
false_belief[i]) for i in range(len(labels))
]
header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief']
with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header) # Write the header row
writer.writerows(data) # Write the data rows
#np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy())
test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores(
torch.cat(m1_pred_list),
torch.cat(m1_label_list),
torch.cat(m2_pred_list),
torch.cat(m2_label_list),
torch.cat(m12_pred_list),
torch.cat(m12_label_list),
torch.cat(m21_pred_list),
torch.cat(m21_label_list),
torch.cat(mc_pred_list),
torch.cat(mc_label_list)
)
print("Test m1 F1: {}".format(test_m1_f1))
print("Test m2 F1: {}".format(test_m2_f1))
print("Test m12 F1: {}".format(test_m12_f1))
print("Test m21 F1: {}".format(test_m21_f1))
print("Test mc F1: {}".format(test_mc_f1))
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}")
f.write(f"m1 f1: {test_m1_f1}")
f.write(f"m2 f1: {test_m2_f1}")
f.write(f"m12 f1: {test_m12_f1}")
f.write(f"m21 f1: {test_m21_f1}")
f.write(f"mc f1: {test_mc_f1}")
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--model_type', type=str)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--aggr', type=str, default='concat', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--test_frames', type=str, default=None)
parser.add_argument('--median', type=int, default=None)
parser.add_argument('--load_model_path', type=str)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--save_preds', action='store_true')
# Parse the command-line arguments
args = parser.parse_args()
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print('###########################################################################')
print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!')
print('###########################################################################')
test(args)

474
tbd/train.py Normal file
View file

@ -0,0 +1,474 @@
import torch
import os
import argparse
import numpy as np
import random
import datetime
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from tbd_dataloader import TBDDataset, collate_fn
from models.common_mind import CommonMindToMnet
from models.sl import SLToMnet
from models.implicit import ImplicitToMnet
from utils.helpers import count_parameters, compute_f1_scores
def main(args):
train_dataset = TBDDataset(
path=args.data_path,
mode="train",
use_preprocessed_img=True
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn
)
val_dataset = TBDDataset(
path=args.data_path,
mode="val",
use_preprocessed_img=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn
)
train_data = [data[1] for data in train_dataset.data]
val_data = [data[1] for data in val_dataset.data]
if args.logger:
wandb.config.update({"train_data": train_data})
wandb.config.update({"val_data": val_data})
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
# model
if args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
if args.resume_from_checkpoint is not None:
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# scheduler
if args.scheduler == None:
scheduler = None
else:
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
# loss function
if args.model_type == 'tom_sl':
ce_loss_m1 = nn.NLLLoss()
ce_loss_m2 = nn.NLLLoss()
ce_loss_m12 = nn.NLLLoss()
ce_loss_m21 = nn.NLLLoss()
ce_loss_mc = nn.NLLLoss()
else:
ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
stats = {
'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []},
'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}
}
max_val_f1 = 0
max_val_classification_epoch = None
counter = 0
print(f'Number of parameters: {count_parameters(model)}')
for i in range(args.num_epoch):
# training
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
epoch_train_loss_m1 = 0.0
epoch_train_loss_m2 = 0.0
epoch_train_loss_m12 = 0.0
epoch_train_loss_m21 = 0.0
epoch_train_loss_mc = 0.0
m1_train_batch_pred_list = []
m2_train_batch_pred_list = []
m12_train_batch_pred_list = []
m21_train_batch_pred_list = []
mc_train_batch_pred_list = []
m1_train_batch_label_list = []
m2_train_batch_label_list = []
m12_train_batch_label_list = []
m21_train_batch_label_list = []
mc_train_batch_label_list = []
model.train()
for j, batch in tqdm(enumerate(train_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
loss_m1 = ce_loss_m1(m1_pred, m1_label)
loss_m2 = ce_loss_m2(m2_pred, m2_label)
loss_m12 = ce_loss_m12(m12_pred, m12_label)
loss_m21 = ce_loss_m21(m21_pred, m21_label)
loss_mc = ce_loss_mc(mc_pred, mc_label)
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
epoch_train_loss_m1 += loss_m1.data.item()
epoch_train_loss_m2 += loss_m2.data.item()
epoch_train_loss_m12 += loss_m12.data.item()
epoch_train_loss_m21 += loss_m21.data.item()
epoch_train_loss_mc += loss_mc.data.item()
m1_train_batch_pred_list.append(m1_pred)
m2_train_batch_pred_list.append(m2_pred)
m12_train_batch_pred_list.append(m12_pred)
m21_train_batch_pred_list.append(m21_pred)
mc_train_batch_pred_list.append(mc_pred)
m1_train_batch_label_list.append(m1_label)
m2_train_batch_label_list.append(m2_label)
m12_train_batch_label_list.append(m12_label)
m21_train_batch_label_list.append(m21_label)
mc_train_batch_label_list.append(mc_label)
optimizer.zero_grad()
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
loss.backward()
optimizer.step()
if args.logger: wandb.log({
'batch_train_loss': loss.data.item(),
'lr': optimizer.param_groups[-1]['lr']
})
print("Epoch {}/{} batch {}/{} training done with loss={}".format(
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item())
)
if scheduler: scheduler.step()
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores(
torch.cat(m1_train_batch_pred_list),
torch.cat(m1_train_batch_label_list),
torch.cat(m2_train_batch_pred_list),
torch.cat(m2_train_batch_label_list),
torch.cat(m12_train_batch_pred_list),
torch.cat(m12_train_batch_label_list),
torch.cat(m21_train_batch_pred_list),
torch.cat(m21_train_batch_label_list),
torch.cat(mc_train_batch_pred_list),
torch.cat(mc_train_batch_label_list)
)
print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format(
i+1,
args.num_epoch,
epoch_train_loss_m1/len(train_dataloader),
epoch_train_loss_m2/len(train_dataloader),
epoch_train_loss_m12/len(train_dataloader),
epoch_train_loss_m21/len(train_dataloader),
epoch_train_loss_mc/len(train_dataloader),
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score
)
)
stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader))
stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader))
stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader))
stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader))
stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader))
stats['train']['m1_f1'].append(train_m1_f1_score)
stats['train']['m2_f1'].append(train_m2_f1_score)
stats['train']['m12_f1'].append(train_m12_f1_score)
stats['train']['m21_f1'].append(train_m21_f1_score)
stats['train']['mc_f1'].append(train_mc_f1_score)
if args.logger: wandb.log(
{
'train_m1_loss': epoch_train_loss_m1/len(train_dataloader),
'train_m2_loss': epoch_train_loss_m2/len(train_dataloader),
'train_m12_loss': epoch_train_loss_m12/len(train_dataloader),
'train_m21_loss': epoch_train_loss_m21/len(train_dataloader),
'train_mc_loss': epoch_train_loss_mc/len(train_dataloader),
'train_loss': epoch_train_loss_m1/len(train_dataloader) + \
epoch_train_loss_m2/len(train_dataloader) + \
epoch_train_loss_m12/len(train_dataloader) + \
epoch_train_loss_m21/len(train_dataloader) + \
epoch_train_loss_mc/len(train_dataloader),
'train_m1_f1_score': train_m1_f1_score,
'train_m2_f1_score': train_m2_f1_score,
'train_m12_f1_score': train_m12_f1_score,
'train_m21_f1_score': train_m21_f1_score,
'train_mc_f1_score': train_mc_f1_score
}
)
# validation
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
epoch_val_loss_m1 = 0.0
epoch_val_loss_m2 = 0.0
epoch_val_loss_m12 = 0.0
epoch_val_loss_m21 = 0.0
epoch_val_loss_mc = 0.0
m1_val_batch_pred_list = []
m2_val_batch_pred_list = []
m12_val_batch_pred_list = []
m21_val_batch_pred_list = []
mc_val_batch_pred_list = []
m1_val_batch_label_list = []
m2_val_batch_label_list = []
m12_val_batch_label_list = []
m21_val_batch_label_list = []
mc_val_batch_label_list = []
model.eval()
with torch.no_grad():
for j, batch in tqdm(enumerate(val_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
loss_m1 = ce_loss_m1(m1_pred, m1_label)
loss_m2 = ce_loss_m2(m2_pred, m2_label)
loss_m12 = ce_loss_m12(m12_pred, m12_label)
loss_m21 = ce_loss_m21(m21_pred, m21_label)
loss_mc = ce_loss_mc(mc_pred, mc_label)
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
epoch_val_loss_m1 += loss_m1.data.item()
epoch_val_loss_m2 += loss_m2.data.item()
epoch_val_loss_m12 += loss_m12.data.item()
epoch_val_loss_m21 += loss_m21.data.item()
epoch_val_loss_mc += loss_mc.data.item()
m1_val_batch_pred_list.append(m1_pred)
m2_val_batch_pred_list.append(m2_pred)
m12_val_batch_pred_list.append(m12_pred)
m21_val_batch_pred_list.append(m21_pred)
mc_val_batch_pred_list.append(mc_pred)
m1_val_batch_label_list.append(m1_label)
m2_val_batch_label_list.append(m2_label)
m12_val_batch_label_list.append(m12_label)
m21_val_batch_label_list.append(m21_label)
mc_val_batch_label_list.append(mc_label)
if args.logger: wandb.log({'batch_val_loss': loss.data.item()})
print("Epoch {}/{} batch {}/{} validation done with loss={}".format(
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item())
)
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores(
torch.cat(m1_val_batch_pred_list),
torch.cat(m1_val_batch_label_list),
torch.cat(m2_val_batch_pred_list),
torch.cat(m2_val_batch_label_list),
torch.cat(m12_val_batch_pred_list),
torch.cat(m12_val_batch_label_list),
torch.cat(m21_val_batch_pred_list),
torch.cat(m21_val_batch_label_list),
torch.cat(mc_val_batch_pred_list),
torch.cat(mc_val_batch_label_list)
)
print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format(
i+1,
args.num_epoch,
epoch_val_loss_m1/len(val_dataloader),
epoch_val_loss_m2/len(val_dataloader),
epoch_val_loss_m12/len(val_dataloader),
epoch_val_loss_m21/len(val_dataloader),
epoch_val_loss_mc/len(val_dataloader),
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score
)
)
stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader))
stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader))
stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader))
stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader))
stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader))
stats['val']['m1_f1'].append(val_m1_f1_score)
stats['val']['m2_f1'].append(val_m2_f1_score)
stats['val']['m12_f1'].append(val_m12_f1_score)
stats['val']['m21_f1'].append(val_m21_f1_score)
stats['val']['mc_f1'].append(val_mc_f1_score)
if args.logger: wandb.log(
{
'val_m1_loss': epoch_val_loss_m1/len(val_dataloader),
'val_m2_loss': epoch_val_loss_m2/len(val_dataloader),
'val_m12_loss': epoch_val_loss_m12/len(val_dataloader),
'val_m21_loss': epoch_val_loss_m21/len(val_dataloader),
'val_mc_loss': epoch_val_loss_mc/len(val_dataloader),
'val_loss': epoch_val_loss_m1/len(val_dataloader) + \
epoch_val_loss_m2/len(val_dataloader) + \
epoch_val_loss_m12/len(val_dataloader) + \
epoch_val_loss_m21/len(val_dataloader) + \
epoch_val_loss_mc/len(val_dataloader),
'val_m1_f1_score': val_m1_f1_score,
'val_m2_f1_score': val_m2_f1_score,
'val_m12_f1_score': val_m12_f1_score,
'val_m21_f1_score': val_m21_f1_score,
'val_mc_f1_score': val_mc_f1_score
}
)
# check for best stat/model using validation accuracy
if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1:
max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1]
max_val_classification_epoch = i+1
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
counter = 0
else:
counter += 1
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
if counter >= args.patience:
break
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
f.write('{}\n'.format(CFG))
f.write('{}\n'.format(train_data))
f.write('{}\n'.format(val_data))
f.write('{}\n'.format(stats))
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1))
f.close()
print(f'Results saved in {experiment_save_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--logger', action='store_true')
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
parser.add_argument('--use_mixup', action='store_true')
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--patience', type=int, default=99)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--num_epoch', type=int, default=300)
parser.add_argument('--lr', type=float, default=4e-4)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=0.005)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='concat', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
# Parse the command-line arguments
args = parser.parse_args()
if args.use_mixup and not args.mixup_alpha:
parser.error("--use_mixup requires --mixup_alpha")
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
# get experiment ID
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path, exist_ok=True)
experiment_save_path = os.path.join(args.save_path, experiment_id)
os.makedirs(experiment_save_path, exist_ok=True)
CFG = {
'use_ocr_custom_loss': 0,
'presaved': args.presaved,
'batch_size': args.batch_size,
'num_epoch': args.num_epoch,
'lr': args.lr,
'scheduler': args.scheduler,
'weight_decay': args.weight_decay,
'model_type': args.model_type,
'use_resnet': args.use_resnet,
'hidden_dim': args.hidden_dim,
'tom_weight': args.tom_weight,
'dropout': args.dropout,
'label_smoothing': args.label_smoothing,
'clip_grad_norm': args.clip_grad_norm,
'use_mixup': args.use_mixup,
'mixup_alpha': args.mixup_alpha,
'non_blocking_tensors': args.non_blocking,
'patience': args.patience,
'pin_memory': args.pin_memory,
'resume_from_checkpoint': args.resume_from_checkpoint,
'aggr': args.aggr,
'mods': args.mods,
'save_path': experiment_save_path ,
'seed': args.seed
}
print(CFG)
print(f'Saving results in {experiment_save_path}')
# set seed values
if args.logger:
wandb.init(project="tbd", config=CFG)
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
main(args)

224
tbd/utils/fb_scores_err.py Normal file
View file

@ -0,0 +1,224 @@
import os
import csv
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import f1_score
import numpy as np
from tqdm import tqdm
ALPHA = 0.7
BAR_WIDTH = 0.27
sns.set_theme(style='whitegrid')
#sns.set_palette('mako')
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
model_to_subdir = {
"IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"],
"IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"],
"IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"],
"IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"],
"CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"],
"CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"],
"CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"],
"CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"],
"DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"],
"Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"]
}
def read_data_from_csv(subdirectory_path):
print(subdirectory_path)
data = []
csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')]
for csv_file in csv_files:
file_path = os.path.join(subdirectory_path, csv_file)
with open(file_path, 'r') as file:
reader = csv.reader(file)
header_skipped = False
for row in reader:
if not header_skipped:
header_skipped = True
continue
frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row
data.append({
'frame': int(frame),
'm1_pred': int(m1_pred),
'm2_pred': int(m2_pred),
'm12_pred': int(m12_pred),
'm21_pred': int(m21_pred),
'mc_pred': int(mc_pred),
'm1_label': int(m1_label),
'm2_label': int(m2_label),
'm12_label': int(m12_label),
'm21_label': int(m21_label),
'mc_label': int(mc_label),
'false_belief': false_belief,
})
return data
def compute_correct_false_belief(data, mind="all", folder=None):
total_false_belief = 0
correct_false_belief = 0
for item in data:
if 'false' in item['false_belief']:
false_belief_type = item['false_belief'].split('_')[0]
if mind == "all" or false_belief_type in mind:
total_false_belief += 1
if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]:
if folder is not None:
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
f.write(f"{str(1)}\n")
correct_false_belief += 1
else:
if folder is not None:
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
f.write(f"{str(0)}\n")
if total_false_belief == 0:
accuracy = 0.0
else:
accuracy = correct_false_belief / total_false_belief
return accuracy
def compute_macro_f1_score(data, mind="all"):
y_true = []
y_pred = []
for item in data:
if 'false' in item['false_belief']:
false_belief_type = item['false_belief'].split('_')[0]
if mind == "all" or false_belief_type in mind:
y_true.append(int(item[f"{false_belief_type}_label"]))
y_pred.append(int(item[f"{false_belief_type}_pred"]))
if not y_true or not y_pred:
macro_f1 = 0.0
else:
macro_f1 = f1_score(y_true, y_pred, average='macro')
return macro_f1
def delete_files_in_subfolders(folder_path, file_names_to_delete):
"""
Delete specified files in all subfolders of a given folder.
Parameters:
folder_path: The path to the folder containing subfolders.
file_names_to_delete: A list of file names to be deleted.
Returns:
None
"""
for root, _, _ in os.walk(folder_path):
for file_name in file_names_to_delete:
file_path = os.path.join(root, file_name)
if os.path.exists(file_path):
os.remove(file_path)
print(f"Deleted: {file_path}")
if __name__ == "__main__":
folder_path = "predictions"
files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"]
delete_files_in_subfolders(folder_path, files_to_delete)
metric = "Accuracy"
if metric == "Macro F1":
score_function = compute_macro_f1_score
elif metric == "Accuracy":
score_function = compute_correct_false_belief
else:
raise ValueError
models = [
'Base', 'DB',
'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$',
'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$'
]
parent_dir = 'predictions'
minds = categories = ['m1', 'm2', 'm12', 'm21']
score_m1_m2 = []
score_m12_m21 = []
score_all = []
std_m1_m2 = []
std_m12_m21 = []
std_all = []
for model in models:
model_scores_m1_m2 = []
model_scores_m12_m21 = []
model_scores_all = []
for s in range(3):
subdir_path = os.path.join(parent_dir, model_to_subdir[model][s])
data = read_data_from_csv(subdir_path)
model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s]))
model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s]))
model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s]))
score_m1_m2.append(np.mean(model_scores_m1_m2))
std_m1_m2.append(np.std(model_scores_m1_m2))
score_m12_m21.append(np.mean(model_scores_m12_m21))
std_m12_m21.append(np.std(model_scores_m12_m21))
score_all.append(np.mean(model_scores_all))
std_all.append(np.std(model_scores_all))
# Create a dataframe to use with sns.catplot
data = {
'Model': [m for m in models],
'FO_FB_mean': score_m1_m2,
'FO_FB_std': std_m1_m2,
'SO_FB_mean': score_m12_m21,
'SO_FB_std': std_m12_m21,
'Both_mean': score_all,
'Both_std': std_all
}
models = data['Model']
fo_fb_mean = data['FO_FB_mean']
fo_fb_std = data['FO_FB_std']
so_fb_mean = data['SO_FB_mean']
so_fb_std = data['SO_FB_std']
both_mean = data['Both_mean']
both_std = data['Both_std']
bar_width = BAR_WIDTH
x = np.arange(len(models))
plt.figure(figsize=(13, 3.5))
fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA)
so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA)
both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA)
def add_labels(bars, std_values):
cnt = 0
for bar, std in zip(bars, std_values):
height = bar.get_height()
offset = std + 0.01
if cnt == 0 or cnt == 1 or cnt == 9:
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10)
else:
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10)
cnt = cnt + 1
add_labels(fo_fb_bars, fo_fb_std)
add_labels(so_fb_bars, so_fb_std)
add_labels(both_bars, both_std)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.xlabel('MToMnet', fontsize=14)
plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14)
plt.xticks(x, models, rotation=0, fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3)
plt.tight_layout()
plt.savefig('results/false_belief_first_vs_second.pdf')

210
tbd/utils/helpers.py Normal file

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,37 @@
import glob
import cv2
import torchvision.transforms as T
import torch
import os
from tqdm import tqdm
PATH_IN = "/scratch/bortoletto/data/tbd/images"
PATH_OUT = "/scratch/bortoletto/data/tbd/images_norm"
normalisation_steps = [
T.ToTensor(),
T.Resize((128,128)),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
preprocess_img = T.Compose(normalisation_steps)
def main():
print(f"{PATH_IN}/*/*/*.jpg")
all_img = glob.glob(f"{PATH_IN}/*/*/*.jpg")
print(len(all_img))
for img_path in tqdm(all_img):
new_img = preprocess_img(cv2.imread(img_path)).numpy()
img_path_split = img_path.split("/")
os.makedirs(f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}", exist_ok=True)
out_img = f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}/{img_path_split[-1][:-4]}.pt"
torch.save(new_img, out_img)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,106 @@
import pandas as pd
import os
import glob
import pickle
DATASET_LOCATION = "YOUR_PATH_HERE"
def reframe_annotation():
annotation_path = f'{DATASET_LOCATION}/retrieve_annotation/all/'
save_path = f'{DATASET_LOCATION}/reformat_annotation/'
if not os.path.exists(save_path):
os.makedirs(save_path)
tasks = glob.glob(annotation_path + '*.txt')
id_map = pd.read_csv('id_map.csv')
for task in tasks:
if not task.split('/')[-1].split('_')[2] == '1.txt':
continue
with open(task, 'r') as f:
lines = f.readlines()
task_id = int(task.split('/')[-1].split('_')[1]) + 1
clip = id_map.loc[id_map['ID'] == task_id].folder
print(task_id, len(clip))
if len(clip) == 0:
continue
with open(save_path + clip.item() + '.txt', 'w') as f:
for line in lines:
words = line.split()
f.write(words[0] + ',' + words[1] + ',' + words[2] + ',' + words[3] + ',' + words[4] + ',' + words[5] +
',' + words[6] + ',' + words[7] + ',' + words[8] + ',' + words[9] + ',' + ' '.join(words[10:]) + '\n')
f.close()
def get_grid_location(obj_frame):
x_min = obj_frame['x_min']#.item()
y_min = obj_frame['y_min']#.item()
x_max = obj_frame['x_max']#.item()
y_max = obj_frame['y_max']#.item()
gridLW = 1280 / 25.
gridLH = 720 / 15.
center_x, center_y = (x_min + x_max)/2, (y_min + y_max)/2
X, Y = int(center_x / gridLW), int(center_y / gridLH)
return X, Y
def regenerate_annotation():
annotation_path = f'{DATASET_LOCATION}/reformat_annotation/'
save_path=f'{DATASET_LOCATION}/regenerate_annotation/'
if not os.path.exists(save_path):
os.makedirs(save_path)
tasks = glob.glob(annotation_path + '*.txt')
for task in tasks:
print(task)
annt = pd.read_csv(task, sep=",", header=None)
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name", "label"]
obj_records = {}
for index, obj_frame in annt.iterrows():
if obj_frame['name'].startswith('P'):
continue
else:
assert obj_frame['name'].startswith('O')
obj_name = obj_frame['name']
# 0: enter 1: disappear 2: update 3: unchange
frame_id = obj_frame['frame']
curr_loc = get_grid_location(obj_frame)
mind_dict = {'m1': {'fluent': 3, 'loc': None}, 'm2': {'fluent': 3, 'loc': None},
'm12': {'fluent': 3, 'loc': None},
'm21': {'fluent': 3, 'loc': None}, 'mc': {'fluent': 3, 'loc': None},
'mg': {'fluent': 3, 'loc': curr_loc}}
mind_dict['mg']['loc'] = curr_loc
if not type(obj_frame['label']) == float:
mind_labels = obj_frame['label'].split()
for mind_label in mind_labels:
if mind_label == 'in_m1' or mind_label == 'in_m2' or mind_label == 'in_m12' \
or mind_label == 'in_m21' or mind_label == 'in_mc' or mind_label == '"in_m1"' or mind_label == '"in_m2"'\
or mind_label == '"in_m12"' or mind_label == '"in_m21"' or mind_label == '"in_mc"':
mind_name = mind_label.split('_')[1].split('"')[0]
mind_dict[mind_name]['loc'] = curr_loc
else:
mind_name = mind_label.split('_')[0].split('"')
if len(mind_name) > 1:
mind_name = mind_name[1]
else:
mind_name = mind_name[0]
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
mind_dict[mind_name]['loc'] = last_loc
for mind_name in mind_dict.keys():
if frame_id > 0:
curr_loc = mind_dict[mind_name]['loc']
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
if last_loc is None and curr_loc is not None:
mind_dict[mind_name]['fluent'] = 0
elif last_loc is not None and curr_loc is None:
mind_dict[mind_name]['fluent'] = 1
elif not last_loc == curr_loc:
mind_dict[mind_name]['fluent'] = 2
if obj_name not in obj_records:
obj_records[obj_name] = [mind_dict]
else:
obj_records[obj_name].append(mind_dict)
with open(save_path + task.split('/')[-1].split('.')[0] + '.p', 'wb') as f:
pickle.dump(obj_records, f)
if __name__ == '__main__':
reframe_annotation()
regenerate_annotation()

75
tbd/utils/similarity.py Normal file
View file

@ -0,0 +1,75 @@
import os
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import seaborn as sns
FOLDER_PATH = 'PATH_TO_FOLDER'
print(FOLDER_PATH)
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
COLORS = sns.color_palette()
sns.set_theme(style='white')
out_left_main_mods_full_test = []
out_right_main_mods_full_test = []
cell_left_main_mods_full_test = []
cell_right_main_mods_full_test = []
cm_left_main_mods_full_test = []
cm_right_main_mods_full_test = []
for i in range(len([filename for filename in os.listdir(FOLDER_PATH) if filename.endswith('.pt')])):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 13: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
elif len(data) == 12: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
elif len(data) == 11: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
else: raise ValueError("Data should have 13 (impl), others are not implemented")
# ====== PCA for left and right embeddings ====== #
out_left_pca = out_left[0].reshape(-1, 64)
out_right_pca = out_right[0].reshape(-1, 64)
out_left_and_right = np.concatenate((out_left_pca, out_right_pca), axis=0)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
pca_result_left = pca_result[:out_left_pca.shape[0]]
pca_result_right = pca_result[out_right_pca.shape[0]:]
plt.figure(figsize=(6.8,6))
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
plt.xlabel('Principal Component 1', fontsize=32)
plt.ylabel('Principal Component 2', fontsize=32)
plt.xticks(fontsize=24)
plt.xticks([-0.4, -0.2, 0.0, 0.2, 0.4])
plt.yticks(fontsize=24)
plt.legend(fontsize=32)
plt.tight_layout()
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
plt.close()

View file

@ -0,0 +1,96 @@
import os
import pandas as pd
import pickle
from tqdm import tqdm
def check_append(obj_name, m1, mind_name, obj_frame, flags, label):
if label:
if not obj_name in m1:
m1[obj_name] = []
m1[obj_name].append(
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
flags[mind_name] = 1
elif not flags[mind_name]:
m1[obj_name].append(
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
flags[mind_name] = 1
else: # false belief
if obj_name in m1:
if flags[mind_name]:
m1[obj_name].append(
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 1])
flags[mind_name] = 0
return flags, m1
def store_mind_set(clip, annotation_path, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
annt = pd.read_csv(annotation_path + clip, sep=",", header=None)
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name",
"label"]
obj_names = annt.name.unique()
m1, m2, m12, m21, mc = {}, {}, {}, {}, {}
flags = {'m1':0, 'm2':0, 'm12':0, 'm21':0, 'mc':0}
for obj_name in obj_names:
if obj_name == 'P1' or obj_name == 'P2':
continue
obj_frames = annt.loc[annt.name == obj_name]
for index, obj_frame in obj_frames.iterrows():
if type(obj_frame.label) == float:
continue
labels = obj_frame.label.split()
for label in labels:
if label == 'in_m1' or label == '"in_m1"':
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 1)
elif label == 'in_m2' or label == '"in_m2"':
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 1)
elif label == 'in_m12'or label == '"in_m12"':
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 1)
elif label == 'in_m21' or label == '"in_m21"':
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 1)
elif label == 'in_mc'or label == '"in_mc"':
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 1)
elif label == 'm1_false' or label == '"m1_false"':
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 0)
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
false_belief = 'm1_false'
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
elif label == 'm2_false' or label == '"m2_false"':
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 0)
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
false_belief = 'm2_false'
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
elif label == 'm12_false' or label == '"m12_false"':
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
false_belief = 'm12_false'
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
elif label == 'm21_false' or label == '"m21_false"':
flags, m21 = check_append(obj_name, m2, 'm21', obj_frame, flags, 0)
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
false_belief = 'm21_false'
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
# print('m1', m1)
# print('m2', m2)
# print('m12', m12)
# print('m21', m21)
# print('mc', mc)
#with open(save_path + clip.split('.')[0] + '.p', 'wb') as f:
# pickle.dump([m1, m2, m12, m21, mc], f)
if __name__ == "__main__":
annotation_path = '/scratch/bortoletto/data/tbd/reformat_annotation/'
save_path = '/scratch/bortoletto/data/tbd/store_mind_set/'
for clip in tqdm(os.listdir(annotation_path), desc="Processing videos", unit="item"):
store_mind_set(clip, annotation_path, save_path)

View file

@ -0,0 +1,95 @@
import time
from tbd_dataloader import TBDv2Dataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def point2screen(points):
K = [607.13232421875, 0.0, 638.6468505859375, 0.0, 607.1067504882812, 367.1607360839844, 0.0, 0.0, 1.0]
K = np.reshape(np.array(K), [3, 3])
rot_points = np.array(points) + np.array([0, 0.2, 0])
rot_points = rot_points
points_camera = rot_points.reshape(3, 1)
project_matrix = np.array(K).reshape(3, 3)
points_prj = project_matrix.dot(points_camera)
points_prj = points_prj.transpose()
if not points_prj[:, 2][0] == 0.0:
points_prj[:, 0] = points_prj[:, 0] / points_prj[:, 2]
points_prj[:, 1] = points_prj[:, 1] / points_prj[:, 2]
points_screen = points_prj[:, :2]
assert points_screen.shape == (1, 2)
points_screen = points_screen.reshape(-1)
return points_screen
if __name__ == '__main__':
data = TBDv2Dataset(number_frames_to_sample=1, resize_img=None)
index = np.random.randint(0, len(data))
start = time.time()
(
kinect_imgs, # <- len x 720 x 1280 x 3
tracker_imgs,
battery_imgs,
skele1,
skele2,
bbox,
tracker_skeID_sample, # <- This is the tracker skeleton ID
tracker2d,
label,
experiment_id, # From here for debugging
timestep,
obj_id, # <- This is the object ID as a string
) = data[index]
end = time.time()
print(f"Time for one sample: {end-start}")
img = kinect_imgs[-1]
bbox = bbox[-1]
print(label.shape)
print(skele1.shape)
print(skele2.shape)
skele1 = skele1[-1, :,:]
skele2 = skele2[-1, :,:]
print(skele1.shape)
# reshape img from c, h, w to h, w, c
img = img.permute(1, 2, 0)
fig, ax = plt.subplots(1)
ax.imshow(img)
print(bbox[0], bbox[1], bbox[2], bbox[3]) # t(top left x, top left y, width, height)
top_left_x, top_left_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3]
x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
for i in range(26):
print(skele1[i,0], skele1[i,1])
print(skele1[i,:].shape)
print(point2screen(skele1[i,:]))
x, y = point2screen(skele1[i,:])[0], point2screen(skele1[i,:])[1]
ax.text(x, y, f"{i}", fontsize=5, color='w')
wedge = patches.Wedge((x,y), 10, 0, 360, width=10, color='b')
ax.add_patch(wedge)
for i in range(26):
x, y = point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]
ax.text(x, y, f"{i}", fontsize=5, color='w')
wedge = patches.Wedge((point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]), 10, 0, 360, width=10, color='r')
ax.add_patch(wedge)
# Create a Rectangle patch
# rect = patches.Rectangle((top_left_x, top_left_y-height), width, height, linewidth=1, edgecolor='r', facecolor='none')
# ax.add_patch(rect)
# rect = patches.Rectangle((x_min, y_max), x_max-x_min, y_max-y_min, linewidth=1, edgecolor='g', facecolor='none')
# ax.add_patch(rect)
fig.savefig(f"bbox_{obj_id}_{index}_{experiment_id}.png")