diff --git a/README.md b/README.md
index c4a4964..09be464 100644
--- a/README.md
+++ b/README.md
@@ -22,11 +22,17 @@ This is a temporary reference that will be updated after the proceedings are pub
}
```
-
-
-
+# Code
-Under construction
+This repository has the following structure:
+
+```
+mtomnet
+├── boss
+└── tbd
+```
+
+We have one subfolder for dataset, containing the code to run the corresponding experiments. Inside each subfolder we provide a README with further instructions.
[1]: https://mattbortoletto.github.io/
diff --git a/boss/.gitignore b/boss/.gitignore
new file mode 100644
index 0000000..b9348b7
--- /dev/null
+++ b/boss/.gitignore
@@ -0,0 +1,206 @@
+experiments/
+predictions_*
+predictions/
+tmp/
+
+### WANDB ###
+
+wandb/*
+wandb
+wandb-debug.log
+logs
+summary*
+run*
+
+### SYSTEM ###
+
+*/__pycache__/*
+
+# Created by https://www.toptal.com/developers/gitignore/api/python,linux
+# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
+
+### Linux ###
+*~
+
+# temporary files which can be created if a process still has a handle open of a deleted file
+.fuse_hidden*
+
+# KDE directory preferences
+.directory
+
+# Linux trash folder which might appear on any partition or disk
+.Trash-*
+
+# .nfs files are created when an open file is removed but is still being accessed
+.nfs*
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+### Python Patch ###
+# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
+poetry.toml
+
+# ruff
+.ruff_cache/
+
+# End of https://www.toptal.com/developers/gitignore/api/python,linux
\ No newline at end of file
diff --git a/boss/README.md b/boss/README.md
new file mode 100644
index 0000000..92e2d01
--- /dev/null
+++ b/boss/README.md
@@ -0,0 +1,16 @@
+# BOSS
+
+## Installing Dependencies
+Run `conda env create -f environment.yml`.
+
+## New bounding box annotations
+We re-extracted bounding box annotations using Yolo-v8. The new data is in `new_bbox/`. The rest of the data can be found [here](https://drive.google.com/drive/folders/1b8FdpyoWx9gUps-BX6qbE9Kea3C2Uyua).
+
+## Train
+`source run_train.sh`.
+
+## Test
+`source run_test.sh` (specify the path to the model).
+
+## Resources
+The original project page for the BOSS dataset can be found [here](https://sites.google.com/view/bossbelief/). Our code is based on the original implementation.
\ No newline at end of file
diff --git a/boss/dataloader.py b/boss/dataloader.py
new file mode 100644
index 0000000..64b05f9
--- /dev/null
+++ b/boss/dataloader.py
@@ -0,0 +1,500 @@
+from torch.utils.data import Dataset
+import os
+from torchvision import transforms
+from natsort import natsorted
+import numpy as np
+import pickle
+import torch
+import cv2
+import ast
+from einops import rearrange
+from utils import spherical2cartesial
+
+
+class Data(Dataset):
+ def __init__(
+ self,
+ frame_path,
+ label_path,
+ pose_path=None,
+ gaze_path=None,
+ bbox_path=None,
+ ocr_graph_path=None,
+ presaved=128,
+ sizes = (128,128),
+ spatial_patch_size: int = 16,
+ temporal_patch_size: int = 4,
+ img_channels: int = 3,
+ patch_data: bool = False,
+ flatten_dim: int = 1
+ ):
+ self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
+ self.size_1, self.size_2 = sizes
+ self.patch_data = patch_data
+ if self.patch_data:
+ self.spatial_patch_size = spatial_patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
+ self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
+
+ assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
+
+ if presaved == 224:
+ self.presaved = 'presaved'
+ elif presaved == 128:
+ self.presaved = 'presaved128'
+ else: raise ValueError
+
+ frame_dirs = os.listdir(frame_path)
+ episode_ids = natsorted(frame_dirs)
+ seq_lens = []
+ for frame_dir in natsorted(frame_dirs):
+ frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
+ seq_lens.append(len(frame_paths))
+ self.data['frame_paths'].append(frame_paths)
+
+ print('Labels...')
+ with open(label_path, 'rb') as fp:
+ labels = pickle.load(fp)
+ left_labels, right_labels = labels
+ for i in episode_ids:
+ episode_id = int(i)
+ left_labels_i = left_labels[episode_id]
+ right_labels_i = right_labels[episode_id]
+ self.data['labels'].append([left_labels_i, right_labels_i])
+ fp.close()
+
+ print('Pose...')
+ self.pose_path = pose_path
+ if pose_path is not None:
+ for i in episode_ids:
+ pose_i_dir = os.path.join(pose_path, i)
+ poses = []
+ for j in natsorted(os.listdir(pose_i_dir)):
+ fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
+ if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
+ poses.append(fs.getNode("pose_0").mat()[:2,:,:])
+ else:
+ poses.append(fs.getNode("pose_0").mat())
+ poses = np.array(poses)
+ poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
+ poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
+ self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
+ #self.data['poses'].append(torch.tensor(np.array(poses)))
+
+ print('Gaze...')
+ """
+ Gaze information is represented by a 3D gaze vector based on the observing camera’s Cartesian eye coordinate system
+ """
+ self.gaze_path = gaze_path
+ if gaze_path is not None:
+ for i in episode_ids:
+ gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
+ with open(gaze_i_txt, 'r') as fp:
+ gaze_i = fp.readlines()
+ fp.close()
+ gaze_i = ast.literal_eval(gaze_i[0])
+ gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
+ for j in range(len(gaze_i)):
+ if len(gaze_i[j]) >= 2:
+ gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
+ elif len(gaze_i[j]) == 1:
+ gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
+ else:
+ continue
+ self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
+
+ print('Bbox...')
+ self.bbox_path = bbox_path
+ if bbox_path is not None:
+ self.objects = [
+ 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
+ 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
+ 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
+ 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
+ ]
+ """
+ # NOTE: old bbox
+ for i in episode_ids:
+ bbox_i_dir = os.path.join(bbox_path, i)
+ with open(bbox_i_dir, 'rb') as fp:
+ bboxes_i = pickle.load(fp)
+ len_i = len(bboxes_i)
+ fp.close()
+ bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
+ for j in range(len(bboxes_i)):
+ items_i_j, bboxes_i_j = bboxes_i[j]
+ for k in range(len(items_i_j)):
+ bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
+ bboxes_i_j[k][0] / 1920, # * self.size_1,
+ bboxes_i_j[k][1] / 1088, # * self.size_2,
+ bboxes_i_j[k][2] / 1920, # * self.size_1,
+ bboxes_i_j[k][3] / 1088, # * self.size_2
+ ]) # [x_min, y_min, x_max, y_max]
+ self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
+ """
+ # NOTE: new bbox
+ for i in episode_ids:
+ bbox_dir = os.path.join(bbox_path, i)
+ bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
+ for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
+ with open(os.path.join(bbox_dir, bbox), 'r') as fp:
+ bbox_content = fp.readlines()
+ fp.close()
+ for bbox_content_line in bbox_content:
+ bbox_content_values = bbox_content_line.split()
+ class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
+ bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
+ self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
+
+ print('OCR...\n')
+ self.ocr_graph_path = ocr_graph_path
+ if ocr_graph_path is not None:
+ ocr_graph = [
+ [15, [10, 4], [17, 2]],
+ [13, [16, 7], [18, 4]],
+ [11, [16, 4], [7, 10]],
+ [14, [10, 11], [7, 1]],
+ [12, [10, 9], [16, 3]],
+ [1, [7, 2], [9, 9], [10, 2]],
+ [5, [8, 8], [6, 8]],
+ [4, [9, 8], [7, 6]],
+ [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
+ [2, [10, 1], [7, 7], [9, 3]],
+ [19, [10, 2], [26, 6]],
+ [20, [10, 7], [26, 5]],
+ [22, [25, 4], [10, 8]],
+ [23, [25, 15]],
+ [21, [16, 5], [24, 8]]
+ ]
+ ocr_tensor = torch.zeros((27, 27))
+ for ocr in ocr_graph:
+ obj = ocr[0]
+ contexts = ocr[1:]
+ total_context_count = sum([i[1] for i in contexts])
+ for context in contexts:
+ ocr_tensor[obj, context[0]] = context[1] / total_context_count
+ ocr_tensor = torch.flatten(ocr_tensor)
+ for i in episode_ids:
+ self.data['ocr_graph'].append(ocr_tensor)
+
+ self.frame_path = frame_path
+ self.transform = transforms.Compose([
+ #transforms.ToPILImage(),
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+
+ def __len__(self):
+ return len(self.data['frame_paths'])
+
+ def __getitem__(self, idx):
+ frames_path = self.data['frame_paths'][idx][0].split('/')
+ frames_path.insert(5, self.presaved)
+ frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
+ with open(frames_path, 'rb') as f:
+ images = pickle.load(f)
+ images = torch.stack(images)
+ if self.patch_data:
+ images = rearrange(
+ images,
+ '(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
+ p1=self.spatial_patch_size,
+ p2=self.spatial_patch_size,
+ p3=self.temporal_patch_size
+ )
+
+ if self.pose_path is not None:
+ pose = self.data['poses'][idx]
+ if self.patch_data:
+ pose = rearrange(
+ pose,
+ '(t p1) d -> t p1 d',
+ p1=self.temporal_patch_size
+ )
+ else:
+ pose = None
+
+ if self.gaze_path is not None:
+ gaze = self.data['gazes'][idx]
+ if self.patch_data:
+ gaze = rearrange(
+ gaze,
+ '(t p1) d -> t p1 d',
+ p1=self.temporal_patch_size
+ )
+ else:
+ gaze = None
+
+ if self.bbox_path is not None:
+ bbox = self.data['bboxes'][idx]
+ if self.patch_data:
+ bbox = rearrange(
+ bbox,
+ '(t p1) d -> t p1 d',
+ p1=self.temporal_patch_size
+ )
+ else:
+ bbox = None
+
+ if self.ocr_graph_path is not None:
+ ocr_graphs = self.data['ocr_graph'][idx]
+ else:
+ ocr_graphs = None
+
+ return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
+
+
+class DataTest(Dataset):
+ def __init__(
+ self,
+ frame_path,
+ label_path,
+ pose_path=None,
+ gaze_path=None,
+ bbox_path=None,
+ ocr_graph_path=None,
+ presaved=128,
+ frame_ids=None,
+ median=None,
+ sizes = (128,128),
+ spatial_patch_size: int = 16,
+ temporal_patch_size: int = 4,
+ img_channels: int = 3,
+ patch_data: bool = False,
+ flatten_dim: int = 1
+ ):
+ self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
+ self.size_1, self.size_2 = sizes
+ self.patch_data = patch_data
+ if self.patch_data:
+ self.spatial_patch_size = spatial_patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
+ self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
+
+ assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
+
+ if presaved == 224:
+ self.presaved = 'presaved'
+ elif presaved == 128:
+ self.presaved = 'presaved128'
+ else: raise ValueError
+
+ if frame_ids is not None:
+ test_ids = []
+ frame_dirs = os.listdir(frame_path)
+ episode_ids = natsorted(frame_dirs)
+ for frame_dir in episode_ids:
+ if int(frame_dir) in frame_ids:
+ test_ids.append(str(frame_dir))
+ frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
+ self.data['frame_paths'].append(frame_paths)
+ elif median is not None:
+ test_ids = []
+ frame_dirs = os.listdir(frame_path)
+ episode_ids = natsorted(frame_dirs)
+ for frame_dir in episode_ids:
+ frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
+ seq_len = len(frame_paths)
+ if (median[1] and seq_len >= median[0]) or (not median[1] and seq_len < median[0]):
+ self.data['frame_paths'].append(frame_paths)
+ test_ids.append(int(frame_dir))
+ else:
+ frame_dirs = os.listdir(frame_path)
+ episode_ids = natsorted(frame_dirs)
+ test_ids = episode_ids.copy()
+ for frame_dir in episode_ids:
+ frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
+ self.data['frame_paths'].append(frame_paths)
+
+ print('Labels...')
+ with open(label_path, 'rb') as fp:
+ labels = pickle.load(fp)
+ left_labels, right_labels = labels
+ for i in test_ids:
+ episode_id = int(i)
+ left_labels_i = left_labels[episode_id]
+ right_labels_i = right_labels[episode_id]
+ self.data['labels'].append([left_labels_i, right_labels_i])
+ fp.close()
+
+ print('Pose...')
+ self.pose_path = pose_path
+ if pose_path is not None:
+ for i in test_ids:
+ pose_i_dir = os.path.join(pose_path, i)
+ poses = []
+ for j in natsorted(os.listdir(pose_i_dir)):
+ fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
+ if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
+ poses.append(fs.getNode("pose_0").mat()[:2,:,:])
+ else:
+ poses.append(fs.getNode("pose_0").mat())
+ poses = np.array(poses)
+ poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
+ poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
+ self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
+
+ print('Gaze...')
+ self.gaze_path = gaze_path
+ if gaze_path is not None:
+ for i in test_ids:
+ gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
+ with open(gaze_i_txt, 'r') as fp:
+ gaze_i = fp.readlines()
+ fp.close()
+ gaze_i = ast.literal_eval(gaze_i[0])
+ gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
+ for j in range(len(gaze_i)):
+ if len(gaze_i[j]) >= 2:
+ gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
+ elif len(gaze_i[j]) == 1:
+ gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
+ else:
+ continue
+ self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
+
+ print('Bbox...')
+ self.bbox_path = bbox_path
+ if bbox_path is not None:
+ self.objects = [
+ 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
+ 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
+ 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
+ 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
+ ]
+ """ NOTE: old bbox
+ for i in test_ids:
+ bbox_i_dir = os.path.join(bbox_path, i)
+ with open(bbox_i_dir, 'rb') as fp:
+ bboxes_i = pickle.load(fp)
+ len_i = len(bboxes_i)
+ fp.close()
+ bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
+ for j in range(len(bboxes_i)):
+ items_i_j, bboxes_i_j = bboxes_i[j]
+ for k in range(len(items_i_j)):
+ bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
+ bboxes_i_j[k][0] / 1920, # * self.size_1,
+ bboxes_i_j[k][1] / 1088, # * self.size_2,
+ bboxes_i_j[k][2] / 1920, # * self.size_1,
+ bboxes_i_j[k][3] / 1088, # * self.size_2
+ ]) # [x_min, y_min, x_max, y_max]
+ self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
+ """
+ for i in test_ids:
+ bbox_dir = os.path.join(bbox_path, i)
+ bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
+ for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
+ with open(os.path.join(bbox_dir, bbox), 'r') as fp:
+ bbox_content = fp.readlines()
+ fp.close()
+ for bbox_content_line in bbox_content:
+ bbox_content_values = bbox_content_line.split()
+ class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
+ bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
+ self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
+
+ print('OCR...\n')
+ self.ocr_graph_path = ocr_graph_path
+ if ocr_graph_path is not None:
+ ocr_graph = [
+ [15, [10, 4], [17, 2]],
+ [13, [16, 7], [18, 4]],
+ [11, [16, 4], [7, 10]],
+ [14, [10, 11], [7, 1]],
+ [12, [10, 9], [16, 3]],
+ [1, [7, 2], [9, 9], [10, 2]],
+ [5, [8, 8], [6, 8]],
+ [4, [9, 8], [7, 6]],
+ [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
+ [2, [10, 1], [7, 7], [9, 3]],
+ [19, [10, 2], [26, 6]],
+ [20, [10, 7], [26, 5]],
+ [22, [25, 4], [10, 8]],
+ [23, [25, 15]],
+ [21, [16, 5], [24, 8]]
+ ]
+ ocr_tensor = torch.zeros((27, 27))
+ for ocr in ocr_graph:
+ obj = ocr[0]
+ contexts = ocr[1:]
+ total_context_count = sum([i[1] for i in contexts])
+ for context in contexts:
+ ocr_tensor[obj, context[0]] = context[1] / total_context_count
+ ocr_tensor = torch.flatten(ocr_tensor)
+ for i in test_ids:
+ self.data['ocr_graph'].append(ocr_tensor)
+
+ self.frame_path = frame_path
+ self.transform = transforms.Compose([
+ #transforms.ToPILImage(),
+ transforms.Resize((self.size_1, self.size_2)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+
+ def __len__(self):
+ return len(self.data['frame_paths'])
+
+ def __getitem__(self, idx):
+ frames_path = self.data['frame_paths'][idx][0].split('/')
+ frames_path.insert(5, self.presaved)
+ frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
+ with open(frames_path, 'rb') as f:
+ images = pickle.load(f)
+ images = torch.stack(images)
+ if self.patch_data:
+ images = rearrange(
+ images,
+ '(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
+ p1=self.spatial_patch_size,
+ p2=self.spatial_patch_size,
+ p3=self.temporal_patch_size
+ )
+
+ if self.pose_path is not None:
+ pose = self.data['poses'][idx]
+ if self.patch_data:
+ pose = rearrange(
+ pose,
+ '(t p1) d -> t p1 d',
+ p1=self.temporal_patch_size
+ )
+ else:
+ pose = None
+
+ if self.gaze_path is not None:
+ gaze = self.data['gazes'][idx]
+ if self.patch_data:
+ gaze = rearrange(
+ gaze,
+ '(t p1) d -> t p1 d',
+ p1=self.temporal_patch_size
+ )
+ else:
+ gaze = None
+
+ if self.bbox_path is not None:
+ bbox = self.data['bboxes'][idx]
+ if self.patch_data:
+ bbox = rearrange(
+ bbox,
+ '(t p1) d -> t p1 d',
+ p1=self.temporal_patch_size
+ )
+ else:
+ bbox = None
+
+ if self.ocr_graph_path is not None:
+ ocr_graphs = self.data['ocr_graph'][idx]
+ else:
+ ocr_graphs = None
+
+ return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
+
diff --git a/boss/environment.yml b/boss/environment.yml
new file mode 100644
index 0000000..2142bac
--- /dev/null
+++ b/boss/environment.yml
@@ -0,0 +1,240 @@
+name: boss
+channels:
+ - pyg
+ - anaconda
+ - conda-forge
+ - defaults
+ - pytorch
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - alembic=1.8.0=pyhd8ed1ab_0
+ - appdirs=1.4.4=pyh9f0ad1d_0
+ - asttokens=2.2.1=pyhd8ed1ab_0
+ - backcall=0.2.0=pyh9f0ad1d_0
+ - backports=1.0=pyhd8ed1ab_3
+ - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
+ - blas=1.0=mkl
+ - blosc=1.21.3=h6a678d5_0
+ - bottle=0.12.23=pyhd8ed1ab_0
+ - bottleneck=1.3.5=py38h7deecbd_0
+ - brotli=1.0.9=h5eee18b_7
+ - brotli-bin=1.0.9=h5eee18b_7
+ - brotlipy=0.7.0=py38h27cfd23_1003
+ - brunsli=0.1=h2531618_0
+ - bzip2=1.0.8=h7b6447c_0
+ - c-ares=1.18.1=h7f8727e_0
+ - ca-certificates=2023.01.10=h06a4308_0
+ - certifi=2023.5.7=py38h06a4308_0
+ - cffi=1.15.1=py38h74dc2b5_0
+ - cfitsio=3.470=h5893167_7
+ - charls=2.2.0=h2531618_0
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
+ - click=8.1.3=unix_pyhd8ed1ab_2
+ - cloudpickle=2.0.0=pyhd3eb1b0_0
+ - cmaes=0.9.1=pyhd8ed1ab_0
+ - colorlog=6.7.0=py38h578d9bd_1
+ - contourpy=1.0.5=py38hdb19cb5_0
+ - cryptography=39.0.1=py38h9ce1e76_0
+ - cudatoolkit=11.3.1=h2bc3f7f_2
+ - cycler=0.11.0=pyhd3eb1b0_0
+ - cytoolz=0.12.0=py38h5eee18b_0
+ - dask-core=2022.7.0=py38h06a4308_0
+ - dbus=1.13.18=hb2f20db_0
+ - debugpy=1.5.1=py38h295c915_0
+ - decorator=5.1.1=pyhd8ed1ab_0
+ - docker-pycreds=0.4.0=py_0
+ - entrypoints=0.4=pyhd8ed1ab_0
+ - executing=1.2.0=pyhd8ed1ab_0
+ - expat=2.4.9=h6a678d5_0
+ - ffmpeg=4.3=hf484d3e_0
+ - fftw=3.3.9=h27cfd23_1
+ - flit-core=3.6.0=pyhd3eb1b0_0
+ - fontconfig=2.14.1=h52c9d5c_1
+ - fonttools=4.25.0=pyhd3eb1b0_0
+ - freetype=2.12.1=h4a9f257_0
+ - fsspec=2022.11.0=py38h06a4308_0
+ - giflib=5.2.1=h5eee18b_1
+ - gitdb=4.0.10=pyhd8ed1ab_0
+ - gitpython=3.1.31=pyhd8ed1ab_0
+ - glib=2.69.1=h4ff587b_1
+ - gmp=6.2.1=h295c915_3
+ - gnutls=3.6.15=he1e5248_0
+ - greenlet=2.0.1=py38h6a678d5_0
+ - gst-plugins-base=1.14.1=h6a678d5_1
+ - gstreamer=1.14.1=h5eee18b_1
+ - icu=58.2=he6710b0_3
+ - idna=3.4=py38h06a4308_0
+ - imagecodecs=2021.8.26=py38hf0132c2_1
+ - imageio=2.19.3=py38h06a4308_0
+ - importlib-metadata=6.1.0=pyha770c72_0
+ - importlib_resources=5.2.0=pyhd3eb1b0_1
+ - intel-openmp=2021.4.0=h06a4308_3561
+ - ipykernel=6.15.0=pyh210e3f2_0
+ - ipython=8.11.0=pyh41d4057_0
+ - jedi=0.18.2=pyhd8ed1ab_0
+ - jinja2=3.1.2=py38h06a4308_0
+ - joblib=1.1.1=py38h06a4308_0
+ - jpeg=9e=h7f8727e_0
+ - jupyter_client=7.0.6=pyhd8ed1ab_0
+ - jupyter_core=4.12.0=py38h578d9bd_0
+ - jxrlib=1.1=h7b6447c_2
+ - kiwisolver=1.4.4=py38h6a678d5_0
+ - krb5=1.19.4=h568e23c_0
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - lerc=3.0=h295c915_0
+ - libaec=1.0.4=he6710b0_1
+ - libbrotlicommon=1.0.9=h5eee18b_7
+ - libbrotlidec=1.0.9=h5eee18b_7
+ - libbrotlienc=1.0.9=h5eee18b_7
+ - libclang=10.0.1=default_hb85057a_2
+ - libcurl=7.87.0=h91b91d3_0
+ - libdeflate=1.8=h7f8727e_5
+ - libedit=3.1.20221030=h5eee18b_0
+ - libev=4.33=h7f8727e_1
+ - libevent=2.1.12=h8f2d780_0
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgfortran-ng=11.2.0=h00389a5_1
+ - libgfortran5=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libiconv=1.16=h7f8727e_2
+ - libidn2=2.3.2=h7f8727e_0
+ - libllvm10=10.0.1=hbcb73fb_5
+ - libnghttp2=1.46.0=hce63b2e_0
+ - libpng=1.6.37=hbc83047_0
+ - libpq=12.9=h16c4e8d_3
+ - libprotobuf=3.20.3=he621ea3_0
+ - libsodium=1.0.18=h36c2ea0_1
+ - libssh2=1.10.0=h8f2d780_0
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtasn1=4.16.0=h27cfd23_0
+ - libtiff=4.5.0=h6a678d5_1
+ - libunistring=0.9.10=h27cfd23_0
+ - libuuid=1.41.5=h5eee18b_0
+ - libwebp=1.2.4=h11a3e52_0
+ - libwebp-base=1.2.4=h5eee18b_0
+ - libxcb=1.15=h7f8727e_0
+ - libxkbcommon=1.0.1=hfa300c1_0
+ - libxml2=2.9.14=h74e7548_0
+ - libxslt=1.1.35=h4e12654_0
+ - libzopfli=1.0.3=he6710b0_0
+ - locket=1.0.0=py38h06a4308_0
+ - lz4-c=1.9.4=h6a678d5_0
+ - mako=1.2.4=pyhd8ed1ab_0
+ - markupsafe=2.1.1=py38h7f8727e_0
+ - matplotlib=3.7.0=py38h06a4308_0
+ - matplotlib-base=3.7.0=py38h417a72b_0
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
+ - mkl=2021.4.0=h06a4308_640
+ - mkl-service=2.4.0=py38h7f8727e_0
+ - mkl_fft=1.3.1=py38hd3c417c_0
+ - mkl_random=1.2.2=py38h51133e4_0
+ - munkres=1.1.4=py_0
+ - natsort=7.1.1=pyhd3eb1b0_0
+ - ncurses=6.3=h5eee18b_3
+ - nest-asyncio=1.5.6=pyhd8ed1ab_0
+ - nettle=3.7.3=hbbd107a_1
+ - networkx=2.8.4=py38h06a4308_0
+ - nspr=4.33=h295c915_0
+ - nss=3.74=h0370c37_0
+ - numexpr=2.8.4=py38he184ba9_0
+ - numpy=1.23.5=py38h14f4228_0
+ - numpy-base=1.23.5=py38h31eccc5_0
+ - openh264=2.1.1=h4ff587b_0
+ - openjpeg=2.4.0=h3ad879b_0
+ - openssl=1.1.1t=h7f8727e_0
+ - optuna=3.1.0=pyhd8ed1ab_0
+ - optuna-dashboard=0.9.0=pyhd8ed1ab_0
+ - packaging=22.0=py38h06a4308_0
+ - pandas=1.5.3=py38h417a72b_0
+ - parso=0.8.3=pyhd8ed1ab_0
+ - partd=1.2.0=pyhd3eb1b0_1
+ - pathtools=0.1.2=py_1
+ - pcre=8.45=h295c915_0
+ - pexpect=4.8.0=pyh1a96a4e_2
+ - pickleshare=0.7.5=py_1003
+ - pillow=9.4.0=py38h6a678d5_0
+ - pip=22.3.1=py38h06a4308_0
+ - ply=3.11=py38_0
+ - prompt-toolkit=3.0.38=pyha770c72_0
+ - prompt_toolkit=3.0.38=hd8ed1ab_0
+ - protobuf=3.20.3=py38h6a678d5_0
+ - psutil=5.9.0=py38h5eee18b_0
+ - ptyprocess=0.7.0=pyhd3deb0d_0
+ - pure_eval=0.2.2=pyhd8ed1ab_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pyg=2.2.0=py38_torch_1.12.0_cu113
+ - pygments=2.14.0=pyhd8ed1ab_0
+ - pyopenssl=23.0.0=py38h06a4308_0
+ - pyparsing=3.0.9=py38h06a4308_0
+ - pyqt=5.15.7=py38h6a678d5_1
+ - pyqt5-sip=12.11.0=py38h6a678d5_1
+ - pysocks=1.7.1=py38h06a4308_0
+ - python=3.8.13=haa1d7c7_1
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
+ - python_abi=3.8=2_cp38
+ - pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0
+ - pytorch-cluster=1.6.0=py38_torch_1.12.0_cu113
+ - pytorch-mutex=1.0=cuda
+ - pytorch-scatter=2.1.0=py38_torch_1.12.0_cu113
+ - pytorch-sparse=0.6.16=py38_torch_1.12.0_cu113
+ - pytz=2022.7=py38h06a4308_0
+ - pywavelets=1.4.1=py38h5eee18b_0
+ - pyyaml=6.0=py38h5eee18b_1
+ - pyzmq=19.0.2=py38ha71036d_2
+ - qt-main=5.15.2=h327a75a_7
+ - qt-webengine=5.15.9=hd2b0992_4
+ - qtwebkit=5.212=h4eab89a_4
+ - readline=8.2=h5eee18b_0
+ - requests=2.28.1=py38h06a4308_1
+ - scikit-image=0.19.3=py38h6a678d5_1
+ - scikit-learn=1.2.1=py38h6a678d5_0
+ - scipy=1.9.3=py38h14f4228_0
+ - sentry-sdk=1.16.0=pyhd8ed1ab_0
+ - setproctitle=1.2.2=py38h0a891b7_2
+ - setuptools=65.6.3=py38h06a4308_0
+ - sip=6.6.2=py38h6a678d5_0
+ - six=1.16.0=pyhd3eb1b0_1
+ - smmap=3.0.5=pyh44b312d_0
+ - snappy=1.1.9=h295c915_0
+ - sqlalchemy=1.4.39=py38h5eee18b_0
+ - sqlite=3.40.1=h5082296_0
+ - stack_data=0.6.2=pyhd8ed1ab_0
+ - threadpoolctl=2.2.0=pyh0d69192_0
+ - tifffile=2021.7.2=pyhd3eb1b0_2
+ - tk=8.6.12=h1ccaba5_0
+ - toml=0.10.2=pyhd3eb1b0_0
+ - toolz=0.12.0=py38h06a4308_0
+ - torchaudio=0.12.0=py38_cu113
+ - torchvision=0.13.0=py38_cu113
+ - tornado=6.1=py38h0a891b7_3
+ - tqdm=4.64.1=py38h06a4308_0
+ - traitlets=5.9.0=pyhd8ed1ab_0
+ - typing_extensions=4.4.0=py38h06a4308_0
+ - tzdata=2022g=h04d1e81_0
+ - urllib3=1.26.14=py38h06a4308_0
+ - wandb=0.13.10=pyhd8ed1ab_0
+ - wcwidth=0.2.6=pyhd8ed1ab_0
+ - wheel=0.37.1=pyhd3eb1b0_0
+ - xz=5.2.10=h5eee18b_1
+ - yaml=0.2.5=h7b6447c_0
+ - zeromq=4.3.4=h9c3ff4c_1
+ - zfp=0.5.5=h295c915_6
+ - zipp=3.11.0=py38h06a4308_0
+ - zlib=1.2.13=h5eee18b_0
+ - zstd=1.5.2=ha4553b6_0
+ - pip:
+ - einops==0.6.1
+ - lion-pytorch==0.1.2
+ - llvmlite==0.40.0
+ - memory-efficient-attention-pytorch==0.1.2
+ - numba==0.57.0
+ - opencv-python==4.7.0.72
+ - pynndescent==0.5.10
+ - seaborn==0.12.2
+ - umap-learn==0.5.3
+ - x-transformers==1.14.0
+prefix: /opt/anaconda3/envs/boss
\ No newline at end of file
diff --git a/boss/models/__init__.py b/boss/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/boss/models/base.py b/boss/models/base.py
new file mode 100644
index 0000000..e5246f2
--- /dev/null
+++ b/boss/models/base.py
@@ -0,0 +1,230 @@
+import torch
+import torch.nn as nn
+from .utils import pose_edge_index
+from torch_geometric.nn import Sequential, GCNConv
+from x_transformers import ContinuousTransformerWrapper, Decoder
+
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.fn = fn
+ self.norm = nn.LayerNorm(dim)
+ def forward(self, x, **kwargs):
+ x = self.norm(x)
+ return self.fn(x, **kwargs)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, dim),
+ nn.GELU(),
+ nn.Linear(dim, dim))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class CNN(nn.Module):
+ def __init__(self, hidden_dim):
+ super(CNN, self).__init__()
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
+ self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = nn.functional.relu(x)
+ x = self.pool(x)
+ x = self.conv2(x)
+ x = nn.functional.relu(x)
+ x = self.pool(x)
+ x = self.conv3(x)
+ x = nn.functional.relu(x)
+ x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
+ return x
+
+
+class MindNetLSTM(nn.Module):
+ """
+ Basic MindNet for model-based ToM, just LSTM on input concatenation
+ """
+ def __init__(self, hidden_dim, dropout, mods):
+ super(MindNetLSTM, self).__init__()
+ self.mods = mods
+ self.gaze_emb = nn.Linear(3, hidden_dim)
+ self.pose_edge_index = pose_edge_index()
+ self.pose_emb = GCNConv(3, hidden_dim)
+ self.LSTM = PreNorm(
+ hidden_dim*len(mods),
+ nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
+ self.proj = nn.Linear(hidden_dim*2, hidden_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.act = nn.GELU()
+
+ def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
+ feats = []
+ if 'rgb' in self.mods:
+ feats.append(rgb_feats)
+ if 'ocr' in self.mods:
+ feats.append(ocr_feats)
+ if 'pose' in self.mods:
+ bs, seq_len = pose.size(0), pose.size(1)
+ self.pose_edge_index = self.pose_edge_index.to(pose.device)
+ pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
+ pose_emb = self.dropout(self.act(pose_emb))
+ pose_emb = torch.mean(pose_emb, dim=1)
+ hd = pose_emb.size(-1)
+ feats.append(pose_emb.view(bs, seq_len, hd))
+ if 'gaze' in self.mods:
+ gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
+ feats.append(gaze_feats)
+ if 'bbox' in self.mods:
+ feats.append(bbox_feats)
+ lstm_inp = torch.cat(feats, 2)
+ lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
+ c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
+ return self.act(self.proj(lstm_out)), c_n, feats
+
+
+class MindNetSL(nn.Module):
+ """
+ Basic MindNet for SL ToM, just LSTM on input concatenation
+ """
+ def __init__(self, hidden_dim, dropout, mods):
+ super(MindNetSL, self).__init__()
+ self.mods = mods
+ self.gaze_emb = nn.Linear(3, hidden_dim)
+ self.pose_edge_index = pose_edge_index()
+ self.pose_emb = GCNConv(3, hidden_dim)
+ self.LSTM = PreNorm(
+ hidden_dim*5,
+ nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
+ )
+ self.proj = nn.Linear(hidden_dim*2, hidden_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.act = nn.GELU()
+
+ def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
+ feats = []
+ if 'rgb' in self.mods:
+ feats.append(rgb_feats)
+ if 'ocr' in self.mods:
+ feats.append(ocr_feats)
+ if 'pose' in self.mods:
+ bs, seq_len = pose.size(0), pose.size(1)
+ self.pose_edge_index = self.pose_edge_index.to(pose.device)
+ pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
+ pose_emb = self.dropout(self.act(pose_emb))
+ pose_emb = torch.mean(pose_emb, dim=1)
+ hd = pose_emb.size(-1)
+ feats.append(pose_emb.view(bs, seq_len, hd))
+ if 'gaze' in self.mods:
+ gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
+ feats.append(gaze_feats)
+ if 'bbox' in self.mods:
+ feats.append(bbox_feats)
+ lstm_inp = torch.cat(feats, 2)
+ lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
+ return self.act(self.proj(lstm_out)), feats
+
+
+class MindNetTF(nn.Module):
+ """
+ Basic MindNet for model-based ToM, Transformer on input concatenation
+ """
+ def __init__(self, hidden_dim, dropout, mods):
+ super(MindNetTF, self).__init__()
+ self.mods = mods
+ self.gaze_emb = nn.Linear(3, hidden_dim)
+ self.pose_edge_index = pose_edge_index()
+ self.pose_emb = GCNConv(3, hidden_dim)
+ self.tf = ContinuousTransformerWrapper(
+ dim_in=hidden_dim*len(mods),
+ dim_out=hidden_dim,
+ max_seq_len=747,
+ attn_layers=Decoder(
+ dim=512,
+ depth=6,
+ heads=8
+ )
+ )
+ self.dropout = nn.Dropout(dropout)
+ self.act = nn.GELU()
+
+ def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
+ feats = []
+ if 'rgb' in self.mods:
+ feats.append(rgb_feats)
+ if 'ocr' in self.mods:
+ feats.append(ocr_feats)
+ if 'pose' in self.mods:
+ bs, seq_len = pose.size(0), pose.size(1)
+ self.pose_edge_index = self.pose_edge_index.to(pose.device)
+ pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
+ pose_emb = self.dropout(self.act(pose_emb))
+ pose_emb = torch.mean(pose_emb, dim=1)
+ hd = pose_emb.size(-1)
+ feats.append(pose_emb.view(bs, seq_len, hd))
+ if 'gaze' in self.mods:
+ gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
+ feats.append(gaze_feats)
+ if 'bbox' in self.mods:
+ feats.append(bbox_feats)
+ tf_inp = torch.cat(feats, 2)
+ tf_out = self.tf(self.dropout(tf_inp))
+ return tf_out, feats
+
+
+class MindNetLSTMXL(nn.Module):
+ """
+ Basic MindNet for model-based ToM, just LSTM on input concatenation
+ """
+ def __init__(self, hidden_dim, dropout, mods):
+ super(MindNetLSTMXL, self).__init__()
+ self.mods = mods
+ self.gaze_emb = nn.Sequential(
+ nn.Linear(3, hidden_dim),
+ nn.GELU(),
+ nn.Linear(hidden_dim, hidden_dim)
+ )
+ self.pose_edge_index = pose_edge_index()
+ self.pose_emb = Sequential('x, edge_index', [
+ (GCNConv(3, hidden_dim), 'x, edge_index -> x'),
+ nn.GELU(),
+ (GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'),
+ ])
+ self.LSTM = PreNorm(
+ hidden_dim*len(mods),
+ nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
+ )
+ self.proj = nn.Linear(hidden_dim*2, hidden_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.act = nn.GELU()
+
+ def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
+ feats = []
+ if 'rgb' in self.mods:
+ feats.append(rgb_feats)
+ if 'ocr' in self.mods:
+ feats.append(ocr_feats)
+ if 'pose' in self.mods:
+ bs, seq_len = pose.size(0), pose.size(1)
+ self.pose_edge_index = self.pose_edge_index.to(pose.device)
+ pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
+ pose_emb = self.dropout(self.act(pose_emb))
+ pose_emb = torch.mean(pose_emb, dim=1)
+ hd = pose_emb.size(-1)
+ feats.append(pose_emb.view(bs, seq_len, hd))
+ if 'gaze' in self.mods:
+ gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
+ feats.append(gaze_feats)
+ if 'bbox' in self.mods:
+ feats.append(bbox_feats)
+ lstm_inp = torch.cat(feats, 2)
+ lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
+ c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
+ return self.act(self.proj(lstm_out)), c_n, feats
\ No newline at end of file
diff --git a/boss/models/resnet.py b/boss/models/resnet.py
new file mode 100644
index 0000000..c3f810c
--- /dev/null
+++ b/boss/models/resnet.py
@@ -0,0 +1,249 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from .utils import left_bias, right_bias
+
+
+class ResNet(nn.Module):
+ def __init__(self, input_dim, device):
+ super(ResNet, self).__init__()
+ # Conv
+ resnet = models.resnet34(pretrained=True)
+ self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
+ # FFs
+ self.left = nn.Linear(input_dim, 27)
+ self.right = nn.Linear(input_dim, 27)
+ # modality FFs
+ self.pose_ff = nn.Linear(150, 150)
+ self.gaze_ff = nn.Linear(6, 6)
+ self.bbox_ff = nn.Linear(108, 108)
+ self.ocr_ff = nn.Linear(729, 64)
+ # others
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(p=0.2)
+ self.device = device
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor):
+ batch_size, sequence_len, channels, height, width = images.shape
+ left_beliefs = []
+ right_beliefs = []
+ image_feats = []
+
+ for i in range(sequence_len):
+ images_i = images[:,i].to(self.device)
+ image_i_feat = self.resnet(images_i)
+ image_i_feat = image_i_feat.view(batch_size, 512)
+ if poses is not None:
+ poses_i = poses[:,i].float()
+ poses_i_feat = self.relu(self.pose_ff(poses_i))
+ image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1)
+ if gazes is not None:
+ gazes_i = gazes[:,i].float()
+ gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
+ image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1)
+ if bboxes is not None:
+ bboxes_i = bboxes[:,i].float()
+ bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
+ image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1)
+ if ocr_tensor is not None:
+ ocr_tensor = ocr_tensor
+ ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
+ image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1)
+ image_feats.append(image_i_feat)
+
+ image_feats = torch.permute(torch.stack(image_feats), (1,0,2))
+ left_beliefs = self.left(self.dropout(image_feats))
+ right_beliefs = self.right(self.dropout(image_feats))
+
+ return left_beliefs, right_beliefs, None
+
+
+class ResNetGRU(nn.Module):
+ def __init__(self, input_dim, device):
+ super(ResNetGRU, self).__init__()
+ resnet = models.resnet34(pretrained=True)
+ self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
+ self.gru = nn.GRU(input_dim, 512, batch_first=True)
+ for name, param in self.gru.named_parameters():
+ if "weight" in name:
+ nn.init.orthogonal_(param)
+ elif "bias" in name:
+ nn.init.constant_(param, 0)
+ # FFs
+ self.left = nn.Linear(512, 27)
+ self.right = nn.Linear(512, 27)
+ # modality FFs
+ self.pose_ff = nn.Linear(150, 150)
+ self.gaze_ff = nn.Linear(6, 6)
+ self.bbox_ff = nn.Linear(108, 108)
+ self.ocr_ff = nn.Linear(729, 64)
+ # others
+ self.dropout = nn.Dropout(p=0.2)
+ self.relu = nn.ReLU()
+ self.device = device
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor):
+ batch_size, sequence_len, channels, height, width = images.shape
+ left_beliefs = []
+ right_beliefs = []
+ rnn_inp = []
+
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ rnn_i_feat = self.resnet(images_i)
+ rnn_i_feat = rnn_i_feat.view(batch_size, 512)
+ if poses is not None:
+ poses_i = poses[:,i].float()
+ poses_i_feat = self.relu(self.pose_ff(poses_i))
+ rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
+ if gazes is not None:
+ gazes_i = gazes[:,i].float()
+ gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
+ rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
+ if bboxes is not None:
+ bboxes_i = bboxes[:,i].float()
+ bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
+ rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
+ if ocr_tensor is not None:
+ ocr_tensor = ocr_tensor
+ ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
+ rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
+ rnn_inp.append(rnn_i_feat)
+
+ rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
+ rnn_out, _ = self.gru(rnn_inp)
+ left_beliefs = self.left(self.dropout(rnn_out))
+ right_beliefs = self.right(self.dropout(rnn_out))
+
+ return left_beliefs, right_beliefs, None
+
+
+class ResNetConv1D(nn.Module):
+ def __init__(self, input_dim, device):
+ super(ResNetConv1D, self).__init__()
+ resnet = models.resnet34(pretrained=True)
+ self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
+ self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4)
+ # FFs
+ self.left = nn.Linear(512, 27)
+ self.right = nn.Linear(512, 27)
+ # modality FFs
+ self.pose_ff = nn.Linear(150, 150)
+ self.gaze_ff = nn.Linear(6, 6)
+ self.bbox_ff = nn.Linear(108, 108)
+ self.ocr_ff = nn.Linear(729, 64)
+ # others
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(p=0.2)
+ self.device = device
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor):
+ batch_size, sequence_len, channels, height, width = images.shape
+ left_beliefs = []
+ right_beliefs = []
+ conv1d_inp = []
+
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ images_i_feat = self.resnet(images_i)
+ images_i_feat = images_i_feat.view(batch_size, 512)
+ if poses is not None:
+ poses_i = poses[:,i].float()
+ poses_i_feat = self.relu(self.pose_ff(poses_i))
+ images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1)
+ if gazes is not None:
+ gazes_i = gazes[:,i].float()
+ gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
+ images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1)
+ if bboxes is not None:
+ bboxes_i = bboxes[:,i].float()
+ bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
+ images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1)
+ if ocr_tensor is not None:
+ ocr_tensor = ocr_tensor
+ ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
+ images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1)
+ conv1d_inp.append(images_i_feat)
+
+ conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0))
+ conv1d_out = self.conv1d(conv1d_inp)
+ conv1d_out = conv1d_out[:,:,:-4]
+ conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1)))
+ left_beliefs = self.left(self.dropout(conv1d_out))
+ right_beliefs = self.right(self.dropout(conv1d_out))
+
+ return left_beliefs, right_beliefs, None
+
+
+class ResNetLSTM(nn.Module):
+ def __init__(self, input_dim, device):
+ super(ResNetLSTM, self).__init__()
+ resnet = models.resnet34(pretrained=True)
+ self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
+ self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True)
+ # FFs
+ self.left = nn.Linear(512, 27)
+ self.right = nn.Linear(512, 27)
+ self.left.bias.data = torch.tensor(left_bias).log()
+ self.right.bias.data = torch.tensor(right_bias).log()
+ # modality FFs
+ self.pose_ff = nn.Linear(150, 150)
+ self.gaze_ff = nn.Linear(6, 6)
+ self.bbox_ff = nn.Linear(108, 108)
+ self.ocr_ff = nn.Linear(729, 64)
+ # others
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(p=0.2)
+ self.device = device
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor):
+ batch_size, sequence_len, channels, height, width = images.shape
+ left_beliefs = []
+ right_beliefs = []
+ rnn_inp = []
+
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ rnn_i_feat = self.resnet(images_i)
+ rnn_i_feat = rnn_i_feat.view(batch_size, 512)
+ if poses is not None:
+ poses_i = poses[:,i].float()
+ poses_i_feat = self.relu(self.pose_ff(poses_i))
+ rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
+ if gazes is not None:
+ gazes_i = gazes[:,i].float()
+ gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
+ rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
+ if bboxes is not None:
+ bboxes_i = bboxes[:,i].float()
+ bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
+ rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
+ if ocr_tensor is not None:
+ ocr_tensor = ocr_tensor
+ ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
+ rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
+ rnn_inp.append(rnn_i_feat)
+
+ rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
+ rnn_out, _ = self.lstm(rnn_inp)
+ left_beliefs = self.left(self.dropout(rnn_out))
+ right_beliefs = self.right(self.dropout(rnn_out))
+
+ return left_beliefs, right_beliefs, None
+
+
+
+
+
+
+
+
+
+
+if __name__ == '__main__':
+ images = torch.ones(3, 22, 3, 128, 128)
+ poses = torch.ones(3, 22, 150)
+ gazes = torch.ones(3, 22, 6)
+ bboxes = torch.ones(3, 22, 108)
+ model = ResNet(32, 'cpu')
+ print(model(images, poses, gazes, bboxes, None)[0].shape)
diff --git a/boss/models/single_mindnet.py b/boss/models/single_mindnet.py
new file mode 100644
index 0000000..c9c6aa8
--- /dev/null
+++ b/boss/models/single_mindnet.py
@@ -0,0 +1,95 @@
+import torch
+import torch.nn as nn
+from torch_geometric.nn.conv import GCNConv
+from .utils import left_bias, right_bias, build_ocr_graph
+from .base import CNN, MindNetLSTM
+import torchvision.models as models
+
+
+class SingleMindNet(nn.Module):
+ """
+ Base ToM net. Supports any subset of modalities
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
+ super(SingleMindNet, self).__init__()
+
+ # ---- Images ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+
+ # ---- OCR and bbox -----#
+ self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
+ self.ocr_gnn = GCNConv(-1, hidden_dim)
+ self.bbox_ff = nn.Linear(108, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net = MindNetLSTM(hidden_dim, dropout, mods)
+
+ self.left = nn.Linear(hidden_dim, 27)
+ self.right = nn.Linear(hidden_dim, 27)
+ self.left.bias.data = torch.tensor(left_bias).log()
+ self.right.bias.data = torch.tensor(right_bias).log()
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
+
+ assert images is not None
+ assert poses is not None
+ assert gazes is not None
+ assert bboxes is not None
+
+ batch_size, sequence_len, channels, height, width = images.shape
+
+ bbox_feat = self.act(self.bbox_ff(bboxes))
+
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+
+ ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
+ self.ocr_edge_index,
+ self.ocr_edge_attr)))
+ ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
+
+ out_left, cell_left, feats_left = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
+ out_right, cell_right, feats_right = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
+
+ return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
+
+
+
+
+if __name__ == "__main__":
+
+ def count_parameters(model):
+ import numpy as np
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+ return sum([np.prod(p.size()) for p in model_parameters])
+
+ images = torch.ones(3, 22, 3, 128, 128)
+ poses = torch.ones(3, 22, 2, 75)
+ gazes = torch.ones(3, 22, 2, 3)
+ bboxes = torch.ones(3, 22, 108)
+ model = SingleMindNet(64, 'cpu', False, 0.5)
+ print(count_parameters(model))
+ breakpoint()
+ out = model(images, poses, gazes, bboxes, None)
+ print(out[0].shape)
\ No newline at end of file
diff --git a/boss/models/tom_base.py b/boss/models/tom_base.py
new file mode 100644
index 0000000..ae45d91
--- /dev/null
+++ b/boss/models/tom_base.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from torch_geometric.nn.conv import GCNConv
+from .utils import left_bias, right_bias, build_ocr_graph
+from .base import CNN, MindNetLSTM
+import numpy as np
+
+
+class BaseToMnet(nn.Module):
+ """
+ Base ToM net. Supports any subset of modalities
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
+ super(BaseToMnet, self).__init__()
+
+ # ---- Images ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+
+ # ---- OCR and bbox -----#
+ self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
+ self.ocr_gnn = GCNConv(-1, hidden_dim)
+ self.bbox_ff = nn.Linear(108, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
+ self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
+
+ self.left = nn.Linear(hidden_dim, 27)
+ self.right = nn.Linear(hidden_dim, 27)
+ self.left.bias.data = torch.tensor(left_bias).log()
+ self.right.bias.data = torch.tensor(right_bias).log()
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
+
+ assert images is not None
+ assert poses is not None
+ assert gazes is not None
+ assert bboxes is not None
+
+ batch_size, sequence_len, channels, height, width = images.shape
+
+ bbox_feat = self.act(self.bbox_ff(bboxes))
+
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+
+ ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
+ self.ocr_edge_index,
+ self.ocr_edge_attr)))
+ ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
+
+ out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
+ out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
+
+ return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
+
+
+
+
+def count_parameters(model):
+ #return sum(p.numel() for p in model.parameters() if p.requires_grad)
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+ return sum([np.prod(p.size()) for p in model_parameters])
+
+
+
+
+
+
+
+
+if __name__ == "__main__":
+
+ images = torch.ones(3, 22, 3, 128, 128)
+ poses = torch.ones(3, 22, 2, 75)
+ gazes = torch.ones(3, 22, 2, 3)
+ bboxes = torch.ones(3, 22, 108)
+ model = BaseToMnet(64, 'cpu', False, 0.5)
+ print(count_parameters(model))
+ breakpoint()
+ out = model(images, poses, gazes, bboxes, None)
+ print(out[0].shape)
\ No newline at end of file
diff --git a/boss/models/tom_common_mind.py b/boss/models/tom_common_mind.py
new file mode 100644
index 0000000..1e1c158
--- /dev/null
+++ b/boss/models/tom_common_mind.py
@@ -0,0 +1,269 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from torch_geometric.nn.conv import GCNConv
+from .utils import left_bias, right_bias, build_ocr_graph
+from .base import CNN, MindNetLSTM, MindNetLSTMXL
+from memory_efficient_attention_pytorch import Attention
+
+
+class CommonMindToMnet(nn.Module):
+ """
+
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
+ super(CommonMindToMnet, self).__init__()
+
+ self.aggr = aggr
+
+ # ---- Images ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ #for param in self.cnn.parameters():
+ # param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+
+ # ---- OCR and bbox -----#
+ self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
+ self.ocr_gnn = GCNConv(-1, hidden_dim)
+ self.bbox_ff = nn.Linear(108, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
+ self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
+ self.proj = nn.Linear(hidden_dim*2, hidden_dim)
+ self.ln_left = nn.LayerNorm(hidden_dim)
+ self.ln_right = nn.LayerNorm(hidden_dim)
+ if aggr == 'attn':
+ self.attn_left = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.attn_right = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.left = nn.Linear(hidden_dim, 27)
+ self.right = nn.Linear(hidden_dim, 27)
+ self.left.bias.data = torch.tensor(left_bias).log()
+ self.right.bias.data = torch.tensor(right_bias).log()
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
+
+ assert images is not None
+ assert poses is not None
+ assert gazes is not None
+ assert bboxes is not None
+
+ batch_size, sequence_len, channels, height, width = images.shape
+
+ bbox_feat = self.act(self.bbox_ff(bboxes))
+
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+
+ ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
+ self.ocr_edge_index,
+ self.ocr_edge_attr)))
+ ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
+
+ out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
+ out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
+
+ common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
+
+ if self.aggr == 'no_tom':
+ return self.left(out_left), self.right(out_right)
+
+ if self.aggr == 'attn':
+ l = self.attn_left(x=out_left, context=common_mind)
+ r = self.attn_right(x=out_right, context=common_mind)
+ elif self.aggr == 'mult':
+ l = out_left * common_mind
+ r = out_right * common_mind
+ elif self.aggr == 'sum':
+ l = out_left + common_mind
+ r = out_right + common_mind
+ elif self.aggr == 'concat':
+ l = torch.cat([out_left, common_mind], 1)
+ r = torch.cat([out_right, common_mind], 1)
+ else: raise ValueError
+ l = self.act(l)
+ l = self.ln_left(l)
+ r = self.act(r)
+ r = self.ln_right(r)
+ if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
+ left_beliefs = self.left(l)
+ right_beliefs = self.right(r)
+ if self.aggr == 'concat':
+ left_beliefs = self.left(l)[:, :-1, :]
+ right_beliefs = self.right(r)[:, :-1, :]
+
+ return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
+
+
+
+class CommonMindToMnetXL(nn.Module):
+ """
+ XL model.
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
+ super(CommonMindToMnetXL, self).__init__()
+
+ self.aggr = aggr
+
+ # ---- Images ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+
+ # ---- OCR and bbox -----#
+ self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
+ self.ocr_gnn = GCNConv(-1, hidden_dim)
+ self.bbox_ff = nn.Linear(108, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_left = MindNetLSTMXL(hidden_dim, dropout, mods)
+ self.mind_net_right = MindNetLSTMXL(hidden_dim, dropout, mods)
+ self.proj = nn.Linear(hidden_dim*2, hidden_dim)
+ self.ln_left = nn.LayerNorm(hidden_dim)
+ self.ln_right = nn.LayerNorm(hidden_dim)
+ if aggr == 'attn':
+ self.attn_left = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.attn_right = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.left = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.GELU(),
+ nn.Linear(hidden_dim, 27),
+ )
+ self.right = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.GELU(),
+ nn.Linear(hidden_dim, 27)
+ )
+ self.left[-1].bias.data = torch.tensor(left_bias).log()
+ self.right[-1].bias.data = torch.tensor(right_bias).log()
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
+
+ assert images is not None
+ assert poses is not None
+ assert gazes is not None
+ assert bboxes is not None
+
+ batch_size, sequence_len, channels, height, width = images.shape
+
+ bbox_feat = self.act(self.bbox_ff(bboxes))
+
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+
+ ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
+ self.ocr_edge_index,
+ self.ocr_edge_attr)))
+ ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
+
+ out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
+ out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
+
+ common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
+
+ if self.aggr == 'no_tom':
+ return self.left(out_left), self.right(out_right)
+
+ if self.aggr == 'attn':
+ l = self.attn_left(x=out_left, context=common_mind)
+ r = self.attn_right(x=out_right, context=common_mind)
+ elif self.aggr == 'mult':
+ l = out_left * common_mind
+ r = out_right * common_mind
+ elif self.aggr == 'sum':
+ l = out_left + common_mind
+ r = out_right + common_mind
+ elif self.aggr == 'concat':
+ l = torch.cat([out_left, common_mind], 1)
+ r = torch.cat([out_right, common_mind], 1)
+ else: raise ValueError
+ l = self.act(l)
+ l = self.ln_left(l)
+ r = self.act(r)
+ r = self.ln_right(r)
+ if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
+ left_beliefs = self.left(l)
+ right_beliefs = self.right(r)
+ if self.aggr == 'concat':
+ left_beliefs = self.left(l)[:, :-1, :]
+ right_beliefs = self.right(r)[:, :-1, :]
+
+ return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
+
+
+
+
+
+
+
+
+if __name__ == "__main__":
+
+ images = torch.ones(3, 22, 3, 128, 128)
+ poses = torch.ones(3, 22, 2, 75)
+ gazes = torch.ones(3, 22, 2, 3)
+ bboxes = torch.ones(3, 22, 108)
+ model = CommonMindToMnetXL(64, 'cpu', False, 0.5, aggr='attn')
+ out = model(images, poses, gazes, bboxes, None)
+ print(out[0].shape)
\ No newline at end of file
diff --git a/boss/models/tom_implicit.py b/boss/models/tom_implicit.py
new file mode 100644
index 0000000..6050606
--- /dev/null
+++ b/boss/models/tom_implicit.py
@@ -0,0 +1,144 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from torch_geometric.nn.conv import GCNConv
+from .utils import left_bias, right_bias, build_ocr_graph
+from .base import CNN, MindNetLSTM
+from memory_efficient_attention_pytorch import Attention
+
+
+class ImplicitToMnet(nn.Module):
+ """
+ Implicit ToM net. Supports any subset of modalities
+ Possible aggregations: sum, mult, attn, concat
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
+ super(ImplicitToMnet, self).__init__()
+
+ self.aggr = aggr
+
+ # ---- Images ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+
+ # ---- OCR and bbox -----#
+ self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
+ self.ocr_gnn = GCNConv(-1, hidden_dim)
+ self.bbox_ff = nn.Linear(108, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
+ self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
+ self.ln_left = nn.LayerNorm(hidden_dim)
+ self.ln_right = nn.LayerNorm(hidden_dim)
+ if aggr == 'attn':
+ self.attn_left = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.attn_right = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.left = nn.Linear(hidden_dim, 27)
+ self.right = nn.Linear(hidden_dim, 27)
+ self.left.bias.data = torch.tensor(left_bias).log()
+ self.right.bias.data = torch.tensor(right_bias).log()
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
+
+ assert images is not None
+ assert poses is not None
+ assert gazes is not None
+ assert bboxes is not None
+
+ batch_size, sequence_len, channels, height, width = images.shape
+
+ bbox_feat = self.act(self.bbox_ff(bboxes))
+
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+
+ ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
+ self.ocr_edge_index,
+ self.ocr_edge_attr)))
+ ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
+
+ out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
+ out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
+
+ if self.aggr == 'no_tom':
+ return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
+
+ if self.aggr == 'attn':
+ l = self.attn_left(x=out_left, context=cell_right)
+ r = self.attn_right(x=out_right, context=cell_left)
+ elif self.aggr == 'mult':
+ l = out_left * cell_right
+ r = out_right * cell_left
+ elif self.aggr == 'sum':
+ l = out_left + cell_right
+ r = out_right + cell_left
+ elif self.aggr == 'concat':
+ l = torch.cat([out_left, cell_right], 1)
+ r = torch.cat([out_right, cell_left], 1)
+ else: raise ValueError
+ l = self.act(l)
+ l = self.ln_left(l)
+ r = self.act(r)
+ r = self.ln_right(r)
+ if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
+ left_beliefs = self.left(l)
+ right_beliefs = self.right(r)
+ if self.aggr == 'concat':
+ left_beliefs = self.left(l)[:, :-1, :]
+ right_beliefs = self.right(r)[:, :-1, :]
+
+ return left_beliefs, right_beliefs, [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
+
+
+
+
+
+
+
+
+
+
+
+
+if __name__ == "__main__":
+
+ images = torch.ones(3, 22, 3, 128, 128)
+ poses = torch.ones(3, 22, 2, 75)
+ gazes = torch.ones(3, 22, 2, 3)
+ bboxes = torch.ones(3, 22, 108)
+ model = ImplicitToMnet(64, 'cpu', False, 0.5, aggr='attn')
+ out = model(images, poses, gazes, bboxes, None)
+ print(out[0].shape)
diff --git a/boss/models/tom_sl.py b/boss/models/tom_sl.py
new file mode 100644
index 0000000..34dab4c
--- /dev/null
+++ b/boss/models/tom_sl.py
@@ -0,0 +1,98 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from torch_geometric.nn.conv import GCNConv
+from .utils import left_bias, right_bias, build_ocr_graph, pose_edge_index
+from .base import CNN, MindNetSL
+
+
+class SLToMnet(nn.Module):
+ """
+ Speaker-Listener ToMnet
+ """
+ def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
+ super(SLToMnet, self).__init__()
+
+ self.tom_weight = tom_weight
+
+ # ---- Images ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+
+ # ---- OCR and bbox -----#
+ self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
+ self.ocr_gnn = GCNConv(-1, hidden_dim)
+ self.bbox_ff = nn.Linear(108, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_left = MindNetSL(hidden_dim, dropout, mods)
+ self.mind_net_right = MindNetSL(hidden_dim, dropout, mods)
+ self.left = nn.Linear(hidden_dim, 27)
+ self.right = nn.Linear(hidden_dim, 27)
+ self.left.bias.data = torch.tensor(left_bias).log()
+ self.right.bias.data = torch.tensor(right_bias).log()
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
+
+ batch_size, sequence_len, channels, height, width = images.shape
+
+ bbox_feat = self.act(self.bbox_ff(bboxes))
+
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+
+ ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
+ self.ocr_edge_index,
+ self.ocr_edge_attr)))
+ ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
+
+ out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
+ left_logits = self.left(out_left)
+ out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
+ right_logits = self.left(out_right)
+
+ left_ranking = torch.log_softmax(left_logits, dim=-1)
+ right_ranking = torch.log_softmax(right_logits, dim=-1)
+
+ right_beliefs = left_ranking + self.tom_weight * right_ranking
+ left_beliefs = right_ranking + self.tom_weight * left_ranking
+
+ return left_beliefs, right_beliefs, [out_left, out_right] + feats_left + feats_right
+
+
+
+
+
+
+
+
+
+if __name__ == "__main__":
+
+ images = torch.ones(3, 22, 3, 128, 128)
+ poses = torch.ones(3, 22, 2, 75)
+ gazes = torch.ones(3, 22, 2, 3)
+ bboxes = torch.ones(3, 22, 108)
+ model = SLToMnet(64, 'cpu', 2.0, False, 0.5)
+ out = model(images, poses, gazes, bboxes, None)
+ print(out[0].shape)
\ No newline at end of file
diff --git a/boss/models/tom_tf.py b/boss/models/tom_tf.py
new file mode 100644
index 0000000..264a33d
--- /dev/null
+++ b/boss/models/tom_tf.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from torch_geometric.nn.conv import GCNConv
+from .utils import left_bias, right_bias, build_ocr_graph
+from .base import CNN, MindNetTF
+from memory_efficient_attention_pytorch import Attention
+
+
+class TFToMnet(nn.Module):
+ """
+ Implicit ToM net. Supports any subset of modalities
+ Possible aggregations: sum, mult, attn, concat
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
+ super(TFToMnet, self).__init__()
+
+ # ---- Images ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+
+ # ---- OCR and bbox -----#
+ self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
+ self.ocr_gnn = GCNConv(-1, hidden_dim)
+ self.bbox_ff = nn.Linear(108, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_left = MindNetTF(hidden_dim, dropout, mods)
+ self.mind_net_right = MindNetTF(hidden_dim, dropout, mods)
+ self.left = nn.Linear(hidden_dim, 27)
+ self.right = nn.Linear(hidden_dim, 27)
+ self.left.bias.data = torch.tensor(left_bias).log()
+ self.right.bias.data = torch.tensor(right_bias).log()
+
+ def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
+
+ assert images is not None
+ assert poses is not None
+ assert gazes is not None
+ assert bboxes is not None
+
+ batch_size, sequence_len, channels, height, width = images.shape
+
+ bbox_feat = self.act(self.bbox_ff(bboxes))
+
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = images[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+
+ ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
+ self.ocr_edge_index,
+ self.ocr_edge_attr)))
+ ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
+
+ out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
+ out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
+
+ l = self.dropout(self.act(out_left))
+ r = self.dropout(self.act(out_right))
+ left_beliefs = self.left(l)
+ right_beliefs = self.right(r)
+
+ return left_beliefs, right_beliefs, feats_left + feats_right
+
+
+
+
+
+
+
+
+
+
+
+
+if __name__ == "__main__":
+
+ images = torch.ones(3, 22, 3, 128, 128)
+ poses = torch.ones(3, 22, 2, 75)
+ gazes = torch.ones(3, 22, 2, 3)
+ bboxes = torch.ones(3, 22, 108)
+ model = TFToMnet(64, 'cpu', False, 0.5)
+ out = model(images, poses, gazes, bboxes, None)
+ print(out[0].shape)
+
diff --git a/boss/models/utils.py b/boss/models/utils.py
new file mode 100644
index 0000000..ba8969e
--- /dev/null
+++ b/boss/models/utils.py
@@ -0,0 +1,95 @@
+import torch
+
+left_bias = [0.10659290976303822,
+ 0.025158905348262015,
+ 0.02811095449589107,
+ 0.026342384511050237,
+ 0.025318475572458178,
+ 0.02283183957873461,
+ 0.021581872822531316,
+ 0.08062285577511237,
+ 0.03824366373234754,
+ 0.04853594319300018,
+ 0.09653998563867983,
+ 0.02961357410707162,
+ 0.02961357410707162,
+ 0.03172787957767081,
+ 0.029985904630196004,
+ 0.02897529321028696,
+ 0.06602218026116327,
+ 0.015345336560197867,
+ 0.026900880295736816,
+ 0.024879657455918726,
+ 0.028669450280577644,
+ 0.01936118720246802,
+ 0.02341693040078721,
+ 0.014707055663413206,
+ 0.027007260445200926,
+ 0.04146166325363687,
+ 0.04243238211749688]
+
+right_bias = [0.13147256721895695,
+ 0.012433179968617855,
+ 0.01623627031195979,
+ 0.013683146724821148,
+ 0.015252253929416771,
+ 0.012579452674131008,
+ 0.03127576394244834,
+ 0.10325523257360177,
+ 0.041155820323927554,
+ 0.06563655221935587,
+ 0.12684503071726816,
+ 0.016156485199861705,
+ 0.0176989973670913,
+ 0.020238823435546928,
+ 0.01918831945958884,
+ 0.01791175766601952,
+ 0.08768383819579266,
+ 0.019002154198026647,
+ 0.029600276588388607,
+ 0.01578415467673732,
+ 0.0176989973670913,
+ 0.011834791627882237,
+ 0.014919815962341426,
+ 0.007552990611951809,
+ 0.029759846812584773,
+ 0.04981250498656951,
+ 0.05533097524002021]
+
+def build_ocr_graph(device):
+ ocr_graph = [
+ [15, [10, 4], [17, 2]],
+ [13, [16, 7], [18, 4]],
+ [11, [16, 4], [7, 10]],
+ [14, [10, 11], [7, 1]],
+ [12, [10, 9], [16, 3]],
+ [1, [7, 2], [9, 9], [10, 2]],
+ [5, [8, 8], [6, 8]],
+ [4, [9, 8], [7, 6]],
+ [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
+ [2, [10, 1], [7, 7], [9, 3]],
+ [19, [10, 2], [26, 6]],
+ [20, [10, 7], [26, 5]],
+ [22, [25, 4], [10, 8]],
+ [23, [25, 15]],
+ [21, [16, 5], [24, 8]]
+ ]
+ edge_index = []
+ edge_attr = []
+ for i in range(len(ocr_graph)):
+ for j in range(1, len(ocr_graph[i])):
+ source_node = ocr_graph[i][0]
+ target_node = ocr_graph[i][j][0]
+ edge_index.append([source_node, target_node])
+ edge_attr.append(ocr_graph[i][j][1])
+ ocr_edge_index = torch.tensor(edge_index).t().long()
+ ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1)
+ x = torch.arange(0, 27)
+ ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float)
+ return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device)
+
+def pose_edge_index():
+ return torch.tensor(
+ [[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20],
+ [15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]],
+ dtype=torch.long)
diff --git a/boss/new_bbox/test_bbox.tar.gz b/boss/new_bbox/test_bbox.tar.gz
new file mode 100644
index 0000000..61b78e0
Binary files /dev/null and b/boss/new_bbox/test_bbox.tar.gz differ
diff --git a/boss/new_bbox/train_bbox.tar.gz b/boss/new_bbox/train_bbox.tar.gz
new file mode 100644
index 0000000..8fba733
Binary files /dev/null and b/boss/new_bbox/train_bbox.tar.gz differ
diff --git a/boss/new_bbox/val_bbox.tar.gz b/boss/new_bbox/val_bbox.tar.gz
new file mode 100644
index 0000000..28dbe56
Binary files /dev/null and b/boss/new_bbox/val_bbox.tar.gz differ
diff --git a/boss/outfile b/boss/outfile
new file mode 100644
index 0000000..5342b84
Binary files /dev/null and b/boss/outfile differ
diff --git a/boss/plots/old_vs_new_bbox.py b/boss/plots/old_vs_new_bbox.py
new file mode 100644
index 0000000..2ef1864
--- /dev/null
+++ b/boss/plots/old_vs_new_bbox.py
@@ -0,0 +1,82 @@
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+sns.set_theme(style='whitegrid')
+
+COLORS = sns.color_palette()
+
+MTOM_COLORS = {
+ "MN1": (110/255, 117/255, 161/255),
+ "MN2": (179/255, 106/255, 98/255),
+ "Base": (193/255, 198/255, 208/255),
+ "CG": (170/255, 129/255, 42/255),
+ "IC": (97/255, 112/255, 83/255),
+ "DB": (144/255, 63/255, 110/255)
+}
+
+abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
+
+abl_old_bbox_mean = [0.539406718, 0.5348262324, 0.529845863]
+abl_old_bbox_std = [0.03639921819, 0.01519544901, 0.01718265794]
+
+filename = 'results/abl_cm_concat_old_vs_new_bbox'
+
+#def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
+
+means = []
+stds = []
+for key, values in abl_tom_cm_concat.items():
+ mean = np.mean(values)
+ std = np.std(values)
+ means.append(mean)
+ stds.append(std)
+fig, ax = plt.subplots(figsize=(10,4))
+x = np.arange(len(abl_tom_cm_concat))
+width = 0.6
+rects1 = ax.bar(
+ x, means, width, label='New bbox', yerr=stds,
+ capsize=5,
+ color=[MTOM_COLORS['CG']]*8,
+ edgecolor='black',
+ linewidth=1.5,
+ alpha=0.6
+)
+rects2 = ax.bar(
+ [0, 2, 5], abl_old_bbox_mean, width, label='Original bbox', yerr=abl_old_bbox_std,
+ capsize=5,
+ color=[COLORS[9]]*8,
+ edgecolor='black',
+ linewidth=1.5,
+ alpha=0.6
+)
+ax.set_ylabel('Accuracy', fontsize=18)
+xticklabels = list(abl_tom_cm_concat.keys())
+ax.set_xticks(np.arange(len(xticklabels)))
+ax.set_xticklabels(xticklabels, rotation=0, fontsize=16)
+ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
+# Add value labels above each bar, avoiding overlapping with error bars
+for rect, std in zip(rects1, stds):
+ height = rect.get_height()
+ if height + std < ax.get_ylim()[1]:
+ ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
+ xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
+ else:
+ ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
+ xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
+for rect, std in zip(rects2, abl_old_bbox_std):
+ height = rect.get_height()
+ if height + std < ax.get_ylim()[1]:
+ ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
+ xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
+ else:
+ ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
+ xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
+# Remove spines
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.legend(fontsize=14)
+ax.grid(axis='x')
+plt.tight_layout()
+plt.savefig(f'{filename}.pdf', bbox_inches='tight')
\ No newline at end of file
diff --git a/boss/plots/pca.py b/boss/plots/pca.py
new file mode 100644
index 0000000..07bc8dd
--- /dev/null
+++ b/boss/plots/pca.py
@@ -0,0 +1,82 @@
+import torch
+import os
+import seaborn as sns
+import matplotlib.pyplot as plt
+import numpy as np
+from sklearn.decomposition import PCA
+
+
+FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-22_12-00-38_train_None" # no_tom seed 1
+
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-39-41_train_None" # impl mult seed 1
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-47-07_train_None" # impl sum seed 1
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_15-58-44_train_None" # impl attn seed 1
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-58-04_train_None" # impl concat seed 1
+
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-40-01_train_None" # cm mult seed 1
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-45-55_train_None" # cm sum seed 1
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-50-42_train_None" # cm attn seed 1
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-57-15_train_None" # cm concat seed 1
+
+#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-37-50_train_None" # db seed 1
+
+print(FOLDER_PATH)
+
+MTOM_COLORS = {
+ "MN1": (110/255, 117/255, 161/255),
+ "MN2": (179/255, 106/255, 98/255),
+ "Base": (193/255, 198/255, 208/255),
+ "CG": (170/255, 129/255, 42/255),
+ "IC": (97/255, 112/255, 83/255),
+ "DB": (144/255, 63/255, 110/255)
+}
+
+sns.set_theme(style='white')
+
+for i in range(60):
+
+ print(f'Computing analysis for test video {i}...', end='\r')
+
+ emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
+ data = torch.load(emb_file)
+ if len(data) == 14: # implicit
+ model = 'impl'
+ out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
+ out_left = out_left.squeeze(0)
+ cell_left = cell_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ cell_right = cell_right.squeeze(0)
+ elif len(data) == 13: # common mind
+ model = 'cm'
+ out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
+ out_left = out_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ common_mind = common_mind.squeeze(0)
+ elif len(data) == 12: # speaker-listener
+ model = 'sl'
+ out_left, out_right, feats = data[0], data[1], data[2:]
+ out_left = out_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
+
+ # ====== PCA for left and right embeddings ====== #
+
+ out_left_and_right = np.concatenate((out_left, out_right), axis=0)
+
+ pca = PCA(n_components=2)
+ pca_result = pca.fit_transform(out_left_and_right)
+
+ # Separate the PCA results for each tensor
+ pca_result_left = pca_result[:out_left.shape[0]]
+ pca_result_right = pca_result[out_right.shape[0]:]
+
+ plt.figure(figsize=(7,6))
+ plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
+ plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
+ plt.xlabel('Principal Component 1', fontsize=30)
+ plt.ylabel('Principal Component 2', fontsize=30)
+ plt.grid(False)
+ plt.legend(fontsize=30)
+ plt.tight_layout()
+ plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
+ plt.close()
\ No newline at end of file
diff --git a/boss/results/abl_cm_concat.json b/boss/results/abl_cm_concat.json
new file mode 100644
index 0000000..abbf2e7
--- /dev/null
+++ b/boss/results/abl_cm_concat.json
@@ -0,0 +1,42 @@
+{
+ "all": [
+ 0.7402937327,
+ 0.7171004799,
+ 0.7305511124
+ ],
+ "rgb\npose\ngaze": [
+ 0.4736803839,
+ 0.4658281227,
+ 0.4948378653
+ ],
+ "rgb\nocr\nbbox": [
+ 0.7364403083,
+ 0.7407663225,
+ 0.7321506471
+ ],
+ "rgb\ngaze": [
+ 0.5013814163,
+ 0.418859968,
+ 0.4644103534
+ ],
+ "rgb\npose": [
+ 0.4545586738,
+ 0.4584484514,
+ 0.4525229024
+ ],
+ "rgb\nbbox": [
+ 0.716591537,
+ 0.7334593573,
+ 0.758324851
+ ],
+ "rgb\nocr": [
+ 0.4581939799,
+ 0.456449033,
+ 0.4577577432
+ ],
+ "rgb": [
+ 0.4896393776,
+ 0.4907299695,
+ 0.4583757452
+ ]
+}
\ No newline at end of file
diff --git a/boss/results/abl_cm_concat.pdf b/boss/results/abl_cm_concat.pdf
new file mode 100644
index 0000000..62c9ef3
Binary files /dev/null and b/boss/results/abl_cm_concat.pdf differ
diff --git a/boss/results/all.json b/boss/results/all.json
new file mode 100644
index 0000000..e003477
--- /dev/null
+++ b/boss/results/all.json
@@ -0,0 +1,72 @@
+{
+ "no_tom": [
+ 0.6504653192,
+ 0.6404682274,
+ 0.6666787844
+ ],
+ "sl": [
+ 0.6584993456,
+ 0.6608259415,
+ 0.6584629926
+ ],
+ "cm mult": [
+ 0.6739130435,
+ 0.6872182638,
+ 0.6591173477
+ ],
+ "cm sum": [
+ 0.5820852116,
+ 0.6401774029,
+ 0.6001163298
+ ],
+ "cm attn": [
+ 0.3240511851,
+ 0.2688672386,
+ 0.3028573506
+ ],
+ "cm concat": [
+ 0.7402937327,
+ 0.7171004799,
+ 0.7305511124
+ ],
+ "impl mult": [
+ 0.7119746983,
+ 0.7019776065,
+ 0.7244437982
+ ],
+ "impl sum": [
+ 0.6542460375,
+ 0.6642431293,
+ 0.6678420823
+ ],
+ "impl attn": [
+ 0.3311763851,
+ 0.3125636179,
+ 0.3112549077
+ ],
+ "impl concat": [
+ 0.6957975862,
+ 0.7227352043,
+ 0.6971062964
+ ],
+ "resnet": [
+ 0.467173186,
+ 0.4267485822,
+ 0.4195870292
+ ],
+ "gru": [
+ 0.5724152974,
+ 0.525628908,
+ 0.4825141777
+ ],
+ "lstm": [
+ 0.6210193398,
+ 0.5547477098,
+ 0.6572633416
+ ],
+ "conv1d": [
+ 0.4478333576,
+ 0.4114802966,
+ 0.3853060928
+ ]
+}
\ No newline at end of file
diff --git a/boss/results/all.pdf b/boss/results/all.pdf
new file mode 100644
index 0000000..4037bee
Binary files /dev/null and b/boss/results/all.pdf differ
diff --git a/boss/test.py b/boss/test.py
new file mode 100644
index 0000000..ec2f599
--- /dev/null
+++ b/boss/test.py
@@ -0,0 +1,181 @@
+import argparse
+import torch
+from tqdm import tqdm
+import csv
+import os
+from torch.utils.data import DataLoader
+
+from dataloader import DataTest
+from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
+from models.tom_implicit import ImplicitToMnet
+from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
+from models.tom_sl import SLToMnet
+from models.single_mindnet import SingleMindNet
+from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim
+
+
+def test(args):
+
+ if args.test_frames == 'friends':
+ test_frame_ids = friends
+ elif args.test_frames == 'strangers':
+ test_frame_ids = strangers
+ elif args.test_frames == 'once':
+ test_frame_ids = tried_once
+ elif args.test_frames == 'twice_thrice':
+ test_frame_ids = tried_twice + tried_thrice
+ elif args.test_frames is None:
+ test_frame_ids = None
+ else: raise NameError
+
+ if args.median is not None:
+ median = (240, False)
+ else:
+ median = None
+
+ if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
+ flatten_dim = 2
+ else:
+ flatten_dim = 1
+
+ # load datasets
+ test_dataset = DataTest(
+ args.test_frame_path,
+ args.label_path,
+ args.test_pose_path,
+ args.test_gaze_path,
+ args.test_bbox_path,
+ args.ocr_graph_path,
+ args.presaved,
+ test_frame_ids,
+ median,
+ flatten_dim=flatten_dim
+ )
+ test_dataloader = DataLoader(
+ test_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=args.num_workers,
+ collate_fn=pad_collate,
+ pin_memory=args.pin_memory
+ )
+ device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
+ assert args.load_model_path is not None
+ if args.model_type == 'resnet':
+ inp_dim = get_input_dim(args.mods)
+ model = ResNet(inp_dim, device).to(device)
+ elif args.model_type == 'gru':
+ inp_dim = get_input_dim(args.mods)
+ model = ResNetGRU(inp_dim, device).to(device)
+ elif args.model_type == 'lstm':
+ inp_dim = get_input_dim(args.mods)
+ model = ResNetLSTM(inp_dim, device).to(device)
+ elif args.model_type == 'conv1d':
+ inp_dim = get_input_dim(args.mods)
+ model = ResNetConv1D(inp_dim, device).to(device)
+ elif args.model_type == 'tom_cm':
+ model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ elif args.model_type == 'tom_impl':
+ model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ elif args.model_type == 'tom_sl':
+ model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
+ elif args.model_type == 'tom_single':
+ model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
+ elif args.model_type == 'tom_cm_xl':
+ model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ else: raise NotImplementedError
+ model.load_state_dict(torch.load(args.load_model_path, map_location=device))
+ model.device = device
+
+ model.eval()
+
+ if args.save_preds:
+ # Define the output file path
+ folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}'
+ if not os.path.exists(folder_path):
+ os.makedirs(folder_path)
+ print(f'Saving predictions in {folder_path}.')
+
+ print('Testing...')
+ num_correct = 0
+ cnt = 0
+ with torch.no_grad():
+ for j, batch in tqdm(enumerate(test_dataloader)):
+ frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
+ if frames is not None: frames = frames.to(device, non_blocking=True)
+ if poses is not None: poses = poses.to(device, non_blocking=True)
+ if gazes is not None: gazes = gazes.to(device, non_blocking=True)
+ if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True)
+ if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True)
+ pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs)
+ pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
+ pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
+ labels = torch.reshape(labels, (-1, 2)).to(device)
+ batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
+ pred_left_labels, pred_right_labels, labels, sequence_lengths
+ )
+ cnt += batch_num_pred
+ num_correct += batch_num_correct
+
+ if args.save_preds:
+ torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
+ data = [(
+ i,
+ torch.argmax(pred_left_labels[i]).cpu().numpy(),
+ torch.argmax(pred_right_labels[i]).cpu().numpy(),
+ labels[:, 0][i].cpu().numpy(),
+ labels[:, 1][i].cpu().numpy()) for i in range(len(labels))
+ ]
+ header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label']
+ with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file:
+ writer = csv.writer(file)
+ writer.writerow(header) # Write the header row
+ writer.writerows(data) # Write the data rows
+
+ test_acc = num_correct / cnt
+ print("Test accuracy: {}".format(num_correct / cnt))
+
+ with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
+ f.write(str(test_acc))
+ f.close()
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ # Define the command-line arguments
+ parser.add_argument('--gpu_id', type=int)
+ parser.add_argument('--presaved', type=int, default=128)
+ parser.add_argument('--non_blocking', action='store_true')
+ parser.add_argument('--num_workers', type=int, default=4)
+ parser.add_argument('--pin_memory', action='store_true')
+ parser.add_argument('--model_type', type=str)
+ parser.add_argument('--aggr', type=str, default='mult', required=False)
+ parser.add_argument('--use_resnet', action='store_true')
+ parser.add_argument('--hidden_dim', type=int, default=64)
+ parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
+ parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
+ parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame')
+ parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose')
+ parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze')
+ parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels')
+ parser.add_argument('--ocr_graph_path', type=str, default='')
+ parser.add_argument('--label_path', type=str, default='outfile')
+ parser.add_argument('--save_path', type=str, default='experiments/')
+ parser.add_argument('--test_frames', type=str, default=None)
+ parser.add_argument('--median', type=int, default=None)
+ parser.add_argument('--load_model_path', type=str)
+ parser.add_argument('--dropout', type=float, default=0.0)
+ parser.add_argument('--save_preds', action='store_true')
+
+ # Parse the command-line arguments
+ args = parser.parse_args()
+
+ if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
+ if not args.aggr:
+ parser.error("The choosen --model_type requires --aggr")
+ if args.model_type == 'tom_sl' and not args.tom_weight:
+ parser.error("The choosen --model_type requires --tom_weight")
+
+ test(args)
\ No newline at end of file
diff --git a/boss/train.py b/boss/train.py
new file mode 100644
index 0000000..710406c
--- /dev/null
+++ b/boss/train.py
@@ -0,0 +1,324 @@
+import torch
+import os
+import argparse
+import numpy as np
+import random
+import datetime
+import wandb
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+import torch.nn as nn
+from torch.optim.lr_scheduler import CosineAnnealingLR
+
+from dataloader import Data
+from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
+from models.tom_implicit import ImplicitToMnet
+from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
+from models.tom_sl import SLToMnet
+from models.tom_tf import TFToMnet
+from models.single_mindnet import SingleMindNet
+from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim
+
+
+def train(args):
+
+ if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
+ flatten_dim = 2
+ else:
+ flatten_dim = 1
+
+ train_dataset = Data(
+ args.train_frame_path,
+ args.label_path,
+ args.train_pose_path,
+ args.train_gaze_path,
+ args.train_bbox_path,
+ args.ocr_graph_path,
+ presaved=args.presaved,
+ flatten_dim=flatten_dim
+ )
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ num_workers=args.num_workers,
+ collate_fn=pad_collate,
+ pin_memory=args.pin_memory
+ )
+ val_dataset = Data(
+ args.val_frame_path,
+ args.label_path,
+ args.val_pose_path,
+ args.val_gaze_path,
+ args.val_bbox_path,
+ args.ocr_graph_path,
+ presaved=args.presaved,
+ flatten_dim=flatten_dim
+ )
+ val_dataloader = DataLoader(
+ val_dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.num_workers,
+ collate_fn=pad_collate,
+ pin_memory=args.pin_memory
+ )
+ device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
+ if args.model_type == 'resnet':
+ inp_dim = get_input_dim(args.mods)
+ model = ResNet(inp_dim, device).to(device)
+ elif args.model_type == 'gru':
+ inp_dim = get_input_dim(args.mods)
+ model = ResNetGRU(inp_dim, device).to(device)
+ elif args.model_type == 'lstm':
+ inp_dim = get_input_dim(args.mods)
+ model = ResNetLSTM(inp_dim, device).to(device)
+ elif args.model_type == 'conv1d':
+ inp_dim = get_input_dim(args.mods)
+ model = ResNetConv1D(inp_dim, device).to(device)
+ elif args.model_type == 'tom_cm':
+ model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ elif args.model_type == 'tom_cm_xl':
+ model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ elif args.model_type == 'tom_impl':
+ model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ elif args.model_type == 'tom_sl':
+ model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
+ elif args.model_type == 'tom_single':
+ model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
+ elif args.model_type == 'tom_tf':
+ model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
+ else: raise NotImplementedError
+ if args.resume_from_checkpoint is not None:
+ model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+ if args.scheduler == None:
+ scheduler = None
+ else:
+ scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
+ if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1)
+ else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device)
+ stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}}
+
+ max_val_classification_acc = 0
+ max_val_classification_epoch = None
+ counter = 0
+
+ print(f'Number of parameters: {count_parameters(model)}')
+
+ for i in range(args.num_epoch):
+ # training
+ print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
+ temp_train_classification_loss = []
+ epoch_num_correct = 0
+ epoch_cnt = 0
+ model.train()
+ for j, batch in tqdm(enumerate(train_dataloader)):
+ if args.use_mixup:
+ frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27)
+ else:
+ frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
+ if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
+ if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
+ if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
+ if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
+ if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
+ pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
+ pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
+ pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
+ if args.use_mixup:
+ labels = torch.reshape(labels, (-1, 54)).to(device)
+ batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup(
+ pred_left_labels, pred_right_labels, labels, sequence_lengths
+ )
+ loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:])
+ else:
+ labels = torch.reshape(labels, (-1, 2)).to(device)
+ batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
+ pred_left_labels, pred_right_labels, labels, sequence_lengths
+ )
+ loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1])
+ epoch_cnt += batch_num_pred
+ epoch_num_correct += batch_num_correct
+ temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2)
+
+ optimizer.zero_grad()
+ if args.clip_grad_norm is not None:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
+ loss.backward()
+ optimizer.step()
+
+ if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']})
+ print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format(
+ i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc)
+ )
+
+ if scheduler: scheduler.step()
+
+ print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format(
+ i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
+ )
+ stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt)
+ stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt)
+
+ # validation
+ print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
+ temp_val_classification_loss = []
+ epoch_num_correct = 0
+ epoch_cnt = 0
+ model.eval()
+ with torch.no_grad():
+ for j, batch in tqdm(enumerate(val_dataloader)):
+ frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
+ if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
+ if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
+ if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
+ if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
+ if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
+ pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
+ pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
+ pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
+ labels = torch.reshape(labels, (-1, 2)).to(device)
+ batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
+ pred_left_labels, pred_right_labels, labels, sequence_lengths
+ )
+ epoch_cnt += batch_num_pred
+ epoch_num_correct += batch_num_correct
+ loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1])
+ temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2)
+
+ if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()})
+ print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format(
+ i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc)
+ )
+
+ print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format(
+ i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
+ )
+
+ cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt
+ cls_acc = epoch_num_correct / epoch_cnt
+ stats['val']['cls_loss'].append(cls_loss)
+ stats['val']['cls_acc'].append(cls_acc)
+ if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i})
+
+ # check for best stat/model using validation accuracy
+ if stats['val']['cls_acc'][-1] >= max_val_classification_acc:
+ max_val_classification_acc = stats['val']['cls_acc'][-1]
+ max_val_classification_epoch = i+1
+ torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
+ counter = 0
+ else:
+ counter += 1
+ print(f'EarlyStopping counter: {counter} out of {args.patience}.')
+ if counter >= args.patience:
+ break
+
+ with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
+ f.write('{}\n'.format(CFG))
+ f.write('{}\n'.format(stats))
+ f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc))
+ f.close()
+
+ print(f'Results saved in {experiment_save_path}')
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ # Define the command-line arguments
+ parser.add_argument('--gpu_id', type=int)
+ parser.add_argument('--seed', type=int, default=1)
+ parser.add_argument('--logger', action='store_true')
+ parser.add_argument('--presaved', type=int, default=128)
+ parser.add_argument('--clip_grad_norm', type=float, default=0.5)
+ parser.add_argument('--use_mixup', action='store_true')
+ parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
+ parser.add_argument('--non_blocking', action='store_true')
+ parser.add_argument('--patience', type=int, default=99)
+ parser.add_argument('--batch_size', type=int, default=4)
+ parser.add_argument('--num_workers', type=int, default=4)
+ parser.add_argument('--pin_memory', action='store_true')
+ parser.add_argument('--num_epoch', type=int, default=300)
+ parser.add_argument('--lr', type=float, default=4e-4)
+ parser.add_argument('--scheduler', type=str, default=None)
+ parser.add_argument('--dropout', type=float, default=0.1)
+ parser.add_argument('--weight_decay', type=float, default=0.005)
+ parser.add_argument('--label_smoothing', type=float, default=0.1)
+ parser.add_argument('--model_type', type=str)
+ parser.add_argument('--aggr', type=str, default='mult', required=False)
+ parser.add_argument('--use_resnet', action='store_true')
+ parser.add_argument('--hidden_dim', type=int, default=64)
+ parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
+ parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
+ parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame')
+ parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose')
+ parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze')
+ parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels')
+ parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame')
+ parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose')
+ parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze')
+ parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels')
+ parser.add_argument('--ocr_graph_path', type=str, default='')
+ parser.add_argument('--label_path', type=str, default='outfile')
+ parser.add_argument('--save_path', type=str, default='experiments/')
+ parser.add_argument('--resume_from_checkpoint', type=str, default=None)
+
+ # Parse the command-line arguments
+ args = parser.parse_args()
+
+ if args.use_mixup and not args.mixup_alpha:
+ parser.error("--use_mixup requires --mixup_alpha")
+ if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
+ if not args.aggr:
+ parser.error("The choosen --model_type requires --aggr")
+ if args.model_type == 'tom_sl' and not args.tom_weight:
+ parser.error("The choosen --model_type requires --tom_weight")
+
+ # get experiment ID
+ experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
+ if not os.path.exists(args.save_path):
+ os.makedirs(args.save_path, exist_ok=True)
+ experiment_save_path = os.path.join(args.save_path, experiment_id)
+ os.makedirs(experiment_save_path, exist_ok=True)
+
+ CFG = {
+ 'use_ocr_custom_loss': 0,
+ 'presaved': args.presaved,
+ 'batch_size': args.batch_size,
+ 'num_epoch': args.num_epoch,
+ 'lr': args.lr,
+ 'scheduler': args.scheduler,
+ 'weight_decay': args.weight_decay,
+ 'model_type': args.model_type,
+ 'use_resnet': args.use_resnet,
+ 'hidden_dim': args.hidden_dim,
+ 'tom_weight': args.tom_weight,
+ 'dropout': args.dropout,
+ 'label_smoothing': args.label_smoothing,
+ 'clip_grad_norm': args.clip_grad_norm,
+ 'use_mixup': args.use_mixup,
+ 'mixup_alpha': args.mixup_alpha,
+ 'non_blocking_tensors': args.non_blocking,
+ 'patience': args.patience,
+ 'pin_memory': args.pin_memory,
+ 'resume_from_checkpoint': args.resume_from_checkpoint,
+ 'aggr': args.aggr,
+ 'mods': args.mods,
+ 'save_path': experiment_save_path ,
+ 'seed': args.seed
+ }
+
+ print(CFG)
+ print(f'Saving results in {experiment_save_path}')
+
+ # set seed values
+ if args.logger:
+ wandb.init(project="boss", config=CFG)
+ os.environ['PYTHONHASHSEED'] = str(args.seed)
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ random.seed(args.seed)
+
+ train(args)
diff --git a/boss/utils.py b/boss/utils.py
new file mode 100644
index 0000000..f89e81d
--- /dev/null
+++ b/boss/utils.py
@@ -0,0 +1,769 @@
+import os
+import torch
+import numpy as np
+from torch.nn.utils.rnn import pad_sequence
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+from sklearn.preprocessing import StandardScaler
+from sklearn.decomposition import PCA
+from sklearn.manifold import TSNE
+#from umap import UMAP
+import json
+import seaborn as sns
+import pandas as pd
+from natsort import natsorted
+import argparse
+import seaborn as sns
+
+
+
+COLORS_DM = {
+ "red": (242/256, 165/256, 179/256),
+ "blue": (195/256, 219/256, 252/256),
+ "green": (156/256, 228/256, 213/256),
+ "yellow": (250/256, 236/256, 144/256),
+ "violet": (207/256, 187/256, 244/256),
+ "orange": (244/256, 188/256, 154/256)
+}
+
+MTOM_COLORS = {
+ "MN1": (110/255, 117/255, 161/255),
+ "MN2": (179/255, 106/255, 98/255),
+ "Base": (193/255, 198/255, 208/255),
+ "CG": (170/255, 129/255, 42/255),
+ "IC": (97/255, 112/255, 83/255),
+ "DB": (144/255, 63/255, 110/255)
+}
+
+COLORS = sns.color_palette()
+
+OBJECTS = [
+ 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
+ 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
+ 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
+ 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
+]
+
+tried_once = [2, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 19, 20, 21, 22, 23, 24,
+ 25, 26, 27, 28, 29, 30, 32, 38, 40, 41, 42, 44, 47, 50, 51, 52, 53, 55,
+ 56, 57, 61, 63, 65, 67, 68, 70, 71, 72, 73, 74, 76, 80, 81, 83, 85, 87,
+ 88, 89, 90, 92, 93, 96, 97, 99, 101, 102, 105, 106, 108, 110, 111, 112,
+ 113, 114, 116, 118, 121, 123, 125, 131, 132, 134, 135, 140, 142, 143,
+ 145, 146, 148, 149, 151, 152, 154, 155, 156, 157, 160, 161, 162, 165,
+ 169, 170, 171, 173, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185,
+ 186, 187, 190, 191, 194, 196, 203, 204, 206, 207, 208, 209, 210, 211,
+ 213, 214, 216, 218, 219, 220, 222, 225, 227, 228, 229, 232, 233, 235,
+ 236, 237, 238, 239, 242, 243, 246, 247, 249, 251, 252, 254, 255, 256,
+ 257, 260, 261, 262, 263, 265, 266, 268, 270, 272, 275, 277, 278, 279,
+ 281, 282, 287, 290, 296, 298]
+
+tried_twice = [1, 3, 11, 13, 14, 16, 17, 31, 33, 35, 36, 37, 39,
+ 43, 45, 49, 54, 58, 59, 60, 62, 64, 66, 69, 75, 77, 79, 82, 84, 86,
+ 91, 94, 95, 98, 100, 103, 104, 107, 109, 115, 117, 119, 120, 122,
+ 124, 126, 127, 128, 129, 130, 133, 136, 137, 138, 139, 141, 144,
+ 147, 150, 153, 158, 159, 164, 166, 167, 168, 172, 174, 177, 188,
+ 189, 192, 193, 195, 197, 198, 199, 200, 201, 202, 205, 212, 215,
+ 217, 221, 223, 224, 226, 230, 231, 234, 240, 241, 244, 245, 248,
+ 250, 253, 258, 259, 264, 267, 269, 271, 273, 274, 276, 280, 283,
+ 284, 285, 286, 288, 291, 292, 293, 294, 295, 297, 299]
+
+tried_thrice = [34, 46, 48, 78, 163, 289]
+
+friends = [i for i in range(150, 300)]
+strangers = [i for i in range(0, 150)]
+
+
+def get_input_dim(modalities):
+ dimensions = {
+ 'rgb': 512,
+ 'pose': 150,
+ 'gaze': 6,
+ 'ocr': 64,
+ 'bbox': 108
+ }
+ sum_of_dimensions = sum(dimensions[modality] for modality in modalities)
+ return sum_of_dimensions
+
+
+def count_parameters(model):
+ #return sum(p.numel() for p in model.parameters() if p.requires_grad)
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+ return sum([np.prod(p.size()) for p in model_parameters])
+
+
+def onehot(label, n_classes):
+ if len(label.size()) == 3:
+ batch_size, seq_len, _ = label.size()
+ label_1_2d = label[:,:,0].contiguous().view(batch_size*seq_len, 1)
+ label_2_2d = label[:,:,1].contiguous().view(batch_size*seq_len, 1)
+ #onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1)
+ #onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1)
+ #onehot_2d = torch.cat((onehot_1_2d, onehot_2_2d), dim=1)
+ #onehot_3d = onehot_2d.view(batch_size, seq_len, 2*n_classes)
+ onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1).view(-1, seq_len, n_classes)
+ onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1).view(-1, seq_len, n_classes)
+ onehot_3d = torch.cat((onehot_1_2d, onehot_2_2d), dim=2)
+ return onehot_3d
+ else:
+ return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
+
+
+def mixup(data, alpha, n_classes):
+ lam = torch.FloatTensor([np.random.beta(alpha, alpha)])
+ frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = data
+ indices = torch.randperm(labels.size(0))
+ # labels
+ labels2 = labels[indices]
+ labels = onehot(labels, n_classes)
+ labels2 = onehot(labels2, n_classes)
+ labels = labels * lam + labels2 * (1 - lam)
+ # frames
+ frames2 = frames[indices]
+ frames = frames * lam + frames2 * (1 - lam)
+ # poses
+ poses2 = poses[indices]
+ poses = poses * lam + poses2 * (1 - lam)
+ # gazes
+ gazes2 = gazes[indices]
+ gazes = gazes * lam + gazes2 * (1 - lam)
+ # bboxes
+ bboxes2 = bboxes[indices]
+ bboxes = bboxes * lam + bboxes2 * (1 - lam)
+ return frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths
+
+
+def get_classification_accuracy(pred_left_labels, pred_right_labels, labels, sequence_lengths):
+ max_len = max(sequence_lengths)
+ pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
+ pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
+ labels = torch.reshape(labels, (-1, max_len, 2))
+ left_correct = torch.argmax(pred_left_labels, 2) == labels[:,:,0]
+ right_correct = torch.argmax(pred_right_labels, 2) == labels[:,:,1]
+ num_pred = sum(sequence_lengths) * 2
+ num_correct = 0
+ for i in range(len(sequence_lengths)):
+ size = sequence_lengths[i]
+ num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
+ acc = num_correct / num_pred
+ return acc, num_correct, num_pred
+
+
+def get_classification_accuracy_mixup(pred_left_labels, pred_right_labels, labels, sequence_lengths):
+ max_len = max(sequence_lengths)
+ pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
+ pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
+ labels = torch.reshape(labels, (-1, max_len, 54))
+ left_correct = torch.argmax(pred_left_labels, 2) == torch.argmax(labels[:,:,:27], 2)
+ right_correct = torch.argmax(pred_right_labels, 2) == torch.argmax(labels[:,:,27:], 2)
+ num_pred = sum(sequence_lengths) * 2
+ num_correct = 0
+ for i in range(len(sequence_lengths)):
+ size = sequence_lengths[i]
+ num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
+ acc = num_correct / num_pred
+ return acc, num_correct, num_pred
+
+
+def pad_collate(batch):
+ (aa, bb, cc, dd, ee, ff) = zip(*batch)
+ seq_lens = [len(a) for a in aa]
+ aa_pad = pad_sequence(aa, batch_first=True, padding_value=0)
+ bb_pad = pad_sequence(bb, batch_first=True, padding_value=-1)
+ if cc[0] is not None:
+ cc_pad = pad_sequence(cc, batch_first=True, padding_value=0)
+ else:
+ cc_pad = None
+ if dd[0] is not None:
+ dd_pad = pad_sequence(dd, batch_first=True, padding_value=0)
+ else:
+ dd_pad = None
+ if ee[0] is not None:
+ ee_pad = pad_sequence(ee, batch_first=True, padding_value=0)
+ else:
+ ee_pad = None
+ if ff[0] is not None:
+ ff_pad = pad_sequence(ff, batch_first=True, padding_value=0)
+ else:
+ ff_pad = None
+ return aa_pad, bb_pad, cc_pad, dd_pad, ee_pad, ff_pad, seq_lens
+
+
+def ocr_loss(pL, lL, pR, lR, ocr, loss, eta=10.0, mode='abs'):
+ """
+ Custom loss based on the negative OCRL matrix.
+
+ Input:
+ pL: tensor of shape [batch_size, num_classes] representing left belief predictions
+ lL: tensor of shape [batch_size] representing the left belief labels
+ pR: tensor of shape [batch_size, num_classes] representing right belief predictions
+ lR: tensor of shape [batch_size] representing the right belief labels
+ ocr: negative OCR matrix (i.e. 1-OCR)
+ loss: loss function
+ eta: hyperparameter for the interaction term
+
+ Output:
+ Final loss resulting from weighting left and right loss with the respective
+ OCR coefficients and summing them.
+ """
+ bs = pL.shape[0]
+ if len([*lR[0].size()]) > 0: # for mixup
+ p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
+ g = torch.tensor([ocr[torch.argmax(lL[i]), torch.argmax(lR[i])] for i in range(bs)], device=pL.device)
+ else:
+ p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
+ g = torch.tensor([ocr[lL[i], lR[i]] for i in range(bs)], device=pL.device)
+ left_loss = torch.mean(loss(pL, lL))
+ right_loss = torch.mean(loss(pR, lR))
+ if mode == 'abs':
+ interaction_loss = torch.mean(torch.abs(g - p))
+ elif mode == 'mse':
+ interaction_loss = torch.mean(torch.pow(g - p, 2))
+ else: raise NotImplementedError
+ eta = (left_loss + right_loss) / interaction_loss
+ interaction_loss = interaction_loss * eta
+ print(f"Left: {left_loss} --- Right: {right_loss} --- Interaction: {interaction_loss}")
+ return left_loss + right_loss + interaction_loss
+
+
+def spherical2cartesial(x):
+ """
+ From https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing#scrollTo=78QhNw4MYSYp
+ """
+ output = torch.zeros(x.size(0),3)
+ output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
+ output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
+ output[:,1] = torch.sin(x[:,1])
+ return output
+
+
+def presave():
+ from PIL import Image
+ from torchvision import transforms
+ import os
+ from natsort import natsorted
+ import pickle as pkl
+
+ preprocess = transforms.Compose([
+ transforms.Resize((128, 128)),
+ #transforms.Resize(256),
+ #transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ])
+
+ frame_path = '/scratch/bortoletto/data/boss/test/frame'
+ frame_dirs = os.listdir(frame_path)
+ frame_paths = []
+ for frame_dir in natsorted(frame_dirs):
+ paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
+ frame_paths.append(paths)
+
+ save_folder = '/scratch/bortoletto/data/boss/presaved128/'+frame_path.split('/')[-2]+'/'+frame_path.split('/')[-1]
+ if not os.path.exists(save_folder):
+ os.makedirs(save_folder)
+ print(save_folder, 'created')
+
+ for video in natsorted(frame_paths):
+ print(video[0].split('/')[-2], end='\r')
+ images = [preprocess(Image.open(i)) for i in video]
+ strings = video[0].split('/')
+ with open(save_folder+'/'+strings[7]+'.pkl', 'wb') as f:
+ pkl.dump(images, f)
+
+
+def compute_cosine_similarity(tensor1, tensor2):
+ return F.cosine_similarity(tensor1, tensor2, dim=-1)
+
+
+def find_most_similar_embedding(rgb, pose, gaze, ocr, bbox, repr):
+ gaze_similarity = compute_cosine_similarity(gaze, repr)
+ pose_similarity = compute_cosine_similarity(pose, repr)
+ ocr_similarity = compute_cosine_similarity(ocr, repr)
+ bbox_similarity = compute_cosine_similarity(bbox, repr)
+ rgb_similarity = compute_cosine_similarity(rgb, repr)
+ similarities = torch.stack([gaze_similarity, pose_similarity, ocr_similarity, bbox_similarity, rgb_similarity])
+ max_index = torch.argmax(similarities, dim=0)
+ main_modality = []
+ main_modality_name = []
+ for idx in max_index:
+ if idx == 0:
+ main_modality.append(gaze)
+ main_modality_name.append('gaze')
+ elif idx == 1:
+ main_modality.append(pose)
+ main_modality_name.append('pose')
+ elif idx == 2:
+ main_modality.append(ocr)
+ main_modality_name.append('ocr')
+ elif idx == 3:
+ main_modality.append(bbox)
+ main_modality_name.append('bbox')
+ else:
+ main_modality.append(rgb)
+ main_modality_name.append('rgb')
+ return main_modality, main_modality_name
+
+
+def plot_similarity_histogram(values, filename, labels):
+ def count_elements(list_of_strings, target_list):
+ counts = []
+ for element in list_of_strings:
+ count = target_list.count(element)
+ counts.append(count)
+ return counts
+ fig, ax = plt.subplots(figsize=(12,4))
+ colors = ["red", "blue", "green", "violet"]
+ # Remove spines
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ colors = [MTOM_COLORS['MN1'], MTOM_COLORS['MN2'], MTOM_COLORS['CG'], MTOM_COLORS['CG']]
+ alphas = [0.6, 0.6, 0.6, 0.6]
+ edgecolors = ['black', 'black', MTOM_COLORS['MN1'], MTOM_COLORS['MN2']]
+ linewidths = [1.0, 1.0, 5.0, 5.0]
+ if isinstance(values[0], list):
+ num_lists = len(values)
+ bar_width = 0.8 / num_lists
+ for i, val in enumerate(values):
+ unique_strings = ['rgb', 'pose', 'gaze', 'bbox', 'ocr']
+ counts = count_elements(unique_strings, val)
+ x = np.arange(len(unique_strings))
+ x_shifted = x + (i - (len(labels) - 1) / 2) * bar_width #x - (bar_width * num_lists / 2) + (bar_width * i)
+ ax.bar(x_shifted, counts, width=bar_width, label=f'{labels[i]}', color=colors[i], edgecolor=edgecolors[i], linewidth=linewidths[i], alpha=alphas[i])
+ ax.set_xlabel('Modality', fontsize=18)
+ ax.set_ylabel('Counts', fontsize=18)
+ ax.set_xticks(np.arange(len(unique_strings)))
+ ax.set_xticklabels(unique_strings, fontsize=18)
+ ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
+ ax.legend(fontsize=18)
+ ax.grid(axis='y')
+ plt.savefig(filename, bbox_inches='tight')
+ else:
+ unique_strings, counts = np.unique(values, return_counts=True)
+ ax.bar(unique_strings, counts)
+ ax.set_xlabel('Modality')
+ ax.set_ylabel('Counts')
+ plt.savefig(filename, bbox_inches='tight')
+
+
+def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
+ means = []
+ stds = []
+ for key, values in data.items():
+ mean = np.mean(values)
+ std = np.std(values)
+ means.append(mean)
+ stds.append(std)
+ fig, ax = plt.subplots(figsize=size)
+ x = np.arange(len(data))
+ width = 0.6
+ rects1 = ax.bar(
+ x, means, width, label='Mean', yerr=stds,
+ capsize=5,
+ color='teal' if colors is None else colors,
+ edgecolor='black',
+ linewidth=1.5,
+ alpha=0.6
+ )
+ ax.set_ylabel('Accuracy', fontsize=18)
+ #ax.set_title('Mean and Standard Deviation of Results', fontsize=14)
+ if filename == 'results/all':
+ xticklabels = [
+ 'Base',
+ 'DB',
+ 'CG$\otimes$', 'CG$\oplus$', 'CG$\odot$', 'CG$\parallel$',
+ 'IC$\otimes$', 'IC$\oplus$', 'IC$\odot$', 'IC$\parallel$',
+ 'CNN', 'CNN+GRU', 'CNN+LSTM', 'CNN+Conv1D'
+ ]
+ else:
+ xticklabels = list(data.keys())
+ ax.set_xticks(np.arange(len(xticklabels)))
+ ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=16) #data.keys(), rotation=rotation, fontsize=16)
+ ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
+ #ax.grid(axis='y')
+ #ax.set_axisbelow(True)
+ #ax.legend(loc='upper right')
+ # Add value labels above each bar, avoiding overlapping with error bars
+ for rect, std in zip(rects1, stds):
+ height = rect.get_height()
+ if height + std < ax.get_ylim()[1]:
+ ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
+ xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
+ else:
+ ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
+ xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
+ # Remove spines
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.grid(axis='x')
+ plt.tight_layout()
+ plt.savefig(f'{filename}.pdf', bbox_inches='tight')
+
+
+def plot_confusion_matrices(left_cm, right_cm, labels, title, annot=True):
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
+ left_xticklabels = [OBJECTS[i] for i in labels[0]]
+ left_yticklabels = [OBJECTS[i] for i in labels[1]]
+ right_xticklabels = [OBJECTS[i] for i in labels[2]]
+ right_yticklabels = [OBJECTS[i] for i in labels[3]]
+ sns.heatmap(
+ left_cm,
+ annot=annot,
+ fmt='.0f',
+ cmap='Blues',
+ cbar=False,
+ xticklabels=left_xticklabels,
+ yticklabels=left_yticklabels,
+ ax=ax1
+ )
+ ax1.set_xlabel('Predicted')
+ ax1.set_ylabel('True')
+ ax1.set_title('Left Confusion Matrix')
+ sns.heatmap(
+ right_cm,
+ annot=annot,
+ fmt='.0f',
+ cmap='Blues',
+ cbar=False, #True if annot is False else False,
+ xticklabels=right_xticklabels,
+ yticklabels=right_yticklabels,
+ ax=ax2
+ )
+ ax2.set_xlabel('Predicted')
+ ax2.set_ylabel('True')
+ ax2.set_title('Right Confusion Matrix')
+ #plt.suptitle(title)
+ plt.tight_layout()
+ plt.savefig(title + '.pdf')
+ plt.close()
+
+def plot_confusion_matrix(confusion_matrix, labels, title, annot=True):
+ plt.figure(figsize=(6, 6))
+ xticklabels = [OBJECTS[i] for i in labels[0]]
+ yticklabels = [OBJECTS[i] for i in labels[1]]
+ sns.heatmap(
+ confusion_matrix,
+ annot=annot,
+ fmt='.0f',
+ cmap='Blues',
+ cbar=False,
+ xticklabels=xticklabels,
+ yticklabels=yticklabels
+ )
+ plt.xlabel('Predicted')
+ plt.ylabel('True')
+ #plt.title(title)
+ plt.tight_layout()
+ plt.savefig(title + '.pdf')
+ plt.close()
+
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--task', type=str, choices=['confusion', 'similarity', 'scores', 'friends_vs_strangers'])
+ parser.add_argument('--folder_path', type=str)
+
+ args = parser.parse_args()
+
+ if args.task == 'similarity':
+ sns.set_theme(style='white')
+ else:
+ sns.set_theme(style='whitegrid')
+
+ # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
+
+ if args.task == 'similarity':
+
+ out_left_main_mods_full_test = []
+ out_right_main_mods_full_test = []
+ cell_left_main_mods_full_test = []
+ cell_right_main_mods_full_test = []
+ cm_left_main_mods_full_test = []
+ cm_right_main_mods_full_test = []
+
+ for i in range(60):
+
+ print(f'Computing analysis for test video {i}...', end='\r')
+
+ emb_file = os.path.join(args.folder_path, f'{i}.pt')
+ data = torch.load(emb_file)
+ if len(data) == 14: # implicit
+ model = 'impl'
+ out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
+ out_left = out_left.squeeze(0)
+ cell_left = cell_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ cell_right = cell_right.squeeze(0)
+ elif len(data) == 13: # common mind
+ model = 'cm'
+ out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
+ out_left = out_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ common_mind = common_mind.squeeze(0)
+ elif len(data) == 12: # speaker-listener
+ model = 'sl'
+ out_left, out_right, feats = data[0], data[1], data[2:]
+ out_left = out_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
+
+ # ====== PCA for left and right embeddings ====== #
+
+ out_left_and_right = np.concatenate((out_left, out_right), axis=0)
+
+ pca = PCA(n_components=2)
+ pca_result = pca.fit_transform(out_left_and_right)
+
+ # Separate the PCA results for each tensor
+ pca_result_left = pca_result[:out_left.shape[0]]
+ pca_result_right = pca_result[out_right.shape[0]:]
+
+ plt.figure(figsize=(7,6))
+ plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='MindNet$_1$', color=MTOM_COLORS['MN1'], s=100)
+ plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='MindNet$_2$', color=MTOM_COLORS['MN2'], s=100)
+ plt.xlabel('Principal Component 1', fontsize=30)
+ plt.ylabel('Principal Component 2', fontsize=30)
+ plt.grid(False)
+ plt.legend(fontsize=30)
+ plt.tight_layout()
+ plt.savefig(f'{args.folder_path}/{i}_pca.pdf')
+ plt.close()
+
+ # ====== Feature similarity ====== #
+
+ if len(feats) == 10:
+ left_rgb, left_ocr, left_pose, left_gaze, left_bbox, right_rgb, right_ocr, right_pose, right_gaze, right_bbox = feats
+ left_rgb = left_rgb.squeeze(0)
+ left_ocr = left_ocr.squeeze(0)
+ left_pose = left_pose.squeeze(0)
+ left_gaze = left_gaze.squeeze(0)
+ left_bbox = left_bbox.squeeze(0)
+ right_rgb = right_rgb.squeeze(0)
+ right_ocr = right_ocr.squeeze(0)
+ right_pose = right_pose.squeeze(0)
+ right_gaze = right_gaze.squeeze(0)
+ right_bbox = right_bbox.squeeze(0)
+ else: raise NotImplementedError("Ablated versions are not supported yet.")
+
+ # out: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
+ # cell: [1, 1, dim] --------- squeeze ---> [1, dim]
+ # cm: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
+ # feat: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
+
+ _, out_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, out_left)
+ _, out_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, out_right)
+ out_left_main_mods_full_test += out_left_main_mods
+ out_right_main_mods_full_test += out_right_main_mods
+ if model == 'impl':
+ _, cell_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, cell_left)
+ _, cell_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, cell_right)
+ cell_left_main_mods_full_test += cell_left_main_mods
+ cell_right_main_mods_full_test += cell_right_main_mods
+ if model == 'cm':
+ _, cm_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, common_mind)
+ _, cm_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, common_mind)
+ cm_left_main_mods_full_test += cm_left_main_mods
+ cm_right_main_mods_full_test += cm_right_main_mods
+
+ if model == 'impl':
+ plot_similarity_histogram(
+ [out_left_main_mods_full_test,
+ out_right_main_mods_full_test,
+ cell_left_main_mods_full_test,
+ cell_right_main_mods_full_test],
+ f'{args.folder_path}/boss_similartiy_impl_hist_all.pdf',
+ [r'$h_1$', r'$h_2$', r'$c_1$', r'$c_2$']
+ )
+ elif model == 'cm':
+ plot_similarity_histogram(
+ [out_left_main_mods_full_test,
+ out_right_main_mods_full_test,
+ cm_left_main_mods_full_test,
+ cm_right_main_mods_full_test],
+ f'{args.folder_path}/boss_similarity_cm_hist_all.pdf',
+ [r'$h_1$', r'$h_2$', r'$cg$ w/ 1', r'$cg$ w/ 2']
+ )
+ elif model == 'sl':
+ plot_similarity_histogram(
+ [out_left_main_mods_full_test,
+ out_right_main_mods_full_test],
+ f'{args.folder_path}/boss_similarity_sl_hist_all.py',
+ [r'$h_1$', r'$h_2$']
+ )
+
+ # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
+
+ elif args.task == 'scores':
+
+
+ all = json.load(open('results/all.json'))
+ abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
+
+ plot_scores_histogram(
+ all,
+ filename='results/all',
+ size=(10,4),
+ rotation=45,
+ colors=[COLORS[7]] + [MTOM_COLORS['DB']] + [MTOM_COLORS['CG']]*4 + [MTOM_COLORS['IC']]*4 + [COLORS[4]]*4
+ )
+
+ plot_scores_histogram(
+ abl_tom_cm_concat,
+ size=(10,4),
+ filename='results/abl_cm_concat',
+ colors=[MTOM_COLORS['CG']]*8
+ )
+
+ # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
+
+ elif args.task == 'confusion':
+
+ # List to store confusion matrices
+ all_df = []
+ left_matrices = []
+ right_matrices = []
+
+ # Iterate over the CSV files
+ for filename in natsorted(os.listdir(args.folder_path)):
+ if filename.endswith('.csv'):
+ file_path = os.path.join(args.folder_path, filename)
+
+ print(f'Processing {file_path}...')
+
+ # Read the CSV file
+ df = pd.read_csv(file_path)
+
+ # Extract the left and right labels and predictions
+ left_labels = df['left_label']
+ right_labels = df['right_label']
+ left_preds = df['left_pred']
+ right_preds = df['right_pred']
+
+ # Calculate the confusion matrices for left and right
+ left_cm = pd.crosstab(left_labels, left_preds)
+ right_cm = pd.crosstab(right_labels, right_preds)
+
+ # Append the confusion matrices to the list
+ left_matrices.append(left_cm)
+ right_matrices.append(right_cm)
+ all_df.append(df)
+
+ # Plot and save the confusion matrices for left and right
+ for i, cm in enumerate(zip(left_matrices, right_matrices)):
+ print(f'Computing confusion matrices for video {i}...', end='\r')
+ plot_confusion_matrices(
+ cm[0],
+ cm[1],
+ labels=[cm[0].columns, cm[0].index, cm[1].columns, cm[1].index],
+ title=f'{args.folder_path}/{i}_cm'
+ )
+
+ merged_df = pd.concat(all_df).reset_index(drop=True)
+
+ merged_left_cm = pd.crosstab(merged_df['left_label'], merged_df['left_pred'])
+ merged_right_cm = pd.crosstab(merged_df['right_label'], merged_df['right_pred'])
+ plot_confusion_matrices(
+ merged_left_cm,
+ merged_right_cm,
+ labels=[merged_left_cm.columns, merged_left_cm.index, merged_right_cm.columns, merged_right_cm.index],
+ title=f'{args.folder_path}/all_lr',
+ annot=False
+ )
+
+ merged_preds = pd.concat([merged_df['left_pred'], merged_df['right_pred']]).reset_index(drop=True)
+ merged_labels = pd.concat([merged_df['left_label'], merged_df['right_label']]).reset_index(drop=True)
+ merged_all_cm = pd.crosstab(merged_labels, merged_preds)
+ plot_confusion_matrix(
+ merged_all_cm,
+ labels=[merged_all_cm.columns, merged_all_cm.index],
+ title=f'{args.folder_path}/all_merged',
+ annot=False
+ )
+
+ # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
+
+ elif args.task == 'friends_vs_strangers':
+
+ # friends ids: 30-59
+ # stranger ids: 0-29
+ out_left_list = []
+ out_right_list = []
+ cell_left_list = []
+ cell_right_list = []
+ common_mind_list = []
+
+ for i in range(60):
+
+ print(f'Computing analysis for test video {i}...', end='\r')
+
+ emb_file = os.path.join(args.folder_path, f'{i}.pt')
+ data = torch.load(emb_file)
+ if len(data) == 14: # implicit
+ model = 'impl'
+ out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
+ out_left = out_left.squeeze(0)
+ cell_left = cell_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ cell_right = cell_right.squeeze(0)
+ out_left_list.append(out_left)
+ out_right_list.append(out_right)
+ cell_left_list.append(cell_left)
+ cell_right_list.append(cell_right)
+ elif len(data) == 13: # common mind
+ model = 'cm'
+ out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
+ out_left = out_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ common_mind = common_mind.squeeze(0)
+ out_left_list.append(out_left)
+ out_right_list.append(out_right)
+ common_mind_list.append(common_mind)
+ elif len(data) == 12: # speaker-listener
+ model = 'sl'
+ out_left, out_right, feats = data[0], data[1], data[2:]
+ out_left = out_left.squeeze(0)
+ out_right = out_right.squeeze(0)
+ out_left_list.append(out_left)
+ out_right_list.append(out_right)
+ else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
+
+ # ====== PCA for left and right embeddings ====== #
+
+ print('\rComputing PCA...')
+
+ strangers_nframes = sum([out_left_list[i].shape[0] for i in range(30)])
+
+ left = torch.cat(out_left_list, 0)
+ right = torch.cat(out_right_list, 0)
+
+ out_left_and_right = torch.cat([left, right], axis=0).numpy()
+ #out_left_and_right = StandardScaler().fit_transform(out_left_and_right)
+
+ pca = PCA(n_components=2)
+ pca_result = pca.fit_transform(out_left_and_right)
+ #pca_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(out_left_and_right)
+ #pca_result = UMAP().fit_transform(out_left_and_right)
+
+ # Separate the PCA results for each tensor
+ #pca_result_left = pca_result[:left.shape[0]]
+ #pca_result_right = pca_result[right.shape[0]:]
+ pca_result_strangers = pca_result[:strangers_nframes]
+ pca_result_friends = pca_result[strangers_nframes:]
+
+ #plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='Left')
+ #plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='Right')
+ plt.scatter(pca_result_friends[:, 0], pca_result_friends[:, 1], label='Friends')
+ plt.scatter(pca_result_strangers[:, 0], pca_result_strangers[:, 1], label='Strangers')
+ plt.xlabel('Principal Component 1')
+ plt.ylabel('Principal Component 2')
+ plt.legend()
+ plt.savefig(f'{args.folder_path}/friends_vs_strangers_pca.pdf')
+ plt.close()
+
+ # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
+
+ else:
+
+ raise NameError
\ No newline at end of file
diff --git a/tbd/.gitignore b/tbd/.gitignore
new file mode 100644
index 0000000..99a5b78
--- /dev/null
+++ b/tbd/.gitignore
@@ -0,0 +1,196 @@
+experiments
+wandb
+predictions
+
+
+# Created by https://www.toptal.com/developers/gitignore/api/python,linux
+# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
+
+### Linux ###
+*~
+
+# temporary files which can be created if a process still has a handle open of a deleted file
+.fuse_hidden*
+
+# KDE directory preferences
+.directory
+
+# Linux trash folder which might appear on any partition or disk
+.Trash-*
+
+# .nfs files are created when an open file is removed but is still being accessed
+.nfs*
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+### Python Patch ###
+# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
+poetry.toml
+
+# ruff
+.ruff_cache/
+
+# LSP config files
+pyrightconfig.json
+
+# End of https://www.toptal.com/developers/gitignore/api/python,linux
\ No newline at end of file
diff --git a/tbd/README.md b/tbd/README.md
new file mode 100644
index 0000000..4f2c430
--- /dev/null
+++ b/tbd/README.md
@@ -0,0 +1,16 @@
+# TBD
+
+# Data
+The original code can be found [here](https://github.com/LifengFan/Triadic-Belief-Dynamics). The dataset is not directly available but must be requested using the link to the Google form provided in the [README](https://github.com/LifengFan/Triadic-Belief-Dynamics?tab=readme-ov-file#dataset).
+
+## Installing Dependencies
+Run `conda env create -f environment.yml`.
+
+## Train
+`source run_train.sh`.
+
+## Test
+`source run_test.sh`. **Make sure to use the same random seed used for training**, otherwise the splits will be different and you will likely have a data leakage.
+
+## Visualisations
+The plots are made using `utils/fb_scores_err.py` (false belief analysis) and `utils/similarity.py` (PCA of latent representations).
diff --git a/tbd/environment.yml b/tbd/environment.yml
new file mode 100644
index 0000000..d6921cb
--- /dev/null
+++ b/tbd/environment.yml
@@ -0,0 +1,100 @@
+name: tbd
+channels:
+ - conda-forge
+ - defaults
+ - pytorch
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - ca-certificates=2023.01.10=h06a4308_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - ncurses=6.4=h6a678d5_0
+ - openssl=1.1.1t=h7f8727e_0
+ - pip=23.0.1=py38h06a4308_0
+ - python=3.8.10=h12debd9_8
+ - readline=8.2=h5eee18b_0
+ - setuptools=66.0.0=py38h06a4308_0
+ - sqlite=3.41.2=h5eee18b_0
+ - tk=8.6.12=h1ccaba5_0
+ - wheel=0.38.4=py38h06a4308_0
+ - xz=5.4.2=h5eee18b_0
+ - zlib=1.2.13=h5eee18b_0
+ - pip:
+ - appdirs==1.4.4
+ - beautifulsoup4==4.12.2
+ - certifi==2023.5.7
+ - charset-normalizer==3.1.0
+ - click==8.1.3
+ - cmake==3.26.4
+ - contourpy==1.1.0
+ - cycler==0.11.0
+ - docker-pycreds==0.4.0
+ - einops==0.6.1
+ - filelock==3.12.0
+ - fonttools==4.40.0
+ - gdown==4.7.1
+ - gitdb==4.0.10
+ - gitpython==3.1.31
+ - idna==3.4
+ - importlib-resources==5.12.0
+ - jinja2==3.1.2
+ - joblib==1.3.1
+ - kiwisolver==1.4.4
+ - lit==16.0.6
+ - markupsafe==2.1.3
+ - matplotlib==3.7.1
+ - memory-efficient-attention-pytorch==0.1.2
+ - mpmath==1.3.0
+ - networkx==3.1
+ - numpy==1.24.4
+ - nvidia-cublas-cu11==11.10.3.66
+ - nvidia-cuda-cupti-cu11==11.7.101
+ - nvidia-cuda-nvrtc-cu11==11.7.99
+ - nvidia-cuda-runtime-cu11==11.7.99
+ - nvidia-cudnn-cu11==8.5.0.96
+ - nvidia-cufft-cu11==10.9.0.58
+ - nvidia-curand-cu11==10.2.10.91
+ - nvidia-cusolver-cu11==11.4.0.1
+ - nvidia-cusparse-cu11==11.7.4.91
+ - nvidia-nccl-cu11==2.14.3
+ - nvidia-nvtx-cu11==11.7.91
+ - opencv-python==4.8.0.74
+ - packaging==23.1
+ - pandas==2.0.3
+ - pathtools==0.1.2
+ - pillow==9.5.0
+ - protobuf==4.23.3
+ - psutil==5.9.5
+ - pyparsing==3.1.0
+ - pysocks==1.7.1
+ - python-dateutil==2.8.2
+ - pytz==2023.3
+ - pyyaml==6.0
+ - requests==2.30.0
+ - scikit-learn==1.3.0
+ - scipy==1.10.1
+ - seaborn==0.12.2
+ - sentry-sdk==1.27.0
+ - setproctitle==1.3.2
+ - six==1.16.0
+ - smmap==5.0.0
+ - soupsieve==2.4.1
+ - sympy==1.12
+ - threadpoolctl==3.1.0
+ - torch==2.0.1
+ - torch-geometric==2.3.1
+ - torchaudio==2.0.2
+ - torchsampler==0.1.2
+ - torchvision==0.15.2
+ - tqdm==4.65.0
+ - triton==2.0.0
+ - typing-extensions==4.7.0
+ - tzdata==2023.3
+ - urllib3==2.0.2
+ - wandb==0.15.5
+ - zipp==3.15.0
+prefix: /opt/anaconda3/envs/tbd
\ No newline at end of file
diff --git a/tbd/models/base.py b/tbd/models/base.py
new file mode 100644
index 0000000..5c4ca02
--- /dev/null
+++ b/tbd/models/base.py
@@ -0,0 +1,156 @@
+import torch
+import torch.nn as nn
+from .utils import pose_edge_index
+from torch_geometric.nn import GCNConv
+
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.fn = fn
+ self.norm = nn.LayerNorm(dim)
+ def forward(self, x, **kwargs):
+ x = self.norm(x)
+ return self.fn(x, **kwargs)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, dim),
+ nn.GELU(),
+ nn.Linear(dim, dim))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class CNN(nn.Module):
+ def __init__(self, hidden_dim):
+ super(CNN, self).__init__()
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
+ self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = nn.functional.relu(x)
+ x = self.pool(x)
+ x = self.conv2(x)
+ x = nn.functional.relu(x)
+ x = self.pool(x)
+ x = self.conv3(x)
+ x = nn.functional.relu(x)
+ x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
+ return x
+
+
+class MindNetLSTM(nn.Module):
+ """
+ Basic MindNet for model-based ToM, just LSTM on input concatenation
+ """
+ def __init__(self, hidden_dim, dropout, mods):
+ super(MindNetLSTM, self).__init__()
+ self.mods = mods
+ if 'rgb_1' in mods:
+ self.img_emb = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+ if 'gaze' in mods:
+ self.gaze_emb = nn.Linear(2, hidden_dim)
+ if 'pose' in mods:
+ self.pose_edge_index = pose_edge_index()
+ self.pose_emb = GCNConv(3, hidden_dim)
+ self.LSTM = PreNorm(
+ hidden_dim*len(mods),
+ nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
+ self.proj = nn.Linear(hidden_dim*2, hidden_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.act = nn.GELU()
+
+ def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
+ feats = []
+ if 'rgb_3' in self.mods:
+ feats.append(rgb_3rd_pov_feats)
+ if 'rgb_1' in self.mods:
+ rgb_feat = []
+ for i in range(rgb_1st_pov.shape[1]):
+ images_i = rgb_1st_pov[:,i]
+ img_i_feat = self.img_emb(images_i)
+ img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+ feats.append(rgb_feats)
+ if 'pose' in self.mods:
+ bs, seq_len = pose.size(0), pose.size(1)
+ self.pose_edge_index = self.pose_edge_index.to(pose.device)
+ pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
+ pose_emb = self.dropout(self.act(pose_emb))
+ pose_emb = torch.mean(pose_emb, dim=1)
+ hd = pose_emb.size(-1)
+ feats.append(pose_emb.view(bs, seq_len, hd))
+ if 'gaze' in self.mods:
+ gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
+ feats.append(gaze_feats)
+ if 'bbox' in self.mods:
+ feats.append(bbox_feats.mean(2))
+ lstm_inp = torch.cat(feats, 2)
+ lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
+ c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
+ return self.act(self.proj(lstm_out)), c_n, feats
+
+
+class MindNetSL(nn.Module):
+ """
+ Basic MindNet for SL ToM, just LSTM on input concatenation
+ """
+ def __init__(self, hidden_dim, dropout, mods):
+ super(MindNetSL, self).__init__()
+ self.mods = mods
+ if 'rgb_1' in mods:
+ self.img_emb = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+ if 'gaze' in mods:
+ self.gaze_emb = nn.Linear(2, hidden_dim)
+ if 'pose' in mods:
+ self.pose_edge_index = pose_edge_index()
+ self.pose_emb = GCNConv(3, hidden_dim)
+ self.LSTM = PreNorm(
+ hidden_dim*len(mods),
+ nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
+ self.proj = nn.Linear(hidden_dim*2, hidden_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.act = nn.GELU()
+
+ def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
+ feats = []
+ if 'rgb_3' in self.mods:
+ feats.append(rgb_3rd_pov_feats)
+ if 'rgb_1' in self.mods:
+ rgb_feat = []
+ for i in range(rgb_1st_pov.shape[1]):
+ images_i = rgb_1st_pov[:,i]
+ img_i_feat = self.img_emb(images_i)
+ img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+ feats.append(rgb_feats)
+ if 'pose' in self.mods:
+ bs, seq_len = pose.size(0), pose.size(1)
+ self.pose_edge_index = self.pose_edge_index.to(pose.device)
+ pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
+ pose_emb = self.dropout(self.act(pose_emb))
+ pose_emb = torch.mean(pose_emb, dim=1)
+ hd = pose_emb.size(-1)
+ feats.append(pose_emb.view(bs, seq_len, hd))
+ if 'gaze' in self.mods:
+ gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
+ feats.append(gaze_feats)
+ if 'bbox' in self.mods:
+ feats.append(bbox_feats.mean(2))
+ lstm_inp = torch.cat(feats, 2)
+ lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
+ return self.act(self.proj(lstm_out)), feats
\ No newline at end of file
diff --git a/tbd/models/common_mind.py b/tbd/models/common_mind.py
new file mode 100644
index 0000000..6b45761
--- /dev/null
+++ b/tbd/models/common_mind.py
@@ -0,0 +1,157 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from .base import CNN, MindNetLSTM
+from memory_efficient_attention_pytorch import Attention
+
+
+class CommonMindToMnet(nn.Module):
+ """
+ img: bs, 3, 128, 128
+ pose: bs, 26, 3
+ gaze: bs, 2 NOTE: only tracker has gaze
+ bbox: bs, 4
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']):
+ super(CommonMindToMnet, self).__init__()
+
+ self.aggr = aggr
+ self.mods = mods
+
+ # ---- 3rd POV Images, object and bbox ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ #for param in self.cnn.parameters():
+ # param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+ self.bbox_ff = nn.Linear(4, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
+ self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
+ if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim)
+ self.ln_1 = nn.LayerNorm(hidden_dim)
+ self.ln_2 = nn.LayerNorm(hidden_dim)
+ if aggr == 'attn':
+ self.attn_left = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.attn_right = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.m1 = nn.Linear(hidden_dim, 4)
+ self.m2 = nn.Linear(hidden_dim, 4)
+ self.m12 = nn.Linear(hidden_dim, 4)
+ self.m21 = nn.Linear(hidden_dim, 4)
+ self.mc = nn.Linear(hidden_dim, 4)
+
+ def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
+
+ batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
+
+ if 'bbox' in self.mods:
+ bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
+ else:
+ bbox_feat = None
+
+ if 'rgb_3' in self.mods:
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = img_3rd_pov[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+ else:
+ rgb_feat_3rd_pov = None
+
+ if tracker_id == 'skele1':
+ out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
+ out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
+ else:
+ out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
+ out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
+
+ if self.aggr == 'no_tom':
+ m1 = self.m1(out_1).mean(1)
+ m2 = self.m2(out_2).mean(1)
+ m12 = self.m12(out_1).mean(1)
+ m21 = self.m21(out_2).mean(1)
+ mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
+
+ return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
+
+ common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h)
+
+ if self.aggr == 'attn':
+ p1 = self.attn_left(x=out_1, context=common_mind)
+ p2 = self.attn_right(x=out_2, context=common_mind)
+ elif self.aggr == 'mult':
+ p1 = out_1 * common_mind
+ p2 = out_2 * common_mind
+ elif self.aggr == 'sum':
+ p1 = out_1 + common_mind
+ p2 = out_2 + common_mind
+ elif self.aggr == 'concat':
+ p1 = torch.cat([out_1, common_mind], 1)
+ p2 = torch.cat([out_2, common_mind], 1)
+ else: raise ValueError
+ p1 = self.act(p1)
+ p1 = self.ln_1(p1)
+ p2 = self.act(p2)
+ p2 = self.ln_2(p2)
+ if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
+ m1 = self.m1(p1).mean(1)
+ m2 = self.m2(p2).mean(1)
+ m12 = self.m12(p1).mean(1)
+ m21 = self.m21(p2).mean(1)
+ mc = self.mc(p1*p2).mean(1)
+ if self.aggr == 'concat':
+ m1 = self.m1(p1).mean(1)
+ m2 = self.m2(p2).mean(1)
+ m12 = self.m12(p1).mean(1)
+ m21 = self.m21(p2).mean(1)
+ mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
+
+ return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2
+
+
+
+
+if __name__ == "__main__":
+
+ img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
+ img_tracker = torch.ones(3, 5, 3, 128, 128)
+ img_battery = torch.ones(3, 5, 3, 128, 128)
+ pose1 = torch.ones(3, 5, 26, 3)
+ pose2 = torch.ones(3, 5, 26, 3)
+ bbox = torch.ones(3, 5, 13, 4)
+ tracker_id = 'skele1'
+ gaze = torch.ones(3, 5, 2)
+ mods = ['pose', 'bbox', 'rgb_3']
+
+ for agg in ['no_tom']:
+ model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods)
+ out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
+
+ print(out[0].shape)
\ No newline at end of file
diff --git a/tbd/models/implicit.py b/tbd/models/implicit.py
new file mode 100644
index 0000000..3689b71
--- /dev/null
+++ b/tbd/models/implicit.py
@@ -0,0 +1,151 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from .base import CNN, MindNetLSTM
+from memory_efficient_attention_pytorch import Attention
+
+
+class ImplicitToMnet(nn.Module):
+ """
+ Implicit ToM net. Supports any subset of modalities
+ Possible aggregations: sum, mult, attn, concat
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
+ super(ImplicitToMnet, self).__init__()
+
+ self.aggr = aggr
+ self.mods = mods
+
+ # ---- 3rd POV Images, object and bbox ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+ self.bbox_ff = nn.Linear(4, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
+ self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
+ self.ln_1 = nn.LayerNorm(hidden_dim)
+ self.ln_2 = nn.LayerNorm(hidden_dim)
+ if aggr == 'attn':
+ self.attn_left = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.attn_right = Attention(
+ dim = hidden_dim,
+ dim_head = hidden_dim // 4,
+ heads = 4,
+ memory_efficient = True,
+ q_bucket_size = hidden_dim,
+ k_bucket_size = hidden_dim)
+ self.m1 = nn.Linear(hidden_dim, 4)
+ self.m2 = nn.Linear(hidden_dim, 4)
+ self.m12 = nn.Linear(hidden_dim, 4)
+ self.m21 = nn.Linear(hidden_dim, 4)
+ self.mc = nn.Linear(hidden_dim, 4)
+
+ def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
+
+ batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
+
+ if 'bbox' in self.mods:
+ bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
+ else:
+ bbox_feat = None
+
+ if 'rgb_3' in self.mods:
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = img_3rd_pov[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+ else:
+ rgb_feat_3rd_pov = None
+
+ if tracker_id == 'skele1':
+ out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
+ out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
+ else:
+ out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
+ out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
+
+ if self.aggr == 'no_tom':
+ m1 = self.m1(out_1).mean(1)
+ m2 = self.m2(out_2).mean(1)
+ m12 = self.m12(out_1).mean(1)
+ m21 = self.m21(out_2).mean(1)
+ mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
+
+ return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
+
+ if self.aggr == 'attn':
+ p1 = self.attn_left(x=out_1, context=cell_2)
+ p2 = self.attn_right(x=out_2, context=cell_1)
+ elif self.aggr == 'mult':
+ p1 = out_1 * cell_2
+ p2 = out_2 * cell_1
+ elif self.aggr == 'sum':
+ p1 = out_1 + cell_2
+ p2 = out_2 + cell_1
+ elif self.aggr == 'concat':
+ p1 = torch.cat([out_1, cell_2], 1)
+ p2 = torch.cat([out_2, cell_1], 1)
+ else: raise ValueError
+ p1 = self.act(p1)
+ p1 = self.ln_1(p1)
+ p2 = self.act(p2)
+ p2 = self.ln_2(p2)
+ if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
+ m1 = self.m1(p1).mean(1)
+ m2 = self.m2(p2).mean(1)
+ m12 = self.m12(p1).mean(1)
+ m21 = self.m21(p2).mean(1)
+ mc = self.mc(p1*p2).mean(1)
+ if self.aggr == 'concat':
+ m1 = self.m1(p1).mean(1)
+ m2 = self.m2(p2).mean(1)
+ m12 = self.m12(p1).mean(1)
+ m21 = self.m21(p2).mean(1)
+ mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
+
+ return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
+
+
+
+
+if __name__ == "__main__":
+
+ img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
+ img_tracker = torch.ones(3, 5, 3, 128, 128)
+ img_battery = torch.ones(3, 5, 3, 128, 128)
+ pose1 = torch.ones(3, 5, 26, 3)
+ pose2 = torch.ones(3, 5, 26, 3)
+ bbox = torch.ones(3, 5, 13, 4)
+ tracker_id = 'skele1'
+ gaze = torch.ones(3, 5, 2)
+
+ for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
+ model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg)
+ out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
+
+ print(agg, out[0].shape)
\ No newline at end of file
diff --git a/tbd/models/sl.py b/tbd/models/sl.py
new file mode 100644
index 0000000..4503971
--- /dev/null
+++ b/tbd/models/sl.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from .base import CNN, MindNetSL
+
+
+class SLToMnet(nn.Module):
+ """
+ Speaker-Listener ToMnet
+ """
+ def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
+ super(SLToMnet, self).__init__()
+
+ self.tom_weight = tom_weight
+ self.mods = mods
+
+ # ---- 3rd POV Images, object and bbox ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+ self.bbox_ff = nn.Linear(4, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_1 = MindNetSL(hidden_dim, dropout, mods=mods)
+ self.mind_net_2 = MindNetSL(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
+ self.m1 = nn.Linear(hidden_dim, 4)
+ self.m2 = nn.Linear(hidden_dim, 4)
+ self.m12 = nn.Linear(hidden_dim, 4)
+ self.m21 = nn.Linear(hidden_dim, 4)
+ self.mc = nn.Linear(hidden_dim, 4)
+
+ def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
+
+ batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
+
+ if 'bbox' in self.mods:
+ bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
+ else:
+ bbox_feat = None
+
+ if 'rgb_3' in self.mods:
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = img_3rd_pov[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+ else:
+ rgb_feat_3rd_pov = None
+
+ if tracker_id == 'skele1':
+ out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
+ out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
+ else:
+ out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
+ out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
+
+ m1_logits = self.m1(out_1).mean(1)
+ m2_logits = self.m2(out_2).mean(1)
+ m12_logits = self.m12(out_1).mean(1)
+ m21_logits = self.m21(out_2).mean(1)
+ mc_logits = self.mc(out_1*out_2).mean(1)
+
+ m1_ranking = torch.log_softmax(m1_logits, dim=-1)
+ m2_ranking = torch.log_softmax(m2_logits, dim=-1)
+ m12_ranking = torch.log_softmax(m12_logits, dim=-1)
+ m21_ranking = torch.log_softmax(m21_logits, dim=-1)
+ mc_ranking = torch.log_softmax(mc_logits, dim=-1)
+
+ # NOTE: does this make sense?
+ m1 = m1_ranking + self.tom_weight * m2_ranking
+ m2 = m2_ranking + self.tom_weight * m1_ranking
+ m12 = m12_ranking + self.tom_weight * m21_ranking
+ m21 = m21_ranking + self.tom_weight * m12_ranking
+ mc = mc_ranking + self.tom_weight * mc_ranking
+
+ return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
+
+
+
+
+
+if __name__ == "__main__":
+
+ img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
+ img_tracker = torch.ones(3, 5, 3, 128, 128)
+ img_battery = torch.ones(3, 5, 3, 128, 128)
+ pose1 = torch.ones(3, 5, 26, 3)
+ pose2 = torch.ones(3, 5, 26, 3)
+ bbox = torch.ones(3, 5, 13, 4)
+ tracker_id = 'skele1'
+ gaze = torch.ones(3, 5, 2)
+
+ model = SLToMnet(hidden_dim=64, device='cpu', tom_weight=2.0, resnet=False, dropout=0.5)
+ out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
+
+ print(out[0].shape)
\ No newline at end of file
diff --git a/tbd/models/tom_base.py b/tbd/models/tom_base.py
new file mode 100644
index 0000000..e70acdf
--- /dev/null
+++ b/tbd/models/tom_base.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+from .base import CNN, MindNetLSTM
+import numpy as np
+
+
+class ImplicitToMnet(nn.Module):
+ """
+ Implicit ToM net. Supports any subset of modalities
+ Possible aggregations: sum, mult, attn, concat
+ """
+ def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
+ super(ImplicitToMnet, self).__init__()
+
+ self.mods = mods
+
+ # ---- 3rd POV Images, object and bbox ----#
+ if resnet:
+ resnet = models.resnet34(weights="IMAGENET1K_V1")
+ self.cnn = nn.Sequential(
+ *(list(resnet.children())[:-1])
+ )
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.rgb_ff = nn.Linear(512, hidden_dim)
+ else:
+ self.cnn = CNN(hidden_dim)
+ self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
+ self.bbox_ff = nn.Linear(4, hidden_dim)
+
+ # ---- Others ----#
+ self.act = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.device = device
+
+ # ---- Mind nets ----#
+ self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
+ self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
+
+ self.m1 = nn.Linear(hidden_dim, 4)
+ self.m2 = nn.Linear(hidden_dim, 4)
+ self.m12 = nn.Linear(hidden_dim, 4)
+ self.m21 = nn.Linear(hidden_dim, 4)
+ self.mc = nn.Linear(hidden_dim, 4)
+
+ def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
+
+ batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
+
+ if 'bbox' in self.mods:
+ bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
+ else:
+ bbox_feat = None
+
+ if 'rgb_3' in self.mods:
+ rgb_feat = []
+ for i in range(sequence_len):
+ images_i = img_3rd_pov[:,i]
+ img_i_feat = self.cnn(images_i)
+ img_i_feat = img_i_feat.view(batch_size, -1)
+ rgb_feat.append(img_i_feat)
+ rgb_feat = torch.stack(rgb_feat, 1)
+ rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
+ else:
+ rgb_feat_3rd_pov = None
+
+ if tracker_id == 'skele1':
+ out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
+ out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
+ else:
+ out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
+ out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
+
+ if self.aggr == 'no_tom':
+ m1 = self.m1(out_1).mean(1)
+ m2 = self.m2(out_2).mean(1)
+ m12 = self.m12(out_1).mean(1)
+ m21 = self.m21(out_2).mean(1)
+ mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
+
+ return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
+
+
+
+def count_parameters(model):
+ #return sum(p.numel() for p in model.parameters() if p.requires_grad)
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+ return sum([np.prod(p.size()) for p in model_parameters])
+
+
+
+if __name__ == "__main__":
+
+ img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
+ img_tracker = torch.ones(3, 5, 3, 128, 128)
+ img_battery = torch.ones(3, 5, 3, 128, 128)
+ pose1 = torch.ones(3, 5, 26, 3)
+ pose2 = torch.ones(3, 5, 26, 3)
+ bbox = torch.ones(3, 5, 13, 4)
+ tracker_id = 'skele1'
+ gaze = torch.ones(3, 5, 2)
+
+ model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
+ print(count_parameters(model))
+ breakpoint()
+
+ for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
+ model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
+ out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
+
+ print(agg, out[0].shape)
\ No newline at end of file
diff --git a/tbd/models/utils.py b/tbd/models/utils.py
new file mode 100644
index 0000000..582c132
--- /dev/null
+++ b/tbd/models/utils.py
@@ -0,0 +1,7 @@
+import torch
+
+
+def pose_edge_index():
+ start = [15, 14, 13, 12, 19, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 3, 4, 5, 6, 8, 8, 4, 20, 21, 21, 22, 24, 22]
+ end = [14, 13, 12, 0, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 4, 20, 20, 21, 22, 24, 23, 25, 24]
+ return torch.tensor([start+end, end+start])
\ No newline at end of file
diff --git a/tbd/results/abl.json b/tbd/results/abl.json
new file mode 100644
index 0000000..98007d8
--- /dev/null
+++ b/tbd/results/abl.json
@@ -0,0 +1,186 @@
+{
+ "all": [
+ {
+ "m1": 0.3803337111550836,
+ "m2": 0.3900899763574355,
+ "m12": 0.4441281276628709,
+ "m21": 0.4818757648120031,
+ "mc": 0.4485177767702456
+ },
+ {
+ "m1": 0.5186066842992191,
+ "m2": 0.521895750052127,
+ "m12": 0.49294626980529677,
+ "m21": 0.4810118034501327,
+ "mc": 0.6097300398369058
+ },
+ {
+ "m1": 0.4965589148309122,
+ "m2": 0.5094894309980568,
+ "m12": 0.4615136302786905,
+ "m21": 0.4554005550423429,
+ "mc": 0.6258118710785031
+ }
+ ],
+ "rgb_3_pose_gaze_bbox": [
+ {
+ "m1": 0.3776045061727805,
+ "m2": 0.3996776745150713,
+ "m12": 0.4762772810038159,
+ "m21": 0.48643178296718503,
+ "mc": 0.4575207273412474
+ },
+ {
+ "m1": 0.5176564423560418,
+ "m2": 0.5109344883698214,
+ "m12": 0.4630213122846928,
+ "m21": 0.4826608133674547,
+ "mc": 0.5979415365779003
+ },
+ {
+ "m1": 0.5114692300997931,
+ "m2": 0.5027048375802656,
+ "m12": 0.47527894405588544,
+ "m21": 0.45223985157847546,
+ "mc": 0.6054099305712209
+ }
+ ],
+ "rgb_3_pose_gaze": [
+ {
+ "m1": 0.403207421026191,
+ "m2": 0.3833413122398237,
+ "m12": 0.4602455224198077,
+ "m21": 0.47181798537346287,
+ "mc": 0.4603675297898878
+ },
+ {
+ "m1": 0.49484810149311514,
+ "m2": 0.5060275976807422,
+ "m12": 0.4610412452830618,
+ "m21": 0.46869095956564044,
+ "mc": 0.6040674897817755
+ },
+ {
+ "m1": 0.5160598186177866,
+ "m2": 0.5309683014233921,
+ "m12": 0.47227245803060636,
+ "m21": 0.46953974307035984,
+ "mc": 0.6014771460423635
+ }
+ ],
+ "rgb_3_pose": [
+ {
+ "m1": 0.4057149181928123,
+ "m2": 0.4002233785689204,
+ "m12": 0.46794813614607333,
+ "m21": 0.4690365183933033,
+ "mc": 0.4591530208921514
+ },
+ {
+ "m1": 0.5362792166212834,
+ "m2": 0.5290656046231254,
+ "m12": 0.4569419683345858,
+ "m21": 0.4530255281497826,
+ "mc": 0.4554252731371068
+ },
+ {
+ "m1": 0.49570625763169085,
+ "m2": 0.5146503967646507,
+ "m12": 0.4567936139893578,
+ "m21": 0.45918214877096325,
+ "mc": 0.5962397441246001
+ }
+ ],
+ "rgb_3_gaze": [
+ {
+ "m1": 0.40135106828655215,
+ "m2": 0.38453470155825614,
+ "m12": 0.4989742833725901,
+ "m21": 0.47369273992079175,
+ "mc": 0.48430622854433986
+ },
+ {
+ "m1": 0.508038122818153,
+ "m2": 0.4875748099051746,
+ "m12": 0.46665443622698555,
+ "m21": 0.46635808547742913,
+ "mc": 0.47936993226840163
+ },
+ {
+ "m1": 0.49795853039610977,
+ "m2": 0.5028666890527814,
+ "m12": 0.44176709237564815,
+ "m21": 0.4483898274665582,
+ "mc": 0.5867527750929912
+ }
+ ],
+ "rgb_3_bbox": [
+ {
+ "m1": 0.3951383898241492,
+ "m2": 0.3818794542844425,
+ "m12": 0.44108151735270384,
+ "m21": 0.46539754196523303,
+ "mc": 0.43982185797713114
+ },
+ {
+ "m1": 0.5093846655989521,
+ "m2": 0.4923439212866733,
+ "m12": 0.4598003475323884,
+ "m21": 0.47647640659290746,
+ "mc": 0.6349953712994137
+ },
+ {
+ "m1": 0.5325224862402295,
+ "m2": 0.5092319973570975,
+ "m12": 0.4435807136490263,
+ "m21": 0.4576911633624616,
+ "mc": 0.6282064277856357
+ }
+ ],
+ "rgb_3_rgb_1": [
+ {
+ "m1": 0.39189391736691903,
+ "m2": 0.3739995635963588,
+ "m12": 0.4792392731637056,
+ "m21": 0.4592726043789752,
+ "mc": 0.4468645255652386
+ },
+ {
+ "m1": 0.4827892482357646,
+ "m2": 0.48042899735042716,
+ "m12": 0.45932653547051094,
+ "m21": 0.48430209616318126,
+ "mc": 0.4506104344435269
+ },
+ {
+ "m1": 0.4820247145474279,
+ "m2": 0.3667553358192628,
+ "m12": 0.44503028688537,
+ "m21": 0.45984906207471654,
+ "mc": 0.465120658971623
+ }
+ ],
+ "rgb_3": [
+ {
+ "m1": 0.40725462165126114,
+ "m2": 0.38737351624656846,
+ "m12": 0.46230461548252094,
+ "m21": 0.4829312519709871,
+ "mc": 0.4492175856929955
+ },
+ {
+ "m1": 0.5286274183685061,
+ "m2": 0.5081429492163979,
+ "m12": 0.4610256989472217,
+ "m21": 0.4733487634477733,
+ "mc": 0.4655243312197501
+ },
+ {
+ "m1": 0.5217968210271873,
+ "m2": 0.5103780571157844,
+ "m12": 0.4431266771306429,
+ "m21": 0.48398542131284883,
+ "mc": 0.6122314353959392
+ }
+ ]
+}
\ No newline at end of file
diff --git a/tbd/results/all.json b/tbd/results/all.json
new file mode 100644
index 0000000..978f197
--- /dev/null
+++ b/tbd/results/all.json
@@ -0,0 +1,232 @@
+{
+ "cm_concat": [
+ {
+ "m1": 0.38921744471949393,
+ "m2": 0.38557137008494935,
+ "m12": 0.44699534554593756,
+ "m21": 0.4747474437468054,
+ "mc": 0.4918107834016411
+ },
+ {
+ "m1": 0.5402415140026018,
+ "m2": 0.48833721513836786,
+ "m12": 0.4631512445419047,
+ "m21": 0.4740880083492652,
+ "mc": 0.6375070925808958
+ },
+ {
+ "m1": 0.5012543523713172,
+ "m2": 0.5068694866895836,
+ "m12": 0.4451537834591627,
+ "m21": 0.45215784721598673,
+ "mc": 0.6201022576104379
+ }
+ ],
+ "cm_sum": [
+ {
+ "m1": 0.39403894801783246,
+ "m2": 0.38541918219411786,
+ "m12": 0.4600376974144952,
+ "m21": 0.471919704007463,
+ "mc": 0.43950812310207055
+ },
+ {
+ "m1": 0.48497621104052574,
+ "m2": 0.5295044689855949,
+ "m12": 0.4502949472343065,
+ "m21": 0.47823492553894387,
+ "mc": 0.6028290833617195
+ },
+ {
+ "m1": 0.503386104373653,
+ "m2": 0.49983127146477085,
+ "m12": 0.46782817568218116,
+ "m21": 0.45484578845116075,
+ "mc": 0.5905749126722909
+ }
+ ],
+ "cm_mult": [
+ {
+ "m1": 0.39070820515470606,
+ "m2": 0.3996851353903932,
+ "m12": 0.4455704586852128,
+ "m21": 0.4713517869738811,
+ "mc": 0.4450907029478458
+ },
+ {
+ "m1": 0.5066540697731119,
+ "m2": 0.526507445454099,
+ "m12": 0.462643008560599,
+ "m21": 0.48263054309565334,
+ "mc": 0.6438566476782207
+ },
+ {
+ "m1": 0.48868811674304546,
+ "m2": 0.5074635877653536,
+ "m12": 0.44597405775819876,
+ "m21": 0.45445350963025877,
+ "mc": 0.5884265473527218
+ }
+ ],
+ "cm_attn": [
+ {
+ "m1": 0.3949557687114269,
+ "m2": 0.3919385900921811,
+ "m12": 0.4850081112466773,
+ "m21": 0.4849575556679713,
+ "mc": 0.4516870089239762
+ },
+ {
+ "m1": 0.4925989821370256,
+ "m2": 0.49409170532242247,
+ "m12": 0.4664647278240569,
+ "m21": 0.46783863397462533,
+ "mc": 0.6398721139927354
+ },
+ {
+ "m1": 0.4945636568169018,
+ "m2": 0.5049812790749876,
+ "m12": 0.454359577718189,
+ "m21": 0.4712184012093268,
+ "mc": 0.5992735441011302
+ }
+ ],
+ "no_tom": [
+ {
+ "m1": 0.2570551317,
+ "m2": 0.375350929686332,
+ "m12": 0.312451988649724,
+ "m21": 0.4631371031641,
+ "mc": 0.457486278214567
+ },
+ {
+ "m1": 0.233046800382043,
+ "m2": 0.522609755931958,
+ "m12": 0.326821758467328,
+ "m21": 0.474338898013257,
+ "mc": 0.604439456291308
+ },
+ {
+ "m1": 0.33774852598382,
+ "m2": 0.520943544364353,
+ "m12": 0.298617214416867,
+ "m21": 0.482175301427192,
+ "mc": 0.634948478570852
+ }
+ ],
+ "sl": [
+ {
+ "m1": 0.365205706591741,
+ "m2": 0.255259363011619,
+ "m12": 0.421227579844245,
+ "m21": 0.376143327741882,
+ "mc": 0.45614515353718
+ },
+ {
+ "m1": 0.493046934143676,
+ "m2": 0.331798174804139,
+ "m12": 0.422821548330913,
+ "m21": 0.399768928780549,
+ "mc": 0.450957023549231
+ },
+ {
+ "m1": 0.466266787709392,
+ "m2": 0.350962671130227,
+ "m12": 0.431694150269919,
+ "m21": 0.378863431433258,
+ "mc": 0.470284405744656
+ }
+ ],
+ "impl_concat": [
+ {
+ "m1": 0.38427302094644894,
+ "m2": 0.38673879043767634,
+ "m12": 0.45694337561663145,
+ "m21": 0.4737891562722213,
+ "mc": 0.4502976351448088
+ },
+ {
+ "m1": 0.49951068243751173,
+ "m2": 0.5084945752383908,
+ "m12": 0.4604721097809549,
+ "m21": 0.4826884970930907,
+ "mc": 0.6200443272625361
+ },
+ {
+ "m1": 0.5013244243339088,
+ "m2": 0.49476495726495723,
+ "m12": 0.4596701406290429,
+ "m21": 0.4554742441542813,
+ "mc": 0.5988949378402535
+ }
+ ],
+ "impl_sum": [
+ {
+ "m1": 0.3803337111550836,
+ "m2": 0.3900899763574355,
+ "m12": 0.4441281276628709,
+ "m21": 0.4818757648120031,
+ "mc": 0.4485177767702456
+ },
+ {
+ "m1": 0.5186066842992191,
+ "m2": 0.521895750052127,
+ "m12": 0.49294626980529677,
+ "m21": 0.4810118034501327,
+ "mc": 0.6097300398369058
+ },
+ {
+ "m1": 0.4965589148309122,
+ "m2": 0.5094894309980568,
+ "m12": 0.4615136302786905,
+ "m21": 0.4554005550423429,
+ "mc": 0.6258118710785031
+ }
+ ],
+ "impl_mult": [
+ {
+ "m1": 0.3789421413006731,
+ "m2": 0.3818053844554785,
+ "m12": 0.46402717346945177,
+ "m21": 0.4903726261039529,
+ "mc": 0.4461443806398687
+ },
+ {
+ "m1": 0.3789421413006731,
+ "m2": 0.3818053844554785,
+ "m12": 0.46402717346945177,
+ "m21": 0.4903726261039529,
+ "mc": 0.4461443806398687
+ },
+ {
+ "m1": 0.49338554196342077,
+ "m2": 0.5066817652688608,
+ "m12": 0.46253374461930613,
+ "m21": 0.47782311190445825,
+ "mc": 0.4581608719646799
+ }
+ ],
+ "impl_attn": [
+ {
+ "m1": 0.37413691393147924,
+ "m2": 0.2546966838007244,
+ "m12": 0.429390512693598,
+ "m21": 0.292401773870023,
+ "mc": 0.45706325836224465
+ },
+ {
+ "m1": 0.513917904196177,
+ "m2": 0.25802580258025803,
+ "m12": 0.49272662664765543,
+ "m21": 0.27041556176385584,
+ "mc": 0.6041394755857196
+ },
+ {
+ "m1": 0.47720445038981674,
+ "m2": 0.25839328537170264,
+ "m12": 0.46505055463781547,
+ "m21": 0.260276985433943,
+ "mc": 0.6021811271770562
+ }
+ ]
+}
diff --git a/tbd/results/false_belief_first_vs_second.pdf b/tbd/results/false_belief_first_vs_second.pdf
new file mode 100644
index 0000000..8e2ec5a
Binary files /dev/null and b/tbd/results/false_belief_first_vs_second.pdf differ
diff --git a/tbd/results/fb_ttest.txt b/tbd/results/fb_ttest.txt
new file mode 100644
index 0000000..a38ce44
--- /dev/null
+++ b/tbd/results/fb_ttest.txt
@@ -0,0 +1,87 @@
+
+========================================================= m1_m2_m12_m21
+
+
+Model: Base -> yes
+
+
+Model: DB -> yes
+
+
+Model: CG$\oplus$ -> no
+
+
+Model: CG$\otimes$ -> no
+
+
+Model: CG$\odot$ -> no
+
+
+Model: IC$\parallel$ -> no
+
+
+Model: IC$\oplus$ -> no
+
+
+Model: IC$\otimes$ -> no
+
+
+Model: IC$\odot$ -> yes
+
+========================================================= m1_m2
+
+
+Model: Base -> yes
+
+
+Model: DB -> yes
+
+
+Model: CG$\oplus$ -> no
+
+
+Model: CG$\otimes$ -> no
+
+
+Model: CG$\odot$ -> no
+
+
+Model: IC$\parallel$ -> no
+
+
+Model: IC$\oplus$ -> no
+
+
+Model: IC$\otimes$ -> no
+
+
+Model: IC$\odot$ -> yes
+
+========================================================= m12_m21
+
+
+Model: Base -> yes
+
+
+Model: DB -> yes
+
+
+Model: CG$\oplus$ -> no
+
+
+Model: CG$\otimes$ -> no
+
+
+Model: CG$\odot$ -> no
+
+
+Model: IC$\parallel$ -> no
+
+
+Model: IC$\oplus$ -> no
+
+
+Model: IC$\otimes$ -> no
+
+
+Model: IC$\odot$ -> yes
diff --git a/tbd/results/hgm_scores.txt b/tbd/results/hgm_scores.txt
new file mode 100644
index 0000000..b501612
--- /dev/null
+++ b/tbd/results/hgm_scores.txt
@@ -0,0 +1,59 @@
+mc =====================================================================
+ precision recall f1-score support
+
+ 0 0.000 0.500 0.001 50
+ 1 0.000 0.000 0.000 4
+ 2 0.004 0.038 0.007 238
+ 3 0.999 0.795 0.885 290788
+
+ accuracy 0.794 291080
+ macro avg 0.251 0.333 0.223 291080
+weighted avg 0.998 0.794 0.884 291080
+
+m1 =====================================================================
+ precision recall f1-score support
+
+ 0 0.000 0.000 0.000 147
+ 1 0.000 0.000 0.000 2
+ 2 0.025 0.051 0.033 1714
+ 3 0.994 0.988 0.991 289217
+
+ accuracy 0.982 291080
+ macro avg 0.255 0.260 0.256 291080
+weighted avg 0.988 0.982 0.985 291080
+
+m2 =====================================================================
+ precision recall f1-score support
+
+ 0 0.001 0.013 0.001 151
+ 2 0.031 0.084 0.045 2394
+ 3 0.992 0.970 0.981 288535
+
+ accuracy 0.962 291080
+ macro avg 0.341 0.355 0.342 291080
+weighted avg 0.983 0.962 0.972 291080
+
+m12 =====================================================================
+ precision recall f1-score support
+
+ 0 0.000 0.000 0.000 93
+ 1 0.000 0.000 0.000 8
+ 2 0.015 0.056 0.023 676
+ 3 0.997 0.990 0.994 290303
+
+ accuracy 0.988 291080
+ macro avg 0.253 0.262 0.254 291080
+weighted avg 0.995 0.988 0.991 291080
+
+m21 =====================================================================
+ precision recall f1-score support
+
+ 0 0.002 0.012 0.003 86
+ 1 0.000 0.000 0.000 12
+ 2 0.010 0.040 0.016 658
+ 3 0.997 0.989 0.993 290324
+
+ accuracy 0.987 291080
+ macro avg 0.252 0.260 0.253 291080
+weighted avg 0.995 0.987 0.991 291080
+
diff --git a/tbd/results/tbd_abl_avg_only.pdf b/tbd/results/tbd_abl_avg_only.pdf
new file mode 100644
index 0000000..e97ae6f
Binary files /dev/null and b/tbd/results/tbd_abl_avg_only.pdf differ
diff --git a/tbd/run_test.sh b/tbd/run_test.sh
new file mode 100644
index 0000000..d526d80
--- /dev/null
+++ b/tbd/run_test.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+python -m test \
+--gpu_id 1 \
+--seed 1 \
+--non_blocking \
+--pin_memory \
+--model_type tom_cm \
+--aggr no_tom \
+--hidden_dim 64 \
+--batch_size 64 \
+--load_model_path /PATH/TO/model
\ No newline at end of file
diff --git a/tbd/run_train.sh b/tbd/run_train.sh
new file mode 100644
index 0000000..39ea559
--- /dev/null
+++ b/tbd/run_train.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+python -m train \
+--gpu_id 2 \
+--seed 123 \
+--logger \
+--non_blocking \
+--pin_memory \
+--batch_size 64 \
+--num_workers 16 \
+--num_epoch 300 \
+--lr 5e-4 \
+--dropout 0.1 \
+--model_type tom_cm \
+--aggr no_tom \
+--hidden_dim 64
diff --git a/tbd/tbd_dataloader.py b/tbd/tbd_dataloader.py
new file mode 100644
index 0000000..8875859
--- /dev/null
+++ b/tbd/tbd_dataloader.py
@@ -0,0 +1,568 @@
+from __future__ import annotations
+
+from typing import Optional, Union
+
+import torch
+import pickle
+import torch
+import time
+import glob
+import random
+import os
+import numpy as np
+import pandas as pd
+import cv2
+from itertools import product
+import csv
+import torchvision.transforms as T
+
+from utils.helpers import tracker_skeID, CLIPS_IDS_88, ALL_IDS, UNIQUE_OBJ_IDS
+
+
+def collate_fn(batch):
+ # Unpack the batch into individual elements
+ img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep = zip(*batch)
+
+ # Determine the maximum number of objects in any batch
+ max_n_obj = max(bbox.shape[1] for bbox in bboxes)
+
+ # Pad the bounding box tensors
+ bboxes_pad = []
+ for bbox in bboxes:
+ pad_size = max_n_obj - bbox.shape[1]
+ pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
+ padded_bbox = torch.cat((bbox, pad), dim=1)
+ bboxes_pad.append(padded_bbox)
+
+ # Stack the padded tensors into a batch tensor
+ bboxes_batch = torch.stack(bboxes_pad, dim=0)
+
+ img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
+ img_tracker = torch.stack(img_tracker, dim=0)
+ img_battery = torch.stack(img_battery, dim=0)
+ pose1 = torch.stack(pose1, dim=0)
+ pose2 = torch.stack(pose2, dim=0)
+ gaze = torch.stack(gaze, dim=0)
+ labels = torch.tensor(labels, dtype=torch.long)
+
+ # Return the batched tensors
+ return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep
+
+def collate_fn_test(batch):
+ # Unpack the batch into individual elements
+ img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep, false_beliefs = zip(*batch)
+
+ # Determine the maximum number of objects in any batch
+ max_n_obj = max(bbox.shape[1] for bbox in bboxes)
+
+ # Pad the bounding box tensors
+ bboxes_pad = []
+ for bbox in bboxes:
+ pad_size = max_n_obj - bbox.shape[1]
+ pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
+ padded_bbox = torch.cat((bbox, pad), dim=1)
+ bboxes_pad.append(padded_bbox)
+
+ # Stack the padded tensors into a batch tensor
+ bboxes_batch = torch.stack(bboxes_pad, dim=0)
+
+ img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
+ img_tracker = torch.stack(img_tracker, dim=0)
+ img_battery = torch.stack(img_battery, dim=0)
+ pose1 = torch.stack(pose1, dim=0)
+ pose2 = torch.stack(pose2, dim=0)
+ gaze = torch.stack(gaze, dim=0)
+ labels = torch.tensor(labels, dtype=torch.long)
+
+ # Return the batched tensors
+ return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep, false_beliefs
+
+
+class TBDDataset(torch.utils.data.Dataset):
+
+ def __init__(
+ self,
+ path: str = "/scratch/bortoletto/data/tbd",
+ mode: str = "train",
+ tbd_data_path: str = "/scratch/bortoletto/data/tbd/mind_lstm_training_cnn_att/",
+ list_of_ids_to_consider: list = ALL_IDS,
+ use_preprocessed_img: bool = True,
+ resize_img: Optional[Union[tuple, int]] = (128,128),
+ ):
+ """TBD Dataset based on the 88 clip version of the TBD data.
+
+ Expects the following folder structure:
+ - path
+ - tracker_gt_smooth <- These are eye tracking from POV, 2D coordinates
+ - images/*/ <- These are the images by experiment id
+ - battery <- These are the images, 1st Person
+ - tracker <- These are the images, 1st other Person w/ eye fixation
+ - kinect <- These are the images, 3rd Person
+ - skeleton <- Pose estimation, 3D coordinates
+ - annotation <- These are the labels, i.e. [0,3] (see below)
+
+
+ Labels are strcutured as follows:
+ {
+ "O1": [ <- Object with id O1
+ {
+ "m1": {
+ "fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
+ "loc": null
+ },
+ "m2": {
+ "fluent": 3,
+ "loc": null
+ },
+ "m12": {
+ "fluent": 3,
+ "loc": null
+ },
+ "m21": {
+ "fluent": 3,
+ "loc": null
+ },
+ "mc": {
+ "fluent": 3,
+ "loc": null
+ },
+ "mg": {
+ "fluent": 3,
+ "loc": [
+ 22,
+ 9
+ ]
+ }
+ }, ...
+ ], ...
+ }
+
+ This corresponds to a strict subset of the raw dataset collected
+ by the TBD people in their paper "Learning Traidic Belief Dynamics
+ in Nonverbal Communication from Videos" (CVPR2021, Oral).
+
+ We keep small amounts of data in memory (everything <100MB).
+ Otherwise we read from disk on the fly. This dataset applies normalization.
+
+ Args:
+ path (str, optional): Where the folders lie.
+ Defaults to "/scratch/ruhdorfer/triadic_beleif_data_v2".
+ list_of_ids_to_consider (list, optional): List of ids to consider.
+ Defaults to ALL_IDS. Otherwise specify a list,
+ e.g. ["test_94342_23", "test_boelter_21", ...].
+ resize_img (Optional[Union[tuple, int]], optional): Resize image to
+ this size if required. Defaults to None.
+ """
+ print(f"Loading TBD Dataset in mode {mode}...")
+
+ self.mode = mode
+
+ start = time.time()
+
+ self.skeleton_3D_path = f"{path}/skeleton"
+ self.tracker_2D_path = f"{path}/tracker_gt_smooth"
+ self.bbox_csv_path = f"{path}/annotations_with_bbox.csv"
+ if use_preprocessed_img:
+ self.img_path = f"{path}/images_norm"
+ else:
+ self.img_path = f"{path}/images"
+ self.obj_ids_path = f"{path}/mind_lstm_training_cnn_att_shu.pkl"
+
+ self.label_map = list(product([0, 1, 2, 3], repeat=5))
+
+ clips = os.listdir(tbd_data_path)
+ data = []
+ labels = []
+ for clip in clips:
+ with open(tbd_data_path + clip, 'rb') as f:
+ vec_input, label_ = pickle.load(f, encoding='latin1')
+ data = data + vec_input
+ labels = labels + label_
+ c = list(zip(data, labels))
+ random.shuffle(c)
+ train_ratio = int(len(c) * 0.6)
+ validate_ratio = int(len(c) * 0.2)
+ data, label = zip(*c)
+ train_x, train_y = data[:train_ratio], label[:train_ratio]
+ validate_x, validate_y = data[train_ratio:train_ratio + validate_ratio], label[train_ratio:train_ratio + validate_ratio]
+ test_x, test_y = data[train_ratio + validate_ratio:], label[train_ratio + validate_ratio:]
+ self.mind_count = np.zeros(1024) # used for CE weights
+
+ if mode == "train":
+ self.data, self.labels = train_x, train_y
+ elif mode == "val":
+ self.data, self.labels = validate_x, validate_y
+ elif mode == "test":
+ self.data, self.labels = test_x, test_y
+
+ self.false_beliefs_path = f"{path}/store_mind_set"
+
+ # keep small amouts of data in memory
+ self.skeleton_3D = self.load_skeleton_3D(self.skeleton_3D_path, list_of_ids_to_consider)
+ self.tracker_2D = self.load_tracker_2D(self.tracker_2D_path, list_of_ids_to_consider)
+ self.bbox_df = pd.read_csv(self.bbox_csv_path, header=0)
+ self.obj_ids = self.load_obj_ids(self.obj_ids_path)
+
+ if not use_preprocessed_img:
+ normalisation_steps = [
+ T.ToTensor(),
+ T.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+ ]
+ if resize_img is not None:
+ normalisation_steps.insert(1, T.Resize(resize_img))
+ self.preprocess_img = T.Compose(normalisation_steps)
+ else:
+ self.preprocess_img = None
+
+ self.use_preprocessed_img = use_preprocessed_img
+ print(f"Done loading in {time.time() - start}s.")
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(
+ self, idx: int
+ ) -> tuple[
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict, str, int, str
+ ]:
+ """Given an index, return the corresponding experiment_id and timestep in the experiment.
+ Then picky the appropriate data and labels from these.
+
+ Args:
+ idx (int): _description_
+
+ Returns:
+ tuple: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict
+ Returns the following:
+ - img_kinect: torch.Tensor of shape (T, C, H, W) (Default is [T, 3, 720, 1280])
+ - img_tracker: torch.Tensor of shape (T, H, W, C)
+ - img_battery: torch.Tensor of shape (T, H, W, C)
+ - skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 1)
+ - skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 2)
+ - bbox: torch.Tensor of shape (T, num_obj, 5)
+ - tracker to skeleton ID: str (either skeleton 1 or 2)
+ - tracker_2D: torch.Tensor of shape (T, 2)
+ - labels: dict (see below)
+ """
+ labels = self.label_map[self.labels[idx]]
+ experiment_id = self.data[idx][1][0].split('/')[6]
+ img_data_path = f"{self.img_path}/{experiment_id}"
+ frame_ids = [int(os.path.basename(self.data[idx][1][i]).split('_')[0]) for i in range(len(self.data[idx][1]))]
+
+ if self.use_preprocessed_img:
+ kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.pt")))
+ tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.pt")))
+ battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.pt")))
+ kinect_img_paths = [kinect[id] for id in frame_ids]
+ tracker_img_paths = [tracker[id] for id in frame_ids]
+ battery_img_paths = [battery[id] for id in frame_ids]
+ else:
+ kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.jpg")))
+ tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.jpg")))
+ battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.jpg")))
+ kinect_img_paths = [kinect[id] for id in frame_ids]
+ tracker_img_paths = [tracker[id] for id in frame_ids]
+ battery_img_paths = [battery[id] for id in frame_ids]
+
+ # load images
+ kinect_imgs = [
+ torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
+ for img_path in kinect_img_paths
+ ]
+ kinect_imgs = torch.stack(kinect_imgs, axis=0)
+
+ tracker_imgs = [
+ torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
+ for img_path in tracker_img_paths
+ ]
+ tracker_imgs = torch.stack(tracker_imgs, axis=0)
+
+ battery_imgs = [
+ torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
+ for img_path in battery_img_paths
+ ]
+ battery_imgs = torch.stack(battery_imgs, axis=0)
+
+ # load object id to check for false beliefs - only for testing
+ if self.mode == "test": #or self.mode == "train":
+ if f"{experiment_id}.txt" in os.listdir(self.false_beliefs_path):
+ obj_id = self.obj_ids[experiment_id][frame_ids[-1]]
+ obj_id = next(x for x in obj_id if x is not None)
+ false_belief = next((line.strip().split(',')[2] for line in open(f"{self.false_beliefs_path}/{experiment_id}.txt") if line.startswith(str(frame_ids[-1]) + ',' + obj_id + ',')), "no")
+ #if experiment_id in ['test_boelter4_0', 'test_boelter4_7', 'test_boelter4_6', 'test_boelter4_8', 'test_boelter2_3',
+ # 'test_94342_20', 'test_94342_18', 'test_94342_11', 'test_94342_17', 'test_boelter3_8', 'test_94342_2',
+ # 'test_boelter2_17', 'test_boelter3_7', 'test_94342_4', 'test_boelter3_9', 'test_boelter_10',
+ # 'test_boelter2_6', 'test_boelter4_10', 'test_boelter4_2', 'test_boelter4_5', 'test_94342_24',
+ # 'test_94342_15', 'test_boelter3_5', 'test_94342_8', 'test2', 'test_boelter3_12']:
+ # print('here!')
+ # with open(os.path.join(f'results/hgm_test_fb.csv'), mode='a') as file:
+ # writer = csv.writer(file)
+ # writer.writerow([experiment_id, obj_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
+ else:
+ false_belief = "no"
+ #with open(os.path.join(f'results/test_fb.csv'), mode='a') as file:
+ # writer = csv.writer(file)
+ # writer.writerow([experiment_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
+
+ df = self.bbox_df[
+ (self.bbox_df.experiment_name == experiment_id)
+ #& (self.bbox_df.name == obj_id) # NOTE: load the bounding boxes for all objects
+ & (self.bbox_df.name != 'P1')
+ & (self.bbox_df.name != 'P2')
+ & (self.bbox_df.frame.isin(frame_ids))
+ ]
+
+ bboxes = []
+ for f in frame_ids:
+ bbox = torch.tensor(df.loc[df['frame'] == f, ["x_min", "y_min", "x_max", "y_max"]].to_numpy(), dtype=torch.float32)
+ bbox[:, 0] = bbox[:, 0] / 1280.0
+ bbox[:, 1] = bbox[:, 1] / 720.0
+ bbox[:, 2] = bbox[:, 2] / 1280.0
+ bbox[:, 3] = bbox[:, 3] / 720.0
+ bboxes.append(bbox)
+ bboxes = torch.stack(bboxes) # NOTE: this will need a collate function bc not every video has the same number of objects
+
+ skele1 = self.skeleton_3D[experiment_id]["skele1"][frame_ids]
+ skele2 = self.skeleton_3D[experiment_id]["skele2"][frame_ids]
+
+ gaze = self.tracker_2D[experiment_id][frame_ids]
+
+ if self.mode == "test":
+ return (
+ kinect_imgs,
+ tracker_imgs,
+ battery_imgs,
+ skele1,
+ skele2,
+ bboxes,
+ tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
+ gaze,
+ labels, # <- per object "m1", "m2", "m12", "m21", "mc"
+ experiment_id,
+ frame_ids,
+ #self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
+ false_belief
+ )
+ else:
+ return (
+ kinect_imgs,
+ tracker_imgs,
+ battery_imgs,
+ skele1,
+ skele2,
+ bboxes,
+ tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
+ gaze,
+ labels, # <- per object "m1", "m2", "m12", "m21", "mc"
+ experiment_id,
+ frame_ids
+ #self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
+ )
+
+ def onehot(self, x, n=len(UNIQUE_OBJ_IDS)):
+ retval = torch.zeros(n)
+ if x > 0:
+ retval[x-1] = 1
+ return retval
+
+ def load_obj_ids(self, path: str):
+ with open(path, "rb") as f:
+ ids = pickle.load(f)
+ return ids
+
+ def extract_labels(self):
+ """TODO: Converts index label to [m1, m2, m12, m21, mc] format.
+
+ """
+ return
+
+ def _flatten_mind_obj_timestep(self, mind_obj_dict: dict) -> list:
+ """Flattens the mind object dict to a list. I.e. takes
+
+ {
+ "m1": {
+ "fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
+ "loc": null
+ },
+ "m2": {
+ "fluent": 3,
+ "loc": null
+ },
+ "m12": {
+ "fluent": 3,
+ "loc": null
+ },
+ "m21": {
+ "fluent": 3,
+ "loc": null
+ },
+ "mc": {
+ "fluent": 3,
+ "loc": null
+ },
+ "mg": {
+ "fluent": 3,
+ "loc": [
+ 22,
+ 9
+ ]
+ }
+ }
+
+ and returns [3, 3, 3, 3, 3, 3]
+
+ Args:
+ mind_obj_dict (dict): Mind object dict as described in __init__.doctstring.
+
+ Returns:
+ list: List of mind object labels.
+ """
+ return np.array([mind_obj["fluent"] for key, mind_obj in mind_obj_dict.items() if key != "mg"])
+
+ def load_skeleton_3D(self, path: str, list_of_ids_to_consider: list):
+ """Load skeleton 3D data from disk.
+
+ - path
+ - * <- list of ids
+ - skele1.p <- 3D coord per id and timestep
+ - skele2.p <-
+
+ Args:
+ path (str): Where the skeleton 3D data lie.
+ list_of_ids_to_consider (list): List of ids to consider.
+ Defaults to None which means all ids. Otherwise specify a list,
+ e.g. ["test_94342_23", "test_boelter_21", ...].
+
+ Returns:
+ dict: skeleton 3D data as described above in __init__.doctstring.
+ """
+ skeleton_3D = {}
+ for experiment_id in list_of_ids_to_consider:
+ skeleton_3D[experiment_id] = {}
+ with open(f"{path}/{experiment_id}/skele1.p", "rb") as f:
+ skeleton_3D[experiment_id]["skele1"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
+ with open(f"{path}/{experiment_id}/skele2.p", "rb") as f:
+ skeleton_3D[experiment_id]["skele2"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
+ return skeleton_3D
+
+ def load_tracker_2D(self, path: str, list_of_ids_to_consider: list):
+ """Load tracker 2D data from disk.
+
+ - path
+ - *.p <- 2D coord per id and timestep
+
+ Args:
+ path (str): Where the tracker 2D data lie.
+ list_of_ids_to_consider (list): List of ids to consider.
+ Defaults to None which means all ids. Otherwise specify a list,
+ e.g. ["test_94342_23", "test_boelter_21", ...].
+
+ Returns:
+ dict: tracker 2D data.
+ """
+ tracker_2D = {}
+ for experiment_id in list_of_ids_to_consider:
+ with open(f"{path}/{experiment_id}.p", "rb") as f:
+ tracker_2D[experiment_id] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
+ return tracker_2D
+
+ def load_bbox(self, path: str, list_of_ids_to_consider: list):
+ """Load bbox data from disk.
+
+ - bbox_tensors.pickle <- bbox per experiment id one tensor
+
+ Args:
+ path (str): Where the bbox data lie.
+ list_of_ids_to_consider (list): List of ids to consider.
+
+ Returns:
+ dict: bbox data.
+ """
+ with open(path, "rb") as f:
+ pickle_data = pickle.load(f)
+ for key in CLIPS_IDS_88:
+ if key not in list_of_ids_to_consider:
+ pickle_data.pop(key, None)
+ return pickle_data
+
+
+
+
+
+
+if __name__ == '__main__':
+
+ # os.environ['PYTHONHASHSEED'] = str(42)
+ # torch.manual_seed(42)
+ # np.random.seed(42)
+ # random.seed(42)
+
+ data = TBDDataset(use_preprocessed_img=True, mode="test")
+
+ from tqdm import tqdm
+ for i in tqdm(range(data.__len__())):
+ data[i]
+
+ breakpoint()
+
+ from torch.utils.data import DataLoader
+
+ # Just for guessing time
+ data_0=data[0]
+ data_last=data[len(data)-1]
+ idx = np.random.randint(1, len(data)-1) # Something in between.
+ start = time.time()
+ (
+ kinect_imgs, # <- len x 720 x 1280 x 3 originally, likely smaller now
+ tracker_imgs,
+ battery_imgs,
+ skele1,
+ skele2,
+ bbox,
+ tracker_skeID_sample, # <- This is the tracker skeleton ID
+ tracker2d,
+ label,
+ experiment_id, # From here for debugging
+ timestep,
+ #obj_id, # <- This is the object ID as a one-hot
+ false_belief
+ ) = data[idx]
+ end = time.time()
+ print(f"Time for one sample: {end-start}")
+
+ print('kinect:', kinect_imgs.shape)
+ print('tracker:', tracker_imgs.shape)
+ print('battery:', battery_imgs.shape)
+ print('skele1:', skele1.shape)
+ print('skele2:', skele2.shape)
+ print('gaze:', tracker2d.shape)
+ print('bbox:', bbox.shape)
+ print('label:', label)
+
+ #breakpoint()
+
+ dl = DataLoader(
+ data,
+ batch_size=4,
+ shuffle=False,
+ collate_fn=collate_fn
+ )
+
+ from tqdm import tqdm
+
+ for j, batch in tqdm(enumerate(dl)):
+ #print(j, end='\r')
+ img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
+ #breakpoint()
+ #print(img_3rd_pov.shape)
+ #print(img_tracker.shape)
+ #print(img_battery.shape)
+ #print(pose1.shape, pose2.shape)
+ #print(bbox.shape)
+ #print(gaze.shape)
+
+
+ breakpoint()
diff --git a/tbd/test.py b/tbd/test.py
new file mode 100644
index 0000000..aaf7e9e
--- /dev/null
+++ b/tbd/test.py
@@ -0,0 +1,196 @@
+import torch
+import csv
+import argparse
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+import random
+import os
+import numpy as np
+
+from tbd_dataloader import TBDDataset, collate_fn_test
+from models.common_mind import CommonMindToMnet
+from models.sl import SLToMnet
+from models.implicit import ImplicitToMnet
+from utils.helpers import compute_f1_scores
+
+
+def test(args):
+
+ test_dataset = TBDDataset(
+ path=args.data_path,
+ mode="test",
+ use_preprocessed_img=True
+ )
+ test_dataloader = DataLoader(
+ test_dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_memory,
+ collate_fn=collate_fn_test
+ )
+
+ device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
+
+ # model
+ if args.model_type == 'tom_cm':
+ model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ elif args.model_type == 'tom_sl':
+ model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
+ elif args.model_type == 'tom_impl':
+ model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ else: raise NotImplementedError
+
+ model.load_state_dict(torch.load(args.load_model_path, map_location=device))
+ model.device = device
+
+ model.eval()
+
+ if args.save_preds:
+ # Define the output file path
+ folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}'
+ if not os.path.exists(folder_path):
+ os.makedirs(folder_path)
+ print(f'Saving predictions in {folder_path}.')
+
+ print('Testing...')
+ m1_pred_list = []
+ m2_pred_list = []
+ m12_pred_list = []
+ m21_pred_list = []
+ mc_pred_list = []
+ m1_label_list = []
+ m2_label_list = []
+ m12_label_list = []
+ m21_label_list = []
+ mc_label_list = []
+ with torch.no_grad():
+ for j, batch in tqdm(enumerate(test_dataloader)):
+ img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch
+ if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
+ if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
+ if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
+ if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
+ if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
+ if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
+ if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
+ m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
+ m1_pred = m1_pred.reshape(-1, 4)
+ m2_pred = m2_pred.reshape(-1, 4)
+ m12_pred = m12_pred.reshape(-1, 4)
+ m21_pred = m21_pred.reshape(-1, 4)
+ mc_pred = mc_pred.reshape(-1, 4)
+ m1_label = labels[:, 0].reshape(-1).to(device)
+ m2_label = labels[:, 1].reshape(-1).to(device)
+ m12_label = labels[:, 2].reshape(-1).to(device)
+ m21_label = labels[:, 3].reshape(-1).to(device)
+ mc_label = labels[:, 4].reshape(-1).to(device)
+
+ m1_pred_list.append(m1_pred)
+ m2_pred_list.append(m2_pred)
+ m12_pred_list.append(m12_pred)
+ m21_pred_list.append(m21_pred)
+ mc_pred_list.append(mc_pred)
+ m1_label_list.append(m1_label)
+ m2_label_list.append(m2_label)
+ m12_label_list.append(m12_label)
+ m21_label_list.append(m21_label)
+ mc_label_list.append(mc_label)
+
+ if args.save_preds:
+ torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
+ data = [(
+ i,
+ torch.argmax(m1_pred[i]).cpu().numpy(),
+ torch.argmax(m2_pred[i]).cpu().numpy(),
+ torch.argmax(m12_pred[i]).cpu().numpy(),
+ torch.argmax(m21_pred[i]).cpu().numpy(),
+ torch.argmax(mc_pred[i]).cpu().numpy(),
+ m1_label[i].cpu().numpy(),
+ m2_label[i].cpu().numpy(),
+ m12_label[i].cpu().numpy(),
+ m21_label[i].cpu().numpy(),
+ mc_label[i].cpu().numpy(),
+ false_belief[i]) for i in range(len(labels))
+ ]
+ header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief']
+ with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file:
+ writer = csv.writer(file)
+ writer.writerow(header) # Write the header row
+ writer.writerows(data) # Write the data rows
+
+ #np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy())
+ test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores(
+ torch.cat(m1_pred_list),
+ torch.cat(m1_label_list),
+ torch.cat(m2_pred_list),
+ torch.cat(m2_label_list),
+ torch.cat(m12_pred_list),
+ torch.cat(m12_label_list),
+ torch.cat(m21_pred_list),
+ torch.cat(m21_label_list),
+ torch.cat(mc_pred_list),
+ torch.cat(mc_label_list)
+ )
+
+ print("Test m1 F1: {}".format(test_m1_f1))
+ print("Test m2 F1: {}".format(test_m2_f1))
+ print("Test m12 F1: {}".format(test_m12_f1))
+ print("Test m21 F1: {}".format(test_m21_f1))
+ print("Test mc F1: {}".format(test_mc_f1))
+
+ with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
+ f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}")
+ f.write(f"m1 f1: {test_m1_f1}")
+ f.write(f"m2 f1: {test_m2_f1}")
+ f.write(f"m12 f1: {test_m12_f1}")
+ f.write(f"m21 f1: {test_m21_f1}")
+ f.write(f"mc f1: {test_mc_f1}")
+ f.close()
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ # Define the command-line arguments
+ parser.add_argument('--gpu_id', type=int)
+ parser.add_argument('--seed', type=int, default=1)
+ parser.add_argument('--presaved', type=int, default=128)
+ parser.add_argument('--non_blocking', action='store_true')
+ parser.add_argument('--num_workers', type=int, default=16)
+ parser.add_argument('--pin_memory', action='store_true')
+ parser.add_argument('--model_type', type=str)
+ parser.add_argument('--batch_size', type=int, default=64)
+ parser.add_argument('--aggr', type=str, default='concat', required=False)
+ parser.add_argument('--use_resnet', action='store_true')
+ parser.add_argument('--hidden_dim', type=int, default=64)
+ parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
+ parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
+ parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
+ parser.add_argument('--save_path', type=str, default='experiments/')
+ parser.add_argument('--test_frames', type=str, default=None)
+ parser.add_argument('--median', type=int, default=None)
+ parser.add_argument('--load_model_path', type=str)
+ parser.add_argument('--dropout', type=float, default=0.0)
+ parser.add_argument('--save_preds', action='store_true')
+
+ # Parse the command-line arguments
+ args = parser.parse_args()
+
+ if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
+ if not args.aggr:
+ parser.error("The choosen --model_type requires --aggr")
+ if args.model_type == 'tom_sl' and not args.tom_weight:
+ parser.error("The choosen --model_type requires --tom_weight")
+
+ os.environ['PYTHONHASHSEED'] = str(args.seed)
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ random.seed(args.seed)
+
+ print('###########################################################################')
+ print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!')
+ print('###########################################################################')
+
+ test(args)
\ No newline at end of file
diff --git a/tbd/train.py b/tbd/train.py
new file mode 100644
index 0000000..7ac4b3b
--- /dev/null
+++ b/tbd/train.py
@@ -0,0 +1,474 @@
+import torch
+import os
+import argparse
+import numpy as np
+import random
+import datetime
+import wandb
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+import torch.nn as nn
+from torch.optim.lr_scheduler import CosineAnnealingLR
+
+from tbd_dataloader import TBDDataset, collate_fn
+from models.common_mind import CommonMindToMnet
+from models.sl import SLToMnet
+from models.implicit import ImplicitToMnet
+from utils.helpers import count_parameters, compute_f1_scores
+
+
+def main(args):
+
+ train_dataset = TBDDataset(
+ path=args.data_path,
+ mode="train",
+ use_preprocessed_img=True
+ )
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_memory,
+ collate_fn=collate_fn
+ )
+ val_dataset = TBDDataset(
+ path=args.data_path,
+ mode="val",
+ use_preprocessed_img=True
+ )
+ val_dataloader = DataLoader(
+ val_dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_memory,
+ collate_fn=collate_fn
+ )
+
+ train_data = [data[1] for data in train_dataset.data]
+ val_data = [data[1] for data in val_dataset.data]
+ if args.logger:
+ wandb.config.update({"train_data": train_data})
+ wandb.config.update({"val_data": val_data})
+
+ device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
+
+ # model
+ if args.model_type == 'tom_cm':
+ model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ elif args.model_type == 'tom_sl':
+ model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
+ elif args.model_type == 'tom_impl':
+ model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
+ else: raise NotImplementedError
+ if args.resume_from_checkpoint is not None:
+ model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
+ # optimizer
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+ # scheduler
+ if args.scheduler == None:
+ scheduler = None
+ else:
+ scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
+ # loss function
+ if args.model_type == 'tom_sl':
+ ce_loss_m1 = nn.NLLLoss()
+ ce_loss_m2 = nn.NLLLoss()
+ ce_loss_m12 = nn.NLLLoss()
+ ce_loss_m21 = nn.NLLLoss()
+ ce_loss_mc = nn.NLLLoss()
+ else:
+ ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
+ ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
+ ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
+ ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
+ ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
+
+ stats = {
+ 'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []},
+ 'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}
+ }
+ max_val_f1 = 0
+ max_val_classification_epoch = None
+ counter = 0
+
+ print(f'Number of parameters: {count_parameters(model)}')
+
+ for i in range(args.num_epoch):
+ # training
+ print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
+ epoch_train_loss_m1 = 0.0
+ epoch_train_loss_m2 = 0.0
+ epoch_train_loss_m12 = 0.0
+ epoch_train_loss_m21 = 0.0
+ epoch_train_loss_mc = 0.0
+ m1_train_batch_pred_list = []
+ m2_train_batch_pred_list = []
+ m12_train_batch_pred_list = []
+ m21_train_batch_pred_list = []
+ mc_train_batch_pred_list = []
+ m1_train_batch_label_list = []
+ m2_train_batch_label_list = []
+ m12_train_batch_label_list = []
+ m21_train_batch_label_list = []
+ mc_train_batch_label_list = []
+ model.train()
+ for j, batch in tqdm(enumerate(train_dataloader)):
+ img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
+ if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
+ if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
+ if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
+ if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
+ if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
+ if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
+ if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
+ m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
+ m1_pred = m1_pred.reshape(-1, 4)
+ m2_pred = m2_pred.reshape(-1, 4)
+ m12_pred = m12_pred.reshape(-1, 4)
+ m21_pred = m21_pred.reshape(-1, 4)
+ mc_pred = mc_pred.reshape(-1, 4)
+ m1_label = labels[:, 0].reshape(-1).to(device)
+ m2_label = labels[:, 1].reshape(-1).to(device)
+ m12_label = labels[:, 2].reshape(-1).to(device)
+ m21_label = labels[:, 3].reshape(-1).to(device)
+ mc_label = labels[:, 4].reshape(-1).to(device)
+
+ loss_m1 = ce_loss_m1(m1_pred, m1_label)
+ loss_m2 = ce_loss_m2(m2_pred, m2_label)
+ loss_m12 = ce_loss_m12(m12_pred, m12_label)
+ loss_m21 = ce_loss_m21(m21_pred, m21_label)
+ loss_mc = ce_loss_mc(mc_pred, mc_label)
+ loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
+
+ epoch_train_loss_m1 += loss_m1.data.item()
+ epoch_train_loss_m2 += loss_m2.data.item()
+ epoch_train_loss_m12 += loss_m12.data.item()
+ epoch_train_loss_m21 += loss_m21.data.item()
+ epoch_train_loss_mc += loss_mc.data.item()
+
+ m1_train_batch_pred_list.append(m1_pred)
+ m2_train_batch_pred_list.append(m2_pred)
+ m12_train_batch_pred_list.append(m12_pred)
+ m21_train_batch_pred_list.append(m21_pred)
+ mc_train_batch_pred_list.append(mc_pred)
+ m1_train_batch_label_list.append(m1_label)
+ m2_train_batch_label_list.append(m2_label)
+ m12_train_batch_label_list.append(m12_label)
+ m21_train_batch_label_list.append(m21_label)
+ mc_train_batch_label_list.append(mc_label)
+
+ optimizer.zero_grad()
+ if args.clip_grad_norm is not None:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
+ loss.backward()
+ optimizer.step()
+
+ if args.logger: wandb.log({
+ 'batch_train_loss': loss.data.item(),
+ 'lr': optimizer.param_groups[-1]['lr']
+ })
+
+ print("Epoch {}/{} batch {}/{} training done with loss={}".format(
+ i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item())
+ )
+
+ if scheduler: scheduler.step()
+
+ train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores(
+ torch.cat(m1_train_batch_pred_list),
+ torch.cat(m1_train_batch_label_list),
+ torch.cat(m2_train_batch_pred_list),
+ torch.cat(m2_train_batch_label_list),
+ torch.cat(m12_train_batch_pred_list),
+ torch.cat(m12_train_batch_label_list),
+ torch.cat(m21_train_batch_pred_list),
+ torch.cat(m21_train_batch_label_list),
+ torch.cat(mc_train_batch_pred_list),
+ torch.cat(mc_train_batch_label_list)
+ )
+
+ print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format(
+ i+1,
+ args.num_epoch,
+ epoch_train_loss_m1/len(train_dataloader),
+ epoch_train_loss_m2/len(train_dataloader),
+ epoch_train_loss_m12/len(train_dataloader),
+ epoch_train_loss_m21/len(train_dataloader),
+ epoch_train_loss_mc/len(train_dataloader),
+ train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score
+ )
+ )
+ stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader))
+ stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader))
+ stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader))
+ stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader))
+ stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader))
+ stats['train']['m1_f1'].append(train_m1_f1_score)
+ stats['train']['m2_f1'].append(train_m2_f1_score)
+ stats['train']['m12_f1'].append(train_m12_f1_score)
+ stats['train']['m21_f1'].append(train_m21_f1_score)
+ stats['train']['mc_f1'].append(train_mc_f1_score)
+
+ if args.logger: wandb.log(
+ {
+ 'train_m1_loss': epoch_train_loss_m1/len(train_dataloader),
+ 'train_m2_loss': epoch_train_loss_m2/len(train_dataloader),
+ 'train_m12_loss': epoch_train_loss_m12/len(train_dataloader),
+ 'train_m21_loss': epoch_train_loss_m21/len(train_dataloader),
+ 'train_mc_loss': epoch_train_loss_mc/len(train_dataloader),
+ 'train_loss': epoch_train_loss_m1/len(train_dataloader) + \
+ epoch_train_loss_m2/len(train_dataloader) + \
+ epoch_train_loss_m12/len(train_dataloader) + \
+ epoch_train_loss_m21/len(train_dataloader) + \
+ epoch_train_loss_mc/len(train_dataloader),
+ 'train_m1_f1_score': train_m1_f1_score,
+ 'train_m2_f1_score': train_m2_f1_score,
+ 'train_m12_f1_score': train_m12_f1_score,
+ 'train_m21_f1_score': train_m21_f1_score,
+ 'train_mc_f1_score': train_mc_f1_score
+ }
+ )
+
+ # validation
+ print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
+ epoch_val_loss_m1 = 0.0
+ epoch_val_loss_m2 = 0.0
+ epoch_val_loss_m12 = 0.0
+ epoch_val_loss_m21 = 0.0
+ epoch_val_loss_mc = 0.0
+ m1_val_batch_pred_list = []
+ m2_val_batch_pred_list = []
+ m12_val_batch_pred_list = []
+ m21_val_batch_pred_list = []
+ mc_val_batch_pred_list = []
+ m1_val_batch_label_list = []
+ m2_val_batch_label_list = []
+ m12_val_batch_label_list = []
+ m21_val_batch_label_list = []
+ mc_val_batch_label_list = []
+ model.eval()
+ with torch.no_grad():
+ for j, batch in tqdm(enumerate(val_dataloader)):
+ img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
+ if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
+ if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
+ if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
+ if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
+ if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
+ if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
+ if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
+ m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
+ m1_pred = m1_pred.reshape(-1, 4)
+ m2_pred = m2_pred.reshape(-1, 4)
+ m12_pred = m12_pred.reshape(-1, 4)
+ m21_pred = m21_pred.reshape(-1, 4)
+ mc_pred = mc_pred.reshape(-1, 4)
+ m1_label = labels[:, 0].reshape(-1).to(device)
+ m2_label = labels[:, 1].reshape(-1).to(device)
+ m12_label = labels[:, 2].reshape(-1).to(device)
+ m21_label = labels[:, 3].reshape(-1).to(device)
+ mc_label = labels[:, 4].reshape(-1).to(device)
+
+ loss_m1 = ce_loss_m1(m1_pred, m1_label)
+ loss_m2 = ce_loss_m2(m2_pred, m2_label)
+ loss_m12 = ce_loss_m12(m12_pred, m12_label)
+ loss_m21 = ce_loss_m21(m21_pred, m21_label)
+ loss_mc = ce_loss_mc(mc_pred, mc_label)
+ loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
+
+ epoch_val_loss_m1 += loss_m1.data.item()
+ epoch_val_loss_m2 += loss_m2.data.item()
+ epoch_val_loss_m12 += loss_m12.data.item()
+ epoch_val_loss_m21 += loss_m21.data.item()
+ epoch_val_loss_mc += loss_mc.data.item()
+
+ m1_val_batch_pred_list.append(m1_pred)
+ m2_val_batch_pred_list.append(m2_pred)
+ m12_val_batch_pred_list.append(m12_pred)
+ m21_val_batch_pred_list.append(m21_pred)
+ mc_val_batch_pred_list.append(mc_pred)
+ m1_val_batch_label_list.append(m1_label)
+ m2_val_batch_label_list.append(m2_label)
+ m12_val_batch_label_list.append(m12_label)
+ m21_val_batch_label_list.append(m21_label)
+ mc_val_batch_label_list.append(mc_label)
+
+ if args.logger: wandb.log({'batch_val_loss': loss.data.item()})
+ print("Epoch {}/{} batch {}/{} validation done with loss={}".format(
+ i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item())
+ )
+
+ val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores(
+ torch.cat(m1_val_batch_pred_list),
+ torch.cat(m1_val_batch_label_list),
+ torch.cat(m2_val_batch_pred_list),
+ torch.cat(m2_val_batch_label_list),
+ torch.cat(m12_val_batch_pred_list),
+ torch.cat(m12_val_batch_label_list),
+ torch.cat(m21_val_batch_pred_list),
+ torch.cat(m21_val_batch_label_list),
+ torch.cat(mc_val_batch_pred_list),
+ torch.cat(mc_val_batch_label_list)
+ )
+
+ print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format(
+ i+1,
+ args.num_epoch,
+ epoch_val_loss_m1/len(val_dataloader),
+ epoch_val_loss_m2/len(val_dataloader),
+ epoch_val_loss_m12/len(val_dataloader),
+ epoch_val_loss_m21/len(val_dataloader),
+ epoch_val_loss_mc/len(val_dataloader),
+ val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score
+ )
+ )
+
+ stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader))
+ stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader))
+ stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader))
+ stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader))
+ stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader))
+ stats['val']['m1_f1'].append(val_m1_f1_score)
+ stats['val']['m2_f1'].append(val_m2_f1_score)
+ stats['val']['m12_f1'].append(val_m12_f1_score)
+ stats['val']['m21_f1'].append(val_m21_f1_score)
+ stats['val']['mc_f1'].append(val_mc_f1_score)
+
+ if args.logger: wandb.log(
+ {
+ 'val_m1_loss': epoch_val_loss_m1/len(val_dataloader),
+ 'val_m2_loss': epoch_val_loss_m2/len(val_dataloader),
+ 'val_m12_loss': epoch_val_loss_m12/len(val_dataloader),
+ 'val_m21_loss': epoch_val_loss_m21/len(val_dataloader),
+ 'val_mc_loss': epoch_val_loss_mc/len(val_dataloader),
+ 'val_loss': epoch_val_loss_m1/len(val_dataloader) + \
+ epoch_val_loss_m2/len(val_dataloader) + \
+ epoch_val_loss_m12/len(val_dataloader) + \
+ epoch_val_loss_m21/len(val_dataloader) + \
+ epoch_val_loss_mc/len(val_dataloader),
+ 'val_m1_f1_score': val_m1_f1_score,
+ 'val_m2_f1_score': val_m2_f1_score,
+ 'val_m12_f1_score': val_m12_f1_score,
+ 'val_m21_f1_score': val_m21_f1_score,
+ 'val_mc_f1_score': val_mc_f1_score
+ }
+ )
+
+ # check for best stat/model using validation accuracy
+ if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1:
+ max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1]
+ max_val_classification_epoch = i+1
+ torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
+ counter = 0
+ else:
+ counter += 1
+ print(f'EarlyStopping counter: {counter} out of {args.patience}.')
+ if counter >= args.patience:
+ break
+
+ with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
+ f.write('{}\n'.format(CFG))
+ f.write('{}\n'.format(train_data))
+ f.write('{}\n'.format(val_data))
+ f.write('{}\n'.format(stats))
+ f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1))
+ f.close()
+
+ print(f'Results saved in {experiment_save_path}')
+
+
+
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ # Define the command-line arguments
+ parser.add_argument('--gpu_id', type=int)
+ parser.add_argument('--seed', type=int, default=1)
+ parser.add_argument('--logger', action='store_true')
+ parser.add_argument('--presaved', type=int, default=128)
+ parser.add_argument('--clip_grad_norm', type=float, default=0.5)
+ parser.add_argument('--use_mixup', action='store_true')
+ parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
+ parser.add_argument('--non_blocking', action='store_true')
+ parser.add_argument('--patience', type=int, default=99)
+ parser.add_argument('--batch_size', type=int, default=4)
+ parser.add_argument('--num_workers', type=int, default=8)
+ parser.add_argument('--pin_memory', action='store_true')
+ parser.add_argument('--num_epoch', type=int, default=300)
+ parser.add_argument('--lr', type=float, default=4e-4)
+ parser.add_argument('--scheduler', type=str, default=None)
+ parser.add_argument('--dropout', type=float, default=0.1)
+ parser.add_argument('--weight_decay', type=float, default=0.005)
+ parser.add_argument('--label_smoothing', type=float, default=0.1)
+ parser.add_argument('--model_type', type=str)
+ parser.add_argument('--aggr', type=str, default='concat', required=False)
+ parser.add_argument('--use_resnet', action='store_true')
+ parser.add_argument('--hidden_dim', type=int, default=64)
+ parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
+ parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
+ parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
+ parser.add_argument('--save_path', type=str, default='experiments/')
+ parser.add_argument('--resume_from_checkpoint', type=str, default=None)
+
+ # Parse the command-line arguments
+ args = parser.parse_args()
+
+ if args.use_mixup and not args.mixup_alpha:
+ parser.error("--use_mixup requires --mixup_alpha")
+ if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
+ if not args.aggr:
+ parser.error("The choosen --model_type requires --aggr")
+ if args.model_type == 'tom_sl' and not args.tom_weight:
+ parser.error("The choosen --model_type requires --tom_weight")
+
+ # get experiment ID
+ experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
+ if not os.path.exists(args.save_path):
+ os.makedirs(args.save_path, exist_ok=True)
+ experiment_save_path = os.path.join(args.save_path, experiment_id)
+ os.makedirs(experiment_save_path, exist_ok=True)
+
+ CFG = {
+ 'use_ocr_custom_loss': 0,
+ 'presaved': args.presaved,
+ 'batch_size': args.batch_size,
+ 'num_epoch': args.num_epoch,
+ 'lr': args.lr,
+ 'scheduler': args.scheduler,
+ 'weight_decay': args.weight_decay,
+ 'model_type': args.model_type,
+ 'use_resnet': args.use_resnet,
+ 'hidden_dim': args.hidden_dim,
+ 'tom_weight': args.tom_weight,
+ 'dropout': args.dropout,
+ 'label_smoothing': args.label_smoothing,
+ 'clip_grad_norm': args.clip_grad_norm,
+ 'use_mixup': args.use_mixup,
+ 'mixup_alpha': args.mixup_alpha,
+ 'non_blocking_tensors': args.non_blocking,
+ 'patience': args.patience,
+ 'pin_memory': args.pin_memory,
+ 'resume_from_checkpoint': args.resume_from_checkpoint,
+ 'aggr': args.aggr,
+ 'mods': args.mods,
+ 'save_path': experiment_save_path ,
+ 'seed': args.seed
+ }
+
+ print(CFG)
+ print(f'Saving results in {experiment_save_path}')
+
+ # set seed values
+ if args.logger:
+ wandb.init(project="tbd", config=CFG)
+ os.environ['PYTHONHASHSEED'] = str(args.seed)
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ random.seed(args.seed)
+
+ main(args)
\ No newline at end of file
diff --git a/tbd/utils/fb_scores_err.py b/tbd/utils/fb_scores_err.py
new file mode 100644
index 0000000..7e9f3b0
--- /dev/null
+++ b/tbd/utils/fb_scores_err.py
@@ -0,0 +1,224 @@
+import os
+import csv
+import matplotlib.pyplot as plt
+import seaborn as sns
+from sklearn.metrics import f1_score
+import numpy as np
+from tqdm import tqdm
+
+ALPHA = 0.7
+BAR_WIDTH = 0.27
+sns.set_theme(style='whitegrid')
+#sns.set_palette('mako')
+
+MTOM_COLORS = {
+ "MN1": (110/255, 117/255, 161/255),
+ "MN2": (179/255, 106/255, 98/255),
+ "Base": (193/255, 198/255, 208/255),
+ "CG": (170/255, 129/255, 42/255),
+ "IC": (97/255, 112/255, 83/255),
+ "DB": (144/255, 63/255, 110/255)
+}
+
+model_to_subdir = {
+ "IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"],
+ "IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"],
+ "IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"],
+ "IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"],
+ "CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"],
+ "CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"],
+ "CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"],
+ "CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"],
+ "DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"],
+ "Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"]
+}
+
+def read_data_from_csv(subdirectory_path):
+ print(subdirectory_path)
+ data = []
+ csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')]
+ for csv_file in csv_files:
+ file_path = os.path.join(subdirectory_path, csv_file)
+ with open(file_path, 'r') as file:
+ reader = csv.reader(file)
+ header_skipped = False
+ for row in reader:
+ if not header_skipped:
+ header_skipped = True
+ continue
+ frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row
+ data.append({
+ 'frame': int(frame),
+ 'm1_pred': int(m1_pred),
+ 'm2_pred': int(m2_pred),
+ 'm12_pred': int(m12_pred),
+ 'm21_pred': int(m21_pred),
+ 'mc_pred': int(mc_pred),
+ 'm1_label': int(m1_label),
+ 'm2_label': int(m2_label),
+ 'm12_label': int(m12_label),
+ 'm21_label': int(m21_label),
+ 'mc_label': int(mc_label),
+ 'false_belief': false_belief,
+ })
+ return data
+
+def compute_correct_false_belief(data, mind="all", folder=None):
+ total_false_belief = 0
+ correct_false_belief = 0
+ for item in data:
+ if 'false' in item['false_belief']:
+ false_belief_type = item['false_belief'].split('_')[0]
+ if mind == "all" or false_belief_type in mind:
+ total_false_belief += 1
+ if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]:
+ if folder is not None:
+ with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
+ f.write(f"{str(1)}\n")
+ correct_false_belief += 1
+ else:
+ if folder is not None:
+ with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
+ f.write(f"{str(0)}\n")
+ if total_false_belief == 0:
+ accuracy = 0.0
+ else:
+ accuracy = correct_false_belief / total_false_belief
+ return accuracy
+
+def compute_macro_f1_score(data, mind="all"):
+ y_true = []
+ y_pred = []
+
+ for item in data:
+ if 'false' in item['false_belief']:
+ false_belief_type = item['false_belief'].split('_')[0]
+ if mind == "all" or false_belief_type in mind:
+ y_true.append(int(item[f"{false_belief_type}_label"]))
+ y_pred.append(int(item[f"{false_belief_type}_pred"]))
+
+ if not y_true or not y_pred:
+ macro_f1 = 0.0
+ else:
+ macro_f1 = f1_score(y_true, y_pred, average='macro')
+
+ return macro_f1
+
+def delete_files_in_subfolders(folder_path, file_names_to_delete):
+ """
+ Delete specified files in all subfolders of a given folder.
+
+ Parameters:
+ folder_path: The path to the folder containing subfolders.
+ file_names_to_delete: A list of file names to be deleted.
+
+ Returns:
+ None
+ """
+ for root, _, _ in os.walk(folder_path):
+ for file_name in file_names_to_delete:
+ file_path = os.path.join(root, file_name)
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ print(f"Deleted: {file_path}")
+
+
+
+
+if __name__ == "__main__":
+
+ folder_path = "predictions"
+ files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"]
+ delete_files_in_subfolders(folder_path, files_to_delete)
+
+ metric = "Accuracy"
+ if metric == "Macro F1":
+ score_function = compute_macro_f1_score
+ elif metric == "Accuracy":
+ score_function = compute_correct_false_belief
+ else:
+ raise ValueError
+
+ models = [
+ 'Base', 'DB',
+ 'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$',
+ 'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$'
+ ]
+
+ parent_dir = 'predictions'
+ minds = categories = ['m1', 'm2', 'm12', 'm21']
+ score_m1_m2 = []
+ score_m12_m21 = []
+ score_all = []
+ std_m1_m2 = []
+ std_m12_m21 = []
+ std_all = []
+
+ for model in models:
+ model_scores_m1_m2 = []
+ model_scores_m12_m21 = []
+ model_scores_all = []
+ for s in range(3):
+ subdir_path = os.path.join(parent_dir, model_to_subdir[model][s])
+ data = read_data_from_csv(subdir_path)
+ model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s]))
+ model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s]))
+ model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s]))
+ score_m1_m2.append(np.mean(model_scores_m1_m2))
+ std_m1_m2.append(np.std(model_scores_m1_m2))
+ score_m12_m21.append(np.mean(model_scores_m12_m21))
+ std_m12_m21.append(np.std(model_scores_m12_m21))
+ score_all.append(np.mean(model_scores_all))
+ std_all.append(np.std(model_scores_all))
+
+ # Create a dataframe to use with sns.catplot
+ data = {
+ 'Model': [m for m in models],
+ 'FO_FB_mean': score_m1_m2,
+ 'FO_FB_std': std_m1_m2,
+ 'SO_FB_mean': score_m12_m21,
+ 'SO_FB_std': std_m12_m21,
+ 'Both_mean': score_all,
+ 'Both_std': std_all
+ }
+
+ models = data['Model']
+ fo_fb_mean = data['FO_FB_mean']
+ fo_fb_std = data['FO_FB_std']
+ so_fb_mean = data['SO_FB_mean']
+ so_fb_std = data['SO_FB_std']
+ both_mean = data['Both_mean']
+ both_std = data['Both_std']
+
+ bar_width = BAR_WIDTH
+ x = np.arange(len(models))
+
+ plt.figure(figsize=(13, 3.5))
+ fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA)
+ so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA)
+ both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA)
+
+ def add_labels(bars, std_values):
+ cnt = 0
+ for bar, std in zip(bars, std_values):
+ height = bar.get_height()
+ offset = std + 0.01
+ if cnt == 0 or cnt == 1 or cnt == 9:
+ plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10)
+ else:
+ plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10)
+ cnt = cnt + 1
+
+ add_labels(fo_fb_bars, fo_fb_std)
+ add_labels(so_fb_bars, so_fb_std)
+ add_labels(both_bars, both_std)
+
+ plt.gca().spines['top'].set_visible(False)
+ plt.gca().spines['right'].set_visible(False)
+ plt.xlabel('MToMnet', fontsize=14)
+ plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14)
+ plt.xticks(x, models, rotation=0, fontsize=14)
+ plt.yticks(fontsize=14)
+ plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3)
+ plt.tight_layout()
+ plt.savefig('results/false_belief_first_vs_second.pdf')
diff --git a/tbd/utils/helpers.py b/tbd/utils/helpers.py
new file mode 100644
index 0000000..97c4a01
--- /dev/null
+++ b/tbd/utils/helpers.py
@@ -0,0 +1,210 @@
+import numpy as np
+import random
+import torch
+from sklearn.metrics import f1_score
+
+
+tracker_skeID = {
+ 'test1': 'skele1', 'test2': 'skele2', 'test6': 'skele2', 'test7': 'skele1','test_9434_1': 'skele2',
+ 'test_9434_3': 'skele2', 'test_9434_18': 'skele1', 'test_94342_0': 'skele2', 'test_94342_1': 'skele2',
+ 'test_94342_2': 'skele2', 'test_94342_3': 'skele2', 'test_94342_4': 'skele1', 'test_94342_5': 'skele1',
+ 'test_94342_6': 'skele1', 'test_94342_7': 'skele1', 'test_94342_8': 'skele1',
+ 'test_94342_10': 'skele2', 'test_94342_11': 'skele2', 'test_94342_12': 'skele1',
+ 'test_94342_13': 'skele2', 'test_94342_14': 'skele1', 'test_94342_15': 'skele2',
+ 'test_94342_16': 'skele1', 'test_94342_17': 'skele2', 'test_94342_18': 'skele1',
+ 'test_94342_19': 'skele2', 'test_94342_20': 'skele1', 'test_94342_21': 'skele2',
+ 'test_94342_22': 'skele1', 'test_94342_23': 'skele1', 'test_94342_24': 'skele1',
+ 'test_94342_25': 'skele2', 'test_94342_26': 'skele1', 'test_boelter_1': 'skele2',
+ 'test_boelter_2': 'skele2', 'test_boelter_3': 'skele2',
+ 'test_boelter_4': 'skele1', 'test_boelter_5': 'skele1', 'test_boelter_6': 'skele1',
+ 'test_boelter_7': 'skele1', 'test_boelter_9': 'skele1', 'test_boelter_10': 'skele1',
+ 'test_boelter_12': 'skele2', 'test_boelter_13': 'skele1', 'test_boelter_14': 'skele1',
+ 'test_boelter_15': 'skele1', 'test_boelter_17': 'skele2', 'test_boelter_18': 'skele1',
+ 'test_boelter_19': 'skele2', 'test_boelter_21': 'skele1', 'test_boelter_22': 'skele2',
+ 'test_boelter_24': 'skele1', 'test_boelter_25': 'skele1',
+ 'test_boelter2_0': 'skele1', 'test_boelter2_2': 'skele1', 'test_boelter2_3': 'skele1',
+ 'test_boelter2_4': 'skele1', 'test_boelter2_5': 'skele1', 'test_boelter2_6': 'skele1',
+ 'test_boelter2_7': 'skele2', 'test_boelter2_8': 'skele2', 'test_boelter2_12': 'skele2',
+ 'test_boelter2_14': 'skele2', 'test_boelter2_15': 'skele2', 'test_boelter2_16': 'skele1',
+ 'test_boelter2_17': 'skele1',
+ 'test_boelter3_0': 'skele1', 'test_boelter3_1': 'skele2', 'test_boelter3_2': 'skele2',
+ 'test_boelter3_3': 'skele2', 'test_boelter3_4': 'skele1', 'test_boelter3_5': 'skele2',
+ 'test_boelter3_6': 'skele2', 'test_boelter3_7': 'skele1', 'test_boelter3_8': 'skele2',
+ 'test_boelter3_9': 'skele2', 'test_boelter3_10': 'skele1', 'test_boelter3_11': 'skele2',
+ 'test_boelter3_12': 'skele2', 'test_boelter3_13': 'skele2',
+ 'test_boelter4_0': 'skele2', 'test_boelter4_1': 'skele2', 'test_boelter4_2': 'skele2',
+ 'test_boelter4_3': 'skele2', 'test_boelter4_4': 'skele2', 'test_boelter4_5': 'skele2',
+ 'test_boelter4_6': 'skele2', 'test_boelter4_7': 'skele2', 'test_boelter4_8': 'skele2',
+ 'test_boelter4_9': 'skele2', 'test_boelter4_10': 'skele2', 'test_boelter4_11': 'skele2',
+ 'test_boelter4_12': 'skele2', 'test_boelter4_13': 'skele2',
+}
+
+event_seg_tracker = {
+ 'test_9434_18': [[0, 749, 0], [750, 824, 0], [825, 863, 2], [864, 974, 0], [975, 1041, 0]],
+ 'test_94342_1': [[0, 13, 0], [14, 104, 0], [105, 333, 0], [334, 451, 0], [452, 652, 0], [653, 897, 0], [898, 1076, 0], [1077, 1181, 0], [1181, 1266, 0],[1267, 1386, 0]],
+ 'test_94342_6': [[0, 95, 0], [96, 267, 1], [268, 441, 1], [442, 559, 1], [560, 681, 1], [682, 796, 1], [797, 835, 1], [836, 901, 0], [902, 943, 1]],
+ 'test_94342_10': [[0, 36, 0], [37, 169, 0], [170, 244, 1], [245, 424, 0], [425, 599, 0], [600, 640, 0], [641, 680, 0], [681, 726, 1], [727, 866, 2], [867, 1155, 2]],
+ 'test_94342_21': [[0, 13, 0], [14, 66, 2], [67, 594, 2], [595, 1097, 2], [1098, 1133, 0]],
+ 'test1': [[0, 477, 0], [478, 559, 0], [560, 689, 2], [690, 698, 0]],
+ 'test6': [[0, 140, 0], [141, 375, 0], [376, 678, 0], [679, 703, 0]],
+ 'test7': [[0, 100, 0], [101, 220, 2], [221, 226, 0]],
+ 'test_boelter_2': [[0, 154, 0], [155, 279, 0], [280, 371, 0], [372, 450, 0], [451, 470, 0], [471, 531, 0],[532, 606, 0]],
+ 'test_boelter_7': [[0, 69, 0], [70, 118, 1], [119, 239, 0], [240, 328, 1], [329, 376, 0], [377, 397, 1], [398, 520, 0], [521, 564, 0], [565, 619, 1], [620, 688, 1], [689, 871, 0], [872, 897, 0], [898, 958, 1], [959, 1010, 0], [1011, 1084, 0], [1085, 1140, 0], [1141, 1178, 0], [1179, 1267, 1], [1268, 1317, 0], [1318, 1327, 0]],
+ 'test_boelter_24': [[0, 62, 0], [63, 185, 2], [186, 233, 2], [234, 292, 2], [293, 314, 0]],
+ 'test_boelter_12': [[0, 47, 1], [48, 119, 0], [120, 157, 1], [158, 231, 0], [232, 317, 0], [318, 423, 0], [424,459,0], [460, 522, 0], [523, 586, 0], [587, 636, 0], [637, 745, 1], [746, 971, 2]],
+ 'test_9434_1': [[0, 57, 0], [58, 124, 0], [125, 182, 1], [183, 251, 2],[252, 417, 0]],
+ 'test_94342_16': [[0, 21, 0], [22, 45, 0], [46, 84, 0], [85, 158, 1], [159, 200, 1], [201, 214, 0],[215, 370, 1], [371, 524, 1], [525, 587, 2], [588, 782, 2],[783, 1009, 2]],
+ 'test_boelter4_12': [[0, 141, 0], [142, 462, 2], [463, 605, 0], [606, 942, 2], [943, 1232, 2], [1233, 1293, 0]],
+ 'test_boelter4_9': [[0, 27, 0], [28, 172, 0], [173, 221, 0], [222, 307, 1], [308, 466, 0], [467, 794, 1], [795, 866, 1], [867, 1005, 2], [1006, 1214, 2], [1215, 1270, 0]],
+ 'test_boelter4_4': [[0, 120, 0], [121, 183, 0], [184, 280, 1], [281, 714, 0]],
+ 'test_boelter4_3': [[0, 117, 0], [118, 200, 1], [201, 293, 1], [294, 404, 1], [405, 600, 1], [601, 800, 1], [801, 905, 1],[906, 1234, 1]],
+ 'test_boelter4_1': [[0, 310, 0], [311, 560, 0], [561, 680, 0], [681, 748, 0], [749, 839, 0], [840, 1129, 0], [1130, 1237, 0]],
+ 'test_boelter3_13': [[0, 204, 2], [205, 300, 2], [301, 488, 2], [489, 755, 2]],
+ 'test_boelter3_11': [[0, 254, 1], [255, 424, 0], [425, 598, 1], [599, 692, 0], [693, 772, 2], [773, 878, 2], [879, 960, 2], [961, 1171, 2],[1172, 1397, 2]],
+ 'test_boelter3_6': [[0, 174, 1], [175, 280, 1], [281, 639, 0], [640, 695, 1], [696, 788, 0], [789, 887, 2], [888, 1035, 1], [1036, 1445, 2]],
+ 'test_boelter3_4': [[0, 158, 1], [159, 309, 1], [310, 477, 1], [478, 668, 1], [669, 780, 1], [781, 817, 0], [818, 848, 1], [849, 942, 1]],
+ 'test_boelter3_0': [[0, 140, 0], [141, 353, 0], [354, 599, 0], [600, 727, 0],[728, 768, 0]],
+ 'test_boelter2_15': [[0, 46, 0], [47, 252, 2], [253, 298, 1], [299, 414, 2], [415, 547, 2], [548, 690, 1], [691, 728, 1], [729, 773, 2],[774, 935, 2]],
+ 'test_boelter2_12': [[0, 163, 0], [164, 285, 1], [286, 444, 1], [445, 519, 0], [520, 583, 1], [584, 623, 0], [624, 660, 0], [661, 854, 1], [855, 921, 1], [922, 1006, 2], [1007, 1125, 2],[1126, 1332, 2], [1333, 1416, 2]],
+ 'test_boelter2_5': [[0, 94, 0], [95, 176, 1], [177, 246, 1], [247, 340, 1], [341, 442, 1], [443, 547, 1], [548, 654, 1], [655, 734, 0], [735, 792, 0], [793, 1019, 0], [1020, 1088, 0], [1089, 1206, 0], [1207, 1316, 1], [1317, 1466, 1], [1467, 1787, 2], [1788, 1936, 1], [1937, 2084, 2]],
+ 'test_boelter2_4': [[0, 260, 1], [261, 421, 1], [422, 635, 1], [636, 741, 1], [742, 846, 1], [847, 903, 1], [904, 953, 1], [954, 1005, 1], [1006, 1148, 1], [1149, 1270, 1], [1271, 1525, 1]],
+ 'test_boelter2_2': [[0, 131, 0], [132, 226, 0], [227, 267, 0], [268, 352, 0], [353, 412, 0], [413, 457, 0], [458, 502, 0], [503, 532, 0], [533, 578, 0], [579, 640, 0], [641, 722, 0], [723, 826, 0], [827, 913, 0], [914, 992, 0], [993, 1070, 0], [1071, 1265, 0], [1266, 1412, 0]],
+ 'test_boelter_21': [[0, 238, 1], [239, 310, 0], [311, 373, 1], [374, 457, 0],[458, 546, 2], [547, 575, 1], [576, 748, 2], [749, 952, 2]],
+}
+
+event_seg_battery={
+ 'test1': [[0, 94, 0], [95, 155, 0], [156, 225, 0], [226, 559, 0], [560, 689, 2], [690, 698, 0]],
+ 'test7': [[0, 70, 0], [71, 100, 0], [101, 220, 2], [221, 226, 0]],
+ 'test6': [[0, 488, 0], [489, 541, 0], [542, 672, 0], [673, 703, 0]],
+ 'test_94342_10': [[0, 156, 0], [157, 169, 0], [170, 244, 1], [245, 274, 0], [275, 389, 0], [390, 525, 0], [526, 665, 0], [666, 680, 0], [681, 726, 1], [727, 866, 2], [867, 1155, 2]],
+ 'test_94342_1': [[0, 751, 0], [752, 876, 0], [877, 1167, 0], [1168, 1386, 0]],
+ 'test_9434_18': [[0, 96, 0], [97, 361, 0], [362, 528, 0], [529, 608, 0], [609, 824, 0], [825, 863, 2], [864, 1041, 0]],
+ 'test_94342_6': [[0, 95, 0], [96, 267, 1], [268, 441, 1], [442, 559, 1], [560, 681, 1], [682, 796, 1], [797, 835, 1], [836, 901, 0], [902, 943, 1]],
+ 'test_boelter_24': [[0, 62, 0], [63, 185, 2], [186, 233, 2], [234, 292, 2], [293, 314, 0]],
+ 'test_boelter2_4': [[0, 260, 1], [261, 421, 1], [422, 635, 1], [636, 741, 1], [742, 846, 1], [847, 903, 1], [904, 953, 1], [954, 1005, 1], [1006, 1148, 1], [1149, 1270, 1], [1271, 1525, 1]],
+ 'test_boelter2_5': [[0, 94, 0], [95, 176, 1], [177, 246, 1], [247, 340, 1], [341, 442, 1], [443, 547, 1], [548, 654, 1], [655, 1206, 0], [1207, 1316, 1], [1317, 1466, 1], [1467, 1787, 2], [1788, 1936, 1], [1937, 2084, 2]],
+ 'test_boelter2_2': [[0, 145, 0], [146, 224, 0], [225, 271, 0], [272, 392, 0], [393, 454, 0], [455, 762, 0], [763, 982, 0], [983, 1412, 0]],
+ 'test_boelter_21': [[0, 238, 1], [239, 285, 0], [286, 310, 0], [311, 373, 1], [374, 457, 0], [458, 546, 2], [547, 575, 1], [576, 748, 2], [749, 952, 2]],
+ 'test_9434_1': [[0, 67, 0], [68, 124, 0], [125, 182, 1], [183, 251, 2], [252, 343, 0], [344, 380, 0], [381, 417, 0]],
+ 'test_boelter3_6': [[0, 174, 1], [175, 280, 1], [281, 498, 0], [499, 639, 0], [640, 695, 1], [696, 748, 0], [749, 788, 0], [789, 887, 2], [888, 1035, 1], [1036, 1445, 2]],
+ 'test_boelter3_4': [[0, 158, 1], [159, 309, 1], [310, 477, 1], [478, 668, 1], [669, 780, 1], [781, 817, 0], [818, 848, 1], [849, 942, 1]],
+ 'test_boelter3_0': [[0, 102, 0], [103, 480, 0], [481, 703, 0], [704, 768, 0]],
+ 'test_boelter2_12': [[0, 163, 0], [164, 285, 1], [286, 444, 1], [445, 519, 0], [520, 583, 1], [584, 660, 0], [661, 854, 1], [855, 921, 1], [922, 1006, 2], [1007, 1125, 2], [1126, 1332, 2], [1333, 1416, 2]],
+ 'test_94342_16': [[0, 84, 0], [85, 158, 1], [159, 200, 1], [201, 214, 0], [215, 370, 1], [371, 524, 1], [525, 587, 2], [588, 782, 2], [783, 1009, 2]],
+ 'test_boelter2_15': [[0, 46, 0], [47, 252, 2], [253, 298, 1], [299, 414, 2], [415, 547, 2], [548, 690, 1], [691, 728, 1], [729, 773, 2], [774, 935, 2]],
+ 'test_boelter3_13': [[0, 204, 2], [205, 300, 2], [301, 488, 2], [489, 755, 2]],
+ 'test_boelter3_11': [[0, 254, 1], [255, 424, 0], [425, 598, 1], [599, 692, 0], [693, 772, 2], [773, 878, 2], [879, 960, 2], [961, 1171, 2], [1172, 1397, 2]],
+ 'test_boelter4_12': [[0, 32, 0], [33, 141, 0], [142, 462, 2], [463, 519, 0], [520, 597, 0], [598, 605, 0], [606, 942, 2], [943, 1232, 2], [1233, 1293, 0]],
+ 'test_boelter4_9': [[0, 221, 0], [222, 307, 1], [308, 466, 0], [467, 794, 1], [795, 866, 1], [867, 1005, 2], [1006, 1214, 2], [1215, 1270, 0]],
+ 'test_boelter4_4': [[0, 183, 0], [184, 280, 1], [281, 529, 0], [530, 714, 0]],
+ 'test_boelter4_1': [[0, 252, 0], [253, 729, 0], [730, 1202, 0], [1203, 1237, 0]],
+ 'test_boelter4_3': [[0, 117, 0], [118, 200, 1], [201, 293, 1], [294, 404, 1], [405, 600, 1], [601, 800, 1], [801, 905, 1], [906, 1234, 1]],
+ 'test_boelter_12': [[0, 47, 1], [48, 119, 0], [120, 157, 1], [158, 636, 0], [637, 745, 1], [746, 971, 2]],
+ 'test_boelter_7': [[0, 69, 0], [70, 118, 1], [119, 133, 0], [134, 187, 0], [188, 239, 0], [240, 328, 1], [329, 376, 0], [377, 397, 1], [398, 491, 0], [492, 564, 0], [565, 619, 1], [620, 688, 1], [689, 774, 0], [775, 862, 0], [863, 897, 0], [898, 958, 1], [959, 1000, 0], [1001, 1178, 0], [1179, 1267, 1], [1268, 1307, 0], [1308, 1327, 0]],
+ 'test_94342_21': [[0, 13, 0], [14, 66, 2], [67, 594, 2], [595, 1097, 2], [1098, 1133, 0]],
+ 'test_boelter_2': [[0, 318, 0], [319, 458, 0], [459, 543, 0], [544, 606, 0]]
+}
+
+CLIPS_OBJ_BY_ID_88 = {'test1': ['P2', 'O1', 'O2', 'O3', 'P1', 'O4'], 'test2': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test6': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test7': ['P1', 'P2', 'O1'], 'test_94342_0': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13'], 'test_94342_1': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_94342_10': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_94342_11': ['P1', 'P2', 'O2', 'O1', 'O3', 'O4'], 'test_94342_12': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_94342_13': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_94342_14': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O14', 'O15'], 'test_94342_15': ['P1', 'P2', 'O1'], 'test_94342_16': ['P1', 'P2', 'O4', 'O1'], 'test_94342_17': ['P1', 'P2', 'O1', 'O2', 'O4', 'O5', 'O6'], 'test_94342_18': ['P1', 'P2', 'O1', 'O2'], 'test_94342_19': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O8', 'O9'], 'test_94342_2': ['P2', 'P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_94342_20': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_94342_21': ['P2', 'P1', 'O1', 'O2'], 'test_94342_22': ['P1', 'P2', 'O1', 'O2'], 'test_94342_23': ['P2', 'P1', 'O1', 'O2'], 'test_94342_24': ['P2', 'P1', 'O1', 'O2'], 'test_94342_25': ['P2', 'P1', 'O1'], 'test_94342_26': ['P2', 'P1', 'O1', 'O2'], 'test_94342_3': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13'], 'test_94342_4': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12'], 'test_94342_5': ['P2', 'P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_94342_6': ['P2', 'P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_94342_7': ['P2', 'P1', 'O1', 'O2', 'O3'], 'test_94342_8': ['P2', 'P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_9434_1': ['P1', 'P2', 'O1', 'O2'], 'test_9434_18': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_9434_3': ['P1', 'P2', 'O1', 'O2'], 'test_boelter2_0': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter2_12': ['P1', 'P2', 'O1', 'O2', 'O4', 'O5'], 'test_boelter2_14': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter2_15': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter2_16': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9'], 'test_boelter2_17': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter2_3': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter2_6': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter2_7': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter2_8': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter3_0': ['P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'P2'], 'test_boelter3_1': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_boelter3_10': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter3_11': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter3_12': ['P1', 'P2', 'O1', 'O2'], 'test_boelter3_13': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter3_2': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12'], 'test_boelter3_3': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter3_4': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_5': ['P1', 'P2', 'O1', 'O2'], 'test_boelter3_6': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter3_7': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter3_8': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_9': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter4_0': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter4_1': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_boelter4_10': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_11': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_12': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter4_13': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_2': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_3': ['P1', 'P2', 'O1', 'O5', 'O2', 'O3', 'O4', 'O6'], 'test_boelter4_4': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_5': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_6': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_7': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter4_8': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_9': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_1': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O7', 'O6', 'O8'], 'test_boelter_10': ['P1', 'O1', 'P2', 'O2'], 'test_boelter_12': ['P2', 'P1', 'O1', 'O2', 'O3'], 'test_boelter_13': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter_14': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_15': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter_17': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_18': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter_19': ['P1', 'P2', 'O1', 'O2'], 'test_boelter_2': ['P1', 'P2', 'O1', 'O2', 'O3', 'O5', 'O6', 'O7', 'O8'], 'test_boelter_21': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_3': ['P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'P2', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_boelter_4': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter_5': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter_6': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter_7': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_boelter_9': ['P1', 'P2', 'O1', 'O2']}
+CLIPS_OBJ_BY_ID_88_NO_P = {'test1': ['O1', 'O2', 'O3', 'O4'], 'test2': ['O1', 'O2', 'O3', 'O4'], 'test6': ['O1', 'O2', 'O3', 'O4'], 'test7': ['O1'], 'test_94342_0': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13'], 'test_94342_1': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_94342_10': ['O1', 'O2', 'O3', 'O4'], 'test_94342_11': ['O2', 'O1', 'O3', 'O4'], 'test_94342_12': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_94342_13': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_94342_14': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O14', 'O15'], 'test_94342_15': ['O1'], 'test_94342_16': ['O4', 'O1'], 'test_94342_17': ['O1', 'O2', 'O4', 'O5', 'O6'], 'test_94342_18': ['O1', 'O2'], 'test_94342_19': ['O1', 'O2', 'O3', 'O4', 'O5', 'O8', 'O9'], 'test_94342_2': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_94342_20': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_94342_21': ['O1', 'O2'], 'test_94342_22': ['O1', 'O2'], 'test_94342_23': ['O1', 'O2'], 'test_94342_24': ['O1', 'O2'], 'test_94342_25': ['O1'], 'test_94342_26': ['O1', 'O2'], 'test_94342_3': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13'], 'test_94342_4': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12'], 'test_94342_5': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_94342_6': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_94342_7': ['O1', 'O2', 'O3'], 'test_94342_8': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_9434_1': ['O1', 'O2'], 'test_9434_18': ['O1', 'O2', 'O3', 'O4'], 'test_9434_3': ['O1', 'O2'], 'test_boelter2_0': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter2_12': ['O1', 'O2', 'O4', 'O5'], 'test_boelter2_14': ['O1', 'O2', 'O3'], 'test_boelter2_15': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter2_16': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9'], 'test_boelter2_17': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter2_3': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter2_6': ['O1', 'O2', 'O3', 'O4'], 'test_boelter2_7': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter2_8': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter3_0': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_1': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_boelter3_10': ['O1', 'O2', 'O3', 'O4'], 'test_boelter3_11': ['O1', 'O2', 'O3'], 'test_boelter3_12': ['O1', 'O2'], 'test_boelter3_13': ['O1', 'O2', 'O3'], 'test_boelter3_2': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12'], 'test_boelter3_3': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter3_4': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_5': ['O1', 'O2'], 'test_boelter3_6': ['O1', 'O2', 'O3'], 'test_boelter3_7': ['O1', 'O2', 'O3', 'O4'], 'test_boelter3_8': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_9': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter4_0': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter4_1': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_boelter4_10': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_11': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_12': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter4_13': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_2': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_3': ['O1', 'O5', 'O2', 'O3', 'O4', 'O6'], 'test_boelter4_4': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_5': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_6': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_7': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter4_8': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_9': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_1': ['O1', 'O2', 'O3', 'O4', 'O5', 'O7', 'O6', 'O8'], 'test_boelter_10': ['O1', 'O2'], 'test_boelter_12': ['O1', 'O2', 'O3'], 'test_boelter_13': ['O1', 'O2', 'O3', 'O4'], 'test_boelter_14': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_15': ['O1', 'O2', 'O3'], 'test_boelter_17': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_18': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter_19': ['O1', 'O2'], 'test_boelter_2': ['O1', 'O2', 'O3', 'O5', 'O6', 'O7', 'O8'], 'test_boelter_21': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_3': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_boelter_4': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter_5': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter_6': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter_7': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_boelter_9': ['O1', 'O2']}
+CLIPS_OBJ_NUM_BY_ID_88 = {'test1': 6, 'test2': 6, 'test6': 6, 'test7': 3, 'test_94342_0': 15, 'test_94342_1': 16, 'test_94342_10': 6, 'test_94342_11': 6, 'test_94342_12': 10, 'test_94342_13': 12, 'test_94342_14': 15, 'test_94342_15': 3, 'test_94342_16': 4, 'test_94342_17': 7, 'test_94342_18': 4, 'test_94342_19': 9, 'test_94342_2': 13, 'test_94342_20': 7, 'test_94342_21': 4, 'test_94342_22': 4, 'test_94342_23': 4, 'test_94342_24': 4, 'test_94342_25': 3, 'test_94342_26': 4, 'test_94342_3': 15, 'test_94342_4': 14, 'test_94342_5': 13, 'test_94342_6': 9, 'test_94342_7': 5, 'test_94342_8': 8, 'test_9434_1': 4, 'test_9434_18': 6, 'test_9434_3': 4, 'test_boelter2_0': 12, 'test_boelter2_12': 6, 'test_boelter2_14': 5, 'test_boelter2_15': 7, 'test_boelter2_16': 11, 'test_boelter2_17': 8, 'test_boelter2_3': 12, 'test_boelter2_6': 6, 'test_boelter2_7': 9, 'test_boelter2_8': 13, 'test_boelter3_0': 9, 'test_boelter3_1': 10, 'test_boelter3_10': 6, 'test_boelter3_11': 5, 'test_boelter3_12': 4, 'test_boelter3_13': 5, 'test_boelter3_2': 14, 'test_boelter3_3': 7, 'test_boelter3_4': 9, 'test_boelter3_5': 4, 'test_boelter3_6': 5, 'test_boelter3_7': 6, 'test_boelter3_8': 10, 'test_boelter3_9': 12, 'test_boelter4_0': 12, 'test_boelter4_1': 10, 'test_boelter4_10': 9, 'test_boelter4_11': 6, 'test_boelter4_12': 7, 'test_boelter4_13': 6, 'test_boelter4_2': 6, 'test_boelter4_3': 8, 'test_boelter4_4': 6, 'test_boelter4_5': 6, 'test_boelter4_6': 9, 'test_boelter4_7': 8, 'test_boelter4_8': 9, 'test_boelter4_9': 7, 'test_boelter_1': 10, 'test_boelter_10': 4, 'test_boelter_12': 5, 'test_boelter_13': 6, 'test_boelter_14': 7, 'test_boelter_15': 5, 'test_boelter_17': 7, 'test_boelter_18': 8, 'test_boelter_19': 4, 'test_boelter_2': 9, 'test_boelter_21': 7, 'test_boelter_3': 16, 'test_boelter_4': 9, 'test_boelter_5': 13, 'test_boelter_6': 13, 'test_boelter_7': 16, 'test_boelter_9': 4}
+
+CLIPS_LEN_BY_ID_88 = {'test_94342_13': 1455, 'test_boelter4_11': 1355, 'test_94342_20': 1865, 'test_94342_0': 1940, 'test_94342_23': 539, 'test_boelter4_5': 1166, 'test_boelter_12': 972, 'test_9434_3': 323, 'test_boelter_15': 1055, 'test_94342_19': 1695, 'test_boelter_21': 953, 'test_boelter3_2': 1326, 'test_boelter4_0': 1322, 'test_boelter_18': 1386, 'test6': 704, 'test_boelter_1': 925, 'test_boelter3_6': 1446, 'test_94342_21': 1134, 'test_boelter4_10': 1263, 'test_9434_1': 418, 'test_94342_17': 1057, 'test_boelter4_9': 1271, 'test_94342_18': 1539, 'test_boelter4_12': 1294, 'test_boelter3_11': 1398, 'test_boelter4_1': 1238, 'test_94342_26': 527, 'test_boelter_10': 654, 'test_boelter4_8': 1006, 'test_boelter3_8': 1161, 'test2': 975, 'test_94342_7': 1386, 'test_94342_16': 1010, 'test_boelter2_17': 1268, 'test_boelter_4': 787, 'test_boelter3_3': 861, 'test_94342_1': 1387, 'test_boelter_13': 1004, 'test_boelter3_1': 1351, 'test_boelter2_8': 1347, 'test_boelter2_14': 920, 'test_boelter2_0': 1143, 'test7': 227, 'test_94342_3': 1776, 'test_boelter2_12': 1417, 'test_94342_8': 1795, 'test_boelter4_7': 1401, 'test_9434_18': 1042, 'test_94342_22': 586, 'test_94342_5': 2292, 'test_boelter3_9': 1383, 'test1': 699, 'test_boelter_6': 1435, 'test_boelter_19': 959, 'test_boelter4_13': 933, 'test_94342_10': 1156, 'test_boelter4_4': 715, 'test_boelter3_4': 943, 'test_boelter2_3': 942, 'test_boelter_5': 834, 'test_94342_12': 2417, 'test_boelter_14': 904, 'test_boelter3_0': 769, 'test_94342_6': 944, 'test_94342_15': 1174, 'test_94342_24': 741, 'test_boelter_2': 607, 'test_boelter_7': 1328, 'test_boelter_3': 596, 'test_94342_4': 1924, 'test_boelter4_2': 1353, 'test_boelter3_13': 756, 'test_94342_25': 568, 'test_boelter2_16': 1734, 'test_boelter3_5': 851, 'test_boelter4_3': 1235, 'test_boelter4_6': 1334, 'test_boelter3_10': 1301, 'test_boelter2_7': 1505, 'test_94342_14': 1841, 'test_boelter3_7': 1544, 'test_boelter2_15': 936, 'test_boelter_9': 636, 'test_boelter2_6': 2100, 'test_boelter3_12': 359, 'test_boelter_17': 817, 'test_94342_11': 1610, 'test_94342_2': 1968}
+CLIPS_IDS_88 = ['test_94342_13', 'test_boelter4_11', 'test_94342_20', 'test_94342_0', 'test_94342_23', 'test_boelter4_5', 'test_boelter_12', 'test_9434_3', 'test_boelter_15', 'test_94342_19', 'test_boelter_21', 'test_boelter3_2', 'test_boelter4_0', 'test_boelter_18', 'test6', 'test_boelter_1', 'test_boelter3_6', 'test_94342_21', 'test_boelter4_10', 'test_9434_1', 'test_94342_17', 'test_boelter4_9', 'test_94342_18', 'test_boelter4_12', 'test_boelter3_11', 'test_boelter4_1', 'test_94342_26', 'test_boelter_10', 'test_boelter4_8', 'test_boelter3_8', 'test2', 'test_94342_7', 'test_94342_16', 'test_boelter2_17', 'test_boelter_4', 'test_boelter3_3', 'test_94342_1', 'test_boelter_13', 'test_boelter3_1', 'test_boelter2_8', 'test_boelter2_14', 'test_boelter2_0', 'test7', 'test_94342_3', 'test_boelter2_12', 'test_94342_8', 'test_boelter4_7', 'test_9434_18', 'test_94342_22', 'test_94342_5', 'test_boelter3_9', 'test1', 'test_boelter_6', 'test_boelter_19', 'test_boelter4_13', 'test_94342_10', 'test_boelter4_4', 'test_boelter3_4', 'test_boelter2_3', 'test_boelter_5', 'test_94342_12', 'test_boelter_14', 'test_boelter3_0', 'test_94342_6', 'test_94342_15', 'test_94342_24', 'test_boelter_2', 'test_boelter_7', 'test_boelter_3', 'test_94342_4', 'test_boelter4_2', 'test_boelter3_13', 'test_94342_25', 'test_boelter2_16', 'test_boelter3_5', 'test_boelter4_3', 'test_boelter4_6', 'test_boelter3_10', 'test_boelter2_7', 'test_94342_14', 'test_boelter3_7', 'test_boelter2_15', 'test_boelter_9', 'test_boelter2_6', 'test_boelter3_12', 'test_boelter_17', 'test_94342_11', 'test_94342_2']
+CLIPS_LEN_88 = [1455, 1355, 1865, 1940, 539, 1166, 972, 323, 1055, 1695, 953, 1326, 1322, 1386, 704, 925, 1446, 1134, 1263, 418, 1057, 1271, 1539, 1294, 1398, 1238, 527, 654, 1006, 1161, 975, 1386, 1010, 1268, 787, 861, 1387, 1004, 1351, 1347, 920, 1143, 227, 1776, 1417, 1795, 1401, 1042, 586, 2292, 1383, 699, 1435, 959, 933, 1156, 715, 943, 942, 834, 2417, 904, 769, 944, 1174, 741, 607, 1328, 596, 1924, 1353, 756, 568, 1734, 851, 1235, 1334, 1301, 1505, 1841, 1544, 936, 636, 2100, 359, 817, 1610, 1968]
+
+CLIPS_WITH_GT_EVENT =['test_boelter2_15', 'test_94342_16', 'test_boelter4_4', 'test_94342_21', 'test_boelter4_1', 'test_boelter4_9', 'test_94342_1', 'test_boelter3_4', 'test_boelter_2', 'test_boelter_21', 'test_boelter4_12', 'test_boelter_7', 'test7', 'test_9434_18', 'test_94342_10', 'test_boelter3_13', 'test_94342_6', 'test1', 'test_boelter_12', 'test_boelter3_0', 'test6', 'test_9434_1', 'test_boelter2_12', 'test_boelter3_6', 'test_boelter4_3', 'test_boelter3_11']
+
+CLIPS_LEN_BY_ID = {'test_94342_13': 1455, 'test_boelter4_11': 1355, 'test_94342_20': 1865, 'test_94342_0': 1940, 'test_94342_23': 539, 'test_boelter4_5': 1166, 'test_boelter_12': 972, 'test_9434_3': 323, 'test_boelter_15': 1055, 'test_94342_19': 1695, 'test_boelter_21': 953, 'test_boelter3_2': 1326, 'test_boelter4_0': 1322, 'test_boelter_18': 1386, 'test6': 704, 'test_boelter_1': 925, 'test_boelter3_6': 1446, 'test_94342_21': 1134, 'test_boelter4_10': 1263, 'test_9434_1': 418, 'test_94342_17': 1057, 'test_boelter4_9': 1271, 'test_94342_18': 1539, 'test_boelter4_12': 1294, 'test_boelter3_11': 1398, 'test_boelter4_1': 1238, 'test_94342_26': 527, 'test_boelter_10': 654, 'test_boelter4_8': 1006, 'test_boelter3_8': 1161, 'test2': 975, 'test_94342_7': 1386, 'test_94342_16': 1010, 'test_boelter2_17': 1268, 'test_boelter_4': 787, 'test_boelter3_3': 861, 'test_94342_1': 1387, 'test_boelter_13': 1004, 'test_boelter_24': 315, 'test_boelter3_1': 1351, 'test_boelter2_8': 1347, 'test_boelter2_2': 1413, 'test_boelter2_14': 920, 'test_boelter2_0': 1143, 'test7': 227, 'test_94342_3': 1776, 'test_boelter2_12': 1417, 'test_94342_8': 1795, 'test_boelter4_7': 1401, 'test_9434_18': 1042, 'test_94342_22': 586, 'test_94342_5': 2292, 'test_boelter3_9': 1383, 'test1': 699, 'test_boelter_6': 1435, 'test_boelter_19': 959, 'test_boelter4_13': 933, 'test_94342_10': 1156, 'test_boelter4_4': 715, 'test_boelter3_4': 943, 'test_boelter2_3': 942, 'test_boelter_5': 834, 'test_94342_12': 2417, 'test_boelter_14': 904, 'test_boelter3_0': 769, 'test_94342_6': 944, 'test_94342_15': 1174, 'test_94342_24': 741, 'test_boelter_2': 607, 'test_boelter2_5': 2085, 'test_boelter_7': 1328, 'test_boelter_3': 596, 'test_94342_4': 1924, 'test_boelter4_2': 1353, 'test_boelter3_13': 756, 'test_94342_25': 568, 'test_boelter2_16': 1734, 'test_boelter3_5': 851, 'test_boelter4_3': 1235, 'test_boelter4_6': 1334, 'test_boelter3_10': 1301, 'test_boelter2_7': 1505, 'test_94342_14': 1841, 'test_boelter_22': 828, 'test_boelter3_7': 1544, 'test_boelter2_15': 936, 'test_boelter_9': 636, 'test_boelter_25': 951, 'test_boelter2_6': 2100, 'test_boelter2_4': 1526, 'test_boelter3_12': 359, 'test_boelter_17': 817, 'test_94342_11': 1610, 'test_94342_2': 1968}
+CLIPS_LEN = [1455, 1355, 1865, 1940, 539, 1166, 972, 323, 1055, 1695, 953, 1326, 1322, 1386, 704, 925, 1446, 1134, 1263, 418, 1057, 1271, 1539, 1294, 1398, 1238, 527, 654, 1006, 1161, 975, 1386, 1010, 1268, 787, 861, 1387, 1004, 315, 1351, 1347, 1413, 920, 1143, 227, 1776, 1417, 1795, 1401, 1042, 586, 2292, 1383, 699, 1435, 959, 933, 1156, 715, 943, 942, 834, 2417, 904, 769, 944, 1174, 741, 607, 2085, 1328, 596, 1924, 1353, 756, 568, 1734, 851, 1235, 1334, 1301, 1505, 1841, 828, 1544, 936, 636, 951, 2100, 1526, 359, 817, 1610, 1968]
+CLIPS_IDS = ['test_94342_13', 'test_boelter4_11', 'test_94342_20', 'test_94342_0', 'test_94342_23', 'test_boelter4_5', 'test_boelter_12', 'test_9434_3', 'test_boelter_15', 'test_94342_19', 'test_boelter_21', 'test_boelter3_2', 'test_boelter4_0', 'test_boelter_18', 'test6', 'test_boelter_1', 'test_boelter3_6', 'test_94342_21', 'test_boelter4_10', 'test_9434_1', 'test_94342_17', 'test_boelter4_9', 'test_94342_18', 'test_boelter4_12', 'test_boelter3_11', 'test_boelter4_1', 'test_94342_26', 'test_boelter_10', 'test_boelter4_8', 'test_boelter3_8', 'test2', 'test_94342_7', 'test_94342_16', 'test_boelter2_17', 'test_boelter_4', 'test_boelter3_3', 'test_94342_1', 'test_boelter_13', 'test_boelter_24', 'test_boelter3_1', 'test_boelter2_8', 'test_boelter2_2', 'test_boelter2_14', 'test_boelter2_0', 'test7', 'test_94342_3', 'test_boelter2_12', 'test_94342_8', 'test_boelter4_7', 'test_9434_18', 'test_94342_22', 'test_94342_5', 'test_boelter3_9', 'test1', 'test_boelter_6', 'test_boelter_19', 'test_boelter4_13', 'test_94342_10', 'test_boelter4_4', 'test_boelter3_4', 'test_boelter2_3', 'test_boelter_5', 'test_94342_12', 'test_boelter_14', 'test_boelter3_0', 'test_94342_6', 'test_94342_15', 'test_94342_24', 'test_boelter_2', 'test_boelter2_5', 'test_boelter_7', 'test_boelter_3', 'test_94342_4', 'test_boelter4_2', 'test_boelter3_13', 'test_94342_25', 'test_boelter2_16', 'test_boelter3_5', 'test_boelter4_3', 'test_boelter4_6', 'test_boelter3_10', 'test_boelter2_7', 'test_94342_14', 'test_boelter_22', 'test_boelter3_7', 'test_boelter2_15', 'test_boelter_9', 'test_boelter_25', 'test_boelter2_6', 'test_boelter2_4', 'test_boelter3_12', 'test_boelter_17', 'test_94342_11', 'test_94342_2']
+
+ALL_IDS = ['test_boelter_15', 'test6', 'test_94342_7', 'test_94342_10', 'test_94342_21', 'test_boelter2_8', 'test_boelter4_1', 'test_boelter3_1', 'test_boelter3_11', 'test_boelter3_10', 'test_boelter3_0', 'test_boelter4_0', 'test_94342_20', 'test_94342_11', 'test_94342_6', 'test_boelter_1', 'test_9434_3', 'test7', 'test_boelter_14', 'test_boelter2_12', 'test_boelter2_3', 'test_9434_1', 'test_94342_4', 'test_boelter_3', 'test_boelter3_8', 'test_94342_13', 'test_boelter4_8', 'test_94342_22', 'test_boelter_9', 'test_boelter4_2', 'test_boelter3_2', 'test_94342_19', 'test_boelter3_12', 'test_boelter3_13', 'test_boelter3_3', 'test_94342_18', 'test_boelter4_3', 'test_9434_18', 'test_94342_23', 'test_boelter4_9', 'test_boelter3_9', 'test_94342_12', 'test_94342_5', 'test_boelter_2', 'test_boelter2_0', 'test_boelter_17', 'test_boelter2_15', 'test_boelter_13', 'test_boelter_6', 'test_94342_1', 'test_94342_16', 'test_boelter4_10', 'test_boelter_19', 'test_boelter4_7', 'test_boelter3_7', 'test_boelter3_6', 'test_boelter4_6', 'test_boelter_18', 'test_boelter4_11', 'test_94342_26', 'test_94342_17', 'test1', 'test_boelter_7', 'test_94342_0', 'test_boelter_12', 'test_boelter2_14', 'test_boelter_10', 'test_boelter2_7', 'test_boelter2_16', 'test_boelter_5', 'test_94342_2', 'test_94342_15', 'test_94342_8', 'test_94342_24', 'test_boelter4_13', 'test_boelter4_4', 'test_boelter_21', 'test_boelter3_4', 'test_boelter3_5', 'test_boelter4_5', 'test_boelter4_12', 'test_94342_25', 'test_94342_14', 'test2', 'test_boelter_4', 'test_94342_3', 'test_boelter2_6', 'test_boelter2_17']
+
+UNIQUE_OBJ_IDS = ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14', 'O15']
+
+mind_test_clips = ['test_boelter4_5.p', 'test_94342_2.p', 'test_boelter4_10.p', 'test_boelter2_3.p', 'test_94342_20.p', 'test_boelter3_9.p', 'test_boelter4_6.p', 'test2.p', 'test_boelter4_2.p', 'test_94342_24.p', 'test_94342_17.p', 'test_94342_8.p', 'test_94342_11.p', 'test_boelter3_7.p', 'test_94342_18.p', 'test_boelter_10.p', 'test_boelter3_8.p', 'test_boelter2_6.p', 'test_boelter4_7.p', 'test_boelter4_8.p', 'test_boelter4_0.p', 'test_boelter2_17.p', 'test_boelter3_12.p', 'test_boelter3_5.p', 'test_94342_4.p', 'test_94342_15.p']
+
+
+def count_parameters(model):
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+ return sum([np.prod(p.size()) for p in model_parameters])
+
+
+def split_train_val_test():
+ # Calculate the total number of frames
+ total_frames = sum(CLIPS_LEN_BY_ID_88.values())
+
+ # Calculate the number of frames for each split
+ train_frames = int(total_frames * 0.6)
+ validation_frames = int(total_frames * 0.2)
+
+ # Convert the dictionary to a list of tuples (video_id, frame_length)
+ video_list = list(CLIPS_LEN_BY_ID_88.items())
+
+ # Shuffle the video list randomly
+ random.shuffle(video_list)
+
+ # Split the videos based on the number of frames
+ train_ids = []
+ validation_ids = []
+ test_ids = []
+ frames_count = 0
+
+ for video_id, frames in video_list:
+ if frames_count < train_frames:
+ train_ids.append(video_id)
+ elif frames_count < train_frames + validation_frames:
+ validation_ids.append(video_id)
+ else:
+ test_ids.append(video_id)
+ frames_count += frames
+
+ # Print the results
+ print("Train IDs:", train_ids)
+ print("Validation IDs:", validation_ids)
+ print("Test IDs:", test_ids)
+
+ train_frames_total = sum(CLIPS_LEN_BY_ID_88[video_id] for video_id in train_ids)
+ validation_frames_total = sum(CLIPS_LEN_BY_ID_88[video_id] for video_id in validation_ids)
+ test_frames_total = sum(CLIPS_LEN_BY_ID_88[video_id] for video_id in test_ids)
+
+ print("Total frames for train_ids:", train_frames_total)
+ print("Total frames for validation_ids:", validation_frames_total)
+ print("Total frames for test_ids:", test_frames_total)
+
+ return train_ids, validation_ids, test_ids
+
+
+def compute_f1_scores(m1_pred, m1_label, m2_pred, m2_label, m12_pred, m12_label, m21_pred, m21_label, mc_pred, mc_label, verbose=False):
+ # Compute F1 score for m1
+ m1_pred_labels = torch.argmax(m1_pred, dim=-1).cpu().numpy()
+ m1_true_labels = m1_label.cpu().numpy()
+ if verbose: print(f'm1 --- pred {np.unique(m1_pred_labels, return_counts=True)}, true {np.unique(m1_true_labels, return_counts=True)}')
+ m1_f1_score = f1_score(m1_true_labels, m1_pred_labels, average='macro')
+
+ # Compute F1 score for m2
+ m2_pred_labels = torch.argmax(m2_pred, dim=-1).cpu().numpy()
+ m2_true_labels = m2_label.cpu().numpy()
+ if verbose: print(f'm2 --- pred {np.unique(m2_pred_labels, return_counts=True)}, true {np.unique(m2_true_labels, return_counts=True)}')
+ m2_f1_score = f1_score(m2_true_labels, m2_pred_labels, average='macro')
+
+ # Compute F1 score for m12
+ m12_pred_labels = torch.argmax(m12_pred, dim=-1).cpu().numpy()
+ m12_true_labels = m12_label.cpu().numpy()
+ if verbose: print(f'm12 --- pred {np.unique(m12_pred_labels, return_counts=True)}, true {np.unique(m12_true_labels, return_counts=True)}')
+ m12_f1_score = f1_score(m12_true_labels, m12_pred_labels, average='macro')
+
+ # Compute F1 score for m21
+ m21_pred_labels = torch.argmax(m21_pred, dim=-1).cpu().numpy()
+ m21_true_labels = m21_label.cpu().numpy()
+ if verbose: print(f'm21 --- pred {np.unique(m21_pred_labels, return_counts=True)}, true {np.unique(m21_true_labels, return_counts=True)}')
+ m21_f1_score = f1_score(m21_true_labels, m21_pred_labels, average='macro')
+
+ # Compute F1 score for mc
+ mc_pred_labels = torch.argmax(mc_pred, dim=-1).cpu().numpy()
+ mc_true_labels = mc_label.cpu().numpy()
+ if verbose: print(f'mc --- pred {np.unique(mc_pred_labels, return_counts=True)}, true {np.unique(mc_true_labels, return_counts=True)}')
+ mc_f1_score = f1_score(mc_true_labels, mc_pred_labels, average='macro')
+
+ return m1_f1_score, m2_f1_score, m12_f1_score, m21_f1_score, mc_f1_score
\ No newline at end of file
diff --git a/tbd/utils/preprocess_img.py b/tbd/utils/preprocess_img.py
new file mode 100644
index 0000000..b1680d1
--- /dev/null
+++ b/tbd/utils/preprocess_img.py
@@ -0,0 +1,37 @@
+import glob
+
+import cv2
+
+import torchvision.transforms as T
+import torch
+import os
+from tqdm import tqdm
+
+
+PATH_IN = "/scratch/bortoletto/data/tbd/images"
+PATH_OUT = "/scratch/bortoletto/data/tbd/images_norm"
+
+normalisation_steps = [
+ T.ToTensor(),
+ T.Resize((128,128)),
+ T.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+]
+
+preprocess_img = T.Compose(normalisation_steps)
+
+def main():
+ print(f"{PATH_IN}/*/*/*.jpg")
+ all_img = glob.glob(f"{PATH_IN}/*/*/*.jpg")
+ print(len(all_img))
+ for img_path in tqdm(all_img):
+ new_img = preprocess_img(cv2.imread(img_path)).numpy()
+ img_path_split = img_path.split("/")
+ os.makedirs(f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}", exist_ok=True)
+ out_img = f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}/{img_path_split[-1][:-4]}.pt"
+ torch.save(new_img, out_img)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/tbd/utils/reformat_labels_ours.py b/tbd/utils/reformat_labels_ours.py
new file mode 100644
index 0000000..9500a6f
--- /dev/null
+++ b/tbd/utils/reformat_labels_ours.py
@@ -0,0 +1,106 @@
+import pandas as pd
+import os
+import glob
+import pickle
+
+DATASET_LOCATION = "YOUR_PATH_HERE"
+
+def reframe_annotation():
+ annotation_path = f'{DATASET_LOCATION}/retrieve_annotation/all/'
+ save_path = f'{DATASET_LOCATION}/reformat_annotation/'
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ tasks = glob.glob(annotation_path + '*.txt')
+ id_map = pd.read_csv('id_map.csv')
+ for task in tasks:
+ if not task.split('/')[-1].split('_')[2] == '1.txt':
+ continue
+ with open(task, 'r') as f:
+ lines = f.readlines()
+ task_id = int(task.split('/')[-1].split('_')[1]) + 1
+ clip = id_map.loc[id_map['ID'] == task_id].folder
+ print(task_id, len(clip))
+ if len(clip) == 0:
+ continue
+ with open(save_path + clip.item() + '.txt', 'w') as f:
+ for line in lines:
+ words = line.split()
+ f.write(words[0] + ',' + words[1] + ',' + words[2] + ',' + words[3] + ',' + words[4] + ',' + words[5] +
+ ',' + words[6] + ',' + words[7] + ',' + words[8] + ',' + words[9] + ',' + ' '.join(words[10:]) + '\n')
+ f.close()
+
+def get_grid_location(obj_frame):
+ x_min = obj_frame['x_min']#.item()
+ y_min = obj_frame['y_min']#.item()
+ x_max = obj_frame['x_max']#.item()
+ y_max = obj_frame['y_max']#.item()
+ gridLW = 1280 / 25.
+ gridLH = 720 / 15.
+ center_x, center_y = (x_min + x_max)/2, (y_min + y_max)/2
+ X, Y = int(center_x / gridLW), int(center_y / gridLH)
+ return X, Y
+
+def regenerate_annotation():
+ annotation_path = f'{DATASET_LOCATION}/reformat_annotation/'
+ save_path=f'{DATASET_LOCATION}/regenerate_annotation/'
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ tasks = glob.glob(annotation_path + '*.txt')
+ for task in tasks:
+ print(task)
+ annt = pd.read_csv(task, sep=",", header=None)
+ annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name", "label"]
+ obj_records = {}
+ for index, obj_frame in annt.iterrows():
+ if obj_frame['name'].startswith('P'):
+ continue
+ else:
+ assert obj_frame['name'].startswith('O')
+ obj_name = obj_frame['name']
+ # 0: enter 1: disappear 2: update 3: unchange
+ frame_id = obj_frame['frame']
+ curr_loc = get_grid_location(obj_frame)
+ mind_dict = {'m1': {'fluent': 3, 'loc': None}, 'm2': {'fluent': 3, 'loc': None},
+ 'm12': {'fluent': 3, 'loc': None},
+ 'm21': {'fluent': 3, 'loc': None}, 'mc': {'fluent': 3, 'loc': None},
+ 'mg': {'fluent': 3, 'loc': curr_loc}}
+ mind_dict['mg']['loc'] = curr_loc
+ if not type(obj_frame['label']) == float:
+ mind_labels = obj_frame['label'].split()
+ for mind_label in mind_labels:
+ if mind_label == 'in_m1' or mind_label == 'in_m2' or mind_label == 'in_m12' \
+ or mind_label == 'in_m21' or mind_label == 'in_mc' or mind_label == '"in_m1"' or mind_label == '"in_m2"'\
+ or mind_label == '"in_m12"' or mind_label == '"in_m21"' or mind_label == '"in_mc"':
+ mind_name = mind_label.split('_')[1].split('"')[0]
+ mind_dict[mind_name]['loc'] = curr_loc
+ else:
+ mind_name = mind_label.split('_')[0].split('"')
+ if len(mind_name) > 1:
+ mind_name = mind_name[1]
+ else:
+ mind_name = mind_name[0]
+ last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
+ mind_dict[mind_name]['loc'] = last_loc
+
+ for mind_name in mind_dict.keys():
+ if frame_id > 0:
+ curr_loc = mind_dict[mind_name]['loc']
+ last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
+ if last_loc is None and curr_loc is not None:
+ mind_dict[mind_name]['fluent'] = 0
+ elif last_loc is not None and curr_loc is None:
+ mind_dict[mind_name]['fluent'] = 1
+ elif not last_loc == curr_loc:
+ mind_dict[mind_name]['fluent'] = 2
+ if obj_name not in obj_records:
+ obj_records[obj_name] = [mind_dict]
+ else:
+ obj_records[obj_name].append(mind_dict)
+
+ with open(save_path + task.split('/')[-1].split('.')[0] + '.p', 'wb') as f:
+ pickle.dump(obj_records, f)
+
+
+if __name__ == '__main__':
+ reframe_annotation()
+ regenerate_annotation()
\ No newline at end of file
diff --git a/tbd/utils/similarity.py b/tbd/utils/similarity.py
new file mode 100644
index 0000000..5517fcf
--- /dev/null
+++ b/tbd/utils/similarity.py
@@ -0,0 +1,75 @@
+import os
+import torch
+import numpy as np
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+from sklearn.decomposition import PCA
+import seaborn as sns
+
+
+FOLDER_PATH = 'PATH_TO_FOLDER'
+
+print(FOLDER_PATH)
+
+MTOM_COLORS = {
+ "MN1": (110/255, 117/255, 161/255),
+ "MN2": (179/255, 106/255, 98/255),
+ "Base": (193/255, 198/255, 208/255),
+ "CG": (170/255, 129/255, 42/255),
+ "IC": (97/255, 112/255, 83/255),
+ "DB": (144/255, 63/255, 110/255)
+}
+
+COLORS = sns.color_palette()
+
+sns.set_theme(style='white')
+
+out_left_main_mods_full_test = []
+out_right_main_mods_full_test = []
+cell_left_main_mods_full_test = []
+cell_right_main_mods_full_test = []
+cm_left_main_mods_full_test = []
+cm_right_main_mods_full_test = []
+
+for i in range(len([filename for filename in os.listdir(FOLDER_PATH) if filename.endswith('.pt')])):
+
+ print(f'Computing analysis for test video {i}...', end='\r')
+
+ emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
+ data = torch.load(emb_file)
+ if len(data) == 13: # implicit
+ model = 'impl'
+ out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
+ elif len(data) == 12: # common mind
+ model = 'cm'
+ out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
+ elif len(data) == 11: # speaker-listener
+ model = 'sl'
+ out_left, out_right, feats = data[0], data[1], data[2:]
+ else: raise ValueError("Data should have 13 (impl), others are not implemented")
+
+ # ====== PCA for left and right embeddings ====== #
+
+ out_left_pca = out_left[0].reshape(-1, 64)
+ out_right_pca = out_right[0].reshape(-1, 64)
+ out_left_and_right = np.concatenate((out_left_pca, out_right_pca), axis=0)
+
+ pca = PCA(n_components=2)
+ pca_result = pca.fit_transform(out_left_and_right)
+
+ # Separate the PCA results for each tensor
+ pca_result_left = pca_result[:out_left_pca.shape[0]]
+ pca_result_right = pca_result[out_right_pca.shape[0]:]
+
+ plt.figure(figsize=(6.8,6))
+ plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
+ plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
+ plt.xlabel('Principal Component 1', fontsize=32)
+ plt.ylabel('Principal Component 2', fontsize=32)
+ plt.xticks(fontsize=24)
+ plt.xticks([-0.4, -0.2, 0.0, 0.2, 0.4])
+ plt.yticks(fontsize=24)
+ plt.legend(fontsize=32)
+ plt.tight_layout()
+ plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
+ plt.close()
\ No newline at end of file
diff --git a/tbd/utils/store_mind_set.py b/tbd/utils/store_mind_set.py
new file mode 100644
index 0000000..c1ba5f0
--- /dev/null
+++ b/tbd/utils/store_mind_set.py
@@ -0,0 +1,96 @@
+import os
+import pandas as pd
+import pickle
+from tqdm import tqdm
+
+
+def check_append(obj_name, m1, mind_name, obj_frame, flags, label):
+ if label:
+ if not obj_name in m1:
+ m1[obj_name] = []
+ m1[obj_name].append(
+ [obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
+ flags[mind_name] = 1
+ elif not flags[mind_name]:
+ m1[obj_name].append(
+ [obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
+ flags[mind_name] = 1
+ else: # false belief
+ if obj_name in m1:
+ if flags[mind_name]:
+ m1[obj_name].append(
+ [obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 1])
+ flags[mind_name] = 0
+ return flags, m1
+
+
+def store_mind_set(clip, annotation_path, save_path):
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ annt = pd.read_csv(annotation_path + clip, sep=",", header=None)
+ annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name",
+ "label"]
+ obj_names = annt.name.unique()
+ m1, m2, m12, m21, mc = {}, {}, {}, {}, {}
+ flags = {'m1':0, 'm2':0, 'm12':0, 'm21':0, 'mc':0}
+ for obj_name in obj_names:
+ if obj_name == 'P1' or obj_name == 'P2':
+ continue
+ obj_frames = annt.loc[annt.name == obj_name]
+ for index, obj_frame in obj_frames.iterrows():
+ if type(obj_frame.label) == float:
+ continue
+ labels = obj_frame.label.split()
+ for label in labels:
+ if label == 'in_m1' or label == '"in_m1"':
+ flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 1)
+ elif label == 'in_m2' or label == '"in_m2"':
+ flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 1)
+ elif label == 'in_m12'or label == '"in_m12"':
+ flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 1)
+ elif label == 'in_m21' or label == '"in_m21"':
+ flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 1)
+ elif label == 'in_mc'or label == '"in_mc"':
+ flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 1)
+ elif label == 'm1_false' or label == '"m1_false"':
+ flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 0)
+ flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
+ flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
+ false_belief = 'm1_false'
+ with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
+ file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
+ elif label == 'm2_false' or label == '"m2_false"':
+ flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 0)
+ flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
+ flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
+ false_belief = 'm2_false'
+ with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
+ file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
+ elif label == 'm12_false' or label == '"m12_false"':
+ flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
+ flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
+ false_belief = 'm12_false'
+ with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
+ file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
+ elif label == 'm21_false' or label == '"m21_false"':
+ flags, m21 = check_append(obj_name, m2, 'm21', obj_frame, flags, 0)
+ flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
+ false_belief = 'm21_false'
+ with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
+ file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
+ # print('m1', m1)
+ # print('m2', m2)
+ # print('m12', m12)
+ # print('m21', m21)
+ # print('mc', mc)
+ #with open(save_path + clip.split('.')[0] + '.p', 'wb') as f:
+ # pickle.dump([m1, m2, m12, m21, mc], f)
+
+
+if __name__ == "__main__":
+
+ annotation_path = '/scratch/bortoletto/data/tbd/reformat_annotation/'
+ save_path = '/scratch/bortoletto/data/tbd/store_mind_set/'
+
+ for clip in tqdm(os.listdir(annotation_path), desc="Processing videos", unit="item"):
+ store_mind_set(clip, annotation_path, save_path)
\ No newline at end of file
diff --git a/tbd/utils/visualize_bbox.py b/tbd/utils/visualize_bbox.py
new file mode 100644
index 0000000..fc33567
--- /dev/null
+++ b/tbd/utils/visualize_bbox.py
@@ -0,0 +1,95 @@
+import time
+from tbd_dataloader import TBDv2Dataset
+
+import numpy as np
+
+import matplotlib.pyplot as plt
+import matplotlib.patches as patches
+
+def point2screen(points):
+ K = [607.13232421875, 0.0, 638.6468505859375, 0.0, 607.1067504882812, 367.1607360839844, 0.0, 0.0, 1.0]
+ K = np.reshape(np.array(K), [3, 3])
+ rot_points = np.array(points) + np.array([0, 0.2, 0])
+ rot_points = rot_points
+ points_camera = rot_points.reshape(3, 1)
+
+ project_matrix = np.array(K).reshape(3, 3)
+ points_prj = project_matrix.dot(points_camera)
+ points_prj = points_prj.transpose()
+ if not points_prj[:, 2][0] == 0.0:
+ points_prj[:, 0] = points_prj[:, 0] / points_prj[:, 2]
+ points_prj[:, 1] = points_prj[:, 1] / points_prj[:, 2]
+ points_screen = points_prj[:, :2]
+ assert points_screen.shape == (1, 2)
+ points_screen = points_screen.reshape(-1)
+ return points_screen
+
+if __name__ == '__main__':
+ data = TBDv2Dataset(number_frames_to_sample=1, resize_img=None)
+ index = np.random.randint(0, len(data))
+ start = time.time()
+ (
+ kinect_imgs, # <- len x 720 x 1280 x 3
+ tracker_imgs,
+ battery_imgs,
+ skele1,
+ skele2,
+ bbox,
+ tracker_skeID_sample, # <- This is the tracker skeleton ID
+ tracker2d,
+ label,
+ experiment_id, # From here for debugging
+ timestep,
+ obj_id, # <- This is the object ID as a string
+ ) = data[index]
+ end = time.time()
+ print(f"Time for one sample: {end-start}")
+
+ img = kinect_imgs[-1]
+ bbox = bbox[-1]
+ print(label.shape)
+
+ print(skele1.shape)
+ print(skele2.shape)
+
+ skele1 = skele1[-1, :,:]
+ skele2 = skele2[-1, :,:]
+
+ print(skele1.shape)
+
+
+
+ # reshape img from c, h, w to h, w, c
+ img = img.permute(1, 2, 0)
+
+ fig, ax = plt.subplots(1)
+ ax.imshow(img)
+ print(bbox[0], bbox[1], bbox[2], bbox[3]) # t(top left x, top left y, width, height)
+ top_left_x, top_left_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3]
+ x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
+
+
+
+
+ for i in range(26):
+ print(skele1[i,0], skele1[i,1])
+ print(skele1[i,:].shape)
+ print(point2screen(skele1[i,:]))
+ x, y = point2screen(skele1[i,:])[0], point2screen(skele1[i,:])[1]
+ ax.text(x, y, f"{i}", fontsize=5, color='w')
+
+ wedge = patches.Wedge((x,y), 10, 0, 360, width=10, color='b')
+ ax.add_patch(wedge)
+
+ for i in range(26):
+ x, y = point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]
+ ax.text(x, y, f"{i}", fontsize=5, color='w')
+ wedge = patches.Wedge((point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]), 10, 0, 360, width=10, color='r')
+ ax.add_patch(wedge)
+
+ # Create a Rectangle patch
+ # rect = patches.Rectangle((top_left_x, top_left_y-height), width, height, linewidth=1, edgecolor='r', facecolor='none')
+ # ax.add_patch(rect)
+ # rect = patches.Rectangle((x_min, y_max), x_max-x_min, y_max-y_min, linewidth=1, edgecolor='g', facecolor='none')
+ # ax.add_patch(rect)
+ fig.savefig(f"bbox_{obj_id}_{index}_{experiment_id}.png")