up
This commit is contained in:
parent
d4aaf7f4ad
commit
25b8b3f343
55 changed files with 7592 additions and 4 deletions
14
README.md
14
README.md
|
@ -22,11 +22,17 @@ This is a temporary reference that will be updated after the proceedings are pub
|
|||
}
|
||||
```
|
||||
|
||||
<br>
|
||||
<br>
|
||||
<br>
|
||||
# Code
|
||||
|
||||
Under construction
|
||||
This repository has the following structure:
|
||||
|
||||
```
|
||||
mtomnet
|
||||
├── boss
|
||||
└── tbd
|
||||
```
|
||||
|
||||
We have one subfolder for dataset, containing the code to run the corresponding experiments. Inside each subfolder we provide a README with further instructions.
|
||||
|
||||
|
||||
[1]: https://mattbortoletto.github.io/
|
||||
|
|
206
boss/.gitignore
vendored
Normal file
206
boss/.gitignore
vendored
Normal file
|
@ -0,0 +1,206 @@
|
|||
experiments/
|
||||
predictions_*
|
||||
predictions/
|
||||
tmp/
|
||||
|
||||
### WANDB ###
|
||||
|
||||
wandb/*
|
||||
wandb
|
||||
wandb-debug.log
|
||||
logs
|
||||
summary*
|
||||
run*
|
||||
|
||||
### SYSTEM ###
|
||||
|
||||
*/__pycache__/*
|
||||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
|
||||
|
||||
### Linux ###
|
||||
*~
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python,linux
|
16
boss/README.md
Normal file
16
boss/README.md
Normal file
|
@ -0,0 +1,16 @@
|
|||
# BOSS
|
||||
|
||||
## Installing Dependencies
|
||||
Run `conda env create -f environment.yml`.
|
||||
|
||||
## New bounding box annotations
|
||||
We re-extracted bounding box annotations using Yolo-v8. The new data is in `new_bbox/`. The rest of the data can be found [here](https://drive.google.com/drive/folders/1b8FdpyoWx9gUps-BX6qbE9Kea3C2Uyua).
|
||||
|
||||
## Train
|
||||
`source run_train.sh`.
|
||||
|
||||
## Test
|
||||
`source run_test.sh` (specify the path to the model).
|
||||
|
||||
## Resources
|
||||
The original project page for the BOSS dataset can be found [here](https://sites.google.com/view/bossbelief/). Our code is based on the original implementation.
|
500
boss/dataloader.py
Normal file
500
boss/dataloader.py
Normal file
|
@ -0,0 +1,500 @@
|
|||
from torch.utils.data import Dataset
|
||||
import os
|
||||
from torchvision import transforms
|
||||
from natsort import natsorted
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
import cv2
|
||||
import ast
|
||||
from einops import rearrange
|
||||
from utils import spherical2cartesial
|
||||
|
||||
|
||||
class Data(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
frame_path,
|
||||
label_path,
|
||||
pose_path=None,
|
||||
gaze_path=None,
|
||||
bbox_path=None,
|
||||
ocr_graph_path=None,
|
||||
presaved=128,
|
||||
sizes = (128,128),
|
||||
spatial_patch_size: int = 16,
|
||||
temporal_patch_size: int = 4,
|
||||
img_channels: int = 3,
|
||||
patch_data: bool = False,
|
||||
flatten_dim: int = 1
|
||||
):
|
||||
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
|
||||
self.size_1, self.size_2 = sizes
|
||||
self.patch_data = patch_data
|
||||
if self.patch_data:
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
|
||||
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
|
||||
|
||||
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
if presaved == 224:
|
||||
self.presaved = 'presaved'
|
||||
elif presaved == 128:
|
||||
self.presaved = 'presaved128'
|
||||
else: raise ValueError
|
||||
|
||||
frame_dirs = os.listdir(frame_path)
|
||||
episode_ids = natsorted(frame_dirs)
|
||||
seq_lens = []
|
||||
for frame_dir in natsorted(frame_dirs):
|
||||
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
||||
seq_lens.append(len(frame_paths))
|
||||
self.data['frame_paths'].append(frame_paths)
|
||||
|
||||
print('Labels...')
|
||||
with open(label_path, 'rb') as fp:
|
||||
labels = pickle.load(fp)
|
||||
left_labels, right_labels = labels
|
||||
for i in episode_ids:
|
||||
episode_id = int(i)
|
||||
left_labels_i = left_labels[episode_id]
|
||||
right_labels_i = right_labels[episode_id]
|
||||
self.data['labels'].append([left_labels_i, right_labels_i])
|
||||
fp.close()
|
||||
|
||||
print('Pose...')
|
||||
self.pose_path = pose_path
|
||||
if pose_path is not None:
|
||||
for i in episode_ids:
|
||||
pose_i_dir = os.path.join(pose_path, i)
|
||||
poses = []
|
||||
for j in natsorted(os.listdir(pose_i_dir)):
|
||||
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
|
||||
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
|
||||
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
|
||||
else:
|
||||
poses.append(fs.getNode("pose_0").mat())
|
||||
poses = np.array(poses)
|
||||
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
|
||||
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
|
||||
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
|
||||
#self.data['poses'].append(torch.tensor(np.array(poses)))
|
||||
|
||||
print('Gaze...')
|
||||
"""
|
||||
Gaze information is represented by a 3D gaze vector based on the observing camera’s Cartesian eye coordinate system
|
||||
"""
|
||||
self.gaze_path = gaze_path
|
||||
if gaze_path is not None:
|
||||
for i in episode_ids:
|
||||
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
|
||||
with open(gaze_i_txt, 'r') as fp:
|
||||
gaze_i = fp.readlines()
|
||||
fp.close()
|
||||
gaze_i = ast.literal_eval(gaze_i[0])
|
||||
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
|
||||
for j in range(len(gaze_i)):
|
||||
if len(gaze_i[j]) >= 2:
|
||||
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
|
||||
elif len(gaze_i[j]) == 1:
|
||||
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
|
||||
else:
|
||||
continue
|
||||
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
|
||||
|
||||
print('Bbox...')
|
||||
self.bbox_path = bbox_path
|
||||
if bbox_path is not None:
|
||||
self.objects = [
|
||||
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
|
||||
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
|
||||
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
|
||||
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
|
||||
]
|
||||
"""
|
||||
# NOTE: old bbox
|
||||
for i in episode_ids:
|
||||
bbox_i_dir = os.path.join(bbox_path, i)
|
||||
with open(bbox_i_dir, 'rb') as fp:
|
||||
bboxes_i = pickle.load(fp)
|
||||
len_i = len(bboxes_i)
|
||||
fp.close()
|
||||
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
|
||||
for j in range(len(bboxes_i)):
|
||||
items_i_j, bboxes_i_j = bboxes_i[j]
|
||||
for k in range(len(items_i_j)):
|
||||
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
|
||||
bboxes_i_j[k][0] / 1920, # * self.size_1,
|
||||
bboxes_i_j[k][1] / 1088, # * self.size_2,
|
||||
bboxes_i_j[k][2] / 1920, # * self.size_1,
|
||||
bboxes_i_j[k][3] / 1088, # * self.size_2
|
||||
]) # [x_min, y_min, x_max, y_max]
|
||||
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
|
||||
"""
|
||||
# NOTE: new bbox
|
||||
for i in episode_ids:
|
||||
bbox_dir = os.path.join(bbox_path, i)
|
||||
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
|
||||
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
|
||||
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
|
||||
bbox_content = fp.readlines()
|
||||
fp.close()
|
||||
for bbox_content_line in bbox_content:
|
||||
bbox_content_values = bbox_content_line.split()
|
||||
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
|
||||
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
|
||||
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
|
||||
|
||||
print('OCR...\n')
|
||||
self.ocr_graph_path = ocr_graph_path
|
||||
if ocr_graph_path is not None:
|
||||
ocr_graph = [
|
||||
[15, [10, 4], [17, 2]],
|
||||
[13, [16, 7], [18, 4]],
|
||||
[11, [16, 4], [7, 10]],
|
||||
[14, [10, 11], [7, 1]],
|
||||
[12, [10, 9], [16, 3]],
|
||||
[1, [7, 2], [9, 9], [10, 2]],
|
||||
[5, [8, 8], [6, 8]],
|
||||
[4, [9, 8], [7, 6]],
|
||||
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
|
||||
[2, [10, 1], [7, 7], [9, 3]],
|
||||
[19, [10, 2], [26, 6]],
|
||||
[20, [10, 7], [26, 5]],
|
||||
[22, [25, 4], [10, 8]],
|
||||
[23, [25, 15]],
|
||||
[21, [16, 5], [24, 8]]
|
||||
]
|
||||
ocr_tensor = torch.zeros((27, 27))
|
||||
for ocr in ocr_graph:
|
||||
obj = ocr[0]
|
||||
contexts = ocr[1:]
|
||||
total_context_count = sum([i[1] for i in contexts])
|
||||
for context in contexts:
|
||||
ocr_tensor[obj, context[0]] = context[1] / total_context_count
|
||||
ocr_tensor = torch.flatten(ocr_tensor)
|
||||
for i in episode_ids:
|
||||
self.data['ocr_graph'].append(ocr_tensor)
|
||||
|
||||
self.frame_path = frame_path
|
||||
self.transform = transforms.Compose([
|
||||
#transforms.ToPILImage(),
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data['frame_paths'])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
frames_path = self.data['frame_paths'][idx][0].split('/')
|
||||
frames_path.insert(5, self.presaved)
|
||||
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
|
||||
with open(frames_path, 'rb') as f:
|
||||
images = pickle.load(f)
|
||||
images = torch.stack(images)
|
||||
if self.patch_data:
|
||||
images = rearrange(
|
||||
images,
|
||||
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
|
||||
p1=self.spatial_patch_size,
|
||||
p2=self.spatial_patch_size,
|
||||
p3=self.temporal_patch_size
|
||||
)
|
||||
|
||||
if self.pose_path is not None:
|
||||
pose = self.data['poses'][idx]
|
||||
if self.patch_data:
|
||||
pose = rearrange(
|
||||
pose,
|
||||
'(t p1) d -> t p1 d',
|
||||
p1=self.temporal_patch_size
|
||||
)
|
||||
else:
|
||||
pose = None
|
||||
|
||||
if self.gaze_path is not None:
|
||||
gaze = self.data['gazes'][idx]
|
||||
if self.patch_data:
|
||||
gaze = rearrange(
|
||||
gaze,
|
||||
'(t p1) d -> t p1 d',
|
||||
p1=self.temporal_patch_size
|
||||
)
|
||||
else:
|
||||
gaze = None
|
||||
|
||||
if self.bbox_path is not None:
|
||||
bbox = self.data['bboxes'][idx]
|
||||
if self.patch_data:
|
||||
bbox = rearrange(
|
||||
bbox,
|
||||
'(t p1) d -> t p1 d',
|
||||
p1=self.temporal_patch_size
|
||||
)
|
||||
else:
|
||||
bbox = None
|
||||
|
||||
if self.ocr_graph_path is not None:
|
||||
ocr_graphs = self.data['ocr_graph'][idx]
|
||||
else:
|
||||
ocr_graphs = None
|
||||
|
||||
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
|
||||
|
||||
|
||||
class DataTest(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
frame_path,
|
||||
label_path,
|
||||
pose_path=None,
|
||||
gaze_path=None,
|
||||
bbox_path=None,
|
||||
ocr_graph_path=None,
|
||||
presaved=128,
|
||||
frame_ids=None,
|
||||
median=None,
|
||||
sizes = (128,128),
|
||||
spatial_patch_size: int = 16,
|
||||
temporal_patch_size: int = 4,
|
||||
img_channels: int = 3,
|
||||
patch_data: bool = False,
|
||||
flatten_dim: int = 1
|
||||
):
|
||||
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
|
||||
self.size_1, self.size_2 = sizes
|
||||
self.patch_data = patch_data
|
||||
if self.patch_data:
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
|
||||
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
|
||||
|
||||
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
if presaved == 224:
|
||||
self.presaved = 'presaved'
|
||||
elif presaved == 128:
|
||||
self.presaved = 'presaved128'
|
||||
else: raise ValueError
|
||||
|
||||
if frame_ids is not None:
|
||||
test_ids = []
|
||||
frame_dirs = os.listdir(frame_path)
|
||||
episode_ids = natsorted(frame_dirs)
|
||||
for frame_dir in episode_ids:
|
||||
if int(frame_dir) in frame_ids:
|
||||
test_ids.append(str(frame_dir))
|
||||
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
||||
self.data['frame_paths'].append(frame_paths)
|
||||
elif median is not None:
|
||||
test_ids = []
|
||||
frame_dirs = os.listdir(frame_path)
|
||||
episode_ids = natsorted(frame_dirs)
|
||||
for frame_dir in episode_ids:
|
||||
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
||||
seq_len = len(frame_paths)
|
||||
if (median[1] and seq_len >= median[0]) or (not median[1] and seq_len < median[0]):
|
||||
self.data['frame_paths'].append(frame_paths)
|
||||
test_ids.append(int(frame_dir))
|
||||
else:
|
||||
frame_dirs = os.listdir(frame_path)
|
||||
episode_ids = natsorted(frame_dirs)
|
||||
test_ids = episode_ids.copy()
|
||||
for frame_dir in episode_ids:
|
||||
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
||||
self.data['frame_paths'].append(frame_paths)
|
||||
|
||||
print('Labels...')
|
||||
with open(label_path, 'rb') as fp:
|
||||
labels = pickle.load(fp)
|
||||
left_labels, right_labels = labels
|
||||
for i in test_ids:
|
||||
episode_id = int(i)
|
||||
left_labels_i = left_labels[episode_id]
|
||||
right_labels_i = right_labels[episode_id]
|
||||
self.data['labels'].append([left_labels_i, right_labels_i])
|
||||
fp.close()
|
||||
|
||||
print('Pose...')
|
||||
self.pose_path = pose_path
|
||||
if pose_path is not None:
|
||||
for i in test_ids:
|
||||
pose_i_dir = os.path.join(pose_path, i)
|
||||
poses = []
|
||||
for j in natsorted(os.listdir(pose_i_dir)):
|
||||
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
|
||||
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
|
||||
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
|
||||
else:
|
||||
poses.append(fs.getNode("pose_0").mat())
|
||||
poses = np.array(poses)
|
||||
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
|
||||
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
|
||||
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
|
||||
|
||||
print('Gaze...')
|
||||
self.gaze_path = gaze_path
|
||||
if gaze_path is not None:
|
||||
for i in test_ids:
|
||||
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
|
||||
with open(gaze_i_txt, 'r') as fp:
|
||||
gaze_i = fp.readlines()
|
||||
fp.close()
|
||||
gaze_i = ast.literal_eval(gaze_i[0])
|
||||
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
|
||||
for j in range(len(gaze_i)):
|
||||
if len(gaze_i[j]) >= 2:
|
||||
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
|
||||
elif len(gaze_i[j]) == 1:
|
||||
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
|
||||
else:
|
||||
continue
|
||||
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
|
||||
|
||||
print('Bbox...')
|
||||
self.bbox_path = bbox_path
|
||||
if bbox_path is not None:
|
||||
self.objects = [
|
||||
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
|
||||
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
|
||||
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
|
||||
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
|
||||
]
|
||||
""" NOTE: old bbox
|
||||
for i in test_ids:
|
||||
bbox_i_dir = os.path.join(bbox_path, i)
|
||||
with open(bbox_i_dir, 'rb') as fp:
|
||||
bboxes_i = pickle.load(fp)
|
||||
len_i = len(bboxes_i)
|
||||
fp.close()
|
||||
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
|
||||
for j in range(len(bboxes_i)):
|
||||
items_i_j, bboxes_i_j = bboxes_i[j]
|
||||
for k in range(len(items_i_j)):
|
||||
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
|
||||
bboxes_i_j[k][0] / 1920, # * self.size_1,
|
||||
bboxes_i_j[k][1] / 1088, # * self.size_2,
|
||||
bboxes_i_j[k][2] / 1920, # * self.size_1,
|
||||
bboxes_i_j[k][3] / 1088, # * self.size_2
|
||||
]) # [x_min, y_min, x_max, y_max]
|
||||
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
|
||||
"""
|
||||
for i in test_ids:
|
||||
bbox_dir = os.path.join(bbox_path, i)
|
||||
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
|
||||
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
|
||||
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
|
||||
bbox_content = fp.readlines()
|
||||
fp.close()
|
||||
for bbox_content_line in bbox_content:
|
||||
bbox_content_values = bbox_content_line.split()
|
||||
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
|
||||
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
|
||||
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
|
||||
|
||||
print('OCR...\n')
|
||||
self.ocr_graph_path = ocr_graph_path
|
||||
if ocr_graph_path is not None:
|
||||
ocr_graph = [
|
||||
[15, [10, 4], [17, 2]],
|
||||
[13, [16, 7], [18, 4]],
|
||||
[11, [16, 4], [7, 10]],
|
||||
[14, [10, 11], [7, 1]],
|
||||
[12, [10, 9], [16, 3]],
|
||||
[1, [7, 2], [9, 9], [10, 2]],
|
||||
[5, [8, 8], [6, 8]],
|
||||
[4, [9, 8], [7, 6]],
|
||||
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
|
||||
[2, [10, 1], [7, 7], [9, 3]],
|
||||
[19, [10, 2], [26, 6]],
|
||||
[20, [10, 7], [26, 5]],
|
||||
[22, [25, 4], [10, 8]],
|
||||
[23, [25, 15]],
|
||||
[21, [16, 5], [24, 8]]
|
||||
]
|
||||
ocr_tensor = torch.zeros((27, 27))
|
||||
for ocr in ocr_graph:
|
||||
obj = ocr[0]
|
||||
contexts = ocr[1:]
|
||||
total_context_count = sum([i[1] for i in contexts])
|
||||
for context in contexts:
|
||||
ocr_tensor[obj, context[0]] = context[1] / total_context_count
|
||||
ocr_tensor = torch.flatten(ocr_tensor)
|
||||
for i in test_ids:
|
||||
self.data['ocr_graph'].append(ocr_tensor)
|
||||
|
||||
self.frame_path = frame_path
|
||||
self.transform = transforms.Compose([
|
||||
#transforms.ToPILImage(),
|
||||
transforms.Resize((self.size_1, self.size_2)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data['frame_paths'])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
frames_path = self.data['frame_paths'][idx][0].split('/')
|
||||
frames_path.insert(5, self.presaved)
|
||||
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
|
||||
with open(frames_path, 'rb') as f:
|
||||
images = pickle.load(f)
|
||||
images = torch.stack(images)
|
||||
if self.patch_data:
|
||||
images = rearrange(
|
||||
images,
|
||||
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
|
||||
p1=self.spatial_patch_size,
|
||||
p2=self.spatial_patch_size,
|
||||
p3=self.temporal_patch_size
|
||||
)
|
||||
|
||||
if self.pose_path is not None:
|
||||
pose = self.data['poses'][idx]
|
||||
if self.patch_data:
|
||||
pose = rearrange(
|
||||
pose,
|
||||
'(t p1) d -> t p1 d',
|
||||
p1=self.temporal_patch_size
|
||||
)
|
||||
else:
|
||||
pose = None
|
||||
|
||||
if self.gaze_path is not None:
|
||||
gaze = self.data['gazes'][idx]
|
||||
if self.patch_data:
|
||||
gaze = rearrange(
|
||||
gaze,
|
||||
'(t p1) d -> t p1 d',
|
||||
p1=self.temporal_patch_size
|
||||
)
|
||||
else:
|
||||
gaze = None
|
||||
|
||||
if self.bbox_path is not None:
|
||||
bbox = self.data['bboxes'][idx]
|
||||
if self.patch_data:
|
||||
bbox = rearrange(
|
||||
bbox,
|
||||
'(t p1) d -> t p1 d',
|
||||
p1=self.temporal_patch_size
|
||||
)
|
||||
else:
|
||||
bbox = None
|
||||
|
||||
if self.ocr_graph_path is not None:
|
||||
ocr_graphs = self.data['ocr_graph'][idx]
|
||||
else:
|
||||
ocr_graphs = None
|
||||
|
||||
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
|
||||
|
240
boss/environment.yml
Normal file
240
boss/environment.yml
Normal file
|
@ -0,0 +1,240 @@
|
|||
name: boss
|
||||
channels:
|
||||
- pyg
|
||||
- anaconda
|
||||
- conda-forge
|
||||
- defaults
|
||||
- pytorch
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- alembic=1.8.0=pyhd8ed1ab_0
|
||||
- appdirs=1.4.4=pyh9f0ad1d_0
|
||||
- asttokens=2.2.1=pyhd8ed1ab_0
|
||||
- backcall=0.2.0=pyh9f0ad1d_0
|
||||
- backports=1.0=pyhd8ed1ab_3
|
||||
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
|
||||
- blas=1.0=mkl
|
||||
- blosc=1.21.3=h6a678d5_0
|
||||
- bottle=0.12.23=pyhd8ed1ab_0
|
||||
- bottleneck=1.3.5=py38h7deecbd_0
|
||||
- brotli=1.0.9=h5eee18b_7
|
||||
- brotli-bin=1.0.9=h5eee18b_7
|
||||
- brotlipy=0.7.0=py38h27cfd23_1003
|
||||
- brunsli=0.1=h2531618_0
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- c-ares=1.18.1=h7f8727e_0
|
||||
- ca-certificates=2023.01.10=h06a4308_0
|
||||
- certifi=2023.5.7=py38h06a4308_0
|
||||
- cffi=1.15.1=py38h74dc2b5_0
|
||||
- cfitsio=3.470=h5893167_7
|
||||
- charls=2.2.0=h2531618_0
|
||||
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
||||
- click=8.1.3=unix_pyhd8ed1ab_2
|
||||
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
||||
- cmaes=0.9.1=pyhd8ed1ab_0
|
||||
- colorlog=6.7.0=py38h578d9bd_1
|
||||
- contourpy=1.0.5=py38hdb19cb5_0
|
||||
- cryptography=39.0.1=py38h9ce1e76_0
|
||||
- cudatoolkit=11.3.1=h2bc3f7f_2
|
||||
- cycler=0.11.0=pyhd3eb1b0_0
|
||||
- cytoolz=0.12.0=py38h5eee18b_0
|
||||
- dask-core=2022.7.0=py38h06a4308_0
|
||||
- dbus=1.13.18=hb2f20db_0
|
||||
- debugpy=1.5.1=py38h295c915_0
|
||||
- decorator=5.1.1=pyhd8ed1ab_0
|
||||
- docker-pycreds=0.4.0=py_0
|
||||
- entrypoints=0.4=pyhd8ed1ab_0
|
||||
- executing=1.2.0=pyhd8ed1ab_0
|
||||
- expat=2.4.9=h6a678d5_0
|
||||
- ffmpeg=4.3=hf484d3e_0
|
||||
- fftw=3.3.9=h27cfd23_1
|
||||
- flit-core=3.6.0=pyhd3eb1b0_0
|
||||
- fontconfig=2.14.1=h52c9d5c_1
|
||||
- fonttools=4.25.0=pyhd3eb1b0_0
|
||||
- freetype=2.12.1=h4a9f257_0
|
||||
- fsspec=2022.11.0=py38h06a4308_0
|
||||
- giflib=5.2.1=h5eee18b_1
|
||||
- gitdb=4.0.10=pyhd8ed1ab_0
|
||||
- gitpython=3.1.31=pyhd8ed1ab_0
|
||||
- glib=2.69.1=h4ff587b_1
|
||||
- gmp=6.2.1=h295c915_3
|
||||
- gnutls=3.6.15=he1e5248_0
|
||||
- greenlet=2.0.1=py38h6a678d5_0
|
||||
- gst-plugins-base=1.14.1=h6a678d5_1
|
||||
- gstreamer=1.14.1=h5eee18b_1
|
||||
- icu=58.2=he6710b0_3
|
||||
- idna=3.4=py38h06a4308_0
|
||||
- imagecodecs=2021.8.26=py38hf0132c2_1
|
||||
- imageio=2.19.3=py38h06a4308_0
|
||||
- importlib-metadata=6.1.0=pyha770c72_0
|
||||
- importlib_resources=5.2.0=pyhd3eb1b0_1
|
||||
- intel-openmp=2021.4.0=h06a4308_3561
|
||||
- ipykernel=6.15.0=pyh210e3f2_0
|
||||
- ipython=8.11.0=pyh41d4057_0
|
||||
- jedi=0.18.2=pyhd8ed1ab_0
|
||||
- jinja2=3.1.2=py38h06a4308_0
|
||||
- joblib=1.1.1=py38h06a4308_0
|
||||
- jpeg=9e=h7f8727e_0
|
||||
- jupyter_client=7.0.6=pyhd8ed1ab_0
|
||||
- jupyter_core=4.12.0=py38h578d9bd_0
|
||||
- jxrlib=1.1=h7b6447c_2
|
||||
- kiwisolver=1.4.4=py38h6a678d5_0
|
||||
- krb5=1.19.4=h568e23c_0
|
||||
- lame=3.100=h7b6447c_0
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- ld_impl_linux-64=2.38=h1181459_1
|
||||
- lerc=3.0=h295c915_0
|
||||
- libaec=1.0.4=he6710b0_1
|
||||
- libbrotlicommon=1.0.9=h5eee18b_7
|
||||
- libbrotlidec=1.0.9=h5eee18b_7
|
||||
- libbrotlienc=1.0.9=h5eee18b_7
|
||||
- libclang=10.0.1=default_hb85057a_2
|
||||
- libcurl=7.87.0=h91b91d3_0
|
||||
- libdeflate=1.8=h7f8727e_5
|
||||
- libedit=3.1.20221030=h5eee18b_0
|
||||
- libev=4.33=h7f8727e_1
|
||||
- libevent=2.1.12=h8f2d780_0
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgfortran-ng=11.2.0=h00389a5_1
|
||||
- libgfortran5=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- libiconv=1.16=h7f8727e_2
|
||||
- libidn2=2.3.2=h7f8727e_0
|
||||
- libllvm10=10.0.1=hbcb73fb_5
|
||||
- libnghttp2=1.46.0=hce63b2e_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libpq=12.9=h16c4e8d_3
|
||||
- libprotobuf=3.20.3=he621ea3_0
|
||||
- libsodium=1.0.18=h36c2ea0_1
|
||||
- libssh2=1.10.0=h8f2d780_0
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.5.0=h6a678d5_1
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuuid=1.41.5=h5eee18b_0
|
||||
- libwebp=1.2.4=h11a3e52_0
|
||||
- libwebp-base=1.2.4=h5eee18b_0
|
||||
- libxcb=1.15=h7f8727e_0
|
||||
- libxkbcommon=1.0.1=hfa300c1_0
|
||||
- libxml2=2.9.14=h74e7548_0
|
||||
- libxslt=1.1.35=h4e12654_0
|
||||
- libzopfli=1.0.3=he6710b0_0
|
||||
- locket=1.0.0=py38h06a4308_0
|
||||
- lz4-c=1.9.4=h6a678d5_0
|
||||
- mako=1.2.4=pyhd8ed1ab_0
|
||||
- markupsafe=2.1.1=py38h7f8727e_0
|
||||
- matplotlib=3.7.0=py38h06a4308_0
|
||||
- matplotlib-base=3.7.0=py38h417a72b_0
|
||||
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
||||
- mkl=2021.4.0=h06a4308_640
|
||||
- mkl-service=2.4.0=py38h7f8727e_0
|
||||
- mkl_fft=1.3.1=py38hd3c417c_0
|
||||
- mkl_random=1.2.2=py38h51133e4_0
|
||||
- munkres=1.1.4=py_0
|
||||
- natsort=7.1.1=pyhd3eb1b0_0
|
||||
- ncurses=6.3=h5eee18b_3
|
||||
- nest-asyncio=1.5.6=pyhd8ed1ab_0
|
||||
- nettle=3.7.3=hbbd107a_1
|
||||
- networkx=2.8.4=py38h06a4308_0
|
||||
- nspr=4.33=h295c915_0
|
||||
- nss=3.74=h0370c37_0
|
||||
- numexpr=2.8.4=py38he184ba9_0
|
||||
- numpy=1.23.5=py38h14f4228_0
|
||||
- numpy-base=1.23.5=py38h31eccc5_0
|
||||
- openh264=2.1.1=h4ff587b_0
|
||||
- openjpeg=2.4.0=h3ad879b_0
|
||||
- openssl=1.1.1t=h7f8727e_0
|
||||
- optuna=3.1.0=pyhd8ed1ab_0
|
||||
- optuna-dashboard=0.9.0=pyhd8ed1ab_0
|
||||
- packaging=22.0=py38h06a4308_0
|
||||
- pandas=1.5.3=py38h417a72b_0
|
||||
- parso=0.8.3=pyhd8ed1ab_0
|
||||
- partd=1.2.0=pyhd3eb1b0_1
|
||||
- pathtools=0.1.2=py_1
|
||||
- pcre=8.45=h295c915_0
|
||||
- pexpect=4.8.0=pyh1a96a4e_2
|
||||
- pickleshare=0.7.5=py_1003
|
||||
- pillow=9.4.0=py38h6a678d5_0
|
||||
- pip=22.3.1=py38h06a4308_0
|
||||
- ply=3.11=py38_0
|
||||
- prompt-toolkit=3.0.38=pyha770c72_0
|
||||
- prompt_toolkit=3.0.38=hd8ed1ab_0
|
||||
- protobuf=3.20.3=py38h6a678d5_0
|
||||
- psutil=5.9.0=py38h5eee18b_0
|
||||
- ptyprocess=0.7.0=pyhd3deb0d_0
|
||||
- pure_eval=0.2.2=pyhd8ed1ab_0
|
||||
- pycparser=2.21=pyhd3eb1b0_0
|
||||
- pyg=2.2.0=py38_torch_1.12.0_cu113
|
||||
- pygments=2.14.0=pyhd8ed1ab_0
|
||||
- pyopenssl=23.0.0=py38h06a4308_0
|
||||
- pyparsing=3.0.9=py38h06a4308_0
|
||||
- pyqt=5.15.7=py38h6a678d5_1
|
||||
- pyqt5-sip=12.11.0=py38h6a678d5_1
|
||||
- pysocks=1.7.1=py38h06a4308_0
|
||||
- python=3.8.13=haa1d7c7_1
|
||||
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
||||
- python_abi=3.8=2_cp38
|
||||
- pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0
|
||||
- pytorch-cluster=1.6.0=py38_torch_1.12.0_cu113
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pytorch-scatter=2.1.0=py38_torch_1.12.0_cu113
|
||||
- pytorch-sparse=0.6.16=py38_torch_1.12.0_cu113
|
||||
- pytz=2022.7=py38h06a4308_0
|
||||
- pywavelets=1.4.1=py38h5eee18b_0
|
||||
- pyyaml=6.0=py38h5eee18b_1
|
||||
- pyzmq=19.0.2=py38ha71036d_2
|
||||
- qt-main=5.15.2=h327a75a_7
|
||||
- qt-webengine=5.15.9=hd2b0992_4
|
||||
- qtwebkit=5.212=h4eab89a_4
|
||||
- readline=8.2=h5eee18b_0
|
||||
- requests=2.28.1=py38h06a4308_1
|
||||
- scikit-image=0.19.3=py38h6a678d5_1
|
||||
- scikit-learn=1.2.1=py38h6a678d5_0
|
||||
- scipy=1.9.3=py38h14f4228_0
|
||||
- sentry-sdk=1.16.0=pyhd8ed1ab_0
|
||||
- setproctitle=1.2.2=py38h0a891b7_2
|
||||
- setuptools=65.6.3=py38h06a4308_0
|
||||
- sip=6.6.2=py38h6a678d5_0
|
||||
- six=1.16.0=pyhd3eb1b0_1
|
||||
- smmap=3.0.5=pyh44b312d_0
|
||||
- snappy=1.1.9=h295c915_0
|
||||
- sqlalchemy=1.4.39=py38h5eee18b_0
|
||||
- sqlite=3.40.1=h5082296_0
|
||||
- stack_data=0.6.2=pyhd8ed1ab_0
|
||||
- threadpoolctl=2.2.0=pyh0d69192_0
|
||||
- tifffile=2021.7.2=pyhd3eb1b0_2
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- toml=0.10.2=pyhd3eb1b0_0
|
||||
- toolz=0.12.0=py38h06a4308_0
|
||||
- torchaudio=0.12.0=py38_cu113
|
||||
- torchvision=0.13.0=py38_cu113
|
||||
- tornado=6.1=py38h0a891b7_3
|
||||
- tqdm=4.64.1=py38h06a4308_0
|
||||
- traitlets=5.9.0=pyhd8ed1ab_0
|
||||
- typing_extensions=4.4.0=py38h06a4308_0
|
||||
- tzdata=2022g=h04d1e81_0
|
||||
- urllib3=1.26.14=py38h06a4308_0
|
||||
- wandb=0.13.10=pyhd8ed1ab_0
|
||||
- wcwidth=0.2.6=pyhd8ed1ab_0
|
||||
- wheel=0.37.1=pyhd3eb1b0_0
|
||||
- xz=5.2.10=h5eee18b_1
|
||||
- yaml=0.2.5=h7b6447c_0
|
||||
- zeromq=4.3.4=h9c3ff4c_1
|
||||
- zfp=0.5.5=h295c915_6
|
||||
- zipp=3.11.0=py38h06a4308_0
|
||||
- zlib=1.2.13=h5eee18b_0
|
||||
- zstd=1.5.2=ha4553b6_0
|
||||
- pip:
|
||||
- einops==0.6.1
|
||||
- lion-pytorch==0.1.2
|
||||
- llvmlite==0.40.0
|
||||
- memory-efficient-attention-pytorch==0.1.2
|
||||
- numba==0.57.0
|
||||
- opencv-python==4.7.0.72
|
||||
- pynndescent==0.5.10
|
||||
- seaborn==0.12.2
|
||||
- umap-learn==0.5.3
|
||||
- x-transformers==1.14.0
|
||||
prefix: /opt/anaconda3/envs/boss
|
0
boss/models/__init__.py
Normal file
0
boss/models/__init__.py
Normal file
230
boss/models/base.py
Normal file
230
boss/models/base.py
Normal file
|
@ -0,0 +1,230 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .utils import pose_edge_index
|
||||
from torch_geometric.nn import Sequential, GCNConv
|
||||
from x_transformers import ContinuousTransformerWrapper, Decoder
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self, hidden_dim):
|
||||
super(CNN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv2(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv3(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
|
||||
return x
|
||||
|
||||
|
||||
class MindNetLSTM(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetLSTM, self).__init__()
|
||||
self.mods = mods
|
||||
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||||
feats = []
|
||||
if 'rgb' in self.mods:
|
||||
feats.append(rgb_feats)
|
||||
if 'ocr' in self.mods:
|
||||
feats.append(ocr_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats)
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||||
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||||
return self.act(self.proj(lstm_out)), c_n, feats
|
||||
|
||||
|
||||
class MindNetSL(nn.Module):
|
||||
"""
|
||||
Basic MindNet for SL ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetSL, self).__init__()
|
||||
self.mods = mods
|
||||
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*5,
|
||||
nn.LSTM(input_size=hidden_dim*5, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
|
||||
)
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||||
feats = []
|
||||
if 'rgb' in self.mods:
|
||||
feats.append(rgb_feats)
|
||||
if 'ocr' in self.mods:
|
||||
feats.append(ocr_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats)
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
|
||||
return self.act(self.proj(lstm_out)), feats
|
||||
|
||||
|
||||
class MindNetTF(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, Transformer on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetTF, self).__init__()
|
||||
self.mods = mods
|
||||
self.gaze_emb = nn.Linear(3, hidden_dim)
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.tf = ContinuousTransformerWrapper(
|
||||
dim_in=hidden_dim*len(mods),
|
||||
dim_out=hidden_dim,
|
||||
max_seq_len=747,
|
||||
attn_layers=Decoder(
|
||||
dim=512,
|
||||
depth=6,
|
||||
heads=8
|
||||
)
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||||
feats = []
|
||||
if 'rgb' in self.mods:
|
||||
feats.append(rgb_feats)
|
||||
if 'ocr' in self.mods:
|
||||
feats.append(ocr_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats)
|
||||
tf_inp = torch.cat(feats, 2)
|
||||
tf_out = self.tf(self.dropout(tf_inp))
|
||||
return tf_out, feats
|
||||
|
||||
|
||||
class MindNetLSTMXL(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetLSTMXL, self).__init__()
|
||||
self.mods = mods
|
||||
self.gaze_emb = nn.Sequential(
|
||||
nn.Linear(3, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, hidden_dim)
|
||||
)
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = Sequential('x, edge_index', [
|
||||
(GCNConv(3, hidden_dim), 'x, edge_index -> x'),
|
||||
nn.GELU(),
|
||||
(GCNConv(hidden_dim, hidden_dim), 'x, edge_index -> x'),
|
||||
])
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
|
||||
)
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_feats, ocr_feats, pose, gaze, bbox_feats):
|
||||
feats = []
|
||||
if 'rgb' in self.mods:
|
||||
feats.append(rgb_feats)
|
||||
if 'ocr' in self.mods:
|
||||
feats.append(ocr_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 25, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats)
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||||
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||||
return self.act(self.proj(lstm_out)), c_n, feats
|
249
boss/models/resnet.py
Normal file
249
boss/models/resnet.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .utils import left_bias, right_bias
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, input_dim, device):
|
||||
super(ResNet, self).__init__()
|
||||
# Conv
|
||||
resnet = models.resnet34(pretrained=True)
|
||||
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
|
||||
# FFs
|
||||
self.left = nn.Linear(input_dim, 27)
|
||||
self.right = nn.Linear(input_dim, 27)
|
||||
# modality FFs
|
||||
self.pose_ff = nn.Linear(150, 150)
|
||||
self.gaze_ff = nn.Linear(6, 6)
|
||||
self.bbox_ff = nn.Linear(108, 108)
|
||||
self.ocr_ff = nn.Linear(729, 64)
|
||||
# others
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.device = device
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
left_beliefs = []
|
||||
right_beliefs = []
|
||||
image_feats = []
|
||||
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i].to(self.device)
|
||||
image_i_feat = self.resnet(images_i)
|
||||
image_i_feat = image_i_feat.view(batch_size, 512)
|
||||
if poses is not None:
|
||||
poses_i = poses[:,i].float()
|
||||
poses_i_feat = self.relu(self.pose_ff(poses_i))
|
||||
image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1)
|
||||
if gazes is not None:
|
||||
gazes_i = gazes[:,i].float()
|
||||
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
|
||||
image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1)
|
||||
if bboxes is not None:
|
||||
bboxes_i = bboxes[:,i].float()
|
||||
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
|
||||
image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1)
|
||||
if ocr_tensor is not None:
|
||||
ocr_tensor = ocr_tensor
|
||||
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
|
||||
image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1)
|
||||
image_feats.append(image_i_feat)
|
||||
|
||||
image_feats = torch.permute(torch.stack(image_feats), (1,0,2))
|
||||
left_beliefs = self.left(self.dropout(image_feats))
|
||||
right_beliefs = self.right(self.dropout(image_feats))
|
||||
|
||||
return left_beliefs, right_beliefs, None
|
||||
|
||||
|
||||
class ResNetGRU(nn.Module):
|
||||
def __init__(self, input_dim, device):
|
||||
super(ResNetGRU, self).__init__()
|
||||
resnet = models.resnet34(pretrained=True)
|
||||
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
|
||||
self.gru = nn.GRU(input_dim, 512, batch_first=True)
|
||||
for name, param in self.gru.named_parameters():
|
||||
if "weight" in name:
|
||||
nn.init.orthogonal_(param)
|
||||
elif "bias" in name:
|
||||
nn.init.constant_(param, 0)
|
||||
# FFs
|
||||
self.left = nn.Linear(512, 27)
|
||||
self.right = nn.Linear(512, 27)
|
||||
# modality FFs
|
||||
self.pose_ff = nn.Linear(150, 150)
|
||||
self.gaze_ff = nn.Linear(6, 6)
|
||||
self.bbox_ff = nn.Linear(108, 108)
|
||||
self.ocr_ff = nn.Linear(729, 64)
|
||||
# others
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.relu = nn.ReLU()
|
||||
self.device = device
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
left_beliefs = []
|
||||
right_beliefs = []
|
||||
rnn_inp = []
|
||||
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
rnn_i_feat = self.resnet(images_i)
|
||||
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
|
||||
if poses is not None:
|
||||
poses_i = poses[:,i].float()
|
||||
poses_i_feat = self.relu(self.pose_ff(poses_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
|
||||
if gazes is not None:
|
||||
gazes_i = gazes[:,i].float()
|
||||
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
|
||||
if bboxes is not None:
|
||||
bboxes_i = bboxes[:,i].float()
|
||||
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
|
||||
if ocr_tensor is not None:
|
||||
ocr_tensor = ocr_tensor
|
||||
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
|
||||
rnn_inp.append(rnn_i_feat)
|
||||
|
||||
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
|
||||
rnn_out, _ = self.gru(rnn_inp)
|
||||
left_beliefs = self.left(self.dropout(rnn_out))
|
||||
right_beliefs = self.right(self.dropout(rnn_out))
|
||||
|
||||
return left_beliefs, right_beliefs, None
|
||||
|
||||
|
||||
class ResNetConv1D(nn.Module):
|
||||
def __init__(self, input_dim, device):
|
||||
super(ResNetConv1D, self).__init__()
|
||||
resnet = models.resnet34(pretrained=True)
|
||||
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
|
||||
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4)
|
||||
# FFs
|
||||
self.left = nn.Linear(512, 27)
|
||||
self.right = nn.Linear(512, 27)
|
||||
# modality FFs
|
||||
self.pose_ff = nn.Linear(150, 150)
|
||||
self.gaze_ff = nn.Linear(6, 6)
|
||||
self.bbox_ff = nn.Linear(108, 108)
|
||||
self.ocr_ff = nn.Linear(729, 64)
|
||||
# others
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.device = device
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
left_beliefs = []
|
||||
right_beliefs = []
|
||||
conv1d_inp = []
|
||||
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
images_i_feat = self.resnet(images_i)
|
||||
images_i_feat = images_i_feat.view(batch_size, 512)
|
||||
if poses is not None:
|
||||
poses_i = poses[:,i].float()
|
||||
poses_i_feat = self.relu(self.pose_ff(poses_i))
|
||||
images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1)
|
||||
if gazes is not None:
|
||||
gazes_i = gazes[:,i].float()
|
||||
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
|
||||
images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1)
|
||||
if bboxes is not None:
|
||||
bboxes_i = bboxes[:,i].float()
|
||||
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
|
||||
images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1)
|
||||
if ocr_tensor is not None:
|
||||
ocr_tensor = ocr_tensor
|
||||
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
|
||||
images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1)
|
||||
conv1d_inp.append(images_i_feat)
|
||||
|
||||
conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0))
|
||||
conv1d_out = self.conv1d(conv1d_inp)
|
||||
conv1d_out = conv1d_out[:,:,:-4]
|
||||
conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1)))
|
||||
left_beliefs = self.left(self.dropout(conv1d_out))
|
||||
right_beliefs = self.right(self.dropout(conv1d_out))
|
||||
|
||||
return left_beliefs, right_beliefs, None
|
||||
|
||||
|
||||
class ResNetLSTM(nn.Module):
|
||||
def __init__(self, input_dim, device):
|
||||
super(ResNetLSTM, self).__init__()
|
||||
resnet = models.resnet34(pretrained=True)
|
||||
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
|
||||
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True)
|
||||
# FFs
|
||||
self.left = nn.Linear(512, 27)
|
||||
self.right = nn.Linear(512, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
# modality FFs
|
||||
self.pose_ff = nn.Linear(150, 150)
|
||||
self.gaze_ff = nn.Linear(6, 6)
|
||||
self.bbox_ff = nn.Linear(108, 108)
|
||||
self.ocr_ff = nn.Linear(729, 64)
|
||||
# others
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.device = device
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
left_beliefs = []
|
||||
right_beliefs = []
|
||||
rnn_inp = []
|
||||
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
rnn_i_feat = self.resnet(images_i)
|
||||
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
|
||||
if poses is not None:
|
||||
poses_i = poses[:,i].float()
|
||||
poses_i_feat = self.relu(self.pose_ff(poses_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
|
||||
if gazes is not None:
|
||||
gazes_i = gazes[:,i].float()
|
||||
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
|
||||
if bboxes is not None:
|
||||
bboxes_i = bboxes[:,i].float()
|
||||
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
|
||||
if ocr_tensor is not None:
|
||||
ocr_tensor = ocr_tensor
|
||||
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
|
||||
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
|
||||
rnn_inp.append(rnn_i_feat)
|
||||
|
||||
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
|
||||
rnn_out, _ = self.lstm(rnn_inp)
|
||||
left_beliefs = self.left(self.dropout(rnn_out))
|
||||
right_beliefs = self.right(self.dropout(rnn_out))
|
||||
|
||||
return left_beliefs, right_beliefs, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 150)
|
||||
gazes = torch.ones(3, 22, 6)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = ResNet(32, 'cpu')
|
||||
print(model(images, poses, gazes, bboxes, None)[0].shape)
|
95
boss/models/single_mindnet.py
Normal file
95
boss/models/single_mindnet.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetLSTM
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class SingleMindNet(nn.Module):
|
||||
"""
|
||||
Base ToM net. Supports any subset of modalities
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(SingleMindNet, self).__init__()
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def count_parameters(model):
|
||||
import numpy as np
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = SingleMindNet(64, 'cpu', False, 0.5)
|
||||
print(count_parameters(model))
|
||||
breakpoint()
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
104
boss/models/tom_base.py
Normal file
104
boss/models/tom_base.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetLSTM
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseToMnet(nn.Module):
|
||||
"""
|
||||
Base ToM net. Supports any subset of modalities
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(BaseToMnet, self).__init__()
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = BaseToMnet(64, 'cpu', False, 0.5)
|
||||
print(count_parameters(model))
|
||||
breakpoint()
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
269
boss/models/tom_common_mind.py
Normal file
269
boss/models/tom_common_mind.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetLSTM, MindNetLSTMXL
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class CommonMindToMnet(nn.Module):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(CommonMindToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
#for param in self.cnn.parameters():
|
||||
# param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.ln_left = nn.LayerNorm(hidden_dim)
|
||||
self.ln_right = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
return self.left(out_left), self.right(out_right)
|
||||
|
||||
if self.aggr == 'attn':
|
||||
l = self.attn_left(x=out_left, context=common_mind)
|
||||
r = self.attn_right(x=out_right, context=common_mind)
|
||||
elif self.aggr == 'mult':
|
||||
l = out_left * common_mind
|
||||
r = out_right * common_mind
|
||||
elif self.aggr == 'sum':
|
||||
l = out_left + common_mind
|
||||
r = out_right + common_mind
|
||||
elif self.aggr == 'concat':
|
||||
l = torch.cat([out_left, common_mind], 1)
|
||||
r = torch.cat([out_right, common_mind], 1)
|
||||
else: raise ValueError
|
||||
l = self.act(l)
|
||||
l = self.ln_left(l)
|
||||
r = self.act(r)
|
||||
r = self.ln_right(r)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
left_beliefs = self.left(l)
|
||||
right_beliefs = self.right(r)
|
||||
if self.aggr == 'concat':
|
||||
left_beliefs = self.left(l)[:, :-1, :]
|
||||
right_beliefs = self.right(r)[:, :-1, :]
|
||||
|
||||
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
class CommonMindToMnetXL(nn.Module):
|
||||
"""
|
||||
XL model.
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(CommonMindToMnetXL, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetLSTMXL(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetLSTMXL(hidden_dim, dropout, mods)
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.ln_left = nn.LayerNorm(hidden_dim)
|
||||
self.ln_right = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.left = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, 27),
|
||||
)
|
||||
self.right = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, 27)
|
||||
)
|
||||
self.left[-1].bias.data = torch.tensor(left_bias).log()
|
||||
self.right[-1].bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
common_mind = self.proj(torch.cat([cell_left, cell_right], -1))
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
return self.left(out_left), self.right(out_right)
|
||||
|
||||
if self.aggr == 'attn':
|
||||
l = self.attn_left(x=out_left, context=common_mind)
|
||||
r = self.attn_right(x=out_right, context=common_mind)
|
||||
elif self.aggr == 'mult':
|
||||
l = out_left * common_mind
|
||||
r = out_right * common_mind
|
||||
elif self.aggr == 'sum':
|
||||
l = out_left + common_mind
|
||||
r = out_right + common_mind
|
||||
elif self.aggr == 'concat':
|
||||
l = torch.cat([out_left, common_mind], 1)
|
||||
r = torch.cat([out_right, common_mind], 1)
|
||||
else: raise ValueError
|
||||
l = self.act(l)
|
||||
l = self.ln_left(l)
|
||||
r = self.act(r)
|
||||
r = self.ln_right(r)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
left_beliefs = self.left(l)
|
||||
right_beliefs = self.right(r)
|
||||
if self.aggr == 'concat':
|
||||
left_beliefs = self.left(l)[:, :-1, :]
|
||||
right_beliefs = self.right(r)[:, :-1, :]
|
||||
|
||||
return left_beliefs, right_beliefs, [out_left, out_right, common_mind] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = CommonMindToMnetXL(64, 'cpu', False, 0.5, aggr='attn')
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
144
boss/models/tom_implicit.py
Normal file
144
boss/models/tom_implicit.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetLSTM
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class ImplicitToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(ImplicitToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetLSTM(hidden_dim, dropout, mods)
|
||||
self.ln_left = nn.LayerNorm(hidden_dim)
|
||||
self.ln_right = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, cell_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, cell_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
return self.left(out_left), self.right(out_right), [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
|
||||
|
||||
if self.aggr == 'attn':
|
||||
l = self.attn_left(x=out_left, context=cell_right)
|
||||
r = self.attn_right(x=out_right, context=cell_left)
|
||||
elif self.aggr == 'mult':
|
||||
l = out_left * cell_right
|
||||
r = out_right * cell_left
|
||||
elif self.aggr == 'sum':
|
||||
l = out_left + cell_right
|
||||
r = out_right + cell_left
|
||||
elif self.aggr == 'concat':
|
||||
l = torch.cat([out_left, cell_right], 1)
|
||||
r = torch.cat([out_right, cell_left], 1)
|
||||
else: raise ValueError
|
||||
l = self.act(l)
|
||||
l = self.ln_left(l)
|
||||
r = self.act(r)
|
||||
r = self.ln_right(r)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
left_beliefs = self.left(l)
|
||||
right_beliefs = self.right(r)
|
||||
if self.aggr == 'concat':
|
||||
left_beliefs = self.left(l)[:, :-1, :]
|
||||
right_beliefs = self.right(r)[:, :-1, :]
|
||||
|
||||
return left_beliefs, right_beliefs, [out_left, cell_left, out_right, cell_right] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = ImplicitToMnet(64, 'cpu', False, 0.5, aggr='attn')
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
98
boss/models/tom_sl.py
Normal file
98
boss/models/tom_sl.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph, pose_edge_index
|
||||
from .base import CNN, MindNetSL
|
||||
|
||||
|
||||
class SLToMnet(nn.Module):
|
||||
"""
|
||||
Speaker-Listener ToMnet
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(SLToMnet, self).__init__()
|
||||
|
||||
self.tom_weight = tom_weight
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetSL(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetSL(hidden_dim, dropout, mods)
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
left_logits = self.left(out_left)
|
||||
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
right_logits = self.left(out_right)
|
||||
|
||||
left_ranking = torch.log_softmax(left_logits, dim=-1)
|
||||
right_ranking = torch.log_softmax(right_logits, dim=-1)
|
||||
|
||||
right_beliefs = left_ranking + self.tom_weight * right_ranking
|
||||
left_beliefs = right_ranking + self.tom_weight * left_ranking
|
||||
|
||||
return left_beliefs, right_beliefs, [out_left, out_right] + feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = SLToMnet(64, 'cpu', 2.0, False, 0.5)
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
104
boss/models/tom_tf.py
Normal file
104
boss/models/tom_tf.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from torch_geometric.nn.conv import GCNConv
|
||||
from .utils import left_bias, right_bias, build_ocr_graph
|
||||
from .base import CNN, MindNetTF
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class TFToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb', 'pose', 'gaze', 'ocr', 'bbox']):
|
||||
super(TFToMnet, self).__init__()
|
||||
|
||||
# ---- Images ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
# ---- OCR and bbox -----#
|
||||
self.ocr_x, self.ocr_edge_index, self.ocr_edge_attr = build_ocr_graph(device)
|
||||
self.ocr_gnn = GCNConv(-1, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(108, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_left = MindNetTF(hidden_dim, dropout, mods)
|
||||
self.mind_net_right = MindNetTF(hidden_dim, dropout, mods)
|
||||
self.left = nn.Linear(hidden_dim, 27)
|
||||
self.right = nn.Linear(hidden_dim, 27)
|
||||
self.left.bias.data = torch.tensor(left_bias).log()
|
||||
self.right.bias.data = torch.tensor(right_bias).log()
|
||||
|
||||
def forward(self, images, poses, gazes, bboxes, ocr_tensor=None):
|
||||
|
||||
assert images is not None
|
||||
assert poses is not None
|
||||
assert gazes is not None
|
||||
assert bboxes is not None
|
||||
|
||||
batch_size, sequence_len, channels, height, width = images.shape
|
||||
|
||||
bbox_feat = self.act(self.bbox_ff(bboxes))
|
||||
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = images[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
|
||||
ocr_feat = self.dropout(self.act(self.ocr_gnn(self.ocr_x,
|
||||
self.ocr_edge_index,
|
||||
self.ocr_edge_attr)))
|
||||
ocr_feat = ocr_feat.mean(0).unsqueeze(0).repeat(batch_size, sequence_len, 1)
|
||||
|
||||
out_left, feats_left = self.mind_net_left(rgb_feat, ocr_feat, poses[:, :, 0, :], gazes[:, :, 0, :], bbox_feat)
|
||||
out_right, feats_right = self.mind_net_right(rgb_feat, ocr_feat, poses[:, :, 1, :], gazes[:, :, 1, :], bbox_feat)
|
||||
|
||||
l = self.dropout(self.act(out_left))
|
||||
r = self.dropout(self.act(out_right))
|
||||
left_beliefs = self.left(l)
|
||||
right_beliefs = self.right(r)
|
||||
|
||||
return left_beliefs, right_beliefs, feats_left + feats_right
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
images = torch.ones(3, 22, 3, 128, 128)
|
||||
poses = torch.ones(3, 22, 2, 75)
|
||||
gazes = torch.ones(3, 22, 2, 3)
|
||||
bboxes = torch.ones(3, 22, 108)
|
||||
model = TFToMnet(64, 'cpu', False, 0.5)
|
||||
out = model(images, poses, gazes, bboxes, None)
|
||||
print(out[0].shape)
|
||||
|
95
boss/models/utils.py
Normal file
95
boss/models/utils.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import torch
|
||||
|
||||
left_bias = [0.10659290976303822,
|
||||
0.025158905348262015,
|
||||
0.02811095449589107,
|
||||
0.026342384511050237,
|
||||
0.025318475572458178,
|
||||
0.02283183957873461,
|
||||
0.021581872822531316,
|
||||
0.08062285577511237,
|
||||
0.03824366373234754,
|
||||
0.04853594319300018,
|
||||
0.09653998563867983,
|
||||
0.02961357410707162,
|
||||
0.02961357410707162,
|
||||
0.03172787957767081,
|
||||
0.029985904630196004,
|
||||
0.02897529321028696,
|
||||
0.06602218026116327,
|
||||
0.015345336560197867,
|
||||
0.026900880295736816,
|
||||
0.024879657455918726,
|
||||
0.028669450280577644,
|
||||
0.01936118720246802,
|
||||
0.02341693040078721,
|
||||
0.014707055663413206,
|
||||
0.027007260445200926,
|
||||
0.04146166325363687,
|
||||
0.04243238211749688]
|
||||
|
||||
right_bias = [0.13147256721895695,
|
||||
0.012433179968617855,
|
||||
0.01623627031195979,
|
||||
0.013683146724821148,
|
||||
0.015252253929416771,
|
||||
0.012579452674131008,
|
||||
0.03127576394244834,
|
||||
0.10325523257360177,
|
||||
0.041155820323927554,
|
||||
0.06563655221935587,
|
||||
0.12684503071726816,
|
||||
0.016156485199861705,
|
||||
0.0176989973670913,
|
||||
0.020238823435546928,
|
||||
0.01918831945958884,
|
||||
0.01791175766601952,
|
||||
0.08768383819579266,
|
||||
0.019002154198026647,
|
||||
0.029600276588388607,
|
||||
0.01578415467673732,
|
||||
0.0176989973670913,
|
||||
0.011834791627882237,
|
||||
0.014919815962341426,
|
||||
0.007552990611951809,
|
||||
0.029759846812584773,
|
||||
0.04981250498656951,
|
||||
0.05533097524002021]
|
||||
|
||||
def build_ocr_graph(device):
|
||||
ocr_graph = [
|
||||
[15, [10, 4], [17, 2]],
|
||||
[13, [16, 7], [18, 4]],
|
||||
[11, [16, 4], [7, 10]],
|
||||
[14, [10, 11], [7, 1]],
|
||||
[12, [10, 9], [16, 3]],
|
||||
[1, [7, 2], [9, 9], [10, 2]],
|
||||
[5, [8, 8], [6, 8]],
|
||||
[4, [9, 8], [7, 6]],
|
||||
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
|
||||
[2, [10, 1], [7, 7], [9, 3]],
|
||||
[19, [10, 2], [26, 6]],
|
||||
[20, [10, 7], [26, 5]],
|
||||
[22, [25, 4], [10, 8]],
|
||||
[23, [25, 15]],
|
||||
[21, [16, 5], [24, 8]]
|
||||
]
|
||||
edge_index = []
|
||||
edge_attr = []
|
||||
for i in range(len(ocr_graph)):
|
||||
for j in range(1, len(ocr_graph[i])):
|
||||
source_node = ocr_graph[i][0]
|
||||
target_node = ocr_graph[i][j][0]
|
||||
edge_index.append([source_node, target_node])
|
||||
edge_attr.append(ocr_graph[i][j][1])
|
||||
ocr_edge_index = torch.tensor(edge_index).t().long()
|
||||
ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1)
|
||||
x = torch.arange(0, 27)
|
||||
ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float)
|
||||
return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device)
|
||||
|
||||
def pose_edge_index():
|
||||
return torch.tensor(
|
||||
[[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20],
|
||||
[15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]],
|
||||
dtype=torch.long)
|
BIN
boss/new_bbox/test_bbox.tar.gz
Normal file
BIN
boss/new_bbox/test_bbox.tar.gz
Normal file
Binary file not shown.
BIN
boss/new_bbox/train_bbox.tar.gz
Normal file
BIN
boss/new_bbox/train_bbox.tar.gz
Normal file
Binary file not shown.
BIN
boss/new_bbox/val_bbox.tar.gz
Normal file
BIN
boss/new_bbox/val_bbox.tar.gz
Normal file
Binary file not shown.
BIN
boss/outfile
Normal file
BIN
boss/outfile
Normal file
Binary file not shown.
82
boss/plots/old_vs_new_bbox.py
Normal file
82
boss/plots/old_vs_new_bbox.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sns.set_theme(style='whitegrid')
|
||||
|
||||
COLORS = sns.color_palette()
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
|
||||
|
||||
abl_old_bbox_mean = [0.539406718, 0.5348262324, 0.529845863]
|
||||
abl_old_bbox_std = [0.03639921819, 0.01519544901, 0.01718265794]
|
||||
|
||||
filename = 'results/abl_cm_concat_old_vs_new_bbox'
|
||||
|
||||
#def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
|
||||
|
||||
means = []
|
||||
stds = []
|
||||
for key, values in abl_tom_cm_concat.items():
|
||||
mean = np.mean(values)
|
||||
std = np.std(values)
|
||||
means.append(mean)
|
||||
stds.append(std)
|
||||
fig, ax = plt.subplots(figsize=(10,4))
|
||||
x = np.arange(len(abl_tom_cm_concat))
|
||||
width = 0.6
|
||||
rects1 = ax.bar(
|
||||
x, means, width, label='New bbox', yerr=stds,
|
||||
capsize=5,
|
||||
color=[MTOM_COLORS['CG']]*8,
|
||||
edgecolor='black',
|
||||
linewidth=1.5,
|
||||
alpha=0.6
|
||||
)
|
||||
rects2 = ax.bar(
|
||||
[0, 2, 5], abl_old_bbox_mean, width, label='Original bbox', yerr=abl_old_bbox_std,
|
||||
capsize=5,
|
||||
color=[COLORS[9]]*8,
|
||||
edgecolor='black',
|
||||
linewidth=1.5,
|
||||
alpha=0.6
|
||||
)
|
||||
ax.set_ylabel('Accuracy', fontsize=18)
|
||||
xticklabels = list(abl_tom_cm_concat.keys())
|
||||
ax.set_xticks(np.arange(len(xticklabels)))
|
||||
ax.set_xticklabels(xticklabels, rotation=0, fontsize=16)
|
||||
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
|
||||
# Add value labels above each bar, avoiding overlapping with error bars
|
||||
for rect, std in zip(rects1, stds):
|
||||
height = rect.get_height()
|
||||
if height + std < ax.get_ylim()[1]:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
|
||||
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
|
||||
else:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
|
||||
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
|
||||
for rect, std in zip(rects2, abl_old_bbox_std):
|
||||
height = rect.get_height()
|
||||
if height + std < ax.get_ylim()[1]:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
|
||||
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
|
||||
else:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
|
||||
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
|
||||
# Remove spines
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.spines['right'].set_visible(False)
|
||||
ax.legend(fontsize=14)
|
||||
ax.grid(axis='x')
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{filename}.pdf', bbox_inches='tight')
|
82
boss/plots/pca.py
Normal file
82
boss/plots/pca.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import torch
|
||||
import os
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
|
||||
FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-22_12-00-38_train_None" # no_tom seed 1
|
||||
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-39-41_train_None" # impl mult seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-47-07_train_None" # impl sum seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_15-58-44_train_None" # impl attn seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-58-04_train_None" # impl concat seed 1
|
||||
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-40-01_train_None" # cm mult seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-45-55_train_None" # cm sum seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-50-42_train_None" # cm attn seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-57-15_train_None" # cm concat seed 1
|
||||
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-37-50_train_None" # db seed 1
|
||||
|
||||
print(FOLDER_PATH)
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
sns.set_theme(style='white')
|
||||
|
||||
for i in range(60):
|
||||
|
||||
print(f'Computing analysis for test video {i}...', end='\r')
|
||||
|
||||
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
|
||||
data = torch.load(emb_file)
|
||||
if len(data) == 14: # implicit
|
||||
model = 'impl'
|
||||
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||||
out_left = out_left.squeeze(0)
|
||||
cell_left = cell_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
cell_right = cell_right.squeeze(0)
|
||||
elif len(data) == 13: # common mind
|
||||
model = 'cm'
|
||||
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||||
out_left = out_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
common_mind = common_mind.squeeze(0)
|
||||
elif len(data) == 12: # speaker-listener
|
||||
model = 'sl'
|
||||
out_left, out_right, feats = data[0], data[1], data[2:]
|
||||
out_left = out_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
|
||||
|
||||
# ====== PCA for left and right embeddings ====== #
|
||||
|
||||
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
pca_result = pca.fit_transform(out_left_and_right)
|
||||
|
||||
# Separate the PCA results for each tensor
|
||||
pca_result_left = pca_result[:out_left.shape[0]]
|
||||
pca_result_right = pca_result[out_right.shape[0]:]
|
||||
|
||||
plt.figure(figsize=(7,6))
|
||||
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
|
||||
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
|
||||
plt.xlabel('Principal Component 1', fontsize=30)
|
||||
plt.ylabel('Principal Component 2', fontsize=30)
|
||||
plt.grid(False)
|
||||
plt.legend(fontsize=30)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
|
||||
plt.close()
|
42
boss/results/abl_cm_concat.json
Normal file
42
boss/results/abl_cm_concat.json
Normal file
|
@ -0,0 +1,42 @@
|
|||
{
|
||||
"all": [
|
||||
0.7402937327,
|
||||
0.7171004799,
|
||||
0.7305511124
|
||||
],
|
||||
"rgb\npose\ngaze": [
|
||||
0.4736803839,
|
||||
0.4658281227,
|
||||
0.4948378653
|
||||
],
|
||||
"rgb\nocr\nbbox": [
|
||||
0.7364403083,
|
||||
0.7407663225,
|
||||
0.7321506471
|
||||
],
|
||||
"rgb\ngaze": [
|
||||
0.5013814163,
|
||||
0.418859968,
|
||||
0.4644103534
|
||||
],
|
||||
"rgb\npose": [
|
||||
0.4545586738,
|
||||
0.4584484514,
|
||||
0.4525229024
|
||||
],
|
||||
"rgb\nbbox": [
|
||||
0.716591537,
|
||||
0.7334593573,
|
||||
0.758324851
|
||||
],
|
||||
"rgb\nocr": [
|
||||
0.4581939799,
|
||||
0.456449033,
|
||||
0.4577577432
|
||||
],
|
||||
"rgb": [
|
||||
0.4896393776,
|
||||
0.4907299695,
|
||||
0.4583757452
|
||||
]
|
||||
}
|
BIN
boss/results/abl_cm_concat.pdf
Normal file
BIN
boss/results/abl_cm_concat.pdf
Normal file
Binary file not shown.
72
boss/results/all.json
Normal file
72
boss/results/all.json
Normal file
|
@ -0,0 +1,72 @@
|
|||
{
|
||||
"no_tom": [
|
||||
0.6504653192,
|
||||
0.6404682274,
|
||||
0.6666787844
|
||||
],
|
||||
"sl": [
|
||||
0.6584993456,
|
||||
0.6608259415,
|
||||
0.6584629926
|
||||
],
|
||||
"cm mult": [
|
||||
0.6739130435,
|
||||
0.6872182638,
|
||||
0.6591173477
|
||||
],
|
||||
"cm sum": [
|
||||
0.5820852116,
|
||||
0.6401774029,
|
||||
0.6001163298
|
||||
],
|
||||
"cm attn": [
|
||||
0.3240511851,
|
||||
0.2688672386,
|
||||
0.3028573506
|
||||
],
|
||||
"cm concat": [
|
||||
0.7402937327,
|
||||
0.7171004799,
|
||||
0.7305511124
|
||||
],
|
||||
"impl mult": [
|
||||
0.7119746983,
|
||||
0.7019776065,
|
||||
0.7244437982
|
||||
],
|
||||
"impl sum": [
|
||||
0.6542460375,
|
||||
0.6642431293,
|
||||
0.6678420823
|
||||
],
|
||||
"impl attn": [
|
||||
0.3311763851,
|
||||
0.3125636179,
|
||||
0.3112549077
|
||||
],
|
||||
"impl concat": [
|
||||
0.6957975862,
|
||||
0.7227352043,
|
||||
0.6971062964
|
||||
],
|
||||
"resnet": [
|
||||
0.467173186,
|
||||
0.4267485822,
|
||||
0.4195870292
|
||||
],
|
||||
"gru": [
|
||||
0.5724152974,
|
||||
0.525628908,
|
||||
0.4825141777
|
||||
],
|
||||
"lstm": [
|
||||
0.6210193398,
|
||||
0.5547477098,
|
||||
0.6572633416
|
||||
],
|
||||
"conv1d": [
|
||||
0.4478333576,
|
||||
0.4114802966,
|
||||
0.3853060928
|
||||
]
|
||||
}
|
BIN
boss/results/all.pdf
Normal file
BIN
boss/results/all.pdf
Normal file
Binary file not shown.
181
boss/test.py
Normal file
181
boss/test.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
import argparse
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import csv
|
||||
import os
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataloader import DataTest
|
||||
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
|
||||
from models.tom_implicit import ImplicitToMnet
|
||||
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
|
||||
from models.tom_sl import SLToMnet
|
||||
from models.single_mindnet import SingleMindNet
|
||||
from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim
|
||||
|
||||
|
||||
def test(args):
|
||||
|
||||
if args.test_frames == 'friends':
|
||||
test_frame_ids = friends
|
||||
elif args.test_frames == 'strangers':
|
||||
test_frame_ids = strangers
|
||||
elif args.test_frames == 'once':
|
||||
test_frame_ids = tried_once
|
||||
elif args.test_frames == 'twice_thrice':
|
||||
test_frame_ids = tried_twice + tried_thrice
|
||||
elif args.test_frames is None:
|
||||
test_frame_ids = None
|
||||
else: raise NameError
|
||||
|
||||
if args.median is not None:
|
||||
median = (240, False)
|
||||
else:
|
||||
median = None
|
||||
|
||||
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
|
||||
flatten_dim = 2
|
||||
else:
|
||||
flatten_dim = 1
|
||||
|
||||
# load datasets
|
||||
test_dataset = DataTest(
|
||||
args.test_frame_path,
|
||||
args.label_path,
|
||||
args.test_pose_path,
|
||||
args.test_gaze_path,
|
||||
args.test_bbox_path,
|
||||
args.ocr_graph_path,
|
||||
args.presaved,
|
||||
test_frame_ids,
|
||||
median,
|
||||
flatten_dim=flatten_dim
|
||||
)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
collate_fn=pad_collate,
|
||||
pin_memory=args.pin_memory
|
||||
)
|
||||
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
||||
assert args.load_model_path is not None
|
||||
if args.model_type == 'resnet':
|
||||
inp_dim = get_input_dim(args.mods)
|
||||
model = ResNet(inp_dim, device).to(device)
|
||||
elif args.model_type == 'gru':
|
||||
inp_dim = get_input_dim(args.mods)
|
||||
model = ResNetGRU(inp_dim, device).to(device)
|
||||
elif args.model_type == 'lstm':
|
||||
inp_dim = get_input_dim(args.mods)
|
||||
model = ResNetLSTM(inp_dim, device).to(device)
|
||||
elif args.model_type == 'conv1d':
|
||||
inp_dim = get_input_dim(args.mods)
|
||||
model = ResNetConv1D(inp_dim, device).to(device)
|
||||
elif args.model_type == 'tom_cm':
|
||||
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_impl':
|
||||
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_sl':
|
||||
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
elif args.model_type == 'tom_single':
|
||||
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
elif args.model_type == 'tom_cm_xl':
|
||||
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
else: raise NotImplementedError
|
||||
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
|
||||
model.device = device
|
||||
|
||||
model.eval()
|
||||
|
||||
if args.save_preds:
|
||||
# Define the output file path
|
||||
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}'
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path)
|
||||
print(f'Saving predictions in {folder_path}.')
|
||||
|
||||
print('Testing...')
|
||||
num_correct = 0
|
||||
cnt = 0
|
||||
with torch.no_grad():
|
||||
for j, batch in tqdm(enumerate(test_dataloader)):
|
||||
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
|
||||
if frames is not None: frames = frames.to(device, non_blocking=True)
|
||||
if poses is not None: poses = poses.to(device, non_blocking=True)
|
||||
if gazes is not None: gazes = gazes.to(device, non_blocking=True)
|
||||
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True)
|
||||
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True)
|
||||
pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs)
|
||||
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
|
||||
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
|
||||
labels = torch.reshape(labels, (-1, 2)).to(device)
|
||||
batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
|
||||
pred_left_labels, pred_right_labels, labels, sequence_lengths
|
||||
)
|
||||
cnt += batch_num_pred
|
||||
num_correct += batch_num_correct
|
||||
|
||||
if args.save_preds:
|
||||
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
|
||||
data = [(
|
||||
i,
|
||||
torch.argmax(pred_left_labels[i]).cpu().numpy(),
|
||||
torch.argmax(pred_right_labels[i]).cpu().numpy(),
|
||||
labels[:, 0][i].cpu().numpy(),
|
||||
labels[:, 1][i].cpu().numpy()) for i in range(len(labels))
|
||||
]
|
||||
header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label']
|
||||
with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow(header) # Write the header row
|
||||
writer.writerows(data) # Write the data rows
|
||||
|
||||
test_acc = num_correct / cnt
|
||||
print("Test accuracy: {}".format(num_correct / cnt))
|
||||
|
||||
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
|
||||
f.write(str(test_acc))
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Define the command-line arguments
|
||||
parser.add_argument('--gpu_id', type=int)
|
||||
parser.add_argument('--presaved', type=int, default=128)
|
||||
parser.add_argument('--non_blocking', action='store_true')
|
||||
parser.add_argument('--num_workers', type=int, default=4)
|
||||
parser.add_argument('--pin_memory', action='store_true')
|
||||
parser.add_argument('--model_type', type=str)
|
||||
parser.add_argument('--aggr', type=str, default='mult', required=False)
|
||||
parser.add_argument('--use_resnet', action='store_true')
|
||||
parser.add_argument('--hidden_dim', type=int, default=64)
|
||||
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
||||
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
|
||||
parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame')
|
||||
parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose')
|
||||
parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze')
|
||||
parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels')
|
||||
parser.add_argument('--ocr_graph_path', type=str, default='')
|
||||
parser.add_argument('--label_path', type=str, default='outfile')
|
||||
parser.add_argument('--save_path', type=str, default='experiments/')
|
||||
parser.add_argument('--test_frames', type=str, default=None)
|
||||
parser.add_argument('--median', type=int, default=None)
|
||||
parser.add_argument('--load_model_path', type=str)
|
||||
parser.add_argument('--dropout', type=float, default=0.0)
|
||||
parser.add_argument('--save_preds', action='store_true')
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
||||
if not args.aggr:
|
||||
parser.error("The choosen --model_type requires --aggr")
|
||||
if args.model_type == 'tom_sl' and not args.tom_weight:
|
||||
parser.error("The choosen --model_type requires --tom_weight")
|
||||
|
||||
test(args)
|
324
boss/train.py
Normal file
324
boss/train.py
Normal file
|
@ -0,0 +1,324 @@
|
|||
import torch
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import random
|
||||
import datetime
|
||||
import wandb
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from dataloader import Data
|
||||
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
|
||||
from models.tom_implicit import ImplicitToMnet
|
||||
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
|
||||
from models.tom_sl import SLToMnet
|
||||
from models.tom_tf import TFToMnet
|
||||
from models.single_mindnet import SingleMindNet
|
||||
from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim
|
||||
|
||||
|
||||
def train(args):
|
||||
|
||||
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
|
||||
flatten_dim = 2
|
||||
else:
|
||||
flatten_dim = 1
|
||||
|
||||
train_dataset = Data(
|
||||
args.train_frame_path,
|
||||
args.label_path,
|
||||
args.train_pose_path,
|
||||
args.train_gaze_path,
|
||||
args.train_bbox_path,
|
||||
args.ocr_graph_path,
|
||||
presaved=args.presaved,
|
||||
flatten_dim=flatten_dim
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.num_workers,
|
||||
collate_fn=pad_collate,
|
||||
pin_memory=args.pin_memory
|
||||
)
|
||||
val_dataset = Data(
|
||||
args.val_frame_path,
|
||||
args.label_path,
|
||||
args.val_pose_path,
|
||||
args.val_gaze_path,
|
||||
args.val_bbox_path,
|
||||
args.ocr_graph_path,
|
||||
presaved=args.presaved,
|
||||
flatten_dim=flatten_dim
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
collate_fn=pad_collate,
|
||||
pin_memory=args.pin_memory
|
||||
)
|
||||
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
||||
if args.model_type == 'resnet':
|
||||
inp_dim = get_input_dim(args.mods)
|
||||
model = ResNet(inp_dim, device).to(device)
|
||||
elif args.model_type == 'gru':
|
||||
inp_dim = get_input_dim(args.mods)
|
||||
model = ResNetGRU(inp_dim, device).to(device)
|
||||
elif args.model_type == 'lstm':
|
||||
inp_dim = get_input_dim(args.mods)
|
||||
model = ResNetLSTM(inp_dim, device).to(device)
|
||||
elif args.model_type == 'conv1d':
|
||||
inp_dim = get_input_dim(args.mods)
|
||||
model = ResNetConv1D(inp_dim, device).to(device)
|
||||
elif args.model_type == 'tom_cm':
|
||||
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_cm_xl':
|
||||
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_impl':
|
||||
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_sl':
|
||||
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
elif args.model_type == 'tom_single':
|
||||
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
elif args.model_type == 'tom_tf':
|
||||
model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
else: raise NotImplementedError
|
||||
if args.resume_from_checkpoint is not None:
|
||||
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
if args.scheduler == None:
|
||||
scheduler = None
|
||||
else:
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
|
||||
if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1)
|
||||
else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device)
|
||||
stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}}
|
||||
|
||||
max_val_classification_acc = 0
|
||||
max_val_classification_epoch = None
|
||||
counter = 0
|
||||
|
||||
print(f'Number of parameters: {count_parameters(model)}')
|
||||
|
||||
for i in range(args.num_epoch):
|
||||
# training
|
||||
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
|
||||
temp_train_classification_loss = []
|
||||
epoch_num_correct = 0
|
||||
epoch_cnt = 0
|
||||
model.train()
|
||||
for j, batch in tqdm(enumerate(train_dataloader)):
|
||||
if args.use_mixup:
|
||||
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27)
|
||||
else:
|
||||
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
|
||||
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
|
||||
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
|
||||
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
|
||||
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
|
||||
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
|
||||
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
|
||||
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
|
||||
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
|
||||
if args.use_mixup:
|
||||
labels = torch.reshape(labels, (-1, 54)).to(device)
|
||||
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup(
|
||||
pred_left_labels, pred_right_labels, labels, sequence_lengths
|
||||
)
|
||||
loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:])
|
||||
else:
|
||||
labels = torch.reshape(labels, (-1, 2)).to(device)
|
||||
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
|
||||
pred_left_labels, pred_right_labels, labels, sequence_lengths
|
||||
)
|
||||
loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1])
|
||||
epoch_cnt += batch_num_pred
|
||||
epoch_num_correct += batch_num_correct
|
||||
temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2)
|
||||
|
||||
optimizer.zero_grad()
|
||||
if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']})
|
||||
print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format(
|
||||
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc)
|
||||
)
|
||||
|
||||
if scheduler: scheduler.step()
|
||||
|
||||
print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format(
|
||||
i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
|
||||
)
|
||||
stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt)
|
||||
stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt)
|
||||
|
||||
# validation
|
||||
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
|
||||
temp_val_classification_loss = []
|
||||
epoch_num_correct = 0
|
||||
epoch_cnt = 0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for j, batch in tqdm(enumerate(val_dataloader)):
|
||||
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
|
||||
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
|
||||
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
|
||||
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
|
||||
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
|
||||
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
|
||||
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
|
||||
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
|
||||
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
|
||||
labels = torch.reshape(labels, (-1, 2)).to(device)
|
||||
batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
|
||||
pred_left_labels, pred_right_labels, labels, sequence_lengths
|
||||
)
|
||||
epoch_cnt += batch_num_pred
|
||||
epoch_num_correct += batch_num_correct
|
||||
loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1])
|
||||
temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2)
|
||||
|
||||
if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()})
|
||||
print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format(
|
||||
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc)
|
||||
)
|
||||
|
||||
print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format(
|
||||
i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
|
||||
)
|
||||
|
||||
cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt
|
||||
cls_acc = epoch_num_correct / epoch_cnt
|
||||
stats['val']['cls_loss'].append(cls_loss)
|
||||
stats['val']['cls_acc'].append(cls_acc)
|
||||
if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i})
|
||||
|
||||
# check for best stat/model using validation accuracy
|
||||
if stats['val']['cls_acc'][-1] >= max_val_classification_acc:
|
||||
max_val_classification_acc = stats['val']['cls_acc'][-1]
|
||||
max_val_classification_epoch = i+1
|
||||
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
|
||||
counter = 0
|
||||
else:
|
||||
counter += 1
|
||||
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
|
||||
if counter >= args.patience:
|
||||
break
|
||||
|
||||
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
|
||||
f.write('{}\n'.format(CFG))
|
||||
f.write('{}\n'.format(stats))
|
||||
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc))
|
||||
f.close()
|
||||
|
||||
print(f'Results saved in {experiment_save_path}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Define the command-line arguments
|
||||
parser.add_argument('--gpu_id', type=int)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--logger', action='store_true')
|
||||
parser.add_argument('--presaved', type=int, default=128)
|
||||
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
|
||||
parser.add_argument('--use_mixup', action='store_true')
|
||||
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
|
||||
parser.add_argument('--non_blocking', action='store_true')
|
||||
parser.add_argument('--patience', type=int, default=99)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--num_workers', type=int, default=4)
|
||||
parser.add_argument('--pin_memory', action='store_true')
|
||||
parser.add_argument('--num_epoch', type=int, default=300)
|
||||
parser.add_argument('--lr', type=float, default=4e-4)
|
||||
parser.add_argument('--scheduler', type=str, default=None)
|
||||
parser.add_argument('--dropout', type=float, default=0.1)
|
||||
parser.add_argument('--weight_decay', type=float, default=0.005)
|
||||
parser.add_argument('--label_smoothing', type=float, default=0.1)
|
||||
parser.add_argument('--model_type', type=str)
|
||||
parser.add_argument('--aggr', type=str, default='mult', required=False)
|
||||
parser.add_argument('--use_resnet', action='store_true')
|
||||
parser.add_argument('--hidden_dim', type=int, default=64)
|
||||
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
||||
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
|
||||
parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame')
|
||||
parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose')
|
||||
parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze')
|
||||
parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels')
|
||||
parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame')
|
||||
parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose')
|
||||
parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze')
|
||||
parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels')
|
||||
parser.add_argument('--ocr_graph_path', type=str, default='')
|
||||
parser.add_argument('--label_path', type=str, default='outfile')
|
||||
parser.add_argument('--save_path', type=str, default='experiments/')
|
||||
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.use_mixup and not args.mixup_alpha:
|
||||
parser.error("--use_mixup requires --mixup_alpha")
|
||||
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
||||
if not args.aggr:
|
||||
parser.error("The choosen --model_type requires --aggr")
|
||||
if args.model_type == 'tom_sl' and not args.tom_weight:
|
||||
parser.error("The choosen --model_type requires --tom_weight")
|
||||
|
||||
# get experiment ID
|
||||
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path, exist_ok=True)
|
||||
experiment_save_path = os.path.join(args.save_path, experiment_id)
|
||||
os.makedirs(experiment_save_path, exist_ok=True)
|
||||
|
||||
CFG = {
|
||||
'use_ocr_custom_loss': 0,
|
||||
'presaved': args.presaved,
|
||||
'batch_size': args.batch_size,
|
||||
'num_epoch': args.num_epoch,
|
||||
'lr': args.lr,
|
||||
'scheduler': args.scheduler,
|
||||
'weight_decay': args.weight_decay,
|
||||
'model_type': args.model_type,
|
||||
'use_resnet': args.use_resnet,
|
||||
'hidden_dim': args.hidden_dim,
|
||||
'tom_weight': args.tom_weight,
|
||||
'dropout': args.dropout,
|
||||
'label_smoothing': args.label_smoothing,
|
||||
'clip_grad_norm': args.clip_grad_norm,
|
||||
'use_mixup': args.use_mixup,
|
||||
'mixup_alpha': args.mixup_alpha,
|
||||
'non_blocking_tensors': args.non_blocking,
|
||||
'patience': args.patience,
|
||||
'pin_memory': args.pin_memory,
|
||||
'resume_from_checkpoint': args.resume_from_checkpoint,
|
||||
'aggr': args.aggr,
|
||||
'mods': args.mods,
|
||||
'save_path': experiment_save_path ,
|
||||
'seed': args.seed
|
||||
}
|
||||
|
||||
print(CFG)
|
||||
print(f'Saving results in {experiment_save_path}')
|
||||
|
||||
# set seed values
|
||||
if args.logger:
|
||||
wandb.init(project="boss", config=CFG)
|
||||
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
train(args)
|
769
boss/utils.py
Normal file
769
boss/utils.py
Normal file
|
@ -0,0 +1,769 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.manifold import TSNE
|
||||
#from umap import UMAP
|
||||
import json
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
from natsort import natsorted
|
||||
import argparse
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
|
||||
COLORS_DM = {
|
||||
"red": (242/256, 165/256, 179/256),
|
||||
"blue": (195/256, 219/256, 252/256),
|
||||
"green": (156/256, 228/256, 213/256),
|
||||
"yellow": (250/256, 236/256, 144/256),
|
||||
"violet": (207/256, 187/256, 244/256),
|
||||
"orange": (244/256, 188/256, 154/256)
|
||||
}
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
COLORS = sns.color_palette()
|
||||
|
||||
OBJECTS = [
|
||||
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
|
||||
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
|
||||
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
|
||||
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
|
||||
]
|
||||
|
||||
tried_once = [2, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 19, 20, 21, 22, 23, 24,
|
||||
25, 26, 27, 28, 29, 30, 32, 38, 40, 41, 42, 44, 47, 50, 51, 52, 53, 55,
|
||||
56, 57, 61, 63, 65, 67, 68, 70, 71, 72, 73, 74, 76, 80, 81, 83, 85, 87,
|
||||
88, 89, 90, 92, 93, 96, 97, 99, 101, 102, 105, 106, 108, 110, 111, 112,
|
||||
113, 114, 116, 118, 121, 123, 125, 131, 132, 134, 135, 140, 142, 143,
|
||||
145, 146, 148, 149, 151, 152, 154, 155, 156, 157, 160, 161, 162, 165,
|
||||
169, 170, 171, 173, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185,
|
||||
186, 187, 190, 191, 194, 196, 203, 204, 206, 207, 208, 209, 210, 211,
|
||||
213, 214, 216, 218, 219, 220, 222, 225, 227, 228, 229, 232, 233, 235,
|
||||
236, 237, 238, 239, 242, 243, 246, 247, 249, 251, 252, 254, 255, 256,
|
||||
257, 260, 261, 262, 263, 265, 266, 268, 270, 272, 275, 277, 278, 279,
|
||||
281, 282, 287, 290, 296, 298]
|
||||
|
||||
tried_twice = [1, 3, 11, 13, 14, 16, 17, 31, 33, 35, 36, 37, 39,
|
||||
43, 45, 49, 54, 58, 59, 60, 62, 64, 66, 69, 75, 77, 79, 82, 84, 86,
|
||||
91, 94, 95, 98, 100, 103, 104, 107, 109, 115, 117, 119, 120, 122,
|
||||
124, 126, 127, 128, 129, 130, 133, 136, 137, 138, 139, 141, 144,
|
||||
147, 150, 153, 158, 159, 164, 166, 167, 168, 172, 174, 177, 188,
|
||||
189, 192, 193, 195, 197, 198, 199, 200, 201, 202, 205, 212, 215,
|
||||
217, 221, 223, 224, 226, 230, 231, 234, 240, 241, 244, 245, 248,
|
||||
250, 253, 258, 259, 264, 267, 269, 271, 273, 274, 276, 280, 283,
|
||||
284, 285, 286, 288, 291, 292, 293, 294, 295, 297, 299]
|
||||
|
||||
tried_thrice = [34, 46, 48, 78, 163, 289]
|
||||
|
||||
friends = [i for i in range(150, 300)]
|
||||
strangers = [i for i in range(0, 150)]
|
||||
|
||||
|
||||
def get_input_dim(modalities):
|
||||
dimensions = {
|
||||
'rgb': 512,
|
||||
'pose': 150,
|
||||
'gaze': 6,
|
||||
'ocr': 64,
|
||||
'bbox': 108
|
||||
}
|
||||
sum_of_dimensions = sum(dimensions[modality] for modality in modalities)
|
||||
return sum_of_dimensions
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
|
||||
def onehot(label, n_classes):
|
||||
if len(label.size()) == 3:
|
||||
batch_size, seq_len, _ = label.size()
|
||||
label_1_2d = label[:,:,0].contiguous().view(batch_size*seq_len, 1)
|
||||
label_2_2d = label[:,:,1].contiguous().view(batch_size*seq_len, 1)
|
||||
#onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1)
|
||||
#onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1)
|
||||
#onehot_2d = torch.cat((onehot_1_2d, onehot_2_2d), dim=1)
|
||||
#onehot_3d = onehot_2d.view(batch_size, seq_len, 2*n_classes)
|
||||
onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1).view(-1, seq_len, n_classes)
|
||||
onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1).view(-1, seq_len, n_classes)
|
||||
onehot_3d = torch.cat((onehot_1_2d, onehot_2_2d), dim=2)
|
||||
return onehot_3d
|
||||
else:
|
||||
return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
|
||||
|
||||
|
||||
def mixup(data, alpha, n_classes):
|
||||
lam = torch.FloatTensor([np.random.beta(alpha, alpha)])
|
||||
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = data
|
||||
indices = torch.randperm(labels.size(0))
|
||||
# labels
|
||||
labels2 = labels[indices]
|
||||
labels = onehot(labels, n_classes)
|
||||
labels2 = onehot(labels2, n_classes)
|
||||
labels = labels * lam + labels2 * (1 - lam)
|
||||
# frames
|
||||
frames2 = frames[indices]
|
||||
frames = frames * lam + frames2 * (1 - lam)
|
||||
# poses
|
||||
poses2 = poses[indices]
|
||||
poses = poses * lam + poses2 * (1 - lam)
|
||||
# gazes
|
||||
gazes2 = gazes[indices]
|
||||
gazes = gazes * lam + gazes2 * (1 - lam)
|
||||
# bboxes
|
||||
bboxes2 = bboxes[indices]
|
||||
bboxes = bboxes * lam + bboxes2 * (1 - lam)
|
||||
return frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths
|
||||
|
||||
|
||||
def get_classification_accuracy(pred_left_labels, pred_right_labels, labels, sequence_lengths):
|
||||
max_len = max(sequence_lengths)
|
||||
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
|
||||
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
|
||||
labels = torch.reshape(labels, (-1, max_len, 2))
|
||||
left_correct = torch.argmax(pred_left_labels, 2) == labels[:,:,0]
|
||||
right_correct = torch.argmax(pred_right_labels, 2) == labels[:,:,1]
|
||||
num_pred = sum(sequence_lengths) * 2
|
||||
num_correct = 0
|
||||
for i in range(len(sequence_lengths)):
|
||||
size = sequence_lengths[i]
|
||||
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
|
||||
acc = num_correct / num_pred
|
||||
return acc, num_correct, num_pred
|
||||
|
||||
|
||||
def get_classification_accuracy_mixup(pred_left_labels, pred_right_labels, labels, sequence_lengths):
|
||||
max_len = max(sequence_lengths)
|
||||
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
|
||||
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
|
||||
labels = torch.reshape(labels, (-1, max_len, 54))
|
||||
left_correct = torch.argmax(pred_left_labels, 2) == torch.argmax(labels[:,:,:27], 2)
|
||||
right_correct = torch.argmax(pred_right_labels, 2) == torch.argmax(labels[:,:,27:], 2)
|
||||
num_pred = sum(sequence_lengths) * 2
|
||||
num_correct = 0
|
||||
for i in range(len(sequence_lengths)):
|
||||
size = sequence_lengths[i]
|
||||
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
|
||||
acc = num_correct / num_pred
|
||||
return acc, num_correct, num_pred
|
||||
|
||||
|
||||
def pad_collate(batch):
|
||||
(aa, bb, cc, dd, ee, ff) = zip(*batch)
|
||||
seq_lens = [len(a) for a in aa]
|
||||
aa_pad = pad_sequence(aa, batch_first=True, padding_value=0)
|
||||
bb_pad = pad_sequence(bb, batch_first=True, padding_value=-1)
|
||||
if cc[0] is not None:
|
||||
cc_pad = pad_sequence(cc, batch_first=True, padding_value=0)
|
||||
else:
|
||||
cc_pad = None
|
||||
if dd[0] is not None:
|
||||
dd_pad = pad_sequence(dd, batch_first=True, padding_value=0)
|
||||
else:
|
||||
dd_pad = None
|
||||
if ee[0] is not None:
|
||||
ee_pad = pad_sequence(ee, batch_first=True, padding_value=0)
|
||||
else:
|
||||
ee_pad = None
|
||||
if ff[0] is not None:
|
||||
ff_pad = pad_sequence(ff, batch_first=True, padding_value=0)
|
||||
else:
|
||||
ff_pad = None
|
||||
return aa_pad, bb_pad, cc_pad, dd_pad, ee_pad, ff_pad, seq_lens
|
||||
|
||||
|
||||
def ocr_loss(pL, lL, pR, lR, ocr, loss, eta=10.0, mode='abs'):
|
||||
"""
|
||||
Custom loss based on the negative OCRL matrix.
|
||||
|
||||
Input:
|
||||
pL: tensor of shape [batch_size, num_classes] representing left belief predictions
|
||||
lL: tensor of shape [batch_size] representing the left belief labels
|
||||
pR: tensor of shape [batch_size, num_classes] representing right belief predictions
|
||||
lR: tensor of shape [batch_size] representing the right belief labels
|
||||
ocr: negative OCR matrix (i.e. 1-OCR)
|
||||
loss: loss function
|
||||
eta: hyperparameter for the interaction term
|
||||
|
||||
Output:
|
||||
Final loss resulting from weighting left and right loss with the respective
|
||||
OCR coefficients and summing them.
|
||||
"""
|
||||
bs = pL.shape[0]
|
||||
if len([*lR[0].size()]) > 0: # for mixup
|
||||
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
|
||||
g = torch.tensor([ocr[torch.argmax(lL[i]), torch.argmax(lR[i])] for i in range(bs)], device=pL.device)
|
||||
else:
|
||||
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
|
||||
g = torch.tensor([ocr[lL[i], lR[i]] for i in range(bs)], device=pL.device)
|
||||
left_loss = torch.mean(loss(pL, lL))
|
||||
right_loss = torch.mean(loss(pR, lR))
|
||||
if mode == 'abs':
|
||||
interaction_loss = torch.mean(torch.abs(g - p))
|
||||
elif mode == 'mse':
|
||||
interaction_loss = torch.mean(torch.pow(g - p, 2))
|
||||
else: raise NotImplementedError
|
||||
eta = (left_loss + right_loss) / interaction_loss
|
||||
interaction_loss = interaction_loss * eta
|
||||
print(f"Left: {left_loss} --- Right: {right_loss} --- Interaction: {interaction_loss}")
|
||||
return left_loss + right_loss + interaction_loss
|
||||
|
||||
|
||||
def spherical2cartesial(x):
|
||||
"""
|
||||
From https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing#scrollTo=78QhNw4MYSYp
|
||||
"""
|
||||
output = torch.zeros(x.size(0),3)
|
||||
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
|
||||
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
|
||||
output[:,1] = torch.sin(x[:,1])
|
||||
return output
|
||||
|
||||
|
||||
def presave():
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import os
|
||||
from natsort import natsorted
|
||||
import pickle as pkl
|
||||
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize((128, 128)),
|
||||
#transforms.Resize(256),
|
||||
#transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
frame_path = '/scratch/bortoletto/data/boss/test/frame'
|
||||
frame_dirs = os.listdir(frame_path)
|
||||
frame_paths = []
|
||||
for frame_dir in natsorted(frame_dirs):
|
||||
paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
||||
frame_paths.append(paths)
|
||||
|
||||
save_folder = '/scratch/bortoletto/data/boss/presaved128/'+frame_path.split('/')[-2]+'/'+frame_path.split('/')[-1]
|
||||
if not os.path.exists(save_folder):
|
||||
os.makedirs(save_folder)
|
||||
print(save_folder, 'created')
|
||||
|
||||
for video in natsorted(frame_paths):
|
||||
print(video[0].split('/')[-2], end='\r')
|
||||
images = [preprocess(Image.open(i)) for i in video]
|
||||
strings = video[0].split('/')
|
||||
with open(save_folder+'/'+strings[7]+'.pkl', 'wb') as f:
|
||||
pkl.dump(images, f)
|
||||
|
||||
|
||||
def compute_cosine_similarity(tensor1, tensor2):
|
||||
return F.cosine_similarity(tensor1, tensor2, dim=-1)
|
||||
|
||||
|
||||
def find_most_similar_embedding(rgb, pose, gaze, ocr, bbox, repr):
|
||||
gaze_similarity = compute_cosine_similarity(gaze, repr)
|
||||
pose_similarity = compute_cosine_similarity(pose, repr)
|
||||
ocr_similarity = compute_cosine_similarity(ocr, repr)
|
||||
bbox_similarity = compute_cosine_similarity(bbox, repr)
|
||||
rgb_similarity = compute_cosine_similarity(rgb, repr)
|
||||
similarities = torch.stack([gaze_similarity, pose_similarity, ocr_similarity, bbox_similarity, rgb_similarity])
|
||||
max_index = torch.argmax(similarities, dim=0)
|
||||
main_modality = []
|
||||
main_modality_name = []
|
||||
for idx in max_index:
|
||||
if idx == 0:
|
||||
main_modality.append(gaze)
|
||||
main_modality_name.append('gaze')
|
||||
elif idx == 1:
|
||||
main_modality.append(pose)
|
||||
main_modality_name.append('pose')
|
||||
elif idx == 2:
|
||||
main_modality.append(ocr)
|
||||
main_modality_name.append('ocr')
|
||||
elif idx == 3:
|
||||
main_modality.append(bbox)
|
||||
main_modality_name.append('bbox')
|
||||
else:
|
||||
main_modality.append(rgb)
|
||||
main_modality_name.append('rgb')
|
||||
return main_modality, main_modality_name
|
||||
|
||||
|
||||
def plot_similarity_histogram(values, filename, labels):
|
||||
def count_elements(list_of_strings, target_list):
|
||||
counts = []
|
||||
for element in list_of_strings:
|
||||
count = target_list.count(element)
|
||||
counts.append(count)
|
||||
return counts
|
||||
fig, ax = plt.subplots(figsize=(12,4))
|
||||
colors = ["red", "blue", "green", "violet"]
|
||||
# Remove spines
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.spines['right'].set_visible(False)
|
||||
colors = [MTOM_COLORS['MN1'], MTOM_COLORS['MN2'], MTOM_COLORS['CG'], MTOM_COLORS['CG']]
|
||||
alphas = [0.6, 0.6, 0.6, 0.6]
|
||||
edgecolors = ['black', 'black', MTOM_COLORS['MN1'], MTOM_COLORS['MN2']]
|
||||
linewidths = [1.0, 1.0, 5.0, 5.0]
|
||||
if isinstance(values[0], list):
|
||||
num_lists = len(values)
|
||||
bar_width = 0.8 / num_lists
|
||||
for i, val in enumerate(values):
|
||||
unique_strings = ['rgb', 'pose', 'gaze', 'bbox', 'ocr']
|
||||
counts = count_elements(unique_strings, val)
|
||||
x = np.arange(len(unique_strings))
|
||||
x_shifted = x + (i - (len(labels) - 1) / 2) * bar_width #x - (bar_width * num_lists / 2) + (bar_width * i)
|
||||
ax.bar(x_shifted, counts, width=bar_width, label=f'{labels[i]}', color=colors[i], edgecolor=edgecolors[i], linewidth=linewidths[i], alpha=alphas[i])
|
||||
ax.set_xlabel('Modality', fontsize=18)
|
||||
ax.set_ylabel('Counts', fontsize=18)
|
||||
ax.set_xticks(np.arange(len(unique_strings)))
|
||||
ax.set_xticklabels(unique_strings, fontsize=18)
|
||||
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
|
||||
ax.legend(fontsize=18)
|
||||
ax.grid(axis='y')
|
||||
plt.savefig(filename, bbox_inches='tight')
|
||||
else:
|
||||
unique_strings, counts = np.unique(values, return_counts=True)
|
||||
ax.bar(unique_strings, counts)
|
||||
ax.set_xlabel('Modality')
|
||||
ax.set_ylabel('Counts')
|
||||
plt.savefig(filename, bbox_inches='tight')
|
||||
|
||||
|
||||
def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
|
||||
means = []
|
||||
stds = []
|
||||
for key, values in data.items():
|
||||
mean = np.mean(values)
|
||||
std = np.std(values)
|
||||
means.append(mean)
|
||||
stds.append(std)
|
||||
fig, ax = plt.subplots(figsize=size)
|
||||
x = np.arange(len(data))
|
||||
width = 0.6
|
||||
rects1 = ax.bar(
|
||||
x, means, width, label='Mean', yerr=stds,
|
||||
capsize=5,
|
||||
color='teal' if colors is None else colors,
|
||||
edgecolor='black',
|
||||
linewidth=1.5,
|
||||
alpha=0.6
|
||||
)
|
||||
ax.set_ylabel('Accuracy', fontsize=18)
|
||||
#ax.set_title('Mean and Standard Deviation of Results', fontsize=14)
|
||||
if filename == 'results/all':
|
||||
xticklabels = [
|
||||
'Base',
|
||||
'DB',
|
||||
'CG$\otimes$', 'CG$\oplus$', 'CG$\odot$', 'CG$\parallel$',
|
||||
'IC$\otimes$', 'IC$\oplus$', 'IC$\odot$', 'IC$\parallel$',
|
||||
'CNN', 'CNN+GRU', 'CNN+LSTM', 'CNN+Conv1D'
|
||||
]
|
||||
else:
|
||||
xticklabels = list(data.keys())
|
||||
ax.set_xticks(np.arange(len(xticklabels)))
|
||||
ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=16) #data.keys(), rotation=rotation, fontsize=16)
|
||||
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
|
||||
#ax.grid(axis='y')
|
||||
#ax.set_axisbelow(True)
|
||||
#ax.legend(loc='upper right')
|
||||
# Add value labels above each bar, avoiding overlapping with error bars
|
||||
for rect, std in zip(rects1, stds):
|
||||
height = rect.get_height()
|
||||
if height + std < ax.get_ylim()[1]:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
|
||||
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
|
||||
else:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
|
||||
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
|
||||
# Remove spines
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.spines['right'].set_visible(False)
|
||||
ax.grid(axis='x')
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{filename}.pdf', bbox_inches='tight')
|
||||
|
||||
|
||||
def plot_confusion_matrices(left_cm, right_cm, labels, title, annot=True):
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
||||
left_xticklabels = [OBJECTS[i] for i in labels[0]]
|
||||
left_yticklabels = [OBJECTS[i] for i in labels[1]]
|
||||
right_xticklabels = [OBJECTS[i] for i in labels[2]]
|
||||
right_yticklabels = [OBJECTS[i] for i in labels[3]]
|
||||
sns.heatmap(
|
||||
left_cm,
|
||||
annot=annot,
|
||||
fmt='.0f',
|
||||
cmap='Blues',
|
||||
cbar=False,
|
||||
xticklabels=left_xticklabels,
|
||||
yticklabels=left_yticklabels,
|
||||
ax=ax1
|
||||
)
|
||||
ax1.set_xlabel('Predicted')
|
||||
ax1.set_ylabel('True')
|
||||
ax1.set_title('Left Confusion Matrix')
|
||||
sns.heatmap(
|
||||
right_cm,
|
||||
annot=annot,
|
||||
fmt='.0f',
|
||||
cmap='Blues',
|
||||
cbar=False, #True if annot is False else False,
|
||||
xticklabels=right_xticklabels,
|
||||
yticklabels=right_yticklabels,
|
||||
ax=ax2
|
||||
)
|
||||
ax2.set_xlabel('Predicted')
|
||||
ax2.set_ylabel('True')
|
||||
ax2.set_title('Right Confusion Matrix')
|
||||
#plt.suptitle(title)
|
||||
plt.tight_layout()
|
||||
plt.savefig(title + '.pdf')
|
||||
plt.close()
|
||||
|
||||
def plot_confusion_matrix(confusion_matrix, labels, title, annot=True):
|
||||
plt.figure(figsize=(6, 6))
|
||||
xticklabels = [OBJECTS[i] for i in labels[0]]
|
||||
yticklabels = [OBJECTS[i] for i in labels[1]]
|
||||
sns.heatmap(
|
||||
confusion_matrix,
|
||||
annot=annot,
|
||||
fmt='.0f',
|
||||
cmap='Blues',
|
||||
cbar=False,
|
||||
xticklabels=xticklabels,
|
||||
yticklabels=yticklabels
|
||||
)
|
||||
plt.xlabel('Predicted')
|
||||
plt.ylabel('True')
|
||||
#plt.title(title)
|
||||
plt.tight_layout()
|
||||
plt.savefig(title + '.pdf')
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--task', type=str, choices=['confusion', 'similarity', 'scores', 'friends_vs_strangers'])
|
||||
parser.add_argument('--folder_path', type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.task == 'similarity':
|
||||
sns.set_theme(style='white')
|
||||
else:
|
||||
sns.set_theme(style='whitegrid')
|
||||
|
||||
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||||
|
||||
if args.task == 'similarity':
|
||||
|
||||
out_left_main_mods_full_test = []
|
||||
out_right_main_mods_full_test = []
|
||||
cell_left_main_mods_full_test = []
|
||||
cell_right_main_mods_full_test = []
|
||||
cm_left_main_mods_full_test = []
|
||||
cm_right_main_mods_full_test = []
|
||||
|
||||
for i in range(60):
|
||||
|
||||
print(f'Computing analysis for test video {i}...', end='\r')
|
||||
|
||||
emb_file = os.path.join(args.folder_path, f'{i}.pt')
|
||||
data = torch.load(emb_file)
|
||||
if len(data) == 14: # implicit
|
||||
model = 'impl'
|
||||
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||||
out_left = out_left.squeeze(0)
|
||||
cell_left = cell_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
cell_right = cell_right.squeeze(0)
|
||||
elif len(data) == 13: # common mind
|
||||
model = 'cm'
|
||||
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||||
out_left = out_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
common_mind = common_mind.squeeze(0)
|
||||
elif len(data) == 12: # speaker-listener
|
||||
model = 'sl'
|
||||
out_left, out_right, feats = data[0], data[1], data[2:]
|
||||
out_left = out_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
|
||||
|
||||
# ====== PCA for left and right embeddings ====== #
|
||||
|
||||
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
pca_result = pca.fit_transform(out_left_and_right)
|
||||
|
||||
# Separate the PCA results for each tensor
|
||||
pca_result_left = pca_result[:out_left.shape[0]]
|
||||
pca_result_right = pca_result[out_right.shape[0]:]
|
||||
|
||||
plt.figure(figsize=(7,6))
|
||||
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='MindNet$_1$', color=MTOM_COLORS['MN1'], s=100)
|
||||
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='MindNet$_2$', color=MTOM_COLORS['MN2'], s=100)
|
||||
plt.xlabel('Principal Component 1', fontsize=30)
|
||||
plt.ylabel('Principal Component 2', fontsize=30)
|
||||
plt.grid(False)
|
||||
plt.legend(fontsize=30)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{args.folder_path}/{i}_pca.pdf')
|
||||
plt.close()
|
||||
|
||||
# ====== Feature similarity ====== #
|
||||
|
||||
if len(feats) == 10:
|
||||
left_rgb, left_ocr, left_pose, left_gaze, left_bbox, right_rgb, right_ocr, right_pose, right_gaze, right_bbox = feats
|
||||
left_rgb = left_rgb.squeeze(0)
|
||||
left_ocr = left_ocr.squeeze(0)
|
||||
left_pose = left_pose.squeeze(0)
|
||||
left_gaze = left_gaze.squeeze(0)
|
||||
left_bbox = left_bbox.squeeze(0)
|
||||
right_rgb = right_rgb.squeeze(0)
|
||||
right_ocr = right_ocr.squeeze(0)
|
||||
right_pose = right_pose.squeeze(0)
|
||||
right_gaze = right_gaze.squeeze(0)
|
||||
right_bbox = right_bbox.squeeze(0)
|
||||
else: raise NotImplementedError("Ablated versions are not supported yet.")
|
||||
|
||||
# out: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
|
||||
# cell: [1, 1, dim] --------- squeeze ---> [1, dim]
|
||||
# cm: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
|
||||
# feat: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
|
||||
|
||||
_, out_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, out_left)
|
||||
_, out_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, out_right)
|
||||
out_left_main_mods_full_test += out_left_main_mods
|
||||
out_right_main_mods_full_test += out_right_main_mods
|
||||
if model == 'impl':
|
||||
_, cell_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, cell_left)
|
||||
_, cell_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, cell_right)
|
||||
cell_left_main_mods_full_test += cell_left_main_mods
|
||||
cell_right_main_mods_full_test += cell_right_main_mods
|
||||
if model == 'cm':
|
||||
_, cm_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, common_mind)
|
||||
_, cm_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, common_mind)
|
||||
cm_left_main_mods_full_test += cm_left_main_mods
|
||||
cm_right_main_mods_full_test += cm_right_main_mods
|
||||
|
||||
if model == 'impl':
|
||||
plot_similarity_histogram(
|
||||
[out_left_main_mods_full_test,
|
||||
out_right_main_mods_full_test,
|
||||
cell_left_main_mods_full_test,
|
||||
cell_right_main_mods_full_test],
|
||||
f'{args.folder_path}/boss_similartiy_impl_hist_all.pdf',
|
||||
[r'$h_1$', r'$h_2$', r'$c_1$', r'$c_2$']
|
||||
)
|
||||
elif model == 'cm':
|
||||
plot_similarity_histogram(
|
||||
[out_left_main_mods_full_test,
|
||||
out_right_main_mods_full_test,
|
||||
cm_left_main_mods_full_test,
|
||||
cm_right_main_mods_full_test],
|
||||
f'{args.folder_path}/boss_similarity_cm_hist_all.pdf',
|
||||
[r'$h_1$', r'$h_2$', r'$cg$ w/ 1', r'$cg$ w/ 2']
|
||||
)
|
||||
elif model == 'sl':
|
||||
plot_similarity_histogram(
|
||||
[out_left_main_mods_full_test,
|
||||
out_right_main_mods_full_test],
|
||||
f'{args.folder_path}/boss_similarity_sl_hist_all.py',
|
||||
[r'$h_1$', r'$h_2$']
|
||||
)
|
||||
|
||||
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||||
|
||||
elif args.task == 'scores':
|
||||
|
||||
|
||||
all = json.load(open('results/all.json'))
|
||||
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
|
||||
|
||||
plot_scores_histogram(
|
||||
all,
|
||||
filename='results/all',
|
||||
size=(10,4),
|
||||
rotation=45,
|
||||
colors=[COLORS[7]] + [MTOM_COLORS['DB']] + [MTOM_COLORS['CG']]*4 + [MTOM_COLORS['IC']]*4 + [COLORS[4]]*4
|
||||
)
|
||||
|
||||
plot_scores_histogram(
|
||||
abl_tom_cm_concat,
|
||||
size=(10,4),
|
||||
filename='results/abl_cm_concat',
|
||||
colors=[MTOM_COLORS['CG']]*8
|
||||
)
|
||||
|
||||
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||||
|
||||
elif args.task == 'confusion':
|
||||
|
||||
# List to store confusion matrices
|
||||
all_df = []
|
||||
left_matrices = []
|
||||
right_matrices = []
|
||||
|
||||
# Iterate over the CSV files
|
||||
for filename in natsorted(os.listdir(args.folder_path)):
|
||||
if filename.endswith('.csv'):
|
||||
file_path = os.path.join(args.folder_path, filename)
|
||||
|
||||
print(f'Processing {file_path}...')
|
||||
|
||||
# Read the CSV file
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# Extract the left and right labels and predictions
|
||||
left_labels = df['left_label']
|
||||
right_labels = df['right_label']
|
||||
left_preds = df['left_pred']
|
||||
right_preds = df['right_pred']
|
||||
|
||||
# Calculate the confusion matrices for left and right
|
||||
left_cm = pd.crosstab(left_labels, left_preds)
|
||||
right_cm = pd.crosstab(right_labels, right_preds)
|
||||
|
||||
# Append the confusion matrices to the list
|
||||
left_matrices.append(left_cm)
|
||||
right_matrices.append(right_cm)
|
||||
all_df.append(df)
|
||||
|
||||
# Plot and save the confusion matrices for left and right
|
||||
for i, cm in enumerate(zip(left_matrices, right_matrices)):
|
||||
print(f'Computing confusion matrices for video {i}...', end='\r')
|
||||
plot_confusion_matrices(
|
||||
cm[0],
|
||||
cm[1],
|
||||
labels=[cm[0].columns, cm[0].index, cm[1].columns, cm[1].index],
|
||||
title=f'{args.folder_path}/{i}_cm'
|
||||
)
|
||||
|
||||
merged_df = pd.concat(all_df).reset_index(drop=True)
|
||||
|
||||
merged_left_cm = pd.crosstab(merged_df['left_label'], merged_df['left_pred'])
|
||||
merged_right_cm = pd.crosstab(merged_df['right_label'], merged_df['right_pred'])
|
||||
plot_confusion_matrices(
|
||||
merged_left_cm,
|
||||
merged_right_cm,
|
||||
labels=[merged_left_cm.columns, merged_left_cm.index, merged_right_cm.columns, merged_right_cm.index],
|
||||
title=f'{args.folder_path}/all_lr',
|
||||
annot=False
|
||||
)
|
||||
|
||||
merged_preds = pd.concat([merged_df['left_pred'], merged_df['right_pred']]).reset_index(drop=True)
|
||||
merged_labels = pd.concat([merged_df['left_label'], merged_df['right_label']]).reset_index(drop=True)
|
||||
merged_all_cm = pd.crosstab(merged_labels, merged_preds)
|
||||
plot_confusion_matrix(
|
||||
merged_all_cm,
|
||||
labels=[merged_all_cm.columns, merged_all_cm.index],
|
||||
title=f'{args.folder_path}/all_merged',
|
||||
annot=False
|
||||
)
|
||||
|
||||
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||||
|
||||
elif args.task == 'friends_vs_strangers':
|
||||
|
||||
# friends ids: 30-59
|
||||
# stranger ids: 0-29
|
||||
out_left_list = []
|
||||
out_right_list = []
|
||||
cell_left_list = []
|
||||
cell_right_list = []
|
||||
common_mind_list = []
|
||||
|
||||
for i in range(60):
|
||||
|
||||
print(f'Computing analysis for test video {i}...', end='\r')
|
||||
|
||||
emb_file = os.path.join(args.folder_path, f'{i}.pt')
|
||||
data = torch.load(emb_file)
|
||||
if len(data) == 14: # implicit
|
||||
model = 'impl'
|
||||
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||||
out_left = out_left.squeeze(0)
|
||||
cell_left = cell_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
cell_right = cell_right.squeeze(0)
|
||||
out_left_list.append(out_left)
|
||||
out_right_list.append(out_right)
|
||||
cell_left_list.append(cell_left)
|
||||
cell_right_list.append(cell_right)
|
||||
elif len(data) == 13: # common mind
|
||||
model = 'cm'
|
||||
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||||
out_left = out_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
common_mind = common_mind.squeeze(0)
|
||||
out_left_list.append(out_left)
|
||||
out_right_list.append(out_right)
|
||||
common_mind_list.append(common_mind)
|
||||
elif len(data) == 12: # speaker-listener
|
||||
model = 'sl'
|
||||
out_left, out_right, feats = data[0], data[1], data[2:]
|
||||
out_left = out_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
out_left_list.append(out_left)
|
||||
out_right_list.append(out_right)
|
||||
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
|
||||
|
||||
# ====== PCA for left and right embeddings ====== #
|
||||
|
||||
print('\rComputing PCA...')
|
||||
|
||||
strangers_nframes = sum([out_left_list[i].shape[0] for i in range(30)])
|
||||
|
||||
left = torch.cat(out_left_list, 0)
|
||||
right = torch.cat(out_right_list, 0)
|
||||
|
||||
out_left_and_right = torch.cat([left, right], axis=0).numpy()
|
||||
#out_left_and_right = StandardScaler().fit_transform(out_left_and_right)
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
pca_result = pca.fit_transform(out_left_and_right)
|
||||
#pca_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(out_left_and_right)
|
||||
#pca_result = UMAP().fit_transform(out_left_and_right)
|
||||
|
||||
# Separate the PCA results for each tensor
|
||||
#pca_result_left = pca_result[:left.shape[0]]
|
||||
#pca_result_right = pca_result[right.shape[0]:]
|
||||
pca_result_strangers = pca_result[:strangers_nframes]
|
||||
pca_result_friends = pca_result[strangers_nframes:]
|
||||
|
||||
#plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='Left')
|
||||
#plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='Right')
|
||||
plt.scatter(pca_result_friends[:, 0], pca_result_friends[:, 1], label='Friends')
|
||||
plt.scatter(pca_result_strangers[:, 0], pca_result_strangers[:, 1], label='Strangers')
|
||||
plt.xlabel('Principal Component 1')
|
||||
plt.ylabel('Principal Component 2')
|
||||
plt.legend()
|
||||
plt.savefig(f'{args.folder_path}/friends_vs_strangers_pca.pdf')
|
||||
plt.close()
|
||||
|
||||
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||||
|
||||
else:
|
||||
|
||||
raise NameError
|
196
tbd/.gitignore
vendored
Normal file
196
tbd/.gitignore
vendored
Normal file
|
@ -0,0 +1,196 @@
|
|||
experiments
|
||||
wandb
|
||||
predictions
|
||||
|
||||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
|
||||
|
||||
### Linux ###
|
||||
*~
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python,linux
|
16
tbd/README.md
Normal file
16
tbd/README.md
Normal file
|
@ -0,0 +1,16 @@
|
|||
# TBD
|
||||
|
||||
# Data
|
||||
The original code can be found [here](https://github.com/LifengFan/Triadic-Belief-Dynamics). The dataset is not directly available but must be requested using the link to the Google form provided in the [README](https://github.com/LifengFan/Triadic-Belief-Dynamics?tab=readme-ov-file#dataset).
|
||||
|
||||
## Installing Dependencies
|
||||
Run `conda env create -f environment.yml`.
|
||||
|
||||
## Train
|
||||
`source run_train.sh`.
|
||||
|
||||
## Test
|
||||
`source run_test.sh`. **Make sure to use the same random seed used for training**, otherwise the splits will be different and you will likely have a data leakage.
|
||||
|
||||
## Visualisations
|
||||
The plots are made using `utils/fb_scores_err.py` (false belief analysis) and `utils/similarity.py` (PCA of latent representations).
|
100
tbd/environment.yml
Normal file
100
tbd/environment.yml
Normal file
|
@ -0,0 +1,100 @@
|
|||
name: tbd
|
||||
channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
- pytorch
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- ca-certificates=2023.01.10=h06a4308_0
|
||||
- ld_impl_linux-64=2.38=h1181459_1
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- ncurses=6.4=h6a678d5_0
|
||||
- openssl=1.1.1t=h7f8727e_0
|
||||
- pip=23.0.1=py38h06a4308_0
|
||||
- python=3.8.10=h12debd9_8
|
||||
- readline=8.2=h5eee18b_0
|
||||
- setuptools=66.0.0=py38h06a4308_0
|
||||
- sqlite=3.41.2=h5eee18b_0
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- wheel=0.38.4=py38h06a4308_0
|
||||
- xz=5.4.2=h5eee18b_0
|
||||
- zlib=1.2.13=h5eee18b_0
|
||||
- pip:
|
||||
- appdirs==1.4.4
|
||||
- beautifulsoup4==4.12.2
|
||||
- certifi==2023.5.7
|
||||
- charset-normalizer==3.1.0
|
||||
- click==8.1.3
|
||||
- cmake==3.26.4
|
||||
- contourpy==1.1.0
|
||||
- cycler==0.11.0
|
||||
- docker-pycreds==0.4.0
|
||||
- einops==0.6.1
|
||||
- filelock==3.12.0
|
||||
- fonttools==4.40.0
|
||||
- gdown==4.7.1
|
||||
- gitdb==4.0.10
|
||||
- gitpython==3.1.31
|
||||
- idna==3.4
|
||||
- importlib-resources==5.12.0
|
||||
- jinja2==3.1.2
|
||||
- joblib==1.3.1
|
||||
- kiwisolver==1.4.4
|
||||
- lit==16.0.6
|
||||
- markupsafe==2.1.3
|
||||
- matplotlib==3.7.1
|
||||
- memory-efficient-attention-pytorch==0.1.2
|
||||
- mpmath==1.3.0
|
||||
- networkx==3.1
|
||||
- numpy==1.24.4
|
||||
- nvidia-cublas-cu11==11.10.3.66
|
||||
- nvidia-cuda-cupti-cu11==11.7.101
|
||||
- nvidia-cuda-nvrtc-cu11==11.7.99
|
||||
- nvidia-cuda-runtime-cu11==11.7.99
|
||||
- nvidia-cudnn-cu11==8.5.0.96
|
||||
- nvidia-cufft-cu11==10.9.0.58
|
||||
- nvidia-curand-cu11==10.2.10.91
|
||||
- nvidia-cusolver-cu11==11.4.0.1
|
||||
- nvidia-cusparse-cu11==11.7.4.91
|
||||
- nvidia-nccl-cu11==2.14.3
|
||||
- nvidia-nvtx-cu11==11.7.91
|
||||
- opencv-python==4.8.0.74
|
||||
- packaging==23.1
|
||||
- pandas==2.0.3
|
||||
- pathtools==0.1.2
|
||||
- pillow==9.5.0
|
||||
- protobuf==4.23.3
|
||||
- psutil==5.9.5
|
||||
- pyparsing==3.1.0
|
||||
- pysocks==1.7.1
|
||||
- python-dateutil==2.8.2
|
||||
- pytz==2023.3
|
||||
- pyyaml==6.0
|
||||
- requests==2.30.0
|
||||
- scikit-learn==1.3.0
|
||||
- scipy==1.10.1
|
||||
- seaborn==0.12.2
|
||||
- sentry-sdk==1.27.0
|
||||
- setproctitle==1.3.2
|
||||
- six==1.16.0
|
||||
- smmap==5.0.0
|
||||
- soupsieve==2.4.1
|
||||
- sympy==1.12
|
||||
- threadpoolctl==3.1.0
|
||||
- torch==2.0.1
|
||||
- torch-geometric==2.3.1
|
||||
- torchaudio==2.0.2
|
||||
- torchsampler==0.1.2
|
||||
- torchvision==0.15.2
|
||||
- tqdm==4.65.0
|
||||
- triton==2.0.0
|
||||
- typing-extensions==4.7.0
|
||||
- tzdata==2023.3
|
||||
- urllib3==2.0.2
|
||||
- wandb==0.15.5
|
||||
- zipp==3.15.0
|
||||
prefix: /opt/anaconda3/envs/tbd
|
156
tbd/models/base.py
Normal file
156
tbd/models/base.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .utils import pose_edge_index
|
||||
from torch_geometric.nn import GCNConv
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self, hidden_dim):
|
||||
super(CNN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3 = nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv2(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv3(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = nn.functional.max_pool2d(x, kernel_size=x.shape[2:]) # global max pooling
|
||||
return x
|
||||
|
||||
|
||||
class MindNetLSTM(nn.Module):
|
||||
"""
|
||||
Basic MindNet for model-based ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetLSTM, self).__init__()
|
||||
self.mods = mods
|
||||
if 'rgb_1' in mods:
|
||||
self.img_emb = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
if 'gaze' in mods:
|
||||
self.gaze_emb = nn.Linear(2, hidden_dim)
|
||||
if 'pose' in mods:
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
|
||||
feats = []
|
||||
if 'rgb_3' in self.mods:
|
||||
feats.append(rgb_3rd_pov_feats)
|
||||
if 'rgb_1' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(rgb_1st_pov.shape[1]):
|
||||
images_i = rgb_1st_pov[:,i]
|
||||
img_i_feat = self.img_emb(images_i)
|
||||
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
feats.append(rgb_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats.mean(2))
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, (h_n, c_n) = self.LSTM(self.dropout(lstm_inp))
|
||||
c_n = c_n.mean(0, keepdim=True).permute(1, 0, 2)
|
||||
return self.act(self.proj(lstm_out)), c_n, feats
|
||||
|
||||
|
||||
class MindNetSL(nn.Module):
|
||||
"""
|
||||
Basic MindNet for SL ToM, just LSTM on input concatenation
|
||||
"""
|
||||
def __init__(self, hidden_dim, dropout, mods):
|
||||
super(MindNetSL, self).__init__()
|
||||
self.mods = mods
|
||||
if 'rgb_1' in mods:
|
||||
self.img_emb = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
if 'gaze' in mods:
|
||||
self.gaze_emb = nn.Linear(2, hidden_dim)
|
||||
if 'pose' in mods:
|
||||
self.pose_edge_index = pose_edge_index()
|
||||
self.pose_emb = GCNConv(3, hidden_dim)
|
||||
self.LSTM = PreNorm(
|
||||
hidden_dim*len(mods),
|
||||
nn.LSTM(input_size=hidden_dim*len(mods), hidden_size=hidden_dim, batch_first=True, bidirectional=True))
|
||||
self.proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, rgb_3rd_pov_feats, bbox_feats, rgb_1st_pov, pose, gaze):
|
||||
feats = []
|
||||
if 'rgb_3' in self.mods:
|
||||
feats.append(rgb_3rd_pov_feats)
|
||||
if 'rgb_1' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(rgb_1st_pov.shape[1]):
|
||||
images_i = rgb_1st_pov[:,i]
|
||||
img_i_feat = self.img_emb(images_i)
|
||||
img_i_feat = img_i_feat.view(rgb_1st_pov.shape[0], -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feats = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
feats.append(rgb_feats)
|
||||
if 'pose' in self.mods:
|
||||
bs, seq_len = pose.size(0), pose.size(1)
|
||||
self.pose_edge_index = self.pose_edge_index.to(pose.device)
|
||||
pose_emb = self.pose_emb(pose.view(bs*seq_len, 26, 3), self.pose_edge_index)
|
||||
pose_emb = self.dropout(self.act(pose_emb))
|
||||
pose_emb = torch.mean(pose_emb, dim=1)
|
||||
hd = pose_emb.size(-1)
|
||||
feats.append(pose_emb.view(bs, seq_len, hd))
|
||||
if 'gaze' in self.mods:
|
||||
gaze_feats = self.dropout(self.act(self.gaze_emb(gaze)))
|
||||
feats.append(gaze_feats)
|
||||
if 'bbox' in self.mods:
|
||||
feats.append(bbox_feats.mean(2))
|
||||
lstm_inp = torch.cat(feats, 2)
|
||||
lstm_out, _ = self.LSTM(self.dropout(lstm_inp))
|
||||
return self.act(self.proj(lstm_out)), feats
|
157
tbd/models/common_mind.py
Normal file
157
tbd/models/common_mind.py
Normal file
|
@ -0,0 +1,157 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class CommonMindToMnet(nn.Module):
|
||||
"""
|
||||
img: bs, 3, 128, 128
|
||||
pose: bs, 26, 3
|
||||
gaze: bs, 2 NOTE: only tracker has gaze
|
||||
bbox: bs, 4
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_1', 'rgb_3', 'pose', 'gaze', 'bbox']):
|
||||
super(CommonMindToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
#for param in self.cnn.parameters():
|
||||
# param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
if aggr != 'no_tom': self.cm_proj = nn.Linear(hidden_dim*2, hidden_dim)
|
||||
self.ln_1 = nn.LayerNorm(hidden_dim)
|
||||
self.ln_2 = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
|
||||
|
||||
common_mind = self.cm_proj(torch.cat([cell_1, cell_2], -1)) # (bs, 1, h)
|
||||
|
||||
if self.aggr == 'attn':
|
||||
p1 = self.attn_left(x=out_1, context=common_mind)
|
||||
p2 = self.attn_right(x=out_2, context=common_mind)
|
||||
elif self.aggr == 'mult':
|
||||
p1 = out_1 * common_mind
|
||||
p2 = out_2 * common_mind
|
||||
elif self.aggr == 'sum':
|
||||
p1 = out_1 + common_mind
|
||||
p2 = out_2 + common_mind
|
||||
elif self.aggr == 'concat':
|
||||
p1 = torch.cat([out_1, common_mind], 1)
|
||||
p2 = torch.cat([out_2, common_mind], 1)
|
||||
else: raise ValueError
|
||||
p1 = self.act(p1)
|
||||
p1 = self.ln_1(p1)
|
||||
p2 = self.act(p2)
|
||||
p2 = self.ln_2(p2)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1)
|
||||
if self.aggr == 'concat':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2, common_mind] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
mods = ['pose', 'bbox', 'rgb_3']
|
||||
|
||||
for agg in ['no_tom']:
|
||||
model = CommonMindToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg, mods=mods)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(out[0].shape)
|
151
tbd/models/implicit.py
Normal file
151
tbd/models/implicit.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
from memory_efficient_attention_pytorch import Attention
|
||||
|
||||
|
||||
class ImplicitToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, aggr='sum', mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(ImplicitToMnet, self).__init__()
|
||||
|
||||
self.aggr = aggr
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
self.ln_1 = nn.LayerNorm(hidden_dim)
|
||||
self.ln_2 = nn.LayerNorm(hidden_dim)
|
||||
if aggr == 'attn':
|
||||
self.attn_left = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.attn_right = Attention(
|
||||
dim = hidden_dim,
|
||||
dim_head = hidden_dim // 4,
|
||||
heads = 4,
|
||||
memory_efficient = True,
|
||||
q_bucket_size = hidden_dim,
|
||||
k_bucket_size = hidden_dim)
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
if self.aggr == 'attn':
|
||||
p1 = self.attn_left(x=out_1, context=cell_2)
|
||||
p2 = self.attn_right(x=out_2, context=cell_1)
|
||||
elif self.aggr == 'mult':
|
||||
p1 = out_1 * cell_2
|
||||
p2 = out_2 * cell_1
|
||||
elif self.aggr == 'sum':
|
||||
p1 = out_1 + cell_2
|
||||
p2 = out_2 + cell_1
|
||||
elif self.aggr == 'concat':
|
||||
p1 = torch.cat([out_1, cell_2], 1)
|
||||
p2 = torch.cat([out_2, cell_1], 1)
|
||||
else: raise ValueError
|
||||
p1 = self.act(p1)
|
||||
p1 = self.ln_1(p1)
|
||||
p2 = self.act(p2)
|
||||
p2 = self.ln_2(p2)
|
||||
if self.aggr == 'mult' or self.aggr == 'sum' or self.aggr == 'attn':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1)
|
||||
if self.aggr == 'concat':
|
||||
m1 = self.m1(p1).mean(1)
|
||||
m2 = self.m2(p2).mean(1)
|
||||
m12 = self.m12(p1).mean(1)
|
||||
m21 = self.m21(p2).mean(1)
|
||||
mc = self.mc(p1*p2).mean(1) # NOTE: here I multiply p1 and p2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5, aggr=agg)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(agg, out[0].shape)
|
112
tbd/models/sl.py
Normal file
112
tbd/models/sl.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetSL
|
||||
|
||||
|
||||
class SLToMnet(nn.Module):
|
||||
"""
|
||||
Speaker-Listener ToMnet
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, tom_weight, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(SLToMnet, self).__init__()
|
||||
|
||||
self.tom_weight = tom_weight
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetSL(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetSL(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
m1_logits = self.m1(out_1).mean(1)
|
||||
m2_logits = self.m2(out_2).mean(1)
|
||||
m12_logits = self.m12(out_1).mean(1)
|
||||
m21_logits = self.m21(out_2).mean(1)
|
||||
mc_logits = self.mc(out_1*out_2).mean(1)
|
||||
|
||||
m1_ranking = torch.log_softmax(m1_logits, dim=-1)
|
||||
m2_ranking = torch.log_softmax(m2_logits, dim=-1)
|
||||
m12_ranking = torch.log_softmax(m12_logits, dim=-1)
|
||||
m21_ranking = torch.log_softmax(m21_logits, dim=-1)
|
||||
mc_ranking = torch.log_softmax(mc_logits, dim=-1)
|
||||
|
||||
# NOTE: does this make sense?
|
||||
m1 = m1_ranking + self.tom_weight * m2_ranking
|
||||
m2 = m2_ranking + self.tom_weight * m1_ranking
|
||||
m12 = m12_ranking + self.tom_weight * m21_ranking
|
||||
m21 = m21_ranking + self.tom_weight * m12_ranking
|
||||
mc = mc_ranking + self.tom_weight * mc_ranking
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, out_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
model = SLToMnet(hidden_dim=64, device='cpu', tom_weight=2.0, resnet=False, dropout=0.5)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(out[0].shape)
|
112
tbd/models/tom_base.py
Normal file
112
tbd/models/tom_base.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from .base import CNN, MindNetLSTM
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ImplicitToMnet(nn.Module):
|
||||
"""
|
||||
Implicit ToM net. Supports any subset of modalities
|
||||
Possible aggregations: sum, mult, attn, concat
|
||||
"""
|
||||
def __init__(self, hidden_dim, device, resnet=False, dropout=0.1, mods=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']):
|
||||
super(ImplicitToMnet, self).__init__()
|
||||
|
||||
self.mods = mods
|
||||
|
||||
# ---- 3rd POV Images, object and bbox ----#
|
||||
if resnet:
|
||||
resnet = models.resnet34(weights="IMAGENET1K_V1")
|
||||
self.cnn = nn.Sequential(
|
||||
*(list(resnet.children())[:-1])
|
||||
)
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.rgb_ff = nn.Linear(512, hidden_dim)
|
||||
else:
|
||||
self.cnn = CNN(hidden_dim)
|
||||
self.rgb_ff = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.bbox_ff = nn.Linear(4, hidden_dim)
|
||||
|
||||
# ---- Others ----#
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.device = device
|
||||
|
||||
# ---- Mind nets ----#
|
||||
self.mind_net_1 = MindNetLSTM(hidden_dim, dropout, mods=mods)
|
||||
self.mind_net_2 = MindNetLSTM(hidden_dim, dropout, mods=[m for m in mods if m != 'gaze'])
|
||||
|
||||
self.m1 = nn.Linear(hidden_dim, 4)
|
||||
self.m2 = nn.Linear(hidden_dim, 4)
|
||||
self.m12 = nn.Linear(hidden_dim, 4)
|
||||
self.m21 = nn.Linear(hidden_dim, 4)
|
||||
self.mc = nn.Linear(hidden_dim, 4)
|
||||
|
||||
def forward(self, img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze):
|
||||
|
||||
batch_size, sequence_len, channels, height, width = img_3rd_pov.shape
|
||||
|
||||
if 'bbox' in self.mods:
|
||||
bbox_feat = self.dropout(self.act(self.bbox_ff(bbox)))
|
||||
else:
|
||||
bbox_feat = None
|
||||
|
||||
if 'rgb_3' in self.mods:
|
||||
rgb_feat = []
|
||||
for i in range(sequence_len):
|
||||
images_i = img_3rd_pov[:,i]
|
||||
img_i_feat = self.cnn(images_i)
|
||||
img_i_feat = img_i_feat.view(batch_size, -1)
|
||||
rgb_feat.append(img_i_feat)
|
||||
rgb_feat = torch.stack(rgb_feat, 1)
|
||||
rgb_feat_3rd_pov = self.dropout(self.act(self.rgb_ff(rgb_feat)))
|
||||
else:
|
||||
rgb_feat_3rd_pov = None
|
||||
|
||||
if tracker_id == 'skele1':
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose1, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose2, gaze=None)
|
||||
else:
|
||||
out_1, cell_1, feats_1 = self.mind_net_1(rgb_feat_3rd_pov, bbox_feat, img_tracker, pose2, gaze)
|
||||
out_2, cell_2, feats_2 = self.mind_net_2(rgb_feat_3rd_pov, bbox_feat, img_battery, pose1, gaze=None)
|
||||
|
||||
if self.aggr == 'no_tom':
|
||||
m1 = self.m1(out_1).mean(1)
|
||||
m2 = self.m2(out_2).mean(1)
|
||||
m12 = self.m12(out_1).mean(1)
|
||||
m21 = self.m21(out_2).mean(1)
|
||||
mc = self.mc(out_1*out_2).mean(1) # NOTE: if no_tom then mc is computed starting from the concat of out_1 and out_2
|
||||
|
||||
return m1, m2, m12, m21, mc, [out_1, cell_1, out_2, cell_2] + feats_1 + feats_2
|
||||
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
img_3rd_pov = torch.ones(3, 5, 3, 128, 128)
|
||||
img_tracker = torch.ones(3, 5, 3, 128, 128)
|
||||
img_battery = torch.ones(3, 5, 3, 128, 128)
|
||||
pose1 = torch.ones(3, 5, 26, 3)
|
||||
pose2 = torch.ones(3, 5, 26, 3)
|
||||
bbox = torch.ones(3, 5, 13, 4)
|
||||
tracker_id = 'skele1'
|
||||
gaze = torch.ones(3, 5, 2)
|
||||
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
|
||||
print(count_parameters(model))
|
||||
breakpoint()
|
||||
|
||||
for agg in ['no_tom', 'concat', 'sum', 'mult', 'attn']:
|
||||
model = ImplicitToMnet(hidden_dim=64, device='cpu', resnet=False, dropout=0.5)
|
||||
out = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
|
||||
print(agg, out[0].shape)
|
7
tbd/models/utils.py
Normal file
7
tbd/models/utils.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
import torch
|
||||
|
||||
|
||||
def pose_edge_index():
|
||||
start = [15, 14, 13, 12, 19, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 3, 4, 5, 6, 8, 8, 4, 20, 21, 21, 22, 24, 22]
|
||||
end = [14, 13, 12, 0, 18, 17, 16, 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 4, 20, 20, 21, 22, 24, 23, 25, 24]
|
||||
return torch.tensor([start+end, end+start])
|
186
tbd/results/abl.json
Normal file
186
tbd/results/abl.json
Normal file
|
@ -0,0 +1,186 @@
|
|||
{
|
||||
"all": [
|
||||
{
|
||||
"m1": 0.3803337111550836,
|
||||
"m2": 0.3900899763574355,
|
||||
"m12": 0.4441281276628709,
|
||||
"m21": 0.4818757648120031,
|
||||
"mc": 0.4485177767702456
|
||||
},
|
||||
{
|
||||
"m1": 0.5186066842992191,
|
||||
"m2": 0.521895750052127,
|
||||
"m12": 0.49294626980529677,
|
||||
"m21": 0.4810118034501327,
|
||||
"mc": 0.6097300398369058
|
||||
},
|
||||
{
|
||||
"m1": 0.4965589148309122,
|
||||
"m2": 0.5094894309980568,
|
||||
"m12": 0.4615136302786905,
|
||||
"m21": 0.4554005550423429,
|
||||
"mc": 0.6258118710785031
|
||||
}
|
||||
],
|
||||
"rgb_3_pose_gaze_bbox": [
|
||||
{
|
||||
"m1": 0.3776045061727805,
|
||||
"m2": 0.3996776745150713,
|
||||
"m12": 0.4762772810038159,
|
||||
"m21": 0.48643178296718503,
|
||||
"mc": 0.4575207273412474
|
||||
},
|
||||
{
|
||||
"m1": 0.5176564423560418,
|
||||
"m2": 0.5109344883698214,
|
||||
"m12": 0.4630213122846928,
|
||||
"m21": 0.4826608133674547,
|
||||
"mc": 0.5979415365779003
|
||||
},
|
||||
{
|
||||
"m1": 0.5114692300997931,
|
||||
"m2": 0.5027048375802656,
|
||||
"m12": 0.47527894405588544,
|
||||
"m21": 0.45223985157847546,
|
||||
"mc": 0.6054099305712209
|
||||
}
|
||||
],
|
||||
"rgb_3_pose_gaze": [
|
||||
{
|
||||
"m1": 0.403207421026191,
|
||||
"m2": 0.3833413122398237,
|
||||
"m12": 0.4602455224198077,
|
||||
"m21": 0.47181798537346287,
|
||||
"mc": 0.4603675297898878
|
||||
},
|
||||
{
|
||||
"m1": 0.49484810149311514,
|
||||
"m2": 0.5060275976807422,
|
||||
"m12": 0.4610412452830618,
|
||||
"m21": 0.46869095956564044,
|
||||
"mc": 0.6040674897817755
|
||||
},
|
||||
{
|
||||
"m1": 0.5160598186177866,
|
||||
"m2": 0.5309683014233921,
|
||||
"m12": 0.47227245803060636,
|
||||
"m21": 0.46953974307035984,
|
||||
"mc": 0.6014771460423635
|
||||
}
|
||||
],
|
||||
"rgb_3_pose": [
|
||||
{
|
||||
"m1": 0.4057149181928123,
|
||||
"m2": 0.4002233785689204,
|
||||
"m12": 0.46794813614607333,
|
||||
"m21": 0.4690365183933033,
|
||||
"mc": 0.4591530208921514
|
||||
},
|
||||
{
|
||||
"m1": 0.5362792166212834,
|
||||
"m2": 0.5290656046231254,
|
||||
"m12": 0.4569419683345858,
|
||||
"m21": 0.4530255281497826,
|
||||
"mc": 0.4554252731371068
|
||||
},
|
||||
{
|
||||
"m1": 0.49570625763169085,
|
||||
"m2": 0.5146503967646507,
|
||||
"m12": 0.4567936139893578,
|
||||
"m21": 0.45918214877096325,
|
||||
"mc": 0.5962397441246001
|
||||
}
|
||||
],
|
||||
"rgb_3_gaze": [
|
||||
{
|
||||
"m1": 0.40135106828655215,
|
||||
"m2": 0.38453470155825614,
|
||||
"m12": 0.4989742833725901,
|
||||
"m21": 0.47369273992079175,
|
||||
"mc": 0.48430622854433986
|
||||
},
|
||||
{
|
||||
"m1": 0.508038122818153,
|
||||
"m2": 0.4875748099051746,
|
||||
"m12": 0.46665443622698555,
|
||||
"m21": 0.46635808547742913,
|
||||
"mc": 0.47936993226840163
|
||||
},
|
||||
{
|
||||
"m1": 0.49795853039610977,
|
||||
"m2": 0.5028666890527814,
|
||||
"m12": 0.44176709237564815,
|
||||
"m21": 0.4483898274665582,
|
||||
"mc": 0.5867527750929912
|
||||
}
|
||||
],
|
||||
"rgb_3_bbox": [
|
||||
{
|
||||
"m1": 0.3951383898241492,
|
||||
"m2": 0.3818794542844425,
|
||||
"m12": 0.44108151735270384,
|
||||
"m21": 0.46539754196523303,
|
||||
"mc": 0.43982185797713114
|
||||
},
|
||||
{
|
||||
"m1": 0.5093846655989521,
|
||||
"m2": 0.4923439212866733,
|
||||
"m12": 0.4598003475323884,
|
||||
"m21": 0.47647640659290746,
|
||||
"mc": 0.6349953712994137
|
||||
},
|
||||
{
|
||||
"m1": 0.5325224862402295,
|
||||
"m2": 0.5092319973570975,
|
||||
"m12": 0.4435807136490263,
|
||||
"m21": 0.4576911633624616,
|
||||
"mc": 0.6282064277856357
|
||||
}
|
||||
],
|
||||
"rgb_3_rgb_1": [
|
||||
{
|
||||
"m1": 0.39189391736691903,
|
||||
"m2": 0.3739995635963588,
|
||||
"m12": 0.4792392731637056,
|
||||
"m21": 0.4592726043789752,
|
||||
"mc": 0.4468645255652386
|
||||
},
|
||||
{
|
||||
"m1": 0.4827892482357646,
|
||||
"m2": 0.48042899735042716,
|
||||
"m12": 0.45932653547051094,
|
||||
"m21": 0.48430209616318126,
|
||||
"mc": 0.4506104344435269
|
||||
},
|
||||
{
|
||||
"m1": 0.4820247145474279,
|
||||
"m2": 0.3667553358192628,
|
||||
"m12": 0.44503028688537,
|
||||
"m21": 0.45984906207471654,
|
||||
"mc": 0.465120658971623
|
||||
}
|
||||
],
|
||||
"rgb_3": [
|
||||
{
|
||||
"m1": 0.40725462165126114,
|
||||
"m2": 0.38737351624656846,
|
||||
"m12": 0.46230461548252094,
|
||||
"m21": 0.4829312519709871,
|
||||
"mc": 0.4492175856929955
|
||||
},
|
||||
{
|
||||
"m1": 0.5286274183685061,
|
||||
"m2": 0.5081429492163979,
|
||||
"m12": 0.4610256989472217,
|
||||
"m21": 0.4733487634477733,
|
||||
"mc": 0.4655243312197501
|
||||
},
|
||||
{
|
||||
"m1": 0.5217968210271873,
|
||||
"m2": 0.5103780571157844,
|
||||
"m12": 0.4431266771306429,
|
||||
"m21": 0.48398542131284883,
|
||||
"mc": 0.6122314353959392
|
||||
}
|
||||
]
|
||||
}
|
232
tbd/results/all.json
Normal file
232
tbd/results/all.json
Normal file
|
@ -0,0 +1,232 @@
|
|||
{
|
||||
"cm_concat": [
|
||||
{
|
||||
"m1": 0.38921744471949393,
|
||||
"m2": 0.38557137008494935,
|
||||
"m12": 0.44699534554593756,
|
||||
"m21": 0.4747474437468054,
|
||||
"mc": 0.4918107834016411
|
||||
},
|
||||
{
|
||||
"m1": 0.5402415140026018,
|
||||
"m2": 0.48833721513836786,
|
||||
"m12": 0.4631512445419047,
|
||||
"m21": 0.4740880083492652,
|
||||
"mc": 0.6375070925808958
|
||||
},
|
||||
{
|
||||
"m1": 0.5012543523713172,
|
||||
"m2": 0.5068694866895836,
|
||||
"m12": 0.4451537834591627,
|
||||
"m21": 0.45215784721598673,
|
||||
"mc": 0.6201022576104379
|
||||
}
|
||||
],
|
||||
"cm_sum": [
|
||||
{
|
||||
"m1": 0.39403894801783246,
|
||||
"m2": 0.38541918219411786,
|
||||
"m12": 0.4600376974144952,
|
||||
"m21": 0.471919704007463,
|
||||
"mc": 0.43950812310207055
|
||||
},
|
||||
{
|
||||
"m1": 0.48497621104052574,
|
||||
"m2": 0.5295044689855949,
|
||||
"m12": 0.4502949472343065,
|
||||
"m21": 0.47823492553894387,
|
||||
"mc": 0.6028290833617195
|
||||
},
|
||||
{
|
||||
"m1": 0.503386104373653,
|
||||
"m2": 0.49983127146477085,
|
||||
"m12": 0.46782817568218116,
|
||||
"m21": 0.45484578845116075,
|
||||
"mc": 0.5905749126722909
|
||||
}
|
||||
],
|
||||
"cm_mult": [
|
||||
{
|
||||
"m1": 0.39070820515470606,
|
||||
"m2": 0.3996851353903932,
|
||||
"m12": 0.4455704586852128,
|
||||
"m21": 0.4713517869738811,
|
||||
"mc": 0.4450907029478458
|
||||
},
|
||||
{
|
||||
"m1": 0.5066540697731119,
|
||||
"m2": 0.526507445454099,
|
||||
"m12": 0.462643008560599,
|
||||
"m21": 0.48263054309565334,
|
||||
"mc": 0.6438566476782207
|
||||
},
|
||||
{
|
||||
"m1": 0.48868811674304546,
|
||||
"m2": 0.5074635877653536,
|
||||
"m12": 0.44597405775819876,
|
||||
"m21": 0.45445350963025877,
|
||||
"mc": 0.5884265473527218
|
||||
}
|
||||
],
|
||||
"cm_attn": [
|
||||
{
|
||||
"m1": 0.3949557687114269,
|
||||
"m2": 0.3919385900921811,
|
||||
"m12": 0.4850081112466773,
|
||||
"m21": 0.4849575556679713,
|
||||
"mc": 0.4516870089239762
|
||||
},
|
||||
{
|
||||
"m1": 0.4925989821370256,
|
||||
"m2": 0.49409170532242247,
|
||||
"m12": 0.4664647278240569,
|
||||
"m21": 0.46783863397462533,
|
||||
"mc": 0.6398721139927354
|
||||
},
|
||||
{
|
||||
"m1": 0.4945636568169018,
|
||||
"m2": 0.5049812790749876,
|
||||
"m12": 0.454359577718189,
|
||||
"m21": 0.4712184012093268,
|
||||
"mc": 0.5992735441011302
|
||||
}
|
||||
],
|
||||
"no_tom": [
|
||||
{
|
||||
"m1": 0.2570551317,
|
||||
"m2": 0.375350929686332,
|
||||
"m12": 0.312451988649724,
|
||||
"m21": 0.4631371031641,
|
||||
"mc": 0.457486278214567
|
||||
},
|
||||
{
|
||||
"m1": 0.233046800382043,
|
||||
"m2": 0.522609755931958,
|
||||
"m12": 0.326821758467328,
|
||||
"m21": 0.474338898013257,
|
||||
"mc": 0.604439456291308
|
||||
},
|
||||
{
|
||||
"m1": 0.33774852598382,
|
||||
"m2": 0.520943544364353,
|
||||
"m12": 0.298617214416867,
|
||||
"m21": 0.482175301427192,
|
||||
"mc": 0.634948478570852
|
||||
}
|
||||
],
|
||||
"sl": [
|
||||
{
|
||||
"m1": 0.365205706591741,
|
||||
"m2": 0.255259363011619,
|
||||
"m12": 0.421227579844245,
|
||||
"m21": 0.376143327741882,
|
||||
"mc": 0.45614515353718
|
||||
},
|
||||
{
|
||||
"m1": 0.493046934143676,
|
||||
"m2": 0.331798174804139,
|
||||
"m12": 0.422821548330913,
|
||||
"m21": 0.399768928780549,
|
||||
"mc": 0.450957023549231
|
||||
},
|
||||
{
|
||||
"m1": 0.466266787709392,
|
||||
"m2": 0.350962671130227,
|
||||
"m12": 0.431694150269919,
|
||||
"m21": 0.378863431433258,
|
||||
"mc": 0.470284405744656
|
||||
}
|
||||
],
|
||||
"impl_concat": [
|
||||
{
|
||||
"m1": 0.38427302094644894,
|
||||
"m2": 0.38673879043767634,
|
||||
"m12": 0.45694337561663145,
|
||||
"m21": 0.4737891562722213,
|
||||
"mc": 0.4502976351448088
|
||||
},
|
||||
{
|
||||
"m1": 0.49951068243751173,
|
||||
"m2": 0.5084945752383908,
|
||||
"m12": 0.4604721097809549,
|
||||
"m21": 0.4826884970930907,
|
||||
"mc": 0.6200443272625361
|
||||
},
|
||||
{
|
||||
"m1": 0.5013244243339088,
|
||||
"m2": 0.49476495726495723,
|
||||
"m12": 0.4596701406290429,
|
||||
"m21": 0.4554742441542813,
|
||||
"mc": 0.5988949378402535
|
||||
}
|
||||
],
|
||||
"impl_sum": [
|
||||
{
|
||||
"m1": 0.3803337111550836,
|
||||
"m2": 0.3900899763574355,
|
||||
"m12": 0.4441281276628709,
|
||||
"m21": 0.4818757648120031,
|
||||
"mc": 0.4485177767702456
|
||||
},
|
||||
{
|
||||
"m1": 0.5186066842992191,
|
||||
"m2": 0.521895750052127,
|
||||
"m12": 0.49294626980529677,
|
||||
"m21": 0.4810118034501327,
|
||||
"mc": 0.6097300398369058
|
||||
},
|
||||
{
|
||||
"m1": 0.4965589148309122,
|
||||
"m2": 0.5094894309980568,
|
||||
"m12": 0.4615136302786905,
|
||||
"m21": 0.4554005550423429,
|
||||
"mc": 0.6258118710785031
|
||||
}
|
||||
],
|
||||
"impl_mult": [
|
||||
{
|
||||
"m1": 0.3789421413006731,
|
||||
"m2": 0.3818053844554785,
|
||||
"m12": 0.46402717346945177,
|
||||
"m21": 0.4903726261039529,
|
||||
"mc": 0.4461443806398687
|
||||
},
|
||||
{
|
||||
"m1": 0.3789421413006731,
|
||||
"m2": 0.3818053844554785,
|
||||
"m12": 0.46402717346945177,
|
||||
"m21": 0.4903726261039529,
|
||||
"mc": 0.4461443806398687
|
||||
},
|
||||
{
|
||||
"m1": 0.49338554196342077,
|
||||
"m2": 0.5066817652688608,
|
||||
"m12": 0.46253374461930613,
|
||||
"m21": 0.47782311190445825,
|
||||
"mc": 0.4581608719646799
|
||||
}
|
||||
],
|
||||
"impl_attn": [
|
||||
{
|
||||
"m1": 0.37413691393147924,
|
||||
"m2": 0.2546966838007244,
|
||||
"m12": 0.429390512693598,
|
||||
"m21": 0.292401773870023,
|
||||
"mc": 0.45706325836224465
|
||||
},
|
||||
{
|
||||
"m1": 0.513917904196177,
|
||||
"m2": 0.25802580258025803,
|
||||
"m12": 0.49272662664765543,
|
||||
"m21": 0.27041556176385584,
|
||||
"mc": 0.6041394755857196
|
||||
},
|
||||
{
|
||||
"m1": 0.47720445038981674,
|
||||
"m2": 0.25839328537170264,
|
||||
"m12": 0.46505055463781547,
|
||||
"m21": 0.260276985433943,
|
||||
"mc": 0.6021811271770562
|
||||
}
|
||||
]
|
||||
}
|
BIN
tbd/results/false_belief_first_vs_second.pdf
Normal file
BIN
tbd/results/false_belief_first_vs_second.pdf
Normal file
Binary file not shown.
87
tbd/results/fb_ttest.txt
Normal file
87
tbd/results/fb_ttest.txt
Normal file
|
@ -0,0 +1,87 @@
|
|||
|
||||
========================================================= m1_m2_m12_m21
|
||||
|
||||
|
||||
Model: Base -> yes
|
||||
|
||||
|
||||
Model: DB -> yes
|
||||
|
||||
|
||||
Model: CG$\oplus$ -> no
|
||||
|
||||
|
||||
Model: CG$\otimes$ -> no
|
||||
|
||||
|
||||
Model: CG$\odot$ -> no
|
||||
|
||||
|
||||
Model: IC$\parallel$ -> no
|
||||
|
||||
|
||||
Model: IC$\oplus$ -> no
|
||||
|
||||
|
||||
Model: IC$\otimes$ -> no
|
||||
|
||||
|
||||
Model: IC$\odot$ -> yes
|
||||
|
||||
========================================================= m1_m2
|
||||
|
||||
|
||||
Model: Base -> yes
|
||||
|
||||
|
||||
Model: DB -> yes
|
||||
|
||||
|
||||
Model: CG$\oplus$ -> no
|
||||
|
||||
|
||||
Model: CG$\otimes$ -> no
|
||||
|
||||
|
||||
Model: CG$\odot$ -> no
|
||||
|
||||
|
||||
Model: IC$\parallel$ -> no
|
||||
|
||||
|
||||
Model: IC$\oplus$ -> no
|
||||
|
||||
|
||||
Model: IC$\otimes$ -> no
|
||||
|
||||
|
||||
Model: IC$\odot$ -> yes
|
||||
|
||||
========================================================= m12_m21
|
||||
|
||||
|
||||
Model: Base -> yes
|
||||
|
||||
|
||||
Model: DB -> yes
|
||||
|
||||
|
||||
Model: CG$\oplus$ -> no
|
||||
|
||||
|
||||
Model: CG$\otimes$ -> no
|
||||
|
||||
|
||||
Model: CG$\odot$ -> no
|
||||
|
||||
|
||||
Model: IC$\parallel$ -> no
|
||||
|
||||
|
||||
Model: IC$\oplus$ -> no
|
||||
|
||||
|
||||
Model: IC$\otimes$ -> no
|
||||
|
||||
|
||||
Model: IC$\odot$ -> yes
|
59
tbd/results/hgm_scores.txt
Normal file
59
tbd/results/hgm_scores.txt
Normal file
|
@ -0,0 +1,59 @@
|
|||
mc =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.000 0.500 0.001 50
|
||||
1 0.000 0.000 0.000 4
|
||||
2 0.004 0.038 0.007 238
|
||||
3 0.999 0.795 0.885 290788
|
||||
|
||||
accuracy 0.794 291080
|
||||
macro avg 0.251 0.333 0.223 291080
|
||||
weighted avg 0.998 0.794 0.884 291080
|
||||
|
||||
m1 =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.000 0.000 0.000 147
|
||||
1 0.000 0.000 0.000 2
|
||||
2 0.025 0.051 0.033 1714
|
||||
3 0.994 0.988 0.991 289217
|
||||
|
||||
accuracy 0.982 291080
|
||||
macro avg 0.255 0.260 0.256 291080
|
||||
weighted avg 0.988 0.982 0.985 291080
|
||||
|
||||
m2 =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.001 0.013 0.001 151
|
||||
2 0.031 0.084 0.045 2394
|
||||
3 0.992 0.970 0.981 288535
|
||||
|
||||
accuracy 0.962 291080
|
||||
macro avg 0.341 0.355 0.342 291080
|
||||
weighted avg 0.983 0.962 0.972 291080
|
||||
|
||||
m12 =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.000 0.000 0.000 93
|
||||
1 0.000 0.000 0.000 8
|
||||
2 0.015 0.056 0.023 676
|
||||
3 0.997 0.990 0.994 290303
|
||||
|
||||
accuracy 0.988 291080
|
||||
macro avg 0.253 0.262 0.254 291080
|
||||
weighted avg 0.995 0.988 0.991 291080
|
||||
|
||||
m21 =====================================================================
|
||||
precision recall f1-score support
|
||||
|
||||
0 0.002 0.012 0.003 86
|
||||
1 0.000 0.000 0.000 12
|
||||
2 0.010 0.040 0.016 658
|
||||
3 0.997 0.989 0.993 290324
|
||||
|
||||
accuracy 0.987 291080
|
||||
macro avg 0.252 0.260 0.253 291080
|
||||
weighted avg 0.995 0.987 0.991 291080
|
||||
|
BIN
tbd/results/tbd_abl_avg_only.pdf
Normal file
BIN
tbd/results/tbd_abl_avg_only.pdf
Normal file
Binary file not shown.
12
tbd/run_test.sh
Normal file
12
tbd/run_test.sh
Normal file
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
|
||||
python -m test \
|
||||
--gpu_id 1 \
|
||||
--seed 1 \
|
||||
--non_blocking \
|
||||
--pin_memory \
|
||||
--model_type tom_cm \
|
||||
--aggr no_tom \
|
||||
--hidden_dim 64 \
|
||||
--batch_size 64 \
|
||||
--load_model_path /PATH/TO/model
|
16
tbd/run_train.sh
Normal file
16
tbd/run_train.sh
Normal file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
|
||||
python -m train \
|
||||
--gpu_id 2 \
|
||||
--seed 123 \
|
||||
--logger \
|
||||
--non_blocking \
|
||||
--pin_memory \
|
||||
--batch_size 64 \
|
||||
--num_workers 16 \
|
||||
--num_epoch 300 \
|
||||
--lr 5e-4 \
|
||||
--dropout 0.1 \
|
||||
--model_type tom_cm \
|
||||
--aggr no_tom \
|
||||
--hidden_dim 64
|
568
tbd/tbd_dataloader.py
Normal file
568
tbd/tbd_dataloader.py
Normal file
|
@ -0,0 +1,568 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import pickle
|
||||
import torch
|
||||
import time
|
||||
import glob
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import cv2
|
||||
from itertools import product
|
||||
import csv
|
||||
import torchvision.transforms as T
|
||||
|
||||
from utils.helpers import tracker_skeID, CLIPS_IDS_88, ALL_IDS, UNIQUE_OBJ_IDS
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
# Unpack the batch into individual elements
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep = zip(*batch)
|
||||
|
||||
# Determine the maximum number of objects in any batch
|
||||
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
|
||||
|
||||
# Pad the bounding box tensors
|
||||
bboxes_pad = []
|
||||
for bbox in bboxes:
|
||||
pad_size = max_n_obj - bbox.shape[1]
|
||||
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
|
||||
padded_bbox = torch.cat((bbox, pad), dim=1)
|
||||
bboxes_pad.append(padded_bbox)
|
||||
|
||||
# Stack the padded tensors into a batch tensor
|
||||
bboxes_batch = torch.stack(bboxes_pad, dim=0)
|
||||
|
||||
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
|
||||
img_tracker = torch.stack(img_tracker, dim=0)
|
||||
img_battery = torch.stack(img_battery, dim=0)
|
||||
pose1 = torch.stack(pose1, dim=0)
|
||||
pose2 = torch.stack(pose2, dim=0)
|
||||
gaze = torch.stack(gaze, dim=0)
|
||||
labels = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
# Return the batched tensors
|
||||
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep
|
||||
|
||||
def collate_fn_test(batch):
|
||||
# Unpack the batch into individual elements
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep, false_beliefs = zip(*batch)
|
||||
|
||||
# Determine the maximum number of objects in any batch
|
||||
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
|
||||
|
||||
# Pad the bounding box tensors
|
||||
bboxes_pad = []
|
||||
for bbox in bboxes:
|
||||
pad_size = max_n_obj - bbox.shape[1]
|
||||
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
|
||||
padded_bbox = torch.cat((bbox, pad), dim=1)
|
||||
bboxes_pad.append(padded_bbox)
|
||||
|
||||
# Stack the padded tensors into a batch tensor
|
||||
bboxes_batch = torch.stack(bboxes_pad, dim=0)
|
||||
|
||||
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
|
||||
img_tracker = torch.stack(img_tracker, dim=0)
|
||||
img_battery = torch.stack(img_battery, dim=0)
|
||||
pose1 = torch.stack(pose1, dim=0)
|
||||
pose2 = torch.stack(pose2, dim=0)
|
||||
gaze = torch.stack(gaze, dim=0)
|
||||
labels = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
# Return the batched tensors
|
||||
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep, false_beliefs
|
||||
|
||||
|
||||
class TBDDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str = "/scratch/bortoletto/data/tbd",
|
||||
mode: str = "train",
|
||||
tbd_data_path: str = "/scratch/bortoletto/data/tbd/mind_lstm_training_cnn_att/",
|
||||
list_of_ids_to_consider: list = ALL_IDS,
|
||||
use_preprocessed_img: bool = True,
|
||||
resize_img: Optional[Union[tuple, int]] = (128,128),
|
||||
):
|
||||
"""TBD Dataset based on the 88 clip version of the TBD data.
|
||||
|
||||
Expects the following folder structure:
|
||||
- path
|
||||
- tracker_gt_smooth <- These are eye tracking from POV, 2D coordinates
|
||||
- images/*/ <- These are the images by experiment id
|
||||
- battery <- These are the images, 1st Person
|
||||
- tracker <- These are the images, 1st other Person w/ eye fixation
|
||||
- kinect <- These are the images, 3rd Person
|
||||
- skeleton <- Pose estimation, 3D coordinates
|
||||
- annotation <- These are the labels, i.e. [0,3] (see below)
|
||||
|
||||
|
||||
Labels are strcutured as follows:
|
||||
{
|
||||
"O1": [ <- Object with id O1
|
||||
{
|
||||
"m1": {
|
||||
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
|
||||
"loc": null
|
||||
},
|
||||
"m2": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"m12": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"m21": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"mc": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"mg": {
|
||||
"fluent": 3,
|
||||
"loc": [
|
||||
22,
|
||||
9
|
||||
]
|
||||
}
|
||||
}, ...
|
||||
], ...
|
||||
}
|
||||
|
||||
This corresponds to a strict subset of the raw dataset collected
|
||||
by the TBD people in their paper "Learning Traidic Belief Dynamics
|
||||
in Nonverbal Communication from Videos" (CVPR2021, Oral).
|
||||
|
||||
We keep small amounts of data in memory (everything <100MB).
|
||||
Otherwise we read from disk on the fly. This dataset applies normalization.
|
||||
|
||||
Args:
|
||||
path (str, optional): Where the folders lie.
|
||||
Defaults to "/scratch/ruhdorfer/triadic_beleif_data_v2".
|
||||
list_of_ids_to_consider (list, optional): List of ids to consider.
|
||||
Defaults to ALL_IDS. Otherwise specify a list,
|
||||
e.g. ["test_94342_23", "test_boelter_21", ...].
|
||||
resize_img (Optional[Union[tuple, int]], optional): Resize image to
|
||||
this size if required. Defaults to None.
|
||||
"""
|
||||
print(f"Loading TBD Dataset in mode {mode}...")
|
||||
|
||||
self.mode = mode
|
||||
|
||||
start = time.time()
|
||||
|
||||
self.skeleton_3D_path = f"{path}/skeleton"
|
||||
self.tracker_2D_path = f"{path}/tracker_gt_smooth"
|
||||
self.bbox_csv_path = f"{path}/annotations_with_bbox.csv"
|
||||
if use_preprocessed_img:
|
||||
self.img_path = f"{path}/images_norm"
|
||||
else:
|
||||
self.img_path = f"{path}/images"
|
||||
self.obj_ids_path = f"{path}/mind_lstm_training_cnn_att_shu.pkl"
|
||||
|
||||
self.label_map = list(product([0, 1, 2, 3], repeat=5))
|
||||
|
||||
clips = os.listdir(tbd_data_path)
|
||||
data = []
|
||||
labels = []
|
||||
for clip in clips:
|
||||
with open(tbd_data_path + clip, 'rb') as f:
|
||||
vec_input, label_ = pickle.load(f, encoding='latin1')
|
||||
data = data + vec_input
|
||||
labels = labels + label_
|
||||
c = list(zip(data, labels))
|
||||
random.shuffle(c)
|
||||
train_ratio = int(len(c) * 0.6)
|
||||
validate_ratio = int(len(c) * 0.2)
|
||||
data, label = zip(*c)
|
||||
train_x, train_y = data[:train_ratio], label[:train_ratio]
|
||||
validate_x, validate_y = data[train_ratio:train_ratio + validate_ratio], label[train_ratio:train_ratio + validate_ratio]
|
||||
test_x, test_y = data[train_ratio + validate_ratio:], label[train_ratio + validate_ratio:]
|
||||
self.mind_count = np.zeros(1024) # used for CE weights
|
||||
|
||||
if mode == "train":
|
||||
self.data, self.labels = train_x, train_y
|
||||
elif mode == "val":
|
||||
self.data, self.labels = validate_x, validate_y
|
||||
elif mode == "test":
|
||||
self.data, self.labels = test_x, test_y
|
||||
|
||||
self.false_beliefs_path = f"{path}/store_mind_set"
|
||||
|
||||
# keep small amouts of data in memory
|
||||
self.skeleton_3D = self.load_skeleton_3D(self.skeleton_3D_path, list_of_ids_to_consider)
|
||||
self.tracker_2D = self.load_tracker_2D(self.tracker_2D_path, list_of_ids_to_consider)
|
||||
self.bbox_df = pd.read_csv(self.bbox_csv_path, header=0)
|
||||
self.obj_ids = self.load_obj_ids(self.obj_ids_path)
|
||||
|
||||
if not use_preprocessed_img:
|
||||
normalisation_steps = [
|
||||
T.ToTensor(),
|
||||
T.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
]
|
||||
if resize_img is not None:
|
||||
normalisation_steps.insert(1, T.Resize(resize_img))
|
||||
self.preprocess_img = T.Compose(normalisation_steps)
|
||||
else:
|
||||
self.preprocess_img = None
|
||||
|
||||
self.use_preprocessed_img = use_preprocessed_img
|
||||
print(f"Done loading in {time.time() - start}s.")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(
|
||||
self, idx: int
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict, str, int, str
|
||||
]:
|
||||
"""Given an index, return the corresponding experiment_id and timestep in the experiment.
|
||||
Then picky the appropriate data and labels from these.
|
||||
|
||||
Args:
|
||||
idx (int): _description_
|
||||
|
||||
Returns:
|
||||
tuple: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict
|
||||
Returns the following:
|
||||
- img_kinect: torch.Tensor of shape (T, C, H, W) (Default is [T, 3, 720, 1280])
|
||||
- img_tracker: torch.Tensor of shape (T, H, W, C)
|
||||
- img_battery: torch.Tensor of shape (T, H, W, C)
|
||||
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 1)
|
||||
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 2)
|
||||
- bbox: torch.Tensor of shape (T, num_obj, 5)
|
||||
- tracker to skeleton ID: str (either skeleton 1 or 2)
|
||||
- tracker_2D: torch.Tensor of shape (T, 2)
|
||||
- labels: dict (see below)
|
||||
"""
|
||||
labels = self.label_map[self.labels[idx]]
|
||||
experiment_id = self.data[idx][1][0].split('/')[6]
|
||||
img_data_path = f"{self.img_path}/{experiment_id}"
|
||||
frame_ids = [int(os.path.basename(self.data[idx][1][i]).split('_')[0]) for i in range(len(self.data[idx][1]))]
|
||||
|
||||
if self.use_preprocessed_img:
|
||||
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.pt")))
|
||||
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.pt")))
|
||||
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.pt")))
|
||||
kinect_img_paths = [kinect[id] for id in frame_ids]
|
||||
tracker_img_paths = [tracker[id] for id in frame_ids]
|
||||
battery_img_paths = [battery[id] for id in frame_ids]
|
||||
else:
|
||||
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.jpg")))
|
||||
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.jpg")))
|
||||
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.jpg")))
|
||||
kinect_img_paths = [kinect[id] for id in frame_ids]
|
||||
tracker_img_paths = [tracker[id] for id in frame_ids]
|
||||
battery_img_paths = [battery[id] for id in frame_ids]
|
||||
|
||||
# load images
|
||||
kinect_imgs = [
|
||||
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
|
||||
for img_path in kinect_img_paths
|
||||
]
|
||||
kinect_imgs = torch.stack(kinect_imgs, axis=0)
|
||||
|
||||
tracker_imgs = [
|
||||
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
|
||||
for img_path in tracker_img_paths
|
||||
]
|
||||
tracker_imgs = torch.stack(tracker_imgs, axis=0)
|
||||
|
||||
battery_imgs = [
|
||||
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
|
||||
for img_path in battery_img_paths
|
||||
]
|
||||
battery_imgs = torch.stack(battery_imgs, axis=0)
|
||||
|
||||
# load object id to check for false beliefs - only for testing
|
||||
if self.mode == "test": #or self.mode == "train":
|
||||
if f"{experiment_id}.txt" in os.listdir(self.false_beliefs_path):
|
||||
obj_id = self.obj_ids[experiment_id][frame_ids[-1]]
|
||||
obj_id = next(x for x in obj_id if x is not None)
|
||||
false_belief = next((line.strip().split(',')[2] for line in open(f"{self.false_beliefs_path}/{experiment_id}.txt") if line.startswith(str(frame_ids[-1]) + ',' + obj_id + ',')), "no")
|
||||
#if experiment_id in ['test_boelter4_0', 'test_boelter4_7', 'test_boelter4_6', 'test_boelter4_8', 'test_boelter2_3',
|
||||
# 'test_94342_20', 'test_94342_18', 'test_94342_11', 'test_94342_17', 'test_boelter3_8', 'test_94342_2',
|
||||
# 'test_boelter2_17', 'test_boelter3_7', 'test_94342_4', 'test_boelter3_9', 'test_boelter_10',
|
||||
# 'test_boelter2_6', 'test_boelter4_10', 'test_boelter4_2', 'test_boelter4_5', 'test_94342_24',
|
||||
# 'test_94342_15', 'test_boelter3_5', 'test_94342_8', 'test2', 'test_boelter3_12']:
|
||||
# print('here!')
|
||||
# with open(os.path.join(f'results/hgm_test_fb.csv'), mode='a') as file:
|
||||
# writer = csv.writer(file)
|
||||
# writer.writerow([experiment_id, obj_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
|
||||
else:
|
||||
false_belief = "no"
|
||||
#with open(os.path.join(f'results/test_fb.csv'), mode='a') as file:
|
||||
# writer = csv.writer(file)
|
||||
# writer.writerow([experiment_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
|
||||
|
||||
df = self.bbox_df[
|
||||
(self.bbox_df.experiment_name == experiment_id)
|
||||
#& (self.bbox_df.name == obj_id) # NOTE: load the bounding boxes for all objects
|
||||
& (self.bbox_df.name != 'P1')
|
||||
& (self.bbox_df.name != 'P2')
|
||||
& (self.bbox_df.frame.isin(frame_ids))
|
||||
]
|
||||
|
||||
bboxes = []
|
||||
for f in frame_ids:
|
||||
bbox = torch.tensor(df.loc[df['frame'] == f, ["x_min", "y_min", "x_max", "y_max"]].to_numpy(), dtype=torch.float32)
|
||||
bbox[:, 0] = bbox[:, 0] / 1280.0
|
||||
bbox[:, 1] = bbox[:, 1] / 720.0
|
||||
bbox[:, 2] = bbox[:, 2] / 1280.0
|
||||
bbox[:, 3] = bbox[:, 3] / 720.0
|
||||
bboxes.append(bbox)
|
||||
bboxes = torch.stack(bboxes) # NOTE: this will need a collate function bc not every video has the same number of objects
|
||||
|
||||
skele1 = self.skeleton_3D[experiment_id]["skele1"][frame_ids]
|
||||
skele2 = self.skeleton_3D[experiment_id]["skele2"][frame_ids]
|
||||
|
||||
gaze = self.tracker_2D[experiment_id][frame_ids]
|
||||
|
||||
if self.mode == "test":
|
||||
return (
|
||||
kinect_imgs,
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bboxes,
|
||||
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
|
||||
gaze,
|
||||
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
|
||||
experiment_id,
|
||||
frame_ids,
|
||||
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
|
||||
false_belief
|
||||
)
|
||||
else:
|
||||
return (
|
||||
kinect_imgs,
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bboxes,
|
||||
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
|
||||
gaze,
|
||||
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
|
||||
experiment_id,
|
||||
frame_ids
|
||||
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
|
||||
)
|
||||
|
||||
def onehot(self, x, n=len(UNIQUE_OBJ_IDS)):
|
||||
retval = torch.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
def load_obj_ids(self, path: str):
|
||||
with open(path, "rb") as f:
|
||||
ids = pickle.load(f)
|
||||
return ids
|
||||
|
||||
def extract_labels(self):
|
||||
"""TODO: Converts index label to [m1, m2, m12, m21, mc] format.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
def _flatten_mind_obj_timestep(self, mind_obj_dict: dict) -> list:
|
||||
"""Flattens the mind object dict to a list. I.e. takes
|
||||
|
||||
{
|
||||
"m1": {
|
||||
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
|
||||
"loc": null
|
||||
},
|
||||
"m2": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"m12": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"m21": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"mc": {
|
||||
"fluent": 3,
|
||||
"loc": null
|
||||
},
|
||||
"mg": {
|
||||
"fluent": 3,
|
||||
"loc": [
|
||||
22,
|
||||
9
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
and returns [3, 3, 3, 3, 3, 3]
|
||||
|
||||
Args:
|
||||
mind_obj_dict (dict): Mind object dict as described in __init__.doctstring.
|
||||
|
||||
Returns:
|
||||
list: List of mind object labels.
|
||||
"""
|
||||
return np.array([mind_obj["fluent"] for key, mind_obj in mind_obj_dict.items() if key != "mg"])
|
||||
|
||||
def load_skeleton_3D(self, path: str, list_of_ids_to_consider: list):
|
||||
"""Load skeleton 3D data from disk.
|
||||
|
||||
- path
|
||||
- * <- list of ids
|
||||
- skele1.p <- 3D coord per id and timestep
|
||||
- skele2.p <-
|
||||
|
||||
Args:
|
||||
path (str): Where the skeleton 3D data lie.
|
||||
list_of_ids_to_consider (list): List of ids to consider.
|
||||
Defaults to None which means all ids. Otherwise specify a list,
|
||||
e.g. ["test_94342_23", "test_boelter_21", ...].
|
||||
|
||||
Returns:
|
||||
dict: skeleton 3D data as described above in __init__.doctstring.
|
||||
"""
|
||||
skeleton_3D = {}
|
||||
for experiment_id in list_of_ids_to_consider:
|
||||
skeleton_3D[experiment_id] = {}
|
||||
with open(f"{path}/{experiment_id}/skele1.p", "rb") as f:
|
||||
skeleton_3D[experiment_id]["skele1"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
|
||||
with open(f"{path}/{experiment_id}/skele2.p", "rb") as f:
|
||||
skeleton_3D[experiment_id]["skele2"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
|
||||
return skeleton_3D
|
||||
|
||||
def load_tracker_2D(self, path: str, list_of_ids_to_consider: list):
|
||||
"""Load tracker 2D data from disk.
|
||||
|
||||
- path
|
||||
- *.p <- 2D coord per id and timestep
|
||||
|
||||
Args:
|
||||
path (str): Where the tracker 2D data lie.
|
||||
list_of_ids_to_consider (list): List of ids to consider.
|
||||
Defaults to None which means all ids. Otherwise specify a list,
|
||||
e.g. ["test_94342_23", "test_boelter_21", ...].
|
||||
|
||||
Returns:
|
||||
dict: tracker 2D data.
|
||||
"""
|
||||
tracker_2D = {}
|
||||
for experiment_id in list_of_ids_to_consider:
|
||||
with open(f"{path}/{experiment_id}.p", "rb") as f:
|
||||
tracker_2D[experiment_id] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
|
||||
return tracker_2D
|
||||
|
||||
def load_bbox(self, path: str, list_of_ids_to_consider: list):
|
||||
"""Load bbox data from disk.
|
||||
|
||||
- bbox_tensors.pickle <- bbox per experiment id one tensor
|
||||
|
||||
Args:
|
||||
path (str): Where the bbox data lie.
|
||||
list_of_ids_to_consider (list): List of ids to consider.
|
||||
|
||||
Returns:
|
||||
dict: bbox data.
|
||||
"""
|
||||
with open(path, "rb") as f:
|
||||
pickle_data = pickle.load(f)
|
||||
for key in CLIPS_IDS_88:
|
||||
if key not in list_of_ids_to_consider:
|
||||
pickle_data.pop(key, None)
|
||||
return pickle_data
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# os.environ['PYTHONHASHSEED'] = str(42)
|
||||
# torch.manual_seed(42)
|
||||
# np.random.seed(42)
|
||||
# random.seed(42)
|
||||
|
||||
data = TBDDataset(use_preprocessed_img=True, mode="test")
|
||||
|
||||
from tqdm import tqdm
|
||||
for i in tqdm(range(data.__len__())):
|
||||
data[i]
|
||||
|
||||
breakpoint()
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Just for guessing time
|
||||
data_0=data[0]
|
||||
data_last=data[len(data)-1]
|
||||
idx = np.random.randint(1, len(data)-1) # Something in between.
|
||||
start = time.time()
|
||||
(
|
||||
kinect_imgs, # <- len x 720 x 1280 x 3 originally, likely smaller now
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bbox,
|
||||
tracker_skeID_sample, # <- This is the tracker skeleton ID
|
||||
tracker2d,
|
||||
label,
|
||||
experiment_id, # From here for debugging
|
||||
timestep,
|
||||
#obj_id, # <- This is the object ID as a one-hot
|
||||
false_belief
|
||||
) = data[idx]
|
||||
end = time.time()
|
||||
print(f"Time for one sample: {end-start}")
|
||||
|
||||
print('kinect:', kinect_imgs.shape)
|
||||
print('tracker:', tracker_imgs.shape)
|
||||
print('battery:', battery_imgs.shape)
|
||||
print('skele1:', skele1.shape)
|
||||
print('skele2:', skele2.shape)
|
||||
print('gaze:', tracker2d.shape)
|
||||
print('bbox:', bbox.shape)
|
||||
print('label:', label)
|
||||
|
||||
#breakpoint()
|
||||
|
||||
dl = DataLoader(
|
||||
data,
|
||||
batch_size=4,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
for j, batch in tqdm(enumerate(dl)):
|
||||
#print(j, end='\r')
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
|
||||
#breakpoint()
|
||||
#print(img_3rd_pov.shape)
|
||||
#print(img_tracker.shape)
|
||||
#print(img_battery.shape)
|
||||
#print(pose1.shape, pose2.shape)
|
||||
#print(bbox.shape)
|
||||
#print(gaze.shape)
|
||||
|
||||
|
||||
breakpoint()
|
196
tbd/test.py
Normal file
196
tbd/test.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
import torch
|
||||
import csv
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tbd_dataloader import TBDDataset, collate_fn_test
|
||||
from models.common_mind import CommonMindToMnet
|
||||
from models.sl import SLToMnet
|
||||
from models.implicit import ImplicitToMnet
|
||||
from utils.helpers import compute_f1_scores
|
||||
|
||||
|
||||
def test(args):
|
||||
|
||||
test_dataset = TBDDataset(
|
||||
path=args.data_path,
|
||||
mode="test",
|
||||
use_preprocessed_img=True
|
||||
)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
collate_fn=collate_fn_test
|
||||
)
|
||||
|
||||
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# model
|
||||
if args.model_type == 'tom_cm':
|
||||
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_sl':
|
||||
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
elif args.model_type == 'tom_impl':
|
||||
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
else: raise NotImplementedError
|
||||
|
||||
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
|
||||
model.device = device
|
||||
|
||||
model.eval()
|
||||
|
||||
if args.save_preds:
|
||||
# Define the output file path
|
||||
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}'
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path)
|
||||
print(f'Saving predictions in {folder_path}.')
|
||||
|
||||
print('Testing...')
|
||||
m1_pred_list = []
|
||||
m2_pred_list = []
|
||||
m12_pred_list = []
|
||||
m21_pred_list = []
|
||||
mc_pred_list = []
|
||||
m1_label_list = []
|
||||
m2_label_list = []
|
||||
m12_label_list = []
|
||||
m21_label_list = []
|
||||
mc_label_list = []
|
||||
with torch.no_grad():
|
||||
for j, batch in tqdm(enumerate(test_dataloader)):
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch
|
||||
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
||||
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
||||
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
||||
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
||||
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
||||
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
||||
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
||||
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
m1_pred = m1_pred.reshape(-1, 4)
|
||||
m2_pred = m2_pred.reshape(-1, 4)
|
||||
m12_pred = m12_pred.reshape(-1, 4)
|
||||
m21_pred = m21_pred.reshape(-1, 4)
|
||||
mc_pred = mc_pred.reshape(-1, 4)
|
||||
m1_label = labels[:, 0].reshape(-1).to(device)
|
||||
m2_label = labels[:, 1].reshape(-1).to(device)
|
||||
m12_label = labels[:, 2].reshape(-1).to(device)
|
||||
m21_label = labels[:, 3].reshape(-1).to(device)
|
||||
mc_label = labels[:, 4].reshape(-1).to(device)
|
||||
|
||||
m1_pred_list.append(m1_pred)
|
||||
m2_pred_list.append(m2_pred)
|
||||
m12_pred_list.append(m12_pred)
|
||||
m21_pred_list.append(m21_pred)
|
||||
mc_pred_list.append(mc_pred)
|
||||
m1_label_list.append(m1_label)
|
||||
m2_label_list.append(m2_label)
|
||||
m12_label_list.append(m12_label)
|
||||
m21_label_list.append(m21_label)
|
||||
mc_label_list.append(mc_label)
|
||||
|
||||
if args.save_preds:
|
||||
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
|
||||
data = [(
|
||||
i,
|
||||
torch.argmax(m1_pred[i]).cpu().numpy(),
|
||||
torch.argmax(m2_pred[i]).cpu().numpy(),
|
||||
torch.argmax(m12_pred[i]).cpu().numpy(),
|
||||
torch.argmax(m21_pred[i]).cpu().numpy(),
|
||||
torch.argmax(mc_pred[i]).cpu().numpy(),
|
||||
m1_label[i].cpu().numpy(),
|
||||
m2_label[i].cpu().numpy(),
|
||||
m12_label[i].cpu().numpy(),
|
||||
m21_label[i].cpu().numpy(),
|
||||
mc_label[i].cpu().numpy(),
|
||||
false_belief[i]) for i in range(len(labels))
|
||||
]
|
||||
header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief']
|
||||
with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow(header) # Write the header row
|
||||
writer.writerows(data) # Write the data rows
|
||||
|
||||
#np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy())
|
||||
test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores(
|
||||
torch.cat(m1_pred_list),
|
||||
torch.cat(m1_label_list),
|
||||
torch.cat(m2_pred_list),
|
||||
torch.cat(m2_label_list),
|
||||
torch.cat(m12_pred_list),
|
||||
torch.cat(m12_label_list),
|
||||
torch.cat(m21_pred_list),
|
||||
torch.cat(m21_label_list),
|
||||
torch.cat(mc_pred_list),
|
||||
torch.cat(mc_label_list)
|
||||
)
|
||||
|
||||
print("Test m1 F1: {}".format(test_m1_f1))
|
||||
print("Test m2 F1: {}".format(test_m2_f1))
|
||||
print("Test m12 F1: {}".format(test_m12_f1))
|
||||
print("Test m21 F1: {}".format(test_m21_f1))
|
||||
print("Test mc F1: {}".format(test_mc_f1))
|
||||
|
||||
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
|
||||
f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}")
|
||||
f.write(f"m1 f1: {test_m1_f1}")
|
||||
f.write(f"m2 f1: {test_m2_f1}")
|
||||
f.write(f"m12 f1: {test_m12_f1}")
|
||||
f.write(f"m21 f1: {test_m21_f1}")
|
||||
f.write(f"mc f1: {test_mc_f1}")
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Define the command-line arguments
|
||||
parser.add_argument('--gpu_id', type=int)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--presaved', type=int, default=128)
|
||||
parser.add_argument('--non_blocking', action='store_true')
|
||||
parser.add_argument('--num_workers', type=int, default=16)
|
||||
parser.add_argument('--pin_memory', action='store_true')
|
||||
parser.add_argument('--model_type', type=str)
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument('--aggr', type=str, default='concat', required=False)
|
||||
parser.add_argument('--use_resnet', action='store_true')
|
||||
parser.add_argument('--hidden_dim', type=int, default=64)
|
||||
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
||||
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
|
||||
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
|
||||
parser.add_argument('--save_path', type=str, default='experiments/')
|
||||
parser.add_argument('--test_frames', type=str, default=None)
|
||||
parser.add_argument('--median', type=int, default=None)
|
||||
parser.add_argument('--load_model_path', type=str)
|
||||
parser.add_argument('--dropout', type=float, default=0.0)
|
||||
parser.add_argument('--save_preds', action='store_true')
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
||||
if not args.aggr:
|
||||
parser.error("The choosen --model_type requires --aggr")
|
||||
if args.model_type == 'tom_sl' and not args.tom_weight:
|
||||
parser.error("The choosen --model_type requires --tom_weight")
|
||||
|
||||
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
print('###########################################################################')
|
||||
print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!')
|
||||
print('###########################################################################')
|
||||
|
||||
test(args)
|
474
tbd/train.py
Normal file
474
tbd/train.py
Normal file
|
@ -0,0 +1,474 @@
|
|||
import torch
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import random
|
||||
import datetime
|
||||
import wandb
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from tbd_dataloader import TBDDataset, collate_fn
|
||||
from models.common_mind import CommonMindToMnet
|
||||
from models.sl import SLToMnet
|
||||
from models.implicit import ImplicitToMnet
|
||||
from utils.helpers import count_parameters, compute_f1_scores
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
train_dataset = TBDDataset(
|
||||
path=args.data_path,
|
||||
mode="train",
|
||||
use_preprocessed_img=True
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
val_dataset = TBDDataset(
|
||||
path=args.data_path,
|
||||
mode="val",
|
||||
use_preprocessed_img=True
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
train_data = [data[1] for data in train_dataset.data]
|
||||
val_data = [data[1] for data in val_dataset.data]
|
||||
if args.logger:
|
||||
wandb.config.update({"train_data": train_data})
|
||||
wandb.config.update({"val_data": val_data})
|
||||
|
||||
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# model
|
||||
if args.model_type == 'tom_cm':
|
||||
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
elif args.model_type == 'tom_sl':
|
||||
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
||||
elif args.model_type == 'tom_impl':
|
||||
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||||
else: raise NotImplementedError
|
||||
if args.resume_from_checkpoint is not None:
|
||||
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
# scheduler
|
||||
if args.scheduler == None:
|
||||
scheduler = None
|
||||
else:
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
|
||||
# loss function
|
||||
if args.model_type == 'tom_sl':
|
||||
ce_loss_m1 = nn.NLLLoss()
|
||||
ce_loss_m2 = nn.NLLLoss()
|
||||
ce_loss_m12 = nn.NLLLoss()
|
||||
ce_loss_m21 = nn.NLLLoss()
|
||||
ce_loss_mc = nn.NLLLoss()
|
||||
else:
|
||||
ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
|
||||
stats = {
|
||||
'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []},
|
||||
'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}
|
||||
}
|
||||
max_val_f1 = 0
|
||||
max_val_classification_epoch = None
|
||||
counter = 0
|
||||
|
||||
print(f'Number of parameters: {count_parameters(model)}')
|
||||
|
||||
for i in range(args.num_epoch):
|
||||
# training
|
||||
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
|
||||
epoch_train_loss_m1 = 0.0
|
||||
epoch_train_loss_m2 = 0.0
|
||||
epoch_train_loss_m12 = 0.0
|
||||
epoch_train_loss_m21 = 0.0
|
||||
epoch_train_loss_mc = 0.0
|
||||
m1_train_batch_pred_list = []
|
||||
m2_train_batch_pred_list = []
|
||||
m12_train_batch_pred_list = []
|
||||
m21_train_batch_pred_list = []
|
||||
mc_train_batch_pred_list = []
|
||||
m1_train_batch_label_list = []
|
||||
m2_train_batch_label_list = []
|
||||
m12_train_batch_label_list = []
|
||||
m21_train_batch_label_list = []
|
||||
mc_train_batch_label_list = []
|
||||
model.train()
|
||||
for j, batch in tqdm(enumerate(train_dataloader)):
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
|
||||
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
||||
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
||||
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
||||
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
||||
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
||||
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
||||
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
||||
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
m1_pred = m1_pred.reshape(-1, 4)
|
||||
m2_pred = m2_pred.reshape(-1, 4)
|
||||
m12_pred = m12_pred.reshape(-1, 4)
|
||||
m21_pred = m21_pred.reshape(-1, 4)
|
||||
mc_pred = mc_pred.reshape(-1, 4)
|
||||
m1_label = labels[:, 0].reshape(-1).to(device)
|
||||
m2_label = labels[:, 1].reshape(-1).to(device)
|
||||
m12_label = labels[:, 2].reshape(-1).to(device)
|
||||
m21_label = labels[:, 3].reshape(-1).to(device)
|
||||
mc_label = labels[:, 4].reshape(-1).to(device)
|
||||
|
||||
loss_m1 = ce_loss_m1(m1_pred, m1_label)
|
||||
loss_m2 = ce_loss_m2(m2_pred, m2_label)
|
||||
loss_m12 = ce_loss_m12(m12_pred, m12_label)
|
||||
loss_m21 = ce_loss_m21(m21_pred, m21_label)
|
||||
loss_mc = ce_loss_mc(mc_pred, mc_label)
|
||||
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
|
||||
|
||||
epoch_train_loss_m1 += loss_m1.data.item()
|
||||
epoch_train_loss_m2 += loss_m2.data.item()
|
||||
epoch_train_loss_m12 += loss_m12.data.item()
|
||||
epoch_train_loss_m21 += loss_m21.data.item()
|
||||
epoch_train_loss_mc += loss_mc.data.item()
|
||||
|
||||
m1_train_batch_pred_list.append(m1_pred)
|
||||
m2_train_batch_pred_list.append(m2_pred)
|
||||
m12_train_batch_pred_list.append(m12_pred)
|
||||
m21_train_batch_pred_list.append(m21_pred)
|
||||
mc_train_batch_pred_list.append(mc_pred)
|
||||
m1_train_batch_label_list.append(m1_label)
|
||||
m2_train_batch_label_list.append(m2_label)
|
||||
m12_train_batch_label_list.append(m12_label)
|
||||
m21_train_batch_label_list.append(m21_label)
|
||||
mc_train_batch_label_list.append(mc_label)
|
||||
|
||||
optimizer.zero_grad()
|
||||
if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if args.logger: wandb.log({
|
||||
'batch_train_loss': loss.data.item(),
|
||||
'lr': optimizer.param_groups[-1]['lr']
|
||||
})
|
||||
|
||||
print("Epoch {}/{} batch {}/{} training done with loss={}".format(
|
||||
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item())
|
||||
)
|
||||
|
||||
if scheduler: scheduler.step()
|
||||
|
||||
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores(
|
||||
torch.cat(m1_train_batch_pred_list),
|
||||
torch.cat(m1_train_batch_label_list),
|
||||
torch.cat(m2_train_batch_pred_list),
|
||||
torch.cat(m2_train_batch_label_list),
|
||||
torch.cat(m12_train_batch_pred_list),
|
||||
torch.cat(m12_train_batch_label_list),
|
||||
torch.cat(m21_train_batch_pred_list),
|
||||
torch.cat(m21_train_batch_label_list),
|
||||
torch.cat(mc_train_batch_pred_list),
|
||||
torch.cat(mc_train_batch_label_list)
|
||||
)
|
||||
|
||||
print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format(
|
||||
i+1,
|
||||
args.num_epoch,
|
||||
epoch_train_loss_m1/len(train_dataloader),
|
||||
epoch_train_loss_m2/len(train_dataloader),
|
||||
epoch_train_loss_m12/len(train_dataloader),
|
||||
epoch_train_loss_m21/len(train_dataloader),
|
||||
epoch_train_loss_mc/len(train_dataloader),
|
||||
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score
|
||||
)
|
||||
)
|
||||
stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader))
|
||||
stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader))
|
||||
stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader))
|
||||
stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader))
|
||||
stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader))
|
||||
stats['train']['m1_f1'].append(train_m1_f1_score)
|
||||
stats['train']['m2_f1'].append(train_m2_f1_score)
|
||||
stats['train']['m12_f1'].append(train_m12_f1_score)
|
||||
stats['train']['m21_f1'].append(train_m21_f1_score)
|
||||
stats['train']['mc_f1'].append(train_mc_f1_score)
|
||||
|
||||
if args.logger: wandb.log(
|
||||
{
|
||||
'train_m1_loss': epoch_train_loss_m1/len(train_dataloader),
|
||||
'train_m2_loss': epoch_train_loss_m2/len(train_dataloader),
|
||||
'train_m12_loss': epoch_train_loss_m12/len(train_dataloader),
|
||||
'train_m21_loss': epoch_train_loss_m21/len(train_dataloader),
|
||||
'train_mc_loss': epoch_train_loss_mc/len(train_dataloader),
|
||||
'train_loss': epoch_train_loss_m1/len(train_dataloader) + \
|
||||
epoch_train_loss_m2/len(train_dataloader) + \
|
||||
epoch_train_loss_m12/len(train_dataloader) + \
|
||||
epoch_train_loss_m21/len(train_dataloader) + \
|
||||
epoch_train_loss_mc/len(train_dataloader),
|
||||
'train_m1_f1_score': train_m1_f1_score,
|
||||
'train_m2_f1_score': train_m2_f1_score,
|
||||
'train_m12_f1_score': train_m12_f1_score,
|
||||
'train_m21_f1_score': train_m21_f1_score,
|
||||
'train_mc_f1_score': train_mc_f1_score
|
||||
}
|
||||
)
|
||||
|
||||
# validation
|
||||
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
|
||||
epoch_val_loss_m1 = 0.0
|
||||
epoch_val_loss_m2 = 0.0
|
||||
epoch_val_loss_m12 = 0.0
|
||||
epoch_val_loss_m21 = 0.0
|
||||
epoch_val_loss_mc = 0.0
|
||||
m1_val_batch_pred_list = []
|
||||
m2_val_batch_pred_list = []
|
||||
m12_val_batch_pred_list = []
|
||||
m21_val_batch_pred_list = []
|
||||
mc_val_batch_pred_list = []
|
||||
m1_val_batch_label_list = []
|
||||
m2_val_batch_label_list = []
|
||||
m12_val_batch_label_list = []
|
||||
m21_val_batch_label_list = []
|
||||
mc_val_batch_label_list = []
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for j, batch in tqdm(enumerate(val_dataloader)):
|
||||
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
|
||||
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
||||
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
||||
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
||||
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
||||
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
||||
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
||||
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
||||
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||||
m1_pred = m1_pred.reshape(-1, 4)
|
||||
m2_pred = m2_pred.reshape(-1, 4)
|
||||
m12_pred = m12_pred.reshape(-1, 4)
|
||||
m21_pred = m21_pred.reshape(-1, 4)
|
||||
mc_pred = mc_pred.reshape(-1, 4)
|
||||
m1_label = labels[:, 0].reshape(-1).to(device)
|
||||
m2_label = labels[:, 1].reshape(-1).to(device)
|
||||
m12_label = labels[:, 2].reshape(-1).to(device)
|
||||
m21_label = labels[:, 3].reshape(-1).to(device)
|
||||
mc_label = labels[:, 4].reshape(-1).to(device)
|
||||
|
||||
loss_m1 = ce_loss_m1(m1_pred, m1_label)
|
||||
loss_m2 = ce_loss_m2(m2_pred, m2_label)
|
||||
loss_m12 = ce_loss_m12(m12_pred, m12_label)
|
||||
loss_m21 = ce_loss_m21(m21_pred, m21_label)
|
||||
loss_mc = ce_loss_mc(mc_pred, mc_label)
|
||||
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
|
||||
|
||||
epoch_val_loss_m1 += loss_m1.data.item()
|
||||
epoch_val_loss_m2 += loss_m2.data.item()
|
||||
epoch_val_loss_m12 += loss_m12.data.item()
|
||||
epoch_val_loss_m21 += loss_m21.data.item()
|
||||
epoch_val_loss_mc += loss_mc.data.item()
|
||||
|
||||
m1_val_batch_pred_list.append(m1_pred)
|
||||
m2_val_batch_pred_list.append(m2_pred)
|
||||
m12_val_batch_pred_list.append(m12_pred)
|
||||
m21_val_batch_pred_list.append(m21_pred)
|
||||
mc_val_batch_pred_list.append(mc_pred)
|
||||
m1_val_batch_label_list.append(m1_label)
|
||||
m2_val_batch_label_list.append(m2_label)
|
||||
m12_val_batch_label_list.append(m12_label)
|
||||
m21_val_batch_label_list.append(m21_label)
|
||||
mc_val_batch_label_list.append(mc_label)
|
||||
|
||||
if args.logger: wandb.log({'batch_val_loss': loss.data.item()})
|
||||
print("Epoch {}/{} batch {}/{} validation done with loss={}".format(
|
||||
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item())
|
||||
)
|
||||
|
||||
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores(
|
||||
torch.cat(m1_val_batch_pred_list),
|
||||
torch.cat(m1_val_batch_label_list),
|
||||
torch.cat(m2_val_batch_pred_list),
|
||||
torch.cat(m2_val_batch_label_list),
|
||||
torch.cat(m12_val_batch_pred_list),
|
||||
torch.cat(m12_val_batch_label_list),
|
||||
torch.cat(m21_val_batch_pred_list),
|
||||
torch.cat(m21_val_batch_label_list),
|
||||
torch.cat(mc_val_batch_pred_list),
|
||||
torch.cat(mc_val_batch_label_list)
|
||||
)
|
||||
|
||||
print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format(
|
||||
i+1,
|
||||
args.num_epoch,
|
||||
epoch_val_loss_m1/len(val_dataloader),
|
||||
epoch_val_loss_m2/len(val_dataloader),
|
||||
epoch_val_loss_m12/len(val_dataloader),
|
||||
epoch_val_loss_m21/len(val_dataloader),
|
||||
epoch_val_loss_mc/len(val_dataloader),
|
||||
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score
|
||||
)
|
||||
)
|
||||
|
||||
stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader))
|
||||
stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader))
|
||||
stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader))
|
||||
stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader))
|
||||
stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader))
|
||||
stats['val']['m1_f1'].append(val_m1_f1_score)
|
||||
stats['val']['m2_f1'].append(val_m2_f1_score)
|
||||
stats['val']['m12_f1'].append(val_m12_f1_score)
|
||||
stats['val']['m21_f1'].append(val_m21_f1_score)
|
||||
stats['val']['mc_f1'].append(val_mc_f1_score)
|
||||
|
||||
if args.logger: wandb.log(
|
||||
{
|
||||
'val_m1_loss': epoch_val_loss_m1/len(val_dataloader),
|
||||
'val_m2_loss': epoch_val_loss_m2/len(val_dataloader),
|
||||
'val_m12_loss': epoch_val_loss_m12/len(val_dataloader),
|
||||
'val_m21_loss': epoch_val_loss_m21/len(val_dataloader),
|
||||
'val_mc_loss': epoch_val_loss_mc/len(val_dataloader),
|
||||
'val_loss': epoch_val_loss_m1/len(val_dataloader) + \
|
||||
epoch_val_loss_m2/len(val_dataloader) + \
|
||||
epoch_val_loss_m12/len(val_dataloader) + \
|
||||
epoch_val_loss_m21/len(val_dataloader) + \
|
||||
epoch_val_loss_mc/len(val_dataloader),
|
||||
'val_m1_f1_score': val_m1_f1_score,
|
||||
'val_m2_f1_score': val_m2_f1_score,
|
||||
'val_m12_f1_score': val_m12_f1_score,
|
||||
'val_m21_f1_score': val_m21_f1_score,
|
||||
'val_mc_f1_score': val_mc_f1_score
|
||||
}
|
||||
)
|
||||
|
||||
# check for best stat/model using validation accuracy
|
||||
if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1:
|
||||
max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1]
|
||||
max_val_classification_epoch = i+1
|
||||
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
|
||||
counter = 0
|
||||
else:
|
||||
counter += 1
|
||||
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
|
||||
if counter >= args.patience:
|
||||
break
|
||||
|
||||
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
|
||||
f.write('{}\n'.format(CFG))
|
||||
f.write('{}\n'.format(train_data))
|
||||
f.write('{}\n'.format(val_data))
|
||||
f.write('{}\n'.format(stats))
|
||||
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1))
|
||||
f.close()
|
||||
|
||||
print(f'Results saved in {experiment_save_path}')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Define the command-line arguments
|
||||
parser.add_argument('--gpu_id', type=int)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--logger', action='store_true')
|
||||
parser.add_argument('--presaved', type=int, default=128)
|
||||
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
|
||||
parser.add_argument('--use_mixup', action='store_true')
|
||||
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
|
||||
parser.add_argument('--non_blocking', action='store_true')
|
||||
parser.add_argument('--patience', type=int, default=99)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--num_workers', type=int, default=8)
|
||||
parser.add_argument('--pin_memory', action='store_true')
|
||||
parser.add_argument('--num_epoch', type=int, default=300)
|
||||
parser.add_argument('--lr', type=float, default=4e-4)
|
||||
parser.add_argument('--scheduler', type=str, default=None)
|
||||
parser.add_argument('--dropout', type=float, default=0.1)
|
||||
parser.add_argument('--weight_decay', type=float, default=0.005)
|
||||
parser.add_argument('--label_smoothing', type=float, default=0.1)
|
||||
parser.add_argument('--model_type', type=str)
|
||||
parser.add_argument('--aggr', type=str, default='concat', required=False)
|
||||
parser.add_argument('--use_resnet', action='store_true')
|
||||
parser.add_argument('--hidden_dim', type=int, default=64)
|
||||
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
||||
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
|
||||
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
|
||||
parser.add_argument('--save_path', type=str, default='experiments/')
|
||||
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.use_mixup and not args.mixup_alpha:
|
||||
parser.error("--use_mixup requires --mixup_alpha")
|
||||
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
||||
if not args.aggr:
|
||||
parser.error("The choosen --model_type requires --aggr")
|
||||
if args.model_type == 'tom_sl' and not args.tom_weight:
|
||||
parser.error("The choosen --model_type requires --tom_weight")
|
||||
|
||||
# get experiment ID
|
||||
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path, exist_ok=True)
|
||||
experiment_save_path = os.path.join(args.save_path, experiment_id)
|
||||
os.makedirs(experiment_save_path, exist_ok=True)
|
||||
|
||||
CFG = {
|
||||
'use_ocr_custom_loss': 0,
|
||||
'presaved': args.presaved,
|
||||
'batch_size': args.batch_size,
|
||||
'num_epoch': args.num_epoch,
|
||||
'lr': args.lr,
|
||||
'scheduler': args.scheduler,
|
||||
'weight_decay': args.weight_decay,
|
||||
'model_type': args.model_type,
|
||||
'use_resnet': args.use_resnet,
|
||||
'hidden_dim': args.hidden_dim,
|
||||
'tom_weight': args.tom_weight,
|
||||
'dropout': args.dropout,
|
||||
'label_smoothing': args.label_smoothing,
|
||||
'clip_grad_norm': args.clip_grad_norm,
|
||||
'use_mixup': args.use_mixup,
|
||||
'mixup_alpha': args.mixup_alpha,
|
||||
'non_blocking_tensors': args.non_blocking,
|
||||
'patience': args.patience,
|
||||
'pin_memory': args.pin_memory,
|
||||
'resume_from_checkpoint': args.resume_from_checkpoint,
|
||||
'aggr': args.aggr,
|
||||
'mods': args.mods,
|
||||
'save_path': experiment_save_path ,
|
||||
'seed': args.seed
|
||||
}
|
||||
|
||||
print(CFG)
|
||||
print(f'Saving results in {experiment_save_path}')
|
||||
|
||||
# set seed values
|
||||
if args.logger:
|
||||
wandb.init(project="tbd", config=CFG)
|
||||
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
main(args)
|
224
tbd/utils/fb_scores_err.py
Normal file
224
tbd/utils/fb_scores_err.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import os
|
||||
import csv
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from sklearn.metrics import f1_score
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
ALPHA = 0.7
|
||||
BAR_WIDTH = 0.27
|
||||
sns.set_theme(style='whitegrid')
|
||||
#sns.set_palette('mako')
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
model_to_subdir = {
|
||||
"IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"],
|
||||
"IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"],
|
||||
"IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"],
|
||||
"IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"],
|
||||
"CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"],
|
||||
"CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"],
|
||||
"CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"],
|
||||
"CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"],
|
||||
"DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"],
|
||||
"Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"]
|
||||
}
|
||||
|
||||
def read_data_from_csv(subdirectory_path):
|
||||
print(subdirectory_path)
|
||||
data = []
|
||||
csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')]
|
||||
for csv_file in csv_files:
|
||||
file_path = os.path.join(subdirectory_path, csv_file)
|
||||
with open(file_path, 'r') as file:
|
||||
reader = csv.reader(file)
|
||||
header_skipped = False
|
||||
for row in reader:
|
||||
if not header_skipped:
|
||||
header_skipped = True
|
||||
continue
|
||||
frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row
|
||||
data.append({
|
||||
'frame': int(frame),
|
||||
'm1_pred': int(m1_pred),
|
||||
'm2_pred': int(m2_pred),
|
||||
'm12_pred': int(m12_pred),
|
||||
'm21_pred': int(m21_pred),
|
||||
'mc_pred': int(mc_pred),
|
||||
'm1_label': int(m1_label),
|
||||
'm2_label': int(m2_label),
|
||||
'm12_label': int(m12_label),
|
||||
'm21_label': int(m21_label),
|
||||
'mc_label': int(mc_label),
|
||||
'false_belief': false_belief,
|
||||
})
|
||||
return data
|
||||
|
||||
def compute_correct_false_belief(data, mind="all", folder=None):
|
||||
total_false_belief = 0
|
||||
correct_false_belief = 0
|
||||
for item in data:
|
||||
if 'false' in item['false_belief']:
|
||||
false_belief_type = item['false_belief'].split('_')[0]
|
||||
if mind == "all" or false_belief_type in mind:
|
||||
total_false_belief += 1
|
||||
if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]:
|
||||
if folder is not None:
|
||||
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
|
||||
f.write(f"{str(1)}\n")
|
||||
correct_false_belief += 1
|
||||
else:
|
||||
if folder is not None:
|
||||
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
|
||||
f.write(f"{str(0)}\n")
|
||||
if total_false_belief == 0:
|
||||
accuracy = 0.0
|
||||
else:
|
||||
accuracy = correct_false_belief / total_false_belief
|
||||
return accuracy
|
||||
|
||||
def compute_macro_f1_score(data, mind="all"):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for item in data:
|
||||
if 'false' in item['false_belief']:
|
||||
false_belief_type = item['false_belief'].split('_')[0]
|
||||
if mind == "all" or false_belief_type in mind:
|
||||
y_true.append(int(item[f"{false_belief_type}_label"]))
|
||||
y_pred.append(int(item[f"{false_belief_type}_pred"]))
|
||||
|
||||
if not y_true or not y_pred:
|
||||
macro_f1 = 0.0
|
||||
else:
|
||||
macro_f1 = f1_score(y_true, y_pred, average='macro')
|
||||
|
||||
return macro_f1
|
||||
|
||||
def delete_files_in_subfolders(folder_path, file_names_to_delete):
|
||||
"""
|
||||
Delete specified files in all subfolders of a given folder.
|
||||
|
||||
Parameters:
|
||||
folder_path: The path to the folder containing subfolders.
|
||||
file_names_to_delete: A list of file names to be deleted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for root, _, _ in os.walk(folder_path):
|
||||
for file_name in file_names_to_delete:
|
||||
file_path = os.path.join(root, file_name)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
print(f"Deleted: {file_path}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
folder_path = "predictions"
|
||||
files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"]
|
||||
delete_files_in_subfolders(folder_path, files_to_delete)
|
||||
|
||||
metric = "Accuracy"
|
||||
if metric == "Macro F1":
|
||||
score_function = compute_macro_f1_score
|
||||
elif metric == "Accuracy":
|
||||
score_function = compute_correct_false_belief
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
models = [
|
||||
'Base', 'DB',
|
||||
'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$',
|
||||
'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$'
|
||||
]
|
||||
|
||||
parent_dir = 'predictions'
|
||||
minds = categories = ['m1', 'm2', 'm12', 'm21']
|
||||
score_m1_m2 = []
|
||||
score_m12_m21 = []
|
||||
score_all = []
|
||||
std_m1_m2 = []
|
||||
std_m12_m21 = []
|
||||
std_all = []
|
||||
|
||||
for model in models:
|
||||
model_scores_m1_m2 = []
|
||||
model_scores_m12_m21 = []
|
||||
model_scores_all = []
|
||||
for s in range(3):
|
||||
subdir_path = os.path.join(parent_dir, model_to_subdir[model][s])
|
||||
data = read_data_from_csv(subdir_path)
|
||||
model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s]))
|
||||
model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s]))
|
||||
model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s]))
|
||||
score_m1_m2.append(np.mean(model_scores_m1_m2))
|
||||
std_m1_m2.append(np.std(model_scores_m1_m2))
|
||||
score_m12_m21.append(np.mean(model_scores_m12_m21))
|
||||
std_m12_m21.append(np.std(model_scores_m12_m21))
|
||||
score_all.append(np.mean(model_scores_all))
|
||||
std_all.append(np.std(model_scores_all))
|
||||
|
||||
# Create a dataframe to use with sns.catplot
|
||||
data = {
|
||||
'Model': [m for m in models],
|
||||
'FO_FB_mean': score_m1_m2,
|
||||
'FO_FB_std': std_m1_m2,
|
||||
'SO_FB_mean': score_m12_m21,
|
||||
'SO_FB_std': std_m12_m21,
|
||||
'Both_mean': score_all,
|
||||
'Both_std': std_all
|
||||
}
|
||||
|
||||
models = data['Model']
|
||||
fo_fb_mean = data['FO_FB_mean']
|
||||
fo_fb_std = data['FO_FB_std']
|
||||
so_fb_mean = data['SO_FB_mean']
|
||||
so_fb_std = data['SO_FB_std']
|
||||
both_mean = data['Both_mean']
|
||||
both_std = data['Both_std']
|
||||
|
||||
bar_width = BAR_WIDTH
|
||||
x = np.arange(len(models))
|
||||
|
||||
plt.figure(figsize=(13, 3.5))
|
||||
fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA)
|
||||
so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA)
|
||||
both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA)
|
||||
|
||||
def add_labels(bars, std_values):
|
||||
cnt = 0
|
||||
for bar, std in zip(bars, std_values):
|
||||
height = bar.get_height()
|
||||
offset = std + 0.01
|
||||
if cnt == 0 or cnt == 1 or cnt == 9:
|
||||
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10)
|
||||
else:
|
||||
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10)
|
||||
cnt = cnt + 1
|
||||
|
||||
add_labels(fo_fb_bars, fo_fb_std)
|
||||
add_labels(so_fb_bars, so_fb_std)
|
||||
add_labels(both_bars, both_std)
|
||||
|
||||
plt.gca().spines['top'].set_visible(False)
|
||||
plt.gca().spines['right'].set_visible(False)
|
||||
plt.xlabel('MToMnet', fontsize=14)
|
||||
plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14)
|
||||
plt.xticks(x, models, rotation=0, fontsize=14)
|
||||
plt.yticks(fontsize=14)
|
||||
plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3)
|
||||
plt.tight_layout()
|
||||
plt.savefig('results/false_belief_first_vs_second.pdf')
|
210
tbd/utils/helpers.py
Normal file
210
tbd/utils/helpers.py
Normal file
File diff suppressed because one or more lines are too long
37
tbd/utils/preprocess_img.py
Normal file
37
tbd/utils/preprocess_img.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
import glob
|
||||
|
||||
import cv2
|
||||
|
||||
import torchvision.transforms as T
|
||||
import torch
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
PATH_IN = "/scratch/bortoletto/data/tbd/images"
|
||||
PATH_OUT = "/scratch/bortoletto/data/tbd/images_norm"
|
||||
|
||||
normalisation_steps = [
|
||||
T.ToTensor(),
|
||||
T.Resize((128,128)),
|
||||
T.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
]
|
||||
|
||||
preprocess_img = T.Compose(normalisation_steps)
|
||||
|
||||
def main():
|
||||
print(f"{PATH_IN}/*/*/*.jpg")
|
||||
all_img = glob.glob(f"{PATH_IN}/*/*/*.jpg")
|
||||
print(len(all_img))
|
||||
for img_path in tqdm(all_img):
|
||||
new_img = preprocess_img(cv2.imread(img_path)).numpy()
|
||||
img_path_split = img_path.split("/")
|
||||
os.makedirs(f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}", exist_ok=True)
|
||||
out_img = f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}/{img_path_split[-1][:-4]}.pt"
|
||||
torch.save(new_img, out_img)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
106
tbd/utils/reformat_labels_ours.py
Normal file
106
tbd/utils/reformat_labels_ours.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
import pandas as pd
|
||||
import os
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
DATASET_LOCATION = "YOUR_PATH_HERE"
|
||||
|
||||
def reframe_annotation():
|
||||
annotation_path = f'{DATASET_LOCATION}/retrieve_annotation/all/'
|
||||
save_path = f'{DATASET_LOCATION}/reformat_annotation/'
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
tasks = glob.glob(annotation_path + '*.txt')
|
||||
id_map = pd.read_csv('id_map.csv')
|
||||
for task in tasks:
|
||||
if not task.split('/')[-1].split('_')[2] == '1.txt':
|
||||
continue
|
||||
with open(task, 'r') as f:
|
||||
lines = f.readlines()
|
||||
task_id = int(task.split('/')[-1].split('_')[1]) + 1
|
||||
clip = id_map.loc[id_map['ID'] == task_id].folder
|
||||
print(task_id, len(clip))
|
||||
if len(clip) == 0:
|
||||
continue
|
||||
with open(save_path + clip.item() + '.txt', 'w') as f:
|
||||
for line in lines:
|
||||
words = line.split()
|
||||
f.write(words[0] + ',' + words[1] + ',' + words[2] + ',' + words[3] + ',' + words[4] + ',' + words[5] +
|
||||
',' + words[6] + ',' + words[7] + ',' + words[8] + ',' + words[9] + ',' + ' '.join(words[10:]) + '\n')
|
||||
f.close()
|
||||
|
||||
def get_grid_location(obj_frame):
|
||||
x_min = obj_frame['x_min']#.item()
|
||||
y_min = obj_frame['y_min']#.item()
|
||||
x_max = obj_frame['x_max']#.item()
|
||||
y_max = obj_frame['y_max']#.item()
|
||||
gridLW = 1280 / 25.
|
||||
gridLH = 720 / 15.
|
||||
center_x, center_y = (x_min + x_max)/2, (y_min + y_max)/2
|
||||
X, Y = int(center_x / gridLW), int(center_y / gridLH)
|
||||
return X, Y
|
||||
|
||||
def regenerate_annotation():
|
||||
annotation_path = f'{DATASET_LOCATION}/reformat_annotation/'
|
||||
save_path=f'{DATASET_LOCATION}/regenerate_annotation/'
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
tasks = glob.glob(annotation_path + '*.txt')
|
||||
for task in tasks:
|
||||
print(task)
|
||||
annt = pd.read_csv(task, sep=",", header=None)
|
||||
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name", "label"]
|
||||
obj_records = {}
|
||||
for index, obj_frame in annt.iterrows():
|
||||
if obj_frame['name'].startswith('P'):
|
||||
continue
|
||||
else:
|
||||
assert obj_frame['name'].startswith('O')
|
||||
obj_name = obj_frame['name']
|
||||
# 0: enter 1: disappear 2: update 3: unchange
|
||||
frame_id = obj_frame['frame']
|
||||
curr_loc = get_grid_location(obj_frame)
|
||||
mind_dict = {'m1': {'fluent': 3, 'loc': None}, 'm2': {'fluent': 3, 'loc': None},
|
||||
'm12': {'fluent': 3, 'loc': None},
|
||||
'm21': {'fluent': 3, 'loc': None}, 'mc': {'fluent': 3, 'loc': None},
|
||||
'mg': {'fluent': 3, 'loc': curr_loc}}
|
||||
mind_dict['mg']['loc'] = curr_loc
|
||||
if not type(obj_frame['label']) == float:
|
||||
mind_labels = obj_frame['label'].split()
|
||||
for mind_label in mind_labels:
|
||||
if mind_label == 'in_m1' or mind_label == 'in_m2' or mind_label == 'in_m12' \
|
||||
or mind_label == 'in_m21' or mind_label == 'in_mc' or mind_label == '"in_m1"' or mind_label == '"in_m2"'\
|
||||
or mind_label == '"in_m12"' or mind_label == '"in_m21"' or mind_label == '"in_mc"':
|
||||
mind_name = mind_label.split('_')[1].split('"')[0]
|
||||
mind_dict[mind_name]['loc'] = curr_loc
|
||||
else:
|
||||
mind_name = mind_label.split('_')[0].split('"')
|
||||
if len(mind_name) > 1:
|
||||
mind_name = mind_name[1]
|
||||
else:
|
||||
mind_name = mind_name[0]
|
||||
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
|
||||
mind_dict[mind_name]['loc'] = last_loc
|
||||
|
||||
for mind_name in mind_dict.keys():
|
||||
if frame_id > 0:
|
||||
curr_loc = mind_dict[mind_name]['loc']
|
||||
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
|
||||
if last_loc is None and curr_loc is not None:
|
||||
mind_dict[mind_name]['fluent'] = 0
|
||||
elif last_loc is not None and curr_loc is None:
|
||||
mind_dict[mind_name]['fluent'] = 1
|
||||
elif not last_loc == curr_loc:
|
||||
mind_dict[mind_name]['fluent'] = 2
|
||||
if obj_name not in obj_records:
|
||||
obj_records[obj_name] = [mind_dict]
|
||||
else:
|
||||
obj_records[obj_name].append(mind_dict)
|
||||
|
||||
with open(save_path + task.split('/')[-1].split('.')[0] + '.p', 'wb') as f:
|
||||
pickle.dump(obj_records, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reframe_annotation()
|
||||
regenerate_annotation()
|
75
tbd/utils/similarity.py
Normal file
75
tbd/utils/similarity.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.decomposition import PCA
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
FOLDER_PATH = 'PATH_TO_FOLDER'
|
||||
|
||||
print(FOLDER_PATH)
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
COLORS = sns.color_palette()
|
||||
|
||||
sns.set_theme(style='white')
|
||||
|
||||
out_left_main_mods_full_test = []
|
||||
out_right_main_mods_full_test = []
|
||||
cell_left_main_mods_full_test = []
|
||||
cell_right_main_mods_full_test = []
|
||||
cm_left_main_mods_full_test = []
|
||||
cm_right_main_mods_full_test = []
|
||||
|
||||
for i in range(len([filename for filename in os.listdir(FOLDER_PATH) if filename.endswith('.pt')])):
|
||||
|
||||
print(f'Computing analysis for test video {i}...', end='\r')
|
||||
|
||||
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
|
||||
data = torch.load(emb_file)
|
||||
if len(data) == 13: # implicit
|
||||
model = 'impl'
|
||||
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||||
elif len(data) == 12: # common mind
|
||||
model = 'cm'
|
||||
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||||
elif len(data) == 11: # speaker-listener
|
||||
model = 'sl'
|
||||
out_left, out_right, feats = data[0], data[1], data[2:]
|
||||
else: raise ValueError("Data should have 13 (impl), others are not implemented")
|
||||
|
||||
# ====== PCA for left and right embeddings ====== #
|
||||
|
||||
out_left_pca = out_left[0].reshape(-1, 64)
|
||||
out_right_pca = out_right[0].reshape(-1, 64)
|
||||
out_left_and_right = np.concatenate((out_left_pca, out_right_pca), axis=0)
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
pca_result = pca.fit_transform(out_left_and_right)
|
||||
|
||||
# Separate the PCA results for each tensor
|
||||
pca_result_left = pca_result[:out_left_pca.shape[0]]
|
||||
pca_result_right = pca_result[out_right_pca.shape[0]:]
|
||||
|
||||
plt.figure(figsize=(6.8,6))
|
||||
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
|
||||
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
|
||||
plt.xlabel('Principal Component 1', fontsize=32)
|
||||
plt.ylabel('Principal Component 2', fontsize=32)
|
||||
plt.xticks(fontsize=24)
|
||||
plt.xticks([-0.4, -0.2, 0.0, 0.2, 0.4])
|
||||
plt.yticks(fontsize=24)
|
||||
plt.legend(fontsize=32)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
|
||||
plt.close()
|
96
tbd/utils/store_mind_set.py
Normal file
96
tbd/utils/store_mind_set.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def check_append(obj_name, m1, mind_name, obj_frame, flags, label):
|
||||
if label:
|
||||
if not obj_name in m1:
|
||||
m1[obj_name] = []
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
|
||||
flags[mind_name] = 1
|
||||
elif not flags[mind_name]:
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
|
||||
flags[mind_name] = 1
|
||||
else: # false belief
|
||||
if obj_name in m1:
|
||||
if flags[mind_name]:
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 1])
|
||||
flags[mind_name] = 0
|
||||
return flags, m1
|
||||
|
||||
|
||||
def store_mind_set(clip, annotation_path, save_path):
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
annt = pd.read_csv(annotation_path + clip, sep=",", header=None)
|
||||
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name",
|
||||
"label"]
|
||||
obj_names = annt.name.unique()
|
||||
m1, m2, m12, m21, mc = {}, {}, {}, {}, {}
|
||||
flags = {'m1':0, 'm2':0, 'm12':0, 'm21':0, 'mc':0}
|
||||
for obj_name in obj_names:
|
||||
if obj_name == 'P1' or obj_name == 'P2':
|
||||
continue
|
||||
obj_frames = annt.loc[annt.name == obj_name]
|
||||
for index, obj_frame in obj_frames.iterrows():
|
||||
if type(obj_frame.label) == float:
|
||||
continue
|
||||
labels = obj_frame.label.split()
|
||||
for label in labels:
|
||||
if label == 'in_m1' or label == '"in_m1"':
|
||||
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 1)
|
||||
elif label == 'in_m2' or label == '"in_m2"':
|
||||
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 1)
|
||||
elif label == 'in_m12'or label == '"in_m12"':
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 1)
|
||||
elif label == 'in_m21' or label == '"in_m21"':
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 1)
|
||||
elif label == 'in_mc'or label == '"in_mc"':
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 1)
|
||||
elif label == 'm1_false' or label == '"m1_false"':
|
||||
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 0)
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
|
||||
false_belief = 'm1_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm2_false' or label == '"m2_false"':
|
||||
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 0)
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
|
||||
false_belief = 'm2_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm12_false' or label == '"m12_false"':
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
|
||||
false_belief = 'm12_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm21_false' or label == '"m21_false"':
|
||||
flags, m21 = check_append(obj_name, m2, 'm21', obj_frame, flags, 0)
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
|
||||
false_belief = 'm21_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
# print('m1', m1)
|
||||
# print('m2', m2)
|
||||
# print('m12', m12)
|
||||
# print('m21', m21)
|
||||
# print('mc', mc)
|
||||
#with open(save_path + clip.split('.')[0] + '.p', 'wb') as f:
|
||||
# pickle.dump([m1, m2, m12, m21, mc], f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
annotation_path = '/scratch/bortoletto/data/tbd/reformat_annotation/'
|
||||
save_path = '/scratch/bortoletto/data/tbd/store_mind_set/'
|
||||
|
||||
for clip in tqdm(os.listdir(annotation_path), desc="Processing videos", unit="item"):
|
||||
store_mind_set(clip, annotation_path, save_path)
|
95
tbd/utils/visualize_bbox.py
Normal file
95
tbd/utils/visualize_bbox.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import time
|
||||
from tbd_dataloader import TBDv2Dataset
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
|
||||
def point2screen(points):
|
||||
K = [607.13232421875, 0.0, 638.6468505859375, 0.0, 607.1067504882812, 367.1607360839844, 0.0, 0.0, 1.0]
|
||||
K = np.reshape(np.array(K), [3, 3])
|
||||
rot_points = np.array(points) + np.array([0, 0.2, 0])
|
||||
rot_points = rot_points
|
||||
points_camera = rot_points.reshape(3, 1)
|
||||
|
||||
project_matrix = np.array(K).reshape(3, 3)
|
||||
points_prj = project_matrix.dot(points_camera)
|
||||
points_prj = points_prj.transpose()
|
||||
if not points_prj[:, 2][0] == 0.0:
|
||||
points_prj[:, 0] = points_prj[:, 0] / points_prj[:, 2]
|
||||
points_prj[:, 1] = points_prj[:, 1] / points_prj[:, 2]
|
||||
points_screen = points_prj[:, :2]
|
||||
assert points_screen.shape == (1, 2)
|
||||
points_screen = points_screen.reshape(-1)
|
||||
return points_screen
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = TBDv2Dataset(number_frames_to_sample=1, resize_img=None)
|
||||
index = np.random.randint(0, len(data))
|
||||
start = time.time()
|
||||
(
|
||||
kinect_imgs, # <- len x 720 x 1280 x 3
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bbox,
|
||||
tracker_skeID_sample, # <- This is the tracker skeleton ID
|
||||
tracker2d,
|
||||
label,
|
||||
experiment_id, # From here for debugging
|
||||
timestep,
|
||||
obj_id, # <- This is the object ID as a string
|
||||
) = data[index]
|
||||
end = time.time()
|
||||
print(f"Time for one sample: {end-start}")
|
||||
|
||||
img = kinect_imgs[-1]
|
||||
bbox = bbox[-1]
|
||||
print(label.shape)
|
||||
|
||||
print(skele1.shape)
|
||||
print(skele2.shape)
|
||||
|
||||
skele1 = skele1[-1, :,:]
|
||||
skele2 = skele2[-1, :,:]
|
||||
|
||||
print(skele1.shape)
|
||||
|
||||
|
||||
|
||||
# reshape img from c, h, w to h, w, c
|
||||
img = img.permute(1, 2, 0)
|
||||
|
||||
fig, ax = plt.subplots(1)
|
||||
ax.imshow(img)
|
||||
print(bbox[0], bbox[1], bbox[2], bbox[3]) # t(top left x, top left y, width, height)
|
||||
top_left_x, top_left_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
|
||||
|
||||
|
||||
|
||||
for i in range(26):
|
||||
print(skele1[i,0], skele1[i,1])
|
||||
print(skele1[i,:].shape)
|
||||
print(point2screen(skele1[i,:]))
|
||||
x, y = point2screen(skele1[i,:])[0], point2screen(skele1[i,:])[1]
|
||||
ax.text(x, y, f"{i}", fontsize=5, color='w')
|
||||
|
||||
wedge = patches.Wedge((x,y), 10, 0, 360, width=10, color='b')
|
||||
ax.add_patch(wedge)
|
||||
|
||||
for i in range(26):
|
||||
x, y = point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]
|
||||
ax.text(x, y, f"{i}", fontsize=5, color='w')
|
||||
wedge = patches.Wedge((point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]), 10, 0, 360, width=10, color='r')
|
||||
ax.add_patch(wedge)
|
||||
|
||||
# Create a Rectangle patch
|
||||
# rect = patches.Rectangle((top_left_x, top_left_y-height), width, height, linewidth=1, edgecolor='r', facecolor='none')
|
||||
# ax.add_patch(rect)
|
||||
# rect = patches.Rectangle((x_min, y_max), x_max-x_min, y_max-y_min, linewidth=1, edgecolor='g', facecolor='none')
|
||||
# ax.add_patch(rect)
|
||||
fig.savefig(f"bbox_{obj_id}_{index}_{experiment_id}.png")
|
Loading…
Reference in a new issue