first commit

This commit is contained in:
Zhiming Hu 2025-04-30 14:15:00 +02:00
parent 99ce0acafb
commit 8f6b6a34e7
73 changed files with 11656 additions and 0 deletions

View file

@ -0,0 +1,46 @@
# HOIGaze: Gaze Estimation During Hand-Object Interactions in Extended Reality Exploiting Eye-Hand-Head Coordination
Project homepage: https://zhiminghu.net/hu25_hoigaze.
## Abstract
```
We present HOIGaze a novel learning-based approach for gaze estimation during hand-object interactions (HOI) in extended reality (XR).
HOIGaze addresses the challenging HOI setting by building on one key insight: The eye, hand, and head movements are closely coordinated during HOIs and this coordination can be exploited to identify samples that are most useful for gaze estimator training as such, effectively denoising the training data.
This denoising approach is in stark contrast to previous gaze estimation methods that treated all training samples as equal.
Specifically, we propose: 1) a novel hierarchical framework that first recognises the hand currently visually attended to and then estimates gaze direction based on the attended hand; 2) a new gaze estimator that uses cross-modal Transformers to fuse head and hand-object features extracted using a convolutional neural network and a spatio-temporal graph convolutional network; and 3) a novel eye-head coordination loss that upgrades training samples belonging to the coordinated eye-head movements.
We evaluate HOIGaze on the HOT3D and Aria digital twin (ADT) datasets and show that it significantly outperforms state-of-the-art methods, achieving an average improvement of 15.6% on HOT3D and 6.0% on ADT in mean angular error.
To demonstrate the potential of our method, we further report significant performance improvements for the sample downstream task of eye-based activity recognition on ADT.
Taken together, our results underline the significant information content available in eye-hand-head coordination and, as such, open up an exciting new direction for learning-based gaze estimation.
```
## Environment:
Ubuntu 22.04
python 3.8+
pytorch 1.8.1
## Usage:
Step 1: Create the environment
```
conda env create -f ./environment/hoigaze.yaml -n hoigaze
conda activate hoigaze
```
Step 2: Follow the instructions in './adt_processing/' and './hot3d_processing/' to process the datasets.
Step 3: Set 'data_dir' and 'cuda_idx' in 'train_hot3d_userX.sh' (X for 1, 2, or 3) to evaluate on HOT3D for different users. Set 'data_dir' and 'cuda_idx' in 'train_hot3d_sceneX.sh' (X for 1, 2, or 3) to evaluate on HOT3D for different scenes.
Step 4: Set 'data_dir' and 'cuda_idx' in 'train_adt.sh' to evaluate on ADT.
## Citation
```bibtex
@inproceedings{hu25hoigaze,
title={HOIGaze: Gaze Estimation During Hand-Object Interactions in Extended Reality Exploiting Eye-Hand-Head Coordination},
author={Hu, Zhiming and Haeufle, Daniel and Schmitt, Syn and Bulling, Andreas},
booktitle={Proceedings of the 2025 ACM Special Interest Group on Computer Graphics and Interactive Techniques},
year={2025}}
```

29
adt_processing/README.md Normal file
View file

@ -0,0 +1,29 @@
## Code to process the ADT dataset
Note: processing the ADT dataset is much more complicated than other datasets because it relies on the Project Aria Tools. It would be easier to get started with other datasets first.
## Usage:
Step 1: Follow https://facebookresearch.github.io/projectaria_tools/docs/open_datasets/aria_digital_twin_dataset/dataset_download to prepare the environment and download the dataset. Please note that in our paper we used the 1.1.0 version of the dataset with Project Aria Tools 1.1.0 and Python 3.8. We use 'python setup.py build_py' to build Project Aria Tools.
Step 2: Set 'dataset_path' and 'dataset_processed_path' in 'adt_preprocessing.py', put 'adt_preprocessing.py', 'adt.csv', and 'utils' into the codebase of the Project Aria Tools, and run it to process the dataset.
Step 3: It is optional but highly recommended to set 'data_path' in 'dataset_visualisation.py' to visualise and get familiar with the dataset.
## Citations
```bibtex
@inproceedings{hu25hoigaze,
title={HOIGaze: Gaze Estimation During Hand-Object Interactions in Extended Reality Exploiting Eye-Hand-Head Coordination},
author={Hu, Zhiming and Haeufle, Daniel and Schmitt, Syn and Bulling, Andreas},
booktitle={Proceedings of the 2025 ACM Special Interest Group on Computer Graphics and Interactive Techniques},
year={2025}}
@inproceedings{pan2023aria,
title={Aria digital twin: A new benchmark dataset for egocentric 3d machine perception},
author={Pan, Xiaqing and Charron, Nicholas and Yang, Yongqian and Peters, Scott and Whelan, Thomas and Kong, Chen and Parkhi, Omkar and Newcombe, Richard and Ren, Yuheng Carl},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={20133--20143},
year={2023}}
```

35
adt_processing/adt.csv Normal file
View file

@ -0,0 +1,35 @@
sequence_name,training,action
Apartment_release_work_skeleton_seq132,0,work
Apartment_release_work_skeleton_seq138,0,work
Apartment_release_meal_skeleton_seq132,0,meal
Apartment_release_decoration_skeleton_seq133,0,decoration
Apartment_release_decoration_skeleton_seq139,0,decoration
Apartment_release_decoration_skeleton_seq134,0,decoration
Apartment_release_work_skeleton_seq107,0,work
Apartment_release_meal_skeleton_seq135,0,meal
Apartment_release_work_skeleton_seq135,0,work
Apartment_release_meal_skeleton_seq131,0,meal
Apartment_release_work_skeleton_seq131,1,work
Apartment_release_work_skeleton_seq109,1,work
Apartment_release_work_skeleton_seq110,1,work
Apartment_release_decoration_skeleton_seq140,1,decoration
Apartment_release_decoration_skeleton_seq137,1,decoration
Apartment_release_work_skeleton_seq136,1,work
Apartment_release_meal_skeleton_seq136,1,meal
Apartment_release_work_skeleton_seq106,1,work
Apartment_release_meal_skeleton_seq134,1,meal
Apartment_release_work_skeleton_seq134,1,work
Apartment_release_decoration_skeleton_seq135,1,decoration
Apartment_release_decoration_skeleton_seq138,1,decoration
Apartment_release_decoration_skeleton_seq132,1,decoration
Apartment_release_work_skeleton_seq139,1,work
Apartment_release_work_skeleton_seq133,1,work
Apartment_release_meal_skeleton_seq139,1,meal
Apartment_release_meal_skeleton_seq133,1,meal
Apartment_release_work_skeleton_seq140,1,work
Apartment_release_work_skeleton_seq137,1,work
Apartment_release_meal_skeleton_seq140,1,meal
Apartment_release_meal_skeleton_seq137,1,meal
Apartment_release_decoration_skeleton_seq136,1,decoration
Apartment_release_decoration_skeleton_seq131,1,decoration
Apartment_release_work_skeleton_seq108,1,work
1 sequence_name training action
2 Apartment_release_work_skeleton_seq132 0 work
3 Apartment_release_work_skeleton_seq138 0 work
4 Apartment_release_meal_skeleton_seq132 0 meal
5 Apartment_release_decoration_skeleton_seq133 0 decoration
6 Apartment_release_decoration_skeleton_seq139 0 decoration
7 Apartment_release_decoration_skeleton_seq134 0 decoration
8 Apartment_release_work_skeleton_seq107 0 work
9 Apartment_release_meal_skeleton_seq135 0 meal
10 Apartment_release_work_skeleton_seq135 0 work
11 Apartment_release_meal_skeleton_seq131 0 meal
12 Apartment_release_work_skeleton_seq131 1 work
13 Apartment_release_work_skeleton_seq109 1 work
14 Apartment_release_work_skeleton_seq110 1 work
15 Apartment_release_decoration_skeleton_seq140 1 decoration
16 Apartment_release_decoration_skeleton_seq137 1 decoration
17 Apartment_release_work_skeleton_seq136 1 work
18 Apartment_release_meal_skeleton_seq136 1 meal
19 Apartment_release_work_skeleton_seq106 1 work
20 Apartment_release_meal_skeleton_seq134 1 meal
21 Apartment_release_work_skeleton_seq134 1 work
22 Apartment_release_decoration_skeleton_seq135 1 decoration
23 Apartment_release_decoration_skeleton_seq138 1 decoration
24 Apartment_release_decoration_skeleton_seq132 1 decoration
25 Apartment_release_work_skeleton_seq139 1 work
26 Apartment_release_work_skeleton_seq133 1 work
27 Apartment_release_meal_skeleton_seq139 1 meal
28 Apartment_release_meal_skeleton_seq133 1 meal
29 Apartment_release_work_skeleton_seq140 1 work
30 Apartment_release_work_skeleton_seq137 1 work
31 Apartment_release_meal_skeleton_seq140 1 meal
32 Apartment_release_meal_skeleton_seq137 1 meal
33 Apartment_release_decoration_skeleton_seq136 1 decoration
34 Apartment_release_decoration_skeleton_seq131 1 decoration
35 Apartment_release_work_skeleton_seq108 1 work

View file

@ -0,0 +1,272 @@
import numpy as np
import os
os.nice(5)
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import plotly.graph_objects as go
import math
from math import tan
import random
from scipy.linalg import pinv
import projectaria_tools.core.mps as mps
import shutil
import json
from PIL import Image
from utils import remake_dir
import pandas as pd
import pylab as p
from IPython.display import display
import time
from projectaria_tools import utils
from projectaria_tools.core.stream_id import StreamId
from projectaria_tools.core import calibration
from projectaria_tools.projects.adt import (
AriaDigitalTwinDataProvider,
AriaDigitalTwinSkeletonProvider,
AriaDigitalTwinDataPathsProvider,
bbox3d_to_line_coordinates,
bbox2d_to_image_coordinates,
utils as adt_utils,
Aria3dPose
)
dataset_path = '/datasets/public/zhiming_datasets/adt/'
dataset_processed_path = '/scratch/hu/pose_forecast/adt_hoigaze/'
remake_dir(dataset_processed_path)
remake_dir(dataset_processed_path + "train/")
remake_dir(dataset_processed_path + "test/")
dataset_info = pd.read_csv('adt.csv')
object_num = 5 # number of extracted dynamic objects that are closest to the left or right hands
for i, seq in enumerate(dataset_info['sequence_name']):
action = dataset_info['action'][i]
print("\nprocessing {}th seq: {}, action: {}...".format(i+1, seq, action))
seq_path = dataset_path + seq + '/'
if dataset_info['training'][i] == 1:
save_path = dataset_processed_path + 'train/' + seq + '_'
if dataset_info['training'][i] == 0:
save_path = dataset_processed_path + 'test/' + seq + '_'
paths_provider = AriaDigitalTwinDataPathsProvider(seq_path)
all_device_serials = paths_provider.get_device_serial_numbers()
selected_device_number = 0
data_paths = paths_provider.get_datapaths_by_device_num(selected_device_number)
print("loading ground truth data...")
gt_provider = AriaDigitalTwinDataProvider(data_paths)
print("loading ground truth data done")
stream_id = StreamId("214-1")
img_timestamps_ns = gt_provider.get_aria_device_capture_timestamps_ns(stream_id)
frame_num = len(img_timestamps_ns)
print("There are {} frames".format(frame_num))
# get all available skeletons in a sequence
skeleton_ids = gt_provider.get_skeleton_ids()
skeleton_info = gt_provider.get_instance_info_by_id(skeleton_ids[0])
print("skeleton ", skeleton_info.name, " wears ", skeleton_info.associated_device_serial)
useful_frames = []
gaze_data = np.zeros((frame_num, 6)) # gaze_direction (3) + gaze_2d (2) + frame_id (1)
head_data = np.zeros((frame_num, 6)) # head_direction (3) + head_translation (3)
hand_data = np.zeros((frame_num, 6)) # left_hand_translation (3) + right_hand_translation (3)
hand_joint_data = np.zeros((frame_num, 92)) # left_hand (15*3) + right_hand (15*3) + attended_hand_gt + attended_hand_baseline (closest_hand)
object_all_data = []
object_bbx_all_data = []
object_center_all_data = []
local_time = time.asctime(time.localtime(time.time()))
print('\nProcessing starts at ' + local_time)
for j in range(frame_num):
timestamps_ns = img_timestamps_ns[j]
skeleton_with_dt = gt_provider.get_skeleton_by_timestamp_ns(timestamps_ns, skeleton_ids[0])
assert skeleton_with_dt.is_valid(), "skeleton is not valid"
skeleton = skeleton_with_dt.data()
head_translation_id = [4]
hand_translation_id = [8, 27]
hand_joints_id = [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42]
hand_translation = np.array(skeleton.joints)[hand_translation_id, :].reshape(2*3)
head_translation = np.array(skeleton.joints)[head_translation_id, :].reshape(1*3)
hand_joints = np.array(skeleton.joints)[hand_joints_id, :].reshape(30*3)
hand_data[j] = hand_translation
hand_joint_data[j, :90] = hand_joints
left_hand_joints = hand_joints[:45].reshape(15, 3)
left_hand_center = np.mean(left_hand_joints, axis=0)
right_hand_joints = hand_joints[45:].reshape(15, 3)
right_hand_center = np.mean(right_hand_joints, axis=0)
# get the Aria pose
aria3dpose_with_dt = gt_provider.get_aria_3d_pose_by_timestamp_ns(timestamps_ns)
if not aria3dpose_with_dt.is_valid():
print("aria 3d pose is not available")
aria3dpose = aria3dpose_with_dt.data()
transform_scene_device = aria3dpose.transform_scene_device.matrix()
# get projection function
cam_calibration = gt_provider.get_aria_camera_calibration(stream_id)
assert cam_calibration is not None, "no camera calibration"
eye_gaze_with_dt = gt_provider.get_eyegaze_by_timestamp_ns(timestamps_ns)
assert eye_gaze_with_dt.is_valid(), "Eye gaze not available"
# Project the gaze center in CPF frame into camera sensor plane, with multiplication performed in homogenous coordinates
eye_gaze = eye_gaze_with_dt.data()
gaze_center_in_cpf = np.array([tan(eye_gaze.yaw), tan(eye_gaze.pitch), 1.0], dtype=np.float64) * eye_gaze.depth
head_center_in_cpf = np.array([0.0, 0.0, 1.0], dtype=np.float64)
transform_cpf_sensor = gt_provider.raw_data_provider_ptr().get_device_calibration().get_transform_cpf_sensor(cam_calibration.get_label())
gaze_center_in_camera = transform_cpf_sensor.inverse().matrix() @ np.hstack((gaze_center_in_cpf, 1)).T
gaze_center_in_camera = gaze_center_in_camera[:3] / gaze_center_in_camera[3:]
gaze_center_in_pixels = cam_calibration.project(gaze_center_in_camera)
head_center_in_camera = transform_cpf_sensor.inverse().matrix() @ np.hstack((head_center_in_cpf, 0)).T
head_center_in_camera = head_center_in_camera[:3]
extrinsic_matrix = cam_calibration.get_transform_device_camera().matrix()
gaze_center_in_device = (extrinsic_matrix @ np.hstack((gaze_center_in_camera, 1)))[0:3]
gaze_center_in_scene = (transform_scene_device @ np.hstack((gaze_center_in_device, 1)))[0:3]
head_center_in_device = (extrinsic_matrix @ np.hstack((head_center_in_camera, 0)))[0:3]
head_center_in_scene = (transform_scene_device @ np.hstack((head_center_in_device, 0)))[0:3]
gaze_direction = gaze_center_in_scene - head_translation
if np.linalg.norm(gaze_direction) == 0: # invalid data that will be filtered
gaze_direction = np.array([0.0, 0.0, 1.0], dtype=np.float64)
else:
gaze_direction = [x / np.linalg.norm(gaze_direction) for x in gaze_direction]
head_direction = head_center_in_scene
head_direction = [x / np.linalg.norm(head_direction) for x in head_direction]
head_data[j, 0:3] = head_direction
head_data[j, 3:6] = head_translation
left_hand_direction = left_hand_center - head_translation
left_hand_direction = np.array([x / np.linalg.norm(left_hand_direction) for x in left_hand_direction])
left_hand_distance_to_gaze = np.arccos(np.sum(gaze_direction*left_hand_direction))
right_hand_direction = right_hand_center - head_translation
right_hand_direction = np.array([x / np.linalg.norm(right_hand_direction) for x in right_hand_direction])
right_hand_distance_to_gaze = np.arccos(np.sum(gaze_direction*right_hand_direction))
if left_hand_distance_to_gaze < right_hand_distance_to_gaze:
hand_joint_data[j, 90:91] = 0
else:
hand_joint_data[j, 90:91] = 1
if gaze_center_in_pixels is not None:
x_pixel = gaze_center_in_pixels[1]
y_pixel = gaze_center_in_pixels[0]
gaze_center_in_pixels[0] = x_pixel
gaze_center_in_pixels[1] = y_pixel
useful_frames.append(j)
gaze_2d = np.divide(gaze_center_in_pixels, cam_calibration.get_image_size())
gaze_data[j, 0:3] = gaze_direction
gaze_data[j, 3:5] = gaze_2d
gaze_data[j, 5:6] = j
# get the objects
bbox3d_with_dt = gt_provider.get_object_3d_boundingboxes_by_timestamp_ns(timestamps_ns)
assert bbox3d_with_dt.is_valid(), "3D bounding box is not available"
bbox3d_all = bbox3d_with_dt.data()
object_all = []
object_bbx_all = []
object_center_all = []
for obj_id in bbox3d_all:
bbox3d = bbox3d_all[obj_id]
aabb = bbox3d.aabb
aabb_coords = bbox3d_to_line_coordinates(aabb)
obb = np.zeros(shape=(len(aabb_coords), 3))
for k in range(0, len(aabb_coords)):
aabb_pt = aabb_coords[k]
aabb_pt_homo = np.append(aabb_pt, [1])
obb_pt = (bbox3d.transform_scene_object.matrix() @ aabb_pt_homo)[0:3]
obb[k] = obb_pt
motion_type = gt_provider.get_instance_info_by_id(obj_id).motion_type
if(str(motion_type) == 'MotionType.DYNAMIC'):
object_all.append(obb)
bbx_idx = [0, 1, 2, 3, 5, 6, 7, 8]
obb_bbx = obb[bbx_idx, :]
object_bbx_all.append(obb_bbx)
obb_center = np.mean(obb_bbx, axis=0)
object_center_all.append(obb_center)
object_all_data.append(object_all)
object_bbx_all_data.append(object_bbx_all)
object_center_all_data.append(object_center_all)
gaze_data = gaze_data[useful_frames, :] # useful_frames are actually continuous
head_data = head_data[useful_frames, :]
hand_data = hand_data[useful_frames, :]
hand_joint_data = hand_joint_data[useful_frames, :]
object_all_data = np.array(object_all_data)
object_all_data = object_all_data[useful_frames, :, :, :]
#print("Objects shape: {}".format(object_all_data.shape))
object_bbx_all_data = np.array(object_bbx_all_data)
object_bbx_all_data = object_bbx_all_data[useful_frames, :, :, :]
object_center_all_data = np.array(object_center_all_data)
object_center_all_data = object_center_all_data[useful_frames, :, :]
# extract the closest objects to the left or right hands
useful_frames_num = len(useful_frames)
print("There are {} useful frames".format(useful_frames_num))
object_num_all = object_all_data.shape[1]
object_left_hand_data = np.zeros((useful_frames_num, object_num, 16, 3))
object_bbx_left_hand_data = np.zeros((useful_frames_num, object_num, 8, 3))
object_distance_to_left_hand = np.zeros((useful_frames_num, object_num_all))
object_right_hand_data = np.zeros((useful_frames_num, object_num, 16, 3))
object_bbx_right_hand_data = np.zeros((useful_frames_num, object_num, 8, 3))
object_distance_to_right_hand = np.zeros((useful_frames_num, object_num_all))
for j in range(useful_frames_num):
left_hand_joints = hand_joint_data[j, :45].reshape(15, 3)
right_hand_joints = hand_joint_data[j, 45:90].reshape(15, 3)
for k in range(object_num_all):
object_pos = object_center_all_data[j, k, :]
object_distance_to_left_hand[j, k] = np.mean(np.linalg.norm(left_hand_joints-object_pos, axis=1))
object_distance_to_right_hand[j, k] = np.mean(np.linalg.norm(right_hand_joints-object_pos, axis=1))
for j in range(useful_frames_num):
distance_to_left_hand = object_distance_to_left_hand[j, :]
distance_to_left_hand_min = np.min(distance_to_left_hand)
distance_to_right_hand = object_distance_to_right_hand[j, :]
distance_to_right_hand_min = np.min(distance_to_right_hand)
if distance_to_left_hand_min < distance_to_right_hand_min:
hand_joint_data[j, 91:92] = 0
else:
hand_joint_data[j, 91:92] = 1
left_hand_index = np.argsort(distance_to_left_hand)
right_hand_index = np.argsort(distance_to_right_hand)
for k in range(object_num):
object_left_hand_data[j, k] = object_all_data[j, left_hand_index[k]]
object_bbx_left_hand_data[j, k] = object_bbx_all_data[j, left_hand_index[k]]
object_right_hand_data[j, k] = object_all_data[j, right_hand_index[k]]
object_bbx_right_hand_data[j, k] = object_bbx_all_data[j, right_hand_index[k]]
gaze_path = save_path + 'gaze.npy'
head_path = save_path + 'head.npy'
hand_path = save_path + 'hand.npy'
hand_joint_path = save_path + 'handjoints.npy'
object_left_hand_path = save_path + 'object_left.npy'
object_bbx_left_hand_path = save_path + 'object_bbxleft.npy'
object_right_hand_path = save_path + 'object_right.npy'
object_bbx_right_hand_path = save_path + 'object_bbxright.npy'
np.save(gaze_path, gaze_data)
np.save(head_path, head_data)
np.save(hand_path, hand_data)
np.save(hand_joint_path, hand_joint_data)
np.save(object_left_hand_path, object_left_hand_data)
np.save(object_bbx_left_hand_path, object_bbx_left_hand_data)
np.save(object_right_hand_path, object_right_hand_data)
np.save(object_bbx_right_hand_path, object_bbx_right_hand_data)
local_time = time.asctime(time.localtime(time.time()))
print('\nProcessing ends at ' + local_time)

View file

@ -0,0 +1,162 @@
# visualise data in the ADT dataset
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# play human pose using a skeleton
class Player_Skeleton:
def __init__(self, fps=30.0, object_num=10):
self._fps = fps
self.object_num = object_num
# names of all the joints: head + left_hand + right_hand + left_hand_joint + right_hand_joint + gaze_direction + head_direction
self._joint_names = ['Head', 'LHand', 'RHand', 'LThumb1', 'LThumb2', 'LThumb3', 'LIndex1', 'LIndex2', 'LIndex3', 'LMiddle1', 'LMiddle2', 'LMiddle3', 'LRing1', 'LRing2', 'LRing3', 'LPinky1', 'LPinky2', 'LPinky3', 'RThumb1', 'RThumb2', 'RThumb3', 'RIndex1', 'RIndex2', 'RIndex3', 'RMiddle1', 'RMiddle2', 'RMiddle3', 'RRing1', 'RRing2', 'RRing3', 'RPinky1', 'RPinky2', 'RPinky3', 'Gaze_direction', 'Head_direction']
self._joint_ids = {name: idx for idx, name in enumerate(self._joint_names)}
# parent of every joint
self._joint_parent_names = {
# root
'Head': 'Head',
'LHand': 'LHand',
'RHand': 'RHand',
'LThumb1': 'LHand',
'LThumb2': 'LThumb1',
'LThumb3': 'LThumb2',
'LIndex1': 'LHand',
'LIndex2': 'LIndex1',
'LIndex3': 'LIndex2',
'LMiddle1': 'LHand',
'LMiddle2': 'LMiddle1',
'LMiddle3': 'LMiddle2',
'LRing1': 'LHand',
'LRing2': 'LRing1',
'LRing3': 'LRing2',
'LPinky1': 'LHand',
'LPinky2': 'LPinky1',
'LPinky3': 'LPinky2',
'RThumb1': 'RHand',
'RThumb2': 'RThumb1',
'RThumb3': 'RThumb2',
'RIndex1': 'RHand',
'RIndex2': 'RIndex1',
'RIndex3': 'RIndex2',
'RMiddle1': 'RHand',
'RMiddle2': 'RMiddle1',
'RMiddle3': 'RMiddle2',
'RRing1': 'RHand',
'RRing2': 'RRing1',
'RRing3': 'RRing2',
'RPinky1': 'RHand',
'RPinky2': 'RPinky1',
'RPinky3': 'RPinky2',
'Gaze_direction': 'Head',
'Head_direction': 'Head',}
# id of joint parent
self._joint_parent_ids = [self._joint_ids[self._joint_parent_names[child_name]] for child_name in self._joint_names]
self._joint_links = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
# colors: 0 for head, 1 for left, 2 for right
self._link_colors = [1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4]
self._fig = plt.figure()
self._ax = plt.gca(projection='3d')
self._plots = []
for i in range(len(self._joint_links)):
if self._link_colors[i] == 0:
color = "#3498db"
if self._link_colors[i] == 1:
color = "#3498db"
if self._link_colors[i] == 2:
color = "#3498db"
if self._link_colors[i] == 3:
color = "#6aa84f"
if self._link_colors[i] == 4:
color = "#a64d79"
self._plots.append(self._ax.plot([0, 0], [0, 0], [0, 0], lw=2.0, c=color))
for i in range(self.object_num):
self._plots.append(self._ax.plot([0, 0], [0, 0], [0, 0], lw=1.0, c='#ff0000'))
self._ax.set_xlabel("x")
self._ax.set_ylabel("y")
self._ax.set_zlabel("z")
# play the sequence of human pose in xyz representations
def play_xyz(self, pose_xyz, gaze, head, objects):
gaze_direction = pose_xyz[:, :3] + gaze[:, :3]*0.5
head_direction = pose_xyz[:, :3] + head[:, :3]*0.5
pose_xyz = np.concatenate((pose_xyz, gaze_direction), axis = 1)
pose_xyz = np.concatenate((pose_xyz, head_direction), axis = 1)
for i in range(pose_xyz.shape[0]):
joint_number = len(self._joint_names)
pose_xyz_tmp = pose_xyz[i].reshape(joint_number, 3)
objects_xyz = objects[i, :, :, :]
for j in range(len(self._joint_links)):
idx = self._joint_links[j]
start_point = pose_xyz_tmp[idx]
end_point = pose_xyz_tmp[self._joint_parent_ids[idx]]
x = np.array([start_point[0], end_point[0]])
y = np.array([start_point[2], end_point[2]])
z = np.array([start_point[1], end_point[1]])
self._plots[j][0].set_xdata(x)
self._plots[j][0].set_ydata(y)
self._plots[j][0].set_3d_properties(z)
for j in range(len(self._joint_links), len(self._joint_links) + objects_xyz.shape[0]):
object_xyz = objects_xyz[j - len(self._joint_links), :, :]
self._plots[j][0].set_xdata(object_xyz[:, 0])
self._plots[j][0].set_ydata(object_xyz[:, 2])
self._plots[j][0].set_3d_properties(object_xyz[:, 1])
r = 1.0
x_root, y_root, z_root = pose_xyz_tmp[0, 0], pose_xyz_tmp[0, 2], pose_xyz_tmp[0, 1]
self._ax.set_xlim3d([-r + x_root, r + x_root])
self._ax.set_ylim3d([-r + y_root, r + y_root])
self._ax.set_zlim3d([-r + z_root, r + z_root])
#self._ax.view_init(elev=30, azim=-110)
self._ax.grid(False)
#self._ax.axis('off')
self._ax.set_aspect('auto')
plt.show(block=False)
self._fig.canvas.draw()
past_time = f"{i / self._fps:.1f}"
plt.title(f"Time: {past_time} s", fontsize=15)
plt.pause(0.000000001)
if __name__ == "__main__":
data_path = '/scratch/hu/pose_forecast/adt_hoigaze/test/Apartment_release_meal_skeleton_seq132_'
gaze_path = data_path + 'gaze.npy'
head_path = data_path + 'head.npy'
hand_path = data_path + 'hand.npy'
hand_joint_path = data_path + 'handjoints.npy'
object_left_hand_path = data_path + 'object_left.npy'
object_right_hand_path = data_path + 'object_right.npy'
gaze = np.load(gaze_path) # gaze_direction (3) + gaze_2d (2) + frame_id (1)
print("Gaze shape: {}".format(gaze.shape))
gaze_direction = gaze[:, :3]
head = np.load(head_path) # head_direction (3) + head_translation (3)
print("Head shape: {}".format(head.shape))
head_direction = head[:, :3]
head_translation = head[:, 3:]
hand_translation = np.load(hand_path) # left_hand_translation (3) + right_hand_translation (3)
print("Hand shape: {}".format(hand_translation.shape))
hand_joint = np.load(hand_joint_path) # left_hand (15*3) + right_hand (15*3) + hand_dominance + closest_hand
print("Hand joint shape: {}".format(hand_joint.shape))
hand_joint = hand_joint[:, :90]
pose = np.concatenate((head_translation, hand_translation), axis=1)
pose = np.concatenate((pose, hand_joint), axis=1)
object_left = np.load(object_left_hand_path)[:, :, :, :]
object_right = np.load(object_right_hand_path)[:, :, :, :]
object_all = np.concatenate((object_left, object_right), axis=1)
print("Object shape: {}".format(object_all.shape))
player = Player_Skeleton(object_num = object_all.shape[1])
player.play_xyz(pose, gaze_direction, head_direction, object_all)

