diff --git a/README.md b/README.md index c4a4964..09be464 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,17 @@ This is a temporary reference that will be updated after the proceedings are pub } ``` -
-
-
+# Code -Under construction +This repository has the following structure: + +``` +mtomnet +├── boss +└── tbd +``` + +We have one subfolder for dataset, containing the code to run the corresponding experiments. Inside each subfolder we provide a README with further instructions. [1]: https://mattbortoletto.github.io/ diff --git a/boss/.gitignore b/boss/.gitignore new file mode 100644 index 0000000..b9348b7 --- /dev/null +++ b/boss/.gitignore @@ -0,0 +1,206 @@ +experiments/ +predictions_* +predictions/ +tmp/ + +### WANDB ### + +wandb/* +wandb +wandb-debug.log +logs +summary* +run* + +### SYSTEM ### + +*/__pycache__/* + +# Created by https://www.toptal.com/developers/gitignore/api/python,linux +# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# End of https://www.toptal.com/developers/gitignore/api/python,linux \ No newline at end of file diff --git a/boss/README.md b/boss/README.md new file mode 100644 index 0000000..92e2d01 --- /dev/null +++ b/boss/README.md @@ -0,0 +1,16 @@ +# BOSS + +## Installing Dependencies +Run `conda env create -f environment.yml`. + +## New bounding box annotations +We re-extracted bounding box annotations using Yolo-v8. The new data is in `new_bbox/`. The rest of the data can be found [here](https://drive.google.com/drive/folders/1b8FdpyoWx9gUps-BX6qbE9Kea3C2Uyua). + +## Train +`source run_train.sh`. + +## Test +`source run_test.sh` (specify the path to the model). + +## Resources +The original project page for the BOSS dataset can be found [here](https://sites.google.com/view/bossbelief/). Our code is based on the original implementation. \ No newline at end of file diff --git a/boss/dataloader.py b/boss/dataloader.py new file mode 100644 index 0000000..64b05f9 --- /dev/null +++ b/boss/dataloader.py @@ -0,0 +1,500 @@ +from torch.utils.data import Dataset +import os +from torchvision import transforms +from natsort import natsorted +import numpy as np +import pickle +import torch +import cv2 +import ast +from einops import rearrange +from utils import spherical2cartesial + + +class Data(Dataset): + def __init__( + self, + frame_path, + label_path, + pose_path=None, + gaze_path=None, + bbox_path=None, + ocr_graph_path=None, + presaved=128, + sizes = (128,128), + spatial_patch_size: int = 16, + temporal_patch_size: int = 4, + img_channels: int = 3, + patch_data: bool = False, + flatten_dim: int = 1 + ): + self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []} + self.size_1, self.size_2 = sizes + self.patch_data = patch_data + if self.patch_data: + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2 + self.spatial_patch_dim = img_channels * spatial_patch_size ** 2 + + assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.' + + if presaved == 224: + self.presaved = 'presaved' + elif presaved == 128: + self.presaved = 'presaved128' + else: raise ValueError + + frame_dirs = os.listdir(frame_path) + episode_ids = natsorted(frame_dirs) + seq_lens = [] + for frame_dir in natsorted(frame_dirs): + frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] + seq_lens.append(len(frame_paths)) + self.data['frame_paths'].append(frame_paths) + + print('Labels...') + with open(label_path, 'rb') as fp: + labels = pickle.load(fp) + left_labels, right_labels = labels + for i in episode_ids: + episode_id = int(i) + left_labels_i = left_labels[episode_id] + right_labels_i = right_labels[episode_id] + self.data['labels'].append([left_labels_i, right_labels_i]) + fp.close() + + print('Pose...') + self.pose_path = pose_path + if pose_path is not None: + for i in episode_ids: + pose_i_dir = os.path.join(pose_path, i) + poses = [] + for j in natsorted(os.listdir(pose_i_dir)): + fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ) + if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3): + poses.append(fs.getNode("pose_0").mat()[:2,:,:]) + else: + poses.append(fs.getNode("pose_0").mat()) + poses = np.array(poses) + poses[:, :, :, 0] = poses[:, :, :, 0] / 1920 + poses[:, :, :, 1] = poses[:, :, :, 1] / 1088 + self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim)) + #self.data['poses'].append(torch.tensor(np.array(poses))) + + print('Gaze...') + """ + Gaze information is represented by a 3D gaze vector based on the observing camera’s Cartesian eye coordinate system + """ + self.gaze_path = gaze_path + if gaze_path is not None: + for i in episode_ids: + gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i)) + with open(gaze_i_txt, 'r') as fp: + gaze_i = fp.readlines() + fp.close() + gaze_i = ast.literal_eval(gaze_i[0]) + gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3)) + for j in range(len(gaze_i)): + if len(gaze_i[j]) >= 2: + gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2]) + elif len(gaze_i[j]) == 1: + gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0]) + else: + continue + self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim)) + + print('Bbox...') + self.bbox_path = bbox_path + if bbox_path is not None: + self.objects = [ + 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener', + 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine', + 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan', + 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener' + ] + """ + # NOTE: old bbox + for i in episode_ids: + bbox_i_dir = os.path.join(bbox_path, i) + with open(bbox_i_dir, 'rb') as fp: + bboxes_i = pickle.load(fp) + len_i = len(bboxes_i) + fp.close() + bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4)) + for j in range(len(bboxes_i)): + items_i_j, bboxes_i_j = bboxes_i[j] + for k in range(len(items_i_j)): + bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([ + bboxes_i_j[k][0] / 1920, # * self.size_1, + bboxes_i_j[k][1] / 1088, # * self.size_2, + bboxes_i_j[k][2] / 1920, # * self.size_1, + bboxes_i_j[k][3] / 1088, # * self.size_2 + ]) # [x_min, y_min, x_max, y_max] + self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1)) + """ + # NOTE: new bbox + for i in episode_ids: + bbox_dir = os.path.join(bbox_path, i) + bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects + for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)): + with open(os.path.join(bbox_dir, bbox), 'r') as fp: + bbox_content = fp.readlines() + fp.close() + for bbox_content_line in bbox_content: + bbox_content_values = bbox_content_line.split() + class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values) + bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height]) + self.data['bboxes'].append(torch.flatten(bbox_tensor, 1)) + + print('OCR...\n') + self.ocr_graph_path = ocr_graph_path + if ocr_graph_path is not None: + ocr_graph = [ + [15, [10, 4], [17, 2]], + [13, [16, 7], [18, 4]], + [11, [16, 4], [7, 10]], + [14, [10, 11], [7, 1]], + [12, [10, 9], [16, 3]], + [1, [7, 2], [9, 9], [10, 2]], + [5, [8, 8], [6, 8]], + [4, [9, 8], [7, 6]], + [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]], + [2, [10, 1], [7, 7], [9, 3]], + [19, [10, 2], [26, 6]], + [20, [10, 7], [26, 5]], + [22, [25, 4], [10, 8]], + [23, [25, 15]], + [21, [16, 5], [24, 8]] + ] + ocr_tensor = torch.zeros((27, 27)) + for ocr in ocr_graph: + obj = ocr[0] + contexts = ocr[1:] + total_context_count = sum([i[1] for i in contexts]) + for context in contexts: + ocr_tensor[obj, context[0]] = context[1] / total_context_count + ocr_tensor = torch.flatten(ocr_tensor) + for i in episode_ids: + self.data['ocr_graph'].append(ocr_tensor) + + self.frame_path = frame_path + self.transform = transforms.Compose([ + #transforms.ToPILImage(), + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def __len__(self): + return len(self.data['frame_paths']) + + def __getitem__(self, idx): + frames_path = self.data['frame_paths'][idx][0].split('/') + frames_path.insert(5, self.presaved) + frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl' + with open(frames_path, 'rb') as f: + images = pickle.load(f) + images = torch.stack(images) + if self.patch_data: + images = rearrange( + images, + '(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)', + p1=self.spatial_patch_size, + p2=self.spatial_patch_size, + p3=self.temporal_patch_size + ) + + if self.pose_path is not None: + pose = self.data['poses'][idx] + if self.patch_data: + pose = rearrange( + pose, + '(t p1) d -> t p1 d', + p1=self.temporal_patch_size + ) + else: + pose = None + + if self.gaze_path is not None: + gaze = self.data['gazes'][idx] + if self.patch_data: + gaze = rearrange( + gaze, + '(t p1) d -> t p1 d', + p1=self.temporal_patch_size + ) + else: + gaze = None + + if self.bbox_path is not None: + bbox = self.data['bboxes'][idx] + if self.patch_data: + bbox = rearrange( + bbox, + '(t p1) d -> t p1 d', + p1=self.temporal_patch_size + ) + else: + bbox = None + + if self.ocr_graph_path is not None: + ocr_graphs = self.data['ocr_graph'][idx] + else: + ocr_graphs = None + + return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs + + +class DataTest(Dataset): + def __init__( + self, + frame_path, + label_path, + pose_path=None, + gaze_path=None, + bbox_path=None, + ocr_graph_path=None, + presaved=128, + frame_ids=None, + median=None, + sizes = (128,128), + spatial_patch_size: int = 16, + temporal_patch_size: int = 4, + img_channels: int = 3, + patch_data: bool = False, + flatten_dim: int = 1 + ): + self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []} + self.size_1, self.size_2 = sizes + self.patch_data = patch_data + if self.patch_data: + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2 + self.spatial_patch_dim = img_channels * spatial_patch_size ** 2 + + assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.' + + if presaved == 224: + self.presaved = 'presaved' + elif presaved == 128: + self.presaved = 'presaved128' + else: raise ValueError + + if frame_ids is not None: + test_ids = [] + frame_dirs = os.listdir(frame_path) + episode_ids = natsorted(frame_dirs) + for frame_dir in episode_ids: + if int(frame_dir) in frame_ids: + test_ids.append(str(frame_dir)) + frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] + self.data['frame_paths'].append(frame_paths) + elif median is not None: + test_ids = [] + frame_dirs = os.listdir(frame_path) + episode_ids = natsorted(frame_dirs) + for frame_dir in episode_ids: + frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] + seq_len = len(frame_paths) + if (median[1] and seq_len >= median[0]) or (not median[1] and seq_len < median[0]): + self.data['frame_paths'].append(frame_paths) + test_ids.append(int(frame_dir)) + else: + frame_dirs = os.listdir(frame_path) + episode_ids = natsorted(frame_dirs) + test_ids = episode_ids.copy() + for frame_dir in episode_ids: + frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] + self.data['frame_paths'].append(frame_paths) + + print('Labels...') + with open(label_path, 'rb') as fp: + labels = pickle.load(fp) + left_labels, right_labels = labels + for i in test_ids: + episode_id = int(i) + left_labels_i = left_labels[episode_id] + right_labels_i = right_labels[episode_id] + self.data['labels'].append([left_labels_i, right_labels_i]) + fp.close() + + print('Pose...') + self.pose_path = pose_path + if pose_path is not None: + for i in test_ids: + pose_i_dir = os.path.join(pose_path, i) + poses = [] + for j in natsorted(os.listdir(pose_i_dir)): + fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ) + if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3): + poses.append(fs.getNode("pose_0").mat()[:2,:,:]) + else: + poses.append(fs.getNode("pose_0").mat()) + poses = np.array(poses) + poses[:, :, :, 0] = poses[:, :, :, 0] / 1920 + poses[:, :, :, 1] = poses[:, :, :, 1] / 1088 + self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim)) + + print('Gaze...') + self.gaze_path = gaze_path + if gaze_path is not None: + for i in test_ids: + gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i)) + with open(gaze_i_txt, 'r') as fp: + gaze_i = fp.readlines() + fp.close() + gaze_i = ast.literal_eval(gaze_i[0]) + gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3)) + for j in range(len(gaze_i)): + if len(gaze_i[j]) >= 2: + gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2]) + elif len(gaze_i[j]) == 1: + gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0]) + else: + continue + self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim)) + + print('Bbox...') + self.bbox_path = bbox_path + if bbox_path is not None: + self.objects = [ + 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener', + 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine', + 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan', + 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener' + ] + """ NOTE: old bbox + for i in test_ids: + bbox_i_dir = os.path.join(bbox_path, i) + with open(bbox_i_dir, 'rb') as fp: + bboxes_i = pickle.load(fp) + len_i = len(bboxes_i) + fp.close() + bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4)) + for j in range(len(bboxes_i)): + items_i_j, bboxes_i_j = bboxes_i[j] + for k in range(len(items_i_j)): + bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([ + bboxes_i_j[k][0] / 1920, # * self.size_1, + bboxes_i_j[k][1] / 1088, # * self.size_2, + bboxes_i_j[k][2] / 1920, # * self.size_1, + bboxes_i_j[k][3] / 1088, # * self.size_2 + ]) # [x_min, y_min, x_max, y_max] + self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1)) + """ + for i in test_ids: + bbox_dir = os.path.join(bbox_path, i) + bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects + for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)): + with open(os.path.join(bbox_dir, bbox), 'r') as fp: + bbox_content = fp.readlines() + fp.close() + for bbox_content_line in bbox_content: + bbox_content_values = bbox_content_line.split() + class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values) + bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height]) + self.data['bboxes'].append(torch.flatten(bbox_tensor, 1)) + + print('OCR...\n') + self.ocr_graph_path = ocr_graph_path + if ocr_graph_path is not None: + ocr_graph = [ + [15, [10, 4], [17, 2]], + [13, [16, 7], [18, 4]], + [11, [16, 4], [7, 10]], + [14, [10, 11], [7, 1]], + [12, [10, 9], [16, 3]], + [1, [7, 2], [9, 9], [10, 2]], + [5, [8, 8], [6, 8]], + [4, [9, 8], [7, 6]], + [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]], + [2, [10, 1], [7, 7], [9, 3]], + [19, [10, 2], [26, 6]], + [20, [10, 7], [26, 5]], + [22, [25, 4], [10, 8]], + [23, [25, 15]], + [21, [16, 5], [24, 8]] + ] + ocr_tensor = torch.zeros((27, 27)) + for ocr in ocr_graph: + obj = ocr[0] + contexts = ocr[1:] + total_context_count = sum([i[1] for i in contexts]) + for context in contexts: + ocr_tensor[obj, context[0]] = context[1] / total_context_count + ocr_tensor = torch.flatten(ocr_tensor) + for i in test_ids: + self.data['ocr_graph'].append(ocr_tensor) + + self.frame_path = frame_path + self.transform = transforms.Compose([ + #transforms.ToPILImage(), + transforms.Resize((self.size_1, self.size_2)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def __len__(self): + return len(self.data['frame_paths']) + + def __getitem__(self, idx): + frames_path = self.data['frame_paths'][idx][0].split('/') + frames_path.insert(5, self.presaved) + frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl' + with open(frames_path, 'rb') as f: + images = pickle.load(f) + images = torch.stack(images) + if self.patch_data: + images = rearrange( + images, + '(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)', + p1=self.spatial_patch_size, + p2=self.spatial_patch_size, + p3=self.temporal_patch_size + ) + + if self.pose_path is not None: + pose = self.data['poses'][idx] + if self.patch_data: + pose = rearrange( + pose, + '(t p1) d -> t p1 d', + p1=self.temporal_patch_size + ) + else: + pose = None + + if self.gaze_path is not None: + gaze = self.data['gazes'][idx] + if self.patch_data: + gaze = rearrange( + gaze, + '(t p1) d -> t p1 d', + p1=self.temporal_patch_size + ) + else: + gaze = None + + if self.bbox_path is not None: + bbox = self.data['bboxes'][idx] + if self.patch_data: + bbox = rearrange( + bbox, + '(t p1) d -> t p1 d', + p1=self.temporal_patch_size + ) + else: + bbox = None + + if self.ocr_graph_path is not None: + ocr_graphs = self.data['ocr_graph'][idx] + else: + ocr_graphs = None + + return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs + diff --git a/boss/environment.yml b/boss/environment.yml new file mode 100644 index 0000000..2142bac --- /dev/null +++ b/boss/environment.yml @@ -0,0 +1,240 @@ +name: boss +channels: + - pyg + - anaconda + - conda-forge + - defaults + - pytorch +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - alembic=1.8.0=pyhd8ed1ab_0 + - appdirs=1.4.4=pyh9f0ad1d_0 + - asttokens=2.2.1=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=pyhd8ed1ab_3 + - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - blas=1.0=mkl + - blosc=1.21.3=h6a678d5_0 + - bottle=0.12.23=pyhd8ed1ab_0 + - bottleneck=1.3.5=py38h7deecbd_0 + - brotli=1.0.9=h5eee18b_7 + - brotli-bin=1.0.9=h5eee18b_7 + - brotlipy=0.7.0=py38h27cfd23_1003 + - brunsli=0.1=h2531618_0 + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.18.1=h7f8727e_0 + - ca-certificates=2023.01.10=h06a4308_0 + - certifi=2023.5.7=py38h06a4308_0 + - cffi=1.15.1=py38h74dc2b5_0 + - cfitsio=3.470=h5893167_7 + - charls=2.2.0=h2531618_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - click=8.1.3=unix_pyhd8ed1ab_2 + - cloudpickle=2.0.0=pyhd3eb1b0_0 + - cmaes=0.9.1=pyhd8ed1ab_0 + - colorlog=6.7.0=py38h578d9bd_1 + - contourpy=1.0.5=py38hdb19cb5_0 + - cryptography=39.0.1=py38h9ce1e76_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - cycler=0.11.0=pyhd3eb1b0_0 + - cytoolz=0.12.0=py38h5eee18b_0 + - dask-core=2022.7.0=py38h06a4308_0 + - dbus=1.13.18=hb2f20db_0 + - debugpy=1.5.1=py38h295c915_0 + - decorator=5.1.1=pyhd8ed1ab_0 + - docker-pycreds=0.4.0=py_0 + - entrypoints=0.4=pyhd8ed1ab_0 + - executing=1.2.0=pyhd8ed1ab_0 + - expat=2.4.9=h6a678d5_0 + - ffmpeg=4.3=hf484d3e_0 + - fftw=3.3.9=h27cfd23_1 + - flit-core=3.6.0=pyhd3eb1b0_0 + - fontconfig=2.14.1=h52c9d5c_1 + - fonttools=4.25.0=pyhd3eb1b0_0 + - freetype=2.12.1=h4a9f257_0 + - fsspec=2022.11.0=py38h06a4308_0 + - giflib=5.2.1=h5eee18b_1 + - gitdb=4.0.10=pyhd8ed1ab_0 + - gitpython=3.1.31=pyhd8ed1ab_0 + - glib=2.69.1=h4ff587b_1 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - greenlet=2.0.1=py38h6a678d5_0 + - gst-plugins-base=1.14.1=h6a678d5_1 + - gstreamer=1.14.1=h5eee18b_1 + - icu=58.2=he6710b0_3 + - idna=3.4=py38h06a4308_0 + - imagecodecs=2021.8.26=py38hf0132c2_1 + - imageio=2.19.3=py38h06a4308_0 + - importlib-metadata=6.1.0=pyha770c72_0 + - importlib_resources=5.2.0=pyhd3eb1b0_1 + - intel-openmp=2021.4.0=h06a4308_3561 + - ipykernel=6.15.0=pyh210e3f2_0 + - ipython=8.11.0=pyh41d4057_0 + - jedi=0.18.2=pyhd8ed1ab_0 + - jinja2=3.1.2=py38h06a4308_0 + - joblib=1.1.1=py38h06a4308_0 + - jpeg=9e=h7f8727e_0 + - jupyter_client=7.0.6=pyhd8ed1ab_0 + - jupyter_core=4.12.0=py38h578d9bd_0 + - jxrlib=1.1=h7b6447c_2 + - kiwisolver=1.4.4=py38h6a678d5_0 + - krb5=1.19.4=h568e23c_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libaec=1.0.4=he6710b0_1 + - libbrotlicommon=1.0.9=h5eee18b_7 + - libbrotlidec=1.0.9=h5eee18b_7 + - libbrotlienc=1.0.9=h5eee18b_7 + - libclang=10.0.1=default_hb85057a_2 + - libcurl=7.87.0=h91b91d3_0 + - libdeflate=1.8=h7f8727e_5 + - libedit=3.1.20221030=h5eee18b_0 + - libev=4.33=h7f8727e_1 + - libevent=2.1.12=h8f2d780_0 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=11.2.0=h00389a5_1 + - libgfortran5=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - libllvm10=10.0.1=hbcb73fb_5 + - libnghttp2=1.46.0=hce63b2e_0 + - libpng=1.6.37=hbc83047_0 + - libpq=12.9=h16c4e8d_3 + - libprotobuf=3.20.3=he621ea3_0 + - libsodium=1.0.18=h36c2ea0_1 + - libssh2=1.10.0=h8f2d780_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.5.0=h6a678d5_1 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp=1.2.4=h11a3e52_0 + - libwebp-base=1.2.4=h5eee18b_0 + - libxcb=1.15=h7f8727e_0 + - libxkbcommon=1.0.1=hfa300c1_0 + - libxml2=2.9.14=h74e7548_0 + - libxslt=1.1.35=h4e12654_0 + - libzopfli=1.0.3=he6710b0_0 + - locket=1.0.0=py38h06a4308_0 + - lz4-c=1.9.4=h6a678d5_0 + - mako=1.2.4=pyhd8ed1ab_0 + - markupsafe=2.1.1=py38h7f8727e_0 + - matplotlib=3.7.0=py38h06a4308_0 + - matplotlib-base=3.7.0=py38h417a72b_0 + - matplotlib-inline=0.1.6=pyhd8ed1ab_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.1=py38hd3c417c_0 + - mkl_random=1.2.2=py38h51133e4_0 + - munkres=1.1.4=py_0 + - natsort=7.1.1=pyhd3eb1b0_0 + - ncurses=6.3=h5eee18b_3 + - nest-asyncio=1.5.6=pyhd8ed1ab_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=2.8.4=py38h06a4308_0 + - nspr=4.33=h295c915_0 + - nss=3.74=h0370c37_0 + - numexpr=2.8.4=py38he184ba9_0 + - numpy=1.23.5=py38h14f4228_0 + - numpy-base=1.23.5=py38h31eccc5_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=1.1.1t=h7f8727e_0 + - optuna=3.1.0=pyhd8ed1ab_0 + - optuna-dashboard=0.9.0=pyhd8ed1ab_0 + - packaging=22.0=py38h06a4308_0 + - pandas=1.5.3=py38h417a72b_0 + - parso=0.8.3=pyhd8ed1ab_0 + - partd=1.2.0=pyhd3eb1b0_1 + - pathtools=0.1.2=py_1 + - pcre=8.45=h295c915_0 + - pexpect=4.8.0=pyh1a96a4e_2 + - pickleshare=0.7.5=py_1003 + - pillow=9.4.0=py38h6a678d5_0 + - pip=22.3.1=py38h06a4308_0 + - ply=3.11=py38_0 + - prompt-toolkit=3.0.38=pyha770c72_0 + - prompt_toolkit=3.0.38=hd8ed1ab_0 + - protobuf=3.20.3=py38h6a678d5_0 + - psutil=5.9.0=py38h5eee18b_0 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyg=2.2.0=py38_torch_1.12.0_cu113 + - pygments=2.14.0=pyhd8ed1ab_0 + - pyopenssl=23.0.0=py38h06a4308_0 + - pyparsing=3.0.9=py38h06a4308_0 + - pyqt=5.15.7=py38h6a678d5_1 + - pyqt5-sip=12.11.0=py38h6a678d5_1 + - pysocks=1.7.1=py38h06a4308_0 + - python=3.8.13=haa1d7c7_1 + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python_abi=3.8=2_cp38 + - pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0 + - pytorch-cluster=1.6.0=py38_torch_1.12.0_cu113 + - pytorch-mutex=1.0=cuda + - pytorch-scatter=2.1.0=py38_torch_1.12.0_cu113 + - pytorch-sparse=0.6.16=py38_torch_1.12.0_cu113 + - pytz=2022.7=py38h06a4308_0 + - pywavelets=1.4.1=py38h5eee18b_0 + - pyyaml=6.0=py38h5eee18b_1 + - pyzmq=19.0.2=py38ha71036d_2 + - qt-main=5.15.2=h327a75a_7 + - qt-webengine=5.15.9=hd2b0992_4 + - qtwebkit=5.212=h4eab89a_4 + - readline=8.2=h5eee18b_0 + - requests=2.28.1=py38h06a4308_1 + - scikit-image=0.19.3=py38h6a678d5_1 + - scikit-learn=1.2.1=py38h6a678d5_0 + - scipy=1.9.3=py38h14f4228_0 + - sentry-sdk=1.16.0=pyhd8ed1ab_0 + - setproctitle=1.2.2=py38h0a891b7_2 + - setuptools=65.6.3=py38h06a4308_0 + - sip=6.6.2=py38h6a678d5_0 + - six=1.16.0=pyhd3eb1b0_1 + - smmap=3.0.5=pyh44b312d_0 + - snappy=1.1.9=h295c915_0 + - sqlalchemy=1.4.39=py38h5eee18b_0 + - sqlite=3.40.1=h5082296_0 + - stack_data=0.6.2=pyhd8ed1ab_0 + - threadpoolctl=2.2.0=pyh0d69192_0 + - tifffile=2021.7.2=pyhd3eb1b0_2 + - tk=8.6.12=h1ccaba5_0 + - toml=0.10.2=pyhd3eb1b0_0 + - toolz=0.12.0=py38h06a4308_0 + - torchaudio=0.12.0=py38_cu113 + - torchvision=0.13.0=py38_cu113 + - tornado=6.1=py38h0a891b7_3 + - tqdm=4.64.1=py38h06a4308_0 + - traitlets=5.9.0=pyhd8ed1ab_0 + - typing_extensions=4.4.0=py38h06a4308_0 + - tzdata=2022g=h04d1e81_0 + - urllib3=1.26.14=py38h06a4308_0 + - wandb=0.13.10=pyhd8ed1ab_0 + - wcwidth=0.2.6=pyhd8ed1ab_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.10=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zeromq=4.3.4=h9c3ff4c_1 + - zfp=0.5.5=h295c915_6 + - zipp=3.11.0=py38h06a4308_0 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.2=ha4553b6_0 + - pip: + - einops==0.6.1 + - lion-pytorch==0.1.2 + - llvmlite==0.40.0 + - memory-efficient-attention-pytorch==0.1.2 + - numba==0.57.0 + - opencv-python==4.7.0.72 + - pynndescent==0.5.10 + - seaborn==0.12.2 + - umap-learn==0.5.3 + - x-transformers==1.14.0 +prefix: /opt/anaconda3/envs/boss \ No newline at end of file diff --git a/boss/models/__init__.py b/boss/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/boss/models/base.py b/boss/models/base.py new file mode 100644 index 0000000..e5246f2 --- /dev/null +++ b/boss/models/base.py @@ -0,0 +1,230 @@ +import torch +import torch.nn as nn +from .utils import pose_edge_index +from torch_geometric.nn import Sequential, GCNConv +from x_transformers import ContinuousTransformerWrapper, Decoder + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Linear(dim, dim)) + + def forward(self, x): + return self.net(x) + + +class CNN(nn.Module): + def __init__(self, hidden_dim): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.conv1(x) + x = nn.functional.relu(x) + x = self.pool(x) + x = self.conv2(x) + x = nn.functional.relu(x) + x = self.pool(x) + x = self.conv3(x) + x = nn.functional.relu(x) + x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling + return x + + +class MindNetLSTM(nn.Module): + """ + Basic MindNet for model-based ToM, just LSTM on input concatenation + """ + def __init__(self, hidden_dim, dropout, mods): + super(MindNetLSTM, self).__init__() + self.mods = mods + self.gaze_emb = nn.Linear(3, hidden_dim) + self.pose_edge_index = pose_edge_index() + self.pose_emb = GCNConv(3, hidden_dim) + self.LSTM = PreNorm( + hidden_dim*len(mods), + nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True)) + self.proj = nn.Linear(hidden_dim*2, hidden_dim) + self.dropout = nn.Dropout(dropout) + self.act = nn.GELU() + + def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats): + feats = [] + if 'rgb' in self.mods: + feats.append(rgb_feats) + if 'ocr' in self.mods: + feats.append(ocr_feats) + if 'pose' in self.mods: + bs, seq_len = pose.size(0), pose.size(1) + self.pose_edge_index = self.pose_edge_index.to(pose.device) + pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index) + pose_emb = self.dropout(self.act(pose_emb)) + pose_emb = torch.mean(pose_emb, dim=1) + hd = pose_emb.size(-1) + feats.append(pose_emb.view(bs, seq_len, hd)) + if 'gaze' in self.mods: + gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) + feats.append(gaze_feats) + if 'bbox' in self.mods: + feats.append(bbox_feats) + lstm_inp = torch.cat(feats, 2) + lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp)) + c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2) + return self.act(self.proj(lstm_out)), c_n, feats + + +class MindNetSL(nn.Module): + """ + Basic MindNet for SL ToM, just LSTM on input concatenation + """ + def __init__(self, hidden_dim, dropout, mods): + super(MindNetSL, self).__init__() + self.mods = mods + self.gaze_emb = nn.Linear(3, hidden_dim) + self.pose_edge_index = pose_edge_index() + self.pose_emb = GCNConv(3, hidden_dim) + self.LSTM = PreNorm( + hidden_dim*5, + nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True) + ) + self.proj = nn.Linear(hidden_dim*2, hidden_dim) + self.dropout = nn.Dropout(dropout) + self.act = nn.GELU() + + def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats): + feats = [] + if 'rgb' in self.mods: + feats.append(rgb_feats) + if 'ocr' in self.mods: + feats.append(ocr_feats) + if 'pose' in self.mods: + bs, seq_len = pose.size(0), pose.size(1) + self.pose_edge_index = self.pose_edge_index.to(pose.device) + pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index) + pose_emb = self.dropout(self.act(pose_emb)) + pose_emb = torch.mean(pose_emb, dim=1) + hd = pose_emb.size(-1) + feats.append(pose_emb.view(bs, seq_len, hd)) + if 'gaze' in self.mods: + gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) + feats.append(gaze_feats) + if 'bbox' in self.mods: + feats.append(bbox_feats) + lstm_inp = torch.cat(feats, 2) + lstm_out, _ = self.LSTM(self.dropout(lstm_inp)) + return self.act(self.proj(lstm_out)), feats + + +class MindNetTF(nn.Module): + """ + Basic MindNet for model-based ToM, Transformer on input concatenation + """ + def __init__(self, hidden_dim, dropout, mods): + super(MindNetTF, self).__init__() + self.mods = mods + self.gaze_emb = nn.Linear(3, hidden_dim) + self.pose_edge_index = pose_edge_index() + self.pose_emb = GCNConv(3, hidden_dim) + self.tf = ContinuousTransformerWrapper( + dim_in=hidden_dim*len(mods), + dim_out=hidden_dim, + max_seq_len=747, + attn_layers=Decoder( + dim=512, + depth=6, + heads=8 + ) + ) + self.dropout = nn.Dropout(dropout) + self.act = nn.GELU() + + def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats): + feats = [] + if 'rgb' in self.mods: + feats.append(rgb_feats) + if 'ocr' in self.mods: + feats.append(ocr_feats) + if 'pose' in self.mods: + bs, seq_len = pose.size(0), pose.size(1) + self.pose_edge_index = self.pose_edge_index.to(pose.device) + pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index) + pose_emb = self.dropout(self.act(pose_emb)) + pose_emb = torch.mean(pose_emb, dim=1) + hd = pose_emb.size(-1) + feats.append(pose_emb.view(bs, seq_len, hd)) + if 'gaze' in self.mods: + gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) + feats.append(gaze_feats) + if 'bbox' in self.mods: + feats.append(bbox_feats) + tf_inp = torch.cat(feats, 2) + tf_out = self.tf(self.dropout(tf_inp)) + return tf_out, feats + + +class MindNetLSTMXL(nn.Module): + """ + Basic MindNet for model-based ToM, just LSTM on input concatenation + """ + def __init__(self, hidden_dim, dropout, mods): + super(MindNetLSTMXL, self).__init__() + self.mods = mods + self.gaze_emb = nn.Sequential( + nn.Linear(3, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim) + ) + self.pose_edge_index = pose_edge_index() + self.pose_emb = Sequential('x, edge_index', [ + (GCNConv(3, hidden_dim), 'x, edge_index -> x'), + nn.GELU(), + (GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'), + ]) + self.LSTM = PreNorm( + hidden_dim*len(mods), + nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True) + ) + self.proj = nn.Linear(hidden_dim*2, hidden_dim) + self.dropout = nn.Dropout(dropout) + self.act = nn.GELU() + + def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats): + feats = [] + if 'rgb' in self.mods: + feats.append(rgb_feats) + if 'ocr' in self.mods: + feats.append(ocr_feats) + if 'pose' in self.mods: + bs, seq_len = pose.size(0), pose.size(1) + self.pose_edge_index = self.pose_edge_index.to(pose.device) + pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index) + pose_emb = self.dropout(self.act(pose_emb)) + pose_emb = torch.mean(pose_emb, dim=1) + hd = pose_emb.size(-1) + feats.append(pose_emb.view(bs, seq_len, hd)) + if 'gaze' in self.mods: + gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) + feats.append(gaze_feats) + if 'bbox' in self.mods: + feats.append(bbox_feats) + lstm_inp = torch.cat(feats, 2) + lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp)) + c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2) + return self.act(self.proj(lstm_out)), c_n, feats \ No newline at end of file diff --git a/boss/models/resnet.py b/boss/models/resnet.py new file mode 100644 index 0000000..c3f810c --- /dev/null +++ b/boss/models/resnet.py @@ -0,0 +1,249 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from .utils import left_bias, right_bias + + +class ResNet(nn.Module): + def __init__(self, input_dim, device): + super(ResNet, self).__init__() + # Conv + resnet = models.resnet34(pretrained=True) + self.resnet = nn.Sequential(*(list(resnet.children())[:-1])) + # FFs + self.left = nn.Linear(input_dim, 27) + self.right = nn.Linear(input_dim, 27) + # modality FFs + self.pose_ff = nn.Linear(150, 150) + self.gaze_ff = nn.Linear(6, 6) + self.bbox_ff = nn.Linear(108, 108) + self.ocr_ff = nn.Linear(729, 64) + # others + self.relu = nn.ReLU() + self.dropout = nn.Dropout(p=0.2) + self.device = device + + def forward(self, images, poses, gazes, bboxes, ocr_tensor): + batch_size, sequence_len, channels, height, width = images.shape + left_beliefs = [] + right_beliefs = [] + image_feats = [] + + for i in range(sequence_len): + images_i = images[:,i].to(self.device) + image_i_feat = self.resnet(images_i) + image_i_feat = image_i_feat.view(batch_size, 512) + if poses is not None: + poses_i = poses[:,i].float() + poses_i_feat = self.relu(self.pose_ff(poses_i)) + image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1) + if gazes is not None: + gazes_i = gazes[:,i].float() + gazes_i_feat = self.relu(self.gaze_ff(gazes_i)) + image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1) + if bboxes is not None: + bboxes_i = bboxes[:,i].float() + bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i)) + image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1) + if ocr_tensor is not None: + ocr_tensor = ocr_tensor + ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor)) + image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1) + image_feats.append(image_i_feat) + + image_feats = torch.permute(torch.stack(image_feats), (1,0,2)) + left_beliefs = self.left(self.dropout(image_feats)) + right_beliefs = self.right(self.dropout(image_feats)) + + return left_beliefs, right_beliefs, None + + +class ResNetGRU(nn.Module): + def __init__(self, input_dim, device): + super(ResNetGRU, self).__init__() + resnet = models.resnet34(pretrained=True) + self.resnet = nn.Sequential(*(list(resnet.children())[:-1])) + self.gru = nn.GRU(input_dim, 512, batch_first=True) + for name, param in self.gru.named_parameters(): + if "weight" in name: + nn.init.orthogonal_(param) + elif "bias" in name: + nn.init.constant_(param, 0) + # FFs + self.left = nn.Linear(512, 27) + self.right = nn.Linear(512, 27) + # modality FFs + self.pose_ff = nn.Linear(150, 150) + self.gaze_ff = nn.Linear(6, 6) + self.bbox_ff = nn.Linear(108, 108) + self.ocr_ff = nn.Linear(729, 64) + # others + self.dropout = nn.Dropout(p=0.2) + self.relu = nn.ReLU() + self.device = device + + def forward(self, images, poses, gazes, bboxes, ocr_tensor): + batch_size, sequence_len, channels, height, width = images.shape + left_beliefs = [] + right_beliefs = [] + rnn_inp = [] + + for i in range(sequence_len): + images_i = images[:,i] + rnn_i_feat = self.resnet(images_i) + rnn_i_feat = rnn_i_feat.view(batch_size, 512) + if poses is not None: + poses_i = poses[:,i].float() + poses_i_feat = self.relu(self.pose_ff(poses_i)) + rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1) + if gazes is not None: + gazes_i = gazes[:,i].float() + gazes_i_feat = self.relu(self.gaze_ff(gazes_i)) + rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1) + if bboxes is not None: + bboxes_i = bboxes[:,i].float() + bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i)) + rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1) + if ocr_tensor is not None: + ocr_tensor = ocr_tensor + ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor)) + rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1) + rnn_inp.append(rnn_i_feat) + + rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2)) + rnn_out, _ = self.gru(rnn_inp) + left_beliefs = self.left(self.dropout(rnn_out)) + right_beliefs = self.right(self.dropout(rnn_out)) + + return left_beliefs, right_beliefs, None + + +class ResNetConv1D(nn.Module): + def __init__(self, input_dim, device): + super(ResNetConv1D, self).__init__() + resnet = models.resnet34(pretrained=True) + self.resnet = nn.Sequential(*(list(resnet.children())[:-1])) + self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4) + # FFs + self.left = nn.Linear(512, 27) + self.right = nn.Linear(512, 27) + # modality FFs + self.pose_ff = nn.Linear(150, 150) + self.gaze_ff = nn.Linear(6, 6) + self.bbox_ff = nn.Linear(108, 108) + self.ocr_ff = nn.Linear(729, 64) + # others + self.relu = nn.ReLU() + self.dropout = nn.Dropout(p=0.2) + self.device = device + + def forward(self, images, poses, gazes, bboxes, ocr_tensor): + batch_size, sequence_len, channels, height, width = images.shape + left_beliefs = [] + right_beliefs = [] + conv1d_inp = [] + + for i in range(sequence_len): + images_i = images[:,i] + images_i_feat = self.resnet(images_i) + images_i_feat = images_i_feat.view(batch_size, 512) + if poses is not None: + poses_i = poses[:,i].float() + poses_i_feat = self.relu(self.pose_ff(poses_i)) + images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1) + if gazes is not None: + gazes_i = gazes[:,i].float() + gazes_i_feat = self.relu(self.gaze_ff(gazes_i)) + images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1) + if bboxes is not None: + bboxes_i = bboxes[:,i].float() + bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i)) + images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1) + if ocr_tensor is not None: + ocr_tensor = ocr_tensor + ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor)) + images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1) + conv1d_inp.append(images_i_feat) + + conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0)) + conv1d_out = self.conv1d(conv1d_inp) + conv1d_out = conv1d_out[:,:,:-4] + conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1))) + left_beliefs = self.left(self.dropout(conv1d_out)) + right_beliefs = self.right(self.dropout(conv1d_out)) + + return left_beliefs, right_beliefs, None + + +class ResNetLSTM(nn.Module): + def __init__(self, input_dim, device): + super(ResNetLSTM, self).__init__() + resnet = models.resnet34(pretrained=True) + self.resnet = nn.Sequential(*(list(resnet.children())[:-1])) + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True) + # FFs + self.left = nn.Linear(512, 27) + self.right = nn.Linear(512, 27) + self.left.bias.data = torch.tensor(left_bias).log() + self.right.bias.data = torch.tensor(right_bias).log() + # modality FFs + self.pose_ff = nn.Linear(150, 150) + self.gaze_ff = nn.Linear(6, 6) + self.bbox_ff = nn.Linear(108, 108) + self.ocr_ff = nn.Linear(729, 64) + # others + self.relu = nn.ReLU() + self.dropout = nn.Dropout(p=0.2) + self.device = device + + def forward(self, images, poses, gazes, bboxes, ocr_tensor): + batch_size, sequence_len, channels, height, width = images.shape + left_beliefs = [] + right_beliefs = [] + rnn_inp = [] + + for i in range(sequence_len): + images_i = images[:,i] + rnn_i_feat = self.resnet(images_i) + rnn_i_feat = rnn_i_feat.view(batch_size, 512) + if poses is not None: + poses_i = poses[:,i].float() + poses_i_feat = self.relu(self.pose_ff(poses_i)) + rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1) + if gazes is not None: + gazes_i = gazes[:,i].float() + gazes_i_feat = self.relu(self.gaze_ff(gazes_i)) + rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1) + if bboxes is not None: + bboxes_i = bboxes[:,i].float() + bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i)) + rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1) + if ocr_tensor is not None: + ocr_tensor = ocr_tensor + ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor)) + rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1) + rnn_inp.append(rnn_i_feat) + + rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2)) + rnn_out, _ = self.lstm(rnn_inp) + left_beliefs = self.left(self.dropout(rnn_out)) + right_beliefs = self.right(self.dropout(rnn_out)) + + return left_beliefs, right_beliefs, None + + + + + + + + + + +if __name__ == '__main__': + images = torch.ones(3, 22, 3, 128, 128) + poses = torch.ones(3, 22, 150) + gazes = torch.ones(3, 22, 6) + bboxes = torch.ones(3, 22, 108) + model = ResNet(32, 'cpu') + print(model(images, poses, gazes, bboxes, None)[0].shape) diff --git a/boss/models/single_mindnet.py b/boss/models/single_mindnet.py new file mode 100644 index 0000000..c9c6aa8 --- /dev/null +++ b/boss/models/single_mindnet.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn +from torch_geometric.nn.conv import GCNConv +from .utils import left_bias, right_bias, build_ocr_graph +from .base import CNN, MindNetLSTM +import torchvision.models as models + + +class SingleMindNet(nn.Module): + """ + Base ToM net. Supports any subset of modalities + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']): + super(SingleMindNet, self).__init__() + + # ---- Images ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + + # ---- OCR and bbox -----# + self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device) + self.ocr_gnn = GCNConv(-1, hidden_dim) + self.bbox_ff = nn.Linear(108, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net = MindNetLSTM(hidden_dim, dropout, mods) + + self.left = nn.Linear(hidden_dim, 27) + self.right = nn.Linear(hidden_dim, 27) + self.left.bias.data = torch.tensor(left_bias).log() + self.right.bias.data = torch.tensor(right_bias).log() + + def forward(self, images, poses, gazes, bboxes, ocr_tensor=None): + + assert images is not None + assert poses is not None + assert gazes is not None + assert bboxes is not None + + batch_size, sequence_len, channels, height, width = images.shape + + bbox_feat = self.act(self.bbox_ff(bboxes)) + + rgb_feat = [] + for i in range(sequence_len): + images_i = images[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat))) + + ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x, + self.ocr_edge_index, + self.ocr_edge_attr))) + ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1) + + out_left, cell_left, feats_left = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat) + out_right, cell_right, feats_right = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat) + + return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right + + + + +if __name__ == "__main__": + + def count_parameters(model): + import numpy as np + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([np.prod(p.size()) for p in model_parameters]) + + images = torch.ones(3, 22, 3, 128, 128) + poses = torch.ones(3, 22, 2, 75) + gazes = torch.ones(3, 22, 2, 3) + bboxes = torch.ones(3, 22, 108) + model = SingleMindNet(64, 'cpu', False, 0.5) + print(count_parameters(model)) + breakpoint() + out = model(images, poses, gazes, bboxes, None) + print(out[0].shape) \ No newline at end of file diff --git a/boss/models/tom_base.py b/boss/models/tom_base.py new file mode 100644 index 0000000..ae45d91 --- /dev/null +++ b/boss/models/tom_base.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from torch_geometric.nn.conv import GCNConv +from .utils import left_bias, right_bias, build_ocr_graph +from .base import CNN, MindNetLSTM +import numpy as np + + +class BaseToMnet(nn.Module): + """ + Base ToM net. Supports any subset of modalities + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']): + super(BaseToMnet, self).__init__() + + # ---- Images ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + + # ---- OCR and bbox -----# + self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device) + self.ocr_gnn = GCNConv(-1, hidden_dim) + self.bbox_ff = nn.Linear(108, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods) + self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods) + + self.left = nn.Linear(hidden_dim, 27) + self.right = nn.Linear(hidden_dim, 27) + self.left.bias.data = torch.tensor(left_bias).log() + self.right.bias.data = torch.tensor(right_bias).log() + + def forward(self, images, poses, gazes, bboxes, ocr_tensor=None): + + assert images is not None + assert poses is not None + assert gazes is not None + assert bboxes is not None + + batch_size, sequence_len, channels, height, width = images.shape + + bbox_feat = self.act(self.bbox_ff(bboxes)) + + rgb_feat = [] + for i in range(sequence_len): + images_i = images[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat))) + + ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x, + self.ocr_edge_index, + self.ocr_edge_attr))) + ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1) + + out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat) + out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat) + + return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right + + + + +def count_parameters(model): + #return sum(p.numel() for p in model.parameters() if p.requires_grad) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([np.prod(p.size()) for p in model_parameters]) + + + + + + + + +if __name__ == "__main__": + + images = torch.ones(3, 22, 3, 128, 128) + poses = torch.ones(3, 22, 2, 75) + gazes = torch.ones(3, 22, 2, 3) + bboxes = torch.ones(3, 22, 108) + model = BaseToMnet(64, 'cpu', False, 0.5) + print(count_parameters(model)) + breakpoint() + out = model(images, poses, gazes, bboxes, None) + print(out[0].shape) \ No newline at end of file diff --git a/boss/models/tom_common_mind.py b/boss/models/tom_common_mind.py new file mode 100644 index 0000000..1e1c158 --- /dev/null +++ b/boss/models/tom_common_mind.py @@ -0,0 +1,269 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from torch_geometric.nn.conv import GCNConv +from .utils import left_bias, right_bias, build_ocr_graph +from .base import CNN, MindNetLSTM, MindNetLSTMXL +from memory_efficient_attention_pytorch import Attention + + +class CommonMindToMnet(nn.Module): + """ + + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']): + super(CommonMindToMnet, self).__init__() + + self.aggr = aggr + + # ---- Images ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + #for param in self.cnn.parameters(): + # param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + + # ---- OCR and bbox -----# + self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device) + self.ocr_gnn = GCNConv(-1, hidden_dim) + self.bbox_ff = nn.Linear(108, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods) + self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods) + self.proj = nn.Linear(hidden_dim*2, hidden_dim) + self.ln_left = nn.LayerNorm(hidden_dim) + self.ln_right = nn.LayerNorm(hidden_dim) + if aggr == 'attn': + self.attn_left = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.attn_right = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.left = nn.Linear(hidden_dim, 27) + self.right = nn.Linear(hidden_dim, 27) + self.left.bias.data = torch.tensor(left_bias).log() + self.right.bias.data = torch.tensor(right_bias).log() + + def forward(self, images, poses, gazes, bboxes, ocr_tensor=None): + + assert images is not None + assert poses is not None + assert gazes is not None + assert bboxes is not None + + batch_size, sequence_len, channels, height, width = images.shape + + bbox_feat = self.act(self.bbox_ff(bboxes)) + + rgb_feat = [] + for i in range(sequence_len): + images_i = images[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat))) + + ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x, + self.ocr_edge_index, + self.ocr_edge_attr))) + ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1) + + out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat) + out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat) + + common_mind = self.proj(torch.cat([cell_left, cell_right], -1)) + + if self.aggr == 'no_tom': + return self.left(out_left), self.right(out_right) + + if self.aggr == 'attn': + l = self.attn_left(x=out_left, context=common_mind) + r = self.attn_right(x=out_right, context=common_mind) + elif self.aggr == 'mult': + l = out_left * common_mind + r = out_right * common_mind + elif self.aggr == 'sum': + l = out_left + common_mind + r = out_right + common_mind + elif self.aggr == 'concat': + l = torch.cat([out_left, common_mind], 1) + r = torch.cat([out_right, common_mind], 1) + else: raise ValueError + l = self.act(l) + l = self.ln_left(l) + r = self.act(r) + r = self.ln_right(r) + if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn': + left_beliefs = self.left(l) + right_beliefs = self.right(r) + if self.aggr == 'concat': + left_beliefs = self.left(l)[:, :-1, :] + right_beliefs = self.right(r)[:, :-1, :] + + return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right + + + +class CommonMindToMnetXL(nn.Module): + """ + XL model. + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']): + super(CommonMindToMnetXL, self).__init__() + + self.aggr = aggr + + # ---- Images ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + + # ---- OCR and bbox -----# + self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device) + self.ocr_gnn = GCNConv(-1, hidden_dim) + self.bbox_ff = nn.Linear(108, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_left = MindNetLSTMXL(hidden_dim, dropout, mods) + self.mind_net_right = MindNetLSTMXL(hidden_dim, dropout, mods) + self.proj = nn.Linear(hidden_dim*2, hidden_dim) + self.ln_left = nn.LayerNorm(hidden_dim) + self.ln_right = nn.LayerNorm(hidden_dim) + if aggr == 'attn': + self.attn_left = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.attn_right = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.left = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, 27), + ) + self.right = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, 27) + ) + self.left[-1].bias.data = torch.tensor(left_bias).log() + self.right[-1].bias.data = torch.tensor(right_bias).log() + + def forward(self, images, poses, gazes, bboxes, ocr_tensor=None): + + assert images is not None + assert poses is not None + assert gazes is not None + assert bboxes is not None + + batch_size, sequence_len, channels, height, width = images.shape + + bbox_feat = self.act(self.bbox_ff(bboxes)) + + rgb_feat = [] + for i in range(sequence_len): + images_i = images[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat))) + + ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x, + self.ocr_edge_index, + self.ocr_edge_attr))) + ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1) + + out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat) + out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat) + + common_mind = self.proj(torch.cat([cell_left, cell_right], -1)) + + if self.aggr == 'no_tom': + return self.left(out_left), self.right(out_right) + + if self.aggr == 'attn': + l = self.attn_left(x=out_left, context=common_mind) + r = self.attn_right(x=out_right, context=common_mind) + elif self.aggr == 'mult': + l = out_left * common_mind + r = out_right * common_mind + elif self.aggr == 'sum': + l = out_left + common_mind + r = out_right + common_mind + elif self.aggr == 'concat': + l = torch.cat([out_left, common_mind], 1) + r = torch.cat([out_right, common_mind], 1) + else: raise ValueError + l = self.act(l) + l = self.ln_left(l) + r = self.act(r) + r = self.ln_right(r) + if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn': + left_beliefs = self.left(l) + right_beliefs = self.right(r) + if self.aggr == 'concat': + left_beliefs = self.left(l)[:, :-1, :] + right_beliefs = self.right(r)[:, :-1, :] + + return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right + + + + + + + + +if __name__ == "__main__": + + images = torch.ones(3, 22, 3, 128, 128) + poses = torch.ones(3, 22, 2, 75) + gazes = torch.ones(3, 22, 2, 3) + bboxes = torch.ones(3, 22, 108) + model = CommonMindToMnetXL(64, 'cpu', False, 0.5, aggr='attn') + out = model(images, poses, gazes, bboxes, None) + print(out[0].shape) \ No newline at end of file diff --git a/boss/models/tom_implicit.py b/boss/models/tom_implicit.py new file mode 100644 index 0000000..6050606 --- /dev/null +++ b/boss/models/tom_implicit.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from torch_geometric.nn.conv import GCNConv +from .utils import left_bias, right_bias, build_ocr_graph +from .base import CNN, MindNetLSTM +from memory_efficient_attention_pytorch import Attention + + +class ImplicitToMnet(nn.Module): + """ + Implicit ToM net. Supports any subset of modalities + Possible aggregations: sum, mult, attn, concat + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']): + super(ImplicitToMnet, self).__init__() + + self.aggr = aggr + + # ---- Images ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + + # ---- OCR and bbox -----# + self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device) + self.ocr_gnn = GCNConv(-1, hidden_dim) + self.bbox_ff = nn.Linear(108, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods) + self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods) + self.ln_left = nn.LayerNorm(hidden_dim) + self.ln_right = nn.LayerNorm(hidden_dim) + if aggr == 'attn': + self.attn_left = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.attn_right = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.left = nn.Linear(hidden_dim, 27) + self.right = nn.Linear(hidden_dim, 27) + self.left.bias.data = torch.tensor(left_bias).log() + self.right.bias.data = torch.tensor(right_bias).log() + + def forward(self, images, poses, gazes, bboxes, ocr_tensor=None): + + assert images is not None + assert poses is not None + assert gazes is not None + assert bboxes is not None + + batch_size, sequence_len, channels, height, width = images.shape + + bbox_feat = self.act(self.bbox_ff(bboxes)) + + rgb_feat = [] + for i in range(sequence_len): + images_i = images[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat))) + + ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x, + self.ocr_edge_index, + self.ocr_edge_attr))) + ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1) + + out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat) + out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat) + + if self.aggr == 'no_tom': + return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right + + if self.aggr == 'attn': + l = self.attn_left(x=out_left, context=cell_right) + r = self.attn_right(x=out_right, context=cell_left) + elif self.aggr == 'mult': + l = out_left * cell_right + r = out_right * cell_left + elif self.aggr == 'sum': + l = out_left + cell_right + r = out_right + cell_left + elif self.aggr == 'concat': + l = torch.cat([out_left, cell_right], 1) + r = torch.cat([out_right, cell_left], 1) + else: raise ValueError + l = self.act(l) + l = self.ln_left(l) + r = self.act(r) + r = self.ln_right(r) + if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn': + left_beliefs = self.left(l) + right_beliefs = self.right(r) + if self.aggr == 'concat': + left_beliefs = self.left(l)[:, :-1, :] + right_beliefs = self.right(r)[:, :-1, :] + + return left_beliefs, right_beliefs, [out_left, cell_left, out_right, cell_right] + feats_left + feats_right + + + + + + + + + + + + +if __name__ == "__main__": + + images = torch.ones(3, 22, 3, 128, 128) + poses = torch.ones(3, 22, 2, 75) + gazes = torch.ones(3, 22, 2, 3) + bboxes = torch.ones(3, 22, 108) + model = ImplicitToMnet(64, 'cpu', False, 0.5, aggr='attn') + out = model(images, poses, gazes, bboxes, None) + print(out[0].shape) diff --git a/boss/models/tom_sl.py b/boss/models/tom_sl.py new file mode 100644 index 0000000..34dab4c --- /dev/null +++ b/boss/models/tom_sl.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from torch_geometric.nn.conv import GCNConv +from .utils import left_bias, right_bias, build_ocr_graph, pose_edge_index +from .base import CNN, MindNetSL + + +class SLToMnet(nn.Module): + """ + Speaker-Listener ToMnet + """ + def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']): + super(SLToMnet, self).__init__() + + self.tom_weight = tom_weight + + # ---- Images ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + + # ---- OCR and bbox -----# + self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device) + self.ocr_gnn = GCNConv(-1, hidden_dim) + self.bbox_ff = nn.Linear(108, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_left = MindNetSL(hidden_dim, dropout, mods) + self.mind_net_right = MindNetSL(hidden_dim, dropout, mods) + self.left = nn.Linear(hidden_dim, 27) + self.right = nn.Linear(hidden_dim, 27) + self.left.bias.data = torch.tensor(left_bias).log() + self.right.bias.data = torch.tensor(right_bias).log() + + def forward(self, images, poses, gazes, bboxes, ocr_tensor=None): + + batch_size, sequence_len, channels, height, width = images.shape + + bbox_feat = self.act(self.bbox_ff(bboxes)) + + rgb_feat = [] + for i in range(sequence_len): + images_i = images[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat))) + + ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x, + self.ocr_edge_index, + self.ocr_edge_attr))) + ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1) + + out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat) + left_logits = self.left(out_left) + out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat) + right_logits = self.left(out_right) + + left_ranking = torch.log_softmax(left_logits, dim=-1) + right_ranking = torch.log_softmax(right_logits, dim=-1) + + right_beliefs = left_ranking + self.tom_weight * right_ranking + left_beliefs = right_ranking + self.tom_weight * left_ranking + + return left_beliefs, right_beliefs, [out_left, out_right] + feats_left + feats_right + + + + + + + + + +if __name__ == "__main__": + + images = torch.ones(3, 22, 3, 128, 128) + poses = torch.ones(3, 22, 2, 75) + gazes = torch.ones(3, 22, 2, 3) + bboxes = torch.ones(3, 22, 108) + model = SLToMnet(64, 'cpu', 2.0, False, 0.5) + out = model(images, poses, gazes, bboxes, None) + print(out[0].shape) \ No newline at end of file diff --git a/boss/models/tom_tf.py b/boss/models/tom_tf.py new file mode 100644 index 0000000..264a33d --- /dev/null +++ b/boss/models/tom_tf.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from torch_geometric.nn.conv import GCNConv +from .utils import left_bias, right_bias, build_ocr_graph +from .base import CNN, MindNetTF +from memory_efficient_attention_pytorch import Attention + + +class TFToMnet(nn.Module): + """ + Implicit ToM net. Supports any subset of modalities + Possible aggregations: sum, mult, attn, concat + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']): + super(TFToMnet, self).__init__() + + # ---- Images ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + + # ---- OCR and bbox -----# + self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device) + self.ocr_gnn = GCNConv(-1, hidden_dim) + self.bbox_ff = nn.Linear(108, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_left = MindNetTF(hidden_dim, dropout, mods) + self.mind_net_right = MindNetTF(hidden_dim, dropout, mods) + self.left = nn.Linear(hidden_dim, 27) + self.right = nn.Linear(hidden_dim, 27) + self.left.bias.data = torch.tensor(left_bias).log() + self.right.bias.data = torch.tensor(right_bias).log() + + def forward(self, images, poses, gazes, bboxes, ocr_tensor=None): + + assert images is not None + assert poses is not None + assert gazes is not None + assert bboxes is not None + + batch_size, sequence_len, channels, height, width = images.shape + + bbox_feat = self.act(self.bbox_ff(bboxes)) + + rgb_feat = [] + for i in range(sequence_len): + images_i = images[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat))) + + ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x, + self.ocr_edge_index, + self.ocr_edge_attr))) + ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1) + + out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat) + out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat) + + l = self.dropout(self.act(out_left)) + r = self.dropout(self.act(out_right)) + left_beliefs = self.left(l) + right_beliefs = self.right(r) + + return left_beliefs, right_beliefs, feats_left + feats_right + + + + + + + + + + + + +if __name__ == "__main__": + + images = torch.ones(3, 22, 3, 128, 128) + poses = torch.ones(3, 22, 2, 75) + gazes = torch.ones(3, 22, 2, 3) + bboxes = torch.ones(3, 22, 108) + model = TFToMnet(64, 'cpu', False, 0.5) + out = model(images, poses, gazes, bboxes, None) + print(out[0].shape) + diff --git a/boss/models/utils.py b/boss/models/utils.py new file mode 100644 index 0000000..ba8969e --- /dev/null +++ b/boss/models/utils.py @@ -0,0 +1,95 @@ +import torch + +left_bias = [0.10659290976303822, + 0.025158905348262015, + 0.02811095449589107, + 0.026342384511050237, + 0.025318475572458178, + 0.02283183957873461, + 0.021581872822531316, + 0.08062285577511237, + 0.03824366373234754, + 0.04853594319300018, + 0.09653998563867983, + 0.02961357410707162, + 0.02961357410707162, + 0.03172787957767081, + 0.029985904630196004, + 0.02897529321028696, + 0.06602218026116327, + 0.015345336560197867, + 0.026900880295736816, + 0.024879657455918726, + 0.028669450280577644, + 0.01936118720246802, + 0.02341693040078721, + 0.014707055663413206, + 0.027007260445200926, + 0.04146166325363687, + 0.04243238211749688] + +right_bias = [0.13147256721895695, + 0.012433179968617855, + 0.01623627031195979, + 0.013683146724821148, + 0.015252253929416771, + 0.012579452674131008, + 0.03127576394244834, + 0.10325523257360177, + 0.041155820323927554, + 0.06563655221935587, + 0.12684503071726816, + 0.016156485199861705, + 0.0176989973670913, + 0.020238823435546928, + 0.01918831945958884, + 0.01791175766601952, + 0.08768383819579266, + 0.019002154198026647, + 0.029600276588388607, + 0.01578415467673732, + 0.0176989973670913, + 0.011834791627882237, + 0.014919815962341426, + 0.007552990611951809, + 0.029759846812584773, + 0.04981250498656951, + 0.05533097524002021] + +def build_ocr_graph(device): + ocr_graph = [ + [15, [10, 4], [17, 2]], + [13, [16, 7], [18, 4]], + [11, [16, 4], [7, 10]], + [14, [10, 11], [7, 1]], + [12, [10, 9], [16, 3]], + [1, [7, 2], [9, 9], [10, 2]], + [5, [8, 8], [6, 8]], + [4, [9, 8], [7, 6]], + [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]], + [2, [10, 1], [7, 7], [9, 3]], + [19, [10, 2], [26, 6]], + [20, [10, 7], [26, 5]], + [22, [25, 4], [10, 8]], + [23, [25, 15]], + [21, [16, 5], [24, 8]] + ] + edge_index = [] + edge_attr = [] + for i in range(len(ocr_graph)): + for j in range(1, len(ocr_graph[i])): + source_node = ocr_graph[i][0] + target_node = ocr_graph[i][j][0] + edge_index.append([source_node, target_node]) + edge_attr.append(ocr_graph[i][j][1]) + ocr_edge_index = torch.tensor(edge_index).t().long() + ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1) + x = torch.arange(0, 27) + ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float) + return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device) + +def pose_edge_index(): + return torch.tensor( + [[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20], + [15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]], + dtype=torch.long) diff --git a/boss/new_bbox/test_bbox.tar.gz b/boss/new_bbox/test_bbox.tar.gz new file mode 100644 index 0000000..61b78e0 Binary files /dev/null and b/boss/new_bbox/test_bbox.tar.gz differ diff --git a/boss/new_bbox/train_bbox.tar.gz b/boss/new_bbox/train_bbox.tar.gz new file mode 100644 index 0000000..8fba733 Binary files /dev/null and b/boss/new_bbox/train_bbox.tar.gz differ diff --git a/boss/new_bbox/val_bbox.tar.gz b/boss/new_bbox/val_bbox.tar.gz new file mode 100644 index 0000000..28dbe56 Binary files /dev/null and b/boss/new_bbox/val_bbox.tar.gz differ diff --git a/boss/outfile b/boss/outfile new file mode 100644 index 0000000..5342b84 Binary files /dev/null and b/boss/outfile differ diff --git a/boss/plots/old_vs_new_bbox.py b/boss/plots/old_vs_new_bbox.py new file mode 100644 index 0000000..2ef1864 --- /dev/null +++ b/boss/plots/old_vs_new_bbox.py @@ -0,0 +1,82 @@ +import json +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +sns.set_theme(style='whitegrid') + +COLORS = sns.color_palette() + +MTOM_COLORS = { + "MN1": (110/255, 117/255, 161/255), + "MN2": (179/255, 106/255, 98/255), + "Base": (193/255, 198/255, 208/255), + "CG": (170/255, 129/255, 42/255), + "IC": (97/255, 112/255, 83/255), + "DB": (144/255, 63/255, 110/255) +} + +abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json')) + +abl_old_bbox_mean = [0.539406718, 0.5348262324, 0.529845863] +abl_old_bbox_std = [0.03639921819, 0.01519544901, 0.01718265794] + +filename = 'results/abl_cm_concat_old_vs_new_bbox' + +#def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None): + +means = [] +stds = [] +for key, values in abl_tom_cm_concat.items(): + mean = np.mean(values) + std = np.std(values) + means.append(mean) + stds.append(std) +fig, ax = plt.subplots(figsize=(10,4)) +x = np.arange(len(abl_tom_cm_concat)) +width = 0.6 +rects1 = ax.bar( + x, means, width, label='New bbox', yerr=stds, + capsize=5, + color=[MTOM_COLORS['CG']]*8, + edgecolor='black', + linewidth=1.5, + alpha=0.6 +) +rects2 = ax.bar( + [0, 2, 5], abl_old_bbox_mean, width, label='Original bbox', yerr=abl_old_bbox_std, + capsize=5, + color=[COLORS[9]]*8, + edgecolor='black', + linewidth=1.5, + alpha=0.6 +) +ax.set_ylabel('Accuracy', fontsize=18) +xticklabels = list(abl_tom_cm_concat.keys()) +ax.set_xticks(np.arange(len(xticklabels))) +ax.set_xticklabels(xticklabels, rotation=0, fontsize=16) +ax.set_yticklabels(ax.get_yticklabels(), fontsize=16) +# Add value labels above each bar, avoiding overlapping with error bars +for rect, std in zip(rects1, stds): + height = rect.get_height() + if height + std < ax.get_ylim()[1]: + ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std), + xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14) + else: + ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std), + xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14) +for rect, std in zip(rects2, abl_old_bbox_std): + height = rect.get_height() + if height + std < ax.get_ylim()[1]: + ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std), + xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14) + else: + ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std), + xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14) +# Remove spines +ax.spines['top'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.legend(fontsize=14) +ax.grid(axis='x') +plt.tight_layout() +plt.savefig(f'{filename}.pdf', bbox_inches='tight') \ No newline at end of file diff --git a/boss/plots/pca.py b/boss/plots/pca.py new file mode 100644 index 0000000..07bc8dd --- /dev/null +++ b/boss/plots/pca.py @@ -0,0 +1,82 @@ +import torch +import os +import seaborn as sns +import matplotlib.pyplot as plt +import numpy as np +from sklearn.decomposition import PCA + + +FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-22_12-00-38_train_None" # no_tom seed 1 + +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-39-41_train_None" # impl mult seed 1 +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-47-07_train_None" # impl sum seed 1 +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_15-58-44_train_None" # impl attn seed 1 +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-58-04_train_None" # impl concat seed 1 + +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-40-01_train_None" # cm mult seed 1 +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-45-55_train_None" # cm sum seed 1 +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-50-42_train_None" # cm attn seed 1 +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-57-15_train_None" # cm concat seed 1 + +#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-37-50_train_None" # db seed 1 + +print(FOLDER_PATH) + +MTOM_COLORS = { + "MN1": (110/255, 117/255, 161/255), + "MN2": (179/255, 106/255, 98/255), + "Base": (193/255, 198/255, 208/255), + "CG": (170/255, 129/255, 42/255), + "IC": (97/255, 112/255, 83/255), + "DB": (144/255, 63/255, 110/255) +} + +sns.set_theme(style='white') + +for i in range(60): + + print(f'Computing analysis for test video {i}...', end='\r') + + emb_file = os.path.join(FOLDER_PATH, f'{i}.pt') + data = torch.load(emb_file) + if len(data) == 14: # implicit + model = 'impl' + out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:] + out_left = out_left.squeeze(0) + cell_left = cell_left.squeeze(0) + out_right = out_right.squeeze(0) + cell_right = cell_right.squeeze(0) + elif len(data) == 13: # common mind + model = 'cm' + out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:] + out_left = out_left.squeeze(0) + out_right = out_right.squeeze(0) + common_mind = common_mind.squeeze(0) + elif len(data) == 12: # speaker-listener + model = 'sl' + out_left, out_right, feats = data[0], data[1], data[2:] + out_left = out_left.squeeze(0) + out_right = out_right.squeeze(0) + else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!") + + # ====== PCA for left and right embeddings ====== # + + out_left_and_right = np.concatenate((out_left, out_right), axis=0) + + pca = PCA(n_components=2) + pca_result = pca.fit_transform(out_left_and_right) + + # Separate the PCA results for each tensor + pca_result_left = pca_result[:out_left.shape[0]] + pca_result_right = pca_result[out_right.shape[0]:] + + plt.figure(figsize=(7,6)) + plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100) + plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100) + plt.xlabel('Principal Component 1', fontsize=30) + plt.ylabel('Principal Component 2', fontsize=30) + plt.grid(False) + plt.legend(fontsize=30) + plt.tight_layout() + plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf') + plt.close() \ No newline at end of file diff --git a/boss/results/abl_cm_concat.json b/boss/results/abl_cm_concat.json new file mode 100644 index 0000000..abbf2e7 --- /dev/null +++ b/boss/results/abl_cm_concat.json @@ -0,0 +1,42 @@ +{ + "all": [ + 0.7402937327, + 0.7171004799, + 0.7305511124 + ], + "rgb\npose\ngaze": [ + 0.4736803839, + 0.4658281227, + 0.4948378653 + ], + "rgb\nocr\nbbox": [ + 0.7364403083, + 0.7407663225, + 0.7321506471 + ], + "rgb\ngaze": [ + 0.5013814163, + 0.418859968, + 0.4644103534 + ], + "rgb\npose": [ + 0.4545586738, + 0.4584484514, + 0.4525229024 + ], + "rgb\nbbox": [ + 0.716591537, + 0.7334593573, + 0.758324851 + ], + "rgb\nocr": [ + 0.4581939799, + 0.456449033, + 0.4577577432 + ], + "rgb": [ + 0.4896393776, + 0.4907299695, + 0.4583757452 + ] +} \ No newline at end of file diff --git a/boss/results/abl_cm_concat.pdf b/boss/results/abl_cm_concat.pdf new file mode 100644 index 0000000..62c9ef3 Binary files /dev/null and b/boss/results/abl_cm_concat.pdf differ diff --git a/boss/results/all.json b/boss/results/all.json new file mode 100644 index 0000000..e003477 --- /dev/null +++ b/boss/results/all.json @@ -0,0 +1,72 @@ +{ + "no_tom": [ + 0.6504653192, + 0.6404682274, + 0.6666787844 + ], + "sl": [ + 0.6584993456, + 0.6608259415, + 0.6584629926 + ], + "cm mult": [ + 0.6739130435, + 0.6872182638, + 0.6591173477 + ], + "cm sum": [ + 0.5820852116, + 0.6401774029, + 0.6001163298 + ], + "cm attn": [ + 0.3240511851, + 0.2688672386, + 0.3028573506 + ], + "cm concat": [ + 0.7402937327, + 0.7171004799, + 0.7305511124 + ], + "impl mult": [ + 0.7119746983, + 0.7019776065, + 0.7244437982 + ], + "impl sum": [ + 0.6542460375, + 0.6642431293, + 0.6678420823 + ], + "impl attn": [ + 0.3311763851, + 0.3125636179, + 0.3112549077 + ], + "impl concat": [ + 0.6957975862, + 0.7227352043, + 0.6971062964 + ], + "resnet": [ + 0.467173186, + 0.4267485822, + 0.4195870292 + ], + "gru": [ + 0.5724152974, + 0.525628908, + 0.4825141777 + ], + "lstm": [ + 0.6210193398, + 0.5547477098, + 0.6572633416 + ], + "conv1d": [ + 0.4478333576, + 0.4114802966, + 0.3853060928 + ] +} \ No newline at end of file diff --git a/boss/results/all.pdf b/boss/results/all.pdf new file mode 100644 index 0000000..4037bee Binary files /dev/null and b/boss/results/all.pdf differ diff --git a/boss/test.py b/boss/test.py new file mode 100644 index 0000000..ec2f599 --- /dev/null +++ b/boss/test.py @@ -0,0 +1,181 @@ +import argparse +import torch +from tqdm import tqdm +import csv +import os +from torch.utils.data import DataLoader + +from dataloader import DataTest +from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D +from models.tom_implicit import ImplicitToMnet +from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL +from models.tom_sl import SLToMnet +from models.single_mindnet import SingleMindNet +from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim + + +def test(args): + + if args.test_frames == 'friends': + test_frame_ids = friends + elif args.test_frames == 'strangers': + test_frame_ids = strangers + elif args.test_frames == 'once': + test_frame_ids = tried_once + elif args.test_frames == 'twice_thrice': + test_frame_ids = tried_twice + tried_thrice + elif args.test_frames is None: + test_frame_ids = None + else: raise NameError + + if args.median is not None: + median = (240, False) + else: + median = None + + if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single': + flatten_dim = 2 + else: + flatten_dim = 1 + + # load datasets + test_dataset = DataTest( + args.test_frame_path, + args.label_path, + args.test_pose_path, + args.test_gaze_path, + args.test_bbox_path, + args.ocr_graph_path, + args.presaved, + test_frame_ids, + median, + flatten_dim=flatten_dim + ) + test_dataloader = DataLoader( + test_dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_workers, + collate_fn=pad_collate, + pin_memory=args.pin_memory + ) + device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu') + assert args.load_model_path is not None + if args.model_type == 'resnet': + inp_dim = get_input_dim(args.mods) + model = ResNet(inp_dim, device).to(device) + elif args.model_type == 'gru': + inp_dim = get_input_dim(args.mods) + model = ResNetGRU(inp_dim, device).to(device) + elif args.model_type == 'lstm': + inp_dim = get_input_dim(args.mods) + model = ResNetLSTM(inp_dim, device).to(device) + elif args.model_type == 'conv1d': + inp_dim = get_input_dim(args.mods) + model = ResNetConv1D(inp_dim, device).to(device) + elif args.model_type == 'tom_cm': + model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + elif args.model_type == 'tom_impl': + model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + elif args.model_type == 'tom_sl': + model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device) + elif args.model_type == 'tom_single': + model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device) + elif args.model_type == 'tom_cm_xl': + model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + else: raise NotImplementedError + model.load_state_dict(torch.load(args.load_model_path, map_location=device)) + model.device = device + + model.eval() + + if args.save_preds: + # Define the output file path + folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}' + if not os.path.exists(folder_path): + os.makedirs(folder_path) + print(f'Saving predictions in {folder_path}.') + + print('Testing...') + num_correct = 0 + cnt = 0 + with torch.no_grad(): + for j, batch in tqdm(enumerate(test_dataloader)): + frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch + if frames is not None: frames = frames.to(device, non_blocking=True) + if poses is not None: poses = poses.to(device, non_blocking=True) + if gazes is not None: gazes = gazes.to(device, non_blocking=True) + if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True) + if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True) + pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs) + pred_left_labels = torch.reshape(pred_left_labels, (-1, 27)) + pred_right_labels = torch.reshape(pred_right_labels, (-1, 27)) + labels = torch.reshape(labels, (-1, 2)).to(device) + batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy( + pred_left_labels, pred_right_labels, labels, sequence_lengths + ) + cnt += batch_num_pred + num_correct += batch_num_correct + + if args.save_preds: + torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt")) + data = [( + i, + torch.argmax(pred_left_labels[i]).cpu().numpy(), + torch.argmax(pred_right_labels[i]).cpu().numpy(), + labels[:, 0][i].cpu().numpy(), + labels[:, 1][i].cpu().numpy()) for i in range(len(labels)) + ] + header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label'] + with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerow(header) # Write the header row + writer.writerows(data) # Write the data rows + + test_acc = num_correct / cnt + print("Test accuracy: {}".format(num_correct / cnt)) + + with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f: + f.write(str(test_acc)) + f.close() + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + # Define the command-line arguments + parser.add_argument('--gpu_id', type=int) + parser.add_argument('--presaved', type=int, default=128) + parser.add_argument('--non_blocking', action='store_true') + parser.add_argument('--num_workers', type=int, default=4) + parser.add_argument('--pin_memory', action='store_true') + parser.add_argument('--model_type', type=str) + parser.add_argument('--aggr', type=str, default='mult', required=False) + parser.add_argument('--use_resnet', action='store_true') + parser.add_argument('--hidden_dim', type=int, default=64) + parser.add_argument('--tom_weight', type=float, default=2.0, required=False) + parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox']) + parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame') + parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose') + parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze') + parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels') + parser.add_argument('--ocr_graph_path', type=str, default='') + parser.add_argument('--label_path', type=str, default='outfile') + parser.add_argument('--save_path', type=str, default='experiments/') + parser.add_argument('--test_frames', type=str, default=None) + parser.add_argument('--median', type=int, default=None) + parser.add_argument('--load_model_path', type=str) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--save_preds', action='store_true') + + # Parse the command-line arguments + args = parser.parse_args() + + if args.model_type == 'tom_cm' or args.model_type == 'tom_impl': + if not args.aggr: + parser.error("The choosen --model_type requires --aggr") + if args.model_type == 'tom_sl' and not args.tom_weight: + parser.error("The choosen --model_type requires --tom_weight") + + test(args) \ No newline at end of file diff --git a/boss/train.py b/boss/train.py new file mode 100644 index 0000000..710406c --- /dev/null +++ b/boss/train.py @@ -0,0 +1,324 @@ +import torch +import os +import argparse +import numpy as np +import random +import datetime +import wandb +from tqdm import tqdm +from torch.utils.data import DataLoader +import torch.nn as nn +from torch.optim.lr_scheduler import CosineAnnealingLR + +from dataloader import Data +from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D +from models.tom_implicit import ImplicitToMnet +from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL +from models.tom_sl import SLToMnet +from models.tom_tf import TFToMnet +from models.single_mindnet import SingleMindNet +from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim + + +def train(args): + + if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single': + flatten_dim = 2 + else: + flatten_dim = 1 + + train_dataset = Data( + args.train_frame_path, + args.label_path, + args.train_pose_path, + args.train_gaze_path, + args.train_bbox_path, + args.ocr_graph_path, + presaved=args.presaved, + flatten_dim=flatten_dim + ) + train_dataloader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + collate_fn=pad_collate, + pin_memory=args.pin_memory + ) + val_dataset = Data( + args.val_frame_path, + args.label_path, + args.val_pose_path, + args.val_gaze_path, + args.val_bbox_path, + args.ocr_graph_path, + presaved=args.presaved, + flatten_dim=flatten_dim + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + collate_fn=pad_collate, + pin_memory=args.pin_memory + ) + device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu') + if args.model_type == 'resnet': + inp_dim = get_input_dim(args.mods) + model = ResNet(inp_dim, device).to(device) + elif args.model_type == 'gru': + inp_dim = get_input_dim(args.mods) + model = ResNetGRU(inp_dim, device).to(device) + elif args.model_type == 'lstm': + inp_dim = get_input_dim(args.mods) + model = ResNetLSTM(inp_dim, device).to(device) + elif args.model_type == 'conv1d': + inp_dim = get_input_dim(args.mods) + model = ResNetConv1D(inp_dim, device).to(device) + elif args.model_type == 'tom_cm': + model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + elif args.model_type == 'tom_cm_xl': + model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + elif args.model_type == 'tom_impl': + model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + elif args.model_type == 'tom_sl': + model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device) + elif args.model_type == 'tom_single': + model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device) + elif args.model_type == 'tom_tf': + model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device) + else: raise NotImplementedError + if args.resume_from_checkpoint is not None: + model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device)) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + if args.scheduler == None: + scheduler = None + else: + scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5) + if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1) + else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device) + stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}} + + max_val_classification_acc = 0 + max_val_classification_epoch = None + counter = 0 + + print(f'Number of parameters: {count_parameters(model)}') + + for i in range(args.num_epoch): + # training + print('Training for epoch {}/{}...'.format(i+1, args.num_epoch)) + temp_train_classification_loss = [] + epoch_num_correct = 0 + epoch_cnt = 0 + model.train() + for j, batch in tqdm(enumerate(train_dataloader)): + if args.use_mixup: + frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27) + else: + frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch + if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking) + if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking) + if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking) + if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking) + if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking) + pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs) + pred_left_labels = torch.reshape(pred_left_labels, (-1, 27)) + pred_right_labels = torch.reshape(pred_right_labels, (-1, 27)) + if args.use_mixup: + labels = torch.reshape(labels, (-1, 54)).to(device) + batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup( + pred_left_labels, pred_right_labels, labels, sequence_lengths + ) + loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:]) + else: + labels = torch.reshape(labels, (-1, 2)).to(device) + batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy( + pred_left_labels, pred_right_labels, labels, sequence_lengths + ) + loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1]) + epoch_cnt += batch_num_pred + epoch_num_correct += batch_num_correct + temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2) + + optimizer.zero_grad() + if args.clip_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) + loss.backward() + optimizer.step() + + if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']}) + print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format( + i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc) + ) + + if scheduler: scheduler.step() + + print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format( + i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt) + ) + stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt) + stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt) + + # validation + print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch)) + temp_val_classification_loss = [] + epoch_num_correct = 0 + epoch_cnt = 0 + model.eval() + with torch.no_grad(): + for j, batch in tqdm(enumerate(val_dataloader)): + frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch + if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking) + if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking) + if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking) + if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking) + if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking) + pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs) + pred_left_labels = torch.reshape(pred_left_labels, (-1, 27)) + pred_right_labels = torch.reshape(pred_right_labels, (-1, 27)) + labels = torch.reshape(labels, (-1, 2)).to(device) + batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy( + pred_left_labels, pred_right_labels, labels, sequence_lengths + ) + epoch_cnt += batch_num_pred + epoch_num_correct += batch_num_correct + loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1]) + temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2) + + if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()}) + print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format( + i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc) + ) + + print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format( + i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt) + ) + + cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt + cls_acc = epoch_num_correct / epoch_cnt + stats['val']['cls_loss'].append(cls_loss) + stats['val']['cls_acc'].append(cls_acc) + if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i}) + + # check for best stat/model using validation accuracy + if stats['val']['cls_acc'][-1] >= max_val_classification_acc: + max_val_classification_acc = stats['val']['cls_acc'][-1] + max_val_classification_epoch = i+1 + torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model')) + counter = 0 + else: + counter += 1 + print(f'EarlyStopping counter: {counter} out of {args.patience}.') + if counter >= args.patience: + break + + with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f: + f.write('{}\n'.format(CFG)) + f.write('{}\n'.format(stats)) + f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc)) + f.close() + + print(f'Results saved in {experiment_save_path}') + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + # Define the command-line arguments + parser.add_argument('--gpu_id', type=int) + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--logger', action='store_true') + parser.add_argument('--presaved', type=int, default=128) + parser.add_argument('--clip_grad_norm', type=float, default=0.5) + parser.add_argument('--use_mixup', action='store_true') + parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False) + parser.add_argument('--non_blocking', action='store_true') + parser.add_argument('--patience', type=int, default=99) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--num_workers', type=int, default=4) + parser.add_argument('--pin_memory', action='store_true') + parser.add_argument('--num_epoch', type=int, default=300) + parser.add_argument('--lr', type=float, default=4e-4) + parser.add_argument('--scheduler', type=str, default=None) + parser.add_argument('--dropout', type=float, default=0.1) + parser.add_argument('--weight_decay', type=float, default=0.005) + parser.add_argument('--label_smoothing', type=float, default=0.1) + parser.add_argument('--model_type', type=str) + parser.add_argument('--aggr', type=str, default='mult', required=False) + parser.add_argument('--use_resnet', action='store_true') + parser.add_argument('--hidden_dim', type=int, default=64) + parser.add_argument('--tom_weight', type=float, default=2.0, required=False) + parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox']) + parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame') + parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose') + parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze') + parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels') + parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame') + parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose') + parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze') + parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels') + parser.add_argument('--ocr_graph_path', type=str, default='') + parser.add_argument('--label_path', type=str, default='outfile') + parser.add_argument('--save_path', type=str, default='experiments/') + parser.add_argument('--resume_from_checkpoint', type=str, default=None) + + # Parse the command-line arguments + args = parser.parse_args() + + if args.use_mixup and not args.mixup_alpha: + parser.error("--use_mixup requires --mixup_alpha") + if args.model_type == 'tom_cm' or args.model_type == 'tom_impl': + if not args.aggr: + parser.error("The choosen --model_type requires --aggr") + if args.model_type == 'tom_sl' and not args.tom_weight: + parser.error("The choosen --model_type requires --tom_weight") + + # get experiment ID + experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train' + if not os.path.exists(args.save_path): + os.makedirs(args.save_path, exist_ok=True) + experiment_save_path = os.path.join(args.save_path, experiment_id) + os.makedirs(experiment_save_path, exist_ok=True) + + CFG = { + 'use_ocr_custom_loss': 0, + 'presaved': args.presaved, + 'batch_size': args.batch_size, + 'num_epoch': args.num_epoch, + 'lr': args.lr, + 'scheduler': args.scheduler, + 'weight_decay': args.weight_decay, + 'model_type': args.model_type, + 'use_resnet': args.use_resnet, + 'hidden_dim': args.hidden_dim, + 'tom_weight': args.tom_weight, + 'dropout': args.dropout, + 'label_smoothing': args.label_smoothing, + 'clip_grad_norm': args.clip_grad_norm, + 'use_mixup': args.use_mixup, + 'mixup_alpha': args.mixup_alpha, + 'non_blocking_tensors': args.non_blocking, + 'patience': args.patience, + 'pin_memory': args.pin_memory, + 'resume_from_checkpoint': args.resume_from_checkpoint, + 'aggr': args.aggr, + 'mods': args.mods, + 'save_path': experiment_save_path , + 'seed': args.seed + } + + print(CFG) + print(f'Saving results in {experiment_save_path}') + + # set seed values + if args.logger: + wandb.init(project="boss", config=CFG) + os.environ['PYTHONHASHSEED'] = str(args.seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + train(args) diff --git a/boss/utils.py b/boss/utils.py new file mode 100644 index 0000000..f89e81d --- /dev/null +++ b/boss/utils.py @@ -0,0 +1,769 @@ +import os +import torch +import numpy as np +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +import matplotlib.pyplot as plt +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE +#from umap import UMAP +import json +import seaborn as sns +import pandas as pd +from natsort import natsorted +import argparse +import seaborn as sns + + + +COLORS_DM = { + "red": (242/256, 165/256, 179/256), + "blue": (195/256, 219/256, 252/256), + "green": (156/256, 228/256, 213/256), + "yellow": (250/256, 236/256, 144/256), + "violet": (207/256, 187/256, 244/256), + "orange": (244/256, 188/256, 154/256) +} + +MTOM_COLORS = { + "MN1": (110/255, 117/255, 161/255), + "MN2": (179/255, 106/255, 98/255), + "Base": (193/255, 198/255, 208/255), + "CG": (170/255, 129/255, 42/255), + "IC": (97/255, 112/255, 83/255), + "DB": (144/255, 63/255, 110/255) +} + +COLORS = sns.color_palette() + +OBJECTS = [ + 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener', + 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine', + 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan', + 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener' +] + +tried_once = [2, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 32, 38, 40, 41, 42, 44, 47, 50, 51, 52, 53, 55, + 56, 57, 61, 63, 65, 67, 68, 70, 71, 72, 73, 74, 76, 80, 81, 83, 85, 87, + 88, 89, 90, 92, 93, 96, 97, 99, 101, 102, 105, 106, 108, 110, 111, 112, + 113, 114, 116, 118, 121, 123, 125, 131, 132, 134, 135, 140, 142, 143, + 145, 146, 148, 149, 151, 152, 154, 155, 156, 157, 160, 161, 162, 165, + 169, 170, 171, 173, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185, + 186, 187, 190, 191, 194, 196, 203, 204, 206, 207, 208, 209, 210, 211, + 213, 214, 216, 218, 219, 220, 222, 225, 227, 228, 229, 232, 233, 235, + 236, 237, 238, 239, 242, 243, 246, 247, 249, 251, 252, 254, 255, 256, + 257, 260, 261, 262, 263, 265, 266, 268, 270, 272, 275, 277, 278, 279, + 281, 282, 287, 290, 296, 298] + +tried_twice = [1, 3, 11, 13, 14, 16, 17, 31, 33, 35, 36, 37, 39, + 43, 45, 49, 54, 58, 59, 60, 62, 64, 66, 69, 75, 77, 79, 82, 84, 86, + 91, 94, 95, 98, 100, 103, 104, 107, 109, 115, 117, 119, 120, 122, + 124, 126, 127, 128, 129, 130, 133, 136, 137, 138, 139, 141, 144, + 147, 150, 153, 158, 159, 164, 166, 167, 168, 172, 174, 177, 188, + 189, 192, 193, 195, 197, 198, 199, 200, 201, 202, 205, 212, 215, + 217, 221, 223, 224, 226, 230, 231, 234, 240, 241, 244, 245, 248, + 250, 253, 258, 259, 264, 267, 269, 271, 273, 274, 276, 280, 283, + 284, 285, 286, 288, 291, 292, 293, 294, 295, 297, 299] + +tried_thrice = [34, 46, 48, 78, 163, 289] + +friends = [i for i in range(150, 300)] +strangers = [i for i in range(0, 150)] + + +def get_input_dim(modalities): + dimensions = { + 'rgb': 512, + 'pose': 150, + 'gaze': 6, + 'ocr': 64, + 'bbox': 108 + } + sum_of_dimensions = sum(dimensions[modality] for modality in modalities) + return sum_of_dimensions + + +def count_parameters(model): + #return sum(p.numel() for p in model.parameters() if p.requires_grad) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([np.prod(p.size()) for p in model_parameters]) + + +def onehot(label, n_classes): + if len(label.size()) == 3: + batch_size, seq_len, _ = label.size() + label_1_2d = label[:,:,0].contiguous().view(batch_size*seq_len, 1) + label_2_2d = label[:,:,1].contiguous().view(batch_size*seq_len, 1) + #onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1) + #onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1) + #onehot_2d = torch.cat((onehot_1_2d, onehot_2_2d), dim=1) + #onehot_3d = onehot_2d.view(batch_size, seq_len, 2*n_classes) + onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1).view(-1, seq_len, n_classes) + onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1).view(-1, seq_len, n_classes) + onehot_3d = torch.cat((onehot_1_2d, onehot_2_2d), dim=2) + return onehot_3d + else: + return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1) + + +def mixup(data, alpha, n_classes): + lam = torch.FloatTensor([np.random.beta(alpha, alpha)]) + frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = data + indices = torch.randperm(labels.size(0)) + # labels + labels2 = labels[indices] + labels = onehot(labels, n_classes) + labels2 = onehot(labels2, n_classes) + labels = labels * lam + labels2 * (1 - lam) + # frames + frames2 = frames[indices] + frames = frames * lam + frames2 * (1 - lam) + # poses + poses2 = poses[indices] + poses = poses * lam + poses2 * (1 - lam) + # gazes + gazes2 = gazes[indices] + gazes = gazes * lam + gazes2 * (1 - lam) + # bboxes + bboxes2 = bboxes[indices] + bboxes = bboxes * lam + bboxes2 * (1 - lam) + return frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths + + +def get_classification_accuracy(pred_left_labels, pred_right_labels, labels, sequence_lengths): + max_len = max(sequence_lengths) + pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27)) + pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27)) + labels = torch.reshape(labels, (-1, max_len, 2)) + left_correct = torch.argmax(pred_left_labels, 2) == labels[:,:,0] + right_correct = torch.argmax(pred_right_labels, 2) == labels[:,:,1] + num_pred = sum(sequence_lengths) * 2 + num_correct = 0 + for i in range(len(sequence_lengths)): + size = sequence_lengths[i] + num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item() + acc = num_correct / num_pred + return acc, num_correct, num_pred + + +def get_classification_accuracy_mixup(pred_left_labels, pred_right_labels, labels, sequence_lengths): + max_len = max(sequence_lengths) + pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27)) + pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27)) + labels = torch.reshape(labels, (-1, max_len, 54)) + left_correct = torch.argmax(pred_left_labels, 2) == torch.argmax(labels[:,:,:27], 2) + right_correct = torch.argmax(pred_right_labels, 2) == torch.argmax(labels[:,:,27:], 2) + num_pred = sum(sequence_lengths) * 2 + num_correct = 0 + for i in range(len(sequence_lengths)): + size = sequence_lengths[i] + num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item() + acc = num_correct / num_pred + return acc, num_correct, num_pred + + +def pad_collate(batch): + (aa, bb, cc, dd, ee, ff) = zip(*batch) + seq_lens = [len(a) for a in aa] + aa_pad = pad_sequence(aa, batch_first=True, padding_value=0) + bb_pad = pad_sequence(bb, batch_first=True, padding_value=-1) + if cc[0] is not None: + cc_pad = pad_sequence(cc, batch_first=True, padding_value=0) + else: + cc_pad = None + if dd[0] is not None: + dd_pad = pad_sequence(dd, batch_first=True, padding_value=0) + else: + dd_pad = None + if ee[0] is not None: + ee_pad = pad_sequence(ee, batch_first=True, padding_value=0) + else: + ee_pad = None + if ff[0] is not None: + ff_pad = pad_sequence(ff, batch_first=True, padding_value=0) + else: + ff_pad = None + return aa_pad, bb_pad, cc_pad, dd_pad, ee_pad, ff_pad, seq_lens + + +def ocr_loss(pL, lL, pR, lR, ocr, loss, eta=10.0, mode='abs'): + """ + Custom loss based on the negative OCRL matrix. + + Input: + pL: tensor of shape [batch_size, num_classes] representing left belief predictions + lL: tensor of shape [batch_size] representing the left belief labels + pR: tensor of shape [batch_size, num_classes] representing right belief predictions + lR: tensor of shape [batch_size] representing the right belief labels + ocr: negative OCR matrix (i.e. 1-OCR) + loss: loss function + eta: hyperparameter for the interaction term + + Output: + Final loss resulting from weighting left and right loss with the respective + OCR coefficients and summing them. + """ + bs = pL.shape[0] + if len([*lR[0].size()]) > 0: # for mixup + p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device) + g = torch.tensor([ocr[torch.argmax(lL[i]), torch.argmax(lR[i])] for i in range(bs)], device=pL.device) + else: + p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device) + g = torch.tensor([ocr[lL[i], lR[i]] for i in range(bs)], device=pL.device) + left_loss = torch.mean(loss(pL, lL)) + right_loss = torch.mean(loss(pR, lR)) + if mode == 'abs': + interaction_loss = torch.mean(torch.abs(g - p)) + elif mode == 'mse': + interaction_loss = torch.mean(torch.pow(g - p, 2)) + else: raise NotImplementedError + eta = (left_loss + right_loss) / interaction_loss + interaction_loss = interaction_loss * eta + print(f"Left: {left_loss} --- Right: {right_loss} --- Interaction: {interaction_loss}") + return left_loss + right_loss + interaction_loss + + +def spherical2cartesial(x): + """ + From https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing#scrollTo=78QhNw4MYSYp + """ + output = torch.zeros(x.size(0),3) + output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0]) + output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0]) + output[:,1] = torch.sin(x[:,1]) + return output + + +def presave(): + from PIL import Image + from torchvision import transforms + import os + from natsort import natsorted + import pickle as pkl + + preprocess = transforms.Compose([ + transforms.Resize((128, 128)), + #transforms.Resize(256), + #transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + frame_path = '/scratch/bortoletto/data/boss/test/frame' + frame_dirs = os.listdir(frame_path) + frame_paths = [] + for frame_dir in natsorted(frame_dirs): + paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] + frame_paths.append(paths) + + save_folder = '/scratch/bortoletto/data/boss/presaved128/'+frame_path.split('/')[-2]+'/'+frame_path.split('/')[-1] + if not os.path.exists(save_folder): + os.makedirs(save_folder) + print(save_folder, 'created') + + for video in natsorted(frame_paths): + print(video[0].split('/')[-2], end='\r') + images = [preprocess(Image.open(i)) for i in video] + strings = video[0].split('/') + with open(save_folder+'/'+strings[7]+'.pkl', 'wb') as f: + pkl.dump(images, f) + + +def compute_cosine_similarity(tensor1, tensor2): + return F.cosine_similarity(tensor1, tensor2, dim=-1) + + +def find_most_similar_embedding(rgb, pose, gaze, ocr, bbox, repr): + gaze_similarity = compute_cosine_similarity(gaze, repr) + pose_similarity = compute_cosine_similarity(pose, repr) + ocr_similarity = compute_cosine_similarity(ocr, repr) + bbox_similarity = compute_cosine_similarity(bbox, repr) + rgb_similarity = compute_cosine_similarity(rgb, repr) + similarities = torch.stack([gaze_similarity, pose_similarity, ocr_similarity, bbox_similarity, rgb_similarity]) + max_index = torch.argmax(similarities, dim=0) + main_modality = [] + main_modality_name = [] + for idx in max_index: + if idx == 0: + main_modality.append(gaze) + main_modality_name.append('gaze') + elif idx == 1: + main_modality.append(pose) + main_modality_name.append('pose') + elif idx == 2: + main_modality.append(ocr) + main_modality_name.append('ocr') + elif idx == 3: + main_modality.append(bbox) + main_modality_name.append('bbox') + else: + main_modality.append(rgb) + main_modality_name.append('rgb') + return main_modality, main_modality_name + + +def plot_similarity_histogram(values, filename, labels): + def count_elements(list_of_strings, target_list): + counts = [] + for element in list_of_strings: + count = target_list.count(element) + counts.append(count) + return counts + fig, ax = plt.subplots(figsize=(12,4)) + colors = ["red", "blue", "green", "violet"] + # Remove spines + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + colors = [MTOM_COLORS['MN1'], MTOM_COLORS['MN2'], MTOM_COLORS['CG'], MTOM_COLORS['CG']] + alphas = [0.6, 0.6, 0.6, 0.6] + edgecolors = ['black', 'black', MTOM_COLORS['MN1'], MTOM_COLORS['MN2']] + linewidths = [1.0, 1.0, 5.0, 5.0] + if isinstance(values[0], list): + num_lists = len(values) + bar_width = 0.8 / num_lists + for i, val in enumerate(values): + unique_strings = ['rgb', 'pose', 'gaze', 'bbox', 'ocr'] + counts = count_elements(unique_strings, val) + x = np.arange(len(unique_strings)) + x_shifted = x + (i - (len(labels) - 1) / 2) * bar_width #x - (bar_width * num_lists / 2) + (bar_width * i) + ax.bar(x_shifted, counts, width=bar_width, label=f'{labels[i]}', color=colors[i], edgecolor=edgecolors[i], linewidth=linewidths[i], alpha=alphas[i]) + ax.set_xlabel('Modality', fontsize=18) + ax.set_ylabel('Counts', fontsize=18) + ax.set_xticks(np.arange(len(unique_strings))) + ax.set_xticklabels(unique_strings, fontsize=18) + ax.set_yticklabels(ax.get_yticklabels(), fontsize=16) + ax.legend(fontsize=18) + ax.grid(axis='y') + plt.savefig(filename, bbox_inches='tight') + else: + unique_strings, counts = np.unique(values, return_counts=True) + ax.bar(unique_strings, counts) + ax.set_xlabel('Modality') + ax.set_ylabel('Counts') + plt.savefig(filename, bbox_inches='tight') + + +def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None): + means = [] + stds = [] + for key, values in data.items(): + mean = np.mean(values) + std = np.std(values) + means.append(mean) + stds.append(std) + fig, ax = plt.subplots(figsize=size) + x = np.arange(len(data)) + width = 0.6 + rects1 = ax.bar( + x, means, width, label='Mean', yerr=stds, + capsize=5, + color='teal' if colors is None else colors, + edgecolor='black', + linewidth=1.5, + alpha=0.6 + ) + ax.set_ylabel('Accuracy', fontsize=18) + #ax.set_title('Mean and Standard Deviation of Results', fontsize=14) + if filename == 'results/all': + xticklabels = [ + 'Base', + 'DB', + 'CG$\otimes$', 'CG$\oplus$', 'CG$\odot$', 'CG$\parallel$', + 'IC$\otimes$', 'IC$\oplus$', 'IC$\odot$', 'IC$\parallel$', + 'CNN', 'CNN+GRU', 'CNN+LSTM', 'CNN+Conv1D' + ] + else: + xticklabels = list(data.keys()) + ax.set_xticks(np.arange(len(xticklabels))) + ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=16) #data.keys(), rotation=rotation, fontsize=16) + ax.set_yticklabels(ax.get_yticklabels(), fontsize=16) + #ax.grid(axis='y') + #ax.set_axisbelow(True) + #ax.legend(loc='upper right') + # Add value labels above each bar, avoiding overlapping with error bars + for rect, std in zip(rects1, stds): + height = rect.get_height() + if height + std < ax.get_ylim()[1]: + ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std), + xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14) + else: + ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std), + xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14) + # Remove spines + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.grid(axis='x') + plt.tight_layout() + plt.savefig(f'{filename}.pdf', bbox_inches='tight') + + +def plot_confusion_matrices(left_cm, right_cm, labels, title, annot=True): + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) + left_xticklabels = [OBJECTS[i] for i in labels[0]] + left_yticklabels = [OBJECTS[i] for i in labels[1]] + right_xticklabels = [OBJECTS[i] for i in labels[2]] + right_yticklabels = [OBJECTS[i] for i in labels[3]] + sns.heatmap( + left_cm, + annot=annot, + fmt='.0f', + cmap='Blues', + cbar=False, + xticklabels=left_xticklabels, + yticklabels=left_yticklabels, + ax=ax1 + ) + ax1.set_xlabel('Predicted') + ax1.set_ylabel('True') + ax1.set_title('Left Confusion Matrix') + sns.heatmap( + right_cm, + annot=annot, + fmt='.0f', + cmap='Blues', + cbar=False, #True if annot is False else False, + xticklabels=right_xticklabels, + yticklabels=right_yticklabels, + ax=ax2 + ) + ax2.set_xlabel('Predicted') + ax2.set_ylabel('True') + ax2.set_title('Right Confusion Matrix') + #plt.suptitle(title) + plt.tight_layout() + plt.savefig(title + '.pdf') + plt.close() + +def plot_confusion_matrix(confusion_matrix, labels, title, annot=True): + plt.figure(figsize=(6, 6)) + xticklabels = [OBJECTS[i] for i in labels[0]] + yticklabels = [OBJECTS[i] for i in labels[1]] + sns.heatmap( + confusion_matrix, + annot=annot, + fmt='.0f', + cmap='Blues', + cbar=False, + xticklabels=xticklabels, + yticklabels=yticklabels + ) + plt.xlabel('Predicted') + plt.ylabel('True') + #plt.title(title) + plt.tight_layout() + plt.savefig(title + '.pdf') + plt.close() + + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + parser.add_argument('--task', type=str, choices=['confusion', 'similarity', 'scores', 'friends_vs_strangers']) + parser.add_argument('--folder_path', type=str) + + args = parser.parse_args() + + if args.task == 'similarity': + sns.set_theme(style='white') + else: + sns.set_theme(style='whitegrid') + + # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # + + if args.task == 'similarity': + + out_left_main_mods_full_test = [] + out_right_main_mods_full_test = [] + cell_left_main_mods_full_test = [] + cell_right_main_mods_full_test = [] + cm_left_main_mods_full_test = [] + cm_right_main_mods_full_test = [] + + for i in range(60): + + print(f'Computing analysis for test video {i}...', end='\r') + + emb_file = os.path.join(args.folder_path, f'{i}.pt') + data = torch.load(emb_file) + if len(data) == 14: # implicit + model = 'impl' + out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:] + out_left = out_left.squeeze(0) + cell_left = cell_left.squeeze(0) + out_right = out_right.squeeze(0) + cell_right = cell_right.squeeze(0) + elif len(data) == 13: # common mind + model = 'cm' + out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:] + out_left = out_left.squeeze(0) + out_right = out_right.squeeze(0) + common_mind = common_mind.squeeze(0) + elif len(data) == 12: # speaker-listener + model = 'sl' + out_left, out_right, feats = data[0], data[1], data[2:] + out_left = out_left.squeeze(0) + out_right = out_right.squeeze(0) + else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!") + + # ====== PCA for left and right embeddings ====== # + + out_left_and_right = np.concatenate((out_left, out_right), axis=0) + + pca = PCA(n_components=2) + pca_result = pca.fit_transform(out_left_and_right) + + # Separate the PCA results for each tensor + pca_result_left = pca_result[:out_left.shape[0]] + pca_result_right = pca_result[out_right.shape[0]:] + + plt.figure(figsize=(7,6)) + plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='MindNet$_1$', color=MTOM_COLORS['MN1'], s=100) + plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='MindNet$_2$', color=MTOM_COLORS['MN2'], s=100) + plt.xlabel('Principal Component 1', fontsize=30) + plt.ylabel('Principal Component 2', fontsize=30) + plt.grid(False) + plt.legend(fontsize=30) + plt.tight_layout() + plt.savefig(f'{args.folder_path}/{i}_pca.pdf') + plt.close() + + # ====== Feature similarity ====== # + + if len(feats) == 10: + left_rgb, left_ocr, left_pose, left_gaze, left_bbox, right_rgb, right_ocr, right_pose, right_gaze, right_bbox = feats + left_rgb = left_rgb.squeeze(0) + left_ocr = left_ocr.squeeze(0) + left_pose = left_pose.squeeze(0) + left_gaze = left_gaze.squeeze(0) + left_bbox = left_bbox.squeeze(0) + right_rgb = right_rgb.squeeze(0) + right_ocr = right_ocr.squeeze(0) + right_pose = right_pose.squeeze(0) + right_gaze = right_gaze.squeeze(0) + right_bbox = right_bbox.squeeze(0) + else: raise NotImplementedError("Ablated versions are not supported yet.") + + # out: [1, seq_len, dim] --- squeeze ---> [seq_len, dim] + # cell: [1, 1, dim] --------- squeeze ---> [1, dim] + # cm: [1, seq_len, dim] --- squeeze ---> [seq_len, dim] + # feat: [1, seq_len, dim] --- squeeze ---> [seq_len, dim] + + _, out_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, out_left) + _, out_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, out_right) + out_left_main_mods_full_test += out_left_main_mods + out_right_main_mods_full_test += out_right_main_mods + if model == 'impl': + _, cell_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, cell_left) + _, cell_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, cell_right) + cell_left_main_mods_full_test += cell_left_main_mods + cell_right_main_mods_full_test += cell_right_main_mods + if model == 'cm': + _, cm_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, common_mind) + _, cm_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, common_mind) + cm_left_main_mods_full_test += cm_left_main_mods + cm_right_main_mods_full_test += cm_right_main_mods + + if model == 'impl': + plot_similarity_histogram( + [out_left_main_mods_full_test, + out_right_main_mods_full_test, + cell_left_main_mods_full_test, + cell_right_main_mods_full_test], + f'{args.folder_path}/boss_similartiy_impl_hist_all.pdf', + [r'$h_1$', r'$h_2$', r'$c_1$', r'$c_2$'] + ) + elif model == 'cm': + plot_similarity_histogram( + [out_left_main_mods_full_test, + out_right_main_mods_full_test, + cm_left_main_mods_full_test, + cm_right_main_mods_full_test], + f'{args.folder_path}/boss_similarity_cm_hist_all.pdf', + [r'$h_1$', r'$h_2$', r'$cg$ w/ 1', r'$cg$ w/ 2'] + ) + elif model == 'sl': + plot_similarity_histogram( + [out_left_main_mods_full_test, + out_right_main_mods_full_test], + f'{args.folder_path}/boss_similarity_sl_hist_all.py', + [r'$h_1$', r'$h_2$'] + ) + + # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # + + elif args.task == 'scores': + + + all = json.load(open('results/all.json')) + abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json')) + + plot_scores_histogram( + all, + filename='results/all', + size=(10,4), + rotation=45, + colors=[COLORS[7]] + [MTOM_COLORS['DB']] + [MTOM_COLORS['CG']]*4 + [MTOM_COLORS['IC']]*4 + [COLORS[4]]*4 + ) + + plot_scores_histogram( + abl_tom_cm_concat, + size=(10,4), + filename='results/abl_cm_concat', + colors=[MTOM_COLORS['CG']]*8 + ) + + # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # + + elif args.task == 'confusion': + + # List to store confusion matrices + all_df = [] + left_matrices = [] + right_matrices = [] + + # Iterate over the CSV files + for filename in natsorted(os.listdir(args.folder_path)): + if filename.endswith('.csv'): + file_path = os.path.join(args.folder_path, filename) + + print(f'Processing {file_path}...') + + # Read the CSV file + df = pd.read_csv(file_path) + + # Extract the left and right labels and predictions + left_labels = df['left_label'] + right_labels = df['right_label'] + left_preds = df['left_pred'] + right_preds = df['right_pred'] + + # Calculate the confusion matrices for left and right + left_cm = pd.crosstab(left_labels, left_preds) + right_cm = pd.crosstab(right_labels, right_preds) + + # Append the confusion matrices to the list + left_matrices.append(left_cm) + right_matrices.append(right_cm) + all_df.append(df) + + # Plot and save the confusion matrices for left and right + for i, cm in enumerate(zip(left_matrices, right_matrices)): + print(f'Computing confusion matrices for video {i}...', end='\r') + plot_confusion_matrices( + cm[0], + cm[1], + labels=[cm[0].columns, cm[0].index, cm[1].columns, cm[1].index], + title=f'{args.folder_path}/{i}_cm' + ) + + merged_df = pd.concat(all_df).reset_index(drop=True) + + merged_left_cm = pd.crosstab(merged_df['left_label'], merged_df['left_pred']) + merged_right_cm = pd.crosstab(merged_df['right_label'], merged_df['right_pred']) + plot_confusion_matrices( + merged_left_cm, + merged_right_cm, + labels=[merged_left_cm.columns, merged_left_cm.index, merged_right_cm.columns, merged_right_cm.index], + title=f'{args.folder_path}/all_lr', + annot=False + ) + + merged_preds = pd.concat([merged_df['left_pred'], merged_df['right_pred']]).reset_index(drop=True) + merged_labels = pd.concat([merged_df['left_label'], merged_df['right_label']]).reset_index(drop=True) + merged_all_cm = pd.crosstab(merged_labels, merged_preds) + plot_confusion_matrix( + merged_all_cm, + labels=[merged_all_cm.columns, merged_all_cm.index], + title=f'{args.folder_path}/all_merged', + annot=False + ) + + # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # + + elif args.task == 'friends_vs_strangers': + + # friends ids: 30-59 + # stranger ids: 0-29 + out_left_list = [] + out_right_list = [] + cell_left_list = [] + cell_right_list = [] + common_mind_list = [] + + for i in range(60): + + print(f'Computing analysis for test video {i}...', end='\r') + + emb_file = os.path.join(args.folder_path, f'{i}.pt') + data = torch.load(emb_file) + if len(data) == 14: # implicit + model = 'impl' + out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:] + out_left = out_left.squeeze(0) + cell_left = cell_left.squeeze(0) + out_right = out_right.squeeze(0) + cell_right = cell_right.squeeze(0) + out_left_list.append(out_left) + out_right_list.append(out_right) + cell_left_list.append(cell_left) + cell_right_list.append(cell_right) + elif len(data) == 13: # common mind + model = 'cm' + out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:] + out_left = out_left.squeeze(0) + out_right = out_right.squeeze(0) + common_mind = common_mind.squeeze(0) + out_left_list.append(out_left) + out_right_list.append(out_right) + common_mind_list.append(common_mind) + elif len(data) == 12: # speaker-listener + model = 'sl' + out_left, out_right, feats = data[0], data[1], data[2:] + out_left = out_left.squeeze(0) + out_right = out_right.squeeze(0) + out_left_list.append(out_left) + out_right_list.append(out_right) + else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!") + + # ====== PCA for left and right embeddings ====== # + + print('\rComputing PCA...') + + strangers_nframes = sum([out_left_list[i].shape[0] for i in range(30)]) + + left = torch.cat(out_left_list, 0) + right = torch.cat(out_right_list, 0) + + out_left_and_right = torch.cat([left, right], axis=0).numpy() + #out_left_and_right = StandardScaler().fit_transform(out_left_and_right) + + pca = PCA(n_components=2) + pca_result = pca.fit_transform(out_left_and_right) + #pca_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(out_left_and_right) + #pca_result = UMAP().fit_transform(out_left_and_right) + + # Separate the PCA results for each tensor + #pca_result_left = pca_result[:left.shape[0]] + #pca_result_right = pca_result[right.shape[0]:] + pca_result_strangers = pca_result[:strangers_nframes] + pca_result_friends = pca_result[strangers_nframes:] + + #plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='Left') + #plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='Right') + plt.scatter(pca_result_friends[:, 0], pca_result_friends[:, 1], label='Friends') + plt.scatter(pca_result_strangers[:, 0], pca_result_strangers[:, 1], label='Strangers') + plt.xlabel('Principal Component 1') + plt.ylabel('Principal Component 2') + plt.legend() + plt.savefig(f'{args.folder_path}/friends_vs_strangers_pca.pdf') + plt.close() + + # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # + + else: + + raise NameError \ No newline at end of file diff --git a/tbd/.gitignore b/tbd/.gitignore new file mode 100644 index 0000000..99a5b78 --- /dev/null +++ b/tbd/.gitignore @@ -0,0 +1,196 @@ +experiments +wandb +predictions + + +# Created by https://www.toptal.com/developers/gitignore/api/python,linux +# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# End of https://www.toptal.com/developers/gitignore/api/python,linux \ No newline at end of file diff --git a/tbd/README.md b/tbd/README.md new file mode 100644 index 0000000..4f2c430 --- /dev/null +++ b/tbd/README.md @@ -0,0 +1,16 @@ +# TBD + +# Data +The original code can be found [here](https://github.com/LifengFan/Triadic-Belief-Dynamics). The dataset is not directly available but must be requested using the link to the Google form provided in the [README](https://github.com/LifengFan/Triadic-Belief-Dynamics?tab=readme-ov-file#dataset). + +## Installing Dependencies +Run `conda env create -f environment.yml`. + +## Train +`source run_train.sh`. + +## Test +`source run_test.sh`. **Make sure to use the same random seed used for training**, otherwise the splits will be different and you will likely have a data leakage. + +## Visualisations +The plots are made using `utils/fb_scores_err.py` (false belief analysis) and `utils/similarity.py` (PCA of latent representations). diff --git a/tbd/environment.yml b/tbd/environment.yml new file mode 100644 index 0000000..d6921cb --- /dev/null +++ b/tbd/environment.yml @@ -0,0 +1,100 @@ +name: tbd +channels: + - conda-forge + - defaults + - pytorch +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - ca-certificates=2023.01.10=h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - ncurses=6.4=h6a678d5_0 + - openssl=1.1.1t=h7f8727e_0 + - pip=23.0.1=py38h06a4308_0 + - python=3.8.10=h12debd9_8 + - readline=8.2=h5eee18b_0 + - setuptools=66.0.0=py38h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.38.4=py38h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - appdirs==1.4.4 + - beautifulsoup4==4.12.2 + - certifi==2023.5.7 + - charset-normalizer==3.1.0 + - click==8.1.3 + - cmake==3.26.4 + - contourpy==1.1.0 + - cycler==0.11.0 + - docker-pycreds==0.4.0 + - einops==0.6.1 + - filelock==3.12.0 + - fonttools==4.40.0 + - gdown==4.7.1 + - gitdb==4.0.10 + - gitpython==3.1.31 + - idna==3.4 + - importlib-resources==5.12.0 + - jinja2==3.1.2 + - joblib==1.3.1 + - kiwisolver==1.4.4 + - lit==16.0.6 + - markupsafe==2.1.3 + - matplotlib==3.7.1 + - memory-efficient-attention-pytorch==0.1.2 + - mpmath==1.3.0 + - networkx==3.1 + - numpy==1.24.4 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-cupti-cu11==11.7.101 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - nvidia-cufft-cu11==10.9.0.58 + - nvidia-curand-cu11==10.2.10.91 + - nvidia-cusolver-cu11==11.4.0.1 + - nvidia-cusparse-cu11==11.7.4.91 + - nvidia-nccl-cu11==2.14.3 + - nvidia-nvtx-cu11==11.7.91 + - opencv-python==4.8.0.74 + - packaging==23.1 + - pandas==2.0.3 + - pathtools==0.1.2 + - pillow==9.5.0 + - protobuf==4.23.3 + - psutil==5.9.5 + - pyparsing==3.1.0 + - pysocks==1.7.1 + - python-dateutil==2.8.2 + - pytz==2023.3 + - pyyaml==6.0 + - requests==2.30.0 + - scikit-learn==1.3.0 + - scipy==1.10.1 + - seaborn==0.12.2 + - sentry-sdk==1.27.0 + - setproctitle==1.3.2 + - six==1.16.0 + - smmap==5.0.0 + - soupsieve==2.4.1 + - sympy==1.12 + - threadpoolctl==3.1.0 + - torch==2.0.1 + - torch-geometric==2.3.1 + - torchaudio==2.0.2 + - torchsampler==0.1.2 + - torchvision==0.15.2 + - tqdm==4.65.0 + - triton==2.0.0 + - typing-extensions==4.7.0 + - tzdata==2023.3 + - urllib3==2.0.2 + - wandb==0.15.5 + - zipp==3.15.0 +prefix: /opt/anaconda3/envs/tbd \ No newline at end of file diff --git a/tbd/models/base.py b/tbd/models/base.py new file mode 100644 index 0000000..5c4ca02 --- /dev/null +++ b/tbd/models/base.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +from .utils import pose_edge_index +from torch_geometric.nn import GCNConv + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Linear(dim, dim)) + + def forward(self, x): + return self.net(x) + + +class CNN(nn.Module): + def __init__(self, hidden_dim): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.conv1(x) + x = nn.functional.relu(x) + x = self.pool(x) + x = self.conv2(x) + x = nn.functional.relu(x) + x = self.pool(x) + x = self.conv3(x) + x = nn.functional.relu(x) + x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling + return x + + +class MindNetLSTM(nn.Module): + """ + Basic MindNet for model-based ToM, just LSTM on input concatenation + """ + def __init__(self, hidden_dim, dropout, mods): + super(MindNetLSTM, self).__init__() + self.mods = mods + if 'rgb_1' in mods: + self.img_emb = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + if 'gaze' in mods: + self.gaze_emb = nn.Linear(2, hidden_dim) + if 'pose' in mods: + self.pose_edge_index = pose_edge_index() + self.pose_emb = GCNConv(3, hidden_dim) + self.LSTM = PreNorm( + hidden_dim*len(mods), + nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True)) + self.proj = nn.Linear(hidden_dim*2, hidden_dim) + self.dropout = nn.Dropout(dropout) + self.act = nn.GELU() + + def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze): + feats = [] + if 'rgb_3' in self.mods: + feats.append(rgb_3rd_pov_feats) + if 'rgb_1' in self.mods: + rgb_feat = [] + for i in range(rgb_1st_pov.shape[1]): + images_i = rgb_1st_pov[:,i] + img_i_feat = self.img_emb(images_i) + img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat))) + feats.append(rgb_feats) + if 'pose' in self.mods: + bs, seq_len = pose.size(0), pose.size(1) + self.pose_edge_index = self.pose_edge_index.to(pose.device) + pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index) + pose_emb = self.dropout(self.act(pose_emb)) + pose_emb = torch.mean(pose_emb, dim=1) + hd = pose_emb.size(-1) + feats.append(pose_emb.view(bs, seq_len, hd)) + if 'gaze' in self.mods: + gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) + feats.append(gaze_feats) + if 'bbox' in self.mods: + feats.append(bbox_feats.mean(2)) + lstm_inp = torch.cat(feats, 2) + lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp)) + c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2) + return self.act(self.proj(lstm_out)), c_n, feats + + +class MindNetSL(nn.Module): + """ + Basic MindNet for SL ToM, just LSTM on input concatenation + """ + def __init__(self, hidden_dim, dropout, mods): + super(MindNetSL, self).__init__() + self.mods = mods + if 'rgb_1' in mods: + self.img_emb = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + if 'gaze' in mods: + self.gaze_emb = nn.Linear(2, hidden_dim) + if 'pose' in mods: + self.pose_edge_index = pose_edge_index() + self.pose_emb = GCNConv(3, hidden_dim) + self.LSTM = PreNorm( + hidden_dim*len(mods), + nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True)) + self.proj = nn.Linear(hidden_dim*2, hidden_dim) + self.dropout = nn.Dropout(dropout) + self.act = nn.GELU() + + def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze): + feats = [] + if 'rgb_3' in self.mods: + feats.append(rgb_3rd_pov_feats) + if 'rgb_1' in self.mods: + rgb_feat = [] + for i in range(rgb_1st_pov.shape[1]): + images_i = rgb_1st_pov[:,i] + img_i_feat = self.img_emb(images_i) + img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat))) + feats.append(rgb_feats) + if 'pose' in self.mods: + bs, seq_len = pose.size(0), pose.size(1) + self.pose_edge_index = self.pose_edge_index.to(pose.device) + pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index) + pose_emb = self.dropout(self.act(pose_emb)) + pose_emb = torch.mean(pose_emb, dim=1) + hd = pose_emb.size(-1) + feats.append(pose_emb.view(bs, seq_len, hd)) + if 'gaze' in self.mods: + gaze_feats = self.dropout(self.act(self.gaze_emb(gaze))) + feats.append(gaze_feats) + if 'bbox' in self.mods: + feats.append(bbox_feats.mean(2)) + lstm_inp = torch.cat(feats, 2) + lstm_out, _ = self.LSTM(self.dropout(lstm_inp)) + return self.act(self.proj(lstm_out)), feats \ No newline at end of file diff --git a/tbd/models/common_mind.py b/tbd/models/common_mind.py new file mode 100644 index 0000000..6b45761 --- /dev/null +++ b/tbd/models/common_mind.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from .base import CNN, MindNetLSTM +from memory_efficient_attention_pytorch import Attention + + +class CommonMindToMnet(nn.Module): + """ + img: bs, 3, 128, 128 + pose: bs, 26, 3 + gaze: bs, 2 NOTE: only tracker has gaze + bbox: bs, 4 + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']): + super(CommonMindToMnet, self).__init__() + + self.aggr = aggr + self.mods = mods + + # ---- 3rd POV Images, object and bbox ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + #for param in self.cnn.parameters(): + # param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + self.bbox_ff = nn.Linear(4, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods) + self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze']) + if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim) + self.ln_1 = nn.LayerNorm(hidden_dim) + self.ln_2 = nn.LayerNorm(hidden_dim) + if aggr == 'attn': + self.attn_left = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.attn_right = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.m1 = nn.Linear(hidden_dim, 4) + self.m2 = nn.Linear(hidden_dim, 4) + self.m12 = nn.Linear(hidden_dim, 4) + self.m21 = nn.Linear(hidden_dim, 4) + self.mc = nn.Linear(hidden_dim, 4) + + def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze): + + batch_size, sequence_len, channels, height, width = img_3rd_pov.shape + + if 'bbox' in self.mods: + bbox_feat = self.dropout(self.act(self.bbox_ff(bbox))) + else: + bbox_feat = None + + if 'rgb_3' in self.mods: + rgb_feat = [] + for i in range(sequence_len): + images_i = img_3rd_pov[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat))) + else: + rgb_feat_3rd_pov = None + + if tracker_id == 'skele1': + out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze) + out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None) + else: + out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze) + out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None) + + if self.aggr == 'no_tom': + m1 = self.m1(out_1).mean(1) + m2 = self.m2(out_2).mean(1) + m12 = self.m12(out_1).mean(1) + m21 = self.m21(out_2).mean(1) + mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2 + + return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2 + + common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h) + + if self.aggr == 'attn': + p1 = self.attn_left(x=out_1, context=common_mind) + p2 = self.attn_right(x=out_2, context=common_mind) + elif self.aggr == 'mult': + p1 = out_1 * common_mind + p2 = out_2 * common_mind + elif self.aggr == 'sum': + p1 = out_1 + common_mind + p2 = out_2 + common_mind + elif self.aggr == 'concat': + p1 = torch.cat([out_1, common_mind], 1) + p2 = torch.cat([out_2, common_mind], 1) + else: raise ValueError + p1 = self.act(p1) + p1 = self.ln_1(p1) + p2 = self.act(p2) + p2 = self.ln_2(p2) + if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn': + m1 = self.m1(p1).mean(1) + m2 = self.m2(p2).mean(1) + m12 = self.m12(p1).mean(1) + m21 = self.m21(p2).mean(1) + mc = self.mc(p1*p2).mean(1) + if self.aggr == 'concat': + m1 = self.m1(p1).mean(1) + m2 = self.m2(p2).mean(1) + m12 = self.m12(p1).mean(1) + m21 = self.m21(p2).mean(1) + mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2 + + return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2 + + + + +if __name__ == "__main__": + + img_3rd_pov = torch.ones(3, 5, 3, 128, 128) + img_tracker = torch.ones(3, 5, 3, 128, 128) + img_battery = torch.ones(3, 5, 3, 128, 128) + pose1 = torch.ones(3, 5, 26, 3) + pose2 = torch.ones(3, 5, 26, 3) + bbox = torch.ones(3, 5, 13, 4) + tracker_id = 'skele1' + gaze = torch.ones(3, 5, 2) + mods = ['pose', 'bbox', 'rgb_3'] + + for agg in ['no_tom']: + model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods) + out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) + + print(out[0].shape) \ No newline at end of file diff --git a/tbd/models/implicit.py b/tbd/models/implicit.py new file mode 100644 index 0000000..3689b71 --- /dev/null +++ b/tbd/models/implicit.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from .base import CNN, MindNetLSTM +from memory_efficient_attention_pytorch import Attention + + +class ImplicitToMnet(nn.Module): + """ + Implicit ToM net. Supports any subset of modalities + Possible aggregations: sum, mult, attn, concat + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']): + super(ImplicitToMnet, self).__init__() + + self.aggr = aggr + self.mods = mods + + # ---- 3rd POV Images, object and bbox ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + self.bbox_ff = nn.Linear(4, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods) + self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze']) + self.ln_1 = nn.LayerNorm(hidden_dim) + self.ln_2 = nn.LayerNorm(hidden_dim) + if aggr == 'attn': + self.attn_left = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.attn_right = Attention( + dim = hidden_dim, + dim_head = hidden_dim // 4, + heads = 4, + memory_efficient = True, + q_bucket_size = hidden_dim, + k_bucket_size = hidden_dim) + self.m1 = nn.Linear(hidden_dim, 4) + self.m2 = nn.Linear(hidden_dim, 4) + self.m12 = nn.Linear(hidden_dim, 4) + self.m21 = nn.Linear(hidden_dim, 4) + self.mc = nn.Linear(hidden_dim, 4) + + def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze): + + batch_size, sequence_len, channels, height, width = img_3rd_pov.shape + + if 'bbox' in self.mods: + bbox_feat = self.dropout(self.act(self.bbox_ff(bbox))) + else: + bbox_feat = None + + if 'rgb_3' in self.mods: + rgb_feat = [] + for i in range(sequence_len): + images_i = img_3rd_pov[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat))) + else: + rgb_feat_3rd_pov = None + + if tracker_id == 'skele1': + out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze) + out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None) + else: + out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze) + out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None) + + if self.aggr == 'no_tom': + m1 = self.m1(out_1).mean(1) + m2 = self.m2(out_2).mean(1) + m12 = self.m12(out_1).mean(1) + m21 = self.m21(out_2).mean(1) + mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2 + + return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2 + + if self.aggr == 'attn': + p1 = self.attn_left(x=out_1, context=cell_2) + p2 = self.attn_right(x=out_2, context=cell_1) + elif self.aggr == 'mult': + p1 = out_1 * cell_2 + p2 = out_2 * cell_1 + elif self.aggr == 'sum': + p1 = out_1 + cell_2 + p2 = out_2 + cell_1 + elif self.aggr == 'concat': + p1 = torch.cat([out_1, cell_2], 1) + p2 = torch.cat([out_2, cell_1], 1) + else: raise ValueError + p1 = self.act(p1) + p1 = self.ln_1(p1) + p2 = self.act(p2) + p2 = self.ln_2(p2) + if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn': + m1 = self.m1(p1).mean(1) + m2 = self.m2(p2).mean(1) + m12 = self.m12(p1).mean(1) + m21 = self.m21(p2).mean(1) + mc = self.mc(p1*p2).mean(1) + if self.aggr == 'concat': + m1 = self.m1(p1).mean(1) + m2 = self.m2(p2).mean(1) + m12 = self.m12(p1).mean(1) + m21 = self.m21(p2).mean(1) + mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2 + + return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2 + + + + +if __name__ == "__main__": + + img_3rd_pov = torch.ones(3, 5, 3, 128, 128) + img_tracker = torch.ones(3, 5, 3, 128, 128) + img_battery = torch.ones(3, 5, 3, 128, 128) + pose1 = torch.ones(3, 5, 26, 3) + pose2 = torch.ones(3, 5, 26, 3) + bbox = torch.ones(3, 5, 13, 4) + tracker_id = 'skele1' + gaze = torch.ones(3, 5, 2) + + for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']: + model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg) + out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) + + print(agg, out[0].shape) \ No newline at end of file diff --git a/tbd/models/sl.py b/tbd/models/sl.py new file mode 100644 index 0000000..4503971 --- /dev/null +++ b/tbd/models/sl.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from .base import CNN, MindNetSL + + +class SLToMnet(nn.Module): + """ + Speaker-Listener ToMnet + """ + def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']): + super(SLToMnet, self).__init__() + + self.tom_weight = tom_weight + self.mods = mods + + # ---- 3rd POV Images, object and bbox ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + self.bbox_ff = nn.Linear(4, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_1 = MindNetSL(hidden_dim, dropout, mods=mods) + self.mind_net_2 = MindNetSL(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze']) + self.m1 = nn.Linear(hidden_dim, 4) + self.m2 = nn.Linear(hidden_dim, 4) + self.m12 = nn.Linear(hidden_dim, 4) + self.m21 = nn.Linear(hidden_dim, 4) + self.mc = nn.Linear(hidden_dim, 4) + + def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze): + + batch_size, sequence_len, channels, height, width = img_3rd_pov.shape + + if 'bbox' in self.mods: + bbox_feat = self.dropout(self.act(self.bbox_ff(bbox))) + else: + bbox_feat = None + + if 'rgb_3' in self.mods: + rgb_feat = [] + for i in range(sequence_len): + images_i = img_3rd_pov[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat))) + else: + rgb_feat_3rd_pov = None + + if tracker_id == 'skele1': + out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze) + out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None) + else: + out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze) + out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None) + + m1_logits = self.m1(out_1).mean(1) + m2_logits = self.m2(out_2).mean(1) + m12_logits = self.m12(out_1).mean(1) + m21_logits = self.m21(out_2).mean(1) + mc_logits = self.mc(out_1*out_2).mean(1) + + m1_ranking = torch.log_softmax(m1_logits, dim=-1) + m2_ranking = torch.log_softmax(m2_logits, dim=-1) + m12_ranking = torch.log_softmax(m12_logits, dim=-1) + m21_ranking = torch.log_softmax(m21_logits, dim=-1) + mc_ranking = torch.log_softmax(mc_logits, dim=-1) + + # NOTE: does this make sense? + m1 = m1_ranking + self.tom_weight * m2_ranking + m2 = m2_ranking + self.tom_weight * m1_ranking + m12 = m12_ranking + self.tom_weight * m21_ranking + m21 = m21_ranking + self.tom_weight * m12_ranking + mc = mc_ranking + self.tom_weight * mc_ranking + + return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2 + + + + + +if __name__ == "__main__": + + img_3rd_pov = torch.ones(3, 5, 3, 128, 128) + img_tracker = torch.ones(3, 5, 3, 128, 128) + img_battery = torch.ones(3, 5, 3, 128, 128) + pose1 = torch.ones(3, 5, 26, 3) + pose2 = torch.ones(3, 5, 26, 3) + bbox = torch.ones(3, 5, 13, 4) + tracker_id = 'skele1' + gaze = torch.ones(3, 5, 2) + + model = SLToMnet(hidden_dim=64, device='cpu', tom_weight=2.0, resnet=False, dropout=0.5) + out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) + + print(out[0].shape) \ No newline at end of file diff --git a/tbd/models/tom_base.py b/tbd/models/tom_base.py new file mode 100644 index 0000000..e70acdf --- /dev/null +++ b/tbd/models/tom_base.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from .base import CNN, MindNetLSTM +import numpy as np + + +class ImplicitToMnet(nn.Module): + """ + Implicit ToM net. Supports any subset of modalities + Possible aggregations: sum, mult, attn, concat + """ + def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']): + super(ImplicitToMnet, self).__init__() + + self.mods = mods + + # ---- 3rd POV Images, object and bbox ----# + if resnet: + resnet = models.resnet34(weights="IMAGENET1K_V1") + self.cnn = nn.Sequential( + *(list(resnet.children())[:-1]) + ) + for param in self.cnn.parameters(): + param.requires_grad = False + self.rgb_ff = nn.Linear(512, hidden_dim) + else: + self.cnn = CNN(hidden_dim) + self.rgb_ff = nn.Linear(hidden_dim, hidden_dim) + self.bbox_ff = nn.Linear(4, hidden_dim) + + # ---- Others ----# + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.device = device + + # ---- Mind nets ----# + self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods) + self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze']) + + self.m1 = nn.Linear(hidden_dim, 4) + self.m2 = nn.Linear(hidden_dim, 4) + self.m12 = nn.Linear(hidden_dim, 4) + self.m21 = nn.Linear(hidden_dim, 4) + self.mc = nn.Linear(hidden_dim, 4) + + def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze): + + batch_size, sequence_len, channels, height, width = img_3rd_pov.shape + + if 'bbox' in self.mods: + bbox_feat = self.dropout(self.act(self.bbox_ff(bbox))) + else: + bbox_feat = None + + if 'rgb_3' in self.mods: + rgb_feat = [] + for i in range(sequence_len): + images_i = img_3rd_pov[:,i] + img_i_feat = self.cnn(images_i) + img_i_feat = img_i_feat.view(batch_size, -1) + rgb_feat.append(img_i_feat) + rgb_feat = torch.stack(rgb_feat, 1) + rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat))) + else: + rgb_feat_3rd_pov = None + + if tracker_id == 'skele1': + out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze) + out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None) + else: + out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze) + out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None) + + if self.aggr == 'no_tom': + m1 = self.m1(out_1).mean(1) + m2 = self.m2(out_2).mean(1) + m12 = self.m12(out_1).mean(1) + m21 = self.m21(out_2).mean(1) + mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2 + + return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2 + + + +def count_parameters(model): + #return sum(p.numel() for p in model.parameters() if p.requires_grad) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([np.prod(p.size()) for p in model_parameters]) + + + +if __name__ == "__main__": + + img_3rd_pov = torch.ones(3, 5, 3, 128, 128) + img_tracker = torch.ones(3, 5, 3, 128, 128) + img_battery = torch.ones(3, 5, 3, 128, 128) + pose1 = torch.ones(3, 5, 26, 3) + pose2 = torch.ones(3, 5, 26, 3) + bbox = torch.ones(3, 5, 13, 4) + tracker_id = 'skele1' + gaze = torch.ones(3, 5, 2) + + model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5) + print(count_parameters(model)) + breakpoint() + + for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']: + model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5) + out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) + + print(agg, out[0].shape) \ No newline at end of file diff --git a/tbd/models/utils.py b/tbd/models/utils.py new file mode 100644 index 0000000..582c132 --- /dev/null +++ b/tbd/models/utils.py @@ -0,0 +1,7 @@ +import torch + + +def pose_edge_index(): + start = [15, 14, 13, 12, 19, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 3, 4, 5, 6, 8, 8, 4, 20, 21, 21, 22, 24, 22] + end = [14, 13, 12, 0, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 4, 20, 20, 21, 22, 24, 23, 25, 24] + return torch.tensor([start+end, end+start]) \ No newline at end of file diff --git a/tbd/results/abl.json b/tbd/results/abl.json new file mode 100644 index 0000000..98007d8 --- /dev/null +++ b/tbd/results/abl.json @@ -0,0 +1,186 @@ +{ + "all": [ + { + "m1": 0.3803337111550836, + "m2": 0.3900899763574355, + "m12": 0.4441281276628709, + "m21": 0.4818757648120031, + "mc": 0.4485177767702456 + }, + { + "m1": 0.5186066842992191, + "m2": 0.521895750052127, + "m12": 0.49294626980529677, + "m21": 0.4810118034501327, + "mc": 0.6097300398369058 + }, + { + "m1": 0.4965589148309122, + "m2": 0.5094894309980568, + "m12": 0.4615136302786905, + "m21": 0.4554005550423429, + "mc": 0.6258118710785031 + } + ], + "rgb_3_pose_gaze_bbox": [ + { + "m1": 0.3776045061727805, + "m2": 0.3996776745150713, + "m12": 0.4762772810038159, + "m21": 0.48643178296718503, + "mc": 0.4575207273412474 + }, + { + "m1": 0.5176564423560418, + "m2": 0.5109344883698214, + "m12": 0.4630213122846928, + "m21": 0.4826608133674547, + "mc": 0.5979415365779003 + }, + { + "m1": 0.5114692300997931, + "m2": 0.5027048375802656, + "m12": 0.47527894405588544, + "m21": 0.45223985157847546, + "mc": 0.6054099305712209 + } + ], + "rgb_3_pose_gaze": [ + { + "m1": 0.403207421026191, + "m2": 0.3833413122398237, + "m12": 0.4602455224198077, + "m21": 0.47181798537346287, + "mc": 0.4603675297898878 + }, + { + "m1": 0.49484810149311514, + "m2": 0.5060275976807422, + "m12": 0.4610412452830618, + "m21": 0.46869095956564044, + "mc": 0.6040674897817755 + }, + { + "m1": 0.5160598186177866, + "m2": 0.5309683014233921, + "m12": 0.47227245803060636, + "m21": 0.46953974307035984, + "mc": 0.6014771460423635 + } + ], + "rgb_3_pose": [ + { + "m1": 0.4057149181928123, + "m2": 0.4002233785689204, + "m12": 0.46794813614607333, + "m21": 0.4690365183933033, + "mc": 0.4591530208921514 + }, + { + "m1": 0.5362792166212834, + "m2": 0.5290656046231254, + "m12": 0.4569419683345858, + "m21": 0.4530255281497826, + "mc": 0.4554252731371068 + }, + { + "m1": 0.49570625763169085, + "m2": 0.5146503967646507, + "m12": 0.4567936139893578, + "m21": 0.45918214877096325, + "mc": 0.5962397441246001 + } + ], + "rgb_3_gaze": [ + { + "m1": 0.40135106828655215, + "m2": 0.38453470155825614, + "m12": 0.4989742833725901, + "m21": 0.47369273992079175, + "mc": 0.48430622854433986 + }, + { + "m1": 0.508038122818153, + "m2": 0.4875748099051746, + "m12": 0.46665443622698555, + "m21": 0.46635808547742913, + "mc": 0.47936993226840163 + }, + { + "m1": 0.49795853039610977, + "m2": 0.5028666890527814, + "m12": 0.44176709237564815, + "m21": 0.4483898274665582, + "mc": 0.5867527750929912 + } + ], + "rgb_3_bbox": [ + { + "m1": 0.3951383898241492, + "m2": 0.3818794542844425, + "m12": 0.44108151735270384, + "m21": 0.46539754196523303, + "mc": 0.43982185797713114 + }, + { + "m1": 0.5093846655989521, + "m2": 0.4923439212866733, + "m12": 0.4598003475323884, + "m21": 0.47647640659290746, + "mc": 0.6349953712994137 + }, + { + "m1": 0.5325224862402295, + "m2": 0.5092319973570975, + "m12": 0.4435807136490263, + "m21": 0.4576911633624616, + "mc": 0.6282064277856357 + } + ], + "rgb_3_rgb_1": [ + { + "m1": 0.39189391736691903, + "m2": 0.3739995635963588, + "m12": 0.4792392731637056, + "m21": 0.4592726043789752, + "mc": 0.4468645255652386 + }, + { + "m1": 0.4827892482357646, + "m2": 0.48042899735042716, + "m12": 0.45932653547051094, + "m21": 0.48430209616318126, + "mc": 0.4506104344435269 + }, + { + "m1": 0.4820247145474279, + "m2": 0.3667553358192628, + "m12": 0.44503028688537, + "m21": 0.45984906207471654, + "mc": 0.465120658971623 + } + ], + "rgb_3": [ + { + "m1": 0.40725462165126114, + "m2": 0.38737351624656846, + "m12": 0.46230461548252094, + "m21": 0.4829312519709871, + "mc": 0.4492175856929955 + }, + { + "m1": 0.5286274183685061, + "m2": 0.5081429492163979, + "m12": 0.4610256989472217, + "m21": 0.4733487634477733, + "mc": 0.4655243312197501 + }, + { + "m1": 0.5217968210271873, + "m2": 0.5103780571157844, + "m12": 0.4431266771306429, + "m21": 0.48398542131284883, + "mc": 0.6122314353959392 + } + ] +} \ No newline at end of file diff --git a/tbd/results/all.json b/tbd/results/all.json new file mode 100644 index 0000000..978f197 --- /dev/null +++ b/tbd/results/all.json @@ -0,0 +1,232 @@ +{ + "cm_concat": [ + { + "m1": 0.38921744471949393, + "m2": 0.38557137008494935, + "m12": 0.44699534554593756, + "m21": 0.4747474437468054, + "mc": 0.4918107834016411 + }, + { + "m1": 0.5402415140026018, + "m2": 0.48833721513836786, + "m12": 0.4631512445419047, + "m21": 0.4740880083492652, + "mc": 0.6375070925808958 + }, + { + "m1": 0.5012543523713172, + "m2": 0.5068694866895836, + "m12": 0.4451537834591627, + "m21": 0.45215784721598673, + "mc": 0.6201022576104379 + } + ], + "cm_sum": [ + { + "m1": 0.39403894801783246, + "m2": 0.38541918219411786, + "m12": 0.4600376974144952, + "m21": 0.471919704007463, + "mc": 0.43950812310207055 + }, + { + "m1": 0.48497621104052574, + "m2": 0.5295044689855949, + "m12": 0.4502949472343065, + "m21": 0.47823492553894387, + "mc": 0.6028290833617195 + }, + { + "m1": 0.503386104373653, + "m2": 0.49983127146477085, + "m12": 0.46782817568218116, + "m21": 0.45484578845116075, + "mc": 0.5905749126722909 + } + ], + "cm_mult": [ + { + "m1": 0.39070820515470606, + "m2": 0.3996851353903932, + "m12": 0.4455704586852128, + "m21": 0.4713517869738811, + "mc": 0.4450907029478458 + }, + { + "m1": 0.5066540697731119, + "m2": 0.526507445454099, + "m12": 0.462643008560599, + "m21": 0.48263054309565334, + "mc": 0.6438566476782207 + }, + { + "m1": 0.48868811674304546, + "m2": 0.5074635877653536, + "m12": 0.44597405775819876, + "m21": 0.45445350963025877, + "mc": 0.5884265473527218 + } + ], + "cm_attn": [ + { + "m1": 0.3949557687114269, + "m2": 0.3919385900921811, + "m12": 0.4850081112466773, + "m21": 0.4849575556679713, + "mc": 0.4516870089239762 + }, + { + "m1": 0.4925989821370256, + "m2": 0.49409170532242247, + "m12": 0.4664647278240569, + "m21": 0.46783863397462533, + "mc": 0.6398721139927354 + }, + { + "m1": 0.4945636568169018, + "m2": 0.5049812790749876, + "m12": 0.454359577718189, + "m21": 0.4712184012093268, + "mc": 0.5992735441011302 + } + ], + "no_tom": [ + { + "m1": 0.2570551317, + "m2": 0.375350929686332, + "m12": 0.312451988649724, + "m21": 0.4631371031641, + "mc": 0.457486278214567 + }, + { + "m1": 0.233046800382043, + "m2": 0.522609755931958, + "m12": 0.326821758467328, + "m21": 0.474338898013257, + "mc": 0.604439456291308 + }, + { + "m1": 0.33774852598382, + "m2": 0.520943544364353, + "m12": 0.298617214416867, + "m21": 0.482175301427192, + "mc": 0.634948478570852 + } + ], + "sl": [ + { + "m1": 0.365205706591741, + "m2": 0.255259363011619, + "m12": 0.421227579844245, + "m21": 0.376143327741882, + "mc": 0.45614515353718 + }, + { + "m1": 0.493046934143676, + "m2": 0.331798174804139, + "m12": 0.422821548330913, + "m21": 0.399768928780549, + "mc": 0.450957023549231 + }, + { + "m1": 0.466266787709392, + "m2": 0.350962671130227, + "m12": 0.431694150269919, + "m21": 0.378863431433258, + "mc": 0.470284405744656 + } + ], + "impl_concat": [ + { + "m1": 0.38427302094644894, + "m2": 0.38673879043767634, + "m12": 0.45694337561663145, + "m21": 0.4737891562722213, + "mc": 0.4502976351448088 + }, + { + "m1": 0.49951068243751173, + "m2": 0.5084945752383908, + "m12": 0.4604721097809549, + "m21": 0.4826884970930907, + "mc": 0.6200443272625361 + }, + { + "m1": 0.5013244243339088, + "m2": 0.49476495726495723, + "m12": 0.4596701406290429, + "m21": 0.4554742441542813, + "mc": 0.5988949378402535 + } + ], + "impl_sum": [ + { + "m1": 0.3803337111550836, + "m2": 0.3900899763574355, + "m12": 0.4441281276628709, + "m21": 0.4818757648120031, + "mc": 0.4485177767702456 + }, + { + "m1": 0.5186066842992191, + "m2": 0.521895750052127, + "m12": 0.49294626980529677, + "m21": 0.4810118034501327, + "mc": 0.6097300398369058 + }, + { + "m1": 0.4965589148309122, + "m2": 0.5094894309980568, + "m12": 0.4615136302786905, + "m21": 0.4554005550423429, + "mc": 0.6258118710785031 + } + ], + "impl_mult": [ + { + "m1": 0.3789421413006731, + "m2": 0.3818053844554785, + "m12": 0.46402717346945177, + "m21": 0.4903726261039529, + "mc": 0.4461443806398687 + }, + { + "m1": 0.3789421413006731, + "m2": 0.3818053844554785, + "m12": 0.46402717346945177, + "m21": 0.4903726261039529, + "mc": 0.4461443806398687 + }, + { + "m1": 0.49338554196342077, + "m2": 0.5066817652688608, + "m12": 0.46253374461930613, + "m21": 0.47782311190445825, + "mc": 0.4581608719646799 + } + ], + "impl_attn": [ + { + "m1": 0.37413691393147924, + "m2": 0.2546966838007244, + "m12": 0.429390512693598, + "m21": 0.292401773870023, + "mc": 0.45706325836224465 + }, + { + "m1": 0.513917904196177, + "m2": 0.25802580258025803, + "m12": 0.49272662664765543, + "m21": 0.27041556176385584, + "mc": 0.6041394755857196 + }, + { + "m1": 0.47720445038981674, + "m2": 0.25839328537170264, + "m12": 0.46505055463781547, + "m21": 0.260276985433943, + "mc": 0.6021811271770562 + } + ] +} diff --git a/tbd/results/false_belief_first_vs_second.pdf b/tbd/results/false_belief_first_vs_second.pdf new file mode 100644 index 0000000..8e2ec5a Binary files /dev/null and b/tbd/results/false_belief_first_vs_second.pdf differ diff --git a/tbd/results/fb_ttest.txt b/tbd/results/fb_ttest.txt new file mode 100644 index 0000000..a38ce44 --- /dev/null +++ b/tbd/results/fb_ttest.txt @@ -0,0 +1,87 @@ + +========================================================= m1_m2_m12_m21 + + +Model: Base -> yes + + +Model: DB -> yes + + +Model: CG$\oplus$ -> no + + +Model: CG$\otimes$ -> no + + +Model: CG$\odot$ -> no + + +Model: IC$\parallel$ -> no + + +Model: IC$\oplus$ -> no + + +Model: IC$\otimes$ -> no + + +Model: IC$\odot$ -> yes + +========================================================= m1_m2 + + +Model: Base -> yes + + +Model: DB -> yes + + +Model: CG$\oplus$ -> no + + +Model: CG$\otimes$ -> no + + +Model: CG$\odot$ -> no + + +Model: IC$\parallel$ -> no + + +Model: IC$\oplus$ -> no + + +Model: IC$\otimes$ -> no + + +Model: IC$\odot$ -> yes + +========================================================= m12_m21 + + +Model: Base -> yes + + +Model: DB -> yes + + +Model: CG$\oplus$ -> no + + +Model: CG$\otimes$ -> no + + +Model: CG$\odot$ -> no + + +Model: IC$\parallel$ -> no + + +Model: IC$\oplus$ -> no + + +Model: IC$\otimes$ -> no + + +Model: IC$\odot$ -> yes diff --git a/tbd/results/hgm_scores.txt b/tbd/results/hgm_scores.txt new file mode 100644 index 0000000..b501612 --- /dev/null +++ b/tbd/results/hgm_scores.txt @@ -0,0 +1,59 @@ +mc ===================================================================== + precision recall f1-score support + + 0 0.000 0.500 0.001 50 + 1 0.000 0.000 0.000 4 + 2 0.004 0.038 0.007 238 + 3 0.999 0.795 0.885 290788 + + accuracy 0.794 291080 + macro avg 0.251 0.333 0.223 291080 +weighted avg 0.998 0.794 0.884 291080 + +m1 ===================================================================== + precision recall f1-score support + + 0 0.000 0.000 0.000 147 + 1 0.000 0.000 0.000 2 + 2 0.025 0.051 0.033 1714 + 3 0.994 0.988 0.991 289217 + + accuracy 0.982 291080 + macro avg 0.255 0.260 0.256 291080 +weighted avg 0.988 0.982 0.985 291080 + +m2 ===================================================================== + precision recall f1-score support + + 0 0.001 0.013 0.001 151 + 2 0.031 0.084 0.045 2394 + 3 0.992 0.970 0.981 288535 + + accuracy 0.962 291080 + macro avg 0.341 0.355 0.342 291080 +weighted avg 0.983 0.962 0.972 291080 + +m12 ===================================================================== + precision recall f1-score support + + 0 0.000 0.000 0.000 93 + 1 0.000 0.000 0.000 8 + 2 0.015 0.056 0.023 676 + 3 0.997 0.990 0.994 290303 + + accuracy 0.988 291080 + macro avg 0.253 0.262 0.254 291080 +weighted avg 0.995 0.988 0.991 291080 + +m21 ===================================================================== + precision recall f1-score support + + 0 0.002 0.012 0.003 86 + 1 0.000 0.000 0.000 12 + 2 0.010 0.040 0.016 658 + 3 0.997 0.989 0.993 290324 + + accuracy 0.987 291080 + macro avg 0.252 0.260 0.253 291080 +weighted avg 0.995 0.987 0.991 291080 + diff --git a/tbd/results/tbd_abl_avg_only.pdf b/tbd/results/tbd_abl_avg_only.pdf new file mode 100644 index 0000000..e97ae6f Binary files /dev/null and b/tbd/results/tbd_abl_avg_only.pdf differ diff --git a/tbd/run_test.sh b/tbd/run_test.sh new file mode 100644 index 0000000..d526d80 --- /dev/null +++ b/tbd/run_test.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +python -m test \ +--gpu_id 1 \ +--seed 1 \ +--non_blocking \ +--pin_memory \ +--model_type tom_cm \ +--aggr no_tom \ +--hidden_dim 64 \ +--batch_size 64 \ +--load_model_path /PATH/TO/model \ No newline at end of file diff --git a/tbd/run_train.sh b/tbd/run_train.sh new file mode 100644 index 0000000..39ea559 --- /dev/null +++ b/tbd/run_train.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +python -m train \ +--gpu_id 2 \ +--seed 123 \ +--logger \ +--non_blocking \ +--pin_memory \ +--batch_size 64 \ +--num_workers 16 \ +--num_epoch 300 \ +--lr 5e-4 \ +--dropout 0.1 \ +--model_type tom_cm \ +--aggr no_tom \ +--hidden_dim 64 diff --git a/tbd/tbd_dataloader.py b/tbd/tbd_dataloader.py new file mode 100644 index 0000000..8875859 --- /dev/null +++ b/tbd/tbd_dataloader.py @@ -0,0 +1,568 @@ +from __future__ import annotations + +from typing import Optional, Union + +import torch +import pickle +import torch +import time +import glob +import random +import os +import numpy as np +import pandas as pd +import cv2 +from itertools import product +import csv +import torchvision.transforms as T + +from utils.helpers import tracker_skeID, CLIPS_IDS_88, ALL_IDS, UNIQUE_OBJ_IDS + + +def collate_fn(batch): + # Unpack the batch into individual elements + img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep = zip(*batch) + + # Determine the maximum number of objects in any batch + max_n_obj = max(bbox.shape[1] for bbox in bboxes) + + # Pad the bounding box tensors + bboxes_pad = [] + for bbox in bboxes: + pad_size = max_n_obj - bbox.shape[1] + pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32) + padded_bbox = torch.cat((bbox, pad), dim=1) + bboxes_pad.append(padded_bbox) + + # Stack the padded tensors into a batch tensor + bboxes_batch = torch.stack(bboxes_pad, dim=0) + + img_3rd_pov = torch.stack(img_3rd_pov, dim=0) + img_tracker = torch.stack(img_tracker, dim=0) + img_battery = torch.stack(img_battery, dim=0) + pose1 = torch.stack(pose1, dim=0) + pose2 = torch.stack(pose2, dim=0) + gaze = torch.stack(gaze, dim=0) + labels = torch.tensor(labels, dtype=torch.long) + + # Return the batched tensors + return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep + +def collate_fn_test(batch): + # Unpack the batch into individual elements + img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep, false_beliefs = zip(*batch) + + # Determine the maximum number of objects in any batch + max_n_obj = max(bbox.shape[1] for bbox in bboxes) + + # Pad the bounding box tensors + bboxes_pad = [] + for bbox in bboxes: + pad_size = max_n_obj - bbox.shape[1] + pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32) + padded_bbox = torch.cat((bbox, pad), dim=1) + bboxes_pad.append(padded_bbox) + + # Stack the padded tensors into a batch tensor + bboxes_batch = torch.stack(bboxes_pad, dim=0) + + img_3rd_pov = torch.stack(img_3rd_pov, dim=0) + img_tracker = torch.stack(img_tracker, dim=0) + img_battery = torch.stack(img_battery, dim=0) + pose1 = torch.stack(pose1, dim=0) + pose2 = torch.stack(pose2, dim=0) + gaze = torch.stack(gaze, dim=0) + labels = torch.tensor(labels, dtype=torch.long) + + # Return the batched tensors + return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep, false_beliefs + + +class TBDDataset(torch.utils.data.Dataset): + + def __init__( + self, + path: str = "/scratch/bortoletto/data/tbd", + mode: str = "train", + tbd_data_path: str = "/scratch/bortoletto/data/tbd/mind_lstm_training_cnn_att/", + list_of_ids_to_consider: list = ALL_IDS, + use_preprocessed_img: bool = True, + resize_img: Optional[Union[tuple, int]] = (128,128), + ): + """TBD Dataset based on the 88 clip version of the TBD data. + + Expects the following folder structure: + - path + - tracker_gt_smooth <- These are eye tracking from POV, 2D coordinates + - images/*/ <- These are the images by experiment id + - battery <- These are the images, 1st Person + - tracker <- These are the images, 1st other Person w/ eye fixation + - kinect <- These are the images, 3rd Person + - skeleton <- Pose estimation, 3D coordinates + - annotation <- These are the labels, i.e. [0,3] (see below) + + + Labels are strcutured as follows: + { + "O1": [ <- Object with id O1 + { + "m1": { + "fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange + "loc": null + }, + "m2": { + "fluent": 3, + "loc": null + }, + "m12": { + "fluent": 3, + "loc": null + }, + "m21": { + "fluent": 3, + "loc": null + }, + "mc": { + "fluent": 3, + "loc": null + }, + "mg": { + "fluent": 3, + "loc": [ + 22, + 9 + ] + } + }, ... + ], ... + } + + This corresponds to a strict subset of the raw dataset collected + by the TBD people in their paper "Learning Traidic Belief Dynamics + in Nonverbal Communication from Videos" (CVPR2021, Oral). + + We keep small amounts of data in memory (everything <100MB). + Otherwise we read from disk on the fly. This dataset applies normalization. + + Args: + path (str, optional): Where the folders lie. + Defaults to "/scratch/ruhdorfer/triadic_beleif_data_v2". + list_of_ids_to_consider (list, optional): List of ids to consider. + Defaults to ALL_IDS. Otherwise specify a list, + e.g. ["test_94342_23", "test_boelter_21", ...]. + resize_img (Optional[Union[tuple, int]], optional): Resize image to + this size if required. Defaults to None. + """ + print(f"Loading TBD Dataset in mode {mode}...") + + self.mode = mode + + start = time.time() + + self.skeleton_3D_path = f"{path}/skeleton" + self.tracker_2D_path = f"{path}/tracker_gt_smooth" + self.bbox_csv_path = f"{path}/annotations_with_bbox.csv" + if use_preprocessed_img: + self.img_path = f"{path}/images_norm" + else: + self.img_path = f"{path}/images" + self.obj_ids_path = f"{path}/mind_lstm_training_cnn_att_shu.pkl" + + self.label_map = list(product([0, 1, 2, 3], repeat=5)) + + clips = os.listdir(tbd_data_path) + data = [] + labels = [] + for clip in clips: + with open(tbd_data_path + clip, 'rb') as f: + vec_input, label_ = pickle.load(f, encoding='latin1') + data = data + vec_input + labels = labels + label_ + c = list(zip(data, labels)) + random.shuffle(c) + train_ratio = int(len(c) * 0.6) + validate_ratio = int(len(c) * 0.2) + data, label = zip(*c) + train_x, train_y = data[:train_ratio], label[:train_ratio] + validate_x, validate_y = data[train_ratio:train_ratio + validate_ratio], label[train_ratio:train_ratio + validate_ratio] + test_x, test_y = data[train_ratio + validate_ratio:], label[train_ratio + validate_ratio:] + self.mind_count = np.zeros(1024) # used for CE weights + + if mode == "train": + self.data, self.labels = train_x, train_y + elif mode == "val": + self.data, self.labels = validate_x, validate_y + elif mode == "test": + self.data, self.labels = test_x, test_y + + self.false_beliefs_path = f"{path}/store_mind_set" + + # keep small amouts of data in memory + self.skeleton_3D = self.load_skeleton_3D(self.skeleton_3D_path, list_of_ids_to_consider) + self.tracker_2D = self.load_tracker_2D(self.tracker_2D_path, list_of_ids_to_consider) + self.bbox_df = pd.read_csv(self.bbox_csv_path, header=0) + self.obj_ids = self.load_obj_ids(self.obj_ids_path) + + if not use_preprocessed_img: + normalisation_steps = [ + T.ToTensor(), + T.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ] + if resize_img is not None: + normalisation_steps.insert(1, T.Resize(resize_img)) + self.preprocess_img = T.Compose(normalisation_steps) + else: + self.preprocess_img = None + + self.use_preprocessed_img = use_preprocessed_img + print(f"Done loading in {time.time() - start}s.") + + def __len__(self): + return len(self.data) + + def __getitem__( + self, idx: int + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict, str, int, str + ]: + """Given an index, return the corresponding experiment_id and timestep in the experiment. + Then picky the appropriate data and labels from these. + + Args: + idx (int): _description_ + + Returns: + tuple: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict + Returns the following: + - img_kinect: torch.Tensor of shape (T, C, H, W) (Default is [T, 3, 720, 1280]) + - img_tracker: torch.Tensor of shape (T, H, W, C) + - img_battery: torch.Tensor of shape (T, H, W, C) + - skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 1) + - skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 2) + - bbox: torch.Tensor of shape (T, num_obj, 5) + - tracker to skeleton ID: str (either skeleton 1 or 2) + - tracker_2D: torch.Tensor of shape (T, 2) + - labels: dict (see below) + """ + labels = self.label_map[self.labels[idx]] + experiment_id = self.data[idx][1][0].split('/')[6] + img_data_path = f"{self.img_path}/{experiment_id}" + frame_ids = [int(os.path.basename(self.data[idx][1][i]).split('_')[0]) for i in range(len(self.data[idx][1]))] + + if self.use_preprocessed_img: + kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.pt"))) + tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.pt"))) + battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.pt"))) + kinect_img_paths = [kinect[id] for id in frame_ids] + tracker_img_paths = [tracker[id] for id in frame_ids] + battery_img_paths = [battery[id] for id in frame_ids] + else: + kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.jpg"))) + tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.jpg"))) + battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.jpg"))) + kinect_img_paths = [kinect[id] for id in frame_ids] + tracker_img_paths = [tracker[id] for id in frame_ids] + battery_img_paths = [battery[id] for id in frame_ids] + + # load images + kinect_imgs = [ + torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path)) + for img_path in kinect_img_paths + ] + kinect_imgs = torch.stack(kinect_imgs, axis=0) + + tracker_imgs = [ + torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path)) + for img_path in tracker_img_paths + ] + tracker_imgs = torch.stack(tracker_imgs, axis=0) + + battery_imgs = [ + torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path)) + for img_path in battery_img_paths + ] + battery_imgs = torch.stack(battery_imgs, axis=0) + + # load object id to check for false beliefs - only for testing + if self.mode == "test": #or self.mode == "train": + if f"{experiment_id}.txt" in os.listdir(self.false_beliefs_path): + obj_id = self.obj_ids[experiment_id][frame_ids[-1]] + obj_id = next(x for x in obj_id if x is not None) + false_belief = next((line.strip().split(',')[2] for line in open(f"{self.false_beliefs_path}/{experiment_id}.txt") if line.startswith(str(frame_ids[-1]) + ',' + obj_id + ',')), "no") + #if experiment_id in ['test_boelter4_0', 'test_boelter4_7', 'test_boelter4_6', 'test_boelter4_8', 'test_boelter2_3', + # 'test_94342_20', 'test_94342_18', 'test_94342_11', 'test_94342_17', 'test_boelter3_8', 'test_94342_2', + # 'test_boelter2_17', 'test_boelter3_7', 'test_94342_4', 'test_boelter3_9', 'test_boelter_10', + # 'test_boelter2_6', 'test_boelter4_10', 'test_boelter4_2', 'test_boelter4_5', 'test_94342_24', + # 'test_94342_15', 'test_boelter3_5', 'test_94342_8', 'test2', 'test_boelter3_12']: + # print('here!') + # with open(os.path.join(f'results/hgm_test_fb.csv'), mode='a') as file: + # writer = csv.writer(file) + # writer.writerow([experiment_id, obj_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]]) + else: + false_belief = "no" + #with open(os.path.join(f'results/test_fb.csv'), mode='a') as file: + # writer = csv.writer(file) + # writer.writerow([experiment_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]]) + + df = self.bbox_df[ + (self.bbox_df.experiment_name == experiment_id) + #& (self.bbox_df.name == obj_id) # NOTE: load the bounding boxes for all objects + & (self.bbox_df.name != 'P1') + & (self.bbox_df.name != 'P2') + & (self.bbox_df.frame.isin(frame_ids)) + ] + + bboxes = [] + for f in frame_ids: + bbox = torch.tensor(df.loc[df['frame'] == f, ["x_min", "y_min", "x_max", "y_max"]].to_numpy(), dtype=torch.float32) + bbox[:, 0] = bbox[:, 0] / 1280.0 + bbox[:, 1] = bbox[:, 1] / 720.0 + bbox[:, 2] = bbox[:, 2] / 1280.0 + bbox[:, 3] = bbox[:, 3] / 720.0 + bboxes.append(bbox) + bboxes = torch.stack(bboxes) # NOTE: this will need a collate function bc not every video has the same number of objects + + skele1 = self.skeleton_3D[experiment_id]["skele1"][frame_ids] + skele2 = self.skeleton_3D[experiment_id]["skele2"][frame_ids] + + gaze = self.tracker_2D[experiment_id][frame_ids] + + if self.mode == "test": + return ( + kinect_imgs, + tracker_imgs, + battery_imgs, + skele1, + skele2, + bboxes, + tracker_skeID[experiment_id], # <- This is the tracker skeleton ID + gaze, + labels, # <- per object "m1", "m2", "m12", "m21", "mc" + experiment_id, + frame_ids, + #self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding + false_belief + ) + else: + return ( + kinect_imgs, + tracker_imgs, + battery_imgs, + skele1, + skele2, + bboxes, + tracker_skeID[experiment_id], # <- This is the tracker skeleton ID + gaze, + labels, # <- per object "m1", "m2", "m12", "m21", "mc" + experiment_id, + frame_ids + #self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding + ) + + def onehot(self, x, n=len(UNIQUE_OBJ_IDS)): + retval = torch.zeros(n) + if x > 0: + retval[x-1] = 1 + return retval + + def load_obj_ids(self, path: str): + with open(path, "rb") as f: + ids = pickle.load(f) + return ids + + def extract_labels(self): + """TODO: Converts index label to [m1, m2, m12, m21, mc] format. + + """ + return + + def _flatten_mind_obj_timestep(self, mind_obj_dict: dict) -> list: + """Flattens the mind object dict to a list. I.e. takes + + { + "m1": { + "fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange + "loc": null + }, + "m2": { + "fluent": 3, + "loc": null + }, + "m12": { + "fluent": 3, + "loc": null + }, + "m21": { + "fluent": 3, + "loc": null + }, + "mc": { + "fluent": 3, + "loc": null + }, + "mg": { + "fluent": 3, + "loc": [ + 22, + 9 + ] + } + } + + and returns [3, 3, 3, 3, 3, 3] + + Args: + mind_obj_dict (dict): Mind object dict as described in __init__.doctstring. + + Returns: + list: List of mind object labels. + """ + return np.array([mind_obj["fluent"] for key, mind_obj in mind_obj_dict.items() if key != "mg"]) + + def load_skeleton_3D(self, path: str, list_of_ids_to_consider: list): + """Load skeleton 3D data from disk. + + - path + - * <- list of ids + - skele1.p <- 3D coord per id and timestep + - skele2.p <- + + Args: + path (str): Where the skeleton 3D data lie. + list_of_ids_to_consider (list): List of ids to consider. + Defaults to None which means all ids. Otherwise specify a list, + e.g. ["test_94342_23", "test_boelter_21", ...]. + + Returns: + dict: skeleton 3D data as described above in __init__.doctstring. + """ + skeleton_3D = {} + for experiment_id in list_of_ids_to_consider: + skeleton_3D[experiment_id] = {} + with open(f"{path}/{experiment_id}/skele1.p", "rb") as f: + skeleton_3D[experiment_id]["skele1"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32) + with open(f"{path}/{experiment_id}/skele2.p", "rb") as f: + skeleton_3D[experiment_id]["skele2"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32) + return skeleton_3D + + def load_tracker_2D(self, path: str, list_of_ids_to_consider: list): + """Load tracker 2D data from disk. + + - path + - *.p <- 2D coord per id and timestep + + Args: + path (str): Where the tracker 2D data lie. + list_of_ids_to_consider (list): List of ids to consider. + Defaults to None which means all ids. Otherwise specify a list, + e.g. ["test_94342_23", "test_boelter_21", ...]. + + Returns: + dict: tracker 2D data. + """ + tracker_2D = {} + for experiment_id in list_of_ids_to_consider: + with open(f"{path}/{experiment_id}.p", "rb") as f: + tracker_2D[experiment_id] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32) + return tracker_2D + + def load_bbox(self, path: str, list_of_ids_to_consider: list): + """Load bbox data from disk. + + - bbox_tensors.pickle <- bbox per experiment id one tensor + + Args: + path (str): Where the bbox data lie. + list_of_ids_to_consider (list): List of ids to consider. + + Returns: + dict: bbox data. + """ + with open(path, "rb") as f: + pickle_data = pickle.load(f) + for key in CLIPS_IDS_88: + if key not in list_of_ids_to_consider: + pickle_data.pop(key, None) + return pickle_data + + + + + + +if __name__ == '__main__': + + # os.environ['PYTHONHASHSEED'] = str(42) + # torch.manual_seed(42) + # np.random.seed(42) + # random.seed(42) + + data = TBDDataset(use_preprocessed_img=True, mode="test") + + from tqdm import tqdm + for i in tqdm(range(data.__len__())): + data[i] + + breakpoint() + + from torch.utils.data import DataLoader + + # Just for guessing time + data_0=data[0] + data_last=data[len(data)-1] + idx = np.random.randint(1, len(data)-1) # Something in between. + start = time.time() + ( + kinect_imgs, # <- len x 720 x 1280 x 3 originally, likely smaller now + tracker_imgs, + battery_imgs, + skele1, + skele2, + bbox, + tracker_skeID_sample, # <- This is the tracker skeleton ID + tracker2d, + label, + experiment_id, # From here for debugging + timestep, + #obj_id, # <- This is the object ID as a one-hot + false_belief + ) = data[idx] + end = time.time() + print(f"Time for one sample: {end-start}") + + print('kinect:', kinect_imgs.shape) + print('tracker:', tracker_imgs.shape) + print('battery:', battery_imgs.shape) + print('skele1:', skele1.shape) + print('skele2:', skele2.shape) + print('gaze:', tracker2d.shape) + print('bbox:', bbox.shape) + print('label:', label) + + #breakpoint() + + dl = DataLoader( + data, + batch_size=4, + shuffle=False, + collate_fn=collate_fn + ) + + from tqdm import tqdm + + for j, batch in tqdm(enumerate(dl)): + #print(j, end='\r') + img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch + #breakpoint() + #print(img_3rd_pov.shape) + #print(img_tracker.shape) + #print(img_battery.shape) + #print(pose1.shape, pose2.shape) + #print(bbox.shape) + #print(gaze.shape) + + + breakpoint() diff --git a/tbd/test.py b/tbd/test.py new file mode 100644 index 0000000..aaf7e9e --- /dev/null +++ b/tbd/test.py @@ -0,0 +1,196 @@ +import torch +import csv +import argparse +from tqdm import tqdm +from torch.utils.data import DataLoader +import random +import os +import numpy as np + +from tbd_dataloader import TBDDataset, collate_fn_test +from models.common_mind import CommonMindToMnet +from models.sl import SLToMnet +from models.implicit import ImplicitToMnet +from utils.helpers import compute_f1_scores + + +def test(args): + + test_dataset = TBDDataset( + path=args.data_path, + mode="test", + use_preprocessed_img=True + ) + test_dataloader = DataLoader( + test_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + collate_fn=collate_fn_test + ) + + device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu') + + # model + if args.model_type == 'tom_cm': + model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + elif args.model_type == 'tom_sl': + model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device) + elif args.model_type == 'tom_impl': + model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + else: raise NotImplementedError + + model.load_state_dict(torch.load(args.load_model_path, map_location=device)) + model.device = device + + model.eval() + + if args.save_preds: + # Define the output file path + folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}' + if not os.path.exists(folder_path): + os.makedirs(folder_path) + print(f'Saving predictions in {folder_path}.') + + print('Testing...') + m1_pred_list = [] + m2_pred_list = [] + m12_pred_list = [] + m21_pred_list = [] + mc_pred_list = [] + m1_label_list = [] + m2_label_list = [] + m12_label_list = [] + m21_label_list = [] + mc_label_list = [] + with torch.no_grad(): + for j, batch in tqdm(enumerate(test_dataloader)): + img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch + if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking) + if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking) + if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking) + if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking) + if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking) + if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking) + if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking) + m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) + m1_pred = m1_pred.reshape(-1, 4) + m2_pred = m2_pred.reshape(-1, 4) + m12_pred = m12_pred.reshape(-1, 4) + m21_pred = m21_pred.reshape(-1, 4) + mc_pred = mc_pred.reshape(-1, 4) + m1_label = labels[:, 0].reshape(-1).to(device) + m2_label = labels[:, 1].reshape(-1).to(device) + m12_label = labels[:, 2].reshape(-1).to(device) + m21_label = labels[:, 3].reshape(-1).to(device) + mc_label = labels[:, 4].reshape(-1).to(device) + + m1_pred_list.append(m1_pred) + m2_pred_list.append(m2_pred) + m12_pred_list.append(m12_pred) + m21_pred_list.append(m21_pred) + mc_pred_list.append(mc_pred) + m1_label_list.append(m1_label) + m2_label_list.append(m2_label) + m12_label_list.append(m12_label) + m21_label_list.append(m21_label) + mc_label_list.append(mc_label) + + if args.save_preds: + torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt")) + data = [( + i, + torch.argmax(m1_pred[i]).cpu().numpy(), + torch.argmax(m2_pred[i]).cpu().numpy(), + torch.argmax(m12_pred[i]).cpu().numpy(), + torch.argmax(m21_pred[i]).cpu().numpy(), + torch.argmax(mc_pred[i]).cpu().numpy(), + m1_label[i].cpu().numpy(), + m2_label[i].cpu().numpy(), + m12_label[i].cpu().numpy(), + m21_label[i].cpu().numpy(), + mc_label[i].cpu().numpy(), + false_belief[i]) for i in range(len(labels)) + ] + header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief'] + with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerow(header) # Write the header row + writer.writerows(data) # Write the data rows + + #np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy()) + test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores( + torch.cat(m1_pred_list), + torch.cat(m1_label_list), + torch.cat(m2_pred_list), + torch.cat(m2_label_list), + torch.cat(m12_pred_list), + torch.cat(m12_label_list), + torch.cat(m21_pred_list), + torch.cat(m21_label_list), + torch.cat(mc_pred_list), + torch.cat(mc_label_list) + ) + + print("Test m1 F1: {}".format(test_m1_f1)) + print("Test m2 F1: {}".format(test_m2_f1)) + print("Test m12 F1: {}".format(test_m12_f1)) + print("Test m21 F1: {}".format(test_m21_f1)) + print("Test mc F1: {}".format(test_mc_f1)) + + with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f: + f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}") + f.write(f"m1 f1: {test_m1_f1}") + f.write(f"m2 f1: {test_m2_f1}") + f.write(f"m12 f1: {test_m12_f1}") + f.write(f"m21 f1: {test_m21_f1}") + f.write(f"mc f1: {test_mc_f1}") + f.close() + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + # Define the command-line arguments + parser.add_argument('--gpu_id', type=int) + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--presaved', type=int, default=128) + parser.add_argument('--non_blocking', action='store_true') + parser.add_argument('--num_workers', type=int, default=16) + parser.add_argument('--pin_memory', action='store_true') + parser.add_argument('--model_type', type=str) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--aggr', type=str, default='concat', required=False) + parser.add_argument('--use_resnet', action='store_true') + parser.add_argument('--hidden_dim', type=int, default=64) + parser.add_argument('--tom_weight', type=float, default=2.0, required=False) + parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']) + parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd') + parser.add_argument('--save_path', type=str, default='experiments/') + parser.add_argument('--test_frames', type=str, default=None) + parser.add_argument('--median', type=int, default=None) + parser.add_argument('--load_model_path', type=str) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--save_preds', action='store_true') + + # Parse the command-line arguments + args = parser.parse_args() + + if args.model_type == 'tom_cm' or args.model_type == 'tom_impl': + if not args.aggr: + parser.error("The choosen --model_type requires --aggr") + if args.model_type == 'tom_sl' and not args.tom_weight: + parser.error("The choosen --model_type requires --tom_weight") + + os.environ['PYTHONHASHSEED'] = str(args.seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + print('###########################################################################') + print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!') + print('###########################################################################') + + test(args) \ No newline at end of file diff --git a/tbd/train.py b/tbd/train.py new file mode 100644 index 0000000..7ac4b3b --- /dev/null +++ b/tbd/train.py @@ -0,0 +1,474 @@ +import torch +import os +import argparse +import numpy as np +import random +import datetime +import wandb +from tqdm import tqdm +from torch.utils.data import DataLoader +import torch.nn as nn +from torch.optim.lr_scheduler import CosineAnnealingLR + +from tbd_dataloader import TBDDataset, collate_fn +from models.common_mind import CommonMindToMnet +from models.sl import SLToMnet +from models.implicit import ImplicitToMnet +from utils.helpers import count_parameters, compute_f1_scores + + +def main(args): + + train_dataset = TBDDataset( + path=args.data_path, + mode="train", + use_preprocessed_img=True + ) + train_dataloader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + collate_fn=collate_fn + ) + val_dataset = TBDDataset( + path=args.data_path, + mode="val", + use_preprocessed_img=True + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + collate_fn=collate_fn + ) + + train_data = [data[1] for data in train_dataset.data] + val_data = [data[1] for data in val_dataset.data] + if args.logger: + wandb.config.update({"train_data": train_data}) + wandb.config.update({"val_data": val_data}) + + device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu') + + # model + if args.model_type == 'tom_cm': + model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + elif args.model_type == 'tom_sl': + model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device) + elif args.model_type == 'tom_impl': + model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) + else: raise NotImplementedError + if args.resume_from_checkpoint is not None: + model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device)) + # optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + # scheduler + if args.scheduler == None: + scheduler = None + else: + scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5) + # loss function + if args.model_type == 'tom_sl': + ce_loss_m1 = nn.NLLLoss() + ce_loss_m2 = nn.NLLLoss() + ce_loss_m12 = nn.NLLLoss() + ce_loss_m21 = nn.NLLLoss() + ce_loss_mc = nn.NLLLoss() + else: + ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + + stats = { + 'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}, + 'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []} + } + max_val_f1 = 0 + max_val_classification_epoch = None + counter = 0 + + print(f'Number of parameters: {count_parameters(model)}') + + for i in range(args.num_epoch): + # training + print('Training for epoch {}/{}...'.format(i+1, args.num_epoch)) + epoch_train_loss_m1 = 0.0 + epoch_train_loss_m2 = 0.0 + epoch_train_loss_m12 = 0.0 + epoch_train_loss_m21 = 0.0 + epoch_train_loss_mc = 0.0 + m1_train_batch_pred_list = [] + m2_train_batch_pred_list = [] + m12_train_batch_pred_list = [] + m21_train_batch_pred_list = [] + mc_train_batch_pred_list = [] + m1_train_batch_label_list = [] + m2_train_batch_label_list = [] + m12_train_batch_label_list = [] + m21_train_batch_label_list = [] + mc_train_batch_label_list = [] + model.train() + for j, batch in tqdm(enumerate(train_dataloader)): + img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch + if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking) + if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking) + if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking) + if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking) + if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking) + if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking) + if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking) + m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) + m1_pred = m1_pred.reshape(-1, 4) + m2_pred = m2_pred.reshape(-1, 4) + m12_pred = m12_pred.reshape(-1, 4) + m21_pred = m21_pred.reshape(-1, 4) + mc_pred = mc_pred.reshape(-1, 4) + m1_label = labels[:, 0].reshape(-1).to(device) + m2_label = labels[:, 1].reshape(-1).to(device) + m12_label = labels[:, 2].reshape(-1).to(device) + m21_label = labels[:, 3].reshape(-1).to(device) + mc_label = labels[:, 4].reshape(-1).to(device) + + loss_m1 = ce_loss_m1(m1_pred, m1_label) + loss_m2 = ce_loss_m2(m2_pred, m2_label) + loss_m12 = ce_loss_m12(m12_pred, m12_label) + loss_m21 = ce_loss_m21(m21_pred, m21_label) + loss_mc = ce_loss_mc(mc_pred, mc_label) + loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc + + epoch_train_loss_m1 += loss_m1.data.item() + epoch_train_loss_m2 += loss_m2.data.item() + epoch_train_loss_m12 += loss_m12.data.item() + epoch_train_loss_m21 += loss_m21.data.item() + epoch_train_loss_mc += loss_mc.data.item() + + m1_train_batch_pred_list.append(m1_pred) + m2_train_batch_pred_list.append(m2_pred) + m12_train_batch_pred_list.append(m12_pred) + m21_train_batch_pred_list.append(m21_pred) + mc_train_batch_pred_list.append(mc_pred) + m1_train_batch_label_list.append(m1_label) + m2_train_batch_label_list.append(m2_label) + m12_train_batch_label_list.append(m12_label) + m21_train_batch_label_list.append(m21_label) + mc_train_batch_label_list.append(mc_label) + + optimizer.zero_grad() + if args.clip_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) + loss.backward() + optimizer.step() + + if args.logger: wandb.log({ + 'batch_train_loss': loss.data.item(), + 'lr': optimizer.param_groups[-1]['lr'] + }) + + print("Epoch {}/{} batch {}/{} training done with loss={}".format( + i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item()) + ) + + if scheduler: scheduler.step() + + train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores( + torch.cat(m1_train_batch_pred_list), + torch.cat(m1_train_batch_label_list), + torch.cat(m2_train_batch_pred_list), + torch.cat(m2_train_batch_label_list), + torch.cat(m12_train_batch_pred_list), + torch.cat(m12_train_batch_label_list), + torch.cat(m21_train_batch_pred_list), + torch.cat(m21_train_batch_label_list), + torch.cat(mc_train_batch_pred_list), + torch.cat(mc_train_batch_label_list) + ) + + print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format( + i+1, + args.num_epoch, + epoch_train_loss_m1/len(train_dataloader), + epoch_train_loss_m2/len(train_dataloader), + epoch_train_loss_m12/len(train_dataloader), + epoch_train_loss_m21/len(train_dataloader), + epoch_train_loss_mc/len(train_dataloader), + train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score + ) + ) + stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader)) + stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader)) + stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader)) + stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader)) + stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader)) + stats['train']['m1_f1'].append(train_m1_f1_score) + stats['train']['m2_f1'].append(train_m2_f1_score) + stats['train']['m12_f1'].append(train_m12_f1_score) + stats['train']['m21_f1'].append(train_m21_f1_score) + stats['train']['mc_f1'].append(train_mc_f1_score) + + if args.logger: wandb.log( + { + 'train_m1_loss': epoch_train_loss_m1/len(train_dataloader), + 'train_m2_loss': epoch_train_loss_m2/len(train_dataloader), + 'train_m12_loss': epoch_train_loss_m12/len(train_dataloader), + 'train_m21_loss': epoch_train_loss_m21/len(train_dataloader), + 'train_mc_loss': epoch_train_loss_mc/len(train_dataloader), + 'train_loss': epoch_train_loss_m1/len(train_dataloader) + \ + epoch_train_loss_m2/len(train_dataloader) + \ + epoch_train_loss_m12/len(train_dataloader) + \ + epoch_train_loss_m21/len(train_dataloader) + \ + epoch_train_loss_mc/len(train_dataloader), + 'train_m1_f1_score': train_m1_f1_score, + 'train_m2_f1_score': train_m2_f1_score, + 'train_m12_f1_score': train_m12_f1_score, + 'train_m21_f1_score': train_m21_f1_score, + 'train_mc_f1_score': train_mc_f1_score + } + ) + + # validation + print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch)) + epoch_val_loss_m1 = 0.0 + epoch_val_loss_m2 = 0.0 + epoch_val_loss_m12 = 0.0 + epoch_val_loss_m21 = 0.0 + epoch_val_loss_mc = 0.0 + m1_val_batch_pred_list = [] + m2_val_batch_pred_list = [] + m12_val_batch_pred_list = [] + m21_val_batch_pred_list = [] + mc_val_batch_pred_list = [] + m1_val_batch_label_list = [] + m2_val_batch_label_list = [] + m12_val_batch_label_list = [] + m21_val_batch_label_list = [] + mc_val_batch_label_list = [] + model.eval() + with torch.no_grad(): + for j, batch in tqdm(enumerate(val_dataloader)): + img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch + if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking) + if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking) + if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking) + if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking) + if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking) + if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking) + if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking) + m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) + m1_pred = m1_pred.reshape(-1, 4) + m2_pred = m2_pred.reshape(-1, 4) + m12_pred = m12_pred.reshape(-1, 4) + m21_pred = m21_pred.reshape(-1, 4) + mc_pred = mc_pred.reshape(-1, 4) + m1_label = labels[:, 0].reshape(-1).to(device) + m2_label = labels[:, 1].reshape(-1).to(device) + m12_label = labels[:, 2].reshape(-1).to(device) + m21_label = labels[:, 3].reshape(-1).to(device) + mc_label = labels[:, 4].reshape(-1).to(device) + + loss_m1 = ce_loss_m1(m1_pred, m1_label) + loss_m2 = ce_loss_m2(m2_pred, m2_label) + loss_m12 = ce_loss_m12(m12_pred, m12_label) + loss_m21 = ce_loss_m21(m21_pred, m21_label) + loss_mc = ce_loss_mc(mc_pred, mc_label) + loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc + + epoch_val_loss_m1 += loss_m1.data.item() + epoch_val_loss_m2 += loss_m2.data.item() + epoch_val_loss_m12 += loss_m12.data.item() + epoch_val_loss_m21 += loss_m21.data.item() + epoch_val_loss_mc += loss_mc.data.item() + + m1_val_batch_pred_list.append(m1_pred) + m2_val_batch_pred_list.append(m2_pred) + m12_val_batch_pred_list.append(m12_pred) + m21_val_batch_pred_list.append(m21_pred) + mc_val_batch_pred_list.append(mc_pred) + m1_val_batch_label_list.append(m1_label) + m2_val_batch_label_list.append(m2_label) + m12_val_batch_label_list.append(m12_label) + m21_val_batch_label_list.append(m21_label) + mc_val_batch_label_list.append(mc_label) + + if args.logger: wandb.log({'batch_val_loss': loss.data.item()}) + print("Epoch {}/{} batch {}/{} validation done with loss={}".format( + i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item()) + ) + + val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores( + torch.cat(m1_val_batch_pred_list), + torch.cat(m1_val_batch_label_list), + torch.cat(m2_val_batch_pred_list), + torch.cat(m2_val_batch_label_list), + torch.cat(m12_val_batch_pred_list), + torch.cat(m12_val_batch_label_list), + torch.cat(m21_val_batch_pred_list), + torch.cat(m21_val_batch_label_list), + torch.cat(mc_val_batch_pred_list), + torch.cat(mc_val_batch_label_list) + ) + + print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format( + i+1, + args.num_epoch, + epoch_val_loss_m1/len(val_dataloader), + epoch_val_loss_m2/len(val_dataloader), + epoch_val_loss_m12/len(val_dataloader), + epoch_val_loss_m21/len(val_dataloader), + epoch_val_loss_mc/len(val_dataloader), + val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score + ) + ) + + stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader)) + stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader)) + stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader)) + stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader)) + stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader)) + stats['val']['m1_f1'].append(val_m1_f1_score) + stats['val']['m2_f1'].append(val_m2_f1_score) + stats['val']['m12_f1'].append(val_m12_f1_score) + stats['val']['m21_f1'].append(val_m21_f1_score) + stats['val']['mc_f1'].append(val_mc_f1_score) + + if args.logger: wandb.log( + { + 'val_m1_loss': epoch_val_loss_m1/len(val_dataloader), + 'val_m2_loss': epoch_val_loss_m2/len(val_dataloader), + 'val_m12_loss': epoch_val_loss_m12/len(val_dataloader), + 'val_m21_loss': epoch_val_loss_m21/len(val_dataloader), + 'val_mc_loss': epoch_val_loss_mc/len(val_dataloader), + 'val_loss': epoch_val_loss_m1/len(val_dataloader) + \ + epoch_val_loss_m2/len(val_dataloader) + \ + epoch_val_loss_m12/len(val_dataloader) + \ + epoch_val_loss_m21/len(val_dataloader) + \ + epoch_val_loss_mc/len(val_dataloader), + 'val_m1_f1_score': val_m1_f1_score, + 'val_m2_f1_score': val_m2_f1_score, + 'val_m12_f1_score': val_m12_f1_score, + 'val_m21_f1_score': val_m21_f1_score, + 'val_mc_f1_score': val_mc_f1_score + } + ) + + # check for best stat/model using validation accuracy + if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1: + max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] + max_val_classification_epoch = i+1 + torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model')) + counter = 0 + else: + counter += 1 + print(f'EarlyStopping counter: {counter} out of {args.patience}.') + if counter >= args.patience: + break + + with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f: + f.write('{}\n'.format(CFG)) + f.write('{}\n'.format(train_data)) + f.write('{}\n'.format(val_data)) + f.write('{}\n'.format(stats)) + f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1)) + f.close() + + print(f'Results saved in {experiment_save_path}') + + + + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + # Define the command-line arguments + parser.add_argument('--gpu_id', type=int) + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--logger', action='store_true') + parser.add_argument('--presaved', type=int, default=128) + parser.add_argument('--clip_grad_norm', type=float, default=0.5) + parser.add_argument('--use_mixup', action='store_true') + parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False) + parser.add_argument('--non_blocking', action='store_true') + parser.add_argument('--patience', type=int, default=99) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--num_workers', type=int, default=8) + parser.add_argument('--pin_memory', action='store_true') + parser.add_argument('--num_epoch', type=int, default=300) + parser.add_argument('--lr', type=float, default=4e-4) + parser.add_argument('--scheduler', type=str, default=None) + parser.add_argument('--dropout', type=float, default=0.1) + parser.add_argument('--weight_decay', type=float, default=0.005) + parser.add_argument('--label_smoothing', type=float, default=0.1) + parser.add_argument('--model_type', type=str) + parser.add_argument('--aggr', type=str, default='concat', required=False) + parser.add_argument('--use_resnet', action='store_true') + parser.add_argument('--hidden_dim', type=int, default=64) + parser.add_argument('--tom_weight', type=float, default=2.0, required=False) + parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']) + parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd') + parser.add_argument('--save_path', type=str, default='experiments/') + parser.add_argument('--resume_from_checkpoint', type=str, default=None) + + # Parse the command-line arguments + args = parser.parse_args() + + if args.use_mixup and not args.mixup_alpha: + parser.error("--use_mixup requires --mixup_alpha") + if args.model_type == 'tom_cm' or args.model_type == 'tom_impl': + if not args.aggr: + parser.error("The choosen --model_type requires --aggr") + if args.model_type == 'tom_sl' and not args.tom_weight: + parser.error("The choosen --model_type requires --tom_weight") + + # get experiment ID + experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train' + if not os.path.exists(args.save_path): + os.makedirs(args.save_path, exist_ok=True) + experiment_save_path = os.path.join(args.save_path, experiment_id) + os.makedirs(experiment_save_path, exist_ok=True) + + CFG = { + 'use_ocr_custom_loss': 0, + 'presaved': args.presaved, + 'batch_size': args.batch_size, + 'num_epoch': args.num_epoch, + 'lr': args.lr, + 'scheduler': args.scheduler, + 'weight_decay': args.weight_decay, + 'model_type': args.model_type, + 'use_resnet': args.use_resnet, + 'hidden_dim': args.hidden_dim, + 'tom_weight': args.tom_weight, + 'dropout': args.dropout, + 'label_smoothing': args.label_smoothing, + 'clip_grad_norm': args.clip_grad_norm, + 'use_mixup': args.use_mixup, + 'mixup_alpha': args.mixup_alpha, + 'non_blocking_tensors': args.non_blocking, + 'patience': args.patience, + 'pin_memory': args.pin_memory, + 'resume_from_checkpoint': args.resume_from_checkpoint, + 'aggr': args.aggr, + 'mods': args.mods, + 'save_path': experiment_save_path , + 'seed': args.seed + } + + print(CFG) + print(f'Saving results in {experiment_save_path}') + + # set seed values + if args.logger: + wandb.init(project="tbd", config=CFG) + os.environ['PYTHONHASHSEED'] = str(args.seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + main(args) \ No newline at end of file diff --git a/tbd/utils/fb_scores_err.py b/tbd/utils/fb_scores_err.py new file mode 100644 index 0000000..7e9f3b0 --- /dev/null +++ b/tbd/utils/fb_scores_err.py @@ -0,0 +1,224 @@ +import os +import csv +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import f1_score +import numpy as np +from tqdm import tqdm + +ALPHA = 0.7 +BAR_WIDTH = 0.27 +sns.set_theme(style='whitegrid') +#sns.set_palette('mako') + +MTOM_COLORS = { + "MN1": (110/255, 117/255, 161/255), + "MN2": (179/255, 106/255, 98/255), + "Base": (193/255, 198/255, 208/255), + "CG": (170/255, 129/255, 42/255), + "IC": (97/255, 112/255, 83/255), + "DB": (144/255, 63/255, 110/255) +} + +model_to_subdir = { + "IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"], + "IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"], + "IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"], + "IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"], + "CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"], + "CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"], + "CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"], + "CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"], + "DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"], + "Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"] +} + +def read_data_from_csv(subdirectory_path): + print(subdirectory_path) + data = [] + csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')] + for csv_file in csv_files: + file_path = os.path.join(subdirectory_path, csv_file) + with open(file_path, 'r') as file: + reader = csv.reader(file) + header_skipped = False + for row in reader: + if not header_skipped: + header_skipped = True + continue + frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row + data.append({ + 'frame': int(frame), + 'm1_pred': int(m1_pred), + 'm2_pred': int(m2_pred), + 'm12_pred': int(m12_pred), + 'm21_pred': int(m21_pred), + 'mc_pred': int(mc_pred), + 'm1_label': int(m1_label), + 'm2_label': int(m2_label), + 'm12_label': int(m12_label), + 'm21_label': int(m21_label), + 'mc_label': int(mc_label), + 'false_belief': false_belief, + }) + return data + +def compute_correct_false_belief(data, mind="all", folder=None): + total_false_belief = 0 + correct_false_belief = 0 + for item in data: + if 'false' in item['false_belief']: + false_belief_type = item['false_belief'].split('_')[0] + if mind == "all" or false_belief_type in mind: + total_false_belief += 1 + if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]: + if folder is not None: + with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f: + f.write(f"{str(1)}\n") + correct_false_belief += 1 + else: + if folder is not None: + with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f: + f.write(f"{str(0)}\n") + if total_false_belief == 0: + accuracy = 0.0 + else: + accuracy = correct_false_belief / total_false_belief + return accuracy + +def compute_macro_f1_score(data, mind="all"): + y_true = [] + y_pred = [] + + for item in data: + if 'false' in item['false_belief']: + false_belief_type = item['false_belief'].split('_')[0] + if mind == "all" or false_belief_type in mind: + y_true.append(int(item[f"{false_belief_type}_label"])) + y_pred.append(int(item[f"{false_belief_type}_pred"])) + + if not y_true or not y_pred: + macro_f1 = 0.0 + else: + macro_f1 = f1_score(y_true, y_pred, average='macro') + + return macro_f1 + +def delete_files_in_subfolders(folder_path, file_names_to_delete): + """ + Delete specified files in all subfolders of a given folder. + + Parameters: + folder_path: The path to the folder containing subfolders. + file_names_to_delete: A list of file names to be deleted. + + Returns: + None + """ + for root, _, _ in os.walk(folder_path): + for file_name in file_names_to_delete: + file_path = os.path.join(root, file_name) + if os.path.exists(file_path): + os.remove(file_path) + print(f"Deleted: {file_path}") + + + + +if __name__ == "__main__": + + folder_path = "predictions" + files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"] + delete_files_in_subfolders(folder_path, files_to_delete) + + metric = "Accuracy" + if metric == "Macro F1": + score_function = compute_macro_f1_score + elif metric == "Accuracy": + score_function = compute_correct_false_belief + else: + raise ValueError + + models = [ + 'Base', 'DB', + 'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$', + 'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$' + ] + + parent_dir = 'predictions' + minds = categories = ['m1', 'm2', 'm12', 'm21'] + score_m1_m2 = [] + score_m12_m21 = [] + score_all = [] + std_m1_m2 = [] + std_m12_m21 = [] + std_all = [] + + for model in models: + model_scores_m1_m2 = [] + model_scores_m12_m21 = [] + model_scores_all = [] + for s in range(3): + subdir_path = os.path.join(parent_dir, model_to_subdir[model][s]) + data = read_data_from_csv(subdir_path) + model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s])) + model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s])) + model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s])) + score_m1_m2.append(np.mean(model_scores_m1_m2)) + std_m1_m2.append(np.std(model_scores_m1_m2)) + score_m12_m21.append(np.mean(model_scores_m12_m21)) + std_m12_m21.append(np.std(model_scores_m12_m21)) + score_all.append(np.mean(model_scores_all)) + std_all.append(np.std(model_scores_all)) + + # Create a dataframe to use with sns.catplot + data = { + 'Model': [m for m in models], + 'FO_FB_mean': score_m1_m2, + 'FO_FB_std': std_m1_m2, + 'SO_FB_mean': score_m12_m21, + 'SO_FB_std': std_m12_m21, + 'Both_mean': score_all, + 'Both_std': std_all + } + + models = data['Model'] + fo_fb_mean = data['FO_FB_mean'] + fo_fb_std = data['FO_FB_std'] + so_fb_mean = data['SO_FB_mean'] + so_fb_std = data['SO_FB_std'] + both_mean = data['Both_mean'] + both_std = data['Both_std'] + + bar_width = BAR_WIDTH + x = np.arange(len(models)) + + plt.figure(figsize=(13, 3.5)) + fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA) + so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA) + both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA) + + def add_labels(bars, std_values): + cnt = 0 + for bar, std in zip(bars, std_values): + height = bar.get_height() + offset = std + 0.01 + if cnt == 0 or cnt == 1 or cnt == 9: + plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10) + else: + plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10) + cnt = cnt + 1 + + add_labels(fo_fb_bars, fo_fb_std) + add_labels(so_fb_bars, so_fb_std) + add_labels(both_bars, both_std) + + plt.gca().spines['top'].set_visible(False) + plt.gca().spines['right'].set_visible(False) + plt.xlabel('MToMnet', fontsize=14) + plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14) + plt.xticks(x, models, rotation=0, fontsize=14) + plt.yticks(fontsize=14) + plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3) + plt.tight_layout() + plt.savefig('results/false_belief_first_vs_second.pdf') diff --git a/tbd/utils/helpers.py b/tbd/utils/helpers.py new file mode 100644 index 0000000..97c4a01 --- /dev/null +++ b/tbd/utils/helpers.py @@ -0,0 +1,210 @@ +import numpy as np +import random +import torch +from sklearn.metrics import f1_score + + +tracker_skeID = { + 'test1': 'skele1', 'test2': 'skele2', 'test6': 'skele2', 'test7': 'skele1','test_9434_1': 'skele2', + 'test_9434_3': 'skele2', 'test_9434_18': 'skele1', 'test_94342_0': 'skele2', 'test_94342_1': 'skele2', + 'test_94342_2': 'skele2', 'test_94342_3': 'skele2', 'test_94342_4': 'skele1', 'test_94342_5': 'skele1', + 'test_94342_6': 'skele1', 'test_94342_7': 'skele1', 'test_94342_8': 'skele1', + 'test_94342_10': 'skele2', 'test_94342_11': 'skele2', 'test_94342_12': 'skele1', + 'test_94342_13': 'skele2', 'test_94342_14': 'skele1', 'test_94342_15': 'skele2', + 'test_94342_16': 'skele1', 'test_94342_17': 'skele2', 'test_94342_18': 'skele1', + 'test_94342_19': 'skele2', 'test_94342_20': 'skele1', 'test_94342_21': 'skele2', + 'test_94342_22': 'skele1', 'test_94342_23': 'skele1', 'test_94342_24': 'skele1', + 'test_94342_25': 'skele2', 'test_94342_26': 'skele1', 'test_boelter_1': 'skele2', + 'test_boelter_2': 'skele2', 'test_boelter_3': 'skele2', + 'test_boelter_4': 'skele1', 'test_boelter_5': 'skele1', 'test_boelter_6': 'skele1', + 'test_boelter_7': 'skele1', 'test_boelter_9': 'skele1', 'test_boelter_10': 'skele1', + 'test_boelter_12': 'skele2', 'test_boelter_13': 'skele1', 'test_boelter_14': 'skele1', + 'test_boelter_15': 'skele1', 'test_boelter_17': 'skele2', 'test_boelter_18': 'skele1', + 'test_boelter_19': 'skele2', 'test_boelter_21': 'skele1', 'test_boelter_22': 'skele2', + 'test_boelter_24': 'skele1', 'test_boelter_25': 'skele1', + 'test_boelter2_0': 'skele1', 'test_boelter2_2': 'skele1', 'test_boelter2_3': 'skele1', + 'test_boelter2_4': 'skele1', 'test_boelter2_5': 'skele1', 'test_boelter2_6': 'skele1', + 'test_boelter2_7': 'skele2', 'test_boelter2_8': 'skele2', 'test_boelter2_12': 'skele2', + 'test_boelter2_14': 'skele2', 'test_boelter2_15': 'skele2', 'test_boelter2_16': 'skele1', + 'test_boelter2_17': 'skele1', + 'test_boelter3_0': 'skele1', 'test_boelter3_1': 'skele2', 'test_boelter3_2': 'skele2', + 'test_boelter3_3': 'skele2', 'test_boelter3_4': 'skele1', 'test_boelter3_5': 'skele2', + 'test_boelter3_6': 'skele2', 'test_boelter3_7': 'skele1', 'test_boelter3_8': 'skele2', + 'test_boelter3_9': 'skele2', 'test_boelter3_10': 'skele1', 'test_boelter3_11': 'skele2', + 'test_boelter3_12': 'skele2', 'test_boelter3_13': 'skele2', + 'test_boelter4_0': 'skele2', 'test_boelter4_1': 'skele2', 'test_boelter4_2': 'skele2', + 'test_boelter4_3': 'skele2', 'test_boelter4_4': 'skele2', 'test_boelter4_5': 'skele2', + 'test_boelter4_6': 'skele2', 'test_boelter4_7': 'skele2', 'test_boelter4_8': 'skele2', + 'test_boelter4_9': 'skele2', 'test_boelter4_10': 'skele2', 'test_boelter4_11': 'skele2', + 'test_boelter4_12': 'skele2', 'test_boelter4_13': 'skele2', +} + +event_seg_tracker = { + 'test_9434_18': [[0, 749, 0], [750, 824, 0], [825, 863, 2], [864, 974, 0], [975, 1041, 0]], + 'test_94342_1': [[0, 13, 0], [14, 104, 0], [105, 333, 0], [334, 451, 0], [452, 652, 0], [653, 897, 0], [898, 1076, 0], [1077, 1181, 0], [1181, 1266, 0],[1267, 1386, 0]], + 'test_94342_6': [[0, 95, 0], [96, 267, 1], [268, 441, 1], [442, 559, 1], [560, 681, 1], [682, 796, 1], [797, 835, 1], [836, 901, 0], [902, 943, 1]], + 'test_94342_10': [[0, 36, 0], [37, 169, 0], [170, 244, 1], [245, 424, 0], [425, 599, 0], [600, 640, 0], [641, 680, 0], [681, 726, 1], [727, 866, 2], [867, 1155, 2]], + 'test_94342_21': [[0, 13, 0], [14, 66, 2], [67, 594, 2], [595, 1097, 2], [1098, 1133, 0]], + 'test1': [[0, 477, 0], [478, 559, 0], [560, 689, 2], [690, 698, 0]], + 'test6': [[0, 140, 0], [141, 375, 0], [376, 678, 0], [679, 703, 0]], + 'test7': [[0, 100, 0], [101, 220, 2], [221, 226, 0]], + 'test_boelter_2': [[0, 154, 0], [155, 279, 0], [280, 371, 0], [372, 450, 0], [451, 470, 0], [471, 531, 0],[532, 606, 0]], + 'test_boelter_7': [[0, 69, 0], [70, 118, 1], [119, 239, 0], [240, 328, 1], [329, 376, 0], [377, 397, 1], [398, 520, 0], [521, 564, 0], [565, 619, 1], [620, 688, 1], [689, 871, 0], [872, 897, 0], [898, 958, 1], [959, 1010, 0], [1011, 1084, 0], [1085, 1140, 0], [1141, 1178, 0], [1179, 1267, 1], [1268, 1317, 0], [1318, 1327, 0]], + 'test_boelter_24': [[0, 62, 0], [63, 185, 2], [186, 233, 2], [234, 292, 2], [293, 314, 0]], + 'test_boelter_12': [[0, 47, 1], [48, 119, 0], [120, 157, 1], [158, 231, 0], [232, 317, 0], [318, 423, 0], [424,459,0], [460, 522, 0], [523, 586, 0], [587, 636, 0], [637, 745, 1], [746, 971, 2]], + 'test_9434_1': [[0, 57, 0], [58, 124, 0], [125, 182, 1], [183, 251, 2],[252, 417, 0]], + 'test_94342_16': [[0, 21, 0], [22, 45, 0], [46, 84, 0], [85, 158, 1], [159, 200, 1], [201, 214, 0],[215, 370, 1], [371, 524, 1], [525, 587, 2], [588, 782, 2],[783, 1009, 2]], + 'test_boelter4_12': [[0, 141, 0], [142, 462, 2], [463, 605, 0], [606, 942, 2], [943, 1232, 2], [1233, 1293, 0]], + 'test_boelter4_9': [[0, 27, 0], [28, 172, 0], [173, 221, 0], [222, 307, 1], [308, 466, 0], [467, 794, 1], [795, 866, 1], [867, 1005, 2], [1006, 1214, 2], [1215, 1270, 0]], + 'test_boelter4_4': [[0, 120, 0], [121, 183, 0], [184, 280, 1], [281, 714, 0]], + 'test_boelter4_3': [[0, 117, 0], [118, 200, 1], [201, 293, 1], [294, 404, 1], [405, 600, 1], [601, 800, 1], [801, 905, 1],[906, 1234, 1]], + 'test_boelter4_1': [[0, 310, 0], [311, 560, 0], [561, 680, 0], [681, 748, 0], [749, 839, 0], [840, 1129, 0], [1130, 1237, 0]], + 'test_boelter3_13': [[0, 204, 2], [205, 300, 2], [301, 488, 2], [489, 755, 2]], + 'test_boelter3_11': [[0, 254, 1], [255, 424, 0], [425, 598, 1], [599, 692, 0], [693, 772, 2], [773, 878, 2], [879, 960, 2], [961, 1171, 2],[1172, 1397, 2]], + 'test_boelter3_6': [[0, 174, 1], [175, 280, 1], [281, 639, 0], [640, 695, 1], [696, 788, 0], [789, 887, 2], [888, 1035, 1], [1036, 1445, 2]], + 'test_boelter3_4': [[0, 158, 1], [159, 309, 1], [310, 477, 1], [478, 668, 1], [669, 780, 1], [781, 817, 0], [818, 848, 1], [849, 942, 1]], + 'test_boelter3_0': [[0, 140, 0], [141, 353, 0], [354, 599, 0], [600, 727, 0],[728, 768, 0]], + 'test_boelter2_15': [[0, 46, 0], [47, 252, 2], [253, 298, 1], [299, 414, 2], [415, 547, 2], [548, 690, 1], [691, 728, 1], [729, 773, 2],[774, 935, 2]], + 'test_boelter2_12': [[0, 163, 0], [164, 285, 1], [286, 444, 1], [445, 519, 0], [520, 583, 1], [584, 623, 0], [624, 660, 0], [661, 854, 1], [855, 921, 1], [922, 1006, 2], [1007, 1125, 2],[1126, 1332, 2], [1333, 1416, 2]], + 'test_boelter2_5': [[0, 94, 0], [95, 176, 1], [177, 246, 1], [247, 340, 1], [341, 442, 1], [443, 547, 1], [548, 654, 1], [655, 734, 0], [735, 792, 0], [793, 1019, 0], [1020, 1088, 0], [1089, 1206, 0], [1207, 1316, 1], [1317, 1466, 1], [1467, 1787, 2], [1788, 1936, 1], [1937, 2084, 2]], + 'test_boelter2_4': [[0, 260, 1], [261, 421, 1], [422, 635, 1], [636, 741, 1], [742, 846, 1], [847, 903, 1], [904, 953, 1], [954, 1005, 1], [1006, 1148, 1], [1149, 1270, 1], [1271, 1525, 1]], + 'test_boelter2_2': [[0, 131, 0], [132, 226, 0], [227, 267, 0], [268, 352, 0], [353, 412, 0], [413, 457, 0], [458, 502, 0], [503, 532, 0], [533, 578, 0], [579, 640, 0], [641, 722, 0], [723, 826, 0], [827, 913, 0], [914, 992, 0], [993, 1070, 0], [1071, 1265, 0], [1266, 1412, 0]], + 'test_boelter_21': [[0, 238, 1], [239, 310, 0], [311, 373, 1], [374, 457, 0],[458, 546, 2], [547, 575, 1], [576, 748, 2], [749, 952, 2]], +} + +event_seg_battery={ + 'test1': [[0, 94, 0], [95, 155, 0], [156, 225, 0], [226, 559, 0], [560, 689, 2], [690, 698, 0]], + 'test7': [[0, 70, 0], [71, 100, 0], [101, 220, 2], [221, 226, 0]], + 'test6': [[0, 488, 0], [489, 541, 0], [542, 672, 0], [673, 703, 0]], + 'test_94342_10': [[0, 156, 0], [157, 169, 0], [170, 244, 1], [245, 274, 0], [275, 389, 0], [390, 525, 0], [526, 665, 0], [666, 680, 0], [681, 726, 1], [727, 866, 2], [867, 1155, 2]], + 'test_94342_1': [[0, 751, 0], [752, 876, 0], [877, 1167, 0], [1168, 1386, 0]], + 'test_9434_18': [[0, 96, 0], [97, 361, 0], [362, 528, 0], [529, 608, 0], [609, 824, 0], [825, 863, 2], [864, 1041, 0]], + 'test_94342_6': [[0, 95, 0], [96, 267, 1], [268, 441, 1], [442, 559, 1], [560, 681, 1], [682, 796, 1], [797, 835, 1], [836, 901, 0], [902, 943, 1]], + 'test_boelter_24': [[0, 62, 0], [63, 185, 2], [186, 233, 2], [234, 292, 2], [293, 314, 0]], + 'test_boelter2_4': [[0, 260, 1], [261, 421, 1], [422, 635, 1], [636, 741, 1], [742, 846, 1], [847, 903, 1], [904, 953, 1], [954, 1005, 1], [1006, 1148, 1], [1149, 1270, 1], [1271, 1525, 1]], + 'test_boelter2_5': [[0, 94, 0], [95, 176, 1], [177, 246, 1], [247, 340, 1], [341, 442, 1], [443, 547, 1], [548, 654, 1], [655, 1206, 0], [1207, 1316, 1], [1317, 1466, 1], [1467, 1787, 2], [1788, 1936, 1], [1937, 2084, 2]], + 'test_boelter2_2': [[0, 145, 0], [146, 224, 0], [225, 271, 0], [272, 392, 0], [393, 454, 0], [455, 762, 0], [763, 982, 0], [983, 1412, 0]], + 'test_boelter_21': [[0, 238, 1], [239, 285, 0], [286, 310, 0], [311, 373, 1], [374, 457, 0], [458, 546, 2], [547, 575, 1], [576, 748, 2], [749, 952, 2]], + 'test_9434_1': [[0, 67, 0], [68, 124, 0], [125, 182, 1], [183, 251, 2], [252, 343, 0], [344, 380, 0], [381, 417, 0]], + 'test_boelter3_6': [[0, 174, 1], [175, 280, 1], [281, 498, 0], [499, 639, 0], [640, 695, 1], [696, 748, 0], [749, 788, 0], [789, 887, 2], [888, 1035, 1], [1036, 1445, 2]], + 'test_boelter3_4': [[0, 158, 1], [159, 309, 1], [310, 477, 1], [478, 668, 1], [669, 780, 1], [781, 817, 0], [818, 848, 1], [849, 942, 1]], + 'test_boelter3_0': [[0, 102, 0], [103, 480, 0], [481, 703, 0], [704, 768, 0]], + 'test_boelter2_12': [[0, 163, 0], [164, 285, 1], [286, 444, 1], [445, 519, 0], [520, 583, 1], [584, 660, 0], [661, 854, 1], [855, 921, 1], [922, 1006, 2], [1007, 1125, 2], [1126, 1332, 2], [1333, 1416, 2]], + 'test_94342_16': [[0, 84, 0], [85, 158, 1], [159, 200, 1], [201, 214, 0], [215, 370, 1], [371, 524, 1], [525, 587, 2], [588, 782, 2], [783, 1009, 2]], + 'test_boelter2_15': [[0, 46, 0], [47, 252, 2], [253, 298, 1], [299, 414, 2], [415, 547, 2], [548, 690, 1], [691, 728, 1], [729, 773, 2], [774, 935, 2]], + 'test_boelter3_13': [[0, 204, 2], [205, 300, 2], [301, 488, 2], [489, 755, 2]], + 'test_boelter3_11': [[0, 254, 1], [255, 424, 0], [425, 598, 1], [599, 692, 0], [693, 772, 2], [773, 878, 2], [879, 960, 2], [961, 1171, 2], [1172, 1397, 2]], + 'test_boelter4_12': [[0, 32, 0], [33, 141, 0], [142, 462, 2], [463, 519, 0], [520, 597, 0], [598, 605, 0], [606, 942, 2], [943, 1232, 2], [1233, 1293, 0]], + 'test_boelter4_9': [[0, 221, 0], [222, 307, 1], [308, 466, 0], [467, 794, 1], [795, 866, 1], [867, 1005, 2], [1006, 1214, 2], [1215, 1270, 0]], + 'test_boelter4_4': [[0, 183, 0], [184, 280, 1], [281, 529, 0], [530, 714, 0]], + 'test_boelter4_1': [[0, 252, 0], [253, 729, 0], [730, 1202, 0], [1203, 1237, 0]], + 'test_boelter4_3': [[0, 117, 0], [118, 200, 1], [201, 293, 1], [294, 404, 1], [405, 600, 1], [601, 800, 1], [801, 905, 1], [906, 1234, 1]], + 'test_boelter_12': [[0, 47, 1], [48, 119, 0], [120, 157, 1], [158, 636, 0], [637, 745, 1], [746, 971, 2]], + 'test_boelter_7': [[0, 69, 0], [70, 118, 1], [119, 133, 0], [134, 187, 0], [188, 239, 0], [240, 328, 1], [329, 376, 0], [377, 397, 1], [398, 491, 0], [492, 564, 0], [565, 619, 1], [620, 688, 1], [689, 774, 0], [775, 862, 0], [863, 897, 0], [898, 958, 1], [959, 1000, 0], [1001, 1178, 0], [1179, 1267, 1], [1268, 1307, 0], [1308, 1327, 0]], + 'test_94342_21': [[0, 13, 0], [14, 66, 2], [67, 594, 2], [595, 1097, 2], [1098, 1133, 0]], + 'test_boelter_2': [[0, 318, 0], [319, 458, 0], [459, 543, 0], [544, 606, 0]] +} + +CLIPS_OBJ_BY_ID_88 = {'test1': ['P2', 'O1', 'O2', 'O3', 'P1', 'O4'], 'test2': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test6': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test7': ['P1', 'P2', 'O1'], 'test_94342_0': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13'], 'test_94342_1': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_94342_10': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_94342_11': ['P1', 'P2', 'O2', 'O1', 'O3', 'O4'], 'test_94342_12': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_94342_13': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_94342_14': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O14', 'O15'], 'test_94342_15': ['P1', 'P2', 'O1'], 'test_94342_16': ['P1', 'P2', 'O4', 'O1'], 'test_94342_17': ['P1', 'P2', 'O1', 'O2', 'O4', 'O5', 'O6'], 'test_94342_18': ['P1', 'P2', 'O1', 'O2'], 'test_94342_19': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O8', 'O9'], 'test_94342_2': ['P2', 'P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_94342_20': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_94342_21': ['P2', 'P1', 'O1', 'O2'], 'test_94342_22': ['P1', 'P2', 'O1', 'O2'], 'test_94342_23': ['P2', 'P1', 'O1', 'O2'], 'test_94342_24': ['P2', 'P1', 'O1', 'O2'], 'test_94342_25': ['P2', 'P1', 'O1'], 'test_94342_26': ['P2', 'P1', 'O1', 'O2'], 'test_94342_3': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13'], 'test_94342_4': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12'], 'test_94342_5': ['P2', 'P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_94342_6': ['P2', 'P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_94342_7': ['P2', 'P1', 'O1', 'O2', 'O3'], 'test_94342_8': ['P2', 'P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_9434_1': ['P1', 'P2', 'O1', 'O2'], 'test_9434_18': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_9434_3': ['P1', 'P2', 'O1', 'O2'], 'test_boelter2_0': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter2_12': ['P1', 'P2', 'O1', 'O2', 'O4', 'O5'], 'test_boelter2_14': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter2_15': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter2_16': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9'], 'test_boelter2_17': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter2_3': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter2_6': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter2_7': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter2_8': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter3_0': ['P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'P2'], 'test_boelter3_1': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_boelter3_10': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter3_11': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter3_12': ['P1', 'P2', 'O1', 'O2'], 'test_boelter3_13': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter3_2': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12'], 'test_boelter3_3': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter3_4': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_5': ['P1', 'P2', 'O1', 'O2'], 'test_boelter3_6': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter3_7': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter3_8': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_9': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter4_0': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter4_1': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_boelter4_10': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_11': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_12': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter4_13': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_2': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_3': ['P1', 'P2', 'O1', 'O5', 'O2', 'O3', 'O4', 'O6'], 'test_boelter4_4': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_5': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter4_6': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_7': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter4_8': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_9': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_1': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O7', 'O6', 'O8'], 'test_boelter_10': ['P1', 'O1', 'P2', 'O2'], 'test_boelter_12': ['P2', 'P1', 'O1', 'O2', 'O3'], 'test_boelter_13': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4'], 'test_boelter_14': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_15': ['P1', 'P2', 'O1', 'O2', 'O3'], 'test_boelter_17': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_18': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter_19': ['P1', 'P2', 'O1', 'O2'], 'test_boelter_2': ['P1', 'P2', 'O1', 'O2', 'O3', 'O5', 'O6', 'O7', 'O8'], 'test_boelter_21': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_3': ['P1', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'P2', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_boelter_4': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter_5': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter_6': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter_7': ['P1', 'P2', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_boelter_9': ['P1', 'P2', 'O1', 'O2']} +CLIPS_OBJ_BY_ID_88_NO_P = {'test1': ['O1', 'O2', 'O3', 'O4'], 'test2': ['O1', 'O2', 'O3', 'O4'], 'test6': ['O1', 'O2', 'O3', 'O4'], 'test7': ['O1'], 'test_94342_0': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13'], 'test_94342_1': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_94342_10': ['O1', 'O2', 'O3', 'O4'], 'test_94342_11': ['O2', 'O1', 'O3', 'O4'], 'test_94342_12': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_94342_13': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_94342_14': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O14', 'O15'], 'test_94342_15': ['O1'], 'test_94342_16': ['O4', 'O1'], 'test_94342_17': ['O1', 'O2', 'O4', 'O5', 'O6'], 'test_94342_18': ['O1', 'O2'], 'test_94342_19': ['O1', 'O2', 'O3', 'O4', 'O5', 'O8', 'O9'], 'test_94342_2': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_94342_20': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_94342_21': ['O1', 'O2'], 'test_94342_22': ['O1', 'O2'], 'test_94342_23': ['O1', 'O2'], 'test_94342_24': ['O1', 'O2'], 'test_94342_25': ['O1'], 'test_94342_26': ['O1', 'O2'], 'test_94342_3': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13'], 'test_94342_4': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12'], 'test_94342_5': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_94342_6': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_94342_7': ['O1', 'O2', 'O3'], 'test_94342_8': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_9434_1': ['O1', 'O2'], 'test_9434_18': ['O1', 'O2', 'O3', 'O4'], 'test_9434_3': ['O1', 'O2'], 'test_boelter2_0': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter2_12': ['O1', 'O2', 'O4', 'O5'], 'test_boelter2_14': ['O1', 'O2', 'O3'], 'test_boelter2_15': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter2_16': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9'], 'test_boelter2_17': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter2_3': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter2_6': ['O1', 'O2', 'O3', 'O4'], 'test_boelter2_7': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter2_8': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter3_0': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_1': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_boelter3_10': ['O1', 'O2', 'O3', 'O4'], 'test_boelter3_11': ['O1', 'O2', 'O3'], 'test_boelter3_12': ['O1', 'O2'], 'test_boelter3_13': ['O1', 'O2', 'O3'], 'test_boelter3_2': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12'], 'test_boelter3_3': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter3_4': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_5': ['O1', 'O2'], 'test_boelter3_6': ['O1', 'O2', 'O3'], 'test_boelter3_7': ['O1', 'O2', 'O3', 'O4'], 'test_boelter3_8': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter3_9': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter4_0': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10'], 'test_boelter4_1': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8'], 'test_boelter4_10': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_11': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_12': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter4_13': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_2': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_3': ['O1', 'O5', 'O2', 'O3', 'O4', 'O6'], 'test_boelter4_4': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_5': ['O1', 'O2', 'O3', 'O4'], 'test_boelter4_6': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_7': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter4_8': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter4_9': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_1': ['O1', 'O2', 'O3', 'O4', 'O5', 'O7', 'O6', 'O8'], 'test_boelter_10': ['O1', 'O2'], 'test_boelter_12': ['O1', 'O2', 'O3'], 'test_boelter_13': ['O1', 'O2', 'O3', 'O4'], 'test_boelter_14': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_15': ['O1', 'O2', 'O3'], 'test_boelter_17': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_18': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6'], 'test_boelter_19': ['O1', 'O2'], 'test_boelter_2': ['O1', 'O2', 'O3', 'O5', 'O6', 'O7', 'O8'], 'test_boelter_21': ['O1', 'O2', 'O3', 'O4', 'O5'], 'test_boelter_3': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_boelter_4': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7'], 'test_boelter_5': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter_6': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11'], 'test_boelter_7': ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14'], 'test_boelter_9': ['O1', 'O2']} +CLIPS_OBJ_NUM_BY_ID_88 = {'test1': 6, 'test2': 6, 'test6': 6, 'test7': 3, 'test_94342_0': 15, 'test_94342_1': 16, 'test_94342_10': 6, 'test_94342_11': 6, 'test_94342_12': 10, 'test_94342_13': 12, 'test_94342_14': 15, 'test_94342_15': 3, 'test_94342_16': 4, 'test_94342_17': 7, 'test_94342_18': 4, 'test_94342_19': 9, 'test_94342_2': 13, 'test_94342_20': 7, 'test_94342_21': 4, 'test_94342_22': 4, 'test_94342_23': 4, 'test_94342_24': 4, 'test_94342_25': 3, 'test_94342_26': 4, 'test_94342_3': 15, 'test_94342_4': 14, 'test_94342_5': 13, 'test_94342_6': 9, 'test_94342_7': 5, 'test_94342_8': 8, 'test_9434_1': 4, 'test_9434_18': 6, 'test_9434_3': 4, 'test_boelter2_0': 12, 'test_boelter2_12': 6, 'test_boelter2_14': 5, 'test_boelter2_15': 7, 'test_boelter2_16': 11, 'test_boelter2_17': 8, 'test_boelter2_3': 12, 'test_boelter2_6': 6, 'test_boelter2_7': 9, 'test_boelter2_8': 13, 'test_boelter3_0': 9, 'test_boelter3_1': 10, 'test_boelter3_10': 6, 'test_boelter3_11': 5, 'test_boelter3_12': 4, 'test_boelter3_13': 5, 'test_boelter3_2': 14, 'test_boelter3_3': 7, 'test_boelter3_4': 9, 'test_boelter3_5': 4, 'test_boelter3_6': 5, 'test_boelter3_7': 6, 'test_boelter3_8': 10, 'test_boelter3_9': 12, 'test_boelter4_0': 12, 'test_boelter4_1': 10, 'test_boelter4_10': 9, 'test_boelter4_11': 6, 'test_boelter4_12': 7, 'test_boelter4_13': 6, 'test_boelter4_2': 6, 'test_boelter4_3': 8, 'test_boelter4_4': 6, 'test_boelter4_5': 6, 'test_boelter4_6': 9, 'test_boelter4_7': 8, 'test_boelter4_8': 9, 'test_boelter4_9': 7, 'test_boelter_1': 10, 'test_boelter_10': 4, 'test_boelter_12': 5, 'test_boelter_13': 6, 'test_boelter_14': 7, 'test_boelter_15': 5, 'test_boelter_17': 7, 'test_boelter_18': 8, 'test_boelter_19': 4, 'test_boelter_2': 9, 'test_boelter_21': 7, 'test_boelter_3': 16, 'test_boelter_4': 9, 'test_boelter_5': 13, 'test_boelter_6': 13, 'test_boelter_7': 16, 'test_boelter_9': 4} + +CLIPS_LEN_BY_ID_88 = {'test_94342_13': 1455, 'test_boelter4_11': 1355, 'test_94342_20': 1865, 'test_94342_0': 1940, 'test_94342_23': 539, 'test_boelter4_5': 1166, 'test_boelter_12': 972, 'test_9434_3': 323, 'test_boelter_15': 1055, 'test_94342_19': 1695, 'test_boelter_21': 953, 'test_boelter3_2': 1326, 'test_boelter4_0': 1322, 'test_boelter_18': 1386, 'test6': 704, 'test_boelter_1': 925, 'test_boelter3_6': 1446, 'test_94342_21': 1134, 'test_boelter4_10': 1263, 'test_9434_1': 418, 'test_94342_17': 1057, 'test_boelter4_9': 1271, 'test_94342_18': 1539, 'test_boelter4_12': 1294, 'test_boelter3_11': 1398, 'test_boelter4_1': 1238, 'test_94342_26': 527, 'test_boelter_10': 654, 'test_boelter4_8': 1006, 'test_boelter3_8': 1161, 'test2': 975, 'test_94342_7': 1386, 'test_94342_16': 1010, 'test_boelter2_17': 1268, 'test_boelter_4': 787, 'test_boelter3_3': 861, 'test_94342_1': 1387, 'test_boelter_13': 1004, 'test_boelter3_1': 1351, 'test_boelter2_8': 1347, 'test_boelter2_14': 920, 'test_boelter2_0': 1143, 'test7': 227, 'test_94342_3': 1776, 'test_boelter2_12': 1417, 'test_94342_8': 1795, 'test_boelter4_7': 1401, 'test_9434_18': 1042, 'test_94342_22': 586, 'test_94342_5': 2292, 'test_boelter3_9': 1383, 'test1': 699, 'test_boelter_6': 1435, 'test_boelter_19': 959, 'test_boelter4_13': 933, 'test_94342_10': 1156, 'test_boelter4_4': 715, 'test_boelter3_4': 943, 'test_boelter2_3': 942, 'test_boelter_5': 834, 'test_94342_12': 2417, 'test_boelter_14': 904, 'test_boelter3_0': 769, 'test_94342_6': 944, 'test_94342_15': 1174, 'test_94342_24': 741, 'test_boelter_2': 607, 'test_boelter_7': 1328, 'test_boelter_3': 596, 'test_94342_4': 1924, 'test_boelter4_2': 1353, 'test_boelter3_13': 756, 'test_94342_25': 568, 'test_boelter2_16': 1734, 'test_boelter3_5': 851, 'test_boelter4_3': 1235, 'test_boelter4_6': 1334, 'test_boelter3_10': 1301, 'test_boelter2_7': 1505, 'test_94342_14': 1841, 'test_boelter3_7': 1544, 'test_boelter2_15': 936, 'test_boelter_9': 636, 'test_boelter2_6': 2100, 'test_boelter3_12': 359, 'test_boelter_17': 817, 'test_94342_11': 1610, 'test_94342_2': 1968} +CLIPS_IDS_88 = ['test_94342_13', 'test_boelter4_11', 'test_94342_20', 'test_94342_0', 'test_94342_23', 'test_boelter4_5', 'test_boelter_12', 'test_9434_3', 'test_boelter_15', 'test_94342_19', 'test_boelter_21', 'test_boelter3_2', 'test_boelter4_0', 'test_boelter_18', 'test6', 'test_boelter_1', 'test_boelter3_6', 'test_94342_21', 'test_boelter4_10', 'test_9434_1', 'test_94342_17', 'test_boelter4_9', 'test_94342_18', 'test_boelter4_12', 'test_boelter3_11', 'test_boelter4_1', 'test_94342_26', 'test_boelter_10', 'test_boelter4_8', 'test_boelter3_8', 'test2', 'test_94342_7', 'test_94342_16', 'test_boelter2_17', 'test_boelter_4', 'test_boelter3_3', 'test_94342_1', 'test_boelter_13', 'test_boelter3_1', 'test_boelter2_8', 'test_boelter2_14', 'test_boelter2_0', 'test7', 'test_94342_3', 'test_boelter2_12', 'test_94342_8', 'test_boelter4_7', 'test_9434_18', 'test_94342_22', 'test_94342_5', 'test_boelter3_9', 'test1', 'test_boelter_6', 'test_boelter_19', 'test_boelter4_13', 'test_94342_10', 'test_boelter4_4', 'test_boelter3_4', 'test_boelter2_3', 'test_boelter_5', 'test_94342_12', 'test_boelter_14', 'test_boelter3_0', 'test_94342_6', 'test_94342_15', 'test_94342_24', 'test_boelter_2', 'test_boelter_7', 'test_boelter_3', 'test_94342_4', 'test_boelter4_2', 'test_boelter3_13', 'test_94342_25', 'test_boelter2_16', 'test_boelter3_5', 'test_boelter4_3', 'test_boelter4_6', 'test_boelter3_10', 'test_boelter2_7', 'test_94342_14', 'test_boelter3_7', 'test_boelter2_15', 'test_boelter_9', 'test_boelter2_6', 'test_boelter3_12', 'test_boelter_17', 'test_94342_11', 'test_94342_2'] +CLIPS_LEN_88 = [1455, 1355, 1865, 1940, 539, 1166, 972, 323, 1055, 1695, 953, 1326, 1322, 1386, 704, 925, 1446, 1134, 1263, 418, 1057, 1271, 1539, 1294, 1398, 1238, 527, 654, 1006, 1161, 975, 1386, 1010, 1268, 787, 861, 1387, 1004, 1351, 1347, 920, 1143, 227, 1776, 1417, 1795, 1401, 1042, 586, 2292, 1383, 699, 1435, 959, 933, 1156, 715, 943, 942, 834, 2417, 904, 769, 944, 1174, 741, 607, 1328, 596, 1924, 1353, 756, 568, 1734, 851, 1235, 1334, 1301, 1505, 1841, 1544, 936, 636, 2100, 359, 817, 1610, 1968] + +CLIPS_WITH_GT_EVENT =['test_boelter2_15', 'test_94342_16', 'test_boelter4_4', 'test_94342_21', 'test_boelter4_1', 'test_boelter4_9', 'test_94342_1', 'test_boelter3_4', 'test_boelter_2', 'test_boelter_21', 'test_boelter4_12', 'test_boelter_7', 'test7', 'test_9434_18', 'test_94342_10', 'test_boelter3_13', 'test_94342_6', 'test1', 'test_boelter_12', 'test_boelter3_0', 'test6', 'test_9434_1', 'test_boelter2_12', 'test_boelter3_6', 'test_boelter4_3', 'test_boelter3_11'] + +CLIPS_LEN_BY_ID = {'test_94342_13': 1455, 'test_boelter4_11': 1355, 'test_94342_20': 1865, 'test_94342_0': 1940, 'test_94342_23': 539, 'test_boelter4_5': 1166, 'test_boelter_12': 972, 'test_9434_3': 323, 'test_boelter_15': 1055, 'test_94342_19': 1695, 'test_boelter_21': 953, 'test_boelter3_2': 1326, 'test_boelter4_0': 1322, 'test_boelter_18': 1386, 'test6': 704, 'test_boelter_1': 925, 'test_boelter3_6': 1446, 'test_94342_21': 1134, 'test_boelter4_10': 1263, 'test_9434_1': 418, 'test_94342_17': 1057, 'test_boelter4_9': 1271, 'test_94342_18': 1539, 'test_boelter4_12': 1294, 'test_boelter3_11': 1398, 'test_boelter4_1': 1238, 'test_94342_26': 527, 'test_boelter_10': 654, 'test_boelter4_8': 1006, 'test_boelter3_8': 1161, 'test2': 975, 'test_94342_7': 1386, 'test_94342_16': 1010, 'test_boelter2_17': 1268, 'test_boelter_4': 787, 'test_boelter3_3': 861, 'test_94342_1': 1387, 'test_boelter_13': 1004, 'test_boelter_24': 315, 'test_boelter3_1': 1351, 'test_boelter2_8': 1347, 'test_boelter2_2': 1413, 'test_boelter2_14': 920, 'test_boelter2_0': 1143, 'test7': 227, 'test_94342_3': 1776, 'test_boelter2_12': 1417, 'test_94342_8': 1795, 'test_boelter4_7': 1401, 'test_9434_18': 1042, 'test_94342_22': 586, 'test_94342_5': 2292, 'test_boelter3_9': 1383, 'test1': 699, 'test_boelter_6': 1435, 'test_boelter_19': 959, 'test_boelter4_13': 933, 'test_94342_10': 1156, 'test_boelter4_4': 715, 'test_boelter3_4': 943, 'test_boelter2_3': 942, 'test_boelter_5': 834, 'test_94342_12': 2417, 'test_boelter_14': 904, 'test_boelter3_0': 769, 'test_94342_6': 944, 'test_94342_15': 1174, 'test_94342_24': 741, 'test_boelter_2': 607, 'test_boelter2_5': 2085, 'test_boelter_7': 1328, 'test_boelter_3': 596, 'test_94342_4': 1924, 'test_boelter4_2': 1353, 'test_boelter3_13': 756, 'test_94342_25': 568, 'test_boelter2_16': 1734, 'test_boelter3_5': 851, 'test_boelter4_3': 1235, 'test_boelter4_6': 1334, 'test_boelter3_10': 1301, 'test_boelter2_7': 1505, 'test_94342_14': 1841, 'test_boelter_22': 828, 'test_boelter3_7': 1544, 'test_boelter2_15': 936, 'test_boelter_9': 636, 'test_boelter_25': 951, 'test_boelter2_6': 2100, 'test_boelter2_4': 1526, 'test_boelter3_12': 359, 'test_boelter_17': 817, 'test_94342_11': 1610, 'test_94342_2': 1968} +CLIPS_LEN = [1455, 1355, 1865, 1940, 539, 1166, 972, 323, 1055, 1695, 953, 1326, 1322, 1386, 704, 925, 1446, 1134, 1263, 418, 1057, 1271, 1539, 1294, 1398, 1238, 527, 654, 1006, 1161, 975, 1386, 1010, 1268, 787, 861, 1387, 1004, 315, 1351, 1347, 1413, 920, 1143, 227, 1776, 1417, 1795, 1401, 1042, 586, 2292, 1383, 699, 1435, 959, 933, 1156, 715, 943, 942, 834, 2417, 904, 769, 944, 1174, 741, 607, 2085, 1328, 596, 1924, 1353, 756, 568, 1734, 851, 1235, 1334, 1301, 1505, 1841, 828, 1544, 936, 636, 951, 2100, 1526, 359, 817, 1610, 1968] +CLIPS_IDS = ['test_94342_13', 'test_boelter4_11', 'test_94342_20', 'test_94342_0', 'test_94342_23', 'test_boelter4_5', 'test_boelter_12', 'test_9434_3', 'test_boelter_15', 'test_94342_19', 'test_boelter_21', 'test_boelter3_2', 'test_boelter4_0', 'test_boelter_18', 'test6', 'test_boelter_1', 'test_boelter3_6', 'test_94342_21', 'test_boelter4_10', 'test_9434_1', 'test_94342_17', 'test_boelter4_9', 'test_94342_18', 'test_boelter4_12', 'test_boelter3_11', 'test_boelter4_1', 'test_94342_26', 'test_boelter_10', 'test_boelter4_8', 'test_boelter3_8', 'test2', 'test_94342_7', 'test_94342_16', 'test_boelter2_17', 'test_boelter_4', 'test_boelter3_3', 'test_94342_1', 'test_boelter_13', 'test_boelter_24', 'test_boelter3_1', 'test_boelter2_8', 'test_boelter2_2', 'test_boelter2_14', 'test_boelter2_0', 'test7', 'test_94342_3', 'test_boelter2_12', 'test_94342_8', 'test_boelter4_7', 'test_9434_18', 'test_94342_22', 'test_94342_5', 'test_boelter3_9', 'test1', 'test_boelter_6', 'test_boelter_19', 'test_boelter4_13', 'test_94342_10', 'test_boelter4_4', 'test_boelter3_4', 'test_boelter2_3', 'test_boelter_5', 'test_94342_12', 'test_boelter_14', 'test_boelter3_0', 'test_94342_6', 'test_94342_15', 'test_94342_24', 'test_boelter_2', 'test_boelter2_5', 'test_boelter_7', 'test_boelter_3', 'test_94342_4', 'test_boelter4_2', 'test_boelter3_13', 'test_94342_25', 'test_boelter2_16', 'test_boelter3_5', 'test_boelter4_3', 'test_boelter4_6', 'test_boelter3_10', 'test_boelter2_7', 'test_94342_14', 'test_boelter_22', 'test_boelter3_7', 'test_boelter2_15', 'test_boelter_9', 'test_boelter_25', 'test_boelter2_6', 'test_boelter2_4', 'test_boelter3_12', 'test_boelter_17', 'test_94342_11', 'test_94342_2'] + +ALL_IDS = ['test_boelter_15', 'test6', 'test_94342_7', 'test_94342_10', 'test_94342_21', 'test_boelter2_8', 'test_boelter4_1', 'test_boelter3_1', 'test_boelter3_11', 'test_boelter3_10', 'test_boelter3_0', 'test_boelter4_0', 'test_94342_20', 'test_94342_11', 'test_94342_6', 'test_boelter_1', 'test_9434_3', 'test7', 'test_boelter_14', 'test_boelter2_12', 'test_boelter2_3', 'test_9434_1', 'test_94342_4', 'test_boelter_3', 'test_boelter3_8', 'test_94342_13', 'test_boelter4_8', 'test_94342_22', 'test_boelter_9', 'test_boelter4_2', 'test_boelter3_2', 'test_94342_19', 'test_boelter3_12', 'test_boelter3_13', 'test_boelter3_3', 'test_94342_18', 'test_boelter4_3', 'test_9434_18', 'test_94342_23', 'test_boelter4_9', 'test_boelter3_9', 'test_94342_12', 'test_94342_5', 'test_boelter_2', 'test_boelter2_0', 'test_boelter_17', 'test_boelter2_15', 'test_boelter_13', 'test_boelter_6', 'test_94342_1', 'test_94342_16', 'test_boelter4_10', 'test_boelter_19', 'test_boelter4_7', 'test_boelter3_7', 'test_boelter3_6', 'test_boelter4_6', 'test_boelter_18', 'test_boelter4_11', 'test_94342_26', 'test_94342_17', 'test1', 'test_boelter_7', 'test_94342_0', 'test_boelter_12', 'test_boelter2_14', 'test_boelter_10', 'test_boelter2_7', 'test_boelter2_16', 'test_boelter_5', 'test_94342_2', 'test_94342_15', 'test_94342_8', 'test_94342_24', 'test_boelter4_13', 'test_boelter4_4', 'test_boelter_21', 'test_boelter3_4', 'test_boelter3_5', 'test_boelter4_5', 'test_boelter4_12', 'test_94342_25', 'test_94342_14', 'test2', 'test_boelter_4', 'test_94342_3', 'test_boelter2_6', 'test_boelter2_17'] + +UNIQUE_OBJ_IDS = ['O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8', 'O9', 'O10', 'O11', 'O12', 'O13', 'O14', 'O15'] + +mind_test_clips = ['test_boelter4_5.p', 'test_94342_2.p', 'test_boelter4_10.p', 'test_boelter2_3.p', 'test_94342_20.p', 'test_boelter3_9.p', 'test_boelter4_6.p', 'test2.p', 'test_boelter4_2.p', 'test_94342_24.p', 'test_94342_17.p', 'test_94342_8.p', 'test_94342_11.p', 'test_boelter3_7.p', 'test_94342_18.p', 'test_boelter_10.p', 'test_boelter3_8.p', 'test_boelter2_6.p', 'test_boelter4_7.p', 'test_boelter4_8.p', 'test_boelter4_0.p', 'test_boelter2_17.p', 'test_boelter3_12.p', 'test_boelter3_5.p', 'test_94342_4.p', 'test_94342_15.p'] + + +def count_parameters(model): + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([np.prod(p.size()) for p in model_parameters]) + + +def split_train_val_test(): + # Calculate the total number of frames + total_frames = sum(CLIPS_LEN_BY_ID_88.values()) + + # Calculate the number of frames for each split + train_frames = int(total_frames * 0.6) + validation_frames = int(total_frames * 0.2) + + # Convert the dictionary to a list of tuples (video_id, frame_length) + video_list = list(CLIPS_LEN_BY_ID_88.items()) + + # Shuffle the video list randomly + random.shuffle(video_list) + + # Split the videos based on the number of frames + train_ids = [] + validation_ids = [] + test_ids = [] + frames_count = 0 + + for video_id, frames in video_list: + if frames_count < train_frames: + train_ids.append(video_id) + elif frames_count < train_frames + validation_frames: + validation_ids.append(video_id) + else: + test_ids.append(video_id) + frames_count += frames + + # Print the results + print("Train IDs:", train_ids) + print("Validation IDs:", validation_ids) + print("Test IDs:", test_ids) + + train_frames_total = sum(CLIPS_LEN_BY_ID_88[video_id] for video_id in train_ids) + validation_frames_total = sum(CLIPS_LEN_BY_ID_88[video_id] for video_id in validation_ids) + test_frames_total = sum(CLIPS_LEN_BY_ID_88[video_id] for video_id in test_ids) + + print("Total frames for train_ids:", train_frames_total) + print("Total frames for validation_ids:", validation_frames_total) + print("Total frames for test_ids:", test_frames_total) + + return train_ids, validation_ids, test_ids + + +def compute_f1_scores(m1_pred, m1_label, m2_pred, m2_label, m12_pred, m12_label, m21_pred, m21_label, mc_pred, mc_label, verbose=False): + # Compute F1 score for m1 + m1_pred_labels = torch.argmax(m1_pred, dim=-1).cpu().numpy() + m1_true_labels = m1_label.cpu().numpy() + if verbose: print(f'm1 --- pred {np.unique(m1_pred_labels, return_counts=True)}, true {np.unique(m1_true_labels, return_counts=True)}') + m1_f1_score = f1_score(m1_true_labels, m1_pred_labels, average='macro') + + # Compute F1 score for m2 + m2_pred_labels = torch.argmax(m2_pred, dim=-1).cpu().numpy() + m2_true_labels = m2_label.cpu().numpy() + if verbose: print(f'm2 --- pred {np.unique(m2_pred_labels, return_counts=True)}, true {np.unique(m2_true_labels, return_counts=True)}') + m2_f1_score = f1_score(m2_true_labels, m2_pred_labels, average='macro') + + # Compute F1 score for m12 + m12_pred_labels = torch.argmax(m12_pred, dim=-1).cpu().numpy() + m12_true_labels = m12_label.cpu().numpy() + if verbose: print(f'm12 --- pred {np.unique(m12_pred_labels, return_counts=True)}, true {np.unique(m12_true_labels, return_counts=True)}') + m12_f1_score = f1_score(m12_true_labels, m12_pred_labels, average='macro') + + # Compute F1 score for m21 + m21_pred_labels = torch.argmax(m21_pred, dim=-1).cpu().numpy() + m21_true_labels = m21_label.cpu().numpy() + if verbose: print(f'm21 --- pred {np.unique(m21_pred_labels, return_counts=True)}, true {np.unique(m21_true_labels, return_counts=True)}') + m21_f1_score = f1_score(m21_true_labels, m21_pred_labels, average='macro') + + # Compute F1 score for mc + mc_pred_labels = torch.argmax(mc_pred, dim=-1).cpu().numpy() + mc_true_labels = mc_label.cpu().numpy() + if verbose: print(f'mc --- pred {np.unique(mc_pred_labels, return_counts=True)}, true {np.unique(mc_true_labels, return_counts=True)}') + mc_f1_score = f1_score(mc_true_labels, mc_pred_labels, average='macro') + + return m1_f1_score, m2_f1_score, m12_f1_score, m21_f1_score, mc_f1_score \ No newline at end of file diff --git a/tbd/utils/preprocess_img.py b/tbd/utils/preprocess_img.py new file mode 100644 index 0000000..b1680d1 --- /dev/null +++ b/tbd/utils/preprocess_img.py @@ -0,0 +1,37 @@ +import glob + +import cv2 + +import torchvision.transforms as T +import torch +import os +from tqdm import tqdm + + +PATH_IN = "/scratch/bortoletto/data/tbd/images" +PATH_OUT = "/scratch/bortoletto/data/tbd/images_norm" + +normalisation_steps = [ + T.ToTensor(), + T.Resize((128,128)), + T.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) +] + +preprocess_img = T.Compose(normalisation_steps) + +def main(): + print(f"{PATH_IN}/*/*/*.jpg") + all_img = glob.glob(f"{PATH_IN}/*/*/*.jpg") + print(len(all_img)) + for img_path in tqdm(all_img): + new_img = preprocess_img(cv2.imread(img_path)).numpy() + img_path_split = img_path.split("/") + os.makedirs(f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}", exist_ok=True) + out_img = f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}/{img_path_split[-1][:-4]}.pt" + torch.save(new_img, out_img) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tbd/utils/reformat_labels_ours.py b/tbd/utils/reformat_labels_ours.py new file mode 100644 index 0000000..9500a6f --- /dev/null +++ b/tbd/utils/reformat_labels_ours.py @@ -0,0 +1,106 @@ +import pandas as pd +import os +import glob +import pickle + +DATASET_LOCATION = "YOUR_PATH_HERE" + +def reframe_annotation(): + annotation_path = f'{DATASET_LOCATION}/retrieve_annotation/all/' + save_path = f'{DATASET_LOCATION}/reformat_annotation/' + if not os.path.exists(save_path): + os.makedirs(save_path) + tasks = glob.glob(annotation_path + '*.txt') + id_map = pd.read_csv('id_map.csv') + for task in tasks: + if not task.split('/')[-1].split('_')[2] == '1.txt': + continue + with open(task, 'r') as f: + lines = f.readlines() + task_id = int(task.split('/')[-1].split('_')[1]) + 1 + clip = id_map.loc[id_map['ID'] == task_id].folder + print(task_id, len(clip)) + if len(clip) == 0: + continue + with open(save_path + clip.item() + '.txt', 'w') as f: + for line in lines: + words = line.split() + f.write(words[0] + ',' + words[1] + ',' + words[2] + ',' + words[3] + ',' + words[4] + ',' + words[5] + + ',' + words[6] + ',' + words[7] + ',' + words[8] + ',' + words[9] + ',' + ' '.join(words[10:]) + '\n') + f.close() + +def get_grid_location(obj_frame): + x_min = obj_frame['x_min']#.item() + y_min = obj_frame['y_min']#.item() + x_max = obj_frame['x_max']#.item() + y_max = obj_frame['y_max']#.item() + gridLW = 1280 / 25. + gridLH = 720 / 15. + center_x, center_y = (x_min + x_max)/2, (y_min + y_max)/2 + X, Y = int(center_x / gridLW), int(center_y / gridLH) + return X, Y + +def regenerate_annotation(): + annotation_path = f'{DATASET_LOCATION}/reformat_annotation/' + save_path=f'{DATASET_LOCATION}/regenerate_annotation/' + if not os.path.exists(save_path): + os.makedirs(save_path) + tasks = glob.glob(annotation_path + '*.txt') + for task in tasks: + print(task) + annt = pd.read_csv(task, sep=",", header=None) + annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name", "label"] + obj_records = {} + for index, obj_frame in annt.iterrows(): + if obj_frame['name'].startswith('P'): + continue + else: + assert obj_frame['name'].startswith('O') + obj_name = obj_frame['name'] + # 0: enter 1: disappear 2: update 3: unchange + frame_id = obj_frame['frame'] + curr_loc = get_grid_location(obj_frame) + mind_dict = {'m1': {'fluent': 3, 'loc': None}, 'm2': {'fluent': 3, 'loc': None}, + 'm12': {'fluent': 3, 'loc': None}, + 'm21': {'fluent': 3, 'loc': None}, 'mc': {'fluent': 3, 'loc': None}, + 'mg': {'fluent': 3, 'loc': curr_loc}} + mind_dict['mg']['loc'] = curr_loc + if not type(obj_frame['label']) == float: + mind_labels = obj_frame['label'].split() + for mind_label in mind_labels: + if mind_label == 'in_m1' or mind_label == 'in_m2' or mind_label == 'in_m12' \ + or mind_label == 'in_m21' or mind_label == 'in_mc' or mind_label == '"in_m1"' or mind_label == '"in_m2"'\ + or mind_label == '"in_m12"' or mind_label == '"in_m21"' or mind_label == '"in_mc"': + mind_name = mind_label.split('_')[1].split('"')[0] + mind_dict[mind_name]['loc'] = curr_loc + else: + mind_name = mind_label.split('_')[0].split('"') + if len(mind_name) > 1: + mind_name = mind_name[1] + else: + mind_name = mind_name[0] + last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc'] + mind_dict[mind_name]['loc'] = last_loc + + for mind_name in mind_dict.keys(): + if frame_id > 0: + curr_loc = mind_dict[mind_name]['loc'] + last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc'] + if last_loc is None and curr_loc is not None: + mind_dict[mind_name]['fluent'] = 0 + elif last_loc is not None and curr_loc is None: + mind_dict[mind_name]['fluent'] = 1 + elif not last_loc == curr_loc: + mind_dict[mind_name]['fluent'] = 2 + if obj_name not in obj_records: + obj_records[obj_name] = [mind_dict] + else: + obj_records[obj_name].append(mind_dict) + + with open(save_path + task.split('/')[-1].split('.')[0] + '.p', 'wb') as f: + pickle.dump(obj_records, f) + + +if __name__ == '__main__': + reframe_annotation() + regenerate_annotation() \ No newline at end of file diff --git a/tbd/utils/similarity.py b/tbd/utils/similarity.py new file mode 100644 index 0000000..5517fcf --- /dev/null +++ b/tbd/utils/similarity.py @@ -0,0 +1,75 @@ +import os +import torch +import numpy as np +import torch.nn.functional as F +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA +import seaborn as sns + + +FOLDER_PATH = 'PATH_TO_FOLDER' + +print(FOLDER_PATH) + +MTOM_COLORS = { + "MN1": (110/255, 117/255, 161/255), + "MN2": (179/255, 106/255, 98/255), + "Base": (193/255, 198/255, 208/255), + "CG": (170/255, 129/255, 42/255), + "IC": (97/255, 112/255, 83/255), + "DB": (144/255, 63/255, 110/255) +} + +COLORS = sns.color_palette() + +sns.set_theme(style='white') + +out_left_main_mods_full_test = [] +out_right_main_mods_full_test = [] +cell_left_main_mods_full_test = [] +cell_right_main_mods_full_test = [] +cm_left_main_mods_full_test = [] +cm_right_main_mods_full_test = [] + +for i in range(len([filename for filename in os.listdir(FOLDER_PATH) if filename.endswith('.pt')])): + + print(f'Computing analysis for test video {i}...', end='\r') + + emb_file = os.path.join(FOLDER_PATH, f'{i}.pt') + data = torch.load(emb_file) + if len(data) == 13: # implicit + model = 'impl' + out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:] + elif len(data) == 12: # common mind + model = 'cm' + out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:] + elif len(data) == 11: # speaker-listener + model = 'sl' + out_left, out_right, feats = data[0], data[1], data[2:] + else: raise ValueError("Data should have 13 (impl), others are not implemented") + + # ====== PCA for left and right embeddings ====== # + + out_left_pca = out_left[0].reshape(-1, 64) + out_right_pca = out_right[0].reshape(-1, 64) + out_left_and_right = np.concatenate((out_left_pca, out_right_pca), axis=0) + + pca = PCA(n_components=2) + pca_result = pca.fit_transform(out_left_and_right) + + # Separate the PCA results for each tensor + pca_result_left = pca_result[:out_left_pca.shape[0]] + pca_result_right = pca_result[out_right_pca.shape[0]:] + + plt.figure(figsize=(6.8,6)) + plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100) + plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100) + plt.xlabel('Principal Component 1', fontsize=32) + plt.ylabel('Principal Component 2', fontsize=32) + plt.xticks(fontsize=24) + plt.xticks([-0.4, -0.2, 0.0, 0.2, 0.4]) + plt.yticks(fontsize=24) + plt.legend(fontsize=32) + plt.tight_layout() + plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf') + plt.close() \ No newline at end of file diff --git a/tbd/utils/store_mind_set.py b/tbd/utils/store_mind_set.py new file mode 100644 index 0000000..c1ba5f0 --- /dev/null +++ b/tbd/utils/store_mind_set.py @@ -0,0 +1,96 @@ +import os +import pandas as pd +import pickle +from tqdm import tqdm + + +def check_append(obj_name, m1, mind_name, obj_frame, flags, label): + if label: + if not obj_name in m1: + m1[obj_name] = [] + m1[obj_name].append( + [obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0]) + flags[mind_name] = 1 + elif not flags[mind_name]: + m1[obj_name].append( + [obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0]) + flags[mind_name] = 1 + else: # false belief + if obj_name in m1: + if flags[mind_name]: + m1[obj_name].append( + [obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 1]) + flags[mind_name] = 0 + return flags, m1 + + +def store_mind_set(clip, annotation_path, save_path): + if not os.path.exists(save_path): + os.makedirs(save_path) + annt = pd.read_csv(annotation_path + clip, sep=",", header=None) + annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name", + "label"] + obj_names = annt.name.unique() + m1, m2, m12, m21, mc = {}, {}, {}, {}, {} + flags = {'m1':0, 'm2':0, 'm12':0, 'm21':0, 'mc':0} + for obj_name in obj_names: + if obj_name == 'P1' or obj_name == 'P2': + continue + obj_frames = annt.loc[annt.name == obj_name] + for index, obj_frame in obj_frames.iterrows(): + if type(obj_frame.label) == float: + continue + labels = obj_frame.label.split() + for label in labels: + if label == 'in_m1' or label == '"in_m1"': + flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 1) + elif label == 'in_m2' or label == '"in_m2"': + flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 1) + elif label == 'in_m12'or label == '"in_m12"': + flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 1) + elif label == 'in_m21' or label == '"in_m21"': + flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 1) + elif label == 'in_mc'or label == '"in_mc"': + flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 1) + elif label == 'm1_false' or label == '"m1_false"': + flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 0) + flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0) + flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0) + false_belief = 'm1_false' + with open(save_path + clip.split('.')[0] + '.txt', "a") as file: + file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n") + elif label == 'm2_false' or label == '"m2_false"': + flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 0) + flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0) + flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0) + false_belief = 'm2_false' + with open(save_path + clip.split('.')[0] + '.txt', "a") as file: + file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n") + elif label == 'm12_false' or label == '"m12_false"': + flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0) + flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0) + false_belief = 'm12_false' + with open(save_path + clip.split('.')[0] + '.txt', "a") as file: + file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n") + elif label == 'm21_false' or label == '"m21_false"': + flags, m21 = check_append(obj_name, m2, 'm21', obj_frame, flags, 0) + flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0) + false_belief = 'm21_false' + with open(save_path + clip.split('.')[0] + '.txt', "a") as file: + file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n") + # print('m1', m1) + # print('m2', m2) + # print('m12', m12) + # print('m21', m21) + # print('mc', mc) + #with open(save_path + clip.split('.')[0] + '.p', 'wb') as f: + # pickle.dump([m1, m2, m12, m21, mc], f) + + +if __name__ == "__main__": + + annotation_path = '/scratch/bortoletto/data/tbd/reformat_annotation/' + save_path = '/scratch/bortoletto/data/tbd/store_mind_set/' + + for clip in tqdm(os.listdir(annotation_path), desc="Processing videos", unit="item"): + store_mind_set(clip, annotation_path, save_path) \ No newline at end of file diff --git a/tbd/utils/visualize_bbox.py b/tbd/utils/visualize_bbox.py new file mode 100644 index 0000000..fc33567 --- /dev/null +++ b/tbd/utils/visualize_bbox.py @@ -0,0 +1,95 @@ +import time +from tbd_dataloader import TBDv2Dataset + +import numpy as np + +import matplotlib.pyplot as plt +import matplotlib.patches as patches + +def point2screen(points): + K = [607.13232421875, 0.0, 638.6468505859375, 0.0, 607.1067504882812, 367.1607360839844, 0.0, 0.0, 1.0] + K = np.reshape(np.array(K), [3, 3]) + rot_points = np.array(points) + np.array([0, 0.2, 0]) + rot_points = rot_points + points_camera = rot_points.reshape(3, 1) + + project_matrix = np.array(K).reshape(3, 3) + points_prj = project_matrix.dot(points_camera) + points_prj = points_prj.transpose() + if not points_prj[:, 2][0] == 0.0: + points_prj[:, 0] = points_prj[:, 0] / points_prj[:, 2] + points_prj[:, 1] = points_prj[:, 1] / points_prj[:, 2] + points_screen = points_prj[:, :2] + assert points_screen.shape == (1, 2) + points_screen = points_screen.reshape(-1) + return points_screen + +if __name__ == '__main__': + data = TBDv2Dataset(number_frames_to_sample=1, resize_img=None) + index = np.random.randint(0, len(data)) + start = time.time() + ( + kinect_imgs, # <- len x 720 x 1280 x 3 + tracker_imgs, + battery_imgs, + skele1, + skele2, + bbox, + tracker_skeID_sample, # <- This is the tracker skeleton ID + tracker2d, + label, + experiment_id, # From here for debugging + timestep, + obj_id, # <- This is the object ID as a string + ) = data[index] + end = time.time() + print(f"Time for one sample: {end-start}") + + img = kinect_imgs[-1] + bbox = bbox[-1] + print(label.shape) + + print(skele1.shape) + print(skele2.shape) + + skele1 = skele1[-1, :,:] + skele2 = skele2[-1, :,:] + + print(skele1.shape) + + + + # reshape img from c, h, w to h, w, c + img = img.permute(1, 2, 0) + + fig, ax = plt.subplots(1) + ax.imshow(img) + print(bbox[0], bbox[1], bbox[2], bbox[3]) # t(top left x, top left y, width, height) + top_left_x, top_left_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3] + x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3] + + + + + for i in range(26): + print(skele1[i,0], skele1[i,1]) + print(skele1[i,:].shape) + print(point2screen(skele1[i,:])) + x, y = point2screen(skele1[i,:])[0], point2screen(skele1[i,:])[1] + ax.text(x, y, f"{i}", fontsize=5, color='w') + + wedge = patches.Wedge((x,y), 10, 0, 360, width=10, color='b') + ax.add_patch(wedge) + + for i in range(26): + x, y = point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1] + ax.text(x, y, f"{i}", fontsize=5, color='w') + wedge = patches.Wedge((point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]), 10, 0, 360, width=10, color='r') + ax.add_patch(wedge) + + # Create a Rectangle patch + # rect = patches.Rectangle((top_left_x, top_left_y-height), width, height, linewidth=1, edgecolor='r', facecolor='none') + # ax.add_patch(rect) + # rect = patches.Rectangle((x_min, y_max), x_max-x_min, y_max-y_min, linewidth=1, edgecolor='g', facecolor='none') + # ax.add_patch(rect) + fig.savefig(f"bbox_{obj_id}_{index}_{experiment_id}.png")