This commit is contained in:
Matteo Bortoletto 2025-01-10 15:39:20 +01:00
parent d4aaf7f4ad
commit 25b8b3f343
55 changed files with 7592 additions and 4 deletions

206
boss/.gitignore vendored Normal file
View file

@ -0,0 +1,206 @@
experiments/
predictions_*
predictions/
tmp/
### WANDB ###
wandb/*
wandb
wandb-debug.log
logs
summary*
run*
### SYSTEM ###
*/__pycache__/*
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# End of https://www.toptal.com/developers/gitignore/api/python,linux

16
boss/README.md Normal file
View file

@ -0,0 +1,16 @@
# BOSS
## Installing Dependencies
Run `conda env create -f environment.yml`.
## New bounding box annotations
We re-extracted bounding box annotations using Yolo-v8. The new data is in `new_bbox/`. The rest of the data can be found [here](https://drive.google.com/drive/folders/1b8FdpyoWx9gUps-BX6qbE9Kea3C2Uyua).
## Train
`source run_train.sh`.
## Test
`source run_test.sh` (specify the path to the model).
## Resources
The original project page for the BOSS dataset can be found [here](https://sites.google.com/view/bossbelief/). Our code is based on the original implementation.

500
boss/dataloader.py Normal file
View file

@ -0,0 +1,500 @@
from torch.utils.data import Dataset
import os
from torchvision import transforms
from natsort import natsorted
import numpy as np
import pickle
import torch
import cv2
import ast
from einops import rearrange
from utils import spherical2cartesial
class Data(Dataset):
def __init__(
self,
frame_path,
label_path,
pose_path=None,
gaze_path=None,
bbox_path=None,
ocr_graph_path=None,
presaved=128,
sizes = (128,128),
spatial_patch_size: int = 16,
temporal_patch_size: int = 4,
img_channels: int = 3,
patch_data: bool = False,
flatten_dim: int = 1
):
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
self.size_1, self.size_2 = sizes
self.patch_data = patch_data
if self.patch_data:
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
if presaved == 224:
self.presaved = 'presaved'
elif presaved == 128:
self.presaved = 'presaved128'
else: raise ValueError
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
seq_lens = []
for frame_dir in natsorted(frame_dirs):
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
seq_lens.append(len(frame_paths))
self.data['frame_paths'].append(frame_paths)
print('Labels...')
with open(label_path, 'rb') as fp:
labels = pickle.load(fp)
left_labels, right_labels = labels
for i in episode_ids:
episode_id = int(i)
left_labels_i = left_labels[episode_id]
right_labels_i = right_labels[episode_id]
self.data['labels'].append([left_labels_i, right_labels_i])
fp.close()
print('Pose...')
self.pose_path = pose_path
if pose_path is not None:
for i in episode_ids:
pose_i_dir = os.path.join(pose_path, i)
poses = []
for j in natsorted(os.listdir(pose_i_dir)):
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
else:
poses.append(fs.getNode("pose_0").mat())
poses = np.array(poses)
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
#self.data['poses'].append(torch.tensor(np.array(poses)))
print('Gaze...')
"""
Gaze information is represented by a 3D gaze vector based on the observing cameras Cartesian eye coordinate system
"""
self.gaze_path = gaze_path
if gaze_path is not None:
for i in episode_ids:
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
with open(gaze_i_txt, 'r') as fp:
gaze_i = fp.readlines()
fp.close()
gaze_i = ast.literal_eval(gaze_i[0])
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
for j in range(len(gaze_i)):
if len(gaze_i[j]) >= 2:
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
elif len(gaze_i[j]) == 1:
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
else:
continue
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
print('Bbox...')
self.bbox_path = bbox_path
if bbox_path is not None:
self.objects = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
"""
# NOTE: old bbox
for i in episode_ids:
bbox_i_dir = os.path.join(bbox_path, i)
with open(bbox_i_dir, 'rb') as fp:
bboxes_i = pickle.load(fp)
len_i = len(bboxes_i)
fp.close()
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
for j in range(len(bboxes_i)):
items_i_j, bboxes_i_j = bboxes_i[j]
for k in range(len(items_i_j)):
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
bboxes_i_j[k][0] / 1920, # * self.size_1,
bboxes_i_j[k][1] / 1088, # * self.size_2,
bboxes_i_j[k][2] / 1920, # * self.size_1,
bboxes_i_j[k][3] / 1088, # * self.size_2
]) # [x_min, y_min, x_max, y_max]
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
"""
# NOTE: new bbox
for i in episode_ids:
bbox_dir = os.path.join(bbox_path, i)
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
bbox_content = fp.readlines()
fp.close()
for bbox_content_line in bbox_content:
bbox_content_values = bbox_content_line.split()
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
print('OCR...\n')
self.ocr_graph_path = ocr_graph_path
if ocr_graph_path is not None:
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
ocr_tensor = torch.zeros((27, 27))
for ocr in ocr_graph:
obj = ocr[0]
contexts = ocr[1:]
total_context_count = sum([i[1] for i in contexts])
for context in contexts:
ocr_tensor[obj, context[0]] = context[1] / total_context_count
ocr_tensor = torch.flatten(ocr_tensor)
for i in episode_ids:
self.data['ocr_graph'].append(ocr_tensor)
self.frame_path = frame_path
self.transform = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.data['frame_paths'])
def __getitem__(self, idx):
frames_path = self.data['frame_paths'][idx][0].split('/')
frames_path.insert(5, self.presaved)
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
with open(frames_path, 'rb') as f:
images = pickle.load(f)
images = torch.stack(images)
if self.patch_data:
images = rearrange(
images,
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
p1=self.spatial_patch_size,
p2=self.spatial_patch_size,
p3=self.temporal_patch_size
)
if self.pose_path is not None:
pose = self.data['poses'][idx]
if self.patch_data:
pose = rearrange(
pose,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
pose = None
if self.gaze_path is not None:
gaze = self.data['gazes'][idx]
if self.patch_data:
gaze = rearrange(
gaze,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
gaze = None
if self.bbox_path is not None:
bbox = self.data['bboxes'][idx]
if self.patch_data:
bbox = rearrange(
bbox,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
bbox = None
if self.ocr_graph_path is not None:
ocr_graphs = self.data['ocr_graph'][idx]
else:
ocr_graphs = None
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
class DataTest(Dataset):
def __init__(
self,
frame_path,
label_path,
pose_path=None,
gaze_path=None,
bbox_path=None,
ocr_graph_path=None,
presaved=128,
frame_ids=None,
median=None,
sizes = (128,128),
spatial_patch_size: int = 16,
temporal_patch_size: int = 4,
img_channels: int = 3,
patch_data: bool = False,
flatten_dim: int = 1
):
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
self.size_1, self.size_2 = sizes
self.patch_data = patch_data
if self.patch_data:
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
if presaved == 224:
self.presaved = 'presaved'
elif presaved == 128:
self.presaved = 'presaved128'
else: raise ValueError
if frame_ids is not None:
test_ids = []
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
for frame_dir in episode_ids:
if int(frame_dir) in frame_ids:
test_ids.append(str(frame_dir))
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
self.data['frame_paths'].append(frame_paths)
elif median is not None:
test_ids = []
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
for frame_dir in episode_ids:
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
seq_len = len(frame_paths)
if (median[1] and seq_len >= median[0]) or (not median[1] and seq_len < median[0]):
self.data['frame_paths'].append(frame_paths)
test_ids.append(int(frame_dir))
else:
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
test_ids = episode_ids.copy()
for frame_dir in episode_ids:
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
self.data['frame_paths'].append(frame_paths)
print('Labels...')
with open(label_path, 'rb') as fp:
labels = pickle.load(fp)
left_labels, right_labels = labels
for i in test_ids:
episode_id = int(i)
left_labels_i = left_labels[episode_id]
right_labels_i = right_labels[episode_id]
self.data['labels'].append([left_labels_i, right_labels_i])
fp.close()
print('Pose...')
self.pose_path = pose_path
if pose_path is not None:
for i in test_ids:
pose_i_dir = os.path.join(pose_path, i)
poses = []
for j in natsorted(os.listdir(pose_i_dir)):
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
else:
poses.append(fs.getNode("pose_0").mat())
poses = np.array(poses)
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
print('Gaze...')
self.gaze_path = gaze_path
if gaze_path is not None:
for i in test_ids:
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
with open(gaze_i_txt, 'r') as fp:
gaze_i = fp.readlines()
fp.close()
gaze_i = ast.literal_eval(gaze_i[0])
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
for j in range(len(gaze_i)):
if len(gaze_i[j]) >= 2:
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
elif len(gaze_i[j]) == 1:
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
else:
continue
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
print('Bbox...')
self.bbox_path = bbox_path
if bbox_path is not None:
self.objects = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
""" NOTE: old bbox
for i in test_ids:
bbox_i_dir = os.path.join(bbox_path, i)
with open(bbox_i_dir, 'rb') as fp:
bboxes_i = pickle.load(fp)
len_i = len(bboxes_i)
fp.close()
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
for j in range(len(bboxes_i)):
items_i_j, bboxes_i_j = bboxes_i[j]
for k in range(len(items_i_j)):
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
bboxes_i_j[k][0] / 1920, # * self.size_1,
bboxes_i_j[k][1] / 1088, # * self.size_2,
bboxes_i_j[k][2] / 1920, # * self.size_1,
bboxes_i_j[k][3] / 1088, # * self.size_2
]) # [x_min, y_min, x_max, y_max]
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
"""
for i in test_ids:
bbox_dir = os.path.join(bbox_path, i)
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
bbox_content = fp.readlines()
fp.close()
for bbox_content_line in bbox_content:
bbox_content_values = bbox_content_line.split()
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
print('OCR...\n')
self.ocr_graph_path = ocr_graph_path
if ocr_graph_path is not None:
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
ocr_tensor = torch.zeros((27, 27))
for ocr in ocr_graph:
obj = ocr[0]
contexts = ocr[1:]
total_context_count = sum([i[1] for i in contexts])
for context in contexts:
ocr_tensor[obj, context[0]] = context[1] / total_context_count
ocr_tensor = torch.flatten(ocr_tensor)
for i in test_ids:
self.data['ocr_graph'].append(ocr_tensor)
self.frame_path = frame_path
self.transform = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize((self.size_1, self.size_2)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.data['frame_paths'])
def __getitem__(self, idx):
frames_path = self.data['frame_paths'][idx][0].split('/')
frames_path.insert(5, self.presaved)
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
with open(frames_path, 'rb') as f:
images = pickle.load(f)
images = torch.stack(images)
if self.patch_data:
images = rearrange(
images,
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
p1=self.spatial_patch_size,
p2=self.spatial_patch_size,
p3=self.temporal_patch_size
)
if self.pose_path is not None:
pose = self.data['poses'][idx]
if self.patch_data:
pose = rearrange(
pose,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
pose = None
if self.gaze_path is not None:
gaze = self.data['gazes'][idx]
if self.patch_data:
gaze = rearrange(
gaze,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
gaze = None
if self.bbox_path is not None:
bbox = self.data['bboxes'][idx]
if self.patch_data:
bbox = rearrange(
bbox,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
bbox = None
if self.ocr_graph_path is not None:
ocr_graphs = self.data['ocr_graph'][idx]
else:
ocr_graphs = None
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs

240
boss/environment.yml Normal file
View file

@ -0,0 +1,240 @@
name: boss
channels:
- pyg
- anaconda
- conda-forge
- defaults
- pytorch
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- alembic=1.8.0=pyhd8ed1ab_0
- appdirs=1.4.4=pyh9f0ad1d_0
- asttokens=2.2.1=pyhd8ed1ab_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
- blas=1.0=mkl
- blosc=1.21.3=h6a678d5_0
- bottle=0.12.23=pyhd8ed1ab_0
- bottleneck=1.3.5=py38h7deecbd_0
- brotli=1.0.9=h5eee18b_7
- brotli-bin=1.0.9=h5eee18b_7
- brotlipy=0.7.0=py38h27cfd23_1003
- brunsli=0.1=h2531618_0
- bzip2=1.0.8=h7b6447c_0
- c-ares=1.18.1=h7f8727e_0
- ca-certificates=2023.01.10=h06a4308_0
- certifi=2023.5.7=py38h06a4308_0
- cffi=1.15.1=py38h74dc2b5_0
- cfitsio=3.470=h5893167_7
- charls=2.2.0=h2531618_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- click=8.1.3=unix_pyhd8ed1ab_2
- cloudpickle=2.0.0=pyhd3eb1b0_0
- cmaes=0.9.1=pyhd8ed1ab_0
- colorlog=6.7.0=py38h578d9bd_1
- contourpy=1.0.5=py38hdb19cb5_0
- cryptography=39.0.1=py38h9ce1e76_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- cycler=0.11.0=pyhd3eb1b0_0
- cytoolz=0.12.0=py38h5eee18b_0
- dask-core=2022.7.0=py38h06a4308_0
- dbus=1.13.18=hb2f20db_0
- debugpy=1.5.1=py38h295c915_0
- decorator=5.1.1=pyhd8ed1ab_0
- docker-pycreds=0.4.0=py_0
- entrypoints=0.4=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- expat=2.4.9=h6a678d5_0
- ffmpeg=4.3=hf484d3e_0
- fftw=3.3.9=h27cfd23_1
- flit-core=3.6.0=pyhd3eb1b0_0
- fontconfig=2.14.1=h52c9d5c_1
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.12.1=h4a9f257_0
- fsspec=2022.11.0=py38h06a4308_0
- giflib=5.2.1=h5eee18b_1
- gitdb=4.0.10=pyhd8ed1ab_0
- gitpython=3.1.31=pyhd8ed1ab_0
- glib=2.69.1=h4ff587b_1
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- greenlet=2.0.1=py38h6a678d5_0
- gst-plugins-base=1.14.1=h6a678d5_1
- gstreamer=1.14.1=h5eee18b_1
- icu=58.2=he6710b0_3
- idna=3.4=py38h06a4308_0
- imagecodecs=2021.8.26=py38hf0132c2_1
- imageio=2.19.3=py38h06a4308_0
- importlib-metadata=6.1.0=pyha770c72_0
- importlib_resources=5.2.0=pyhd3eb1b0_1
- intel-openmp=2021.4.0=h06a4308_3561
- ipykernel=6.15.0=pyh210e3f2_0
- ipython=8.11.0=pyh41d4057_0
- jedi=0.18.2=pyhd8ed1ab_0
- jinja2=3.1.2=py38h06a4308_0
- joblib=1.1.1=py38h06a4308_0
- jpeg=9e=h7f8727e_0
- jupyter_client=7.0.6=pyhd8ed1ab_0
- jupyter_core=4.12.0=py38h578d9bd_0
- jxrlib=1.1=h7b6447c_2
- kiwisolver=1.4.4=py38h6a678d5_0
- krb5=1.19.4=h568e23c_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libaec=1.0.4=he6710b0_1
- libbrotlicommon=1.0.9=h5eee18b_7
- libbrotlidec=1.0.9=h5eee18b_7
- libbrotlienc=1.0.9=h5eee18b_7
- libclang=10.0.1=default_hb85057a_2
- libcurl=7.87.0=h91b91d3_0
- libdeflate=1.8=h7f8727e_5
- libedit=3.1.20221030=h5eee18b_0
- libev=4.33=h7f8727e_1
- libevent=2.1.12=h8f2d780_0
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h7f8727e_2
- libidn2=2.3.2=h7f8727e_0
- libllvm10=10.0.1=hbcb73fb_5
- libnghttp2=1.46.0=hce63b2e_0
- libpng=1.6.37=hbc83047_0
- libpq=12.9=h16c4e8d_3
- libprotobuf=3.20.3=he621ea3_0
- libsodium=1.0.18=h36c2ea0_1
- libssh2=1.10.0=h8f2d780_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.5.0=h6a678d5_1
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libwebp=1.2.4=h11a3e52_0
- libwebp-base=1.2.4=h5eee18b_0
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.0.1=hfa300c1_0
- libxml2=2.9.14=h74e7548_0
- libxslt=1.1.35=h4e12654_0
- libzopfli=1.0.3=he6710b0_0
- locket=1.0.0=py38h06a4308_0
- lz4-c=1.9.4=h6a678d5_0
- mako=1.2.4=pyhd8ed1ab_0
- markupsafe=2.1.1=py38h7f8727e_0
- matplotlib=3.7.0=py38h06a4308_0
- matplotlib-base=3.7.0=py38h417a72b_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.1=py38hd3c417c_0
- mkl_random=1.2.2=py38h51133e4_0
- munkres=1.1.4=py_0
- natsort=7.1.1=pyhd3eb1b0_0
- ncurses=6.3=h5eee18b_3
- nest-asyncio=1.5.6=pyhd8ed1ab_0
- nettle=3.7.3=hbbd107a_1
- networkx=2.8.4=py38h06a4308_0
- nspr=4.33=h295c915_0
- nss=3.74=h0370c37_0
- numexpr=2.8.4=py38he184ba9_0
- numpy=1.23.5=py38h14f4228_0
- numpy-base=1.23.5=py38h31eccc5_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.4.0=h3ad879b_0
- openssl=1.1.1t=h7f8727e_0
- optuna=3.1.0=pyhd8ed1ab_0
- optuna-dashboard=0.9.0=pyhd8ed1ab_0
- packaging=22.0=py38h06a4308_0
- pandas=1.5.3=py38h417a72b_0
- parso=0.8.3=pyhd8ed1ab_0
- partd=1.2.0=pyhd3eb1b0_1
- pathtools=0.1.2=py_1
- pcre=8.45=h295c915_0
- pexpect=4.8.0=pyh1a96a4e_2
- pickleshare=0.7.5=py_1003
- pillow=9.4.0=py38h6a678d5_0
- pip=22.3.1=py38h06a4308_0
- ply=3.11=py38_0
- prompt-toolkit=3.0.38=pyha770c72_0
- prompt_toolkit=3.0.38=hd8ed1ab_0
- protobuf=3.20.3=py38h6a678d5_0
- psutil=5.9.0=py38h5eee18b_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pycparser=2.21=pyhd3eb1b0_0
- pyg=2.2.0=py38_torch_1.12.0_cu113
- pygments=2.14.0=pyhd8ed1ab_0
- pyopenssl=23.0.0=py38h06a4308_0
- pyparsing=3.0.9=py38h06a4308_0
- pyqt=5.15.7=py38h6a678d5_1
- pyqt5-sip=12.11.0=py38h6a678d5_1
- pysocks=1.7.1=py38h06a4308_0
- python=3.8.13=haa1d7c7_1
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python_abi=3.8=2_cp38
- pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0
- pytorch-cluster=1.6.0=py38_torch_1.12.0_cu113
- pytorch-mutex=1.0=cuda
- pytorch-scatter=2.1.0=py38_torch_1.12.0_cu113
- pytorch-sparse=0.6.16=py38_torch_1.12.0_cu113
- pytz=2022.7=py38h06a4308_0
- pywavelets=1.4.1=py38h5eee18b_0
- pyyaml=6.0=py38h5eee18b_1
- pyzmq=19.0.2=py38ha71036d_2
- qt-main=5.15.2=h327a75a_7
- qt-webengine=5.15.9=hd2b0992_4
- qtwebkit=5.212=h4eab89a_4
- readline=8.2=h5eee18b_0
- requests=2.28.1=py38h06a4308_1
- scikit-image=0.19.3=py38h6a678d5_1
- scikit-learn=1.2.1=py38h6a678d5_0
- scipy=1.9.3=py38h14f4228_0
- sentry-sdk=1.16.0=pyhd8ed1ab_0
- setproctitle=1.2.2=py38h0a891b7_2
- setuptools=65.6.3=py38h06a4308_0
- sip=6.6.2=py38h6a678d5_0
- six=1.16.0=pyhd3eb1b0_1
- smmap=3.0.5=pyh44b312d_0
- snappy=1.1.9=h295c915_0
- sqlalchemy=1.4.39=py38h5eee18b_0
- sqlite=3.40.1=h5082296_0
- stack_data=0.6.2=pyhd8ed1ab_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tifffile=2021.7.2=pyhd3eb1b0_2
- tk=8.6.12=h1ccaba5_0
- toml=0.10.2=pyhd3eb1b0_0
- toolz=0.12.0=py38h06a4308_0
- torchaudio=0.12.0=py38_cu113
- torchvision=0.13.0=py38_cu113
- tornado=6.1=py38h0a891b7_3
- tqdm=4.64.1=py38h06a4308_0
- traitlets=5.9.0=pyhd8ed1ab_0
- typing_extensions=4.4.0=py38h06a4308_0
- tzdata=2022g=h04d1e81_0
- urllib3=1.26.14=py38h06a4308_0
- wandb=0.13.10=pyhd8ed1ab_0
- wcwidth=0.2.6=pyhd8ed1ab_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.10=h5eee18b_1
- yaml=0.2.5=h7b6447c_0
- zeromq=4.3.4=h9c3ff4c_1
- zfp=0.5.5=h295c915_6
- zipp=3.11.0=py38h06a4308_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.2=ha4553b6_0
- pip:
- einops==0.6.1
- lion-pytorch==0.1.2
- llvmlite==0.40.0
- memory-efficient-attention-pytorch==0.1.2
- numba==0.57.0
- opencv-python==4.7.0.72
- pynndescent==0.5.10
- seaborn==0.12.2
- umap-learn==0.5.3
- x-transformers==1.14.0
prefix: /opt/anaconda3/envs/boss

0
boss/models/__init__.py Normal file
View file

230
boss/models/base.py Normal file
View file

@ -0,0 +1,230 @@
import torch
import torch.nn as nn
from .utils import pose_edge_index
from torch_geometric.nn import Sequential, GCNConv
from x_transformers import ContinuousTransformerWrapper, Decoder
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
def forward(self, x):
return self.net(x)
class CNN(nn.Module):
def __init__(self, hidden_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv3(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
return x
class MindNetLSTM(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTM, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats
class MindNetSL(nn.Module):
"""
Basic MindNet for SL ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetSL, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.LSTM = PreNorm(
hidden_dim*5,
nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
return self.act(self.proj(lstm_out)), feats
class MindNetTF(nn.Module):
"""
Basic MindNet for model-based ToM, Transformer on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetTF, self).__init__()
self.mods = mods
self.gaze_emb = nn.Linear(3, hidden_dim)
self.pose_edge_index = pose_edge_index()
self.pose_emb = GCNConv(3, hidden_dim)
self.tf = ContinuousTransformerWrapper(
dim_in=hidden_dim*len(mods),
dim_out=hidden_dim,
max_seq_len=747,
attn_layers=Decoder(
dim=512,
depth=6,
heads=8
)
)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
tf_inp = torch.cat(feats, 2)
tf_out = self.tf(self.dropout(tf_inp))
return tf_out, feats
class MindNetLSTMXL(nn.Module):
"""
Basic MindNet for model-based ToM, just LSTM on input concatenation
"""
def __init__(self, hidden_dim, dropout, mods):
super(MindNetLSTMXL, self).__init__()
self.mods = mods
self.gaze_emb = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.pose_edge_index = pose_edge_index()
self.pose_emb = Sequential('x, edge_index', [
(GCNConv(3, hidden_dim), 'x, edge_index -> x'),
nn.GELU(),
(GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'),
])
self.LSTM = PreNorm(
hidden_dim*len(mods),
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
feats = []
if 'rgb' in self.mods:
feats.append(rgb_feats)
if 'ocr' in self.mods:
feats.append(ocr_feats)
if 'pose' in self.mods:
bs, seq_len = pose.size(0), pose.size(1)
self.pose_edge_index = self.pose_edge_index.to(pose.device)
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
pose_emb = self.dropout(self.act(pose_emb))
pose_emb = torch.mean(pose_emb, dim=1)
hd = pose_emb.size(-1)
feats.append(pose_emb.view(bs, seq_len, hd))
if 'gaze' in self.mods:
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
feats.append(gaze_feats)
if 'bbox' in self.mods:
feats.append(bbox_feats)
lstm_inp = torch.cat(feats, 2)
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
return self.act(self.proj(lstm_out)), c_n, feats

249
boss/models/resnet.py Normal file
View file

@ -0,0 +1,249 @@
import torch
import torch.nn as nn
import torchvision.models as models
from .utils import left_bias, right_bias
class ResNet(nn.Module):
def __init__(self, input_dim, device):
super(ResNet, self).__init__()
# Conv
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
# FFs
self.left = nn.Linear(input_dim, 27)
self.right = nn.Linear(input_dim, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
image_feats = []
for i in range(sequence_len):
images_i = images[:,i].to(self.device)
image_i_feat = self.resnet(images_i)
image_i_feat = image_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1)
image_feats.append(image_i_feat)
image_feats = torch.permute(torch.stack(image_feats), (1,0,2))
left_beliefs = self.left(self.dropout(image_feats))
right_beliefs = self.right(self.dropout(image_feats))
return left_beliefs, right_beliefs, None
class ResNetGRU(nn.Module):
def __init__(self, input_dim, device):
super(ResNetGRU, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.gru = nn.GRU(input_dim, 512, batch_first=True)
for name, param in self.gru.named_parameters():
if "weight" in name:
nn.init.orthogonal_(param)
elif "bias" in name:
nn.init.constant_(param, 0)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.dropout = nn.Dropout(p=0.2)
self.relu = nn.ReLU()
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
rnn_inp = []
for i in range(sequence_len):
images_i = images[:,i]
rnn_i_feat = self.resnet(images_i)
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
rnn_inp.append(rnn_i_feat)
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
rnn_out, _ = self.gru(rnn_inp)
left_beliefs = self.left(self.dropout(rnn_out))
right_beliefs = self.right(self.dropout(rnn_out))
return left_beliefs, right_beliefs, None
class ResNetConv1D(nn.Module):
def __init__(self, input_dim, device):
super(ResNetConv1D, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
conv1d_inp = []
for i in range(sequence_len):
images_i = images[:,i]
images_i_feat = self.resnet(images_i)
images_i_feat = images_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1)
conv1d_inp.append(images_i_feat)
conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0))
conv1d_out = self.conv1d(conv1d_inp)
conv1d_out = conv1d_out[:,:,:-4]
conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1)))
left_beliefs = self.left(self.dropout(conv1d_out))
right_beliefs = self.right(self.dropout(conv1d_out))
return left_beliefs, right_beliefs, None
class ResNetLSTM(nn.Module):
def __init__(self, input_dim, device):
super(ResNetLSTM, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
rnn_inp = []
for i in range(sequence_len):
images_i = images[:,i]
rnn_i_feat = self.resnet(images_i)
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
rnn_inp.append(rnn_i_feat)
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
rnn_out, _ = self.lstm(rnn_inp)
left_beliefs = self.left(self.dropout(rnn_out))
right_beliefs = self.right(self.dropout(rnn_out))
return left_beliefs, right_beliefs, None
if __name__ == '__main__':
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 150)
gazes = torch.ones(3, 22, 6)
bboxes = torch.ones(3, 22, 108)
model = ResNet(32, 'cpu')
print(model(images, poses, gazes, bboxes, None)[0].shape)

View file

@ -0,0 +1,95 @@
import torch
import torch.nn as nn
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
import torchvision.models as models
class SingleMindNet(nn.Module):
"""
Base ToM net. Supports any subset of modalities
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(SingleMindNet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net = MindNetLSTM(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if __name__ == "__main__":
def count_parameters(model):
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = SingleMindNet(64, 'cpu', False, 0.5)
print(count_parameters(model))
breakpoint()
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

104
boss/models/tom_base.py Normal file
View file

@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
import numpy as np
class BaseToMnet(nn.Module):
"""
Base ToM net. Supports any subset of modalities
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(BaseToMnet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = BaseToMnet(64, 'cpu', False, 0.5)
print(count_parameters(model))
breakpoint()
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

View file

@ -0,0 +1,269 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM, MindNetLSTMXL
from memory_efficient_attention_pytorch import Attention
class CommonMindToMnet(nn.Module):
"""
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(CommonMindToMnet, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
#for param in self.cnn.parameters():
# param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right)
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=common_mind)
r = self.attn_right(x=out_right, context=common_mind)
elif self.aggr == 'mult':
l = out_left * common_mind
r = out_right * common_mind
elif self.aggr == 'sum':
l = out_left + common_mind
r = out_right + common_mind
elif self.aggr == 'concat':
l = torch.cat([out_left, common_mind], 1)
r = torch.cat([out_right, common_mind], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
class CommonMindToMnetXL(nn.Module):
"""
XL model.
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(CommonMindToMnetXL, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTMXL(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTMXL(hidden_dim, dropout, mods)
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 27),
)
self.right = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 27)
)
self.left[-1].bias.data = torch.tensor(left_bias).log()
self.right[-1].bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right)
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=common_mind)
r = self.attn_right(x=out_right, context=common_mind)
elif self.aggr == 'mult':
l = out_left * common_mind
r = out_right * common_mind
elif self.aggr == 'sum':
l = out_left + common_mind
r = out_right + common_mind
elif self.aggr == 'concat':
l = torch.cat([out_left, common_mind], 1)
r = torch.cat([out_right, common_mind], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = CommonMindToMnetXL(64, 'cpu', False, 0.5, aggr='attn')
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

144
boss/models/tom_implicit.py Normal file
View file

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetLSTM
from memory_efficient_attention_pytorch import Attention
class ImplicitToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(ImplicitToMnet, self).__init__()
self.aggr = aggr
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
self.ln_left = nn.LayerNorm(hidden_dim)
self.ln_right = nn.LayerNorm(hidden_dim)
if aggr == 'attn':
self.attn_left = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.attn_right = Attention(
dim = hidden_dim,
dim_head = hidden_dim // 4,
heads = 4,
memory_efficient = True,
q_bucket_size = hidden_dim,
k_bucket_size = hidden_dim)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
if self.aggr == 'no_tom':
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if self.aggr == 'attn':
l = self.attn_left(x=out_left, context=cell_right)
r = self.attn_right(x=out_right, context=cell_left)
elif self.aggr == 'mult':
l = out_left * cell_right
r = out_right * cell_left
elif self.aggr == 'sum':
l = out_left + cell_right
r = out_right + cell_left
elif self.aggr == 'concat':
l = torch.cat([out_left, cell_right], 1)
r = torch.cat([out_right, cell_left], 1)
else: raise ValueError
l = self.act(l)
l = self.ln_left(l)
r = self.act(r)
r = self.ln_right(r)
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
left_beliefs = self.left(l)
right_beliefs = self.right(r)
if self.aggr == 'concat':
left_beliefs = self.left(l)[:, :-1, :]
right_beliefs = self.right(r)[:, :-1, :]
return left_beliefs, right_beliefs, [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = ImplicitToMnet(64, 'cpu', False, 0.5, aggr='attn')
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

98
boss/models/tom_sl.py Normal file
View file

@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph, pose_edge_index
from .base import CNN, MindNetSL
class SLToMnet(nn.Module):
"""
Speaker-Listener ToMnet
"""
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(SLToMnet, self).__init__()
self.tom_weight = tom_weight
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetSL(hidden_dim, dropout, mods)
self.mind_net_right = MindNetSL(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
left_logits = self.left(out_left)
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
right_logits = self.left(out_right)
left_ranking = torch.log_softmax(left_logits, dim=-1)
right_ranking = torch.log_softmax(right_logits, dim=-1)
right_beliefs = left_ranking + self.tom_weight * right_ranking
left_beliefs = right_ranking + self.tom_weight * left_ranking
return left_beliefs, right_beliefs, [out_left, out_right] + feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = SLToMnet(64, 'cpu', 2.0, False, 0.5)
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

104
boss/models/tom_tf.py Normal file
View file

@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn.conv import GCNConv
from .utils import left_bias, right_bias, build_ocr_graph
from .base import CNN, MindNetTF
from memory_efficient_attention_pytorch import Attention
class TFToMnet(nn.Module):
"""
Implicit ToM net. Supports any subset of modalities
Possible aggregations: sum, mult, attn, concat
"""
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
super(TFToMnet, self).__init__()
# ---- Images ----#
if resnet:
resnet = models.resnet34(weights="IMAGENET1K_V1")
self.cnn = nn.Sequential(
*(list(resnet.children())[:-1])
)
for param in self.cnn.parameters():
param.requires_grad = False
self.rgb_ff = nn.Linear(512, hidden_dim)
else:
self.cnn = CNN(hidden_dim)
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
# ---- OCR and bbox -----#
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
self.ocr_gnn = GCNConv(-1, hidden_dim)
self.bbox_ff = nn.Linear(108, hidden_dim)
# ---- Others ----#
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.device = device
# ---- Mind nets ----#
self.mind_net_left = MindNetTF(hidden_dim, dropout, mods)
self.mind_net_right = MindNetTF(hidden_dim, dropout, mods)
self.left = nn.Linear(hidden_dim, 27)
self.right = nn.Linear(hidden_dim, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
assert images is not None
assert poses is not None
assert gazes is not None
assert bboxes is not None
batch_size, sequence_len, channels, height, width = images.shape
bbox_feat = self.act(self.bbox_ff(bboxes))
rgb_feat = []
for i in range(sequence_len):
images_i = images[:,i]
img_i_feat = self.cnn(images_i)
img_i_feat = img_i_feat.view(batch_size, -1)
rgb_feat.append(img_i_feat)
rgb_feat = torch.stack(rgb_feat, 1)
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
self.ocr_edge_index,
self.ocr_edge_attr)))
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
l = self.dropout(self.act(out_left))
r = self.dropout(self.act(out_right))
left_beliefs = self.left(l)
right_beliefs = self.right(r)
return left_beliefs, right_beliefs, feats_left + feats_right
if __name__ == "__main__":
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 2, 75)
gazes = torch.ones(3, 22, 2, 3)
bboxes = torch.ones(3, 22, 108)
model = TFToMnet(64, 'cpu', False, 0.5)
out = model(images, poses, gazes, bboxes, None)
print(out[0].shape)

95
boss/models/utils.py Normal file
View file

@ -0,0 +1,95 @@
import torch
left_bias = [0.10659290976303822,
0.025158905348262015,
0.02811095449589107,
0.026342384511050237,
0.025318475572458178,
0.02283183957873461,
0.021581872822531316,
0.08062285577511237,
0.03824366373234754,
0.04853594319300018,
0.09653998563867983,
0.02961357410707162,
0.02961357410707162,
0.03172787957767081,
0.029985904630196004,
0.02897529321028696,
0.06602218026116327,
0.015345336560197867,
0.026900880295736816,
0.024879657455918726,
0.028669450280577644,
0.01936118720246802,
0.02341693040078721,
0.014707055663413206,
0.027007260445200926,
0.04146166325363687,
0.04243238211749688]
right_bias = [0.13147256721895695,
0.012433179968617855,
0.01623627031195979,
0.013683146724821148,
0.015252253929416771,
0.012579452674131008,
0.03127576394244834,
0.10325523257360177,
0.041155820323927554,
0.06563655221935587,
0.12684503071726816,
0.016156485199861705,
0.0176989973670913,
0.020238823435546928,
0.01918831945958884,
0.01791175766601952,
0.08768383819579266,
0.019002154198026647,
0.029600276588388607,
0.01578415467673732,
0.0176989973670913,
0.011834791627882237,
0.014919815962341426,
0.007552990611951809,
0.029759846812584773,
0.04981250498656951,
0.05533097524002021]
def build_ocr_graph(device):
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
edge_index = []
edge_attr = []
for i in range(len(ocr_graph)):
for j in range(1, len(ocr_graph[i])):
source_node = ocr_graph[i][0]
target_node = ocr_graph[i][j][0]
edge_index.append([source_node, target_node])
edge_attr.append(ocr_graph[i][j][1])
ocr_edge_index = torch.tensor(edge_index).t().long()
ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1)
x = torch.arange(0, 27)
ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float)
return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device)
def pose_edge_index():
return torch.tensor(
[[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20],
[15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]],
dtype=torch.long)

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
boss/outfile Normal file

Binary file not shown.

View file

@ -0,0 +1,82 @@
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style='whitegrid')
COLORS = sns.color_palette()
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
abl_old_bbox_mean = [0.539406718, 0.5348262324, 0.529845863]
abl_old_bbox_std = [0.03639921819, 0.01519544901, 0.01718265794]
filename = 'results/abl_cm_concat_old_vs_new_bbox'
#def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
means = []
stds = []
for key, values in abl_tom_cm_concat.items():
mean = np.mean(values)
std = np.std(values)
means.append(mean)
stds.append(std)
fig, ax = plt.subplots(figsize=(10,4))
x = np.arange(len(abl_tom_cm_concat))
width = 0.6
rects1 = ax.bar(
x, means, width, label='New bbox', yerr=stds,
capsize=5,
color=[MTOM_COLORS['CG']]*8,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
rects2 = ax.bar(
[0, 2, 5], abl_old_bbox_mean, width, label='Original bbox', yerr=abl_old_bbox_std,
capsize=5,
color=[COLORS[9]]*8,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
ax.set_ylabel('Accuracy', fontsize=18)
xticklabels = list(abl_tom_cm_concat.keys())
ax.set_xticks(np.arange(len(xticklabels)))
ax.set_xticklabels(xticklabels, rotation=0, fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
# Add value labels above each bar, avoiding overlapping with error bars
for rect, std in zip(rects1, stds):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
for rect, std in zip(rects2, abl_old_bbox_std):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.legend(fontsize=14)
ax.grid(axis='x')
plt.tight_layout()
plt.savefig(f'{filename}.pdf', bbox_inches='tight')

82
boss/plots/pca.py Normal file
View file

@ -0,0 +1,82 @@
import torch
import os
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-22_12-00-38_train_None" # no_tom seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-39-41_train_None" # impl mult seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-47-07_train_None" # impl sum seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_15-58-44_train_None" # impl attn seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-58-04_train_None" # impl concat seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-40-01_train_None" # cm mult seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-45-55_train_None" # cm sum seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-50-42_train_None" # cm attn seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-57-15_train_None" # cm concat seed 1
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-37-50_train_None" # db seed 1
print(FOLDER_PATH)
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
sns.set_theme(style='white')
for i in range(60):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 14: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
out_left = out_left.squeeze(0)
cell_left = cell_left.squeeze(0)
out_right = out_right.squeeze(0)
cell_right = cell_right.squeeze(0)
elif len(data) == 13: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
common_mind = common_mind.squeeze(0)
elif len(data) == 12: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
# ====== PCA for left and right embeddings ====== #
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
pca_result_left = pca_result[:out_left.shape[0]]
pca_result_right = pca_result[out_right.shape[0]:]
plt.figure(figsize=(7,6))
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
plt.xlabel('Principal Component 1', fontsize=30)
plt.ylabel('Principal Component 2', fontsize=30)
plt.grid(False)
plt.legend(fontsize=30)
plt.tight_layout()
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
plt.close()

View file

@ -0,0 +1,42 @@
{
"all": [
0.7402937327,
0.7171004799,
0.7305511124
],
"rgb\npose\ngaze": [
0.4736803839,
0.4658281227,
0.4948378653
],
"rgb\nocr\nbbox": [
0.7364403083,
0.7407663225,
0.7321506471
],
"rgb\ngaze": [
0.5013814163,
0.418859968,
0.4644103534
],
"rgb\npose": [
0.4545586738,
0.4584484514,
0.4525229024
],
"rgb\nbbox": [
0.716591537,
0.7334593573,
0.758324851
],
"rgb\nocr": [
0.4581939799,
0.456449033,
0.4577577432
],
"rgb": [
0.4896393776,
0.4907299695,
0.4583757452
]
}

Binary file not shown.

72
boss/results/all.json Normal file
View file

@ -0,0 +1,72 @@
{
"no_tom": [
0.6504653192,
0.6404682274,
0.6666787844
],
"sl": [
0.6584993456,
0.6608259415,
0.6584629926
],
"cm mult": [
0.6739130435,
0.6872182638,
0.6591173477
],
"cm sum": [
0.5820852116,
0.6401774029,
0.6001163298
],
"cm attn": [
0.3240511851,
0.2688672386,
0.3028573506
],
"cm concat": [
0.7402937327,
0.7171004799,
0.7305511124
],
"impl mult": [
0.7119746983,
0.7019776065,
0.7244437982
],
"impl sum": [
0.6542460375,
0.6642431293,
0.6678420823
],
"impl attn": [
0.3311763851,
0.3125636179,
0.3112549077
],
"impl concat": [
0.6957975862,
0.7227352043,
0.6971062964
],
"resnet": [
0.467173186,
0.4267485822,
0.4195870292
],
"gru": [
0.5724152974,
0.525628908,
0.4825141777
],
"lstm": [
0.6210193398,
0.5547477098,
0.6572633416
],
"conv1d": [
0.4478333576,
0.4114802966,
0.3853060928
]
}

BIN
boss/results/all.pdf Normal file

Binary file not shown.

181
boss/test.py Normal file
View file

@ -0,0 +1,181 @@
import argparse
import torch
from tqdm import tqdm
import csv
import os
from torch.utils.data import DataLoader
from dataloader import DataTest
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
from models.tom_implicit import ImplicitToMnet
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
from models.tom_sl import SLToMnet
from models.single_mindnet import SingleMindNet
from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim
def test(args):
if args.test_frames == 'friends':
test_frame_ids = friends
elif args.test_frames == 'strangers':
test_frame_ids = strangers
elif args.test_frames == 'once':
test_frame_ids = tried_once
elif args.test_frames == 'twice_thrice':
test_frame_ids = tried_twice + tried_thrice
elif args.test_frames is None:
test_frame_ids = None
else: raise NameError
if args.median is not None:
median = (240, False)
else:
median = None
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
flatten_dim = 2
else:
flatten_dim = 1
# load datasets
test_dataset = DataTest(
args.test_frame_path,
args.label_path,
args.test_pose_path,
args.test_gaze_path,
args.test_bbox_path,
args.ocr_graph_path,
args.presaved,
test_frame_ids,
median,
flatten_dim=flatten_dim
)
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
assert args.load_model_path is not None
if args.model_type == 'resnet':
inp_dim = get_input_dim(args.mods)
model = ResNet(inp_dim, device).to(device)
elif args.model_type == 'gru':
inp_dim = get_input_dim(args.mods)
model = ResNetGRU(inp_dim, device).to(device)
elif args.model_type == 'lstm':
inp_dim = get_input_dim(args.mods)
model = ResNetLSTM(inp_dim, device).to(device)
elif args.model_type == 'conv1d':
inp_dim = get_input_dim(args.mods)
model = ResNetConv1D(inp_dim, device).to(device)
elif args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_single':
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_cm_xl':
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
model.device = device
model.eval()
if args.save_preds:
# Define the output file path
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(f'Saving predictions in {folder_path}.')
print('Testing...')
num_correct = 0
cnt = 0
with torch.no_grad():
for j, batch in tqdm(enumerate(test_dataloader)):
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=True)
if poses is not None: poses = poses.to(device, non_blocking=True)
if gazes is not None: gazes = gazes.to(device, non_blocking=True)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True)
pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
cnt += batch_num_pred
num_correct += batch_num_correct
if args.save_preds:
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
data = [(
i,
torch.argmax(pred_left_labels[i]).cpu().numpy(),
torch.argmax(pred_right_labels[i]).cpu().numpy(),
labels[:, 0][i].cpu().numpy(),
labels[:, 1][i].cpu().numpy()) for i in range(len(labels))
]
header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label']
with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header) # Write the header row
writer.writerows(data) # Write the data rows
test_acc = num_correct / cnt
print("Test accuracy: {}".format(num_correct / cnt))
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
f.write(str(test_acc))
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='mult', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame')
parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose')
parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze')
parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels')
parser.add_argument('--ocr_graph_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='outfile')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--test_frames', type=str, default=None)
parser.add_argument('--median', type=int, default=None)
parser.add_argument('--load_model_path', type=str)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--save_preds', action='store_true')
# Parse the command-line arguments
args = parser.parse_args()
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
test(args)

324
boss/train.py Normal file
View file

@ -0,0 +1,324 @@
import torch
import os
import argparse
import numpy as np
import random
import datetime
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from dataloader import Data
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
from models.tom_implicit import ImplicitToMnet
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
from models.tom_sl import SLToMnet
from models.tom_tf import TFToMnet
from models.single_mindnet import SingleMindNet
from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim
def train(args):
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
flatten_dim = 2
else:
flatten_dim = 1
train_dataset = Data(
args.train_frame_path,
args.label_path,
args.train_pose_path,
args.train_gaze_path,
args.train_bbox_path,
args.ocr_graph_path,
presaved=args.presaved,
flatten_dim=flatten_dim
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
val_dataset = Data(
args.val_frame_path,
args.label_path,
args.val_pose_path,
args.val_gaze_path,
args.val_bbox_path,
args.ocr_graph_path,
presaved=args.presaved,
flatten_dim=flatten_dim
)
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
if args.model_type == 'resnet':
inp_dim = get_input_dim(args.mods)
model = ResNet(inp_dim, device).to(device)
elif args.model_type == 'gru':
inp_dim = get_input_dim(args.mods)
model = ResNetGRU(inp_dim, device).to(device)
elif args.model_type == 'lstm':
inp_dim = get_input_dim(args.mods)
model = ResNetLSTM(inp_dim, device).to(device)
elif args.model_type == 'conv1d':
inp_dim = get_input_dim(args.mods)
model = ResNetConv1D(inp_dim, device).to(device)
elif args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_cm_xl':
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_single':
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_tf':
model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
else: raise NotImplementedError
if args.resume_from_checkpoint is not None:
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.scheduler == None:
scheduler = None
else:
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1)
else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device)
stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}}
max_val_classification_acc = 0
max_val_classification_epoch = None
counter = 0
print(f'Number of parameters: {count_parameters(model)}')
for i in range(args.num_epoch):
# training
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
temp_train_classification_loss = []
epoch_num_correct = 0
epoch_cnt = 0
model.train()
for j, batch in tqdm(enumerate(train_dataloader)):
if args.use_mixup:
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27)
else:
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
if args.use_mixup:
labels = torch.reshape(labels, (-1, 54)).to(device)
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:])
else:
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1])
epoch_cnt += batch_num_pred
epoch_num_correct += batch_num_correct
temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2)
optimizer.zero_grad()
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
loss.backward()
optimizer.step()
if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']})
print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format(
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc)
)
if scheduler: scheduler.step()
print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format(
i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
)
stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt)
stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt)
# validation
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
temp_val_classification_loss = []
epoch_num_correct = 0
epoch_cnt = 0
model.eval()
with torch.no_grad():
for j, batch in tqdm(enumerate(val_dataloader)):
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
epoch_cnt += batch_num_pred
epoch_num_correct += batch_num_correct
loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1])
temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2)
if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()})
print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format(
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc)
)
print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format(
i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
)
cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt
cls_acc = epoch_num_correct / epoch_cnt
stats['val']['cls_loss'].append(cls_loss)
stats['val']['cls_acc'].append(cls_acc)
if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i})
# check for best stat/model using validation accuracy
if stats['val']['cls_acc'][-1] >= max_val_classification_acc:
max_val_classification_acc = stats['val']['cls_acc'][-1]
max_val_classification_epoch = i+1
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
counter = 0
else:
counter += 1
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
if counter >= args.patience:
break
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
f.write('{}\n'.format(CFG))
f.write('{}\n'.format(stats))
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc))
f.close()
print(f'Results saved in {experiment_save_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--logger', action='store_true')
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
parser.add_argument('--use_mixup', action='store_true')
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--patience', type=int, default=99)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--num_epoch', type=int, default=300)
parser.add_argument('--lr', type=float, default=4e-4)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=0.005)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='mult', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame')
parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose')
parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze')
parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels')
parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame')
parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose')
parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze')
parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels')
parser.add_argument('--ocr_graph_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='outfile')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
# Parse the command-line arguments
args = parser.parse_args()
if args.use_mixup and not args.mixup_alpha:
parser.error("--use_mixup requires --mixup_alpha")
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
# get experiment ID
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path, exist_ok=True)
experiment_save_path = os.path.join(args.save_path, experiment_id)
os.makedirs(experiment_save_path, exist_ok=True)
CFG = {
'use_ocr_custom_loss': 0,
'presaved': args.presaved,
'batch_size': args.batch_size,
'num_epoch': args.num_epoch,
'lr': args.lr,
'scheduler': args.scheduler,
'weight_decay': args.weight_decay,
'model_type': args.model_type,
'use_resnet': args.use_resnet,
'hidden_dim': args.hidden_dim,
'tom_weight': args.tom_weight,
'dropout': args.dropout,
'label_smoothing': args.label_smoothing,
'clip_grad_norm': args.clip_grad_norm,
'use_mixup': args.use_mixup,
'mixup_alpha': args.mixup_alpha,
'non_blocking_tensors': args.non_blocking,
'patience': args.patience,
'pin_memory': args.pin_memory,
'resume_from_checkpoint': args.resume_from_checkpoint,
'aggr': args.aggr,
'mods': args.mods,
'save_path': experiment_save_path ,
'seed': args.seed
}
print(CFG)
print(f'Saving results in {experiment_save_path}')
# set seed values
if args.logger:
wandb.init(project="boss", config=CFG)
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
train(args)

769
boss/utils.py Normal file
View file

@ -0,0 +1,769 @@
import os
import torch
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
#from umap import UMAP
import json
import seaborn as sns
import pandas as pd
from natsort import natsorted
import argparse
import seaborn as sns
COLORS_DM = {
"red": (242/256, 165/256, 179/256),
"blue": (195/256, 219/256, 252/256),
"green": (156/256, 228/256, 213/256),
"yellow": (250/256, 236/256, 144/256),
"violet": (207/256, 187/256, 244/256),
"orange": (244/256, 188/256, 154/256)
}
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
COLORS = sns.color_palette()
OBJECTS = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
tried_once = [2, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 32, 38, 40, 41, 42, 44, 47, 50, 51, 52, 53, 55,
56, 57, 61, 63, 65, 67, 68, 70, 71, 72, 73, 74, 76, 80, 81, 83, 85, 87,
88, 89, 90, 92, 93, 96, 97, 99, 101, 102, 105, 106, 108, 110, 111, 112,
113, 114, 116, 118, 121, 123, 125, 131, 132, 134, 135, 140, 142, 143,
145, 146, 148, 149, 151, 152, 154, 155, 156, 157, 160, 161, 162, 165,
169, 170, 171, 173, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185,
186, 187, 190, 191, 194, 196, 203, 204, 206, 207, 208, 209, 210, 211,
213, 214, 216, 218, 219, 220, 222, 225, 227, 228, 229, 232, 233, 235,
236, 237, 238, 239, 242, 243, 246, 247, 249, 251, 252, 254, 255, 256,
257, 260, 261, 262, 263, 265, 266, 268, 270, 272, 275, 277, 278, 279,
281, 282, 287, 290, 296, 298]
tried_twice = [1, 3, 11, 13, 14, 16, 17, 31, 33, 35, 36, 37, 39,
43, 45, 49, 54, 58, 59, 60, 62, 64, 66, 69, 75, 77, 79, 82, 84, 86,
91, 94, 95, 98, 100, 103, 104, 107, 109, 115, 117, 119, 120, 122,
124, 126, 127, 128, 129, 130, 133, 136, 137, 138, 139, 141, 144,
147, 150, 153, 158, 159, 164, 166, 167, 168, 172, 174, 177, 188,
189, 192, 193, 195, 197, 198, 199, 200, 201, 202, 205, 212, 215,
217, 221, 223, 224, 226, 230, 231, 234, 240, 241, 244, 245, 248,
250, 253, 258, 259, 264, 267, 269, 271, 273, 274, 276, 280, 283,
284, 285, 286, 288, 291, 292, 293, 294, 295, 297, 299]
tried_thrice = [34, 46, 48, 78, 163, 289]
friends = [i for i in range(150, 300)]
strangers = [i for i in range(0, 150)]
def get_input_dim(modalities):
dimensions = {
'rgb': 512,
'pose': 150,
'gaze': 6,
'ocr': 64,
'bbox': 108
}
sum_of_dimensions = sum(dimensions[modality] for modality in modalities)
return sum_of_dimensions
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
def onehot(label, n_classes):
if len(label.size()) == 3:
batch_size, seq_len, _ = label.size()
label_1_2d = label[:,:,0].contiguous().view(batch_size*seq_len, 1)
label_2_2d = label[:,:,1].contiguous().view(batch_size*seq_len, 1)
#onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1)
#onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1)
#onehot_2d = torch.cat((onehot_1_2d, onehot_2_2d), dim=1)
#onehot_3d = onehot_2d.view(batch_size, seq_len, 2*n_classes)
onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1).view(-1, seq_len, n_classes)
onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1).view(-1, seq_len, n_classes)
onehot_3d = torch.cat((onehot_1_2d, onehot_2_2d), dim=2)
return onehot_3d
else:
return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
def mixup(data, alpha, n_classes):
lam = torch.FloatTensor([np.random.beta(alpha, alpha)])
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = data
indices = torch.randperm(labels.size(0))
# labels
labels2 = labels[indices]
labels = onehot(labels, n_classes)
labels2 = onehot(labels2, n_classes)
labels = labels * lam + labels2 * (1 - lam)
# frames
frames2 = frames[indices]
frames = frames * lam + frames2 * (1 - lam)
# poses
poses2 = poses[indices]
poses = poses * lam + poses2 * (1 - lam)
# gazes
gazes2 = gazes[indices]
gazes = gazes * lam + gazes2 * (1 - lam)
# bboxes
bboxes2 = bboxes[indices]
bboxes = bboxes * lam + bboxes2 * (1 - lam)
return frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths
def get_classification_accuracy(pred_left_labels, pred_right_labels, labels, sequence_lengths):
max_len = max(sequence_lengths)
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
labels = torch.reshape(labels, (-1, max_len, 2))
left_correct = torch.argmax(pred_left_labels, 2) == labels[:,:,0]
right_correct = torch.argmax(pred_right_labels, 2) == labels[:,:,1]
num_pred = sum(sequence_lengths) * 2
num_correct = 0
for i in range(len(sequence_lengths)):
size = sequence_lengths[i]
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
acc = num_correct / num_pred
return acc, num_correct, num_pred
def get_classification_accuracy_mixup(pred_left_labels, pred_right_labels, labels, sequence_lengths):
max_len = max(sequence_lengths)
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
labels = torch.reshape(labels, (-1, max_len, 54))
left_correct = torch.argmax(pred_left_labels, 2) == torch.argmax(labels[:,:,:27], 2)
right_correct = torch.argmax(pred_right_labels, 2) == torch.argmax(labels[:,:,27:], 2)
num_pred = sum(sequence_lengths) * 2
num_correct = 0
for i in range(len(sequence_lengths)):
size = sequence_lengths[i]
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
acc = num_correct / num_pred
return acc, num_correct, num_pred
def pad_collate(batch):
(aa, bb, cc, dd, ee, ff) = zip(*batch)
seq_lens = [len(a) for a in aa]
aa_pad = pad_sequence(aa, batch_first=True, padding_value=0)
bb_pad = pad_sequence(bb, batch_first=True, padding_value=-1)
if cc[0] is not None:
cc_pad = pad_sequence(cc, batch_first=True, padding_value=0)
else:
cc_pad = None
if dd[0] is not None:
dd_pad = pad_sequence(dd, batch_first=True, padding_value=0)
else:
dd_pad = None
if ee[0] is not None:
ee_pad = pad_sequence(ee, batch_first=True, padding_value=0)
else:
ee_pad = None
if ff[0] is not None:
ff_pad = pad_sequence(ff, batch_first=True, padding_value=0)
else:
ff_pad = None
return aa_pad, bb_pad, cc_pad, dd_pad, ee_pad, ff_pad, seq_lens
def ocr_loss(pL, lL, pR, lR, ocr, loss, eta=10.0, mode='abs'):
"""
Custom loss based on the negative OCRL matrix.
Input:
pL: tensor of shape [batch_size, num_classes] representing left belief predictions
lL: tensor of shape [batch_size] representing the left belief labels
pR: tensor of shape [batch_size, num_classes] representing right belief predictions
lR: tensor of shape [batch_size] representing the right belief labels
ocr: negative OCR matrix (i.e. 1-OCR)
loss: loss function
eta: hyperparameter for the interaction term
Output:
Final loss resulting from weighting left and right loss with the respective
OCR coefficients and summing them.
"""
bs = pL.shape[0]
if len([*lR[0].size()]) > 0: # for mixup
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
g = torch.tensor([ocr[torch.argmax(lL[i]), torch.argmax(lR[i])] for i in range(bs)], device=pL.device)
else:
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
g = torch.tensor([ocr[lL[i], lR[i]] for i in range(bs)], device=pL.device)
left_loss = torch.mean(loss(pL, lL))
right_loss = torch.mean(loss(pR, lR))
if mode == 'abs':
interaction_loss = torch.mean(torch.abs(g - p))
elif mode == 'mse':
interaction_loss = torch.mean(torch.pow(g - p, 2))
else: raise NotImplementedError
eta = (left_loss + right_loss) / interaction_loss
interaction_loss = interaction_loss * eta
print(f"Left: {left_loss} --- Right: {right_loss} --- Interaction: {interaction_loss}")
return left_loss + right_loss + interaction_loss
def spherical2cartesial(x):
"""
From https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing#scrollTo=78QhNw4MYSYp
"""
output = torch.zeros(x.size(0),3)
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
output[:,1] = torch.sin(x[:,1])
return output
def presave():
from PIL import Image
from torchvision import transforms
import os
from natsort import natsorted
import pickle as pkl
preprocess = transforms.Compose([
transforms.Resize((128, 128)),
#transforms.Resize(256),
#transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
frame_path = '/scratch/bortoletto/data/boss/test/frame'
frame_dirs = os.listdir(frame_path)
frame_paths = []
for frame_dir in natsorted(frame_dirs):
paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
frame_paths.append(paths)
save_folder = '/scratch/bortoletto/data/boss/presaved128/'+frame_path.split('/')[-2]+'/'+frame_path.split('/')[-1]
if not os.path.exists(save_folder):
os.makedirs(save_folder)
print(save_folder, 'created')
for video in natsorted(frame_paths):
print(video[0].split('/')[-2], end='\r')
images = [preprocess(Image.open(i)) for i in video]
strings = video[0].split('/')
with open(save_folder+'/'+strings[7]+'.pkl', 'wb') as f:
pkl.dump(images, f)
def compute_cosine_similarity(tensor1, tensor2):
return F.cosine_similarity(tensor1, tensor2, dim=-1)
def find_most_similar_embedding(rgb, pose, gaze, ocr, bbox, repr):
gaze_similarity = compute_cosine_similarity(gaze, repr)
pose_similarity = compute_cosine_similarity(pose, repr)
ocr_similarity = compute_cosine_similarity(ocr, repr)
bbox_similarity = compute_cosine_similarity(bbox, repr)
rgb_similarity = compute_cosine_similarity(rgb, repr)
similarities = torch.stack([gaze_similarity, pose_similarity, ocr_similarity, bbox_similarity, rgb_similarity])
max_index = torch.argmax(similarities, dim=0)
main_modality = []
main_modality_name = []
for idx in max_index:
if idx == 0:
main_modality.append(gaze)
main_modality_name.append('gaze')
elif idx == 1:
main_modality.append(pose)
main_modality_name.append('pose')
elif idx == 2:
main_modality.append(ocr)
main_modality_name.append('ocr')
elif idx == 3:
main_modality.append(bbox)
main_modality_name.append('bbox')
else:
main_modality.append(rgb)
main_modality_name.append('rgb')
return main_modality, main_modality_name
def plot_similarity_histogram(values, filename, labels):
def count_elements(list_of_strings, target_list):
counts = []
for element in list_of_strings:
count = target_list.count(element)
counts.append(count)
return counts
fig, ax = plt.subplots(figsize=(12,4))
colors = ["red", "blue", "green", "violet"]
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
colors = [MTOM_COLORS['MN1'], MTOM_COLORS['MN2'], MTOM_COLORS['CG'], MTOM_COLORS['CG']]
alphas = [0.6, 0.6, 0.6, 0.6]
edgecolors = ['black', 'black', MTOM_COLORS['MN1'], MTOM_COLORS['MN2']]
linewidths = [1.0, 1.0, 5.0, 5.0]
if isinstance(values[0], list):
num_lists = len(values)
bar_width = 0.8 / num_lists
for i, val in enumerate(values):
unique_strings = ['rgb', 'pose', 'gaze', 'bbox', 'ocr']
counts = count_elements(unique_strings, val)
x = np.arange(len(unique_strings))
x_shifted = x + (i - (len(labels) - 1) / 2) * bar_width #x - (bar_width * num_lists / 2) + (bar_width * i)
ax.bar(x_shifted, counts, width=bar_width, label=f'{labels[i]}', color=colors[i], edgecolor=edgecolors[i], linewidth=linewidths[i], alpha=alphas[i])
ax.set_xlabel('Modality', fontsize=18)
ax.set_ylabel('Counts', fontsize=18)
ax.set_xticks(np.arange(len(unique_strings)))
ax.set_xticklabels(unique_strings, fontsize=18)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
ax.legend(fontsize=18)
ax.grid(axis='y')
plt.savefig(filename, bbox_inches='tight')
else:
unique_strings, counts = np.unique(values, return_counts=True)
ax.bar(unique_strings, counts)
ax.set_xlabel('Modality')
ax.set_ylabel('Counts')
plt.savefig(filename, bbox_inches='tight')
def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
means = []
stds = []
for key, values in data.items():
mean = np.mean(values)
std = np.std(values)
means.append(mean)
stds.append(std)
fig, ax = plt.subplots(figsize=size)
x = np.arange(len(data))
width = 0.6
rects1 = ax.bar(
x, means, width, label='Mean', yerr=stds,
capsize=5,
color='teal' if colors is None else colors,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
ax.set_ylabel('Accuracy', fontsize=18)
#ax.set_title('Mean and Standard Deviation of Results', fontsize=14)
if filename == 'results/all':
xticklabels = [
'Base',
'DB',
'CG$\otimes$', 'CG$\oplus$', 'CG$\odot$', 'CG$\parallel$',
'IC$\otimes$', 'IC$\oplus$', 'IC$\odot$', 'IC$\parallel$',
'CNN', 'CNN+GRU', 'CNN+LSTM', 'CNN+Conv1D'
]
else:
xticklabels = list(data.keys())
ax.set_xticks(np.arange(len(xticklabels)))
ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=16) #data.keys(), rotation=rotation, fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
#ax.grid(axis='y')
#ax.set_axisbelow(True)
#ax.legend(loc='upper right')
# Add value labels above each bar, avoiding overlapping with error bars
for rect, std in zip(rects1, stds):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(axis='x')
plt.tight_layout()
plt.savefig(f'{filename}.pdf', bbox_inches='tight')
def plot_confusion_matrices(left_cm, right_cm, labels, title, annot=True):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
left_xticklabels = [OBJECTS[i] for i in labels[0]]
left_yticklabels = [OBJECTS[i] for i in labels[1]]
right_xticklabels = [OBJECTS[i] for i in labels[2]]
right_yticklabels = [OBJECTS[i] for i in labels[3]]
sns.heatmap(
left_cm,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False,
xticklabels=left_xticklabels,
yticklabels=left_yticklabels,
ax=ax1
)
ax1.set_xlabel('Predicted')
ax1.set_ylabel('True')
ax1.set_title('Left Confusion Matrix')
sns.heatmap(
right_cm,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False, #True if annot is False else False,
xticklabels=right_xticklabels,
yticklabels=right_yticklabels,
ax=ax2
)
ax2.set_xlabel('Predicted')
ax2.set_ylabel('True')
ax2.set_title('Right Confusion Matrix')
#plt.suptitle(title)
plt.tight_layout()
plt.savefig(title + '.pdf')
plt.close()
def plot_confusion_matrix(confusion_matrix, labels, title, annot=True):
plt.figure(figsize=(6, 6))
xticklabels = [OBJECTS[i] for i in labels[0]]
yticklabels = [OBJECTS[i] for i in labels[1]]
sns.heatmap(
confusion_matrix,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False,
xticklabels=xticklabels,
yticklabels=yticklabels
)
plt.xlabel('Predicted')
plt.ylabel('True')
#plt.title(title)
plt.tight_layout()
plt.savefig(title + '.pdf')
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, choices=['confusion', 'similarity', 'scores', 'friends_vs_strangers'])
parser.add_argument('--folder_path', type=str)
args = parser.parse_args()
if args.task == 'similarity':
sns.set_theme(style='white')
else:
sns.set_theme(style='whitegrid')
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
if args.task == 'similarity':
out_left_main_mods_full_test = []
out_right_main_mods_full_test = []
cell_left_main_mods_full_test = []
cell_right_main_mods_full_test = []
cm_left_main_mods_full_test = []
cm_right_main_mods_full_test = []
for i in range(60):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(args.folder_path, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 14: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
out_left = out_left.squeeze(0)
cell_left = cell_left.squeeze(0)
out_right = out_right.squeeze(0)
cell_right = cell_right.squeeze(0)
elif len(data) == 13: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
common_mind = common_mind.squeeze(0)
elif len(data) == 12: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
# ====== PCA for left and right embeddings ====== #
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
pca_result_left = pca_result[:out_left.shape[0]]
pca_result_right = pca_result[out_right.shape[0]:]
plt.figure(figsize=(7,6))
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='MindNet$_1$', color=MTOM_COLORS['MN1'], s=100)
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='MindNet$_2$', color=MTOM_COLORS['MN2'], s=100)
plt.xlabel('Principal Component 1', fontsize=30)
plt.ylabel('Principal Component 2', fontsize=30)
plt.grid(False)
plt.legend(fontsize=30)
plt.tight_layout()
plt.savefig(f'{args.folder_path}/{i}_pca.pdf')
plt.close()
# ====== Feature similarity ====== #
if len(feats) == 10:
left_rgb, left_ocr, left_pose, left_gaze, left_bbox, right_rgb, right_ocr, right_pose, right_gaze, right_bbox = feats
left_rgb = left_rgb.squeeze(0)
left_ocr = left_ocr.squeeze(0)
left_pose = left_pose.squeeze(0)
left_gaze = left_gaze.squeeze(0)
left_bbox = left_bbox.squeeze(0)
right_rgb = right_rgb.squeeze(0)
right_ocr = right_ocr.squeeze(0)
right_pose = right_pose.squeeze(0)
right_gaze = right_gaze.squeeze(0)
right_bbox = right_bbox.squeeze(0)
else: raise NotImplementedError("Ablated versions are not supported yet.")
# out: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
# cell: [1, 1, dim] --------- squeeze ---> [1, dim]
# cm: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
# feat: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
_, out_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, out_left)
_, out_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, out_right)
out_left_main_mods_full_test += out_left_main_mods
out_right_main_mods_full_test += out_right_main_mods
if model == 'impl':
_, cell_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, cell_left)
_, cell_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, cell_right)
cell_left_main_mods_full_test += cell_left_main_mods
cell_right_main_mods_full_test += cell_right_main_mods
if model == 'cm':
_, cm_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, common_mind)
_, cm_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, common_mind)
cm_left_main_mods_full_test += cm_left_main_mods
cm_right_main_mods_full_test += cm_right_main_mods
if model == 'impl':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test,
cell_left_main_mods_full_test,
cell_right_main_mods_full_test],
f'{args.folder_path}/boss_similartiy_impl_hist_all.pdf',
[r'$h_1$', r'$h_2$', r'$c_1$', r'$c_2$']
)
elif model == 'cm':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test,
cm_left_main_mods_full_test,
cm_right_main_mods_full_test],
f'{args.folder_path}/boss_similarity_cm_hist_all.pdf',
[r'$h_1$', r'$h_2$', r'$cg$ w/ 1', r'$cg$ w/ 2']
)
elif model == 'sl':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test],
f'{args.folder_path}/boss_similarity_sl_hist_all.py',
[r'$h_1$', r'$h_2$']
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'scores':
all = json.load(open('results/all.json'))
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
plot_scores_histogram(
all,
filename='results/all',
size=(10,4),
rotation=45,
colors=[COLORS[7]] + [MTOM_COLORS['DB']] + [MTOM_COLORS['CG']]*4 + [MTOM_COLORS['IC']]*4 + [COLORS[4]]*4
)
plot_scores_histogram(
abl_tom_cm_concat,
size=(10,4),
filename='results/abl_cm_concat',
colors=[MTOM_COLORS['CG']]*8
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'confusion':
# List to store confusion matrices
all_df = []
left_matrices = []
right_matrices = []
# Iterate over the CSV files
for filename in natsorted(os.listdir(args.folder_path)):
if filename.endswith('.csv'):
file_path = os.path.join(args.folder_path, filename)
print(f'Processing {file_path}...')
# Read the CSV file
df = pd.read_csv(file_path)
# Extract the left and right labels and predictions
left_labels = df['left_label']
right_labels = df['right_label']
left_preds = df['left_pred']
right_preds = df['right_pred']
# Calculate the confusion matrices for left and right
left_cm = pd.crosstab(left_labels, left_preds)
right_cm = pd.crosstab(right_labels, right_preds)
# Append the confusion matrices to the list
left_matrices.append(left_cm)
right_matrices.append(right_cm)
all_df.append(df)
# Plot and save the confusion matrices for left and right
for i, cm in enumerate(zip(left_matrices, right_matrices)):
print(f'Computing confusion matrices for video {i}...', end='\r')
plot_confusion_matrices(
cm[0],
cm[1],
labels=[cm[0].columns, cm[0].index, cm[1].columns, cm[1].index],
title=f'{args.folder_path}/{i}_cm'
)
merged_df = pd.concat(all_df).reset_index(drop=True)
merged_left_cm = pd.crosstab(merged_df['left_label'], merged_df['left_pred'])
merged_right_cm = pd.crosstab(merged_df['right_label'], merged_df['right_pred'])
plot_confusion_matrices(
merged_left_cm,
merged_right_cm,
labels=[merged_left_cm.columns, merged_left_cm.index, merged_right_cm.columns, merged_right_cm.index],
title=f'{args.folder_path}/all_lr',
annot=False
)
merged_preds = pd.concat([merged_df['left_pred'], merged_df['right_pred']]).reset_index(drop=True)
merged_labels = pd.concat([merged_df['left_label'], merged_df['right_label']]).reset_index(drop=True)
merged_all_cm = pd.crosstab(merged_labels, merged_preds)
plot_confusion_matrix(
merged_all_cm,
labels=[merged_all_cm.columns, merged_all_cm.index],
title=f'{args.folder_path}/all_merged',
annot=False
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'friends_vs_strangers':
# friends ids: 30-59
# stranger ids: 0-29
out_left_list = []
out_right_list = []
cell_left_list = []
cell_right_list = []
common_mind_list = []
for i in range(60):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(args.folder_path, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 14: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
out_left = out_left.squeeze(0)
cell_left = cell_left.squeeze(0)
out_right = out_right.squeeze(0)
cell_right = cell_right.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
cell_left_list.append(cell_left)
cell_right_list.append(cell_right)
elif len(data) == 13: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
common_mind = common_mind.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
common_mind_list.append(common_mind)
elif len(data) == 12: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
# ====== PCA for left and right embeddings ====== #
print('\rComputing PCA...')
strangers_nframes = sum([out_left_list[i].shape[0] for i in range(30)])
left = torch.cat(out_left_list, 0)
right = torch.cat(out_right_list, 0)
out_left_and_right = torch.cat([left, right], axis=0).numpy()
#out_left_and_right = StandardScaler().fit_transform(out_left_and_right)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
#pca_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(out_left_and_right)
#pca_result = UMAP().fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
#pca_result_left = pca_result[:left.shape[0]]
#pca_result_right = pca_result[right.shape[0]:]
pca_result_strangers = pca_result[:strangers_nframes]
pca_result_friends = pca_result[strangers_nframes:]
#plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='Left')
#plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='Right')
plt.scatter(pca_result_friends[:, 0], pca_result_friends[:, 1], label='Friends')
plt.scatter(pca_result_strangers[:, 0], pca_result_strangers[:, 1], label='Strangers')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.legend()
plt.savefig(f'{args.folder_path}/friends_vs_strangers_pca.pdf')
plt.close()
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
else:
raise NameError