View file

@ -0,0 +1,4 @@
__all__ = ['file_systems']
from .file_systems import remake_dir, make_dir

View file

@ -0,0 +1,50 @@
import os
import shutil
import time
# remove a directory
def remove_dir(dirName):
if os.path.exists(dirName):
shutil.rmtree(dirName)
else:
print("Invalid directory path!")
# remake a directory
def remake_dir(dirName):
if os.path.exists(dirName):
shutil.rmtree(dirName)
os.makedirs(dirName)
else:
os.makedirs(dirName)
# calculate the number of lines in a file
def file_lines(fileName):
if os.path.exists(fileName):
with open(fileName, 'r') as fr:
return len(fr.readlines())
else:
print("Invalid file path!")
return 0
# make a directory if it does not exist.
def make_dir(dirName):
if os.path.exists(dirName):
print("Directory "+ dirName + " already exists.")
else:
os.makedirs(dirName)
if __name__ == "__main__":
dirName = "test"
RemakeDir(dirName)
time.sleep(3)
MakeDir(dirName)
RemoveDir(dirName)
time.sleep(3)
MakeDir(dirName)
#print(FileLines('233.txt'))

View file

@ -0,0 +1,269 @@
from utils import adt_dataset, seed_torch
from model import attended_hand_recognition
from utils.opt import options
from utils import log
from torch.utils.data import DataLoader
import torch
import numpy as np
import time
import datetime
import torch.optim as optim
import os
os.nice(5)
import math
def main(opt):
# set the random seed to ensure reproducibility
seed_torch.seed_torch(seed=0)
torch.set_num_threads(1)
data_dir = opt.data_dir
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
learning_rate = opt.learning_rate
print('>>> create model')
net = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
optimizer = optim.AdamW(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate, weight_decay=opt.weight_decay)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
print('>>> loading datasets')
train_actions = 'all'
test_actions = opt.actions
train_dataset = adt_dataset.adt_dataset(data_dir, seq_len, train_actions, 1, opt.object_num, opt.hand_joint_number, opt.sample_rate)
train_data_size = train_dataset.dataset.shape
print("Training data size: {}".format(train_data_size))
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_dataset = adt_dataset.adt_dataset(data_dir, seq_len, test_actions, 0, opt.object_num, opt.hand_joint_number, opt.sample_rate)
valid_data_size = valid_dataset.dataset.shape
print("Validation data size: {}".format(valid_data_size))
valid_loader = DataLoader(valid_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
# training
local_time = time.asctime(time.localtime(time.time()))
print('\nTraining starts at ' + local_time)
start_time = datetime.datetime.now()
start_epoch = 1
acc_best = 0
best_epoch = 0
exp_lr = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma, last_epoch=-1)
for epo in range(start_epoch, opt.epoch + 1):
is_best = False
learning_rate = exp_lr.optimizer.param_groups[0]["lr"]
train_start_time = datetime.datetime.now()
result_train = run_model(net, optimizer, is_train=1, data_loader=train_loader, opt=opt)
train_end_time = datetime.datetime.now()
train_time = (train_end_time - train_start_time).seconds*1000
train_batch_num = math.ceil(train_data_size[0]/opt.batch_size)
train_time_per_batch = math.ceil(train_time/train_batch_num)
#print('\nTraining time per batch: {} ms'.format(train_time_per_batch))
exp_lr.step()
rng_state = torch.get_rng_state()
if epo % opt.validation_epoch == 0:
print('>>> training epoch: {:d}, lr: {:.12f}'.format(epo, learning_rate))
print('Training data size: {}'.format(train_data_size))
print('Average baseline acc: {:.2f}%'.format(result_train['baseline_acc_average']*100))
print('Average training acc: {:.2f}%'.format(result_train['prediction_acc_average']*100))
test_start_time = datetime.datetime.now()
result_valid = run_model(net, is_train=0, data_loader=valid_loader, opt=opt)
test_end_time = datetime.datetime.now()
test_time = (test_end_time - test_start_time).seconds*1000
test_batch_num = math.ceil(valid_data_size[0]/opt.test_batch_size)
test_time_per_batch = math.ceil(test_time/test_batch_num)
#print('\nTest time per batch: {} ms'.format(test_time_per_batch))
print('Validation data size: {}'.format(valid_data_size))
print('Average baseline acc: {:.2f}%'.format(result_valid['baseline_acc_average']*100))
print('Average validation acc: {:.2f}%'.format(result_valid['prediction_acc_average']*100))
if result_valid['prediction_acc_average'] > acc_best:
acc_best = result_valid['prediction_acc_average']
is_best = True
best_epoch = epo
print('Best validation error: {:.2f}%, best epoch: {}'.format(acc_best*100, best_epoch))
end_time = datetime.datetime.now()
total_training_time = (end_time - start_time).seconds/60
print('\nTotal training time: {:.1f} min'.format(total_training_time))
local_time = time.asctime(time.localtime(time.time()))
print('\nTraining ends at ' + local_time)
result_log = np.array([epo, learning_rate])
head = np.array(['epoch', 'lr'])
for k in result_train.keys():
result_log = np.append(result_log, [result_train[k]])
head = np.append(head, [k])
for k in result_valid.keys():
result_log = np.append(result_log, [result_valid[k]])
head = np.append(head, ['valid_' + k])
csv_name = 'attended_hand_recognition_results'
model_name = 'attended_hand_recognition_model.pt'
log.save_csv_log(opt, head, result_log, is_create=(epo == 1), file_name=csv_name)
log.save_ckpt({'epoch': epo,
'lr': learning_rate,
'acc': result_valid['prediction_acc_average'],
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict()},
opt=opt,
file_name = model_name)
torch.set_rng_state(rng_state)
def eval(opt):
data_dir = opt.data_dir
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
print('>>> create model')
net = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
#load model
model_name = 'attended_hand_recognition_model.pt'
model_path = os.path.join(opt.ckpt, model_name)
print(">>> loading ckpt from '{}'".format(model_path))
ckpt = torch.load(model_path)
net.load_state_dict(ckpt['state_dict'])
print(">>> ckpt loaded (epoch: {} | acc: {})".format(ckpt['epoch'], ckpt['acc']))
print('>>> loading datasets')
train_actions = 'all'
test_actions = opt.actions
train_dataset = adt_dataset.adt_dataset(data_dir, seq_len, train_actions, 1, opt.object_num, opt.hand_joint_number, opt.sample_rate)
train_data_size = train_dataset.dataset.shape
print("Train data size: {}".format(train_data_size))
train_loader = DataLoader(train_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
test_dataset = adt_dataset.adt_dataset(data_dir, seq_len, test_actions, 0, opt.object_num, opt.hand_joint_number, opt.sample_rate)
test_data_size = test_dataset.dataset.shape
print("Test data size: {}".format(test_data_size))
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
# test
local_time = time.asctime(time.localtime(time.time()))
print('\nTest starts at ' + local_time)
start_time = datetime.datetime.now()
if opt.save_predictions:
result_train, predictions_train = run_model(net, is_train=0, data_loader=train_loader, opt=opt)
result_test, predictions_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
else:
result_train = run_model(net, is_train=0, data_loader=train_loader, opt=opt)
result_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
print('Average train baseline acc: {:.2f}%'.format(result_train['baseline_acc_average']*100))
print('Average train method acc: {:.2f}%'.format(result_train['prediction_acc_average']*100))
print('Average test baseline acc: {:.2f}%'.format(result_test['baseline_acc_average']*100))
print('Average test method acc: {:.2f}%'.format(result_test['prediction_acc_average']*100))
end_time = datetime.datetime.now()
total_test_time = (end_time - start_time).seconds/60
print('\nTotal test time: {:.1f} min'.format(total_test_time))
local_time = time.asctime(time.localtime(time.time()))
print('\nTest ends at ' + local_time)
if opt.save_predictions:
prediction = predictions_train[:, :, -3:-1].reshape(-1, 2)
attended_hand_gt = predictions_train[:, :, -1:].reshape(-1)
y_prd = np.argmax(prediction, axis=1)
acc = np.sum(y_prd == attended_hand_gt)/prediction.shape[0]
print('Average train acc: {:.2f}%'.format(acc*100))
predictions_train_path = os.path.join(opt.ckpt, "attended_hand_recognition_train.npy")
np.save(predictions_train_path, predictions_train)
prediction = predictions_test[:, :, -3:-1].reshape(-1, 2)
attended_hand_gt = predictions_test[:, :, -1:].reshape(-1)
y_prd = np.argmax(prediction, axis=1)
acc = np.sum(y_prd == attended_hand_gt)/prediction.shape[0]
print('Average test acc: {:.2f}%'.format(acc*100))
predictions_test_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
np.save(predictions_test_path, predictions_test)
def run_model(net, optimizer=None, is_train=1, data_loader=None, opt=None):
if is_train == 1:
net.train()
else:
net.eval()
if opt.is_eval and opt.save_predictions:
predictions = []
prediction_acc_average = 0
baseline_acc_average = 0
criterion = torch.nn.CrossEntropyLoss()
n = 0
input_n = opt.seq_len
for i, (data) in enumerate(data_loader):
batch_size, seq_n, dim = data.shape
joint_number = opt.joint_number
object_num = opt.object_num
# when only one sample in this batch
if batch_size == 1 and is_train == 1:
continue
n += batch_size*seq_n
data = data.float().to(opt.cuda_idx)
eye_gaze = data.clone()[:, :, :3]
joints = data.clone()[:, :, 3:(joint_number+1)*3]
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
attended_hand_gt = data.clone()[:, :, (joint_number+2+8*object_num*2)*3:(joint_number+2+8*object_num*2)*3+1].type(torch.LongTensor).to(opt.cuda_idx)
attended_hand_baseline = data.clone()[:, :, (joint_number+2+8*object_num*2)*3+1:(joint_number+2+8*object_num*2)*3+2].type(torch.LongTensor).to(opt.cuda_idx)
input = torch.cat((joints, head_directions), dim=2)
if object_num > 0:
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
input = torch.cat((input, object_positions), dim=2)
prediction = net(input, input_n=input_n)
if opt.is_eval and opt.save_predictions:
# eye_gaze + joints + head_directions + object_positions + predictions + attended_hand_gt
prediction = torch.nn.functional.softmax(prediction, dim=2)
prediction_cpu = torch.cat((eye_gaze, input), dim=2)
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
prediction_cpu = torch.cat((prediction_cpu, attended_hand_gt), dim=2)
prediction_cpu = prediction_cpu.cpu().data.numpy()
if len(predictions) == 0:
predictions = prediction_cpu
else:
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
attended_hand_gt = attended_hand_gt.reshape(batch_size*input_n)
attended_hand_baseline = attended_hand_baseline.reshape(batch_size*input_n)
prediction = prediction.reshape(-1, 2)
loss = criterion(prediction, attended_hand_gt)
if is_train == 1:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# calculate prediction accuracy
_, y_prd = torch.max(prediction.data, 1)
acc = torch.sum(y_prd == attended_hand_gt)/(batch_size*input_n)
prediction_acc_average += acc.cpu().data.numpy() * batch_size*input_n
acc = torch.sum(attended_hand_gt == attended_hand_baseline)/(batch_size*input_n)
baseline_acc_average += acc.cpu().data.numpy() * batch_size*input_n
result = {}
result["baseline_acc_average"] = baseline_acc_average / n
result["prediction_acc_average"] = prediction_acc_average / n
if opt.is_eval and opt.save_predictions:
return result, predictions
else:
return result
if __name__ == '__main__':
option = options().parse()
if option.is_eval == False:
main(option)
else:
eval(option)

View file

@ -0,0 +1,368 @@
from utils import hot3d_aria_dataset, seed_torch
from model import attended_hand_recognition
from utils.opt import options
from utils import log
from torch.utils.data import DataLoader
import torch
import numpy as np
import time
import datetime
import torch.optim as optim
import os
os.nice(5)
import math
def main(opt):
# set the random seed to ensure reproducibility
seed_torch.seed_torch(seed=0)
torch.set_num_threads(1)
data_dir = opt.data_dir
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
learning_rate = opt.learning_rate
print('>>> create model')
net = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
optimizer = optim.AdamW(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate, weight_decay=opt.weight_decay)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
print('>>> loading datasets')
actions = opt.actions
test_user_id = opt.test_user_id
if actions == 'all':
if test_user_id == 1:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003']
opt.ckpt = opt.ckpt + '/user1/'
if test_user_id == 2:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0009', 'P0010', 'P0011']
opt.ckpt = opt.ckpt + '/user2/'
if test_user_id == 3:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
test_subjects = ['P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/user3/'
elif actions == 'room':
train_actions = ['kitchen', 'office']
test_actions = ['room']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene1/'
elif actions == 'kitchen':
train_actions = ['room', 'office']
test_actions = ['kitchen']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene2/'
elif actions == 'office':
train_actions = ['room', 'kitchen']
test_actions = ['office']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene3/'
else:
raise( ValueError, "Unrecognised actions: %d" % actions)
if not os.path.isdir(opt.ckpt):
os.makedirs(opt.ckpt)
train_dataset = hot3d_aria_dataset.hot3d_aria_dataset(data_dir, train_subjects, seq_len, train_actions, opt.object_num, opt.sample_rate)
train_data_size = train_dataset.dataset.shape
print("Training data size: {}".format(train_data_size))
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_dataset = hot3d_aria_dataset.hot3d_aria_dataset(data_dir, test_subjects, seq_len, test_actions, opt.object_num, opt.sample_rate)
valid_data_size = valid_dataset.dataset.shape
print("Validation data size: {}".format(valid_data_size))
valid_loader = DataLoader(valid_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
# training
local_time = time.asctime(time.localtime(time.time()))
print('\nTraining starts at ' + local_time)
start_time = datetime.datetime.now()
start_epoch = 1
acc_best = 0
best_epoch = 0
exp_lr = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma, last_epoch=-1)
for epo in range(start_epoch, opt.epoch + 1):
is_best = False
learning_rate = exp_lr.optimizer.param_groups[0]["lr"]
train_start_time = datetime.datetime.now()
result_train = run_model(net, optimizer, is_train=1, data_loader=train_loader, opt=opt)
train_end_time = datetime.datetime.now()
train_time = (train_end_time - train_start_time).seconds*1000
train_batch_num = math.ceil(train_data_size[0]/opt.batch_size)
train_time_per_batch = math.ceil(train_time/train_batch_num)
#print('\nTraining time per batch: {} ms'.format(train_time_per_batch))
exp_lr.step()
rng_state = torch.get_rng_state()
if epo % opt.validation_epoch == 0:
if actions == 'all':
print("\ntest user id: {}\n".format(test_user_id))
elif actions == 'room':
print("\ntest scene/action: room\n")
elif actions == 'kitchen':
print("\ntest scene/action: kitchen\n")
elif actions == 'office':
print("\ntest scene/action: office\n")
print('>>> training epoch: {:d}, lr: {:.12f}'.format(epo, learning_rate))
print('Training data size: {}'.format(train_data_size))
print('Average baseline acc: {:.2f}%'.format(result_train['baseline_acc_average']*100))
print('Average training acc: {:.2f}%'.format(result_train['prediction_acc_average']*100))
test_start_time = datetime.datetime.now()
result_valid = run_model(net, is_train=0, data_loader=valid_loader, opt=opt)
test_end_time = datetime.datetime.now()
test_time = (test_end_time - test_start_time).seconds*1000
test_batch_num = math.ceil(valid_data_size[0]/opt.test_batch_size)
test_time_per_batch = math.ceil(test_time/test_batch_num)
#print('\nTest time per batch: {} ms'.format(test_time_per_batch))
print('Validation data size: {}'.format(valid_data_size))
print('Average baseline acc: {:.2f}%'.format(result_valid['baseline_acc_average']*100))
print('Average validation acc: {:.2f}%'.format(result_valid['prediction_acc_average']*100))
if result_valid['prediction_acc_average'] > acc_best:
acc_best = result_valid['prediction_acc_average']
is_best = True
best_epoch = epo
print('Best validation error: {:.2f}%, best epoch: {}'.format(acc_best*100, best_epoch))
end_time = datetime.datetime.now()
total_training_time = (end_time - start_time).seconds/60
print('\nTotal training time: {:.1f} min'.format(total_training_time))
local_time = time.asctime(time.localtime(time.time()))
print('\nTraining ends at ' + local_time)
result_log = np.array([epo, learning_rate])
head = np.array(['epoch', 'lr'])
for k in result_train.keys():
result_log = np.append(result_log, [result_train[k]])
head = np.append(head, [k])
for k in result_valid.keys():
result_log = np.append(result_log, [result_valid[k]])
head = np.append(head, ['valid_' + k])
csv_name = 'attended_hand_recognition_results'
model_name = 'attended_hand_recognition_model.pt'
log.save_csv_log(opt, head, result_log, is_create=(epo == 1), file_name=csv_name)
log.save_ckpt({'epoch': epo,
'lr': learning_rate,
'acc': result_valid['prediction_acc_average'],
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict()},
opt=opt,
file_name = model_name)
torch.set_rng_state(rng_state)
def eval(opt):
data_dir = opt.data_dir
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
print('>>> create model')
net = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
#load model
actions = opt.actions
test_user_id = opt.test_user_id
if actions == 'all':
if test_user_id == 1:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003']
opt.ckpt = opt.ckpt + '/user1/'
if test_user_id == 2:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0009', 'P0010', 'P0011']
opt.ckpt = opt.ckpt + '/user2/'
if test_user_id == 3:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
test_subjects = ['P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/user3/'
elif actions == 'room':
train_actions = ['kitchen', 'office']
test_actions = ['room']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene1/'
elif actions == 'kitchen':
train_actions = ['room', 'office']
test_actions = ['kitchen']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene2/'
elif actions == 'office':
train_actions = ['room', 'kitchen']
test_actions = ['office']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene3/'
else:
raise( ValueError, "Unrecognised actions: %d" % actions)
model_name = 'attended_hand_recognition_model.pt'
model_path = os.path.join(opt.ckpt, model_name)
print(">>> loading ckpt from '{}'".format(model_path))
ckpt = torch.load(model_path)
net.load_state_dict(ckpt['state_dict'])
print(">>> ckpt loaded (epoch: {} | acc: {})".format(ckpt['epoch'], ckpt['acc']))
print('>>> loading datasets')
train_dataset = hot3d_aria_dataset.hot3d_aria_dataset(data_dir, train_subjects, seq_len, train_actions, opt.object_num, opt.sample_rate)
train_data_size = train_dataset.dataset.shape
print("Train data size: {}".format(train_data_size))
train_loader = DataLoader(train_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
test_dataset = hot3d_aria_dataset.hot3d_aria_dataset(data_dir, test_subjects, seq_len, test_actions, opt.object_num, opt.sample_rate)
test_data_size = test_dataset.dataset.shape
print("Test data size: {}".format(test_data_size))
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
# test
local_time = time.asctime(time.localtime(time.time()))
print('\nTest starts at ' + local_time)
start_time = datetime.datetime.now()
if actions == 'all':
print("\ntest user id: {}\n".format(test_user_id))
elif actions == 'room':
print("\ntest scene/action: room\n")
elif actions == 'kitchen':
print("\ntest scene/action: kitchen\n")
elif actions == 'office':
print("\ntest scene/action: office\n")
if opt.save_predictions:
result_train, predictions_train = run_model(net, is_train=0, data_loader=train_loader, opt=opt)
result_test, predictions_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
else:
result_train = run_model(net, is_train=0, data_loader=train_loader, opt=opt)
result_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
print('Average train baseline acc: {:.2f}%'.format(result_train['baseline_acc_average']*100))
print('Average train method acc: {:.2f}%'.format(result_train['prediction_acc_average']*100))
print('Average test baseline acc: {:.2f}%'.format(result_test['baseline_acc_average']*100))
print('Average test method acc: {:.2f}%'.format(result_test['prediction_acc_average']*100))
end_time = datetime.datetime.now()
total_test_time = (end_time - start_time).seconds/60
print('\nTotal test time: {:.1f} min'.format(total_test_time))
local_time = time.asctime(time.localtime(time.time()))
print('\nTest ends at ' + local_time)
if opt.save_predictions:
prediction = predictions_train[:, :, -3:-1].reshape(-1, 2)
attended_hand_gt = predictions_train[:, :, -1:].reshape(-1)
y_prd = np.argmax(prediction, axis=1)
acc = np.sum(y_prd == attended_hand_gt)/prediction.shape[0]
print('Average train acc: {:.2f}%'.format(acc*100))
predictions_train_path = os.path.join(opt.ckpt, "attended_hand_recognition_train.npy")
np.save(predictions_train_path, predictions_train)
prediction = predictions_test[:, :, -3:-1].reshape(-1, 2)
attended_hand_gt = predictions_test[:, :, -1:].reshape(-1)
y_prd = np.argmax(prediction, axis=1)
acc = np.sum(y_prd == attended_hand_gt)/prediction.shape[0]
print('Average test acc: {:.2f}%'.format(acc*100))
predictions_test_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
np.save(predictions_test_path, predictions_test)
def run_model(net, optimizer=None, is_train=1, data_loader=None, opt=None):
if is_train == 1:
net.train()
else:
net.eval()
if opt.is_eval and opt.save_predictions:
predictions = []
prediction_acc_average = 0
baseline_acc_average = 0
criterion = torch.nn.CrossEntropyLoss()
n = 0
input_n = opt.seq_len
for i, (data) in enumerate(data_loader):
batch_size, seq_n, dim = data.shape
joint_number = opt.joint_number
object_num = opt.object_num
# when only one sample in this batch
if batch_size == 1 and is_train == 1:
continue
n += batch_size*seq_n
data = data.float().to(opt.cuda_idx)
eye_gaze = data.clone()[:, :, :3]
joints = data.clone()[:, :, 3:(joint_number+1)*3]
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
attended_hand_gt = data.clone()[:, :, (joint_number+2+8*object_num*2)*3:(joint_number+2+8*object_num*2)*3+1].type(torch.LongTensor).to(opt.cuda_idx)
attended_hand_baseline = data.clone()[:, :, (joint_number+2+8*object_num*2)*3+1:(joint_number+2+8*object_num*2)*3+2].type(torch.LongTensor).to(opt.cuda_idx)
input = torch.cat((joints, head_directions), dim=2)
if object_num > 0:
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
input = torch.cat((input, object_positions), dim=2)
prediction = net(input, input_n=input_n)
if opt.is_eval and opt.save_predictions:
# eye_gaze + joints + head_directions + object_positions + predictions + attended_hand_gt
prediction = torch.nn.functional.softmax(prediction, dim=2)
prediction_cpu = torch.cat((eye_gaze, input), dim=2)
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
prediction_cpu = torch.cat((prediction_cpu, attended_hand_gt), dim=2)
prediction_cpu = prediction_cpu.cpu().data.numpy()
if len(predictions) == 0:
predictions = prediction_cpu
else:
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
attended_hand_gt = attended_hand_gt.reshape(batch_size*input_n)
attended_hand_baseline = attended_hand_baseline.reshape(batch_size*input_n)
prediction = prediction.reshape(-1, 2)
loss = criterion(prediction, attended_hand_gt)
if is_train == 1:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# calculate prediction accuracy
_, y_prd = torch.max(prediction.data, 1)
acc = torch.sum(y_prd == attended_hand_gt)/(batch_size*input_n)
prediction_acc_average += acc.cpu().data.numpy() * batch_size*input_n
acc = torch.sum(attended_hand_gt == attended_hand_baseline)/(batch_size*input_n)
baseline_acc_average += acc.cpu().data.numpy() * batch_size*input_n
result = {}
result["baseline_acc_average"] = baseline_acc_average / n
result["prediction_acc_average"] = prediction_acc_average / n
if opt.is_eval and opt.save_predictions:
return result, predictions
else:
return result
if __name__ == '__main__':
option = options().parse()
if option.is_eval == False:
main(option)
else:
eval(option)

Binary file not shown.

View file

@ -0,0 +1,7 @@
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
10.0,0.0031512470486230455,0.6745207058933649,0.8722085664853889,0.6619414385512743,0.8330426198868408
20.0,0.0018867680126765363,0.6745207054940767,0.8948185154648364,0.6619414385512743,0.8627868126688715
30.0,0.0011296777049628276,0.674520705352511,0.908833516738123,0.6619414385512743,0.8652105219446476
40.0,0.0006763797713952794,0.6745207055067813,0.915962817709497,0.6619414385512743,0.8569523197573703
50.0,0.0004049735540879638,0.6745207047935076,0.9201425524065895,0.6619414385512743,0.8638178665451156
60.0,0.00024247262624711545,0.674520706799023,0.9228809992886139,0.6619414385512743,0.8646388007627592
1 epoch lr baseline_acc_average prediction_acc_average valid_baseline_acc_average valid_prediction_acc_average
2 10.0 0.0031512470486230455 0.6745207058933649 0.8722085664853889 0.6619414385512743 0.8330426198868408
3 20.0 0.0018867680126765363 0.6745207054940767 0.8948185154648364 0.6619414385512743 0.8627868126688715
4 30.0 0.0011296777049628276 0.674520705352511 0.908833516738123 0.6619414385512743 0.8652105219446476
5 40.0 0.0006763797713952794 0.6745207055067813 0.915962817709497 0.6619414385512743 0.8569523197573703
6 50.0 0.0004049735540879638 0.6745207047935076 0.9201425524065895 0.6619414385512743 0.8638178665451156
7 60.0 0.00024247262624711545 0.674520706799023 0.9228809992886139 0.6619414385512743 0.8646388007627592

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,9 @@
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
10.0,0.0006710886400000004,7.775465094474378,21.685016596478096,9.080147244893045,22.25119689216382
20.0,7.205759403792802e-05,7.335939544994145,21.685016544265725,8.836104823252782,22.25119689216382
30.0,7.73712524553364e-06,7.269482698812543,21.68501663562285,8.795702427614733,22.25119689216382
40.0,8.307674973655742e-07,7.256998672841371,21.685016551757823,8.785421505532184,22.25119689216382
50.0,8.920298079412272e-08,7.269379717798182,21.68501661848976,8.782566600764051,22.25119689216382
60.0,9.578097130411839e-09,7.259323381483945,21.685016565057747,8.782792238863054,22.25119689216382
70.0,1.028440348325758e-09,7.265019663627232,21.685016472364822,8.782802083944372,22.25119689216382
80.0,1.1042794154864952e-10,7.262544146220354,21.685016528410358,8.782802388717515,22.25119689216382
1 epoch lr prediction_error_average baseline_error_average valid_prediction_error_average valid_baseline_error_average
2 10.0 0.0006710886400000004 7.775465094474378 21.685016596478096 9.080147244893045 22.25119689216382
3 20.0 7.205759403792802e-05 7.335939544994145 21.685016544265725 8.836104823252782 22.25119689216382
4 30.0 7.73712524553364e-06 7.269482698812543 21.68501663562285 8.795702427614733 22.25119689216382
5 40.0 8.307674973655742e-07 7.256998672841371 21.685016551757823 8.785421505532184 22.25119689216382
6 50.0 8.920298079412272e-08 7.269379717798182 21.68501661848976 8.782566600764051 22.25119689216382
7 60.0 9.578097130411839e-09 7.259323381483945 21.685016565057747 8.782792238863054 22.25119689216382
8 70.0 1.028440348325758e-09 7.265019663627232 21.685016472364822 8.782802083944372 22.25119689216382
9 80.0 1.1042794154864952e-10 7.262544146220354 21.685016528410358 8.782802388717515 22.25119689216382

View file

@ -0,0 +1,7 @@
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
10.0,0.0031512470486230455,0.6816170010865429,0.852783038866365,0.7241098171488276,0.8022957816714968
20.0,0.0018867680126765363,0.6816170011223602,0.8716868679089899,0.7241098171488276,0.819822787597355
30.0,0.0011296777049628276,0.6816170013059245,0.8825985085263077,0.7241098171488276,0.8229033185732267
40.0,0.0006763797713952794,0.6816170021420371,0.8881970469835044,0.7241098171488276,0.8239226446778548
50.0,0.0004049735540879638,0.6816170025595338,0.8919577810969009,0.7241098171488276,0.8248788276437122
60.0,0.00024247262624711545,0.6816170014827729,0.8943664556306086,0.7241098171488276,0.8249509930207405
1 epoch lr baseline_acc_average prediction_acc_average valid_baseline_acc_average valid_prediction_acc_average
2 10.0 0.0031512470486230455 0.6816170010865429 0.852783038866365 0.7241098171488276 0.8022957816714968
3 20.0 0.0018867680126765363 0.6816170011223602 0.8716868679089899 0.7241098171488276 0.819822787597355
4 30.0 0.0011296777049628276 0.6816170013059245 0.8825985085263077 0.7241098171488276 0.8229033185732267
5 40.0 0.0006763797713952794 0.6816170021420371 0.8881970469835044 0.7241098171488276 0.8239226446778548
6 50.0 0.0004049735540879638 0.6816170025595338 0.8919577810969009 0.7241098171488276 0.8248788276437122
7 60.0 0.00024247262624711545 0.6816170014827729 0.8943664556306086 0.7241098171488276 0.8249509930207405

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,9 @@
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
10.0,0.0031512470486230455,8.74235833757656,22.929272996737275,10.080445969203966,23.691519417545646
20.0,0.0018867680126765363,7.879988509958766,22.92927302668063,9.581555476071555,23.691519417545646
30.0,0.0011296777049628276,7.3950617451102945,22.929273043729715,9.036600356479111,23.691519417545646
40.0,0.0006763797713952794,7.07688238604103,22.92927304215375,8.760339399561913,23.691519417545646
50.0,0.0004049735540879638,6.9027446432564235,22.929273003757487,8.635063652772557,23.691519417545646
60.0,0.00024247262624711545,6.775644442450048,22.929273132700157,8.549054237754392,23.691519417545646
70.0,0.00014517731808828924,6.714401105429294,22.929273023958505,8.694673094179414,23.691519417545646
80.0,8.69230230790188e-05,6.677202380895346,22.92927304903069,8.57591808497942,23.691519417545646
1 epoch lr prediction_error_average baseline_error_average valid_prediction_error_average valid_baseline_error_average
2 10.0 0.0031512470486230455 8.74235833757656 22.929272996737275 10.080445969203966 23.691519417545646
3 20.0 0.0018867680126765363 7.879988509958766 22.92927302668063 9.581555476071555 23.691519417545646
4 30.0 0.0011296777049628276 7.3950617451102945 22.929273043729715 9.036600356479111 23.691519417545646
5 40.0 0.0006763797713952794 7.07688238604103 22.92927304215375 8.760339399561913 23.691519417545646
6 50.0 0.0004049735540879638 6.9027446432564235 22.929273003757487 8.635063652772557 23.691519417545646
7 60.0 0.00024247262624711545 6.775644442450048 22.929273132700157 8.549054237754392 23.691519417545646
8 70.0 0.00014517731808828924 6.714401105429294 22.929273023958505 8.694673094179414 23.691519417545646
9 80.0 8.69230230790188e-05 6.677202380895346 22.92927304903069 8.57591808497942 23.691519417545646

View file

@ -0,0 +1,7 @@
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
10.0,0.0031512470486230455,0.7021381167567796,0.8402169120884563,0.6901813036133927,0.8264494744106567
20.0,0.0018867680126765363,0.7021381167072633,0.8624022141829342,0.6901813036133927,0.8494821318172411
30.0,0.0011296777049628276,0.7021381160817937,0.8747672176315217,0.6901813036133927,0.8499820602572988
40.0,0.0006763797713952794,0.7021381155110528,0.8815414452682059,0.6901813036133927,0.8484319224921495
50.0,0.0004049735540879638,0.7021381158498488,0.8865171791290529,0.6901813036133927,0.8488455337082769
60.0,0.00024247262624711545,0.7021381161156733,0.8899246960738945,0.6901813036133927,0.847752164009063
1 epoch lr baseline_acc_average prediction_acc_average valid_baseline_acc_average valid_prediction_acc_average
2 10.0 0.0031512470486230455 0.7021381167567796 0.8402169120884563 0.6901813036133927 0.8264494744106567
3 20.0 0.0018867680126765363 0.7021381167072633 0.8624022141829342 0.6901813036133927 0.8494821318172411
4 30.0 0.0011296777049628276 0.7021381160817937 0.8747672176315217 0.6901813036133927 0.8499820602572988
5 40.0 0.0006763797713952794 0.7021381155110528 0.8815414452682059 0.6901813036133927 0.8484319224921495
6 50.0 0.0004049735540879638 0.7021381158498488 0.8865171791290529 0.6901813036133927 0.8488455337082769
7 60.0 0.00024247262624711545 0.7021381161156733 0.8899246960738945 0.6901813036133927 0.847752164009063

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,9 @@
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
10.0,0.0031512470486230455,8.839351579997135,23.505237637305886,9.701928491884642,22.82643943394779
20.0,0.0018867680126765363,8.0279754825312,23.505237625713846,9.222408025492149,22.82643943394779
30.0,0.0011296777049628276,7.478540301192631,23.50523756792046,8.860391529150753,22.82643943394779
40.0,0.0006763797713952794,7.155920212053355,23.505237645728876,8.689459985116457,22.82643943394779
50.0,0.0004049735540879638,6.943932257432076,23.505237599193936,8.75292004044837,22.82643943394779
60.0,0.00024247262624711545,6.828086620144171,23.505237697934735,8.7168934796975,22.82643943394779
70.0,0.00014517731808828924,6.745624930361051,23.505237540900172,8.723768544186793,22.82643943394779
80.0,8.69230230790188e-05,6.688888099966966,23.505237563083494,8.702503739012487,22.82643943394779
1 epoch lr prediction_error_average baseline_error_average valid_prediction_error_average valid_baseline_error_average
2 10.0 0.0031512470486230455 8.839351579997135 23.505237637305886 9.701928491884642 22.82643943394779
3 20.0 0.0018867680126765363 8.0279754825312 23.505237625713846 9.222408025492149 22.82643943394779
4 30.0 0.0011296777049628276 7.478540301192631 23.50523756792046 8.860391529150753 22.82643943394779
5 40.0 0.0006763797713952794 7.155920212053355 23.505237645728876 8.689459985116457 22.82643943394779
6 50.0 0.0004049735540879638 6.943932257432076 23.505237599193936 8.75292004044837 22.82643943394779
7 60.0 0.00024247262624711545 6.828086620144171 23.505237697934735 8.7168934796975 22.82643943394779
8 70.0 0.00014517731808828924 6.745624930361051 23.505237540900172 8.723768544186793 22.82643943394779
9 80.0 8.69230230790188e-05 6.688888099966966 23.505237563083494 8.702503739012487 22.82643943394779

View file

@ -0,0 +1,7 @@
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
10.0,0.0031512470486230455,0.7052336015288246,0.8534142175928565,0.661994266023742,0.8135229105560388
20.0,0.0018867680126765363,0.7052336013105647,0.8725956164871361,0.661994266023742,0.831561643393579
30.0,0.0011296777049628276,0.7052336017095152,0.883130780337477,0.661994266023742,0.8272765153270717
40.0,0.0006763797713952794,0.7052336016146974,0.8908625688394565,0.661994266023742,0.8306551724193417
50.0,0.0004049735540879638,0.7052336010994608,0.8941962028439242,0.661994266023742,0.8302761033822049
60.0,0.00024247262624711545,0.7052336012712063,0.8975138279303041,0.661994266023742,0.8269304087636498
1 epoch lr baseline_acc_average prediction_acc_average valid_baseline_acc_average valid_prediction_acc_average
2 10.0 0.0031512470486230455 0.7052336015288246 0.8534142175928565 0.661994266023742 0.8135229105560388
3 20.0 0.0018867680126765363 0.7052336013105647 0.8725956164871361 0.661994266023742 0.831561643393579
4 30.0 0.0011296777049628276 0.7052336017095152 0.883130780337477 0.661994266023742 0.8272765153270717
5 40.0 0.0006763797713952794 0.7052336016146974 0.8908625688394565 0.661994266023742 0.8306551724193417
6 50.0 0.0004049735540879638 0.7052336010994608 0.8941962028439242 0.661994266023742 0.8302761033822049
7 60.0 0.00024247262624711545 0.7052336012712063 0.8975138279303041 0.661994266023742 0.8269304087636498

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,9 @@
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
10.0,0.0031512470486230455,8.883382703181908,23.21022984623519,10.4297500209549,23.16488788543437
20.0,0.0018867680126765363,7.782790481198015,23.210229789501916,9.385957283378383,23.16488788543437
30.0,0.0011296777049628276,7.3295042459914015,23.210229956037832,8.8375927512519,23.16488788543437
40.0,0.0006763797713952794,7.08854188886722,23.21022992380692,8.90408793085143,23.16488788543437
50.0,0.0004049735540879638,6.931669339063979,23.21022993880603,8.754324602049273,23.16488788543437
60.0,0.00024247262624711545,6.825089292923836,23.21022987319924,8.792744936666029,23.16488788543437
70.0,0.00014517731808828924,6.76963417329046,23.210229846750426,8.694600144187069,23.16488788543437
80.0,8.69230230790188e-05,6.737586019620012,23.210229833812264,8.71311124122629,23.16488788543437
1 epoch lr prediction_error_average baseline_error_average valid_prediction_error_average valid_baseline_error_average
2 10.0 0.0031512470486230455 8.883382703181908 23.21022984623519 10.4297500209549 23.16488788543437
3 20.0 0.0018867680126765363 7.782790481198015 23.210229789501916 9.385957283378383 23.16488788543437
4 30.0 0.0011296777049628276 7.3295042459914015 23.210229956037832 8.8375927512519 23.16488788543437
5 40.0 0.0006763797713952794 7.08854188886722 23.21022992380692 8.90408793085143 23.16488788543437
6 50.0 0.0004049735540879638 6.931669339063979 23.21022993880603 8.754324602049273 23.16488788543437
7 60.0 0.00024247262624711545 6.825089292923836 23.21022987319924 8.792744936666029 23.16488788543437
8 70.0 0.00014517731808828924 6.76963417329046 23.210229846750426 8.694600144187069 23.16488788543437
9 80.0 8.69230230790188e-05 6.737586019620012 23.210229833812264 8.71311124122629 23.16488788543437

View file

@ -0,0 +1,7 @@
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
10.0,0.0031512470486230455,0.6967981925509015,0.8302275884102602,0.6967302647043779,0.8086298808019007
20.0,0.0018867680126765363,0.696825941169604,0.8459507790012438,0.6967302647043779,0.8300570517057798
30.0,0.0011296777049628276,0.6967939224521644,0.8587645590916032,0.6967302647043779,0.8328419266690059
40.0,0.0006763797713952794,0.6967939235514304,0.866474601211118,0.6967302647043779,0.8400695022933019
50.0,0.0004049735540879638,0.6967939230934029,0.8740437618288838,0.6967302647043779,0.8441059351178314
60.0,0.00024247262624711545,0.6967939231544733,0.8783235607332871,0.6967302647043779,0.8460717301882389
1 epoch lr baseline_acc_average prediction_acc_average valid_baseline_acc_average valid_prediction_acc_average
2 10.0 0.0031512470486230455 0.6967981925509015 0.8302275884102602 0.6967302647043779 0.8086298808019007
3 20.0 0.0018867680126765363 0.696825941169604 0.8459507790012438 0.6967302647043779 0.8300570517057798
4 30.0 0.0011296777049628276 0.6967939224521644 0.8587645590916032 0.6967302647043779 0.8328419266690059
5 40.0 0.0006763797713952794 0.6967939235514304 0.866474601211118 0.6967302647043779 0.8400695022933019
6 50.0 0.0004049735540879638 0.6967939230934029 0.8740437618288838 0.6967302647043779 0.8441059351178314
7 60.0 0.00024247262624711545 0.6967939231544733 0.8783235607332871 0.6967302647043779 0.8460717301882389

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,9 @@
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
10.0,0.0031512470486230455,8.723726561323542,23.18913731418672,10.435996211392954,23.239844569106314
20.0,0.0018867680126765363,7.853544009024979,23.188485688850527,9.561553679680529,23.239844569106314
30.0,0.0011296777049628276,7.340472601476263,23.18896493169128,9.575626259120058,23.239844569106314
40.0,0.0006763797713952794,7.076471305772906,23.188663758215355,9.33665127001508,23.239844569106314
50.0,0.0004049735540879638,6.922020800289561,23.189159344454282,9.39458229979059,23.239844569106314
60.0,0.00024247262624711545,6.816650379388059,23.18947358795854,9.273735522755194,23.239844569106314
70.0,0.00014517731808828924,6.751277917232669,23.188865782784635,9.232296420427865,23.239844569106314
80.0,8.69230230790188e-05,6.698476171395818,23.1890932419261,9.404228606204375,23.239844569106314
1 epoch lr prediction_error_average baseline_error_average valid_prediction_error_average valid_baseline_error_average
2 10.0 0.0031512470486230455 8.723726561323542 23.18913731418672 10.435996211392954 23.239844569106314
3 20.0 0.0018867680126765363 7.853544009024979 23.188485688850527 9.561553679680529 23.239844569106314
4 30.0 0.0011296777049628276 7.340472601476263 23.18896493169128 9.575626259120058 23.239844569106314
5 40.0 0.0006763797713952794 7.076471305772906 23.188663758215355 9.33665127001508 23.239844569106314
6 50.0 0.0004049735540879638 6.922020800289561 23.189159344454282 9.39458229979059 23.239844569106314
7 60.0 0.00024247262624711545 6.816650379388059 23.18947358795854 9.273735522755194 23.239844569106314
8 70.0 0.00014517731808828924 6.751277917232669 23.188865782784635 9.232296420427865 23.239844569106314
9 80.0 8.69230230790188e-05 6.698476171395818 23.1890932419261 9.404228606204375 23.239844569106314

View file

@ -0,0 +1,7 @@
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
10.0,0.0031512470486230455,0.6960166899282387,0.8442398727362511,0.6979547482149892,0.8083841012201665
20.0,0.0018867680126765363,0.6960166904245663,0.8555533757053816,0.6979547482149892,0.7882599223551159
30.0,0.0011296777049628276,0.6960166902980982,0.8661169178601649,0.6979547482149892,0.8318399900875737
40.0,0.0006763797713952794,0.6960166916629991,0.8750150583824932,0.6979547482149892,0.8195155037706253
50.0,0.0004049735540879638,0.6960166904341111,0.8811001704870125,0.6979547482149892,0.8152017306608094
60.0,0.00024247262624711545,0.6960166910258863,0.8855065376006734,0.6979547482149892,0.8149095465877441
1 epoch lr baseline_acc_average prediction_acc_average valid_baseline_acc_average valid_prediction_acc_average
2 10.0 0.0031512470486230455 0.6960166899282387 0.8442398727362511 0.6979547482149892 0.8083841012201665
3 20.0 0.0018867680126765363 0.6960166904245663 0.8555533757053816 0.6979547482149892 0.7882599223551159
4 30.0 0.0011296777049628276 0.6960166902980982 0.8661169178601649 0.6979547482149892 0.8318399900875737
5 40.0 0.0006763797713952794 0.6960166916629991 0.8750150583824932 0.6979547482149892 0.8195155037706253
6 50.0 0.0004049735540879638 0.6960166904341111 0.8811001704870125 0.6979547482149892 0.8152017306608094
7 60.0 0.00024247262624711545 0.6960166910258863 0.8855065376006734 0.6979547482149892 0.8149095465877441

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,9 @@
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
10.0,0.0031512470486230455,9.092302583226918,20.045022424015006,10.871499093396837,28.00064259211687
20.0,0.0018867680126765363,8.449609629257155,20.045022394846217,9.163385394021835,28.00064259211687
30.0,0.0011296777049628276,8.004199968918563,20.04502237598577,9.715205940504797,28.00064259211687
40.0,0.0006763797713952794,7.736533732776755,20.045022482581658,10.492947047604881,28.00064259211687
50.0,0.0004049735540879638,7.550458214825991,20.04502240393283,9.826412407127586,28.00064259211687
60.0,0.00024247262624711545,7.439071941005396,20.045022420044386,9.617718728174395,28.00064259211687
70.0,0.00014517731808828924,7.354865613532145,20.045022362546746,9.901499594044505,28.00064259211687
80.0,8.69230230790188e-05,7.31953181233149,20.04502238805035,9.986867171353252,28.00064259211687
1 epoch lr prediction_error_average baseline_error_average valid_prediction_error_average valid_baseline_error_average
2 10.0 0.0031512470486230455 9.092302583226918 20.045022424015006 10.871499093396837 28.00064259211687
3 20.0 0.0018867680126765363 8.449609629257155 20.045022394846217 9.163385394021835 28.00064259211687
4 30.0 0.0011296777049628276 8.004199968918563 20.04502237598577 9.715205940504797 28.00064259211687
5 40.0 0.0006763797713952794 7.736533732776755 20.045022482581658 10.492947047604881 28.00064259211687
6 50.0 0.0004049735540879638 7.550458214825991 20.04502240393283 9.826412407127586 28.00064259211687
7 60.0 0.00024247262624711545 7.439071941005396 20.045022420044386 9.617718728174395 28.00064259211687
8 70.0 0.00014517731808828924 7.354865613532145 20.045022362546746 9.901499594044505 28.00064259211687
9 80.0 8.69230230790188e-05 7.31953181233149 20.04502238805035 9.986867171353252 28.00064259211687

View file

@ -0,0 +1,7 @@
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
10.0,0.0031512470486230455,0.6974864419165719,0.8427737693608807,0.6955263208594125,0.8233390038358083
20.0,0.0018867680126765363,0.6974864421003018,0.8595619837544608,0.6955263208594125,0.8226455437919652
30.0,0.0011296777049628276,0.6974864414101947,0.8704107908384844,0.6955263208594125,0.8303636568959843
40.0,0.0006763797713952794,0.6974864407873057,0.8765607045589645,0.6955263208594125,0.8233164893276423
50.0,0.0004049735540879638,0.69748644039744,0.8822820797718322,0.6955263208594125,0.832633161305254
60.0,0.00024247262624711545,0.6974864419569029,0.887539829914574,0.6955263208594125,0.8349657074195717
1 epoch lr baseline_acc_average prediction_acc_average valid_baseline_acc_average valid_prediction_acc_average
2 10.0 0.0031512470486230455 0.6974864419165719 0.8427737693608807 0.6955263208594125 0.8233390038358083
3 20.0 0.0018867680126765363 0.6974864421003018 0.8595619837544608 0.6955263208594125 0.8226455437919652
4 30.0 0.0011296777049628276 0.6974864414101947 0.8704107908384844 0.6955263208594125 0.8303636568959843
5 40.0 0.0006763797713952794 0.6974864407873057 0.8765607045589645 0.6955263208594125 0.8233164893276423
6 50.0 0.0004049735540879638 0.69748644039744 0.8822820797718322 0.6955263208594125 0.832633161305254
7 60.0 0.00024247262624711545 0.6974864419569029 0.887539829914574 0.6955263208594125 0.8349657074195717

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,9 @@
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
10.0,0.0031512470486230455,8.638778725599636,26.1798637664889,10.833028636933662,17.849539670393458
20.0,0.0018867680126765363,7.791013009296702,26.179863858981182,10.735866879370116,17.849539670393458
30.0,0.0011296777049628276,7.266133595065025,26.17986385037725,9.690580197114954,17.849539670393458
40.0,0.0006763797713952794,6.93408266801277,26.17986383503357,10.310306123348958,17.849539670393458
50.0,0.0004049735540879638,6.787781177862328,26.179863753869796,10.454030626426473,17.849539670393458
60.0,0.00024247262624711545,6.659840548202495,26.179863708842547,10.255192241327299,17.849539670393458
70.0,0.00014517731808828924,6.597368892964792,26.17986375731137,10.151271223901132,17.849539670393458
80.0,8.69230230790188e-05,6.558900393612285,26.179863782692973,10.218871813100643,17.849539670393458
1 epoch lr prediction_error_average baseline_error_average valid_prediction_error_average valid_baseline_error_average
2 10.0 0.0031512470486230455 8.638778725599636 26.1798637664889 10.833028636933662 17.849539670393458
3 20.0 0.0018867680126765363 7.791013009296702 26.179863858981182 10.735866879370116 17.849539670393458
4 30.0 0.0011296777049628276 7.266133595065025 26.17986385037725 9.690580197114954 17.849539670393458
5 40.0 0.0006763797713952794 6.93408266801277 26.17986383503357 10.310306123348958 17.849539670393458
6 50.0 0.0004049735540879638 6.787781177862328 26.179863753869796 10.454030626426473 17.849539670393458
7 60.0 0.00024247262624711545 6.659840548202495 26.179863708842547 10.255192241327299 17.849539670393458
8 70.0 0.00014517731808828924 6.597368892964792 26.17986375731137 10.151271223901132 17.849539670393458
9 80.0 8.69230230790188e-05 6.558900393612285 26.179863782692973 10.218871813100643 17.849539670393458

101
environment/hoigaze.yml Normal file
View file

@ -0,0 +1,101 @@
name: hoigaze
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2023.12.12=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.4=h6a678d5_0
- openssl=3.0.12=h7f8727e_0
- pip=23.3.1=py38h06a4308_0
- python=3.8.18=h955ad1f_0
- readline=8.2=h5eee18b_0
- setuptools=68.2.2=py38h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.41.2=py38h06a4308_0
- xz=5.4.5=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- absl-py==2.0.0
- aiohttp==3.9.1
- aiosignal==1.3.1
- appdirs==1.4.4
- async-timeout==4.0.3
- attrs==23.2.0
- cachetools==5.3.2
- certifi==2023.11.17
- charset-normalizer==3.3.2
- click==8.1.7
- contourpy==1.1.1
- cycler==0.12.1
- cython==0.29.37
- docker-pycreds==0.4.0
- fonttools==4.47.2
- frozenlist==1.4.1
- fsspec==2023.12.2
- ftfy==6.1.3
- future==0.18.3
- gitdb==4.0.11
- gitpython==3.1.41
- google-auth==2.26.2
- google-auth-oauthlib==1.0.0
- grpcio==1.60.0
- hdbscan==0.8.33
- idna==3.6
- importlib-metadata==7.0.1
- importlib-resources==6.1.1
- joblib==1.3.2
- kiwisolver==1.4.5
- lmdb==1.2.1
- lpips==0.1.4
- markdown==3.5.2
- markupsafe==2.1.3
- matplotlib==3.5.3
- multidict==6.0.4
- numpy==1.24.4
- oauthlib==3.2.2
- packaging==23.2
- pandas==1.5.3
- pillow==10.2.0
- protobuf==4.25.2
- psutil==5.9.8
- pyasn1==0.5.1
- pyasn1-modules==0.3.0
- pydeprecate==0.3.1
- pyparsing==3.1.1
- python-dateutil==2.8.2
- pytorch-fid==0.2.0
- pytorch-lightning==1.4.5
- pytz==2023.3.post1
- pyyaml==6.0.1
- regex==2023.12.25
- requests==2.31.0
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-learn==1.3.2
- scipy==1.5.4
- sentry-sdk==1.39.2
- setproctitle==1.3.3
- six==1.16.0
- smmap==5.0.1
- tensorboard==2.14.0
- tensorboard-data-server==0.7.2
- threadpoolctl==3.2.0
- torch==1.8.1
- torchmetrics==0.5.0
- torchvision==0.9.1
- tqdm==4.66.1
- typing-extensions==4.9.0
- tzdata==2023.4
- urllib3==2.1.0
- wandb==0.16.2
- wcwidth==0.2.13
- werkzeug==3.0.1
- yarl==1.9.4
- zipp==3.17.0

323
gaze_estimation_adt.py Normal file
View file

@ -0,0 +1,323 @@
from utils import adt_dataset, seed_torch
from model import gaze_estimation
from utils.opt import options
from utils import log
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import numpy as np
import time
import datetime
import torch.optim as optim
import torch.nn.functional as F
import os
os.nice(5)
import math
def main(opt):
# set the random seed to ensure reproducibility
seed_torch.seed_torch(seed=0)
torch.set_num_threads(1)
data_dir = opt.data_dir
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
learning_rate = opt.learning_rate
print('>>> create model')
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
print('>>> loading datasets')
train_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_train.npy")
valid_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
train_dataset = np.load(train_data_path)
train_data_size = train_dataset.shape
print("Training data size: {}".format(train_data_size))
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_dataset = np.load(valid_data_path)
valid_data_size = valid_dataset.shape
print("Validation data size: {}".format(valid_data_size))
valid_loader = DataLoader(valid_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
# training
local_time = time.asctime(time.localtime(time.time()))
print('\nTraining starts at ' + local_time)
start_time = datetime.datetime.now()
start_epoch = 1
err_best = 1000
best_epoch = 0
exp_lr = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma, last_epoch=-1)
for epo in range(start_epoch, opt.epoch + 1):
is_best = False
learning_rate = exp_lr.optimizer.param_groups[0]["lr"]
train_start_time = datetime.datetime.now()
result_train = run_model(net, optimizer, is_train=1, data_loader=train_loader, opt=opt)
train_end_time = datetime.datetime.now()
train_time = (train_end_time - train_start_time).seconds*1000
train_batch_num = math.ceil(train_data_size[0]/opt.batch_size)
train_time_per_batch = math.ceil(train_time/train_batch_num)
#print('\nTraining time per batch: {} ms'.format(train_time_per_batch))
exp_lr.step()
rng_state = torch.get_rng_state()
if epo % opt.validation_epoch == 0:
print('>>> training epoch: {:d}, lr: {:.12f}'.format(epo, learning_rate))
print('Training data size: {}'.format(train_data_size))
print('Average baseline error: {:.2f} degree'.format(result_train['baseline_error_average']))
print('Average training error: {:.2f} degree'.format(result_train['prediction_error_average']))
test_start_time = datetime.datetime.now()
result_valid = run_model(net, is_train=0, data_loader=valid_loader, opt=opt)
test_end_time = datetime.datetime.now()
test_time = (test_end_time - test_start_time).seconds*1000
test_batch_num = math.ceil(valid_data_size[0]/opt.test_batch_size)
test_time_per_batch = math.ceil(test_time/test_batch_num)
#print('\nTest time per batch: {} ms'.format(test_time_per_batch))
print('Validation data size: {}'.format(valid_data_size))
print('Average baseline error: {:.2f} degree'.format(result_valid['baseline_error_average']))
print('Average validation error: {:.2f} degree'.format(result_valid['prediction_error_average']))
if result_valid['prediction_error_average'] < err_best:
err_best = result_valid['prediction_error_average']
is_best = True
best_epoch = epo
print('Best validation error: {:.2f} degree, best epoch: {}'.format(err_best, best_epoch))
end_time = datetime.datetime.now()
total_training_time = (end_time - start_time).seconds/60
print('\nTotal training time: {:.2f} min'.format(total_training_time))
local_time = time.asctime(time.localtime(time.time()))
print('\nTraining ends at ' + local_time)
result_log = np.array([epo, learning_rate])
head = np.array(['epoch', 'lr'])
for k in result_train.keys():
result_log = np.append(result_log, [result_train[k]])
head = np.append(head, [k])
for k in result_valid.keys():
result_log = np.append(result_log, [result_valid[k]])
head = np.append(head, ['valid_' + k])
csv_name = 'gaze_estimation_results'
log.save_csv_log(opt, head, result_log, is_create=(epo == 1), file_name=csv_name)
last_model_name = 'gaze_estimation_model_last.pt'
log.save_ckpt({'epoch': epo,
'lr': learning_rate,
'err': result_valid['prediction_error_average'],
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict()},
opt=opt,
file_name = last_model_name)
if epo == best_epoch:
best_model_name = 'gaze_estimation_model_best.pt'
log.save_ckpt({'epoch': epo,
'lr': learning_rate,
'err': result_valid['prediction_error_average'],
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict()},
opt=opt,
file_name = best_model_name)
torch.set_rng_state(rng_state)
def eval(opt):
data_dir = opt.data_dir
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
print('>>> create model')
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
#load model
model_name = 'gaze_estimation_model_best.pt'
model_path = os.path.join(opt.ckpt, model_name)
print(">>> loading ckpt from '{}'".format(model_path))
ckpt = torch.load(model_path)
net.load_state_dict(ckpt['state_dict'])
print(">>> ckpt loaded (epoch: {} | err: {})".format(ckpt['epoch'], ckpt['err']))
print('>>> loading datasets')
test_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
test_dataset = np.load(test_data_path)
test_data_size = test_dataset.shape
print("Test data size: {}".format(test_data_size))
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
# test
local_time = time.asctime(time.localtime(time.time()))
print('\nTest starts at ' + local_time)
start_time = datetime.datetime.now()
if opt.save_predictions:
result_test, predictions = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
else:
result_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
print('Average baseline error: {:.2f} degree'.format(result_test['baseline_error_average']))
print('Average prediction error: {:.2f} degree'.format(result_test['prediction_error_average']))
end_time = datetime.datetime.now()
total_test_time = (end_time - start_time).seconds/60
print('\nTotal test time: {:.2f} min'.format(total_test_time))
local_time = time.asctime(time.localtime(time.time()))
print('\nTest ends at ' + local_time)
if opt.save_predictions:
# ground_truth + joints + head_directions + object_positions + attended_hand_prd + attended_hand_gt + predictions
batch_size, seq_n, dim = predictions.shape
predictions = predictions.reshape(-1, dim)
ground_truth = predictions[:, :3]
head_directions = predictions[:, 3+opt.joint_number*3:6+opt.joint_number*3]
head_cos = np.sum(head_directions*ground_truth, 1)
head_cos = np.clip(head_cos, -1, 1)
head_errors = np.arccos(head_cos)/np.pi * 180.0
print('Average baseline error: {:.2f} degree'.format(np.mean(head_errors)))
prediction = predictions[:, -3:]
prd_cos = np.sum(prediction*ground_truth, 1)
prd_cos = np.clip(prd_cos, -1, 1)
prediction_errors = np.arccos(prd_cos)/np.pi * 180.0
print('Average prediction error: {:.2f} degree'.format(np.mean(prediction_errors)))
attended_hand_gt = predictions[:, -4]
attended_hand_prd_left = predictions[:, -6]
attended_hand_prd_right = predictions[:, -5]
attended_hand_correct = attended_hand_prd_left
for i in range(attended_hand_correct.shape[0]):
if attended_hand_gt[i] == 0 and attended_hand_prd_left[i] > attended_hand_prd_right[i]:
attended_hand_correct[i] = 1
elif attended_hand_gt[i] == 1 and attended_hand_prd_left[i] < attended_hand_prd_right[i]:
attended_hand_correct[i] = 1
else:
attended_hand_correct[i] = 0
correct_ratio = np.sum(attended_hand_correct)/attended_hand_correct.shape[0]
print("hand recognition acc: {:.2f}%".format(correct_ratio*100))
attended_hand_wrong = 1 - attended_hand_correct
wrong_ratio = np.sum(attended_hand_wrong)/attended_hand_wrong.shape[0]
head_errors_correct = np.sum(head_errors*attended_hand_correct)/np.sum(attended_hand_correct)
print("hand recognition correct size: {}".format(np.sum(attended_hand_correct)))
print("hand recognition correct, average baseline error: {:.2f} degree".format(head_errors_correct))
head_errors_wrong = np.sum(head_errors*attended_hand_wrong)/np.sum(attended_hand_wrong)
print("hand recognition wrong size: {}".format(np.sum(attended_hand_wrong)))
print("hand recognition wrong, average baseline error: {:.2f} degree".format(head_errors_wrong))
head_errors_avg = head_errors_correct*correct_ratio + head_errors_wrong*wrong_ratio
print('Average baseline error: {:.2f} degree'.format(head_errors_avg))
prediction_errors_correct = np.sum(prediction_errors*attended_hand_correct)/np.sum(attended_hand_correct)
print("hand recognition correct, average prediction error: {:.2f} degree".format(prediction_errors_correct))
prediction_errors_wrong = np.sum(prediction_errors*attended_hand_wrong)/np.sum(attended_hand_wrong)
print("hand recognition wrong, average prediction error: {:.2f} degree".format(prediction_errors_wrong))
prediction_errors_avg = prediction_errors_correct*correct_ratio + prediction_errors_wrong*wrong_ratio
print('Average prediction error: {:.2f} degree'.format(prediction_errors_avg))
predictions_path = os.path.join(opt.ckpt, "gaze_predictions.npy")
np.save(predictions_path, predictions)
prediction_errors_path = os.path.join(opt.ckpt, "prediction_errors.npy")
np.save(prediction_errors_path, prediction_errors)
attended_hand_correct_path = os.path.join(opt.ckpt, "attended_hand_correct.npy")
np.save(attended_hand_correct_path, attended_hand_correct)
def acos_safe(x, eps=1e-6):
slope = np.arccos(1-eps) / eps
buf = torch.empty_like(x)
good = abs(x) <= 1-eps
bad = ~good
sign = torch.sign(x[bad])
buf[good] = torch.acos(x[good])
buf[bad] = torch.acos(sign * (1 - eps)) - slope*sign*(abs(x[bad]) - 1 + eps)
return buf
def run_model(net, optimizer=None, is_train=1, data_loader=None, opt=None):
if is_train == 1:
net.train()
else:
net.eval()
if opt.is_eval and opt.save_predictions:
predictions = []
prediction_error_average = 0
baseline_error_average = 0
criterion = torch.nn.MSELoss(reduction='none')
n = 0
input_n = opt.seq_len
for i, (data) in enumerate(data_loader):
batch_size, seq_n, dim = data.shape
joint_number = opt.joint_number
object_num = opt.object_num
# when only one sample in this batch
if batch_size == 1 and is_train == 1:
continue
n += batch_size
data = data.float().to(opt.cuda_idx)
ground_truth = data.clone()[:, :, :3]
joints = data.clone()[:, :, 3:(joint_number+1)*3]
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
attended_hand_prd = data.clone()[:, :, (joint_number+2+8*object_num*2)*3:(joint_number+2+8*object_num*2)*3+2]
attended_hand_gt = data.clone()[:, :, (joint_number+2+8*object_num*2)*3+2:(joint_number+2+8*object_num*2)*3+3]
input = torch.cat((joints, head_directions), dim=2)
if object_num > 0:
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
input = torch.cat((input, object_positions), dim=2)
input = torch.cat((input, attended_hand_prd), dim=2)
input = torch.cat((input, attended_hand_gt), dim=2)
prediction = net(input, input_n=input_n)
if opt.is_eval and opt.save_predictions:
# ground_truth + joints + head_directions + object_positions + attended_hand_prd + attended_hand_gt + predictions
prediction_cpu = torch.cat((ground_truth, input), dim=2)
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
prediction_cpu = prediction_cpu.cpu().data.numpy()
if len(predictions) == 0:
predictions = prediction_cpu
else:
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
gaze_head_cos = torch.sum(ground_truth*head_directions, dim=2, keepdim=True)
gaze_weight = torch.where(gaze_head_cos>opt.gaze_head_cos_threshold, opt.gaze_head_loss_factor, 1.0)
loss = criterion(ground_truth, prediction)
loss = torch.mean(loss*gaze_weight)
if is_train == 1:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Calculate prediction errors
error = torch.mean(acos_safe(torch.sum(ground_truth*prediction, 2)))/torch.tensor(math.pi) * 180.0
prediction_error_average += error.cpu().data.numpy() * batch_size
# Use head directions as the baseline
baseline_error = torch.mean(acos_safe(torch.sum(ground_truth*head_directions, 2)))/torch.tensor(math.pi) * 180.0
baseline_error_average += baseline_error.cpu().data.numpy() * batch_size
result = {}
result["prediction_error_average"] = prediction_error_average / n
result["baseline_error_average"] = baseline_error_average / n
if opt.is_eval and opt.save_predictions:
return result, predictions
else:
return result
if __name__ == '__main__':
option = options().parse()
if option.is_eval == False:
main(option)
else:
eval(option)

552
gaze_estimation_hot3d.py Normal file
View file

@ -0,0 +1,552 @@
from utils import hot3d_aria_dataset, seed_torch
from model import gaze_estimation
from utils.opt import options
from utils import log
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import numpy as np
import time
import datetime
import torch.optim as optim
import torch.nn.functional as F
import os
os.nice(5)
import math
def main(opt):
# set the random seed to ensure reproducibility
seed_torch.seed_torch(seed=0)
torch.set_num_threads(1)
data_dir = opt.data_dir
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
learning_rate = opt.learning_rate
print('>>> create model')
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
print('>>> loading datasets')
actions = opt.actions
test_user_id = opt.test_user_id
if actions == 'all':
if test_user_id == 1:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003']
opt.ckpt = opt.ckpt + '/user1/'
if test_user_id == 2:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0009', 'P0010', 'P0011']
opt.ckpt = opt.ckpt + '/user2/'
if test_user_id == 3:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
test_subjects = ['P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/user3/'
elif actions == 'room':
train_actions = ['kitchen', 'office']
test_actions = ['room']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene1/'
elif actions == 'kitchen':
train_actions = ['room', 'office']
test_actions = ['kitchen']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene2/'
elif actions == 'office':
train_actions = ['room', 'kitchen']
test_actions = ['office']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene3/'
else:
raise( ValueError, "Unrecognised actions: %d" % actions)
train_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_train.npy")
valid_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
train_dataset = np.load(train_data_path)
train_data_size = train_dataset.shape
print("Training data size: {}".format(train_data_size))
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_dataset = np.load(valid_data_path)
valid_data_size = valid_dataset.shape
print("Validation data size: {}".format(valid_data_size))
valid_loader = DataLoader(valid_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
# training
local_time = time.asctime(time.localtime(time.time()))
print('\nTraining starts at ' + local_time)
start_time = datetime.datetime.now()
start_epoch = 1
err_best = 1000
best_epoch = 0
exp_lr = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma, last_epoch=-1)
for epo in range(start_epoch, opt.epoch + 1):
is_best = False
learning_rate = exp_lr.optimizer.param_groups[0]["lr"]
train_start_time = datetime.datetime.now()
result_train = run_model(net, optimizer, is_train=1, data_loader=train_loader, opt=opt)
train_end_time = datetime.datetime.now()
train_time = (train_end_time - train_start_time).seconds*1000
train_batch_num = math.ceil(train_data_size[0]/opt.batch_size)
train_time_per_batch = math.ceil(train_time/train_batch_num)
#print('\nTraining time per batch: {} ms'.format(train_time_per_batch))
exp_lr.step()
rng_state = torch.get_rng_state()
if epo % opt.validation_epoch == 0:
if actions == 'all':
print("\ntest user id: {}\n".format(test_user_id))
elif actions == 'room':
print("\ntest scene/action: room\n")
elif actions == 'kitchen':
print("\ntest scene/action: kitchen\n")
elif actions == 'office':
print("\ntest scene/action: office\n")
print('>>> training epoch: {:d}, lr: {:.12f}'.format(epo, learning_rate))
print('Training data size: {}'.format(train_data_size))
print('Average baseline error: {:.2f} degree'.format(result_train['baseline_error_average']))
print('Average training error: {:.2f} degree'.format(result_train['prediction_error_average']))
test_start_time = datetime.datetime.now()
result_valid = run_model(net, is_train=0, data_loader=valid_loader, opt=opt)
test_end_time = datetime.datetime.now()
test_time = (test_end_time - test_start_time).seconds*1000
test_batch_num = math.ceil(valid_data_size[0]/opt.test_batch_size)
test_time_per_batch = math.ceil(test_time/test_batch_num)
#print('\nTest time per batch: {} ms'.format(test_time_per_batch))
print('Validation data size: {}'.format(valid_data_size))
print('Average baseline error: {:.2f} degree'.format(result_valid['baseline_error_average']))
print('Average validation error: {:.2f} degree'.format(result_valid['prediction_error_average']))
if result_valid['prediction_error_average'] < err_best:
err_best = result_valid['prediction_error_average']
is_best = True
best_epoch = epo
print('Best validation error: {:.2f} degree, best epoch: {}'.format(err_best, best_epoch))
end_time = datetime.datetime.now()
total_training_time = (end_time - start_time).seconds/60
print('\nTotal training time: {:.2f} min'.format(total_training_time))
local_time = time.asctime(time.localtime(time.time()))
print('\nTraining ends at ' + local_time)
result_log = np.array([epo, learning_rate])
head = np.array(['epoch', 'lr'])
for k in result_train.keys():
result_log = np.append(result_log, [result_train[k]])
head = np.append(head, [k])
for k in result_valid.keys():
result_log = np.append(result_log, [result_valid[k]])
head = np.append(head, ['valid_' + k])
csv_name = 'gaze_estimation_results'
log.save_csv_log(opt, head, result_log, is_create=(epo == 1), file_name=csv_name)
model_name = 'gaze_estimation_model_last.pt'
log.save_ckpt({'epoch': epo,
'lr': learning_rate,
'err': result_valid['prediction_error_average'],
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict()},
opt=opt,
file_name = model_name)
if epo == best_epoch:
model_name = 'gaze_estimation_model_best.pt'
log.save_ckpt({'epoch': epo,
'lr': learning_rate,
'err': result_valid['prediction_error_average'],
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict()},
opt=opt,
file_name = model_name)
torch.set_rng_state(rng_state)
def eval(opt):
data_dir = opt.data_dir
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
print('>>> create model')
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
actions = opt.actions
test_user_id = opt.test_user_id
if actions == 'all':
if test_user_id == 1:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003']
opt.ckpt = opt.ckpt + '/user1/'
if test_user_id == 2:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0009', 'P0010', 'P0011']
opt.ckpt = opt.ckpt + '/user2/'
if test_user_id == 3:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
test_subjects = ['P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/user3/'
elif actions == 'room':
train_actions = ['kitchen', 'office']
test_actions = ['room']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene1/'
elif actions == 'kitchen':
train_actions = ['room', 'office']
test_actions = ['kitchen']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene2/'
elif actions == 'office':
train_actions = ['room', 'kitchen']
test_actions = ['office']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene3/'
else:
raise( ValueError, "Unrecognised actions: %d" % actions)
#load model
model_name = 'gaze_estimation_model_best.pt'
model_path = os.path.join(opt.ckpt, model_name)
print(">>> loading ckpt from '{}'".format(model_path))
ckpt = torch.load(model_path)
net.load_state_dict(ckpt['state_dict'])
print(">>> ckpt loaded (epoch: {} | err: {})".format(ckpt['epoch'], ckpt['err']))
print('>>> loading datasets')
test_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
test_dataset = np.load(test_data_path)
test_data_size = test_dataset.shape
print("Test data size: {}".format(test_data_size))
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
# test
local_time = time.asctime(time.localtime(time.time()))
print('\nTest starts at ' + local_time)
start_time = datetime.datetime.now()
if actions == 'all':
print("\ntest user id: {}\n".format(test_user_id))
elif actions == 'room':
print("\ntest scene/action: room\n")
elif actions == 'kitchen':
print("\ntest scene/action: kitchen\n")
elif actions == 'office':
print("\ntest scene/action: office\n")
if opt.save_predictions:
result_test, predictions = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
else:
result_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
print('Average baseline error: {:.2f} degree'.format(result_test['baseline_error_average']))
print('Average prediction error: {:.2f} degree'.format(result_test['prediction_error_average']))
end_time = datetime.datetime.now()
total_test_time = (end_time - start_time).seconds/60
print('\nTotal test time: {:.2f} min'.format(total_test_time))
local_time = time.asctime(time.localtime(time.time()))
print('\nTest ends at ' + local_time)
if opt.save_predictions:
# ground_truth + joints + head_directions + object_positions + attended_hand_prd + attended_hand_gt + predictions
batch_size, seq_n, dim = predictions.shape
predictions = predictions.reshape(-1, dim)
ground_truth = predictions[:, :3]
head_directions = predictions[:, 3+opt.joint_number*3:6+opt.joint_number*3]
head_cos = np.sum(head_directions*ground_truth, 1)
head_cos = np.clip(head_cos, -1, 1)
head_errors = np.arccos(head_cos)/np.pi * 180.0
print('Average baseline error: {:.2f} degree'.format(np.mean(head_errors)))
prediction = predictions[:, -3:]
prd_cos = np.sum(prediction*ground_truth, 1)
prd_cos = np.clip(prd_cos, -1, 1)
prediction_errors = np.arccos(prd_cos)/np.pi * 180.0
print('Average prediction error: {:.2f} degree'.format(np.mean(prediction_errors)))
attended_hand_gt = predictions[:, -4]
attended_hand_prd_left = predictions[:, -6]
attended_hand_prd_right = predictions[:, -5]
attended_hand_correct = attended_hand_prd_left
for i in range(attended_hand_correct.shape[0]):
if attended_hand_gt[i] == 0 and attended_hand_prd_left[i] > attended_hand_prd_right[i]:
attended_hand_correct[i] = 1
elif attended_hand_gt[i] == 1 and attended_hand_prd_left[i] < attended_hand_prd_right[i]:
attended_hand_correct[i] = 1
else:
attended_hand_correct[i] = 0
correct_ratio = np.sum(attended_hand_correct)/attended_hand_correct.shape[0]
print("hand recognition acc: {:.2f}%".format(correct_ratio*100))
attended_hand_wrong = 1 - attended_hand_correct
wrong_ratio = np.sum(attended_hand_wrong)/attended_hand_wrong.shape[0]
head_errors_correct = np.sum(head_errors*attended_hand_correct)/np.sum(attended_hand_correct)
print("hand recognition correct size: {}".format(np.sum(attended_hand_correct)))
print("hand recognition correct, average baseline error: {:.2f} degree".format(head_errors_correct))
head_errors_wrong = np.sum(head_errors*attended_hand_wrong)/np.sum(attended_hand_wrong)
print("hand recognition wrong size: {}".format(np.sum(attended_hand_wrong)))
print("hand recognition wrong, average baseline error: {:.2f} degree".format(head_errors_wrong))
head_errors_avg = head_errors_correct*correct_ratio + head_errors_wrong*wrong_ratio
print('Average baseline error: {:.2f} degree'.format(head_errors_avg))
prediction_errors_correct = np.sum(prediction_errors*attended_hand_correct)/np.sum(attended_hand_correct)
print("hand recognition correct, average prediction error: {:.2f} degree".format(prediction_errors_correct))
prediction_errors_wrong = np.sum(prediction_errors*attended_hand_wrong)/np.sum(attended_hand_wrong)
print("hand recognition wrong, average prediction error: {:.2f} degree".format(prediction_errors_wrong))
prediction_errors_avg = prediction_errors_correct*correct_ratio + prediction_errors_wrong*wrong_ratio
print('Average prediction error: {:.2f} degree'.format(prediction_errors_avg))
predictions_path = os.path.join(opt.ckpt, "gaze_predictions.npy")
np.save(predictions_path, predictions)
prediction_errors_path = os.path.join(opt.ckpt, "prediction_errors.npy")
np.save(prediction_errors_path, prediction_errors)
attended_hand_correct_path = os.path.join(opt.ckpt, "attended_hand_correct.npy")
np.save(attended_hand_correct_path, attended_hand_correct)
def eval_single(opt):
from utils import hot3d_aria_single_dataset
from model import attended_hand_recognition
seq_len = opt.seq_len
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
print('>>> create model')
opt.residual_gcns_num = 2
hand_model = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
opt.residual_gcns_num = 4
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
actions = opt.actions
test_user_id = opt.test_user_id
if actions == 'all':
if test_user_id == 1:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003']
opt.ckpt = opt.ckpt + '/user1/'
if test_user_id == 2:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0009', 'P0010', 'P0011']
opt.ckpt = opt.ckpt + '/user2/'
if test_user_id == 3:
train_actions = 'all'
test_actions = 'all'
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
test_subjects = ['P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/user3/'
elif actions == 'room':
train_actions = ['kitchen', 'office']
test_actions = ['room']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene1/'
elif actions == 'kitchen':
train_actions = ['room', 'office']
test_actions = ['kitchen']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene2/'
elif actions == 'office':
train_actions = ['room', 'kitchen']
test_actions = ['office']
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
opt.ckpt = opt.ckpt + '/scene3/'
else:
raise( ValueError, "Unrecognised actions: %d" % actions)
#load hand model
model_name = 'attended_hand_recognition_model.pt'
model_path = os.path.join(opt.ckpt, model_name)
print(">>> loading ckpt from '{}'".format(model_path))
ckpt = torch.load(model_path)
hand_model.load_state_dict(ckpt['state_dict'])
print(">>> ckpt loaded (epoch: {} | acc: {})".format(ckpt['epoch'], ckpt['acc']))
hand_model.eval()
#load gaze model
model_name = 'gaze_estimation_model_best.pt'
model_path = os.path.join(opt.ckpt, model_name)
print(">>> loading ckpt from '{}'".format(model_path))
ckpt = torch.load(model_path)
net.load_state_dict(ckpt['state_dict'])
print(">>> ckpt loaded (epoch: {} | err: {})".format(ckpt['epoch'], ckpt['err']))
net.eval()
test_dir = '/scratch/hu/pose_forecast/hot3d_hoigaze/'
test_file = 'P0002_016222d1_kitchen_0_1527_'
test_path = test_dir + test_file
test_dataset = hot3d_aria_single_dataset.hot3d_aria_dataset(test_path, seq_len)
print("Test data size: {}".format(test_dataset.dataset.shape))
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
predictions = []
for i, (data) in enumerate(test_loader):
batch_size, seq_n, dim = data.shape
joint_number = opt.joint_number
object_num = opt.object_num
data = data.float().to(opt.cuda_idx)
ground_truth = data.clone()[:, :, :3]
joints = data.clone()[:, :, 3:(joint_number+1)*3]
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
input = torch.cat((joints, head_directions), dim=2)
input = torch.cat((input, object_positions), dim=2)
hand_prd = hand_model(input)
hand_prd = torch.nn.functional.softmax(hand_prd, dim=2)
input = torch.cat((input, hand_prd), dim=2)
prediction = net(input)
prediction_cpu = torch.cat((ground_truth, head_directions), dim=2)
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
prediction_cpu = prediction_cpu.cpu().data.numpy()
if len(predictions) == 0:
predictions = prediction_cpu
else:
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
predictions = predictions.reshape(-1, predictions.shape[2])
ground_truth = predictions[:, :3]
head = predictions[:, 3:6]
head_cos = np.sum(head*ground_truth, 1)
head_cos = np.clip(head_cos, -1, 1)
head_errors = np.arccos(head_cos)/np.pi * 180.0
print('Average baseline error: {:.2f} degree'.format(np.mean(head_errors)))
prediction = predictions[:, -3:]
prd_cos = np.sum(prediction*ground_truth, 1)
prd_cos = np.clip(prd_cos, -1, 1)
prediction_errors = np.arccos(prd_cos)/np.pi * 180.0
print('Average prediction error: {:.2f} degree'.format(np.mean(prediction_errors)))
save_dir = '/scratch/hu/pose_forecast/hot3d_hoigaze_prd/'
save_path = save_dir + test_file + "hoigaze.npy"
np.save(save_path, prediction)
save_path = save_dir + test_file + "gaze.npy"
np.save(save_path, ground_truth)
def acos_safe(x, eps=1e-6):
slope = np.arccos(1-eps) / eps
buf = torch.empty_like(x)
good = abs(x) <= 1-eps
bad = ~good
sign = torch.sign(x[bad])
buf[good] = torch.acos(x[good])
buf[bad] = torch.acos(sign * (1 - eps)) - slope*sign*(abs(x[bad]) - 1 + eps)
return buf
def run_model(net, optimizer=None, is_train=1, data_loader=None, opt=None):
if is_train == 1:
net.train()
else:
net.eval()
if opt.is_eval and opt.save_predictions:
predictions = []
prediction_error_average = 0
baseline_error_average = 0
criterion = torch.nn.MSELoss(reduction='none')
n = 0
input_n = opt.seq_len
for i, (data) in enumerate(data_loader):
batch_size, seq_n, dim = data.shape
joint_number = opt.joint_number
object_num = opt.object_num
# when only one sample in this batch
if batch_size == 1 and is_train == 1:
continue
n += batch_size
data = data.float().to(opt.cuda_idx)
ground_truth = data.clone()[:, :, :3]
joints = data.clone()[:, :, 3:(joint_number+1)*3]
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
attended_hand_prd = data.clone()[:, :, (joint_number+2+8*object_num*2)*3:(joint_number+2+8*object_num*2)*3+2]
attended_hand_gt = data.clone()[:, :, (joint_number+2+8*object_num*2)*3+2:(joint_number+2+8*object_num*2)*3+3]
input = torch.cat((joints, head_directions), dim=2)
if object_num > 0:
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
input = torch.cat((input, object_positions), dim=2)
input = torch.cat((input, attended_hand_prd), dim=2)
input = torch.cat((input, attended_hand_gt), dim=2)
prediction = net(input, input_n=input_n)
if opt.is_eval and opt.save_predictions:
# ground_truth + joints + head_directions + object_positions + attended_hand_prd + attended_hand_gt + predictions
prediction_cpu = torch.cat((ground_truth, input), dim=2)
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
prediction_cpu = prediction_cpu.cpu().data.numpy()
if len(predictions) == 0:
predictions = prediction_cpu
else:
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
gaze_head_cos = torch.sum(ground_truth*head_directions, dim=2, keepdim=True)
gaze_weight = torch.where(gaze_head_cos>opt.gaze_head_cos_threshold, opt.gaze_head_loss_factor, 1.0)
loss = criterion(ground_truth, prediction)
loss = torch.mean(loss*gaze_weight)
if is_train == 1:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Calculate prediction errors
error = torch.mean(acos_safe(torch.sum(ground_truth*prediction, 2)))/torch.tensor(math.pi) * 180.0
prediction_error_average += error.cpu().data.numpy() * batch_size
# Use head directions as the baseline
baseline_error = torch.mean(acos_safe(torch.sum(ground_truth*head_directions, 2)))/torch.tensor(math.pi) * 180.0
baseline_error_average += baseline_error.cpu().data.numpy() * batch_size
result = {}
result["prediction_error_average"] = prediction_error_average / n
result["baseline_error_average"] = baseline_error_average / n
if opt.is_eval and opt.save_predictions:
return result, predictions
else:
return result
if __name__ == '__main__':
option = options().parse()
if option.is_eval == False:
main(option)
else:
eval(option)
#eval_single(option)

View file

@ -0,0 +1,26 @@
## Code to process the HOT3D dataset
## Usage:
Step 1: Follow the instructions at the official repository https://github.com/facebookresearch/hot3d to prepare the environment and download the dataset. You should also enable jupyter because the processing codes are run on jupyter.
Step 2: Set 'dataset_path', 'dataset_processed_path', 'object_library_path', and 'mano_hand_model_path' in 'hot3d_aria_preprocessing.ipynb', put 'hot3d_aria_preprocessing.ipynb', 'hot3d_aria_scene.csv', 'hot3d_objects.csv', and 'utils' into the official repository ('hot3d/hot3d/'), and run it to process the dataset.
Step 3: It is optional but highly recommended to set 'data_path', 'object_library_path', and 'mano_hand_model_path' in 'hot3d_aria_visualisation.ipynb', put 'hot3d_aria_visualisation.ipynb' and 'mano_hand_pose_init' into the official repository ('hot3d/hot3d/'), and run it to visualise and get familiar with the dataset.
## Citations
```bibtex
@inproceedings{hu25hoigaze,
title={HOIGaze: Gaze Estimation During Hand-Object Interactions in Extended Reality Exploiting Eye-Hand-Head Coordination},
author={Hu, Zhiming and Haeufle, Daniel and Schmitt, Syn and Bulling, Andreas},
booktitle={Proceedings of the 2025 ACM Special Interest Group on Computer Graphics and Interactive Techniques},
year={2025}}
@article{banerjee2024introducing,
title={Introducing HOT3D: An Egocentric Dataset for 3D Hand and Object Tracking},
author={Banerjee, Prithviraj and Shkodrani, Sindi and Moulon, Pierre and Hampali, Shreyas and Zhang, Fan and Fountain, Jade and Miller, Edward and Basol, Selen and Newcombe, Richard and Wang, Robert and others},
journal={arXiv preprint arXiv:2406.09598},
year={2024}}
```

View file

@ -0,0 +1,565 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "bbfb8da4-14a5-44d8-b9b0-d66f032f09fb",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.nice(5)\n",
"import rerun as rr\n",
"import numpy as np\n",
"from math import tan\n",
"import time\n",
"from utils import remake_dir\n",
"import pandas as pd\n",
"from dataset_api import Hot3dDataProvider\n",
"from data_loaders.loader_object_library import load_object_library\n",
"from data_loaders.mano_layer import MANOHandModel\n",
"from data_loaders.loader_masks import combine_mask_data, load_mask_data, MaskData\n",
"from data_loaders.loader_hand_poses import Handedness, HandPose\n",
"from data_loaders.loader_object_library import ObjectLibrary\n",
"from data_loaders.headsets import Headset\n",
"from projectaria_tools.core.stream_id import StreamId\n",
"from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions\n",
"from projectaria_tools.core.sophus import SE3\n",
"\n",
"\n",
"dataset_path = '/datasets/public/zhiming_datasets/hot3d/aria/'\n",
"dataset_processed_path = '/scratch/hu/pose_forecast/hot3d_hoigaze/'\n",
"object_library_path = '/datasets/public/zhiming_datasets/hot3d/assets/'\n",
"mano_hand_model_path = '/datasets/public/zhiming_datasets/hot3d/mano_v1_2/models/'\n",
"remake_dir(dataset_processed_path)\n",
"dataset_info = pd.read_csv('hot3d_aria_scene.csv')\n",
"valid_frame_length = 60 # 60 frames -> 2 seconds, dropout the recordings that are too short\n",
"\n",
"# init the object library\n",
"if not os.path.exists(object_library_path):\n",
" print(\"invalid library path.\")\n",
" print(\"please do update the path to VALID values for your system.\")\n",
" raise\n",
"object_library = load_object_library(object_library_folderpath=object_library_path)\n",
"\n",
"# load the bounding box information of the objects\n",
"object_info = pd.read_csv('hot3d_objects.csv')\n",
"object_bbx = {}\n",
"for i, uid in enumerate(object_info['object_uid']): \n",
" bbx_x_min = object_info['bbx_x_min'][i]\n",
" bbx_x_max = object_info['bbx_x_max'][i]\n",
" bbx_y_min = object_info['bbx_y_min'][i]\n",
" bbx_y_max = object_info['bbx_y_max'][i]\n",
" bbx_z_min = object_info['bbx_z_min'][i]\n",
" bbx_z_max = object_info['bbx_z_max'][i]\n",
" bbx = [bbx_x_min, bbx_x_max, bbx_y_min, bbx_y_max, bbx_z_min, bbx_z_max]\n",
" object_bbx[str(uid)] = bbx\n",
" \n",
"# init the HANDs model. If None, the UmeTrack HANDs model will be used\n",
"mano_hand_model = None\n",
"if mano_hand_model_path is not None:\n",
" mano_hand_model = MANOHandModel(mano_hand_model_path)\n",
" \n",
"for i, seq in enumerate(dataset_info['sequence_name']):\n",
" scene = dataset_info['scene'][i]\n",
" print(\"\\nprocessing {}th seq: {}, scene: {}...\".format(i+1, seq, scene))\n",
" seq_path = dataset_path + seq + '/'\n",
" if not os.path.exists(seq_path):\n",
" print(\"invalid input sequence path.\")\n",
" print(\"please do update the path to VALID values for your system.\")\n",
" raise\n",
" save_path = dataset_processed_path + seq + '_' + scene + '_'\n",
"\n",
" # segment the sequence into valid and invalid parts using the masks \n",
" mask_list = [\n",
" \"masks/mask_object_pose_available.csv\",\n",
" \"masks/mask_hand_pose_available.csv\", \n",
" \"masks/mask_headset_pose_available.csv\",\n",
" #\"masks/mask_object_visibility.csv\", \n",
" #\"masks/mask_hand_visible.csv\",\n",
" #\"masks/mask_good_exposure.csv\",\n",
" \"masks/mask_qa_pass.csv\"]\n",
" \n",
" # load the referred masks\n",
" mask_data_list = []\n",
" for it in mask_list:\n",
" if os.path.exists(os.path.join(seq_path, it)):\n",
" ret = load_mask_data(os.path.join(seq_path, it))\n",
" mask_data_list.append(ret)\n",
" # combine the masks (you can choose logical \"and\"/\"or\")\n",
" combined_masks = combine_mask_data(mask_data_list, \"and\")\n",
" masks = []\n",
" for value in combined_masks.data['214-1'].values():\n",
" masks.append(value)\n",
" print(\"valid frames: {}/{}\".format(sum(masks), len(masks)))\n",
" \n",
" # initialize hot3d data provider\n",
" hot3d_data_provider = Hot3dDataProvider(\n",
" sequence_folder=seq_path,\n",
" object_library=object_library,\n",
" mano_hand_model=mano_hand_model)\n",
" #print(f\"data_provider statistics: {hot3d_data_provider.get_data_statistics()}\") \n",
" # alias over the HAND pose data provider\n",
" hand_data_provider = hot3d_data_provider.mano_hand_data_provider if hot3d_data_provider.mano_hand_data_provider is not None else hot3d_data_provider.umetrack_hand_data_provider\n",
" # alias over the Object pose data provider\n",
" object_pose_data_provider = hot3d_data_provider.object_pose_data_provider\n",
" # alias over the HEADSET/Device pose data provider\n",
" device_pose_provider = hot3d_data_provider.device_pose_data_provider\n",
" # alias over the Device data provider\n",
" device_data_provider = hot3d_data_provider.device_data_provider\n",
" device_calibration = device_data_provider.get_device_calibration()\n",
" transform_device_cpf = device_calibration.get_transform_device_cpf()\n",
" # retrieve a list of timestamps for the sequence (in nanoseconds)\n",
" timestamps = device_data_provider.get_sequence_timestamps()\n",
" \n",
" # segment valid data\n",
" index = 0\n",
" valid_frames = 0\n",
" start_frames = []\n",
" end_frames = []\n",
" while(index<len(masks)):\n",
" value = masks[index] \n",
" if value == True:\n",
" start = index\n",
" start_frames.append(start)\n",
" #print(\"start at: {}\".format(start))\n",
" while(value == True):\n",
" index += 1\n",
" if index<len(masks):\n",
" value = masks[index]\n",
" else:\n",
" break \n",
" end = index-1\n",
" end_frames.append(end)\n",
" #print(\"end at: {}\".format(end))\n",
" valid_frames += end - start + 1 \n",
" else:\n",
" index += 1\n",
" \n",
" segment_num = len(start_frames)\n",
" local_time = time.asctime(time.localtime(time.time()))\n",
" print('\\nprocessing starts at ' + local_time) \n",
" for i in range(segment_num):\n",
" start_frame = start_frames[i]\n",
" end_frame = end_frames[i]\n",
" frame_length = end_frame - start_frame + 1\n",
" if frame_length < valid_frame_length:\n",
" continue\n",
" print(\"start frame: {}, end frame: {}, length: {}\".format(start_frame, end_frame, frame_length))\n",
" \n",
" timestamps_data = np.zeros((frame_length, 1))\n",
" head_data = np.zeros((frame_length, 10)) # head_direction (3) + head_translation (3) + head_rotation (4, quat_xyzw)\n",
" gaze_data = np.zeros((frame_length, 6)) # gaze_direction (3) + gaze_center_in_world (3) \n",
" hand_data = np.zeros((frame_length, 44)) # left_hand (22) + right_hand (22), hand = wrist_pose (7, translation (3) + rotation (4)) + joint_angles (15) \n",
" hand_joint_data = np.zeros((frame_length, 122)) # left_hand (20*3) + right_hand (20*3) + attended_hand_gt + attended_hand_baseline (closest_hand) \n",
" hand_joint_initial_data = np.zeros((frame_length, 122)) # left_hand (20*3) + right_hand (20*3) + attended_hand_gt + closest_hand\n",
" object_data = np.zeros((frame_length, 48)) # object_data (8) * 6 objects (at most 6 objects), object_data = object_uid (1) + object_pose (7, translation (3) + rotation (4)) \n",
" object_bbx_data = np.zeros((frame_length, 144)) # bounding box information: 6 objects (at most 6 objects) * 8 vertexes * 3\n",
" object_bbx_left_hand_data = np.zeros((frame_length, 144)) # bounding box information of the objects ranked using distances to the left hand\n",
" object_bbx_right_hand_data = np.zeros((frame_length, 144)) # bounding box information of the objects ranked using distances to the right hand\n",
" object_bbx_left_hand_initial_data = np.zeros((frame_length, 144)) # bounding box information of the objects ranked using distances to the left hand\n",
" object_bbx_right_hand_initial_data = np.zeros((frame_length, 144)) # bounding box information of the objects ranked using distances to the right hand\n",
" \n",
" # extract the valid frames\n",
" for frame in range(start_frame, end_frame+1):\n",
" timestamp_ns = timestamps[frame]\n",
" timestamps_data[frame-start_frame] = timestamp_ns\n",
" \n",
" # extract head data\n",
" headset_pose3d_with_dt = device_pose_provider.get_pose_at_timestamp(\n",
" timestamp_ns=timestamp_ns,\n",
" time_query_options=TimeQueryOptions.CLOSEST,\n",
" time_domain=TimeDomain.TIME_CODE) \n",
" headset_pose3d = headset_pose3d_with_dt.pose3d\n",
" T_world_device = headset_pose3d.T_world_device\n",
" # use cpf pose as head pose, see https://facebookresearch.github.io/projectaria_tools/docs/data_formats/coordinate_convention/3d_coordinate_frame_convention\n",
" T_world_cpf = T_world_device @ transform_device_cpf \n",
" head_translation = T_world_cpf.translation()[0]\n",
" head_center_in_cpf = np.array([0, 0, 1.0], dtype = np.float64)\n",
" head_center_in_world = T_world_cpf @ head_center_in_cpf\n",
" head_center_in_world = head_center_in_world.reshape(3, )\n",
" head_direction = head_center_in_world - head_translation\n",
" head_direction = np.array([x / np.linalg.norm(head_direction) for x in head_direction]) \n",
" head_rotation = np.roll(T_world_cpf.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w\n",
" head_data[frame-start_frame, 0:3] = head_direction\n",
" head_data[frame-start_frame, 3:6] = head_translation\n",
" head_data[frame-start_frame, 6:10] = head_rotation\n",
" \n",
" # extract eye gaze data\n",
" aria_eye_gaze_data = device_data_provider.get_eye_gaze(timestamp_ns) \n",
" yaw = aria_eye_gaze_data.yaw\n",
" pitch = aria_eye_gaze_data.pitch\n",
" depth = aria_eye_gaze_data.depth\n",
" if depth == 0:\n",
" depth = 1\n",
" gaze_center_in_cpf = np.array([tan(yaw), tan(pitch), 1.0], dtype = np.float64)*depth\n",
" gaze_center_in_world = T_world_cpf @ gaze_center_in_cpf\n",
" gaze_center_in_world = gaze_center_in_world.reshape(3, )\n",
" gaze_direction = gaze_center_in_world - head_translation\n",
" gaze_direction = np.array([x / np.linalg.norm(gaze_direction) for x in gaze_direction])\n",
" # in rare cases, yaw, pitch is nan\n",
" if np.isnan(np.sum(gaze_direction)):\n",
" # use previous frame as an alternative\n",
" gaze_direction = gaze_data[frame-start_frame-1, 0:3]\n",
" gaze_center_in_world = gaze_data[frame-start_frame-1, 3:6]\n",
" gaze_data[frame-start_frame, 0:3] = gaze_direction\n",
" gaze_data[frame-start_frame, 3:6] = gaze_center_in_world\n",
" \n",
" # extract hand data\n",
" hand_poses_with_dt = hand_data_provider.get_pose_at_timestamp(\n",
" timestamp_ns=timestamp_ns,\n",
" time_query_options=TimeQueryOptions.CLOSEST,\n",
" time_domain=TimeDomain.TIME_CODE) \n",
" hand_pose_collection = hand_poses_with_dt.pose3d_collection\n",
" left_hand = hand_pose_collection.poses[Handedness.Left]\n",
" left_hand_translation = left_hand.wrist_pose.translation()[0]\n",
" left_hand_rotation = np.roll(left_hand.wrist_pose.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w\n",
" left_hand_joint_angles = left_hand.joint_angles\n",
" left_hand_joints = hand_data_provider.get_hand_landmarks(left_hand).numpy().reshape(-1)\n",
" left_hand_initial = HandPose(Handedness.Left, left_hand.wrist_pose, np.zeros(15))\n",
" left_hand_initial_joints = hand_data_provider.get_hand_landmarks(left_hand_initial).numpy().reshape(-1) \n",
" right_hand = hand_pose_collection.poses[Handedness.Right]\n",
" right_hand_translation = right_hand.wrist_pose.translation()[0]\n",
" right_hand_rotation = np.roll(right_hand.wrist_pose.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w\n",
" right_hand_joint_angles = right_hand.joint_angles\n",
" right_hand_joints = hand_data_provider.get_hand_landmarks(right_hand).numpy().reshape(-1)\n",
" right_hand_initial = HandPose(Handedness.Right, right_hand.wrist_pose, np.zeros(15))\n",
" right_hand_initial_joints = hand_data_provider.get_hand_landmarks(right_hand_initial).numpy().reshape(-1)\n",
" \n",
" left_hand_direction = np.mean(left_hand_joints.reshape((20, 3)), axis=0) - head_translation\n",
" left_hand_direction = np.array([x / np.linalg.norm(left_hand_direction) for x in left_hand_direction]) \n",
" left_hand_distance_to_gaze = np.arccos(np.sum(gaze_direction*left_hand_direction))\n",
" right_hand_direction = np.mean(right_hand_joints.reshape((20, 3)), axis=0) - head_translation\n",
" right_hand_direction = np.array([x / np.linalg.norm(right_hand_direction) for x in right_hand_direction]) \n",
" right_hand_distance_to_gaze = np.arccos(np.sum(gaze_direction*right_hand_direction))\n",
" if left_hand_distance_to_gaze < right_hand_distance_to_gaze:\n",
" hand_joint_data[frame-start_frame, 120:121] = 0\n",
" else:\n",
" hand_joint_data[frame-start_frame, 120:121] = 1\n",
"\n",
" left_hand_initial_direction = np.mean(left_hand_initial_joints.reshape((20, 3)), axis=0) - head_translation\n",
" left_hand_initial_direction = np.array([x / np.linalg.norm(left_hand_initial_direction) for x in left_hand_initial_direction]) \n",
" left_hand_initial_distance_to_gaze = np.arccos(np.sum(gaze_direction*left_hand_initial_direction))\n",
" right_hand_initial_direction = np.mean(right_hand_initial_joints.reshape((20, 3)), axis=0) - head_translation\n",
" right_hand_initial_direction = np.array([x / np.linalg.norm(right_hand_initial_direction) for x in right_hand_initial_direction]) \n",
" right_hand_initial_distance_to_gaze = np.arccos(np.sum(gaze_direction*right_hand_initial_direction))\n",
" if left_hand_initial_distance_to_gaze < right_hand_initial_distance_to_gaze:\n",
" hand_joint_initial_data[frame-start_frame, 120:121] = 0\n",
" else:\n",
" hand_joint_initial_data[frame-start_frame, 120:121] = 1\n",
" \n",
" hand_data[frame-start_frame, 0:3] = left_hand_translation\n",
" hand_data[frame-start_frame, 3:7] = left_hand_rotation\n",
" hand_data[frame-start_frame, 7:22] = left_hand_joint_angles\n",
" hand_data[frame-start_frame, 22:25] = right_hand_translation\n",
" hand_data[frame-start_frame, 25:29] = right_hand_rotation\n",
" hand_data[frame-start_frame, 29:44] = right_hand_joint_angles\n",
" hand_joint_data[frame-start_frame, 0:60] = left_hand_joints\n",
" hand_joint_data[frame-start_frame, 60:120] = right_hand_joints\n",
" hand_joint_initial_data[frame-start_frame, 0:60] = left_hand_initial_joints\n",
" hand_joint_initial_data[frame-start_frame, 60:120] = right_hand_initial_joints\n",
" \n",
" # extract object data\n",
" object_poses_with_dt = object_pose_data_provider.get_pose_at_timestamp(\n",
" timestamp_ns=timestamp_ns,\n",
" time_query_options=TimeQueryOptions.CLOSEST,\n",
" time_domain=TimeDomain.TIME_CODE)\n",
" objects_pose3d = object_poses_with_dt.pose3d_collection.poses\n",
" object_num = len(objects_pose3d)\n",
" objects_distance_to_left_hand = {}\n",
" objects_distance_to_right_hand = {} \n",
" objects_distance_to_left_hand_initial = {}\n",
" objects_distance_to_right_hand_initial = {}\n",
" objects_pose3d_dict = {}\n",
" item = 0\n",
" for (object_uid, object_pose3d) in objects_pose3d.items():\n",
" object_translation = object_pose3d.T_world_object.translation()[0] \n",
" object_distance_to_left_hand = np.mean(np.linalg.norm(left_hand_joints.reshape((20, 3))-object_translation, axis=1))\n",
" object_distance_to_right_hand = np.mean(np.linalg.norm(right_hand_joints.reshape((20, 3))-object_translation, axis=1))\n",
" object_distance_to_left_hand_initial = np.mean(np.linalg.norm(left_hand_initial_joints.reshape((20, 3))-object_translation, axis=1))\n",
" object_distance_to_right_hand_initial = np.mean(np.linalg.norm(right_hand_initial_joints.reshape((20, 3))-object_translation, axis=1)) \n",
" objects_distance_to_left_hand[object_uid] = object_distance_to_left_hand \n",
" objects_distance_to_right_hand[object_uid] = object_distance_to_right_hand\n",
" objects_distance_to_left_hand_initial[object_uid] = object_distance_to_left_hand_initial\n",
" objects_distance_to_right_hand_initial[object_uid] = object_distance_to_right_hand_initial\n",
" objects_pose3d_dict[object_uid] = object_pose3d.T_world_object \n",
" item += 1\n",
" \n",
" objects_distance_to_left_hand_sorted = sorted(objects_distance_to_left_hand.items(), key = lambda kv:(kv[1], kv[0]))\n",
" objects_distance_to_right_hand_sorted = sorted(objects_distance_to_right_hand.items(), key = lambda kv:(kv[1], kv[0])) \n",
" left_object_closest_uid = objects_distance_to_left_hand_sorted[0][0]\n",
" left_object_closest_distance = objects_distance_to_left_hand_sorted[0][1]\n",
" right_object_closest_uid = objects_distance_to_right_hand_sorted[0][0]\n",
" right_object_closest_distance = objects_distance_to_right_hand_sorted[0][1]\n",
" if left_object_closest_distance < right_object_closest_distance:\n",
" hand_joint_data[frame-start_frame, -1] = 0\n",
" else:\n",
" hand_joint_data[frame-start_frame, -1] = 1\n",
"\n",
" objects_distance_to_left_hand_initial_sorted = sorted(objects_distance_to_left_hand_initial.items(), key = lambda kv:(kv[1], kv[0]))\n",
" objects_distance_to_right_hand_initial_sorted = sorted(objects_distance_to_right_hand_initial.items(), key = lambda kv:(kv[1], kv[0]))\n",
" left_initial_object_closest_uid = objects_distance_to_left_hand_initial_sorted[0][0]\n",
" left_initial_object_closest_distance = objects_distance_to_left_hand_initial_sorted[0][1]\n",
" right_initial_object_closest_uid = objects_distance_to_right_hand_initial_sorted[0][0]\n",
" right_initial_object_closest_distance = objects_distance_to_right_hand_initial_sorted[0][1]\n",
" if left_initial_object_closest_distance < right_initial_object_closest_distance:\n",
" hand_joint_initial_data[frame-start_frame, -1] = 0\n",
" else:\n",
" hand_joint_initial_data[frame-start_frame, -1] = 1\n",
" \n",
" item = 0\n",
" for object_uid in objects_pose3d_dict:\n",
" object_pose3d = objects_pose3d_dict[object_uid]\n",
" object_translation = object_pose3d.translation()[0]\n",
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
" object_data[frame-start_frame, item*8:item*8+1] = object_uid\n",
" object_data[frame-start_frame, item*8+1:item*8+4] = object_translation\n",
" object_data[frame-start_frame, item*8+4:item*8+8] = object_rotation\n",
" bbx = object_bbx[object_uid]\n",
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
" x_min = bbx[0]\n",
" x_max = bbx[1]\n",
" y_min = bbx[2]\n",
" y_max = bbx[3]\n",
" z_min = bbx[4]\n",
" z_max = bbx[5]\n",
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
" object_bbx_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
" item += 1\n",
" \n",
" for item in range(len(objects_distance_to_left_hand_sorted)):\n",
" object_uid = objects_distance_to_left_hand_sorted[item][0]\n",
" object_pose3d = objects_pose3d_dict[object_uid] \n",
" object_translation = object_pose3d.translation()[0]\n",
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
" bbx = object_bbx[object_uid]\n",
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
" x_min = bbx[0]\n",
" x_max = bbx[1]\n",
" y_min = bbx[2]\n",
" y_max = bbx[3]\n",
" z_min = bbx[4]\n",
" z_max = bbx[5]\n",
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
" object_bbx_left_hand_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
"\n",
" for item in range(len(objects_distance_to_right_hand_sorted)):\n",
" object_uid = objects_distance_to_right_hand_sorted[item][0]\n",
" object_pose3d = objects_pose3d_dict[object_uid] \n",
" object_translation = object_pose3d.translation()[0]\n",
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
" bbx = object_bbx[object_uid]\n",
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
" x_min = bbx[0]\n",
" x_max = bbx[1]\n",
" y_min = bbx[2]\n",
" y_max = bbx[3]\n",
" z_min = bbx[4]\n",
" z_max = bbx[5]\n",
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
" object_bbx_right_hand_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
"\n",
" for item in range(len(objects_distance_to_left_hand_initial_sorted)):\n",
" object_uid = objects_distance_to_left_hand_initial_sorted[item][0]\n",
" object_pose3d = objects_pose3d_dict[object_uid] \n",
" object_translation = object_pose3d.translation()[0]\n",
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
" bbx = object_bbx[object_uid]\n",
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
" x_min = bbx[0]\n",
" x_max = bbx[1]\n",
" y_min = bbx[2]\n",
" y_max = bbx[3]\n",
" z_min = bbx[4]\n",
" z_max = bbx[5]\n",
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
" object_bbx_left_hand_initial_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
"\n",
" for item in range(len(objects_distance_to_right_hand_initial_sorted)):\n",
" object_uid = objects_distance_to_right_hand_initial_sorted[item][0]\n",
" object_pose3d = objects_pose3d_dict[object_uid] \n",
" object_translation = object_pose3d.translation()[0]\n",
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
" bbx = object_bbx[object_uid]\n",
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
" x_min = bbx[0]\n",
" x_max = bbx[1]\n",
" y_min = bbx[2]\n",
" y_max = bbx[3]\n",
" z_min = bbx[4]\n",
" z_max = bbx[5]\n",
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
" object_bbx_right_hand_initial_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
" \n",
" # save the data\n",
" timestamps_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_timestamps.npy'\n",
" head_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_head.npy'\n",
" gaze_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_gaze.npy'\n",
" hand_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_hand.npy' \n",
" hand_joint_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_handjoints.npy'\n",
" hand_joint_initial_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_inithandjoints.npy'\n",
" object_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_objects.npy'\n",
" object_bbx_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_bbx.npy'\n",
" object_bbx_left_hand_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_bbxleft.npy'\n",
" object_bbx_right_hand_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_bbxright.npy'\n",
" object_bbx_left_hand_initial_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_initbbxleft.npy'\n",
" object_bbx_right_hand_initial_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_initbbxright.npy'\n",
" \n",
" np.save(timestamps_path, timestamps_data)\n",
" np.save(head_path, head_data)\n",
" np.save(gaze_path, gaze_data)\n",
" np.save(hand_path, hand_data)\n",
" np.save(hand_joint_path, hand_joint_data)\n",
" np.save(hand_joint_initial_path, hand_joint_initial_data)\n",
" np.save(object_path, object_data)\n",
" np.save(object_bbx_path, object_bbx_data)\n",
" np.save(object_bbx_left_hand_path, object_bbx_left_hand_data)\n",
" np.save(object_bbx_right_hand_path, object_bbx_right_hand_data)\n",
" np.save(object_bbx_left_hand_initial_path, object_bbx_left_hand_initial_data)\n",
" np.save(object_bbx_right_hand_initial_path, object_bbx_right_hand_initial_data)\n",
" \n",
" local_time = time.asctime(time.localtime(time.time()))\n",
" print('\\nprocessing ends at ' + local_time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec30e81d-8812-4665-be58-00a7e0aa1915",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -0,0 +1,137 @@
sequence_name,scene
P0001_10a27bf7,room
P0001_15c4300c,kitchen
P0001_23fa0ee8,office
P0001_4bf4e21a,room
P0001_550ea2ac,room
P0001_624f2ba9,kitchen
P0001_83b755aa,room
P0001_8d136980,room
P0001_95eabc0a,kitchen
P0001_9b6feab7,office
P0001_9c030609,room
P0001_a68492d5,kitchen
P0001_a9d6c83d,room
P0001_b2bcbe28,room
P0001_f6cc0cc8,office
P0001_f71fc9b1,room
P0002_016222d1,kitchen
P0002_2ea9af5b,kitchen
P0002_59a84a3a,kitchen
P0002_65085bfc,kitchen
P0002_c68438ce,kitchen
P0003_1fa0bb65,room
P0003_5766eae8,kitchen
P0003_743d3087,office
P0003_c701bd11,office
P0003_cbd17f20,kitchen
P0003_e3a74169,kitchen
P0003_ebdc6ff7,kitchen
P0003_f4730606,kitchen
P0009_02511c2f,room
P0009_0e02bf39,kitchen
P0009_1b21bb01,kitchen
P0009_247b5403,kitchen
P0009_24af5475,room
P0009_37e30628,office
P0009_5bb7cc2a,room
P0009_5f2583e0,room
P0009_88a58b3b,room
P0009_8c0caf6c,kitchen
P0009_8c95667c,room
P0009_911a4ba5,room
P0009_927973d7,office
P0009_9e03121a,kitchen
P0009_b1e71d7f,room
P0009_b7d8c222,kitchen
P0009_cf323827,room
P0009_e27ad19f,room
P0009_e71e2f24,kitchen
P0009_ea91844d,office
P0010_0ecbf39f,kitchen
P0010_1573a837,office
P0010_160e551c,room
P0010_1c9fe708,kitchen
P0010_41c4c626,kitchen
P0010_573b2de6,office
P0010_6d15aa57,room
P0010_8ff3e5c4,room
P0010_924e574e,kitchen
P0010_99385ad8,room
P0010_b152143e,office
P0010_e481fd15,room
P0010_fd1891a8,room
P0011_0c2d00ed,kitchen
P0011_11475e24,kitchen
P0011_2255f410,kitchen
P0011_30a5b61d,room
P0011_451a7734,room
P0011_47878e48,office
P0011_4dcdcfec,office
P0011_72efb935,kitchen
P0011_76ea6d47,kitchen
P0011_7a8886e7,room
P0011_7c7e8d86,room
P0011_a497658f,office
P0011_ad5f7dee,room
P0011_b0c895c1,kitchen
P0011_ccc678c7,room
P0011_cd22f5e0,kitchen
P0011_cee8fe4f,room
P0011_d0f907a0,office
P0011_dbf1ebb7,room
P0011_e5cb6a80,kitchen
P0011_ff7cfc3a,kitchen
P0012_0323d770,room
P0012_0a21e4c2,room
P0012_119de519,room
P0012_130a66e1,office
P0012_1c45bbd9,office
P0012_44c5f677,kitchen
P0012_476bae57,kitchen
P0012_6e4f7815,office
P0012_73e66984,room
P0012_915e71c6,office
P0012_af3dab9a,room
P0012_b8fc4c1b,kitchen
P0012_c06d939b,kitchen
P0012_c4353c31,room
P0012_ca1f6626,room
P0012_d6272ce1,office
P0012_d85e10f6,kitchen
P0012_db543fe8,kitchen
P0012_e846d3cc,kitchen
P0012_e97d31b6,kitchen
P0012_f1a33781,room
P0012_f7e3880b,room
P0014_0de48e2c,room
P0014_1fdca00e,office
P0014_24cb3bf0,kitchen
P0014_65c6a968,office
P0014_6db96fd0,kitchen
P0014_8254f925,kitchen
P0014_84ea2dcc,kitchen
P0014_9a25ec6a,office
P0014_9b7a0725,room
P0014_e40eec5d,kitchen
P0014_f7ba43e0,room
P0015_179e1b84,kitchen
P0015_1eb4f17f,kitchen
P0015_367cc58d,room
P0015_3a9bb2ae,room
P0015_3c7f5241,room
P0015_3f46732d,room
P0015_42b8b389,room
P0015_4e5d1c14,room
P0015_53358a64,office
P0015_60573a3b,kitchen
P0015_745920f0,kitchen
P0015_7e95628d,room
P0015_7fc6548d,kitchen
P0015_b0c5102b,room
P0015_c3e1f590,kitchen
P0015_cc5dfd13,room
P0015_cc739faa,room
P0015_dbe00981,room
P0015_e3f96b3f,office
P0015_e7458eb3,office
1 sequence_name scene
2 P0001_10a27bf7 room
3 P0001_15c4300c kitchen
4 P0001_23fa0ee8 office
5 P0001_4bf4e21a room
6 P0001_550ea2ac room
7 P0001_624f2ba9 kitchen
8 P0001_83b755aa room
9 P0001_8d136980 room
10 P0001_95eabc0a kitchen
11 P0001_9b6feab7 office
12 P0001_9c030609 room
13 P0001_a68492d5 kitchen
14 P0001_a9d6c83d room
15 P0001_b2bcbe28 room
16 P0001_f6cc0cc8 office
17 P0001_f71fc9b1 room
18 P0002_016222d1 kitchen
19 P0002_2ea9af5b kitchen
20 P0002_59a84a3a kitchen
21 P0002_65085bfc kitchen
22 P0002_c68438ce kitchen
23 P0003_1fa0bb65 room
24 P0003_5766eae8 kitchen
25 P0003_743d3087 office
26 P0003_c701bd11 office
27 P0003_cbd17f20 kitchen
28 P0003_e3a74169 kitchen
29 P0003_ebdc6ff7 kitchen
30 P0003_f4730606 kitchen
31 P0009_02511c2f room
32 P0009_0e02bf39 kitchen
33 P0009_1b21bb01 kitchen
34 P0009_247b5403 kitchen
35 P0009_24af5475 room
36 P0009_37e30628 office
37 P0009_5bb7cc2a room
38 P0009_5f2583e0 room
39 P0009_88a58b3b room
40 P0009_8c0caf6c kitchen
41 P0009_8c95667c room
42 P0009_911a4ba5 room
43 P0009_927973d7 office
44 P0009_9e03121a kitchen
45 P0009_b1e71d7f room
46 P0009_b7d8c222 kitchen
47 P0009_cf323827 room
48 P0009_e27ad19f room
49 P0009_e71e2f24 kitchen
50 P0009_ea91844d office
51 P0010_0ecbf39f kitchen
52 P0010_1573a837 office
53 P0010_160e551c room
54 P0010_1c9fe708 kitchen
55 P0010_41c4c626 kitchen
56 P0010_573b2de6 office
57 P0010_6d15aa57 room
58 P0010_8ff3e5c4 room
59 P0010_924e574e kitchen
60 P0010_99385ad8 room
61 P0010_b152143e office
62 P0010_e481fd15 room
63 P0010_fd1891a8 room
64 P0011_0c2d00ed kitchen
65 P0011_11475e24 kitchen
66 P0011_2255f410 kitchen
67 P0011_30a5b61d room
68 P0011_451a7734 room
69 P0011_47878e48 office
70 P0011_4dcdcfec office
71 P0011_72efb935 kitchen
72 P0011_76ea6d47 kitchen
73 P0011_7a8886e7 room
74 P0011_7c7e8d86 room
75 P0011_a497658f office
76 P0011_ad5f7dee room
77 P0011_b0c895c1 kitchen
78 P0011_ccc678c7 room
79 P0011_cd22f5e0 kitchen
80 P0011_cee8fe4f room
81 P0011_d0f907a0 office
82 P0011_dbf1ebb7 room
83 P0011_e5cb6a80 kitchen
84 P0011_ff7cfc3a kitchen
85 P0012_0323d770 room
86 P0012_0a21e4c2 room
87 P0012_119de519 room
88 P0012_130a66e1 office
89 P0012_1c45bbd9 office
90 P0012_44c5f677 kitchen
91 P0012_476bae57 kitchen
92 P0012_6e4f7815 office
93 P0012_73e66984 room
94 P0012_915e71c6 office
95 P0012_af3dab9a room
96 P0012_b8fc4c1b kitchen
97 P0012_c06d939b kitchen
98 P0012_c4353c31 room
99 P0012_ca1f6626 room
100 P0012_d6272ce1 office
101 P0012_d85e10f6 kitchen
102 P0012_db543fe8 kitchen
103 P0012_e846d3cc kitchen
104 P0012_e97d31b6 kitchen
105 P0012_f1a33781 room
106 P0012_f7e3880b room
107 P0014_0de48e2c room
108 P0014_1fdca00e office
109 P0014_24cb3bf0 kitchen
110 P0014_65c6a968 office
111 P0014_6db96fd0 kitchen
112 P0014_8254f925 kitchen
113 P0014_84ea2dcc kitchen
114 P0014_9a25ec6a office
115 P0014_9b7a0725 room
116 P0014_e40eec5d kitchen
117 P0014_f7ba43e0 room
118 P0015_179e1b84 kitchen
119 P0015_1eb4f17f kitchen
120 P0015_367cc58d room
121 P0015_3a9bb2ae room
122 P0015_3c7f5241 room
123 P0015_3f46732d room
124 P0015_42b8b389 room
125 P0015_4e5d1c14 room
126 P0015_53358a64 office
127 P0015_60573a3b kitchen
128 P0015_745920f0 kitchen
129 P0015_7e95628d room
130 P0015_7fc6548d kitchen
131 P0015_b0c5102b room
132 P0015_c3e1f590 kitchen
133 P0015_cc5dfd13 room
134 P0015_cc739faa room
135 P0015_dbe00981 room
136 P0015_e3f96b3f office
137 P0015_e7458eb3 office

View file

@ -0,0 +1,296 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "bbfb8da4-14a5-44d8-b9b0-d66f032f09fb",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.nice(5)\n",
"import rerun as rr\n",
"import numpy as np\n",
"from math import tan\n",
"import time\n",
"from utils import remake_dir\n",
"import pandas as pd\n",
"from data_loaders.ManoHandDataProvider import MANOHandDataProvider\n",
"from data_loaders.loader_object_library import load_object_library\n",
"from data_loaders.mano_layer import MANOHandModel\n",
"from data_loaders.loader_hand_poses import Handedness, HandPose\n",
"from data_loaders.hand_common import LANDMARK_CONNECTIVITY\n",
"from data_loaders.loader_object_library import ObjectLibrary\n",
"from data_loaders.headsets import Headset\n",
"from projectaria_tools.core.stream_id import StreamId\n",
"from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions\n",
"from projectaria_tools.core.sophus import SE3\n",
"from projectaria_tools.utils.rerun_helpers import ToTransform3D\n",
"\n",
"\n",
"data_path = '/scratch/hu/pose_forecast/hot3d_hoigaze/P0001_10a27bf7_room_721_890_'\n",
"timestamps_path = data_path + 'timestamps.npy'\n",
"head_path = data_path + 'head.npy'\n",
"gaze_path = data_path + 'gaze.npy'\n",
"hand_path = data_path + 'hand.npy'\n",
"hand_joint_path = data_path + 'handjoints.npy'\n",
"object_path = data_path + 'objects.npy'\n",
"object_bbx_path = data_path + 'object_bbx.npy'\n",
"object_library_path = '/datasets/public/zhiming_datasets/hot3d/assets/'\n",
"mano_hand_model_path = '/datasets/public/zhiming_datasets/hot3d/mano_v1_2/models/'\n",
"\n",
"show_bbx = False\n",
"show_hand_mesh = True\n",
"# init the object library\n",
"if not os.path.exists(object_library_path):\n",
" print(\"invalid object library path.\")\n",
" print(\"please follow the instructions at https://github.com/facebookresearch/hot3d to Download the HOT3D Assets Dataset\") \n",
" raise \n",
"object_library = load_object_library(object_library_folderpath=object_library_path)\n",
"\n",
"# init the HANDs model\n",
"if not os.path.exists(mano_hand_model_path):\n",
" print(\"invalid mano hand model path.\")\n",
" print(\"please follow the instructions at https://github.com/facebookresearch/hot3d to Download the MANO files\")\n",
" raise\n",
"mano_hand_model = MANOHandModel(mano_hand_model_path)\n",
"\n",
"timestamps_data = np.load(timestamps_path)\n",
"head_data = np.load(head_path) # head_direction (3) + head_translation (3) + head_rotation (4, quat_xyzw)\n",
"gaze_data = np.load(gaze_path) # gaze_direction (3) + gaze_center_in_world (3)\n",
"hand_data = np.load(hand_path) # left_hand (22) + right_hand (22), hand = wrist_pose (7, translation (3) + rotation (4)) + joint_angles (15)\n",
"hand_joint_data = np.load(hand_joint_path) # left_hand (20*3) + right_hand (20*3)\n",
"object_data = np.load(object_path) # object_data (8) * 6 objects (at most 6 objects), object_data = object_uid (1) + object_pose (7, translation (3) + rotation (4))\n",
"object_bbx_data = np.load(object_bbx_path) # bounding box information: 6 objects (at most 6 objects) * 8 vertexes * 3\n",
"frame_length = len(head_data)\n",
"\n",
"# alias over the HAND pose data provider\n",
"hand_data_provider = MANOHandDataProvider('./mano_hand_pose_init/mano_hand_pose_trajectory.jsonl', mano_hand_model)\n",
"# keep track of what 3D assets have been loaded/unloaded so we will load them only when needed\n",
"object_cache_status = {}\n",
"\n",
"# Init a rerun context\n",
"rr.init(\"hot3d-aria\")\n",
"rec = rr.memory_recording()\n",
"\n",
"def log_pose(\n",
" pose: SE3,\n",
" label: str,\n",
" static=False\n",
") -> None:\n",
" rr.log(label, ToTransform3D(pose, False), static=static)\n",
"\n",
"\n",
"for i in range(frame_length):\n",
" timestamp_ns = timestamps_data[i, 0]\n",
" rr.set_time_nanos(\"synchronization_time\", int(timestamp_ns))\n",
" rr.set_time_sequence(\"timestamp\", int(timestamp_ns))\n",
" \n",
" head_direction = head_data[i, 0:3]\n",
" head_translation = head_data[i, 3:6]\n",
" head_rotation = head_data[i, 6:10] \n",
" gaze_direction = gaze_data[i, 0:3]\n",
" gaze_center_in_world = gaze_data[i, 3:6]\n",
" left_hand_translation = hand_data[i, 0:3]\n",
" left_hand_rotation = hand_data[i, 3:7]\n",
" left_hand_joint_angles = hand_data[i, 7:22]\n",
" left_hand_joints = hand_joint_data[i, 0:60].reshape((20, 3))\n",
" right_hand_translation = hand_data[i, 22:25]\n",
" right_hand_rotation = hand_data[i, 25:29]\n",
" right_hand_joint_angles = hand_data[i, 29:44]\n",
" right_hand_joints = hand_joint_data[i, 60:120].reshape((20, 3))\n",
" left_hand_wrist_pose = SE3.from_quat_and_translation(left_hand_rotation[-1], left_hand_rotation[:-1], left_hand_translation)\n",
" left_hand_pose = HandPose(Handedness.Left, left_hand_wrist_pose, left_hand_joint_angles)\n",
" right_hand_wrist_pose = SE3.from_quat_and_translation(right_hand_rotation[-1], right_hand_rotation[:-1], right_hand_translation)\n",
" right_hand_pose = HandPose(Handedness.Right, right_hand_wrist_pose, right_hand_joint_angles)\n",
" \n",
" # use cpf pose as head pose, see https://facebookresearch.github.io/projectaria_tools/docs/data_formats/coordinate_convention/3d_coordinate_frame_convention\n",
" T_world_cpf = SE3.from_quat_and_translation(head_rotation[-1], head_rotation[:-1], head_translation)\n",
" log_pose(pose=T_world_cpf, label=\"world/head_pose\")\n",
" #rr.log(\n",
" #\"world/head_direction\",\n",
" #rr.Points3D([head_translation], radii=[0.003]),\n",
" #rr.Arrows3D(vectors=[head_direction*0.4], colors=[[0, 0.8, 0.8, 0.5]]))\n",
" #log_pose(pose=left_hand_wrist_pose, label=\"world/left_hand_pose\")\n",
" #log_pose(pose=right_hand_wrist_pose, label=\"world/right_hand_pose\")\n",
" \n",
" rr.log(\n",
" \"world/gaze_direction\",\n",
" rr.Points3D([head_translation], radii=[0.003]),\n",
" rr.Arrows3D(vectors=[gaze_direction*0.4], colors=[[0, 0.8, 0.2, 0.5]]))\n",
" #print(\"frame: {}, gaze: {}\".format(i+1119, gaze_direction))\n",
" \n",
" # plot hands as a triangular mesh representation\n",
" if show_hand_mesh:\n",
" left_hand_mesh_vertices = hand_data_provider.get_hand_mesh_vertices(left_hand_pose)\n",
" left_hand_triangles, left_hand_vertex_normals = hand_data_provider.get_hand_mesh_faces_and_normals(left_hand_pose) \n",
" rr.log(\n",
" f\"world/left_hand/mesh_faces\",\n",
" rr.Mesh3D(\n",
" vertex_positions=left_hand_mesh_vertices,\n",
" vertex_normals=left_hand_vertex_normals,\n",
" triangle_indices=left_hand_triangles))\n",
" right_hand_mesh_vertices = hand_data_provider.get_hand_mesh_vertices(right_hand_pose)\n",
" right_hand_triangles, right_hand_vertex_normals = hand_data_provider.get_hand_mesh_faces_and_normals(right_hand_pose)\n",
" rr.log(\n",
" f\"world/right_hand/mesh_faces\",\n",
" rr.Mesh3D(\n",
" vertex_positions=right_hand_mesh_vertices,\n",
" vertex_normals=right_hand_vertex_normals,\n",
" triangle_indices=right_hand_triangles))\n",
" else:\n",
" #left_hand_translation = np.array([0, 0, 0])\n",
" #left_hand_rotation = np.array([0, 0, 0, 1])\n",
" #left_hand_joint_angles = np.zeros(15)\n",
" #left_hand_wrist_pose = SE3.from_quat_and_translation(left_hand_rotation[-1], left_hand_rotation[:-1], left_hand_translation)\n",
" #log_pose(pose=left_hand_wrist_pose, label=\"world/left_hand_pose\")\n",
" #left_hand_pose = HandPose(Handedness.Left, left_hand_wrist_pose, left_hand_joint_angles)\n",
" #left_hand_joints = hand_data_provider.get_hand_landmarks(left_hand_pose)\n",
" #left_hand_wrist = left_hand_joints[5, :].clone()\n",
" #joint_number = left_hand_joints.shape[0]\n",
" #for index in range(joint_number):\n",
" # left_hand_joints[index, :] -= left_hand_wrist\n",
" #for index in range(joint_number): \n",
" # tmp = left_hand_joints[index, :].clone()\n",
" # left_hand_joints[index, 1] = -tmp[2]\n",
" # left_hand_joints[index, 2] = tmp[1]\n",
" #for index in range(joint_number): \n",
" # print(left_hand_joints[index])\n",
" \n",
" #right_hand_translation = np.array([0, 0, 0])\n",
" #right_hand_rotation = np.array([0, 0, 0, 1])\n",
" #right_hand_joint_angles = np.zeros(15)\n",
" #right_hand_wrist_pose = SE3.from_quat_and_translation(right_hand_rotation[-1], right_hand_rotation[:-1], right_hand_translation)\n",
" #log_pose(pose=right_hand_wrist_pose, label=\"world/right_hand_pose\")\n",
" #right_hand_pose = HandPose(Handedness.Right, right_hand_wrist_pose, right_hand_joint_angles)\n",
" #right_hand_joints = hand_data_provider.get_hand_landmarks(right_hand_pose)\n",
" #right_hand_wrist = right_hand_joints[5, :].clone()\n",
" #joint_number = right_hand_joints.shape[0]\n",
" #for index in range(joint_number):\n",
" # right_hand_joints[index, :] -= right_hand_wrist\n",
" #for index in range(joint_number): \n",
" # tmp = right_hand_joints[index, :].clone()\n",
" # right_hand_joints[index, 1] = -tmp[2]\n",
" # right_hand_joints[index, 2] = tmp[1]\n",
" #for index in range(joint_number): \n",
" # print(right_hand_joints[index])\n",
" \n",
" left_hand_skeleton = [connections\n",
" for connectivity in LANDMARK_CONNECTIVITY\n",
" for connections in [[left_hand_joints[it].tolist() for it in connectivity]]] \n",
" rr.log(\n",
" f\"world/left_hand_skeleton\",\n",
" rr.LineStrips3D(left_hand_skeleton, radii=0.002),\n",
" )\n",
" right_hand_skeleton = [connections\n",
" for connectivity in LANDMARK_CONNECTIVITY\n",
" for connections in [[right_hand_joints[it].tolist() for it in connectivity]]] \n",
" rr.log(\n",
" f\"world/right_hand_skeleton\",\n",
" rr.LineStrips3D(right_hand_skeleton, radii=0.002),\n",
" )\n",
" \n",
" # load objects\n",
" object_num_max = 6\n",
" logging_status = {} \n",
" for item in range(object_num_max):\n",
" object_uid = str(int(object_data[i, item*8]))\n",
" if object_uid == '0':\n",
" break\n",
" logging_status[object_uid] = False\n",
" object_num = len(logging_status)\n",
" for item in range(object_num):\n",
" object_uid = str(int(object_data[i, item*8]))\n",
" object_name = object_library.object_id_to_name_dict[object_uid]\n",
" object_cad_asset_filepath = ObjectLibrary.get_cad_asset_path(\n",
" object_library_folderpath=object_library.asset_folder_name,\n",
" object_id=object_uid)\n",
" object_translation = object_data[i, item*8+1:item*8+4]\n",
" object_rotation = object_data[i, item*8+4:item*8+8]\n",
" object_pose = SE3.from_quat_and_translation(object_rotation[-1], object_rotation[:-1], object_translation) \n",
" log_pose(pose=object_pose, label=f\"world/objects/{object_name}\")\n",
" logging_status[object_uid] = True # mark object has been seen (enable to know which object has been logged or not)\n",
" \n",
" if show_bbx:\n",
" bbx_vertex_0 = object_bbx_data[i, item*24:item*24+3]\n",
" bbx_vertex_1 = object_bbx_data[i, item*24+3:item*24+6]\n",
" bbx_vertex_2 = object_bbx_data[i, item*24+6:item*24+9]\n",
" bbx_vertex_3 = object_bbx_data[i, item*24+9:item*24+12]\n",
" bbx_vertex_4 = object_bbx_data[i, item*24+12:item*24+15]\n",
" bbx_vertex_5 = object_bbx_data[i, item*24+15:item*24+18]\n",
" bbx_vertex_6 = object_bbx_data[i, item*24+18:item*24+21]\n",
" bbx_vertex_7 = object_bbx_data[i, item*24+21:item*24+24] \n",
" points = [\n",
" bbx_vertex_0,\n",
" bbx_vertex_1,\n",
" bbx_vertex_2,\n",
" bbx_vertex_3,\n",
" bbx_vertex_0,\n",
" bbx_vertex_7,\n",
" bbx_vertex_6,\n",
" bbx_vertex_5,\n",
" bbx_vertex_4,\n",
" bbx_vertex_7,\n",
" bbx_vertex_6,\n",
" bbx_vertex_1,\n",
" bbx_vertex_2,\n",
" bbx_vertex_5,\n",
" bbx_vertex_4,\n",
" bbx_vertex_3]\n",
" rr.log(f\"world/objects_bbx/{object_name}\", rr.LineStrips3D([points]))\n",
" \n",
" # link the corresponding 3D object to the pose\n",
" if object_uid not in object_cache_status.keys():\n",
" object_cache_status[object_uid] = True \n",
" rr.log(\n",
" f\"world/objects/{object_name}\",\n",
" rr.Asset3D(path=object_cad_asset_filepath)) \n",
" # if some objects are not visible, we clear the entity\n",
" for object_uid, displayed in logging_status.items():\n",
" if not displayed:\n",
" object_name = object_library.object_id_to_name_dict[object_uid]\n",
" rr.log(\n",
" f\"world/objects/{object_name}\",\n",
" rr.Clear.recursive())\n",
" if show_bbx:\n",
" rr.log(\n",
" f\"world/objects_bbx/{object_name}\",\n",
" rr.Clear.recursive()) \n",
" if object_uid in object_cache_status.keys():\n",
" del object_cache_status[object_uid] # we will log the mesh again\n",
" \n",
"# show the rerun window\n",
"rr.notebook_show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7811d038-6150-46b3-ac62-d4875191e0f4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -0,0 +1,34 @@
object_uid,object_name,bbx_x_min,bbx_x_max,bbx_y_min,bbx_y_max,bbx_z_min,bbx_z_max
106434519822892,bottle_bbq,-0.023,0.021,0,0.145,-0.02,0.04
106957734975303,puzzle_toy,-0.039,0.019,0,0.058,-0.031,0.027
111305855142671,aria_small,-0.073,0.073,-0.03,0.015,-0.141,0.023
117658302265452,mug_patterned,-0.052,0.045,0,0.084,-0.07,0.06
125863066770940,carton_milk,-0.039,0.033,0,0.192,-0.04,0.04
143541090750632,holder_gray,-0.035,0.051,0,0.124,-0.04,0.03
163340972252026,vase,-0.084,0.052,0,0.168,-0.052,0.048
171735388712249,can_parmesan,-0.031,0.035,0,0.102,-0.04,0.03
183136364331389,can_tomato_sauce,-0.034,0.036,0,0.082,-0.036,0.03
194930206998778,bowl,-0.139,0.141,0,0.13,-0.14,0.14
195041665639898,birdhouse_toy,-0.102,0.078,0,0.233,-0.09,0.091
204462113746498,carton_oj,-0.032,0.043,0,0.192,-0.033,0.041
208243983021975,dino_toy,-0.147,0.092,0,0.135,-0.08,0.06
223371871635142,mug_white,-0.068,0.049,0,0.091,-0.04,0.05
225397651484143,spoon_wooden,-0.034,0.034,0,0.018,-0.155,0.155
228358276546933,coffee_pot,-0.066,0.08,0,0.151,-0.043,0.042
232501673989606,holder_black,-0.022,0.04,0,0.111,-0.046,0.045
238686662724712,whiteboard_marker,-0.01,0.009,0,0.121,-0.009,0.011
248247003992116,plate_bamboo,-0.163,0.138,0,0.047,-0.17,0.132
249541253457812,dvd_remote,-0.018,0.017,-0.082,0.082,0.012,0.038
253405647833885,food_waffles,-0.048,0.022,0,0.021,-0.062,0.068
258906041248094,whiteboard_eraser,-0.043,0.024,0,0.037,-0.074,0.074
261746112525368,bottle_mustard,-0.03,0.018,0,0.149,-0.032,0.033
265826671143948,mouse,-0.036,0.034,-0.021,0.021,-0.056,0.054
270231216246839,spatula_red,-0.154,0.149,0,0.042,-0.114,0.097
27078911029651,dumbbell_5lb,-0.042,0.042,-0.09,0.09,-0.037,0.037
37787722328019,keyboard,-0.215,0.215,-0.018,0.002,-0.063,0.066
4111539686391,flask,-0.051,0.069,0,0.301,-0.104,0.056
5462893327580,cellphone,-0.037,0.036,-0.006,0.005,-0.077,0.077
70709727230291,bottle_ranch,-0.023,0.021,0,0.147,-0.034,0.03
79582884925181,potato_masher,-0.06,0.03,0,0.283,-0.047,0.043
96945373046044,food_vegetables,-0.045,0.023,0,0.02,-0.049,0.048
98604936546412,can_soup,-0.033,0.038,0,0.082,-0.035,0.031
1 object_uid object_name bbx_x_min bbx_x_max bbx_y_min bbx_y_max bbx_z_min bbx_z_max
2 106434519822892 bottle_bbq -0.023 0.021 0 0.145 -0.02 0.04
3 106957734975303 puzzle_toy -0.039 0.019 0 0.058 -0.031 0.027
4 111305855142671 aria_small -0.073 0.073 -0.03 0.015 -0.141 0.023
5 117658302265452 mug_patterned -0.052 0.045 0 0.084 -0.07 0.06
6 125863066770940 carton_milk -0.039 0.033 0 0.192 -0.04 0.04
7 143541090750632 holder_gray -0.035 0.051 0 0.124 -0.04 0.03
8 163340972252026 vase -0.084 0.052 0 0.168 -0.052 0.048
9 171735388712249 can_parmesan -0.031 0.035 0 0.102 -0.04 0.03
10 183136364331389 can_tomato_sauce -0.034 0.036 0 0.082 -0.036 0.03
11 194930206998778 bowl -0.139 0.141 0 0.13 -0.14 0.14
12 195041665639898 birdhouse_toy -0.102 0.078 0 0.233 -0.09 0.091
13 204462113746498 carton_oj -0.032 0.043 0 0.192 -0.033 0.041
14 208243983021975 dino_toy -0.147 0.092 0 0.135 -0.08 0.06
15 223371871635142 mug_white -0.068 0.049 0 0.091 -0.04 0.05
16 225397651484143 spoon_wooden -0.034 0.034 0 0.018 -0.155 0.155
17 228358276546933 coffee_pot -0.066 0.08 0 0.151 -0.043 0.042
18 232501673989606 holder_black -0.022 0.04 0 0.111 -0.046 0.045
19 238686662724712 whiteboard_marker -0.01 0.009 0 0.121 -0.009 0.011
20 248247003992116 plate_bamboo -0.163 0.138 0 0.047 -0.17 0.132
21 249541253457812 dvd_remote -0.018 0.017 -0.082 0.082 0.012 0.038
22 253405647833885 food_waffles -0.048 0.022 0 0.021 -0.062 0.068
23 258906041248094 whiteboard_eraser -0.043 0.024 0 0.037 -0.074 0.074
24 261746112525368 bottle_mustard -0.03 0.018 0 0.149 -0.032 0.033
25 265826671143948 mouse -0.036 0.034 -0.021 0.021 -0.056 0.054
26 270231216246839 spatula_red -0.154 0.149 0 0.042 -0.114 0.097
27 27078911029651 dumbbell_5lb -0.042 0.042 -0.09 0.09 -0.037 0.037
28 37787722328019 keyboard -0.215 0.215 -0.018 0.002 -0.063 0.066
29 4111539686391 flask -0.051 0.069 0 0.301 -0.104 0.056
30 5462893327580 cellphone -0.037 0.036 -0.006 0.005 -0.077 0.077
31 70709727230291 bottle_ranch -0.023 0.021 0 0.147 -0.034 0.03
32 79582884925181 potato_masher -0.06 0.03 0 0.283 -0.047 0.043
33 96945373046044 food_vegetables -0.045 0.023 0 0.02 -0.049 0.048
34 98604936546412 can_soup -0.033 0.038 0 0.082 -0.035 0.031

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,4 @@
__all__ = ['file_systems']
from .file_systems import remake_dir, make_dir

View file

@ -0,0 +1,50 @@
import os
import shutil
import time
# remove a directory
def remove_dir(dirName):
if os.path.exists(dirName):
shutil.rmtree(dirName)
else:
print("Invalid directory path!")
# remake a directory
def remake_dir(dirName):
if os.path.exists(dirName):
shutil.rmtree(dirName)
os.makedirs(dirName)
else:
os.makedirs(dirName)
# calculate the number of lines in a file
def file_lines(fileName):
if os.path.exists(fileName):
with open(fileName, 'r') as fr:
return len(fr.readlines())
else:
print("Invalid file path!")
return 0
# make a directory if it does not exist.
def make_dir(dirName):
if os.path.exists(dirName):
print("Directory "+ dirName + " already exists.")
else:
os.makedirs(dirName)
if __name__ == "__main__":
dirName = "test"
RemakeDir(dirName)
time.sleep(3)
MakeDir(dirName)
RemoveDir(dirName)
time.sleep(3)
MakeDir(dirName)
#print(FileLines('233.txt'))

View file

@ -0,0 +1,98 @@
from torch import nn
import torch
from model import graph_convolution_network
import torch.nn.functional as F
class attended_hand_recognition(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.body_joint_number = opt.body_joint_number
self.hand_joint_number = opt.hand_joint_number
self.joint_number = self.body_joint_number + self.hand_joint_number
self.input_n = opt.seq_len
gcn_latent_features = opt.gcn_latent_features
residual_gcns_num = opt.residual_gcns_num
gcn_dropout = opt.gcn_dropout
head_cnn_channels = opt.head_cnn_channels
recognition_cnn_channels = opt.recognition_cnn_channels
# 1D CNN for extracting features from head directions
in_channels_head = 3
cnn_kernel_size = 3
cnn_padding = (cnn_kernel_size -1)//2
out_channels_1_head = head_cnn_channels
out_channels_2_head = head_cnn_channels
out_channels_head = head_cnn_channels
self.head_cnn = nn.Sequential(
nn.Conv1d(in_channels = in_channels_head, out_channels=out_channels_1_head, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
nn.LayerNorm([out_channels_1_head, self.input_n]),
nn.Tanh(),
nn.Conv1d(in_channels=out_channels_1_head, out_channels=out_channels_2_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
nn.LayerNorm([out_channels_2_head, self.input_n]),
nn.Tanh(),
nn.Conv1d(in_channels=out_channels_2_head, out_channels=out_channels_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
nn.Tanh()
)
# GCN for extracting features from body and left hand joints
self.left_hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
latent_features=gcn_latent_features,
node_n=self.joint_number,
seq_len=self.input_n,
p_dropout=gcn_dropout,
residual_gcns_num=residual_gcns_num)
# GCN for extracting features from body and right hand joints
self.right_hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
latent_features=gcn_latent_features,
node_n=self.joint_number,
seq_len=self.input_n,
p_dropout=gcn_dropout,
residual_gcns_num=residual_gcns_num)
# 1D CNN for recognising attended hand (left or right)
in_channels_recognition = self.joint_number*gcn_latent_features*2 + out_channels_head
cnn_kernel_size = 3
cnn_padding = (cnn_kernel_size -1)//2
out_channels_1_recognition = recognition_cnn_channels
out_channels_recognition = 2
self.recognition_cnn = nn.Sequential(
nn.Conv1d(in_channels = in_channels_recognition, out_channels=out_channels_1_recognition, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
nn.LayerNorm([out_channels_1_recognition, self.input_n]),
nn.Tanh(),
nn.Conv1d(in_channels=out_channels_1_recognition, out_channels=out_channels_recognition, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
)
def forward(self, src, input_n=15):
bs, seq_len, features = src.shape
body_joints = src.clone()[:, :, :self.body_joint_number*3]
left_hand_joints = src.clone()[:, :, self.body_joint_number*3:(self.body_joint_number+self.hand_joint_number)*3]
right_hand_joints = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number)*3:(self.body_joint_number+self.hand_joint_number*2)*3]
head_direction = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2)*3:(self.body_joint_number+self.hand_joint_number*2+1)*3]
left_hand_joints = torch.cat((left_hand_joints, body_joints), dim=2)
left_hand_joints = left_hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
left_hand_features = self.left_hand_gcn(left_hand_joints)
left_hand_features = left_hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
right_hand_joints = torch.cat((right_hand_joints, body_joints), dim=2)
right_hand_joints = right_hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
right_hand_features = self.right_hand_gcn(right_hand_joints)
right_hand_features = right_hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
head_direction = head_direction.permute(0,2,1)
head_features = self.head_cnn(head_direction)
# fuse head and hand features
features = torch.cat((left_hand_features, right_hand_features), dim=1)
features = torch.cat((features, head_features), dim=1)
# recognise attended hand from fused features
prediction = self.recognition_cnn(features).permute(0, 2, 1)
return prediction

140
model/gaze_estimation.py Normal file
View file

@ -0,0 +1,140 @@
from torch import nn
import torch
from model import graph_convolution_network, transformer
import torch.nn.functional as F
class gaze_estimation(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.body_joint_number = opt.body_joint_number
self.hand_joint_number = opt.hand_joint_number
self.input_n = opt.seq_len
self.object_num = opt.object_num
gcn_latent_features = opt.gcn_latent_features
residual_gcns_num = opt.residual_gcns_num
gcn_dropout = opt.gcn_dropout
head_cnn_channels = opt.head_cnn_channels
gaze_cnn_channels = opt.gaze_cnn_channels
self.use_self_att = opt.use_self_att
self_att_head_num = opt.self_att_head_num
self_att_dropout = opt.self_att_dropout
self.use_cross_att = opt.use_cross_att
cross_att_head_num = opt.cross_att_head_num
cross_att_dropout = opt.cross_att_dropout
self.use_attended_hand = opt.use_attended_hand
self.use_attended_hand_gt = opt.use_attended_hand_gt
if self.use_attended_hand:
self.joint_number = self.body_joint_number + self.hand_joint_number + self.object_num
else:
self.joint_number = self.body_joint_number + self.hand_joint_number*2 + self.object_num*2
# 1D CNN for extracting features from head directions
in_channels_head = 3
cnn_kernel_size = 3
cnn_padding = (cnn_kernel_size -1)//2
out_channels_1_head = head_cnn_channels
out_channels_2_head = head_cnn_channels
out_channels_head = head_cnn_channels
self.head_cnn = nn.Sequential(
nn.Conv1d(in_channels = in_channels_head, out_channels=out_channels_1_head, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
nn.LayerNorm([out_channels_1_head, self.input_n]),
nn.Tanh(),
nn.Conv1d(in_channels=out_channels_1_head, out_channels=out_channels_2_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
nn.LayerNorm([out_channels_2_head, self.input_n]),
nn.Tanh(),
nn.Conv1d(in_channels=out_channels_2_head, out_channels=out_channels_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
nn.Tanh()
)
# GCN for extracting features from hand joints, body joints, and scene objects
self.hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
latent_features=gcn_latent_features,
node_n=self.joint_number,
seq_len=self.input_n,
p_dropout=gcn_dropout,
residual_gcns_num=residual_gcns_num)
if self.use_self_att:
self.head_self_att = transformer.temporal_self_attention(out_channels_head, self_att_head_num, self_att_dropout)
self.hand_self_att = transformer.temporal_self_attention(self.joint_number*gcn_latent_features, self_att_head_num, self_att_dropout)
if self.use_cross_att:
self.head_hand_cross_att = transformer.temporal_cross_attention(out_channels_head, self.joint_number*gcn_latent_features, cross_att_head_num, cross_att_dropout)
self.hand_head_cross_att = transformer.temporal_cross_attention(self.joint_number*gcn_latent_features, out_channels_head, cross_att_head_num, cross_att_dropout)
# 1D CNN for estimating eye gaze
in_channels_gaze = self.joint_number*gcn_latent_features + out_channels_head
cnn_kernel_size = 3
cnn_padding = (cnn_kernel_size -1)//2
out_channels_1_gaze = gaze_cnn_channels
out_channels_gaze = 3
self.gaze_cnn = nn.Sequential(
nn.Conv1d(in_channels = in_channels_gaze, out_channels=out_channels_1_gaze, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
nn.LayerNorm([out_channels_1_gaze, self.input_n]),
nn.Tanh(),
nn.Conv1d(in_channels=out_channels_1_gaze, out_channels=out_channels_gaze, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
nn.Tanh()
)
def forward(self, src, input_n=15):
bs, seq_len, features = src.shape
body_joints = src.clone()[:, :, :self.body_joint_number*3]
left_hand_joints = src.clone()[:, :, self.body_joint_number*3:(self.body_joint_number+self.hand_joint_number)*3]
right_hand_joints = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number)*3:(self.body_joint_number+self.hand_joint_number*2)*3]
head_direction = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2)*3:(self.body_joint_number+self.hand_joint_number*2+1)*3]
if self.object_num > 0:
left_object_position = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num)*3]
left_object_position = torch.mean(left_object_position.reshape(bs, seq_len, self.object_num, 8, 3), dim=3).reshape(bs, seq_len, self.object_num*3)
right_object_position = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3]
right_object_position = torch.mean(right_object_position.reshape(bs, seq_len, self.object_num, 8, 3), dim=3).reshape(bs, seq_len, self.object_num*3)
attended_hand_prd = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+2]
left_hand_weights = torch.round(attended_hand_prd[:, :, 0:1])
right_hand_weights = torch.round(attended_hand_prd[:, :, 1:2])
if self.use_attended_hand_gt:
attended_hand_gt = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+2:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+3]
left_hand_weights = 1-attended_hand_gt
right_hand_weights = attended_hand_gt
if self.use_attended_hand:
hand_joints = left_hand_joints*left_hand_weights + right_hand_joints*right_hand_weights
else:
hand_joints = torch.cat((left_hand_joints, right_hand_joints), dim=2)
hand_joints = torch.cat((hand_joints, body_joints), dim=2)
if self.object_num > 0:
if self.use_attended_hand:
object_position = left_object_position*left_hand_weights + right_object_position*right_hand_weights
else:
object_position = torch.cat((left_object_position, right_object_position), dim=2)
hand_joints = torch.cat((hand_joints, object_position), dim=2)
hand_joints = hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
hand_features = self.hand_gcn(hand_joints)
hand_features = hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
head_direction = head_direction.permute(0,2,1)
head_features = self.head_cnn(head_direction)
if self.use_self_att:
head_features = self.head_self_att(head_features.permute(0,2,1)).permute(0,2,1)
hand_features = self.hand_self_att(hand_features.permute(0,2,1)).permute(0,2,1)
if self.use_cross_att:
head_features_copy = head_features.clone()
head_features = self.head_hand_cross_att(head_features.permute(0,2,1), hand_features.permute(0,2,1)).permute(0,2,1)
hand_features = self.hand_head_cross_att(hand_features.permute(0,2,1), head_features_copy.permute(0,2,1)).permute(0,2,1)
# fuse head and hand features
features = torch.cat((hand_features, head_features), dim=1)
# estimate eye gaze
prediction = self.gaze_cnn(features).permute(0, 2, 1)
# normalize to unit vectors
prediction = F.normalize(prediction, dim=2)
return prediction

View file

@ -0,0 +1,79 @@
import torch.nn as nn
import torch
from torch.nn.parameter import Parameter
import math
class graph_convolution(nn.Module):
def __init__(self, in_features, out_features, node_n = 21, seq_len = 40, bias=True):
super().__init__()
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
if bias:
self.bias = Parameter(torch.FloatTensor(seq_len))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
self.feature_weights.data.uniform_(-stdv, stdv)
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input):
y = torch.matmul(input, self.temporal_graph_weights)
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
if self.bias is not None:
return (y + self.bias)
else:
return y
class residual_graph_convolution(nn.Module):
def __init__(self, features, node_n=21, seq_len = 40, bias=True, p_dropout=0.3):
super().__init__()
self.gcn = graph_convolution(features, features, node_n=node_n, seq_len=seq_len, bias=bias)
self.ln = nn.LayerNorm([features, node_n, seq_len], elementwise_affine=True)
self.act_f = nn.Tanh()
self.dropout = nn.Dropout(p_dropout)
def forward(self, x):
y = self.gcn(x)
y = self.ln(y)
y = self.act_f(y)
y = self.dropout(y)
return y + x
class graph_convolution_network(nn.Module):
def __init__(self, in_features, latent_features, node_n=21, seq_len=40, p_dropout=0.3, residual_gcns_num=1):
super().__init__()
self.residual_gcns_num = residual_gcns_num
self.seq_len = seq_len
self.start_gcn = graph_convolution(in_features=in_features, out_features=latent_features, node_n=node_n, seq_len=seq_len)
self.residual_gcns = []
for i in range(residual_gcns_num):
self.residual_gcns.append(residual_graph_convolution(features=latent_features, node_n=node_n, seq_len=seq_len*2, p_dropout=p_dropout))
self.residual_gcns = nn.ModuleList(self.residual_gcns)
def forward(self, x):
y = self.start_gcn(x)
y = torch.cat((y, y), dim=3)
for i in range(self.residual_gcns_num):
y = self.residual_gcns[i](y)
y = y[:, :, :, :self.seq_len]
return y

138
model/transformer.py Normal file
View file

@ -0,0 +1,138 @@
import torch
import torch.nn.functional as F
from torch import layer_norm, nn
import math
class temporal_self_attention(nn.Module):
def __init__(self, latent_dim, num_head, dropout):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
self.key = nn.Linear(latent_dim, latent_dim, bias=False)
self.value = nn.Linear(latent_dim, latent_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
x: B, T, D
"""
B, T, D = x.shape
H = self.num_head
# B, T, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, T, D
key = self.key(self.norm(x)).unsqueeze(1)
query = query.view(B, T, H, -1)
key = key.view(B, T, H, -1)
# B, T, T, H
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
value = self.value(self.norm(x)).view(B, T, H, -1)
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
y = x + y
return y
class spatial_self_attention(nn.Module):
def __init__(self, latent_dim, num_head, dropout):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
self.key = nn.Linear(latent_dim, latent_dim, bias=False)
self.value = nn.Linear(latent_dim, latent_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
x: B, S, D
"""
B, S, D = x.shape
H = self.num_head
# B, S, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, S, D
key = self.key(self.norm(x)).unsqueeze(1)
query = query.view(B, S, H, -1)
key = key.view(B, S, H, -1)
# B, S, S, H
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
value = self.value(self.norm(x)).view(B, S, H, -1)
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, S, D)
y = x + y
return y
class temporal_cross_attention(nn.Module):
def __init__(self, latent_dim, mod_dim, num_head, dropout):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.mod_norm = nn.LayerNorm(mod_dim)
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
self.key = nn.Linear(mod_dim, latent_dim, bias=False)
self.value = nn.Linear(mod_dim, latent_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, xf):
"""
x: B, T, D
xf: B, N, L
"""
B, T, D = x.shape
N = xf.shape[1]
H = self.num_head
# B, T, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, N, D
key = self.key(self.mod_norm(xf)).unsqueeze(1)
query = query.view(B, T, H, -1)
key = key.view(B, N, H, -1)
# B, T, N, H
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
value = self.value(self.mod_norm(xf)).view(B, N, H, -1)
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
y = x + y
return y
class spatial_cross_attention(nn.Module):
def __init__(self, latent_dim, mod_dim, num_head, dropout):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.mod_norm = nn.LayerNorm(mod_dim)
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
self.key = nn.Linear(mod_dim, latent_dim, bias=False)
self.value = nn.Linear(mod_dim, latent_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, xf):
"""
x: B, S, D
xf: B, N, L
"""
B, S, D = x.shape
N = xf.shape[1]
H = self.num_head
# B, S, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, N, D
key = self.key(self.mod_norm(xf)).unsqueeze(1)
query = query.view(B, S, H, -1)
key = key.view(B, N, H, -1)
# B, S, N, H
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
value = self.value(self.mod_norm(xf)).view(B, N, H, -1)
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, S, D)
y = x + y
return y

7
train_adt.sh Normal file
View file

@ -0,0 +1,7 @@
python attended_hand_recognition_adt.py --data_dir /scratch/hu/pose_forecast/adt_hoigaze/ --ckpt ./checkpoints/adt/ --cuda_idx cuda:6 --seq_len 15 --sample_rate 2 --residual_gcns_num 2 --gamma 0.95 --learning_rate 0.005 --epoch 60 --object_num 1 --hand_joint_number 1;
python attended_hand_recognition_adt.py --data_dir /scratch/hu/pose_forecast/adt_hoigaze/ --ckpt ./checkpoints/adt/ --cuda_idx cuda:6 --seq_len 15 --sample_rate 2 --residual_gcns_num 2 --gamma 0.95 --learning_rate 0.005 --epoch 60 --object_num 1 --hand_joint_number 1 --is_eval --save_predictions;
python gaze_estimation_adt.py --data_dir /scratch/hu/pose_forecast/adt_hoigaze/ --ckpt ./checkpoints/adt/ --cuda_idx cuda:6 --seq_len 15 --residual_gcns_num 4 --gamma 0.8 --learning_rate 0.005 --epoch 80 --object_num 1 --hand_joint_number 1;
python gaze_estimation_adt.py --data_dir /scratch/hu/pose_forecast/adt_hoigaze/ --ckpt ./checkpoints/adt/ --cuda_idx cuda:6 --seq_len 15 --residual_gcns_num 4 --gamma 0.8 --learning_rate 0.005 --epoch 80 --object_num 1 --hand_joint_number 1 --is_eval;

7
train_hot3d_scene1.sh Normal file
View file

@ -0,0 +1,7 @@
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --actions 'room' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0;
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --actions 'room' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0 --is_eval --save_predictions;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --actions 'room' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --actions 'room' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;

7
train_hot3d_scene2.sh Normal file
View file

@ -0,0 +1,7 @@
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --actions 'kitchen' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0;
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --actions 'kitchen' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0 --is_eval --save_predictions;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --actions 'kitchen' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --actions 'kitchen' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;

7
train_hot3d_scene3.sh Normal file
View file

@ -0,0 +1,7 @@
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --actions 'office' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0;
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --actions 'office' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0 --is_eval --save_predictions;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --actions 'office' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --actions 'office' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;

7
train_hot3d_user1.sh Normal file
View file

@ -0,0 +1,7 @@
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --test_user_id 1 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05;
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --test_user_id 1 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05 --is_eval --save_predictions;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --test_user_id 1 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --test_user_id 1 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;

7
train_hot3d_user2.sh Normal file
View file

@ -0,0 +1,7 @@
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --test_user_id 2 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05;
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --test_user_id 2 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05 --is_eval --save_predictions;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --test_user_id 2 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --test_user_id 2 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;

7
train_hot3d_user3.sh Normal file
View file

@ -0,0 +1,7 @@
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --test_user_id 3 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05;
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --test_user_id 3 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05 --is_eval --save_predictions;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --test_user_id 3 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --test_user_id 3 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;

1
utils/__init__.py Normal file
View file

@ -0,0 +1 @@
from utils import *

171
utils/adt_dataset.py Normal file
View file

@ -0,0 +1,171 @@
from torch.utils.data import Dataset
import numpy as np
import os
class adt_dataset(Dataset):
def __init__(self, data_dir, seq_len, actions = 'all', train_flag = 1, object_num=1, hand_joint_number=1, sample_rate=1):
actions = self.define_actions(actions)
self.sample_rate = sample_rate
if train_flag == 1:
data_dir = data_dir + 'train/'
if train_flag == 0:
data_dir = data_dir + 'test/'
self.dataset = self.load_data(data_dir, seq_len, actions, object_num, hand_joint_number)
def define_actions(self, action):
"""
Define the list of actions we are using.
Args
action: String with the passed action. Could be "all"
Returns
actions: List of strings of actions
Raises
ValueError if the action is not included.
"""
actions = ['work', 'decoration', 'meal']
if action in actions:
return [action]
if action == "all":
return actions
raise( ValueError, "Unrecognised action: %d" % action )
def load_data(self, data_dir, seq_len, actions, object_num, hand_joint_number):
action_number = len(actions)
dataset = []
file_names = sorted(os.listdir(data_dir))
gaze_file_names = {}
hand_file_names = {}
hand_joint_file_names = {}
head_file_names = {}
object_left_file_names = {}
object_right_file_names = {}
for action_idx in np.arange(action_number):
gaze_file_names[actions[ action_idx ]] = []
hand_file_names[actions[ action_idx ]] = []
hand_joint_file_names[actions[ action_idx ]] = []
head_file_names[actions[ action_idx ]] = []
object_left_file_names[actions[ action_idx ]] = []
object_right_file_names[actions[ action_idx ]] = []
for name in file_names:
name_split = name.split('_')
action = name_split[2]
if action in actions:
data_type = name_split[-1][:-4]
if(data_type == 'gaze'):
gaze_file_names[action].append(name)
if(data_type == 'hand'):
hand_file_names[action].append(name)
if(data_type == 'handjoints'):
hand_joint_file_names[action].append(name)
if(data_type == 'head'):
head_file_names[action].append(name)
if(data_type == 'bbxleft'):
object_left_file_names[action].append(name)
if(data_type == 'bbxright'):
object_right_file_names[action].append(name)
for action_idx in np.arange(action_number):
action = actions[ action_idx ]
segments_number = len(gaze_file_names[action])
print("Reading action {}, segments number {}".format(action, segments_number))
for i in range(segments_number):
gaze_data_path = data_dir + gaze_file_names[action][i]
gaze_data = np.load(gaze_data_path)
gaze_direction = gaze_data[:, :3]
num_frames = gaze_data.shape[0]
if num_frames < seq_len:
continue
hand_data_path = data_dir + hand_file_names[action][i]
hand_translation = np.load(hand_data_path)
hand_joint_data_path = data_dir + hand_joint_file_names[action][i]
hand_joint_data_all = np.load(hand_joint_data_path)
hand_joint_number_default = 15
hand_joint_data = hand_joint_data_all[:, :hand_joint_number_default*6]
left_hand_center = np.mean(hand_joint_data[:, :hand_joint_number_default*3].reshape(hand_joint_data.shape[0], hand_joint_number_default, 3), axis=1)
right_hand_center = np.mean(hand_joint_data[:, hand_joint_number_default*3:].reshape(hand_joint_data.shape[0], hand_joint_number_default, 3), axis=1)
if hand_joint_number == 1:
hand_joint_data = np.concatenate((left_hand_center, right_hand_center), axis=1)
attended_hand_gt = hand_joint_data_all[:, hand_joint_number_default*6:hand_joint_number_default*6+1]
attended_hand_baseline = hand_joint_data_all[:, hand_joint_number_default*6+1:hand_joint_number_default*6+2]
head_data_path = data_dir + head_file_names[action][i]
head_data = np.load(head_data_path)
head_direction = head_data[:, :3]
head_translation = head_data[:, 3:]
object_left_data_path = data_dir + object_left_file_names[action][i]
object_left_data = np.load(object_left_data_path)
object_left_data = object_left_data.reshape(object_left_data.shape[0], -1)
object_right_data_path = data_dir + object_right_file_names[action][i]
object_right_data = np.load(object_right_data_path)
object_right_data = object_right_data.reshape(object_right_data.shape[0], -1)
object_left_bbx = []
object_right_bbx = []
for item in range(object_num):
left_bbx = object_left_data[:, item*24:item*24+24]
right_bbx = object_right_data[:, item*24:item*24+24]
if len(object_left_bbx) == 0:
object_left_bbx = left_bbx
object_right_bbx = right_bbx
else:
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
#object_left_positions = np.mean(object_left_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
#object_right_positions = np.mean(object_right_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
data = gaze_direction
data = np.concatenate((data, hand_translation), axis=1)
data = np.concatenate((data, head_translation), axis=1)
data = np.concatenate((data, hand_joint_data), axis=1)
data = np.concatenate((data, head_direction), axis=1)
if object_num > 0:
data = np.concatenate((data, object_left_bbx), axis=1)
data = np.concatenate((data, object_right_bbx), axis=1)
data = np.concatenate((data, attended_hand_gt), axis=1)
data = np.concatenate((data, attended_hand_baseline), axis=1)
fs = np.arange(0, num_frames - seq_len + 1)
fs_sel = fs
for i in np.arange(seq_len - 1):
fs_sel = np.vstack((fs_sel, fs + i + 1))
fs_sel = fs_sel.transpose()
#print(fs_sel)
seq_sel = data[fs_sel, :]
seq_sel = seq_sel[0::self.sample_rate, :, :]
#print(seq_sel.shape)
if len(dataset) == 0:
dataset = seq_sel
else:
dataset = np.concatenate((dataset, seq_sel), axis=0)
return dataset
def __len__(self):
return np.shape(self.dataset)[0]
def __getitem__(self, item):
return self.dataset[item]
if __name__ == "__main__":
data_dir = "/scratch/hu/pose_forecast/adt_hoigaze/"
seq_len = 15
actions = 'all'
sample_rate = 1
train_flag = 1
object_num = 1
hand_joint_number = 1
train_dataset = adt_dataset(data_dir, seq_len, actions, train_flag, object_num, hand_joint_number, sample_rate)
print("Training data size: {}".format(train_dataset.dataset.shape))
hand_joint_dominance = train_dataset[:, :, -2:-1].flatten()
print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))

137
utils/hot3d_aria_dataset.py Normal file
View file

@ -0,0 +1,137 @@
from torch.utils.data import Dataset
import numpy as np
import os
class hot3d_aria_dataset(Dataset):
def __init__(self, data_dir, subjects, seq_len, actions = 'all', object_num=1, sample_rate=1):
if actions == 'all':
actions = ['room', 'kitchen', 'office']
self.sample_rate = sample_rate
self.dataset = self.load_data(data_dir, subjects, seq_len, actions, object_num)
def load_data(self, data_dir, subjects, seq_len, actions, object_num):
dataset = []
file_names = sorted(os.listdir(data_dir))
gaze_file_names = []
hand_file_names = []
hand_joint_file_names = []
head_file_names = []
object_left_file_names = []
object_right_file_names = []
for name in file_names:
name_split = name.split('_')
subject = name_split[0]
action = name_split[2]
if subject in subjects and action in actions:
data_type = name_split[-1][:-4]
if(data_type == 'gaze'):
gaze_file_names.append(name)
if(data_type == 'hand'):
hand_file_names.append(name)
if(data_type == 'handjoints'):
hand_joint_file_names.append(name)
if(data_type == 'head'):
head_file_names.append(name)
if(data_type == 'bbxleft'):
object_left_file_names.append(name)
if(data_type == 'bbxright'):
object_right_file_names.append(name)
segments_number = len(hand_file_names)
# print("segments number {}".format(segments_number))
for i in range(segments_number):
gaze_data_path = data_dir + gaze_file_names[i]
gaze_data = np.load(gaze_data_path)
num_frames = gaze_data.shape[0]
if num_frames < seq_len:
continue
hand_data_path = data_dir + hand_file_names[i]
hand_data = np.load(hand_data_path)
hand_joint_data_path = data_dir + hand_joint_file_names[i]
hand_joint_data_all = np.load(hand_joint_data_path)
hand_joint_data = hand_joint_data_all[:, :120]
attended_hand_gt = hand_joint_data_all[:, 120:121]
attended_hand_baseline = hand_joint_data_all[:, 121:122]
head_data_path = data_dir + head_file_names[i]
head_data = np.load(head_data_path)
object_left_data_path = data_dir + object_left_file_names[i]
object_left_data = np.load(object_left_data_path)
object_right_data_path = data_dir + object_right_file_names[i]
object_right_data = np.load(object_right_data_path)
left_hand_translation = hand_data[:, 0:3]
right_hand_translation = hand_data[:, 22:25]
head_direction = head_data[:, 0:3]
head_translation = head_data[:, 3:6]
gaze_direction = gaze_data[:, 0:3]
object_left_bbx = []
object_right_bbx = []
for item in range(object_num):
left_bbx = object_left_data[:, item*24:item*24+24]
right_bbx = object_right_data[:, item*24:item*24+24]
if len(object_left_bbx) == 0:
object_left_bbx = left_bbx
object_right_bbx = right_bbx
else:
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
#object_left_positions = np.mean(object_left_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
#object_right_positions = np.mean(object_right_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
data = gaze_direction
data = np.concatenate((data, left_hand_translation), axis=1)
data = np.concatenate((data, right_hand_translation), axis=1)
data = np.concatenate((data, head_translation), axis=1)
data = np.concatenate((data, hand_joint_data), axis=1)
data = np.concatenate((data, head_direction), axis=1)
if object_num > 0:
data = np.concatenate((data, object_left_bbx), axis=1)
data = np.concatenate((data, object_right_bbx), axis=1)
data = np.concatenate((data, attended_hand_gt), axis=1)
data = np.concatenate((data, attended_hand_baseline), axis=1)
fs = np.arange(0, num_frames - seq_len + 1)
fs_sel = fs
for i in np.arange(seq_len - 1):
fs_sel = np.vstack((fs_sel, fs + i + 1))
fs_sel = fs_sel.transpose()
seq_sel = data[fs_sel, :]
seq_sel = seq_sel[0::self.sample_rate, :, :]
if len(dataset) == 0:
dataset = seq_sel
else:
dataset = np.concatenate((dataset, seq_sel), axis=0)
return dataset
def __len__(self):
return np.shape(self.dataset)[0]
def __getitem__(self, item):
return self.dataset[item]
if __name__ == "__main__":
data_dir = "/scratch/hu/pose_forecast/hot3d_hoigaze/"
seq_len = 15
actions = 'all'
all_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
object_num = 1
sample_rate = 10
train_dataset = hot3d_aria_dataset(data_dir, train_subjects, seq_len, actions, object_num, sample_rate)
print("Training data size: {}".format(train_dataset.dataset.shape))
hand_joint_dominance = train_dataset[:, :, -2:-1].flatten()
print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))
#test_subjects = ['P0001', 'P0002', 'P0003']
#sample_rate = 8
#test_dataset = hot3d_aria_dataset(data_dir, test_subjects, seq_len, actions, #object_num, sample_rate)
# print("Test data size: {}".format(test_dataset.dataset.shape))
#hand_joint_dominance = test_dataset[:, :, -2:-1].flatten()
#print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))

View file

@ -0,0 +1,91 @@
from torch.utils.data import Dataset
import numpy as np
import os
class hot3d_aria_dataset(Dataset):
def __init__(self, data_path, seq_len, object_num=1):
self.dataset = self.load_data(data_path, seq_len, object_num)
def load_data(self, data_path, seq_len, object_num):
dataset = []
gaze_file_name = data_path + 'gaze.npy'
hand_file_name = data_path + 'hand.npy'
hand_joint_file_name = data_path + 'handjoints.npy'
head_file_name = data_path + 'head.npy'
object_left_file_name = data_path + 'object_bbxleft.npy'
object_right_file_name = data_path + 'object_bbxright.npy'
gaze_data_path = gaze_file_name
gaze_data = np.load(gaze_data_path)
num_frames = gaze_data.shape[0]
hand_data_path = hand_file_name
hand_data = np.load(hand_data_path)
hand_joint_data_path = hand_joint_file_name
hand_joint_data_all = np.load(hand_joint_data_path)
hand_joint_data = hand_joint_data_all[:, :120]
attended_hand_gt = hand_joint_data_all[:, 120:121]
attended_hand_baseline = hand_joint_data_all[:, 121:122]
head_data_path = head_file_name
head_data = np.load(head_data_path)
object_left_data_path = object_left_file_name
object_left_data = np.load(object_left_data_path)
object_right_data_path = object_right_file_name
object_right_data = np.load(object_right_data_path)
left_hand_translation = hand_data[:, 0:3]
right_hand_translation = hand_data[:, 22:25]
head_direction = head_data[:, 0:3]
head_translation = head_data[:, 3:6]
gaze_direction = gaze_data[:, 0:3]
object_left_bbx = []
object_right_bbx = []
for item in range(object_num):
left_bbx = object_left_data[:, item*24:item*24+24]
right_bbx = object_right_data[:, item*24:item*24+24]
if len(object_left_bbx) == 0:
object_left_bbx = left_bbx
object_right_bbx = right_bbx
else:
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
data = gaze_direction
data = np.concatenate((data, left_hand_translation), axis=1)
data = np.concatenate((data, right_hand_translation), axis=1)
data = np.concatenate((data, head_translation), axis=1)
data = np.concatenate((data, hand_joint_data), axis=1)
data = np.concatenate((data, head_direction), axis=1)
if object_num > 0:
data = np.concatenate((data, object_left_bbx), axis=1)
data = np.concatenate((data, object_right_bbx), axis=1)
data = np.concatenate((data, attended_hand_gt), axis=1)
data = np.concatenate((data, attended_hand_baseline), axis=1)
fs = np.arange(0, num_frames - seq_len + 1)
fs_sel = fs
for i in np.arange(seq_len - 1):
fs_sel = np.vstack((fs_sel, fs + i + 1))
fs_sel = fs_sel.transpose()
seq_sel = data[fs_sel, :]
seq_sel = seq_sel[0::seq_len, :, :]
if len(dataset) == 0:
dataset = seq_sel
else:
dataset = np.concatenate((dataset, seq_sel), axis=0)
return dataset
def __len__(self):
return np.shape(self.dataset)[0]
def __getitem__(self, item):
return self.dataset[item]
if __name__ == "__main__":
data_path = '/scratch/hu/pose_forecast/hot3d_hoigaze/P0001_10a27bf7_room_721_890_'
seq_len = 15
object_num = 1
train_dataset = hot3d_aria_dataset(data_path, seq_len, object_num)
print("Training data size: {}".format(train_dataset.dataset.shape))

28
utils/log.py Normal file
View file

@ -0,0 +1,28 @@
import json
import os
import torch
import pandas as pd
import numpy as np
def save_csv_log(opt, head, value, is_create=False, file_name='results'):
if len(value.shape) < 2:
value = np.expand_dims(value, axis=0)
df = pd.DataFrame(value)
file_path = opt.ckpt + '/{}.csv'.format(file_name)
print(file_path)
if not os.path.exists(file_path) or is_create:
df.to_csv(file_path, header=head, index=False)
else:
with open(file_path, 'a') as f:
df.to_csv(f, header=False, index=False)
def save_ckpt(state, opt=None, file_name = 'model.pt'):
file_path = os.path.join(opt.ckpt, file_name)
torch.save(state, file_path)
def save_options(opt):
with open(opt.ckpt + '/options.json', 'w') as f:
f.write(json.dumps(vars(opt), sort_keys=False, indent=4))

74
utils/opt.py Normal file
View file

@ -0,0 +1,74 @@
import os
import argparse
from pprint import pprint
class options:
def __init__(self):
self.parser = argparse.ArgumentParser()
self.opt = None
def _initial(self):
# ===============================================================
# General options
# ===============================================================
self.parser.add_argument('--cuda_idx', type=str, default='cuda:0', help='cuda idx')
self.parser.add_argument('--data_dir', type=str,
default='./dataset/',
help='path to dataset')
self.parser.add_argument('--is_eval', dest='is_eval', action='store_true',
help='whether to evaluate existing models or not')
self.parser.add_argument('--ckpt', type=str, default='./checkpoints/', help='path to save checkpoints')
self.parser.add_argument('--test_user_id', type=int, default=1, help='id of the test participants')
self.parser.add_argument('--actions', type=str, default='all', help='actions to use')
self.parser.add_argument('--sample_rate', type=int, default=2, help='sample the data')
self.parser.add_argument('--save_predictions', dest='save_predictions', action='store_true',
help='whether to save the prediction results or not')
# ===============================================================
# Model options
# ===============================================================
self.parser.add_argument('--body_joint_number', type=int, default=3, help='number of body joints to use')
self.parser.add_argument('--hand_joint_number', type=int, default=20, help='number of hand joints to use')
self.parser.add_argument('--head_cnn_channels', type=int, default=32, help='number of channels used in the head_CNN')
self.parser.add_argument('--gcn_latent_features', type=int, default=8, help='number of latent features used in the gcn')
self.parser.add_argument('--residual_gcns_num', type=int, default=4, help='number of residual gcns to use')
self.parser.add_argument('--gcn_dropout', type=float, default=0.3, help='drop out probability in the gcn')
self.parser.add_argument('--gaze_cnn_channels', type=int, default=64, help='number of channels used in the gaze_CNN')
self.parser.add_argument('--recognition_cnn_channels', type=int, default=64, help='number of channels used in the recognition_CNN')
self.parser.add_argument('--object_num', type=int, default=1, help='number of scene objects for gaze estimation')
self.parser.add_argument('--use_self_att', type=int, default=1, help='use self attention or not')
self.parser.add_argument('--self_att_head_num', type=int, default=1, help='number of heads used in self attention')
self.parser.add_argument('--self_att_dropout', type=float, default=0.1, help='drop out probability in self attention')
self.parser.add_argument('--use_cross_att', type=int, default=1, help='use cross attention or not')
self.parser.add_argument('--cross_att_head_num', type=int, default=1, help='number of heads used in cross attention')
self.parser.add_argument('--cross_att_dropout', type=float, default=0.1, help='drop out probability in cross attention')
self.parser.add_argument('--use_attended_hand', type=int, default=1, help='use attended hand or use both hands')
self.parser.add_argument('--use_attended_hand_gt', type=int, default=0, help='use attended hand ground truth or not')
# ===============================================================
# Running options
# ===============================================================
self.parser.add_argument('--seq_len', type=int, default=15, help='the length of the used sequence')
self.parser.add_argument('--learning_rate', type=float, default=0.005)
self.parser.add_argument('--gaze_head_loss_factor', type=float, default=4.0)
self.parser.add_argument('--gaze_head_cos_threshold', type=float, default=0.8)
self.parser.add_argument('--weight_decay', type=float, default=0.0)
self.parser.add_argument('--gamma', type=float, default=0.95, help='decay learning rate by gamma')
self.parser.add_argument('--epoch', type=int, default=50)
self.parser.add_argument('--batch_size', type=int, default=32)
self.parser.add_argument('--validation_epoch', type=int, default=10, help='interval of epoches to test')
self.parser.add_argument('--test_batch_size', type=int, default=32)
def _print(self):
print("\n==================Options=================")
pprint(vars(self.opt), indent=4)
print("==========================================\n")
def parse(self, make_dir=True):
self._initial()
self.opt = self.parser.parse_args()
ckpt = self.opt.ckpt
if make_dir==True:
if not os.path.isdir(ckpt):
os.makedirs(ckpt)
self._print()
return self.opt

15
utils/seed_torch.py Normal file
View file

@ -0,0 +1,15 @@
import os
import random
import numpy as np
import torch
def seed_torch(seed=0):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False