first commit
This commit is contained in:
parent
99ce0acafb
commit
8f6b6a34e7
73 changed files with 11656 additions and 0 deletions
46
README.md
46
README.md
|
@ -0,0 +1,46 @@
|
|||
# HOIGaze: Gaze Estimation During Hand-Object Interactions in Extended Reality Exploiting Eye-Hand-Head Coordination
|
||||
Project homepage: https://zhiminghu.net/hu25_hoigaze.
|
||||
|
||||
|
||||
## Abstract
|
||||
```
|
||||
We present HOIGaze – a novel learning-based approach for gaze estimation during hand-object interactions (HOI) in extended reality (XR).
|
||||
HOIGaze addresses the challenging HOI setting by building on one key insight: The eye, hand, and head movements are closely coordinated during HOIs and this coordination can be exploited to identify samples that are most useful for gaze estimator training – as such, effectively denoising the training data.
|
||||
This denoising approach is in stark contrast to previous gaze estimation methods that treated all training samples as equal.
|
||||
Specifically, we propose: 1) a novel hierarchical framework that first recognises the hand currently visually attended to and then estimates gaze direction based on the attended hand; 2) a new gaze estimator that uses cross-modal Transformers to fuse head and hand-object features extracted using a convolutional neural network and a spatio-temporal graph convolutional network; and 3) a novel eye-head coordination loss that upgrades training samples belonging to the coordinated eye-head movements.
|
||||
We evaluate HOIGaze on the HOT3D and Aria digital twin (ADT) datasets and show that it significantly outperforms state-of-the-art methods, achieving an average improvement of 15.6% on HOT3D and 6.0% on ADT in mean angular error.
|
||||
To demonstrate the potential of our method, we further report significant performance improvements for the sample downstream task of eye-based activity recognition on ADT.
|
||||
Taken together, our results underline the significant information content available in eye-hand-head coordination and, as such, open up an exciting new direction for learning-based gaze estimation.
|
||||
```
|
||||
|
||||
|
||||
## Environment:
|
||||
Ubuntu 22.04
|
||||
python 3.8+
|
||||
pytorch 1.8.1
|
||||
|
||||
|
||||
## Usage:
|
||||
Step 1: Create the environment
|
||||
```
|
||||
conda env create -f ./environment/hoigaze.yaml -n hoigaze
|
||||
conda activate hoigaze
|
||||
```
|
||||
|
||||
Step 2: Follow the instructions in './adt_processing/' and './hot3d_processing/' to process the datasets.
|
||||
|
||||
|
||||
Step 3: Set 'data_dir' and 'cuda_idx' in 'train_hot3d_userX.sh' (X for 1, 2, or 3) to evaluate on HOT3D for different users. Set 'data_dir' and 'cuda_idx' in 'train_hot3d_sceneX.sh' (X for 1, 2, or 3) to evaluate on HOT3D for different scenes.
|
||||
|
||||
Step 4: Set 'data_dir' and 'cuda_idx' in 'train_adt.sh' to evaluate on ADT.
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{hu25hoigaze,
|
||||
title={HOIGaze: Gaze Estimation During Hand-Object Interactions in Extended Reality Exploiting Eye-Hand-Head Coordination},
|
||||
author={Hu, Zhiming and Haeufle, Daniel and Schmitt, Syn and Bulling, Andreas},
|
||||
booktitle={Proceedings of the 2025 ACM Special Interest Group on Computer Graphics and Interactive Techniques},
|
||||
year={2025}}
|
||||
```
|
29
adt_processing/README.md
Normal file
29
adt_processing/README.md
Normal file
|
@ -0,0 +1,29 @@
|
|||
## Code to process the ADT dataset
|
||||
|
||||
Note: processing the ADT dataset is much more complicated than other datasets because it relies on the Project Aria Tools. It would be easier to get started with other datasets first.
|
||||
|
||||
|
||||
## Usage:
|
||||
Step 1: Follow https://facebookresearch.github.io/projectaria_tools/docs/open_datasets/aria_digital_twin_dataset/dataset_download to prepare the environment and download the dataset. Please note that in our paper we used the 1.1.0 version of the dataset with Project Aria Tools 1.1.0 and Python 3.8. We use 'python setup.py build_py' to build Project Aria Tools.
|
||||
|
||||
Step 2: Set 'dataset_path' and 'dataset_processed_path' in 'adt_preprocessing.py', put 'adt_preprocessing.py', 'adt.csv', and 'utils' into the codebase of the Project Aria Tools, and run it to process the dataset.
|
||||
|
||||
Step 3: It is optional but highly recommended to set 'data_path' in 'dataset_visualisation.py' to visualise and get familiar with the dataset.
|
||||
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@inproceedings{hu25hoigaze,
|
||||
title={HOIGaze: Gaze Estimation During Hand-Object Interactions in Extended Reality Exploiting Eye-Hand-Head Coordination},
|
||||
author={Hu, Zhiming and Haeufle, Daniel and Schmitt, Syn and Bulling, Andreas},
|
||||
booktitle={Proceedings of the 2025 ACM Special Interest Group on Computer Graphics and Interactive Techniques},
|
||||
year={2025}}
|
||||
|
||||
@inproceedings{pan2023aria,
|
||||
title={Aria digital twin: A new benchmark dataset for egocentric 3d machine perception},
|
||||
author={Pan, Xiaqing and Charron, Nicholas and Yang, Yongqian and Peters, Scott and Whelan, Thomas and Kong, Chen and Parkhi, Omkar and Newcombe, Richard and Ren, Yuheng Carl},
|
||||
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
||||
pages={20133--20143},
|
||||
year={2023}}
|
||||
```
|
35
adt_processing/adt.csv
Normal file
35
adt_processing/adt.csv
Normal file
|
@ -0,0 +1,35 @@
|
|||
sequence_name,training,action
|
||||
Apartment_release_work_skeleton_seq132,0,work
|
||||
Apartment_release_work_skeleton_seq138,0,work
|
||||
Apartment_release_meal_skeleton_seq132,0,meal
|
||||
Apartment_release_decoration_skeleton_seq133,0,decoration
|
||||
Apartment_release_decoration_skeleton_seq139,0,decoration
|
||||
Apartment_release_decoration_skeleton_seq134,0,decoration
|
||||
Apartment_release_work_skeleton_seq107,0,work
|
||||
Apartment_release_meal_skeleton_seq135,0,meal
|
||||
Apartment_release_work_skeleton_seq135,0,work
|
||||
Apartment_release_meal_skeleton_seq131,0,meal
|
||||
Apartment_release_work_skeleton_seq131,1,work
|
||||
Apartment_release_work_skeleton_seq109,1,work
|
||||
Apartment_release_work_skeleton_seq110,1,work
|
||||
Apartment_release_decoration_skeleton_seq140,1,decoration
|
||||
Apartment_release_decoration_skeleton_seq137,1,decoration
|
||||
Apartment_release_work_skeleton_seq136,1,work
|
||||
Apartment_release_meal_skeleton_seq136,1,meal
|
||||
Apartment_release_work_skeleton_seq106,1,work
|
||||
Apartment_release_meal_skeleton_seq134,1,meal
|
||||
Apartment_release_work_skeleton_seq134,1,work
|
||||
Apartment_release_decoration_skeleton_seq135,1,decoration
|
||||
Apartment_release_decoration_skeleton_seq138,1,decoration
|
||||
Apartment_release_decoration_skeleton_seq132,1,decoration
|
||||
Apartment_release_work_skeleton_seq139,1,work
|
||||
Apartment_release_work_skeleton_seq133,1,work
|
||||
Apartment_release_meal_skeleton_seq139,1,meal
|
||||
Apartment_release_meal_skeleton_seq133,1,meal
|
||||
Apartment_release_work_skeleton_seq140,1,work
|
||||
Apartment_release_work_skeleton_seq137,1,work
|
||||
Apartment_release_meal_skeleton_seq140,1,meal
|
||||
Apartment_release_meal_skeleton_seq137,1,meal
|
||||
Apartment_release_decoration_skeleton_seq136,1,decoration
|
||||
Apartment_release_decoration_skeleton_seq131,1,decoration
|
||||
Apartment_release_work_skeleton_seq108,1,work
|
|
272
adt_processing/adt_preprocessing.py
Normal file
272
adt_processing/adt_preprocessing.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
import numpy as np
|
||||
import os
|
||||
os.nice(5)
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
import plotly.graph_objects as go
|
||||
import math
|
||||
from math import tan
|
||||
import random
|
||||
from scipy.linalg import pinv
|
||||
import projectaria_tools.core.mps as mps
|
||||
import shutil
|
||||
import json
|
||||
from PIL import Image
|
||||
from utils import remake_dir
|
||||
import pandas as pd
|
||||
import pylab as p
|
||||
from IPython.display import display
|
||||
import time
|
||||
|
||||
|
||||
from projectaria_tools import utils
|
||||
from projectaria_tools.core.stream_id import StreamId
|
||||
from projectaria_tools.core import calibration
|
||||
from projectaria_tools.projects.adt import (
|
||||
AriaDigitalTwinDataProvider,
|
||||
AriaDigitalTwinSkeletonProvider,
|
||||
AriaDigitalTwinDataPathsProvider,
|
||||
bbox3d_to_line_coordinates,
|
||||
bbox2d_to_image_coordinates,
|
||||
utils as adt_utils,
|
||||
Aria3dPose
|
||||
)
|
||||
|
||||
|
||||
dataset_path = '/datasets/public/zhiming_datasets/adt/'
|
||||
dataset_processed_path = '/scratch/hu/pose_forecast/adt_hoigaze/'
|
||||
|
||||
remake_dir(dataset_processed_path)
|
||||
remake_dir(dataset_processed_path + "train/")
|
||||
remake_dir(dataset_processed_path + "test/")
|
||||
dataset_info = pd.read_csv('adt.csv')
|
||||
object_num = 5 # number of extracted dynamic objects that are closest to the left or right hands
|
||||
|
||||
|
||||
for i, seq in enumerate(dataset_info['sequence_name']):
|
||||
action = dataset_info['action'][i]
|
||||
print("\nprocessing {}th seq: {}, action: {}...".format(i+1, seq, action))
|
||||
seq_path = dataset_path + seq + '/'
|
||||
if dataset_info['training'][i] == 1:
|
||||
save_path = dataset_processed_path + 'train/' + seq + '_'
|
||||
if dataset_info['training'][i] == 0:
|
||||
save_path = dataset_processed_path + 'test/' + seq + '_'
|
||||
|
||||
paths_provider = AriaDigitalTwinDataPathsProvider(seq_path)
|
||||
all_device_serials = paths_provider.get_device_serial_numbers()
|
||||
selected_device_number = 0
|
||||
data_paths = paths_provider.get_datapaths_by_device_num(selected_device_number)
|
||||
print("loading ground truth data...")
|
||||
gt_provider = AriaDigitalTwinDataProvider(data_paths)
|
||||
print("loading ground truth data done")
|
||||
|
||||
stream_id = StreamId("214-1")
|
||||
img_timestamps_ns = gt_provider.get_aria_device_capture_timestamps_ns(stream_id)
|
||||
frame_num = len(img_timestamps_ns)
|
||||
print("There are {} frames".format(frame_num))
|
||||
|
||||
# get all available skeletons in a sequence
|
||||
skeleton_ids = gt_provider.get_skeleton_ids()
|
||||
skeleton_info = gt_provider.get_instance_info_by_id(skeleton_ids[0])
|
||||
print("skeleton ", skeleton_info.name, " wears ", skeleton_info.associated_device_serial)
|
||||
|
||||
useful_frames = []
|
||||
gaze_data = np.zeros((frame_num, 6)) # gaze_direction (3) + gaze_2d (2) + frame_id (1)
|
||||
head_data = np.zeros((frame_num, 6)) # head_direction (3) + head_translation (3)
|
||||
hand_data = np.zeros((frame_num, 6)) # left_hand_translation (3) + right_hand_translation (3)
|
||||
hand_joint_data = np.zeros((frame_num, 92)) # left_hand (15*3) + right_hand (15*3) + attended_hand_gt + attended_hand_baseline (closest_hand)
|
||||
object_all_data = []
|
||||
object_bbx_all_data = []
|
||||
object_center_all_data = []
|
||||
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nProcessing starts at ' + local_time)
|
||||
for j in range(frame_num):
|
||||
timestamps_ns = img_timestamps_ns[j]
|
||||
|
||||
skeleton_with_dt = gt_provider.get_skeleton_by_timestamp_ns(timestamps_ns, skeleton_ids[0])
|
||||
assert skeleton_with_dt.is_valid(), "skeleton is not valid"
|
||||
|
||||
skeleton = skeleton_with_dt.data()
|
||||
head_translation_id = [4]
|
||||
hand_translation_id = [8, 27]
|
||||
hand_joints_id = [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42]
|
||||
hand_translation = np.array(skeleton.joints)[hand_translation_id, :].reshape(2*3)
|
||||
head_translation = np.array(skeleton.joints)[head_translation_id, :].reshape(1*3)
|
||||
hand_joints = np.array(skeleton.joints)[hand_joints_id, :].reshape(30*3)
|
||||
hand_data[j] = hand_translation
|
||||
hand_joint_data[j, :90] = hand_joints
|
||||
left_hand_joints = hand_joints[:45].reshape(15, 3)
|
||||
left_hand_center = np.mean(left_hand_joints, axis=0)
|
||||
right_hand_joints = hand_joints[45:].reshape(15, 3)
|
||||
right_hand_center = np.mean(right_hand_joints, axis=0)
|
||||
|
||||
# get the Aria pose
|
||||
aria3dpose_with_dt = gt_provider.get_aria_3d_pose_by_timestamp_ns(timestamps_ns)
|
||||
if not aria3dpose_with_dt.is_valid():
|
||||
print("aria 3d pose is not available")
|
||||
aria3dpose = aria3dpose_with_dt.data()
|
||||
transform_scene_device = aria3dpose.transform_scene_device.matrix()
|
||||
|
||||
# get projection function
|
||||
cam_calibration = gt_provider.get_aria_camera_calibration(stream_id)
|
||||
assert cam_calibration is not None, "no camera calibration"
|
||||
|
||||
eye_gaze_with_dt = gt_provider.get_eyegaze_by_timestamp_ns(timestamps_ns)
|
||||
assert eye_gaze_with_dt.is_valid(), "Eye gaze not available"
|
||||
|
||||
# Project the gaze center in CPF frame into camera sensor plane, with multiplication performed in homogenous coordinates
|
||||
eye_gaze = eye_gaze_with_dt.data()
|
||||
gaze_center_in_cpf = np.array([tan(eye_gaze.yaw), tan(eye_gaze.pitch), 1.0], dtype=np.float64) * eye_gaze.depth
|
||||
head_center_in_cpf = np.array([0.0, 0.0, 1.0], dtype=np.float64)
|
||||
transform_cpf_sensor = gt_provider.raw_data_provider_ptr().get_device_calibration().get_transform_cpf_sensor(cam_calibration.get_label())
|
||||
gaze_center_in_camera = transform_cpf_sensor.inverse().matrix() @ np.hstack((gaze_center_in_cpf, 1)).T
|
||||
gaze_center_in_camera = gaze_center_in_camera[:3] / gaze_center_in_camera[3:]
|
||||
gaze_center_in_pixels = cam_calibration.project(gaze_center_in_camera)
|
||||
head_center_in_camera = transform_cpf_sensor.inverse().matrix() @ np.hstack((head_center_in_cpf, 0)).T
|
||||
head_center_in_camera = head_center_in_camera[:3]
|
||||
|
||||
extrinsic_matrix = cam_calibration.get_transform_device_camera().matrix()
|
||||
gaze_center_in_device = (extrinsic_matrix @ np.hstack((gaze_center_in_camera, 1)))[0:3]
|
||||
gaze_center_in_scene = (transform_scene_device @ np.hstack((gaze_center_in_device, 1)))[0:3]
|
||||
head_center_in_device = (extrinsic_matrix @ np.hstack((head_center_in_camera, 0)))[0:3]
|
||||
head_center_in_scene = (transform_scene_device @ np.hstack((head_center_in_device, 0)))[0:3]
|
||||
|
||||
gaze_direction = gaze_center_in_scene - head_translation
|
||||
if np.linalg.norm(gaze_direction) == 0: # invalid data that will be filtered
|
||||
gaze_direction = np.array([0.0, 0.0, 1.0], dtype=np.float64)
|
||||
else:
|
||||
gaze_direction = [x / np.linalg.norm(gaze_direction) for x in gaze_direction]
|
||||
head_direction = head_center_in_scene
|
||||
head_direction = [x / np.linalg.norm(head_direction) for x in head_direction]
|
||||
head_data[j, 0:3] = head_direction
|
||||
head_data[j, 3:6] = head_translation
|
||||
|
||||
left_hand_direction = left_hand_center - head_translation
|
||||
left_hand_direction = np.array([x / np.linalg.norm(left_hand_direction) for x in left_hand_direction])
|
||||
left_hand_distance_to_gaze = np.arccos(np.sum(gaze_direction*left_hand_direction))
|
||||
right_hand_direction = right_hand_center - head_translation
|
||||
right_hand_direction = np.array([x / np.linalg.norm(right_hand_direction) for x in right_hand_direction])
|
||||
right_hand_distance_to_gaze = np.arccos(np.sum(gaze_direction*right_hand_direction))
|
||||
if left_hand_distance_to_gaze < right_hand_distance_to_gaze:
|
||||
hand_joint_data[j, 90:91] = 0
|
||||
else:
|
||||
hand_joint_data[j, 90:91] = 1
|
||||
|
||||
if gaze_center_in_pixels is not None:
|
||||
x_pixel = gaze_center_in_pixels[1]
|
||||
y_pixel = gaze_center_in_pixels[0]
|
||||
gaze_center_in_pixels[0] = x_pixel
|
||||
gaze_center_in_pixels[1] = y_pixel
|
||||
|
||||
useful_frames.append(j)
|
||||
gaze_2d = np.divide(gaze_center_in_pixels, cam_calibration.get_image_size())
|
||||
|
||||
gaze_data[j, 0:3] = gaze_direction
|
||||
gaze_data[j, 3:5] = gaze_2d
|
||||
gaze_data[j, 5:6] = j
|
||||
|
||||
# get the objects
|
||||
bbox3d_with_dt = gt_provider.get_object_3d_boundingboxes_by_timestamp_ns(timestamps_ns)
|
||||
assert bbox3d_with_dt.is_valid(), "3D bounding box is not available"
|
||||
bbox3d_all = bbox3d_with_dt.data()
|
||||
|
||||
object_all = []
|
||||
object_bbx_all = []
|
||||
object_center_all = []
|
||||
|
||||
for obj_id in bbox3d_all:
|
||||
bbox3d = bbox3d_all[obj_id]
|
||||
aabb = bbox3d.aabb
|
||||
aabb_coords = bbox3d_to_line_coordinates(aabb)
|
||||
obb = np.zeros(shape=(len(aabb_coords), 3))
|
||||
for k in range(0, len(aabb_coords)):
|
||||
aabb_pt = aabb_coords[k]
|
||||
aabb_pt_homo = np.append(aabb_pt, [1])
|
||||
obb_pt = (bbox3d.transform_scene_object.matrix() @ aabb_pt_homo)[0:3]
|
||||
obb[k] = obb_pt
|
||||
motion_type = gt_provider.get_instance_info_by_id(obj_id).motion_type
|
||||
if(str(motion_type) == 'MotionType.DYNAMIC'):
|
||||
object_all.append(obb)
|
||||
bbx_idx = [0, 1, 2, 3, 5, 6, 7, 8]
|
||||
obb_bbx = obb[bbx_idx, :]
|
||||
object_bbx_all.append(obb_bbx)
|
||||
obb_center = np.mean(obb_bbx, axis=0)
|
||||
object_center_all.append(obb_center)
|
||||
|
||||
object_all_data.append(object_all)
|
||||
object_bbx_all_data.append(object_bbx_all)
|
||||
object_center_all_data.append(object_center_all)
|
||||
|
||||
gaze_data = gaze_data[useful_frames, :] # useful_frames are actually continuous
|
||||
head_data = head_data[useful_frames, :]
|
||||
hand_data = hand_data[useful_frames, :]
|
||||
hand_joint_data = hand_joint_data[useful_frames, :]
|
||||
|
||||
object_all_data = np.array(object_all_data)
|
||||
object_all_data = object_all_data[useful_frames, :, :, :]
|
||||
#print("Objects shape: {}".format(object_all_data.shape))
|
||||
object_bbx_all_data = np.array(object_bbx_all_data)
|
||||
object_bbx_all_data = object_bbx_all_data[useful_frames, :, :, :]
|
||||
object_center_all_data = np.array(object_center_all_data)
|
||||
object_center_all_data = object_center_all_data[useful_frames, :, :]
|
||||
|
||||
# extract the closest objects to the left or right hands
|
||||
useful_frames_num = len(useful_frames)
|
||||
print("There are {} useful frames".format(useful_frames_num))
|
||||
object_num_all = object_all_data.shape[1]
|
||||
object_left_hand_data = np.zeros((useful_frames_num, object_num, 16, 3))
|
||||
object_bbx_left_hand_data = np.zeros((useful_frames_num, object_num, 8, 3))
|
||||
object_distance_to_left_hand = np.zeros((useful_frames_num, object_num_all))
|
||||
object_right_hand_data = np.zeros((useful_frames_num, object_num, 16, 3))
|
||||
object_bbx_right_hand_data = np.zeros((useful_frames_num, object_num, 8, 3))
|
||||
object_distance_to_right_hand = np.zeros((useful_frames_num, object_num_all))
|
||||
|
||||
for j in range(useful_frames_num):
|
||||
left_hand_joints = hand_joint_data[j, :45].reshape(15, 3)
|
||||
right_hand_joints = hand_joint_data[j, 45:90].reshape(15, 3)
|
||||
for k in range(object_num_all):
|
||||
object_pos = object_center_all_data[j, k, :]
|
||||
object_distance_to_left_hand[j, k] = np.mean(np.linalg.norm(left_hand_joints-object_pos, axis=1))
|
||||
object_distance_to_right_hand[j, k] = np.mean(np.linalg.norm(right_hand_joints-object_pos, axis=1))
|
||||
|
||||
for j in range(useful_frames_num):
|
||||
distance_to_left_hand = object_distance_to_left_hand[j, :]
|
||||
distance_to_left_hand_min = np.min(distance_to_left_hand)
|
||||
distance_to_right_hand = object_distance_to_right_hand[j, :]
|
||||
distance_to_right_hand_min = np.min(distance_to_right_hand)
|
||||
if distance_to_left_hand_min < distance_to_right_hand_min:
|
||||
hand_joint_data[j, 91:92] = 0
|
||||
else:
|
||||
hand_joint_data[j, 91:92] = 1
|
||||
|
||||
left_hand_index = np.argsort(distance_to_left_hand)
|
||||
right_hand_index = np.argsort(distance_to_right_hand)
|
||||
for k in range(object_num):
|
||||
object_left_hand_data[j, k] = object_all_data[j, left_hand_index[k]]
|
||||
object_bbx_left_hand_data[j, k] = object_bbx_all_data[j, left_hand_index[k]]
|
||||
object_right_hand_data[j, k] = object_all_data[j, right_hand_index[k]]
|
||||
object_bbx_right_hand_data[j, k] = object_bbx_all_data[j, right_hand_index[k]]
|
||||
|
||||
gaze_path = save_path + 'gaze.npy'
|
||||
head_path = save_path + 'head.npy'
|
||||
hand_path = save_path + 'hand.npy'
|
||||
hand_joint_path = save_path + 'handjoints.npy'
|
||||
object_left_hand_path = save_path + 'object_left.npy'
|
||||
object_bbx_left_hand_path = save_path + 'object_bbxleft.npy'
|
||||
object_right_hand_path = save_path + 'object_right.npy'
|
||||
object_bbx_right_hand_path = save_path + 'object_bbxright.npy'
|
||||
|
||||
np.save(gaze_path, gaze_data)
|
||||
np.save(head_path, head_data)
|
||||
np.save(hand_path, hand_data)
|
||||
np.save(hand_joint_path, hand_joint_data)
|
||||
np.save(object_left_hand_path, object_left_hand_data)
|
||||
np.save(object_bbx_left_hand_path, object_bbx_left_hand_data)
|
||||
np.save(object_right_hand_path, object_right_hand_data)
|
||||
np.save(object_bbx_right_hand_path, object_bbx_right_hand_data)
|
||||
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nProcessing ends at ' + local_time)
|
162
adt_processing/dataset_visualisation.py
Normal file
162
adt_processing/dataset_visualisation.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
# visualise data in the ADT dataset
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
|
||||
# play human pose using a skeleton
|
||||
class Player_Skeleton:
|
||||
def __init__(self, fps=30.0, object_num=10):
|
||||
|
||||
self._fps = fps
|
||||
self.object_num = object_num
|
||||
# names of all the joints: head + left_hand + right_hand + left_hand_joint + right_hand_joint + gaze_direction + head_direction
|
||||
self._joint_names = ['Head', 'LHand', 'RHand', 'LThumb1', 'LThumb2', 'LThumb3', 'LIndex1', 'LIndex2', 'LIndex3', 'LMiddle1', 'LMiddle2', 'LMiddle3', 'LRing1', 'LRing2', 'LRing3', 'LPinky1', 'LPinky2', 'LPinky3', 'RThumb1', 'RThumb2', 'RThumb3', 'RIndex1', 'RIndex2', 'RIndex3', 'RMiddle1', 'RMiddle2', 'RMiddle3', 'RRing1', 'RRing2', 'RRing3', 'RPinky1', 'RPinky2', 'RPinky3', 'Gaze_direction', 'Head_direction']
|
||||
|
||||
self._joint_ids = {name: idx for idx, name in enumerate(self._joint_names)}
|
||||
|
||||
# parent of every joint
|
||||
self._joint_parent_names = {
|
||||
# root
|
||||
'Head': 'Head',
|
||||
'LHand': 'LHand',
|
||||
'RHand': 'RHand',
|
||||
'LThumb1': 'LHand',
|
||||
'LThumb2': 'LThumb1',
|
||||
'LThumb3': 'LThumb2',
|
||||
'LIndex1': 'LHand',
|
||||
'LIndex2': 'LIndex1',
|
||||
'LIndex3': 'LIndex2',
|
||||
'LMiddle1': 'LHand',
|
||||
'LMiddle2': 'LMiddle1',
|
||||
'LMiddle3': 'LMiddle2',
|
||||
'LRing1': 'LHand',
|
||||
'LRing2': 'LRing1',
|
||||
'LRing3': 'LRing2',
|
||||
'LPinky1': 'LHand',
|
||||
'LPinky2': 'LPinky1',
|
||||
'LPinky3': 'LPinky2',
|
||||
'RThumb1': 'RHand',
|
||||
'RThumb2': 'RThumb1',
|
||||
'RThumb3': 'RThumb2',
|
||||
'RIndex1': 'RHand',
|
||||
'RIndex2': 'RIndex1',
|
||||
'RIndex3': 'RIndex2',
|
||||
'RMiddle1': 'RHand',
|
||||
'RMiddle2': 'RMiddle1',
|
||||
'RMiddle3': 'RMiddle2',
|
||||
'RRing1': 'RHand',
|
||||
'RRing2': 'RRing1',
|
||||
'RRing3': 'RRing2',
|
||||
'RPinky1': 'RHand',
|
||||
'RPinky2': 'RPinky1',
|
||||
'RPinky3': 'RPinky2',
|
||||
'Gaze_direction': 'Head',
|
||||
'Head_direction': 'Head',}
|
||||
|
||||
# id of joint parent
|
||||
self._joint_parent_ids = [self._joint_ids[self._joint_parent_names[child_name]] for child_name in self._joint_names]
|
||||
self._joint_links = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
|
||||
# colors: 0 for head, 1 for left, 2 for right
|
||||
self._link_colors = [1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4]
|
||||
|
||||
self._fig = plt.figure()
|
||||
self._ax = plt.gca(projection='3d')
|
||||
|
||||
self._plots = []
|
||||
for i in range(len(self._joint_links)):
|
||||
if self._link_colors[i] == 0:
|
||||
color = "#3498db"
|
||||
if self._link_colors[i] == 1:
|
||||
color = "#3498db"
|
||||
if self._link_colors[i] == 2:
|
||||
color = "#3498db"
|
||||
if self._link_colors[i] == 3:
|
||||
color = "#6aa84f"
|
||||
if self._link_colors[i] == 4:
|
||||
color = "#a64d79"
|
||||
self._plots.append(self._ax.plot([0, 0], [0, 0], [0, 0], lw=2.0, c=color))
|
||||
|
||||
for i in range(self.object_num):
|
||||
self._plots.append(self._ax.plot([0, 0], [0, 0], [0, 0], lw=1.0, c='#ff0000'))
|
||||
|
||||
self._ax.set_xlabel("x")
|
||||
self._ax.set_ylabel("y")
|
||||
self._ax.set_zlabel("z")
|
||||
|
||||
# play the sequence of human pose in xyz representations
|
||||
def play_xyz(self, pose_xyz, gaze, head, objects):
|
||||
gaze_direction = pose_xyz[:, :3] + gaze[:, :3]*0.5
|
||||
head_direction = pose_xyz[:, :3] + head[:, :3]*0.5
|
||||
pose_xyz = np.concatenate((pose_xyz, gaze_direction), axis = 1)
|
||||
pose_xyz = np.concatenate((pose_xyz, head_direction), axis = 1)
|
||||
|
||||
for i in range(pose_xyz.shape[0]):
|
||||
joint_number = len(self._joint_names)
|
||||
pose_xyz_tmp = pose_xyz[i].reshape(joint_number, 3)
|
||||
objects_xyz = objects[i, :, :, :]
|
||||
for j in range(len(self._joint_links)):
|
||||
idx = self._joint_links[j]
|
||||
start_point = pose_xyz_tmp[idx]
|
||||
end_point = pose_xyz_tmp[self._joint_parent_ids[idx]]
|
||||
x = np.array([start_point[0], end_point[0]])
|
||||
y = np.array([start_point[2], end_point[2]])
|
||||
z = np.array([start_point[1], end_point[1]])
|
||||
self._plots[j][0].set_xdata(x)
|
||||
self._plots[j][0].set_ydata(y)
|
||||
self._plots[j][0].set_3d_properties(z)
|
||||
|
||||
for j in range(len(self._joint_links), len(self._joint_links) + objects_xyz.shape[0]):
|
||||
object_xyz = objects_xyz[j - len(self._joint_links), :, :]
|
||||
self._plots[j][0].set_xdata(object_xyz[:, 0])
|
||||
self._plots[j][0].set_ydata(object_xyz[:, 2])
|
||||
self._plots[j][0].set_3d_properties(object_xyz[:, 1])
|
||||
|
||||
r = 1.0
|
||||
x_root, y_root, z_root = pose_xyz_tmp[0, 0], pose_xyz_tmp[0, 2], pose_xyz_tmp[0, 1]
|
||||
self._ax.set_xlim3d([-r + x_root, r + x_root])
|
||||
self._ax.set_ylim3d([-r + y_root, r + y_root])
|
||||
self._ax.set_zlim3d([-r + z_root, r + z_root])
|
||||
#self._ax.view_init(elev=30, azim=-110)
|
||||
|
||||
self._ax.grid(False)
|
||||
#self._ax.axis('off')
|
||||
|
||||
self._ax.set_aspect('auto')
|
||||
plt.show(block=False)
|
||||
self._fig.canvas.draw()
|
||||
past_time = f"{i / self._fps:.1f}"
|
||||
plt.title(f"Time: {past_time} s", fontsize=15)
|
||||
plt.pause(0.000000001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = '/scratch/hu/pose_forecast/adt_hoigaze/test/Apartment_release_meal_skeleton_seq132_'
|
||||
gaze_path = data_path + 'gaze.npy'
|
||||
head_path = data_path + 'head.npy'
|
||||
hand_path = data_path + 'hand.npy'
|
||||
hand_joint_path = data_path + 'handjoints.npy'
|
||||
object_left_hand_path = data_path + 'object_left.npy'
|
||||
object_right_hand_path = data_path + 'object_right.npy'
|
||||
|
||||
gaze = np.load(gaze_path) # gaze_direction (3) + gaze_2d (2) + frame_id (1)
|
||||
print("Gaze shape: {}".format(gaze.shape))
|
||||
gaze_direction = gaze[:, :3]
|
||||
head = np.load(head_path) # head_direction (3) + head_translation (3)
|
||||
print("Head shape: {}".format(head.shape))
|
||||
head_direction = head[:, :3]
|
||||
head_translation = head[:, 3:]
|
||||
hand_translation = np.load(hand_path) # left_hand_translation (3) + right_hand_translation (3)
|
||||
print("Hand shape: {}".format(hand_translation.shape))
|
||||
hand_joint = np.load(hand_joint_path) # left_hand (15*3) + right_hand (15*3) + hand_dominance + closest_hand
|
||||
print("Hand joint shape: {}".format(hand_joint.shape))
|
||||
hand_joint = hand_joint[:, :90]
|
||||
pose = np.concatenate((head_translation, hand_translation), axis=1)
|
||||
pose = np.concatenate((pose, hand_joint), axis=1)
|
||||
object_left = np.load(object_left_hand_path)[:, :, :, :]
|
||||
object_right = np.load(object_right_hand_path)[:, :, :, :]
|
||||
object_all = np.concatenate((object_left, object_right), axis=1)
|
||||
print("Object shape: {}".format(object_all.shape))
|
||||
|
||||
player = Player_Skeleton(object_num = object_all.shape[1])
|
||||
player.play_xyz(pose, gaze_direction, head_direction, object_all)
|
4
adt_processing/utils/__init__.py
Normal file
4
adt_processing/utils/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
__all__ = ['file_systems']
|
||||
|
||||
from .file_systems import remake_dir, make_dir
|
||||
|
50
adt_processing/utils/file_systems.py
Normal file
50
adt_processing/utils/file_systems.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
|
||||
# remove a directory
|
||||
def remove_dir(dirName):
|
||||
if os.path.exists(dirName):
|
||||
shutil.rmtree(dirName)
|
||||
else:
|
||||
print("Invalid directory path!")
|
||||
|
||||
|
||||
# remake a directory
|
||||
def remake_dir(dirName):
|
||||
if os.path.exists(dirName):
|
||||
shutil.rmtree(dirName)
|
||||
os.makedirs(dirName)
|
||||
else:
|
||||
os.makedirs(dirName)
|
||||
|
||||
|
||||
# calculate the number of lines in a file
|
||||
def file_lines(fileName):
|
||||
if os.path.exists(fileName):
|
||||
with open(fileName, 'r') as fr:
|
||||
return len(fr.readlines())
|
||||
else:
|
||||
print("Invalid file path!")
|
||||
return 0
|
||||
|
||||
|
||||
# make a directory if it does not exist.
|
||||
def make_dir(dirName):
|
||||
if os.path.exists(dirName):
|
||||
print("Directory "+ dirName + " already exists.")
|
||||
else:
|
||||
os.makedirs(dirName)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dirName = "test"
|
||||
RemakeDir(dirName)
|
||||
time.sleep(3)
|
||||
MakeDir(dirName)
|
||||
RemoveDir(dirName)
|
||||
time.sleep(3)
|
||||
MakeDir(dirName)
|
||||
#print(FileLines('233.txt'))
|
269
attended_hand_recognition_adt.py
Normal file
269
attended_hand_recognition_adt.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
from utils import adt_dataset, seed_torch
|
||||
from model import attended_hand_recognition
|
||||
from utils.opt import options
|
||||
from utils import log
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import datetime
|
||||
import torch.optim as optim
|
||||
import os
|
||||
os.nice(5)
|
||||
import math
|
||||
|
||||
|
||||
def main(opt):
|
||||
# set the random seed to ensure reproducibility
|
||||
seed_torch.seed_torch(seed=0)
|
||||
torch.set_num_threads(1)
|
||||
|
||||
data_dir = opt.data_dir
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
learning_rate = opt.learning_rate
|
||||
print('>>> create model')
|
||||
net = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
|
||||
optimizer = optim.AdamW(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate, weight_decay=opt.weight_decay)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
print('>>> loading datasets')
|
||||
|
||||
train_actions = 'all'
|
||||
test_actions = opt.actions
|
||||
train_dataset = adt_dataset.adt_dataset(data_dir, seq_len, train_actions, 1, opt.object_num, opt.hand_joint_number, opt.sample_rate)
|
||||
train_data_size = train_dataset.dataset.shape
|
||||
print("Training data size: {}".format(train_data_size))
|
||||
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True)
|
||||
valid_dataset = adt_dataset.adt_dataset(data_dir, seq_len, test_actions, 0, opt.object_num, opt.hand_joint_number, opt.sample_rate)
|
||||
valid_data_size = valid_dataset.dataset.shape
|
||||
print("Validation data size: {}".format(valid_data_size))
|
||||
valid_loader = DataLoader(valid_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
# training
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTraining starts at ' + local_time)
|
||||
start_time = datetime.datetime.now()
|
||||
start_epoch = 1
|
||||
|
||||
acc_best = 0
|
||||
best_epoch = 0
|
||||
exp_lr = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma, last_epoch=-1)
|
||||
for epo in range(start_epoch, opt.epoch + 1):
|
||||
is_best = False
|
||||
learning_rate = exp_lr.optimizer.param_groups[0]["lr"]
|
||||
|
||||
train_start_time = datetime.datetime.now()
|
||||
result_train = run_model(net, optimizer, is_train=1, data_loader=train_loader, opt=opt)
|
||||
train_end_time = datetime.datetime.now()
|
||||
train_time = (train_end_time - train_start_time).seconds*1000
|
||||
train_batch_num = math.ceil(train_data_size[0]/opt.batch_size)
|
||||
train_time_per_batch = math.ceil(train_time/train_batch_num)
|
||||
#print('\nTraining time per batch: {} ms'.format(train_time_per_batch))
|
||||
|
||||
exp_lr.step()
|
||||
rng_state = torch.get_rng_state()
|
||||
if epo % opt.validation_epoch == 0:
|
||||
print('>>> training epoch: {:d}, lr: {:.12f}'.format(epo, learning_rate))
|
||||
print('Training data size: {}'.format(train_data_size))
|
||||
print('Average baseline acc: {:.2f}%'.format(result_train['baseline_acc_average']*100))
|
||||
print('Average training acc: {:.2f}%'.format(result_train['prediction_acc_average']*100))
|
||||
|
||||
test_start_time = datetime.datetime.now()
|
||||
result_valid = run_model(net, is_train=0, data_loader=valid_loader, opt=opt)
|
||||
test_end_time = datetime.datetime.now()
|
||||
test_time = (test_end_time - test_start_time).seconds*1000
|
||||
test_batch_num = math.ceil(valid_data_size[0]/opt.test_batch_size)
|
||||
test_time_per_batch = math.ceil(test_time/test_batch_num)
|
||||
#print('\nTest time per batch: {} ms'.format(test_time_per_batch))
|
||||
print('Validation data size: {}'.format(valid_data_size))
|
||||
|
||||
print('Average baseline acc: {:.2f}%'.format(result_valid['baseline_acc_average']*100))
|
||||
print('Average validation acc: {:.2f}%'.format(result_valid['prediction_acc_average']*100))
|
||||
|
||||
if result_valid['prediction_acc_average'] > acc_best:
|
||||
acc_best = result_valid['prediction_acc_average']
|
||||
is_best = True
|
||||
best_epoch = epo
|
||||
|
||||
print('Best validation error: {:.2f}%, best epoch: {}'.format(acc_best*100, best_epoch))
|
||||
end_time = datetime.datetime.now()
|
||||
total_training_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal training time: {:.1f} min'.format(total_training_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTraining ends at ' + local_time)
|
||||
|
||||
result_log = np.array([epo, learning_rate])
|
||||
head = np.array(['epoch', 'lr'])
|
||||
for k in result_train.keys():
|
||||
result_log = np.append(result_log, [result_train[k]])
|
||||
head = np.append(head, [k])
|
||||
for k in result_valid.keys():
|
||||
result_log = np.append(result_log, [result_valid[k]])
|
||||
head = np.append(head, ['valid_' + k])
|
||||
|
||||
csv_name = 'attended_hand_recognition_results'
|
||||
model_name = 'attended_hand_recognition_model.pt'
|
||||
log.save_csv_log(opt, head, result_log, is_create=(epo == 1), file_name=csv_name)
|
||||
log.save_ckpt({'epoch': epo,
|
||||
'lr': learning_rate,
|
||||
'acc': result_valid['prediction_acc_average'],
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict()},
|
||||
opt=opt,
|
||||
file_name = model_name)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
|
||||
def eval(opt):
|
||||
data_dir = opt.data_dir
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
|
||||
print('>>> create model')
|
||||
net = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
#load model
|
||||
model_name = 'attended_hand_recognition_model.pt'
|
||||
model_path = os.path.join(opt.ckpt, model_name)
|
||||
print(">>> loading ckpt from '{}'".format(model_path))
|
||||
ckpt = torch.load(model_path)
|
||||
net.load_state_dict(ckpt['state_dict'])
|
||||
print(">>> ckpt loaded (epoch: {} | acc: {})".format(ckpt['epoch'], ckpt['acc']))
|
||||
|
||||
print('>>> loading datasets')
|
||||
train_actions = 'all'
|
||||
test_actions = opt.actions
|
||||
train_dataset = adt_dataset.adt_dataset(data_dir, seq_len, train_actions, 1, opt.object_num, opt.hand_joint_number, opt.sample_rate)
|
||||
train_data_size = train_dataset.dataset.shape
|
||||
print("Train data size: {}".format(train_data_size))
|
||||
train_loader = DataLoader(train_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
test_dataset = adt_dataset.adt_dataset(data_dir, seq_len, test_actions, 0, opt.object_num, opt.hand_joint_number, opt.sample_rate)
|
||||
test_data_size = test_dataset.dataset.shape
|
||||
print("Test data size: {}".format(test_data_size))
|
||||
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
# test
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTest starts at ' + local_time)
|
||||
start_time = datetime.datetime.now()
|
||||
if opt.save_predictions:
|
||||
result_train, predictions_train = run_model(net, is_train=0, data_loader=train_loader, opt=opt)
|
||||
result_test, predictions_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
|
||||
else:
|
||||
result_train = run_model(net, is_train=0, data_loader=train_loader, opt=opt)
|
||||
result_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
|
||||
|
||||
print('Average train baseline acc: {:.2f}%'.format(result_train['baseline_acc_average']*100))
|
||||
print('Average train method acc: {:.2f}%'.format(result_train['prediction_acc_average']*100))
|
||||
print('Average test baseline acc: {:.2f}%'.format(result_test['baseline_acc_average']*100))
|
||||
print('Average test method acc: {:.2f}%'.format(result_test['prediction_acc_average']*100))
|
||||
|
||||
end_time = datetime.datetime.now()
|
||||
total_test_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal test time: {:.1f} min'.format(total_test_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTest ends at ' + local_time)
|
||||
|
||||
if opt.save_predictions:
|
||||
prediction = predictions_train[:, :, -3:-1].reshape(-1, 2)
|
||||
attended_hand_gt = predictions_train[:, :, -1:].reshape(-1)
|
||||
y_prd = np.argmax(prediction, axis=1)
|
||||
acc = np.sum(y_prd == attended_hand_gt)/prediction.shape[0]
|
||||
print('Average train acc: {:.2f}%'.format(acc*100))
|
||||
predictions_train_path = os.path.join(opt.ckpt, "attended_hand_recognition_train.npy")
|
||||
np.save(predictions_train_path, predictions_train)
|
||||
|
||||
prediction = predictions_test[:, :, -3:-1].reshape(-1, 2)
|
||||
attended_hand_gt = predictions_test[:, :, -1:].reshape(-1)
|
||||
y_prd = np.argmax(prediction, axis=1)
|
||||
acc = np.sum(y_prd == attended_hand_gt)/prediction.shape[0]
|
||||
print('Average test acc: {:.2f}%'.format(acc*100))
|
||||
predictions_test_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
|
||||
np.save(predictions_test_path, predictions_test)
|
||||
|
||||
|
||||
def run_model(net, optimizer=None, is_train=1, data_loader=None, opt=None):
|
||||
if is_train == 1:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
predictions = []
|
||||
|
||||
prediction_acc_average = 0
|
||||
baseline_acc_average = 0
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
n = 0
|
||||
input_n = opt.seq_len
|
||||
|
||||
for i, (data) in enumerate(data_loader):
|
||||
batch_size, seq_n, dim = data.shape
|
||||
joint_number = opt.joint_number
|
||||
object_num = opt.object_num
|
||||
# when only one sample in this batch
|
||||
if batch_size == 1 and is_train == 1:
|
||||
continue
|
||||
n += batch_size*seq_n
|
||||
data = data.float().to(opt.cuda_idx)
|
||||
|
||||
eye_gaze = data.clone()[:, :, :3]
|
||||
joints = data.clone()[:, :, 3:(joint_number+1)*3]
|
||||
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
|
||||
attended_hand_gt = data.clone()[:, :, (joint_number+2+8*object_num*2)*3:(joint_number+2+8*object_num*2)*3+1].type(torch.LongTensor).to(opt.cuda_idx)
|
||||
attended_hand_baseline = data.clone()[:, :, (joint_number+2+8*object_num*2)*3+1:(joint_number+2+8*object_num*2)*3+2].type(torch.LongTensor).to(opt.cuda_idx)
|
||||
|
||||
input = torch.cat((joints, head_directions), dim=2)
|
||||
if object_num > 0:
|
||||
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
|
||||
input = torch.cat((input, object_positions), dim=2)
|
||||
prediction = net(input, input_n=input_n)
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
# eye_gaze + joints + head_directions + object_positions + predictions + attended_hand_gt
|
||||
prediction = torch.nn.functional.softmax(prediction, dim=2)
|
||||
prediction_cpu = torch.cat((eye_gaze, input), dim=2)
|
||||
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
|
||||
prediction_cpu = torch.cat((prediction_cpu, attended_hand_gt), dim=2)
|
||||
prediction_cpu = prediction_cpu.cpu().data.numpy()
|
||||
if len(predictions) == 0:
|
||||
predictions = prediction_cpu
|
||||
else:
|
||||
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
|
||||
|
||||
attended_hand_gt = attended_hand_gt.reshape(batch_size*input_n)
|
||||
attended_hand_baseline = attended_hand_baseline.reshape(batch_size*input_n)
|
||||
prediction = prediction.reshape(-1, 2)
|
||||
loss = criterion(prediction, attended_hand_gt)
|
||||
|
||||
if is_train == 1:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# calculate prediction accuracy
|
||||
_, y_prd = torch.max(prediction.data, 1)
|
||||
acc = torch.sum(y_prd == attended_hand_gt)/(batch_size*input_n)
|
||||
prediction_acc_average += acc.cpu().data.numpy() * batch_size*input_n
|
||||
|
||||
acc = torch.sum(attended_hand_gt == attended_hand_baseline)/(batch_size*input_n)
|
||||
baseline_acc_average += acc.cpu().data.numpy() * batch_size*input_n
|
||||
|
||||
result = {}
|
||||
result["baseline_acc_average"] = baseline_acc_average / n
|
||||
result["prediction_acc_average"] = prediction_acc_average / n
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
return result, predictions
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
option = options().parse()
|
||||
if option.is_eval == False:
|
||||
main(option)
|
||||
else:
|
||||
eval(option)
|
368
attended_hand_recognition_hot3d.py
Normal file
368
attended_hand_recognition_hot3d.py
Normal file
|
@ -0,0 +1,368 @@
|
|||
from utils import hot3d_aria_dataset, seed_torch
|
||||
from model import attended_hand_recognition
|
||||
from utils.opt import options
|
||||
from utils import log
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import datetime
|
||||
import torch.optim as optim
|
||||
import os
|
||||
os.nice(5)
|
||||
import math
|
||||
|
||||
|
||||
def main(opt):
|
||||
# set the random seed to ensure reproducibility
|
||||
seed_torch.seed_torch(seed=0)
|
||||
torch.set_num_threads(1)
|
||||
|
||||
data_dir = opt.data_dir
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
learning_rate = opt.learning_rate
|
||||
print('>>> create model')
|
||||
net = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
|
||||
optimizer = optim.AdamW(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate, weight_decay=opt.weight_decay)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
print('>>> loading datasets')
|
||||
|
||||
actions = opt.actions
|
||||
test_user_id = opt.test_user_id
|
||||
if actions == 'all':
|
||||
if test_user_id == 1:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003']
|
||||
opt.ckpt = opt.ckpt + '/user1/'
|
||||
if test_user_id == 2:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0009', 'P0010', 'P0011']
|
||||
opt.ckpt = opt.ckpt + '/user2/'
|
||||
if test_user_id == 3:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
|
||||
test_subjects = ['P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/user3/'
|
||||
elif actions == 'room':
|
||||
train_actions = ['kitchen', 'office']
|
||||
test_actions = ['room']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene1/'
|
||||
elif actions == 'kitchen':
|
||||
train_actions = ['room', 'office']
|
||||
test_actions = ['kitchen']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene2/'
|
||||
elif actions == 'office':
|
||||
train_actions = ['room', 'kitchen']
|
||||
test_actions = ['office']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene3/'
|
||||
else:
|
||||
raise( ValueError, "Unrecognised actions: %d" % actions)
|
||||
|
||||
if not os.path.isdir(opt.ckpt):
|
||||
os.makedirs(opt.ckpt)
|
||||
|
||||
train_dataset = hot3d_aria_dataset.hot3d_aria_dataset(data_dir, train_subjects, seq_len, train_actions, opt.object_num, opt.sample_rate)
|
||||
train_data_size = train_dataset.dataset.shape
|
||||
print("Training data size: {}".format(train_data_size))
|
||||
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True)
|
||||
valid_dataset = hot3d_aria_dataset.hot3d_aria_dataset(data_dir, test_subjects, seq_len, test_actions, opt.object_num, opt.sample_rate)
|
||||
valid_data_size = valid_dataset.dataset.shape
|
||||
print("Validation data size: {}".format(valid_data_size))
|
||||
valid_loader = DataLoader(valid_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
# training
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTraining starts at ' + local_time)
|
||||
start_time = datetime.datetime.now()
|
||||
start_epoch = 1
|
||||
|
||||
acc_best = 0
|
||||
best_epoch = 0
|
||||
exp_lr = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma, last_epoch=-1)
|
||||
for epo in range(start_epoch, opt.epoch + 1):
|
||||
is_best = False
|
||||
learning_rate = exp_lr.optimizer.param_groups[0]["lr"]
|
||||
|
||||
train_start_time = datetime.datetime.now()
|
||||
result_train = run_model(net, optimizer, is_train=1, data_loader=train_loader, opt=opt)
|
||||
train_end_time = datetime.datetime.now()
|
||||
train_time = (train_end_time - train_start_time).seconds*1000
|
||||
train_batch_num = math.ceil(train_data_size[0]/opt.batch_size)
|
||||
train_time_per_batch = math.ceil(train_time/train_batch_num)
|
||||
#print('\nTraining time per batch: {} ms'.format(train_time_per_batch))
|
||||
|
||||
exp_lr.step()
|
||||
rng_state = torch.get_rng_state()
|
||||
if epo % opt.validation_epoch == 0:
|
||||
if actions == 'all':
|
||||
print("\ntest user id: {}\n".format(test_user_id))
|
||||
elif actions == 'room':
|
||||
print("\ntest scene/action: room\n")
|
||||
elif actions == 'kitchen':
|
||||
print("\ntest scene/action: kitchen\n")
|
||||
elif actions == 'office':
|
||||
print("\ntest scene/action: office\n")
|
||||
print('>>> training epoch: {:d}, lr: {:.12f}'.format(epo, learning_rate))
|
||||
print('Training data size: {}'.format(train_data_size))
|
||||
print('Average baseline acc: {:.2f}%'.format(result_train['baseline_acc_average']*100))
|
||||
print('Average training acc: {:.2f}%'.format(result_train['prediction_acc_average']*100))
|
||||
|
||||
test_start_time = datetime.datetime.now()
|
||||
result_valid = run_model(net, is_train=0, data_loader=valid_loader, opt=opt)
|
||||
test_end_time = datetime.datetime.now()
|
||||
test_time = (test_end_time - test_start_time).seconds*1000
|
||||
test_batch_num = math.ceil(valid_data_size[0]/opt.test_batch_size)
|
||||
test_time_per_batch = math.ceil(test_time/test_batch_num)
|
||||
#print('\nTest time per batch: {} ms'.format(test_time_per_batch))
|
||||
print('Validation data size: {}'.format(valid_data_size))
|
||||
|
||||
print('Average baseline acc: {:.2f}%'.format(result_valid['baseline_acc_average']*100))
|
||||
print('Average validation acc: {:.2f}%'.format(result_valid['prediction_acc_average']*100))
|
||||
|
||||
if result_valid['prediction_acc_average'] > acc_best:
|
||||
acc_best = result_valid['prediction_acc_average']
|
||||
is_best = True
|
||||
best_epoch = epo
|
||||
|
||||
print('Best validation error: {:.2f}%, best epoch: {}'.format(acc_best*100, best_epoch))
|
||||
end_time = datetime.datetime.now()
|
||||
total_training_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal training time: {:.1f} min'.format(total_training_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTraining ends at ' + local_time)
|
||||
|
||||
result_log = np.array([epo, learning_rate])
|
||||
head = np.array(['epoch', 'lr'])
|
||||
for k in result_train.keys():
|
||||
result_log = np.append(result_log, [result_train[k]])
|
||||
head = np.append(head, [k])
|
||||
for k in result_valid.keys():
|
||||
result_log = np.append(result_log, [result_valid[k]])
|
||||
head = np.append(head, ['valid_' + k])
|
||||
|
||||
csv_name = 'attended_hand_recognition_results'
|
||||
model_name = 'attended_hand_recognition_model.pt'
|
||||
log.save_csv_log(opt, head, result_log, is_create=(epo == 1), file_name=csv_name)
|
||||
log.save_ckpt({'epoch': epo,
|
||||
'lr': learning_rate,
|
||||
'acc': result_valid['prediction_acc_average'],
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict()},
|
||||
opt=opt,
|
||||
file_name = model_name)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
|
||||
def eval(opt):
|
||||
data_dir = opt.data_dir
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
|
||||
print('>>> create model')
|
||||
net = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
#load model
|
||||
actions = opt.actions
|
||||
test_user_id = opt.test_user_id
|
||||
if actions == 'all':
|
||||
if test_user_id == 1:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003']
|
||||
opt.ckpt = opt.ckpt + '/user1/'
|
||||
if test_user_id == 2:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0009', 'P0010', 'P0011']
|
||||
opt.ckpt = opt.ckpt + '/user2/'
|
||||
if test_user_id == 3:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
|
||||
test_subjects = ['P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/user3/'
|
||||
elif actions == 'room':
|
||||
train_actions = ['kitchen', 'office']
|
||||
test_actions = ['room']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene1/'
|
||||
elif actions == 'kitchen':
|
||||
train_actions = ['room', 'office']
|
||||
test_actions = ['kitchen']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene2/'
|
||||
elif actions == 'office':
|
||||
train_actions = ['room', 'kitchen']
|
||||
test_actions = ['office']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene3/'
|
||||
else:
|
||||
raise( ValueError, "Unrecognised actions: %d" % actions)
|
||||
|
||||
model_name = 'attended_hand_recognition_model.pt'
|
||||
model_path = os.path.join(opt.ckpt, model_name)
|
||||
print(">>> loading ckpt from '{}'".format(model_path))
|
||||
ckpt = torch.load(model_path)
|
||||
net.load_state_dict(ckpt['state_dict'])
|
||||
print(">>> ckpt loaded (epoch: {} | acc: {})".format(ckpt['epoch'], ckpt['acc']))
|
||||
|
||||
print('>>> loading datasets')
|
||||
train_dataset = hot3d_aria_dataset.hot3d_aria_dataset(data_dir, train_subjects, seq_len, train_actions, opt.object_num, opt.sample_rate)
|
||||
train_data_size = train_dataset.dataset.shape
|
||||
print("Train data size: {}".format(train_data_size))
|
||||
train_loader = DataLoader(train_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
test_dataset = hot3d_aria_dataset.hot3d_aria_dataset(data_dir, test_subjects, seq_len, test_actions, opt.object_num, opt.sample_rate)
|
||||
test_data_size = test_dataset.dataset.shape
|
||||
print("Test data size: {}".format(test_data_size))
|
||||
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
# test
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTest starts at ' + local_time)
|
||||
start_time = datetime.datetime.now()
|
||||
if actions == 'all':
|
||||
print("\ntest user id: {}\n".format(test_user_id))
|
||||
elif actions == 'room':
|
||||
print("\ntest scene/action: room\n")
|
||||
elif actions == 'kitchen':
|
||||
print("\ntest scene/action: kitchen\n")
|
||||
elif actions == 'office':
|
||||
print("\ntest scene/action: office\n")
|
||||
if opt.save_predictions:
|
||||
result_train, predictions_train = run_model(net, is_train=0, data_loader=train_loader, opt=opt)
|
||||
result_test, predictions_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
|
||||
else:
|
||||
result_train = run_model(net, is_train=0, data_loader=train_loader, opt=opt)
|
||||
result_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
|
||||
|
||||
print('Average train baseline acc: {:.2f}%'.format(result_train['baseline_acc_average']*100))
|
||||
print('Average train method acc: {:.2f}%'.format(result_train['prediction_acc_average']*100))
|
||||
print('Average test baseline acc: {:.2f}%'.format(result_test['baseline_acc_average']*100))
|
||||
print('Average test method acc: {:.2f}%'.format(result_test['prediction_acc_average']*100))
|
||||
|
||||
end_time = datetime.datetime.now()
|
||||
total_test_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal test time: {:.1f} min'.format(total_test_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTest ends at ' + local_time)
|
||||
|
||||
if opt.save_predictions:
|
||||
prediction = predictions_train[:, :, -3:-1].reshape(-1, 2)
|
||||
attended_hand_gt = predictions_train[:, :, -1:].reshape(-1)
|
||||
y_prd = np.argmax(prediction, axis=1)
|
||||
acc = np.sum(y_prd == attended_hand_gt)/prediction.shape[0]
|
||||
print('Average train acc: {:.2f}%'.format(acc*100))
|
||||
predictions_train_path = os.path.join(opt.ckpt, "attended_hand_recognition_train.npy")
|
||||
np.save(predictions_train_path, predictions_train)
|
||||
|
||||
prediction = predictions_test[:, :, -3:-1].reshape(-1, 2)
|
||||
attended_hand_gt = predictions_test[:, :, -1:].reshape(-1)
|
||||
y_prd = np.argmax(prediction, axis=1)
|
||||
acc = np.sum(y_prd == attended_hand_gt)/prediction.shape[0]
|
||||
print('Average test acc: {:.2f}%'.format(acc*100))
|
||||
predictions_test_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
|
||||
np.save(predictions_test_path, predictions_test)
|
||||
|
||||
|
||||
def run_model(net, optimizer=None, is_train=1, data_loader=None, opt=None):
|
||||
if is_train == 1:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
predictions = []
|
||||
|
||||
prediction_acc_average = 0
|
||||
baseline_acc_average = 0
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
n = 0
|
||||
input_n = opt.seq_len
|
||||
|
||||
for i, (data) in enumerate(data_loader):
|
||||
batch_size, seq_n, dim = data.shape
|
||||
joint_number = opt.joint_number
|
||||
object_num = opt.object_num
|
||||
# when only one sample in this batch
|
||||
if batch_size == 1 and is_train == 1:
|
||||
continue
|
||||
n += batch_size*seq_n
|
||||
data = data.float().to(opt.cuda_idx)
|
||||
|
||||
eye_gaze = data.clone()[:, :, :3]
|
||||
joints = data.clone()[:, :, 3:(joint_number+1)*3]
|
||||
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
|
||||
attended_hand_gt = data.clone()[:, :, (joint_number+2+8*object_num*2)*3:(joint_number+2+8*object_num*2)*3+1].type(torch.LongTensor).to(opt.cuda_idx)
|
||||
attended_hand_baseline = data.clone()[:, :, (joint_number+2+8*object_num*2)*3+1:(joint_number+2+8*object_num*2)*3+2].type(torch.LongTensor).to(opt.cuda_idx)
|
||||
|
||||
input = torch.cat((joints, head_directions), dim=2)
|
||||
if object_num > 0:
|
||||
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
|
||||
input = torch.cat((input, object_positions), dim=2)
|
||||
prediction = net(input, input_n=input_n)
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
# eye_gaze + joints + head_directions + object_positions + predictions + attended_hand_gt
|
||||
prediction = torch.nn.functional.softmax(prediction, dim=2)
|
||||
prediction_cpu = torch.cat((eye_gaze, input), dim=2)
|
||||
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
|
||||
prediction_cpu = torch.cat((prediction_cpu, attended_hand_gt), dim=2)
|
||||
prediction_cpu = prediction_cpu.cpu().data.numpy()
|
||||
if len(predictions) == 0:
|
||||
predictions = prediction_cpu
|
||||
else:
|
||||
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
|
||||
|
||||
attended_hand_gt = attended_hand_gt.reshape(batch_size*input_n)
|
||||
attended_hand_baseline = attended_hand_baseline.reshape(batch_size*input_n)
|
||||
prediction = prediction.reshape(-1, 2)
|
||||
loss = criterion(prediction, attended_hand_gt)
|
||||
|
||||
if is_train == 1:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# calculate prediction accuracy
|
||||
_, y_prd = torch.max(prediction.data, 1)
|
||||
acc = torch.sum(y_prd == attended_hand_gt)/(batch_size*input_n)
|
||||
prediction_acc_average += acc.cpu().data.numpy() * batch_size*input_n
|
||||
|
||||
acc = torch.sum(attended_hand_gt == attended_hand_baseline)/(batch_size*input_n)
|
||||
baseline_acc_average += acc.cpu().data.numpy() * batch_size*input_n
|
||||
|
||||
result = {}
|
||||
result["baseline_acc_average"] = baseline_acc_average / n
|
||||
result["prediction_acc_average"] = prediction_acc_average / n
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
return result, predictions
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
option = options().parse()
|
||||
if option.is_eval == False:
|
||||
main(option)
|
||||
else:
|
||||
eval(option)
|
BIN
checkpoints/adt/attended_hand_recognition_model.pt
Normal file
BIN
checkpoints/adt/attended_hand_recognition_model.pt
Normal file
Binary file not shown.
7
checkpoints/adt/attended_hand_recognition_results.csv
Normal file
7
checkpoints/adt/attended_hand_recognition_results.csv
Normal file
|
@ -0,0 +1,7 @@
|
|||
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
|
||||
10.0,0.0031512470486230455,0.6745207058933649,0.8722085664853889,0.6619414385512743,0.8330426198868408
|
||||
20.0,0.0018867680126765363,0.6745207054940767,0.8948185154648364,0.6619414385512743,0.8627868126688715
|
||||
30.0,0.0011296777049628276,0.674520705352511,0.908833516738123,0.6619414385512743,0.8652105219446476
|
||||
40.0,0.0006763797713952794,0.6745207055067813,0.915962817709497,0.6619414385512743,0.8569523197573703
|
||||
50.0,0.0004049735540879638,0.6745207047935076,0.9201425524065895,0.6619414385512743,0.8638178665451156
|
||||
60.0,0.00024247262624711545,0.674520706799023,0.9228809992886139,0.6619414385512743,0.8646388007627592
|
|
BIN
checkpoints/adt/gaze_estimation_model_best.pt
Normal file
BIN
checkpoints/adt/gaze_estimation_model_best.pt
Normal file
Binary file not shown.
BIN
checkpoints/adt/gaze_estimation_model_last.pt
Normal file
BIN
checkpoints/adt/gaze_estimation_model_last.pt
Normal file
Binary file not shown.
9
checkpoints/adt/gaze_estimation_results.csv
Normal file
9
checkpoints/adt/gaze_estimation_results.csv
Normal file
|
@ -0,0 +1,9 @@
|
|||
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
|
||||
10.0,0.0006710886400000004,7.775465094474378,21.685016596478096,9.080147244893045,22.25119689216382
|
||||
20.0,7.205759403792802e-05,7.335939544994145,21.685016544265725,8.836104823252782,22.25119689216382
|
||||
30.0,7.73712524553364e-06,7.269482698812543,21.68501663562285,8.795702427614733,22.25119689216382
|
||||
40.0,8.307674973655742e-07,7.256998672841371,21.685016551757823,8.785421505532184,22.25119689216382
|
||||
50.0,8.920298079412272e-08,7.269379717798182,21.68501661848976,8.782566600764051,22.25119689216382
|
||||
60.0,9.578097130411839e-09,7.259323381483945,21.685016565057747,8.782792238863054,22.25119689216382
|
||||
70.0,1.028440348325758e-09,7.265019663627232,21.685016472364822,8.782802083944372,22.25119689216382
|
||||
80.0,1.1042794154864952e-10,7.262544146220354,21.685016528410358,8.782802388717515,22.25119689216382
|
|
BIN
checkpoints/hot3d/scene1/attended_hand_recognition_model.pt
Normal file
BIN
checkpoints/hot3d/scene1/attended_hand_recognition_model.pt
Normal file
Binary file not shown.
|
@ -0,0 +1,7 @@
|
|||
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
|
||||
10.0,0.0031512470486230455,0.6816170010865429,0.852783038866365,0.7241098171488276,0.8022957816714968
|
||||
20.0,0.0018867680126765363,0.6816170011223602,0.8716868679089899,0.7241098171488276,0.819822787597355
|
||||
30.0,0.0011296777049628276,0.6816170013059245,0.8825985085263077,0.7241098171488276,0.8229033185732267
|
||||
40.0,0.0006763797713952794,0.6816170021420371,0.8881970469835044,0.7241098171488276,0.8239226446778548
|
||||
50.0,0.0004049735540879638,0.6816170025595338,0.8919577810969009,0.7241098171488276,0.8248788276437122
|
||||
60.0,0.00024247262624711545,0.6816170014827729,0.8943664556306086,0.7241098171488276,0.8249509930207405
|
|
BIN
checkpoints/hot3d/scene1/gaze_estimation_model_best.pt
Normal file
BIN
checkpoints/hot3d/scene1/gaze_estimation_model_best.pt
Normal file
Binary file not shown.
BIN
checkpoints/hot3d/scene1/gaze_estimation_model_last.pt
Normal file
BIN
checkpoints/hot3d/scene1/gaze_estimation_model_last.pt
Normal file
Binary file not shown.
9
checkpoints/hot3d/scene1/gaze_estimation_results.csv
Normal file
9
checkpoints/hot3d/scene1/gaze_estimation_results.csv
Normal file
|
@ -0,0 +1,9 @@
|
|||
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
|
||||
10.0,0.0031512470486230455,8.74235833757656,22.929272996737275,10.080445969203966,23.691519417545646
|
||||
20.0,0.0018867680126765363,7.879988509958766,22.92927302668063,9.581555476071555,23.691519417545646
|
||||
30.0,0.0011296777049628276,7.3950617451102945,22.929273043729715,9.036600356479111,23.691519417545646
|
||||
40.0,0.0006763797713952794,7.07688238604103,22.92927304215375,8.760339399561913,23.691519417545646
|
||||
50.0,0.0004049735540879638,6.9027446432564235,22.929273003757487,8.635063652772557,23.691519417545646
|
||||
60.0,0.00024247262624711545,6.775644442450048,22.929273132700157,8.549054237754392,23.691519417545646
|
||||
70.0,0.00014517731808828924,6.714401105429294,22.929273023958505,8.694673094179414,23.691519417545646
|
||||
80.0,8.69230230790188e-05,6.677202380895346,22.92927304903069,8.57591808497942,23.691519417545646
|
|
BIN
checkpoints/hot3d/scene2/attended_hand_recognition_model.pt
Normal file
BIN
checkpoints/hot3d/scene2/attended_hand_recognition_model.pt
Normal file
Binary file not shown.
|
@ -0,0 +1,7 @@
|
|||
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
|
||||
10.0,0.0031512470486230455,0.7021381167567796,0.8402169120884563,0.6901813036133927,0.8264494744106567
|
||||
20.0,0.0018867680126765363,0.7021381167072633,0.8624022141829342,0.6901813036133927,0.8494821318172411
|
||||
30.0,0.0011296777049628276,0.7021381160817937,0.8747672176315217,0.6901813036133927,0.8499820602572988
|
||||
40.0,0.0006763797713952794,0.7021381155110528,0.8815414452682059,0.6901813036133927,0.8484319224921495
|
||||
50.0,0.0004049735540879638,0.7021381158498488,0.8865171791290529,0.6901813036133927,0.8488455337082769
|
||||
60.0,0.00024247262624711545,0.7021381161156733,0.8899246960738945,0.6901813036133927,0.847752164009063
|
|
BIN
checkpoints/hot3d/scene2/gaze_estimation_model_best.pt
Normal file
BIN
checkpoints/hot3d/scene2/gaze_estimation_model_best.pt
Normal file
Binary file not shown.
BIN
checkpoints/hot3d/scene2/gaze_estimation_model_last.pt
Normal file
BIN
checkpoints/hot3d/scene2/gaze_estimation_model_last.pt
Normal file
Binary file not shown.
9
checkpoints/hot3d/scene2/gaze_estimation_results.csv
Normal file
9
checkpoints/hot3d/scene2/gaze_estimation_results.csv
Normal file
|
@ -0,0 +1,9 @@
|
|||
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
|
||||
10.0,0.0031512470486230455,8.839351579997135,23.505237637305886,9.701928491884642,22.82643943394779
|
||||
20.0,0.0018867680126765363,8.0279754825312,23.505237625713846,9.222408025492149,22.82643943394779
|
||||
30.0,0.0011296777049628276,7.478540301192631,23.50523756792046,8.860391529150753,22.82643943394779
|
||||
40.0,0.0006763797713952794,7.155920212053355,23.505237645728876,8.689459985116457,22.82643943394779
|
||||
50.0,0.0004049735540879638,6.943932257432076,23.505237599193936,8.75292004044837,22.82643943394779
|
||||
60.0,0.00024247262624711545,6.828086620144171,23.505237697934735,8.7168934796975,22.82643943394779
|
||||
70.0,0.00014517731808828924,6.745624930361051,23.505237540900172,8.723768544186793,22.82643943394779
|
||||
80.0,8.69230230790188e-05,6.688888099966966,23.505237563083494,8.702503739012487,22.82643943394779
|
|
BIN
checkpoints/hot3d/scene3/attended_hand_recognition_model.pt
Normal file
BIN
checkpoints/hot3d/scene3/attended_hand_recognition_model.pt
Normal file
Binary file not shown.
|
@ -0,0 +1,7 @@
|
|||
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
|
||||
10.0,0.0031512470486230455,0.7052336015288246,0.8534142175928565,0.661994266023742,0.8135229105560388
|
||||
20.0,0.0018867680126765363,0.7052336013105647,0.8725956164871361,0.661994266023742,0.831561643393579
|
||||
30.0,0.0011296777049628276,0.7052336017095152,0.883130780337477,0.661994266023742,0.8272765153270717
|
||||
40.0,0.0006763797713952794,0.7052336016146974,0.8908625688394565,0.661994266023742,0.8306551724193417
|
||||
50.0,0.0004049735540879638,0.7052336010994608,0.8941962028439242,0.661994266023742,0.8302761033822049
|
||||
60.0,0.00024247262624711545,0.7052336012712063,0.8975138279303041,0.661994266023742,0.8269304087636498
|
|
BIN
checkpoints/hot3d/scene3/gaze_estimation_model_best.pt
Normal file
BIN
checkpoints/hot3d/scene3/gaze_estimation_model_best.pt
Normal file
Binary file not shown.
BIN
checkpoints/hot3d/scene3/gaze_estimation_model_last.pt
Normal file
BIN
checkpoints/hot3d/scene3/gaze_estimation_model_last.pt
Normal file
Binary file not shown.
9
checkpoints/hot3d/scene3/gaze_estimation_results.csv
Normal file
9
checkpoints/hot3d/scene3/gaze_estimation_results.csv
Normal file
|
@ -0,0 +1,9 @@
|
|||
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
|
||||
10.0,0.0031512470486230455,8.883382703181908,23.21022984623519,10.4297500209549,23.16488788543437
|
||||
20.0,0.0018867680126765363,7.782790481198015,23.210229789501916,9.385957283378383,23.16488788543437
|
||||
30.0,0.0011296777049628276,7.3295042459914015,23.210229956037832,8.8375927512519,23.16488788543437
|
||||
40.0,0.0006763797713952794,7.08854188886722,23.21022992380692,8.90408793085143,23.16488788543437
|
||||
50.0,0.0004049735540879638,6.931669339063979,23.21022993880603,8.754324602049273,23.16488788543437
|
||||
60.0,0.00024247262624711545,6.825089292923836,23.21022987319924,8.792744936666029,23.16488788543437
|
||||
70.0,0.00014517731808828924,6.76963417329046,23.210229846750426,8.694600144187069,23.16488788543437
|
||||
80.0,8.69230230790188e-05,6.737586019620012,23.210229833812264,8.71311124122629,23.16488788543437
|
|
BIN
checkpoints/hot3d/user1/attended_hand_recognition_model.pt
Normal file
BIN
checkpoints/hot3d/user1/attended_hand_recognition_model.pt
Normal file
Binary file not shown.
|
@ -0,0 +1,7 @@
|
|||
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
|
||||
10.0,0.0031512470486230455,0.6967981925509015,0.8302275884102602,0.6967302647043779,0.8086298808019007
|
||||
20.0,0.0018867680126765363,0.696825941169604,0.8459507790012438,0.6967302647043779,0.8300570517057798
|
||||
30.0,0.0011296777049628276,0.6967939224521644,0.8587645590916032,0.6967302647043779,0.8328419266690059
|
||||
40.0,0.0006763797713952794,0.6967939235514304,0.866474601211118,0.6967302647043779,0.8400695022933019
|
||||
50.0,0.0004049735540879638,0.6967939230934029,0.8740437618288838,0.6967302647043779,0.8441059351178314
|
||||
60.0,0.00024247262624711545,0.6967939231544733,0.8783235607332871,0.6967302647043779,0.8460717301882389
|
|
BIN
checkpoints/hot3d/user1/gaze_estimation_model_best.pt
Normal file
BIN
checkpoints/hot3d/user1/gaze_estimation_model_best.pt
Normal file
Binary file not shown.
BIN
checkpoints/hot3d/user1/gaze_estimation_model_last.pt
Normal file
BIN
checkpoints/hot3d/user1/gaze_estimation_model_last.pt
Normal file
Binary file not shown.
9
checkpoints/hot3d/user1/gaze_estimation_results.csv
Normal file
9
checkpoints/hot3d/user1/gaze_estimation_results.csv
Normal file
|
@ -0,0 +1,9 @@
|
|||
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
|
||||
10.0,0.0031512470486230455,8.723726561323542,23.18913731418672,10.435996211392954,23.239844569106314
|
||||
20.0,0.0018867680126765363,7.853544009024979,23.188485688850527,9.561553679680529,23.239844569106314
|
||||
30.0,0.0011296777049628276,7.340472601476263,23.18896493169128,9.575626259120058,23.239844569106314
|
||||
40.0,0.0006763797713952794,7.076471305772906,23.188663758215355,9.33665127001508,23.239844569106314
|
||||
50.0,0.0004049735540879638,6.922020800289561,23.189159344454282,9.39458229979059,23.239844569106314
|
||||
60.0,0.00024247262624711545,6.816650379388059,23.18947358795854,9.273735522755194,23.239844569106314
|
||||
70.0,0.00014517731808828924,6.751277917232669,23.188865782784635,9.232296420427865,23.239844569106314
|
||||
80.0,8.69230230790188e-05,6.698476171395818,23.1890932419261,9.404228606204375,23.239844569106314
|
|
BIN
checkpoints/hot3d/user2/attended_hand_recognition_model.pt
Normal file
BIN
checkpoints/hot3d/user2/attended_hand_recognition_model.pt
Normal file
Binary file not shown.
|
@ -0,0 +1,7 @@
|
|||
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
|
||||
10.0,0.0031512470486230455,0.6960166899282387,0.8442398727362511,0.6979547482149892,0.8083841012201665
|
||||
20.0,0.0018867680126765363,0.6960166904245663,0.8555533757053816,0.6979547482149892,0.7882599223551159
|
||||
30.0,0.0011296777049628276,0.6960166902980982,0.8661169178601649,0.6979547482149892,0.8318399900875737
|
||||
40.0,0.0006763797713952794,0.6960166916629991,0.8750150583824932,0.6979547482149892,0.8195155037706253
|
||||
50.0,0.0004049735540879638,0.6960166904341111,0.8811001704870125,0.6979547482149892,0.8152017306608094
|
||||
60.0,0.00024247262624711545,0.6960166910258863,0.8855065376006734,0.6979547482149892,0.8149095465877441
|
|
BIN
checkpoints/hot3d/user2/gaze_estimation_model_best.pt
Normal file
BIN
checkpoints/hot3d/user2/gaze_estimation_model_best.pt
Normal file
Binary file not shown.
BIN
checkpoints/hot3d/user2/gaze_estimation_model_last.pt
Normal file
BIN
checkpoints/hot3d/user2/gaze_estimation_model_last.pt
Normal file
Binary file not shown.
9
checkpoints/hot3d/user2/gaze_estimation_results.csv
Normal file
9
checkpoints/hot3d/user2/gaze_estimation_results.csv
Normal file
|
@ -0,0 +1,9 @@
|
|||
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
|
||||
10.0,0.0031512470486230455,9.092302583226918,20.045022424015006,10.871499093396837,28.00064259211687
|
||||
20.0,0.0018867680126765363,8.449609629257155,20.045022394846217,9.163385394021835,28.00064259211687
|
||||
30.0,0.0011296777049628276,8.004199968918563,20.04502237598577,9.715205940504797,28.00064259211687
|
||||
40.0,0.0006763797713952794,7.736533732776755,20.045022482581658,10.492947047604881,28.00064259211687
|
||||
50.0,0.0004049735540879638,7.550458214825991,20.04502240393283,9.826412407127586,28.00064259211687
|
||||
60.0,0.00024247262624711545,7.439071941005396,20.045022420044386,9.617718728174395,28.00064259211687
|
||||
70.0,0.00014517731808828924,7.354865613532145,20.045022362546746,9.901499594044505,28.00064259211687
|
||||
80.0,8.69230230790188e-05,7.31953181233149,20.04502238805035,9.986867171353252,28.00064259211687
|
|
BIN
checkpoints/hot3d/user3/attended_hand_recognition_model.pt
Normal file
BIN
checkpoints/hot3d/user3/attended_hand_recognition_model.pt
Normal file
Binary file not shown.
|
@ -0,0 +1,7 @@
|
|||
epoch,lr,baseline_acc_average,prediction_acc_average,valid_baseline_acc_average,valid_prediction_acc_average
|
||||
10.0,0.0031512470486230455,0.6974864419165719,0.8427737693608807,0.6955263208594125,0.8233390038358083
|
||||
20.0,0.0018867680126765363,0.6974864421003018,0.8595619837544608,0.6955263208594125,0.8226455437919652
|
||||
30.0,0.0011296777049628276,0.6974864414101947,0.8704107908384844,0.6955263208594125,0.8303636568959843
|
||||
40.0,0.0006763797713952794,0.6974864407873057,0.8765607045589645,0.6955263208594125,0.8233164893276423
|
||||
50.0,0.0004049735540879638,0.69748644039744,0.8822820797718322,0.6955263208594125,0.832633161305254
|
||||
60.0,0.00024247262624711545,0.6974864419569029,0.887539829914574,0.6955263208594125,0.8349657074195717
|
|
BIN
checkpoints/hot3d/user3/gaze_estimation_model_best.pt
Normal file
BIN
checkpoints/hot3d/user3/gaze_estimation_model_best.pt
Normal file
Binary file not shown.
BIN
checkpoints/hot3d/user3/gaze_estimation_model_last.pt
Normal file
BIN
checkpoints/hot3d/user3/gaze_estimation_model_last.pt
Normal file
Binary file not shown.
9
checkpoints/hot3d/user3/gaze_estimation_results.csv
Normal file
9
checkpoints/hot3d/user3/gaze_estimation_results.csv
Normal file
|
@ -0,0 +1,9 @@
|
|||
epoch,lr,prediction_error_average,baseline_error_average,valid_prediction_error_average,valid_baseline_error_average
|
||||
10.0,0.0031512470486230455,8.638778725599636,26.1798637664889,10.833028636933662,17.849539670393458
|
||||
20.0,0.0018867680126765363,7.791013009296702,26.179863858981182,10.735866879370116,17.849539670393458
|
||||
30.0,0.0011296777049628276,7.266133595065025,26.17986385037725,9.690580197114954,17.849539670393458
|
||||
40.0,0.0006763797713952794,6.93408266801277,26.17986383503357,10.310306123348958,17.849539670393458
|
||||
50.0,0.0004049735540879638,6.787781177862328,26.179863753869796,10.454030626426473,17.849539670393458
|
||||
60.0,0.00024247262624711545,6.659840548202495,26.179863708842547,10.255192241327299,17.849539670393458
|
||||
70.0,0.00014517731808828924,6.597368892964792,26.17986375731137,10.151271223901132,17.849539670393458
|
||||
80.0,8.69230230790188e-05,6.558900393612285,26.179863782692973,10.218871813100643,17.849539670393458
|
|
101
environment/hoigaze.yml
Normal file
101
environment/hoigaze.yml
Normal file
|
@ -0,0 +1,101 @@
|
|||
name: hoigaze
|
||||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- ca-certificates=2023.12.12=h06a4308_0
|
||||
- ld_impl_linux-64=2.38=h1181459_1
|
||||
- libffi=3.4.4=h6a678d5_0
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- ncurses=6.4=h6a678d5_0
|
||||
- openssl=3.0.12=h7f8727e_0
|
||||
- pip=23.3.1=py38h06a4308_0
|
||||
- python=3.8.18=h955ad1f_0
|
||||
- readline=8.2=h5eee18b_0
|
||||
- setuptools=68.2.2=py38h06a4308_0
|
||||
- sqlite=3.41.2=h5eee18b_0
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- wheel=0.41.2=py38h06a4308_0
|
||||
- xz=5.4.5=h5eee18b_0
|
||||
- zlib=1.2.13=h5eee18b_0
|
||||
- pip:
|
||||
- absl-py==2.0.0
|
||||
- aiohttp==3.9.1
|
||||
- aiosignal==1.3.1
|
||||
- appdirs==1.4.4
|
||||
- async-timeout==4.0.3
|
||||
- attrs==23.2.0
|
||||
- cachetools==5.3.2
|
||||
- certifi==2023.11.17
|
||||
- charset-normalizer==3.3.2
|
||||
- click==8.1.7
|
||||
- contourpy==1.1.1
|
||||
- cycler==0.12.1
|
||||
- cython==0.29.37
|
||||
- docker-pycreds==0.4.0
|
||||
- fonttools==4.47.2
|
||||
- frozenlist==1.4.1
|
||||
- fsspec==2023.12.2
|
||||
- ftfy==6.1.3
|
||||
- future==0.18.3
|
||||
- gitdb==4.0.11
|
||||
- gitpython==3.1.41
|
||||
- google-auth==2.26.2
|
||||
- google-auth-oauthlib==1.0.0
|
||||
- grpcio==1.60.0
|
||||
- hdbscan==0.8.33
|
||||
- idna==3.6
|
||||
- importlib-metadata==7.0.1
|
||||
- importlib-resources==6.1.1
|
||||
- joblib==1.3.2
|
||||
- kiwisolver==1.4.5
|
||||
- lmdb==1.2.1
|
||||
- lpips==0.1.4
|
||||
- markdown==3.5.2
|
||||
- markupsafe==2.1.3
|
||||
- matplotlib==3.5.3
|
||||
- multidict==6.0.4
|
||||
- numpy==1.24.4
|
||||
- oauthlib==3.2.2
|
||||
- packaging==23.2
|
||||
- pandas==1.5.3
|
||||
- pillow==10.2.0
|
||||
- protobuf==4.25.2
|
||||
- psutil==5.9.8
|
||||
- pyasn1==0.5.1
|
||||
- pyasn1-modules==0.3.0
|
||||
- pydeprecate==0.3.1
|
||||
- pyparsing==3.1.1
|
||||
- python-dateutil==2.8.2
|
||||
- pytorch-fid==0.2.0
|
||||
- pytorch-lightning==1.4.5
|
||||
- pytz==2023.3.post1
|
||||
- pyyaml==6.0.1
|
||||
- regex==2023.12.25
|
||||
- requests==2.31.0
|
||||
- requests-oauthlib==1.3.1
|
||||
- rsa==4.9
|
||||
- scikit-learn==1.3.2
|
||||
- scipy==1.5.4
|
||||
- sentry-sdk==1.39.2
|
||||
- setproctitle==1.3.3
|
||||
- six==1.16.0
|
||||
- smmap==5.0.1
|
||||
- tensorboard==2.14.0
|
||||
- tensorboard-data-server==0.7.2
|
||||
- threadpoolctl==3.2.0
|
||||
- torch==1.8.1
|
||||
- torchmetrics==0.5.0
|
||||
- torchvision==0.9.1
|
||||
- tqdm==4.66.1
|
||||
- typing-extensions==4.9.0
|
||||
- tzdata==2023.4
|
||||
- urllib3==2.1.0
|
||||
- wandb==0.16.2
|
||||
- wcwidth==0.2.13
|
||||
- werkzeug==3.0.1
|
||||
- yarl==1.9.4
|
||||
- zipp==3.17.0
|
323
gaze_estimation_adt.py
Normal file
323
gaze_estimation_adt.py
Normal file
|
@ -0,0 +1,323 @@
|
|||
from utils import adt_dataset, seed_torch
|
||||
from model import gaze_estimation
|
||||
from utils.opt import options
|
||||
from utils import log
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import time
|
||||
import datetime
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
os.nice(5)
|
||||
import math
|
||||
|
||||
|
||||
def main(opt):
|
||||
# set the random seed to ensure reproducibility
|
||||
seed_torch.seed_torch(seed=0)
|
||||
torch.set_num_threads(1)
|
||||
|
||||
data_dir = opt.data_dir
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
learning_rate = opt.learning_rate
|
||||
print('>>> create model')
|
||||
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
|
||||
optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
print('>>> loading datasets')
|
||||
|
||||
train_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_train.npy")
|
||||
valid_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
|
||||
|
||||
train_dataset = np.load(train_data_path)
|
||||
train_data_size = train_dataset.shape
|
||||
print("Training data size: {}".format(train_data_size))
|
||||
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True)
|
||||
valid_dataset = np.load(valid_data_path)
|
||||
valid_data_size = valid_dataset.shape
|
||||
print("Validation data size: {}".format(valid_data_size))
|
||||
valid_loader = DataLoader(valid_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
# training
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTraining starts at ' + local_time)
|
||||
start_time = datetime.datetime.now()
|
||||
start_epoch = 1
|
||||
|
||||
err_best = 1000
|
||||
best_epoch = 0
|
||||
exp_lr = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma, last_epoch=-1)
|
||||
for epo in range(start_epoch, opt.epoch + 1):
|
||||
is_best = False
|
||||
learning_rate = exp_lr.optimizer.param_groups[0]["lr"]
|
||||
|
||||
train_start_time = datetime.datetime.now()
|
||||
result_train = run_model(net, optimizer, is_train=1, data_loader=train_loader, opt=opt)
|
||||
train_end_time = datetime.datetime.now()
|
||||
train_time = (train_end_time - train_start_time).seconds*1000
|
||||
train_batch_num = math.ceil(train_data_size[0]/opt.batch_size)
|
||||
train_time_per_batch = math.ceil(train_time/train_batch_num)
|
||||
#print('\nTraining time per batch: {} ms'.format(train_time_per_batch))
|
||||
|
||||
exp_lr.step()
|
||||
rng_state = torch.get_rng_state()
|
||||
if epo % opt.validation_epoch == 0:
|
||||
print('>>> training epoch: {:d}, lr: {:.12f}'.format(epo, learning_rate))
|
||||
print('Training data size: {}'.format(train_data_size))
|
||||
print('Average baseline error: {:.2f} degree'.format(result_train['baseline_error_average']))
|
||||
print('Average training error: {:.2f} degree'.format(result_train['prediction_error_average']))
|
||||
|
||||
test_start_time = datetime.datetime.now()
|
||||
result_valid = run_model(net, is_train=0, data_loader=valid_loader, opt=opt)
|
||||
test_end_time = datetime.datetime.now()
|
||||
test_time = (test_end_time - test_start_time).seconds*1000
|
||||
test_batch_num = math.ceil(valid_data_size[0]/opt.test_batch_size)
|
||||
test_time_per_batch = math.ceil(test_time/test_batch_num)
|
||||
#print('\nTest time per batch: {} ms'.format(test_time_per_batch))
|
||||
print('Validation data size: {}'.format(valid_data_size))
|
||||
|
||||
print('Average baseline error: {:.2f} degree'.format(result_valid['baseline_error_average']))
|
||||
print('Average validation error: {:.2f} degree'.format(result_valid['prediction_error_average']))
|
||||
|
||||
if result_valid['prediction_error_average'] < err_best:
|
||||
err_best = result_valid['prediction_error_average']
|
||||
is_best = True
|
||||
best_epoch = epo
|
||||
|
||||
print('Best validation error: {:.2f} degree, best epoch: {}'.format(err_best, best_epoch))
|
||||
end_time = datetime.datetime.now()
|
||||
total_training_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal training time: {:.2f} min'.format(total_training_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTraining ends at ' + local_time)
|
||||
|
||||
result_log = np.array([epo, learning_rate])
|
||||
head = np.array(['epoch', 'lr'])
|
||||
for k in result_train.keys():
|
||||
result_log = np.append(result_log, [result_train[k]])
|
||||
head = np.append(head, [k])
|
||||
for k in result_valid.keys():
|
||||
result_log = np.append(result_log, [result_valid[k]])
|
||||
head = np.append(head, ['valid_' + k])
|
||||
|
||||
csv_name = 'gaze_estimation_results'
|
||||
log.save_csv_log(opt, head, result_log, is_create=(epo == 1), file_name=csv_name)
|
||||
last_model_name = 'gaze_estimation_model_last.pt'
|
||||
log.save_ckpt({'epoch': epo,
|
||||
'lr': learning_rate,
|
||||
'err': result_valid['prediction_error_average'],
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict()},
|
||||
opt=opt,
|
||||
file_name = last_model_name)
|
||||
if epo == best_epoch:
|
||||
best_model_name = 'gaze_estimation_model_best.pt'
|
||||
log.save_ckpt({'epoch': epo,
|
||||
'lr': learning_rate,
|
||||
'err': result_valid['prediction_error_average'],
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict()},
|
||||
opt=opt,
|
||||
file_name = best_model_name)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
|
||||
def eval(opt):
|
||||
data_dir = opt.data_dir
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
|
||||
print('>>> create model')
|
||||
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
#load model
|
||||
model_name = 'gaze_estimation_model_best.pt'
|
||||
model_path = os.path.join(opt.ckpt, model_name)
|
||||
print(">>> loading ckpt from '{}'".format(model_path))
|
||||
ckpt = torch.load(model_path)
|
||||
net.load_state_dict(ckpt['state_dict'])
|
||||
print(">>> ckpt loaded (epoch: {} | err: {})".format(ckpt['epoch'], ckpt['err']))
|
||||
|
||||
print('>>> loading datasets')
|
||||
test_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
|
||||
test_dataset = np.load(test_data_path)
|
||||
test_data_size = test_dataset.shape
|
||||
print("Test data size: {}".format(test_data_size))
|
||||
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
# test
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTest starts at ' + local_time)
|
||||
start_time = datetime.datetime.now()
|
||||
if opt.save_predictions:
|
||||
result_test, predictions = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
|
||||
else:
|
||||
result_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
|
||||
|
||||
print('Average baseline error: {:.2f} degree'.format(result_test['baseline_error_average']))
|
||||
print('Average prediction error: {:.2f} degree'.format(result_test['prediction_error_average']))
|
||||
|
||||
end_time = datetime.datetime.now()
|
||||
total_test_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal test time: {:.2f} min'.format(total_test_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTest ends at ' + local_time)
|
||||
|
||||
if opt.save_predictions:
|
||||
# ground_truth + joints + head_directions + object_positions + attended_hand_prd + attended_hand_gt + predictions
|
||||
batch_size, seq_n, dim = predictions.shape
|
||||
predictions = predictions.reshape(-1, dim)
|
||||
ground_truth = predictions[:, :3]
|
||||
head_directions = predictions[:, 3+opt.joint_number*3:6+opt.joint_number*3]
|
||||
head_cos = np.sum(head_directions*ground_truth, 1)
|
||||
head_cos = np.clip(head_cos, -1, 1)
|
||||
head_errors = np.arccos(head_cos)/np.pi * 180.0
|
||||
print('Average baseline error: {:.2f} degree'.format(np.mean(head_errors)))
|
||||
|
||||
prediction = predictions[:, -3:]
|
||||
prd_cos = np.sum(prediction*ground_truth, 1)
|
||||
prd_cos = np.clip(prd_cos, -1, 1)
|
||||
prediction_errors = np.arccos(prd_cos)/np.pi * 180.0
|
||||
print('Average prediction error: {:.2f} degree'.format(np.mean(prediction_errors)))
|
||||
|
||||
attended_hand_gt = predictions[:, -4]
|
||||
attended_hand_prd_left = predictions[:, -6]
|
||||
attended_hand_prd_right = predictions[:, -5]
|
||||
attended_hand_correct = attended_hand_prd_left
|
||||
for i in range(attended_hand_correct.shape[0]):
|
||||
if attended_hand_gt[i] == 0 and attended_hand_prd_left[i] > attended_hand_prd_right[i]:
|
||||
attended_hand_correct[i] = 1
|
||||
elif attended_hand_gt[i] == 1 and attended_hand_prd_left[i] < attended_hand_prd_right[i]:
|
||||
attended_hand_correct[i] = 1
|
||||
else:
|
||||
attended_hand_correct[i] = 0
|
||||
|
||||
correct_ratio = np.sum(attended_hand_correct)/attended_hand_correct.shape[0]
|
||||
print("hand recognition acc: {:.2f}%".format(correct_ratio*100))
|
||||
attended_hand_wrong = 1 - attended_hand_correct
|
||||
wrong_ratio = np.sum(attended_hand_wrong)/attended_hand_wrong.shape[0]
|
||||
|
||||
head_errors_correct = np.sum(head_errors*attended_hand_correct)/np.sum(attended_hand_correct)
|
||||
print("hand recognition correct size: {}".format(np.sum(attended_hand_correct)))
|
||||
print("hand recognition correct, average baseline error: {:.2f} degree".format(head_errors_correct))
|
||||
head_errors_wrong = np.sum(head_errors*attended_hand_wrong)/np.sum(attended_hand_wrong)
|
||||
print("hand recognition wrong size: {}".format(np.sum(attended_hand_wrong)))
|
||||
print("hand recognition wrong, average baseline error: {:.2f} degree".format(head_errors_wrong))
|
||||
head_errors_avg = head_errors_correct*correct_ratio + head_errors_wrong*wrong_ratio
|
||||
print('Average baseline error: {:.2f} degree'.format(head_errors_avg))
|
||||
|
||||
prediction_errors_correct = np.sum(prediction_errors*attended_hand_correct)/np.sum(attended_hand_correct)
|
||||
print("hand recognition correct, average prediction error: {:.2f} degree".format(prediction_errors_correct))
|
||||
prediction_errors_wrong = np.sum(prediction_errors*attended_hand_wrong)/np.sum(attended_hand_wrong)
|
||||
print("hand recognition wrong, average prediction error: {:.2f} degree".format(prediction_errors_wrong))
|
||||
prediction_errors_avg = prediction_errors_correct*correct_ratio + prediction_errors_wrong*wrong_ratio
|
||||
print('Average prediction error: {:.2f} degree'.format(prediction_errors_avg))
|
||||
|
||||
predictions_path = os.path.join(opt.ckpt, "gaze_predictions.npy")
|
||||
np.save(predictions_path, predictions)
|
||||
prediction_errors_path = os.path.join(opt.ckpt, "prediction_errors.npy")
|
||||
np.save(prediction_errors_path, prediction_errors)
|
||||
attended_hand_correct_path = os.path.join(opt.ckpt, "attended_hand_correct.npy")
|
||||
np.save(attended_hand_correct_path, attended_hand_correct)
|
||||
|
||||
|
||||
def acos_safe(x, eps=1e-6):
|
||||
slope = np.arccos(1-eps) / eps
|
||||
buf = torch.empty_like(x)
|
||||
good = abs(x) <= 1-eps
|
||||
bad = ~good
|
||||
sign = torch.sign(x[bad])
|
||||
buf[good] = torch.acos(x[good])
|
||||
buf[bad] = torch.acos(sign * (1 - eps)) - slope*sign*(abs(x[bad]) - 1 + eps)
|
||||
return buf
|
||||
|
||||
|
||||
def run_model(net, optimizer=None, is_train=1, data_loader=None, opt=None):
|
||||
if is_train == 1:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
predictions = []
|
||||
|
||||
prediction_error_average = 0
|
||||
baseline_error_average = 0
|
||||
criterion = torch.nn.MSELoss(reduction='none')
|
||||
|
||||
n = 0
|
||||
input_n = opt.seq_len
|
||||
|
||||
for i, (data) in enumerate(data_loader):
|
||||
batch_size, seq_n, dim = data.shape
|
||||
joint_number = opt.joint_number
|
||||
object_num = opt.object_num
|
||||
# when only one sample in this batch
|
||||
if batch_size == 1 and is_train == 1:
|
||||
continue
|
||||
n += batch_size
|
||||
data = data.float().to(opt.cuda_idx)
|
||||
|
||||
ground_truth = data.clone()[:, :, :3]
|
||||
joints = data.clone()[:, :, 3:(joint_number+1)*3]
|
||||
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
|
||||
attended_hand_prd = data.clone()[:, :, (joint_number+2+8*object_num*2)*3:(joint_number+2+8*object_num*2)*3+2]
|
||||
attended_hand_gt = data.clone()[:, :, (joint_number+2+8*object_num*2)*3+2:(joint_number+2+8*object_num*2)*3+3]
|
||||
|
||||
input = torch.cat((joints, head_directions), dim=2)
|
||||
if object_num > 0:
|
||||
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
|
||||
input = torch.cat((input, object_positions), dim=2)
|
||||
input = torch.cat((input, attended_hand_prd), dim=2)
|
||||
input = torch.cat((input, attended_hand_gt), dim=2)
|
||||
prediction = net(input, input_n=input_n)
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
# ground_truth + joints + head_directions + object_positions + attended_hand_prd + attended_hand_gt + predictions
|
||||
prediction_cpu = torch.cat((ground_truth, input), dim=2)
|
||||
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
|
||||
prediction_cpu = prediction_cpu.cpu().data.numpy()
|
||||
if len(predictions) == 0:
|
||||
predictions = prediction_cpu
|
||||
else:
|
||||
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
|
||||
|
||||
gaze_head_cos = torch.sum(ground_truth*head_directions, dim=2, keepdim=True)
|
||||
gaze_weight = torch.where(gaze_head_cos>opt.gaze_head_cos_threshold, opt.gaze_head_loss_factor, 1.0)
|
||||
|
||||
loss = criterion(ground_truth, prediction)
|
||||
loss = torch.mean(loss*gaze_weight)
|
||||
|
||||
if is_train == 1:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Calculate prediction errors
|
||||
error = torch.mean(acos_safe(torch.sum(ground_truth*prediction, 2)))/torch.tensor(math.pi) * 180.0
|
||||
prediction_error_average += error.cpu().data.numpy() * batch_size
|
||||
|
||||
# Use head directions as the baseline
|
||||
baseline_error = torch.mean(acos_safe(torch.sum(ground_truth*head_directions, 2)))/torch.tensor(math.pi) * 180.0
|
||||
baseline_error_average += baseline_error.cpu().data.numpy() * batch_size
|
||||
|
||||
result = {}
|
||||
result["prediction_error_average"] = prediction_error_average / n
|
||||
result["baseline_error_average"] = baseline_error_average / n
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
return result, predictions
|
||||
else:
|
||||
return result
|
||||
|
||||
if __name__ == '__main__':
|
||||
option = options().parse()
|
||||
if option.is_eval == False:
|
||||
main(option)
|
||||
else:
|
||||
eval(option)
|
552
gaze_estimation_hot3d.py
Normal file
552
gaze_estimation_hot3d.py
Normal file
|
@ -0,0 +1,552 @@
|
|||
from utils import hot3d_aria_dataset, seed_torch
|
||||
from model import gaze_estimation
|
||||
from utils.opt import options
|
||||
from utils import log
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import time
|
||||
import datetime
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
os.nice(5)
|
||||
import math
|
||||
|
||||
|
||||
def main(opt):
|
||||
# set the random seed to ensure reproducibility
|
||||
seed_torch.seed_torch(seed=0)
|
||||
torch.set_num_threads(1)
|
||||
|
||||
data_dir = opt.data_dir
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
learning_rate = opt.learning_rate
|
||||
print('>>> create model')
|
||||
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
|
||||
optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
print('>>> loading datasets')
|
||||
|
||||
actions = opt.actions
|
||||
test_user_id = opt.test_user_id
|
||||
if actions == 'all':
|
||||
if test_user_id == 1:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003']
|
||||
opt.ckpt = opt.ckpt + '/user1/'
|
||||
if test_user_id == 2:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0009', 'P0010', 'P0011']
|
||||
opt.ckpt = opt.ckpt + '/user2/'
|
||||
if test_user_id == 3:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
|
||||
test_subjects = ['P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/user3/'
|
||||
elif actions == 'room':
|
||||
train_actions = ['kitchen', 'office']
|
||||
test_actions = ['room']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene1/'
|
||||
elif actions == 'kitchen':
|
||||
train_actions = ['room', 'office']
|
||||
test_actions = ['kitchen']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene2/'
|
||||
elif actions == 'office':
|
||||
train_actions = ['room', 'kitchen']
|
||||
test_actions = ['office']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene3/'
|
||||
else:
|
||||
raise( ValueError, "Unrecognised actions: %d" % actions)
|
||||
|
||||
train_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_train.npy")
|
||||
valid_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
|
||||
|
||||
train_dataset = np.load(train_data_path)
|
||||
train_data_size = train_dataset.shape
|
||||
print("Training data size: {}".format(train_data_size))
|
||||
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True)
|
||||
valid_dataset = np.load(valid_data_path)
|
||||
valid_data_size = valid_dataset.shape
|
||||
print("Validation data size: {}".format(valid_data_size))
|
||||
valid_loader = DataLoader(valid_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
# training
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTraining starts at ' + local_time)
|
||||
start_time = datetime.datetime.now()
|
||||
start_epoch = 1
|
||||
|
||||
err_best = 1000
|
||||
best_epoch = 0
|
||||
exp_lr = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=opt.gamma, last_epoch=-1)
|
||||
for epo in range(start_epoch, opt.epoch + 1):
|
||||
is_best = False
|
||||
learning_rate = exp_lr.optimizer.param_groups[0]["lr"]
|
||||
|
||||
train_start_time = datetime.datetime.now()
|
||||
result_train = run_model(net, optimizer, is_train=1, data_loader=train_loader, opt=opt)
|
||||
train_end_time = datetime.datetime.now()
|
||||
train_time = (train_end_time - train_start_time).seconds*1000
|
||||
train_batch_num = math.ceil(train_data_size[0]/opt.batch_size)
|
||||
train_time_per_batch = math.ceil(train_time/train_batch_num)
|
||||
#print('\nTraining time per batch: {} ms'.format(train_time_per_batch))
|
||||
|
||||
exp_lr.step()
|
||||
rng_state = torch.get_rng_state()
|
||||
if epo % opt.validation_epoch == 0:
|
||||
if actions == 'all':
|
||||
print("\ntest user id: {}\n".format(test_user_id))
|
||||
elif actions == 'room':
|
||||
print("\ntest scene/action: room\n")
|
||||
elif actions == 'kitchen':
|
||||
print("\ntest scene/action: kitchen\n")
|
||||
elif actions == 'office':
|
||||
print("\ntest scene/action: office\n")
|
||||
print('>>> training epoch: {:d}, lr: {:.12f}'.format(epo, learning_rate))
|
||||
print('Training data size: {}'.format(train_data_size))
|
||||
print('Average baseline error: {:.2f} degree'.format(result_train['baseline_error_average']))
|
||||
print('Average training error: {:.2f} degree'.format(result_train['prediction_error_average']))
|
||||
|
||||
test_start_time = datetime.datetime.now()
|
||||
result_valid = run_model(net, is_train=0, data_loader=valid_loader, opt=opt)
|
||||
test_end_time = datetime.datetime.now()
|
||||
test_time = (test_end_time - test_start_time).seconds*1000
|
||||
test_batch_num = math.ceil(valid_data_size[0]/opt.test_batch_size)
|
||||
test_time_per_batch = math.ceil(test_time/test_batch_num)
|
||||
#print('\nTest time per batch: {} ms'.format(test_time_per_batch))
|
||||
print('Validation data size: {}'.format(valid_data_size))
|
||||
|
||||
print('Average baseline error: {:.2f} degree'.format(result_valid['baseline_error_average']))
|
||||
print('Average validation error: {:.2f} degree'.format(result_valid['prediction_error_average']))
|
||||
|
||||
if result_valid['prediction_error_average'] < err_best:
|
||||
err_best = result_valid['prediction_error_average']
|
||||
is_best = True
|
||||
best_epoch = epo
|
||||
|
||||
print('Best validation error: {:.2f} degree, best epoch: {}'.format(err_best, best_epoch))
|
||||
end_time = datetime.datetime.now()
|
||||
total_training_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal training time: {:.2f} min'.format(total_training_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTraining ends at ' + local_time)
|
||||
|
||||
result_log = np.array([epo, learning_rate])
|
||||
head = np.array(['epoch', 'lr'])
|
||||
for k in result_train.keys():
|
||||
result_log = np.append(result_log, [result_train[k]])
|
||||
head = np.append(head, [k])
|
||||
for k in result_valid.keys():
|
||||
result_log = np.append(result_log, [result_valid[k]])
|
||||
head = np.append(head, ['valid_' + k])
|
||||
|
||||
csv_name = 'gaze_estimation_results'
|
||||
log.save_csv_log(opt, head, result_log, is_create=(epo == 1), file_name=csv_name)
|
||||
model_name = 'gaze_estimation_model_last.pt'
|
||||
log.save_ckpt({'epoch': epo,
|
||||
'lr': learning_rate,
|
||||
'err': result_valid['prediction_error_average'],
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict()},
|
||||
opt=opt,
|
||||
file_name = model_name)
|
||||
if epo == best_epoch:
|
||||
model_name = 'gaze_estimation_model_best.pt'
|
||||
log.save_ckpt({'epoch': epo,
|
||||
'lr': learning_rate,
|
||||
'err': result_valid['prediction_error_average'],
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict()},
|
||||
opt=opt,
|
||||
file_name = model_name)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
|
||||
def eval(opt):
|
||||
data_dir = opt.data_dir
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
|
||||
print('>>> create model')
|
||||
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
actions = opt.actions
|
||||
test_user_id = opt.test_user_id
|
||||
if actions == 'all':
|
||||
if test_user_id == 1:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003']
|
||||
opt.ckpt = opt.ckpt + '/user1/'
|
||||
if test_user_id == 2:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0009', 'P0010', 'P0011']
|
||||
opt.ckpt = opt.ckpt + '/user2/'
|
||||
if test_user_id == 3:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
|
||||
test_subjects = ['P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/user3/'
|
||||
elif actions == 'room':
|
||||
train_actions = ['kitchen', 'office']
|
||||
test_actions = ['room']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene1/'
|
||||
elif actions == 'kitchen':
|
||||
train_actions = ['room', 'office']
|
||||
test_actions = ['kitchen']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene2/'
|
||||
elif actions == 'office':
|
||||
train_actions = ['room', 'kitchen']
|
||||
test_actions = ['office']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene3/'
|
||||
else:
|
||||
raise( ValueError, "Unrecognised actions: %d" % actions)
|
||||
|
||||
#load model
|
||||
model_name = 'gaze_estimation_model_best.pt'
|
||||
model_path = os.path.join(opt.ckpt, model_name)
|
||||
print(">>> loading ckpt from '{}'".format(model_path))
|
||||
ckpt = torch.load(model_path)
|
||||
net.load_state_dict(ckpt['state_dict'])
|
||||
print(">>> ckpt loaded (epoch: {} | err: {})".format(ckpt['epoch'], ckpt['err']))
|
||||
|
||||
print('>>> loading datasets')
|
||||
test_data_path = os.path.join(opt.ckpt, "attended_hand_recognition_test.npy")
|
||||
test_dataset = np.load(test_data_path)
|
||||
test_data_size = test_dataset.shape
|
||||
print("Test data size: {}".format(test_data_size))
|
||||
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
# test
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTest starts at ' + local_time)
|
||||
start_time = datetime.datetime.now()
|
||||
if actions == 'all':
|
||||
print("\ntest user id: {}\n".format(test_user_id))
|
||||
elif actions == 'room':
|
||||
print("\ntest scene/action: room\n")
|
||||
elif actions == 'kitchen':
|
||||
print("\ntest scene/action: kitchen\n")
|
||||
elif actions == 'office':
|
||||
print("\ntest scene/action: office\n")
|
||||
if opt.save_predictions:
|
||||
result_test, predictions = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
|
||||
else:
|
||||
result_test = run_model(net, is_train=0, data_loader=test_loader, opt=opt)
|
||||
|
||||
print('Average baseline error: {:.2f} degree'.format(result_test['baseline_error_average']))
|
||||
print('Average prediction error: {:.2f} degree'.format(result_test['prediction_error_average']))
|
||||
|
||||
end_time = datetime.datetime.now()
|
||||
total_test_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal test time: {:.2f} min'.format(total_test_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\nTest ends at ' + local_time)
|
||||
|
||||
if opt.save_predictions:
|
||||
# ground_truth + joints + head_directions + object_positions + attended_hand_prd + attended_hand_gt + predictions
|
||||
batch_size, seq_n, dim = predictions.shape
|
||||
predictions = predictions.reshape(-1, dim)
|
||||
ground_truth = predictions[:, :3]
|
||||
head_directions = predictions[:, 3+opt.joint_number*3:6+opt.joint_number*3]
|
||||
head_cos = np.sum(head_directions*ground_truth, 1)
|
||||
head_cos = np.clip(head_cos, -1, 1)
|
||||
head_errors = np.arccos(head_cos)/np.pi * 180.0
|
||||
print('Average baseline error: {:.2f} degree'.format(np.mean(head_errors)))
|
||||
|
||||
prediction = predictions[:, -3:]
|
||||
prd_cos = np.sum(prediction*ground_truth, 1)
|
||||
prd_cos = np.clip(prd_cos, -1, 1)
|
||||
prediction_errors = np.arccos(prd_cos)/np.pi * 180.0
|
||||
print('Average prediction error: {:.2f} degree'.format(np.mean(prediction_errors)))
|
||||
|
||||
attended_hand_gt = predictions[:, -4]
|
||||
attended_hand_prd_left = predictions[:, -6]
|
||||
attended_hand_prd_right = predictions[:, -5]
|
||||
attended_hand_correct = attended_hand_prd_left
|
||||
for i in range(attended_hand_correct.shape[0]):
|
||||
if attended_hand_gt[i] == 0 and attended_hand_prd_left[i] > attended_hand_prd_right[i]:
|
||||
attended_hand_correct[i] = 1
|
||||
elif attended_hand_gt[i] == 1 and attended_hand_prd_left[i] < attended_hand_prd_right[i]:
|
||||
attended_hand_correct[i] = 1
|
||||
else:
|
||||
attended_hand_correct[i] = 0
|
||||
|
||||
correct_ratio = np.sum(attended_hand_correct)/attended_hand_correct.shape[0]
|
||||
print("hand recognition acc: {:.2f}%".format(correct_ratio*100))
|
||||
attended_hand_wrong = 1 - attended_hand_correct
|
||||
wrong_ratio = np.sum(attended_hand_wrong)/attended_hand_wrong.shape[0]
|
||||
|
||||
head_errors_correct = np.sum(head_errors*attended_hand_correct)/np.sum(attended_hand_correct)
|
||||
print("hand recognition correct size: {}".format(np.sum(attended_hand_correct)))
|
||||
print("hand recognition correct, average baseline error: {:.2f} degree".format(head_errors_correct))
|
||||
head_errors_wrong = np.sum(head_errors*attended_hand_wrong)/np.sum(attended_hand_wrong)
|
||||
print("hand recognition wrong size: {}".format(np.sum(attended_hand_wrong)))
|
||||
print("hand recognition wrong, average baseline error: {:.2f} degree".format(head_errors_wrong))
|
||||
head_errors_avg = head_errors_correct*correct_ratio + head_errors_wrong*wrong_ratio
|
||||
print('Average baseline error: {:.2f} degree'.format(head_errors_avg))
|
||||
|
||||
prediction_errors_correct = np.sum(prediction_errors*attended_hand_correct)/np.sum(attended_hand_correct)
|
||||
print("hand recognition correct, average prediction error: {:.2f} degree".format(prediction_errors_correct))
|
||||
prediction_errors_wrong = np.sum(prediction_errors*attended_hand_wrong)/np.sum(attended_hand_wrong)
|
||||
print("hand recognition wrong, average prediction error: {:.2f} degree".format(prediction_errors_wrong))
|
||||
prediction_errors_avg = prediction_errors_correct*correct_ratio + prediction_errors_wrong*wrong_ratio
|
||||
print('Average prediction error: {:.2f} degree'.format(prediction_errors_avg))
|
||||
|
||||
predictions_path = os.path.join(opt.ckpt, "gaze_predictions.npy")
|
||||
np.save(predictions_path, predictions)
|
||||
prediction_errors_path = os.path.join(opt.ckpt, "prediction_errors.npy")
|
||||
np.save(prediction_errors_path, prediction_errors)
|
||||
attended_hand_correct_path = os.path.join(opt.ckpt, "attended_hand_correct.npy")
|
||||
np.save(attended_hand_correct_path, attended_hand_correct)
|
||||
|
||||
|
||||
def eval_single(opt):
|
||||
from utils import hot3d_aria_single_dataset
|
||||
from model import attended_hand_recognition
|
||||
seq_len = opt.seq_len
|
||||
opt.joint_number = opt.body_joint_number + opt.hand_joint_number*2
|
||||
|
||||
print('>>> create model')
|
||||
opt.residual_gcns_num = 2
|
||||
hand_model = attended_hand_recognition.attended_hand_recognition(opt=opt).to(opt.cuda_idx)
|
||||
opt.residual_gcns_num = 4
|
||||
net = gaze_estimation.gaze_estimation(opt=opt).to(opt.cuda_idx)
|
||||
print(">>> total params: {:.2f}M".format(sum(p.numel() for p in net.parameters()) / 1000000.0))
|
||||
actions = opt.actions
|
||||
test_user_id = opt.test_user_id
|
||||
if actions == 'all':
|
||||
if test_user_id == 1:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003']
|
||||
opt.ckpt = opt.ckpt + '/user1/'
|
||||
if test_user_id == 2:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0009', 'P0010', 'P0011']
|
||||
opt.ckpt = opt.ckpt + '/user2/'
|
||||
if test_user_id == 3:
|
||||
train_actions = 'all'
|
||||
test_actions = 'all'
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011']
|
||||
test_subjects = ['P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/user3/'
|
||||
elif actions == 'room':
|
||||
train_actions = ['kitchen', 'office']
|
||||
test_actions = ['room']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene1/'
|
||||
elif actions == 'kitchen':
|
||||
train_actions = ['room', 'office']
|
||||
test_actions = ['kitchen']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene2/'
|
||||
elif actions == 'office':
|
||||
train_actions = ['room', 'kitchen']
|
||||
test_actions = ['office']
|
||||
train_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
test_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
opt.ckpt = opt.ckpt + '/scene3/'
|
||||
else:
|
||||
raise( ValueError, "Unrecognised actions: %d" % actions)
|
||||
|
||||
#load hand model
|
||||
model_name = 'attended_hand_recognition_model.pt'
|
||||
model_path = os.path.join(opt.ckpt, model_name)
|
||||
print(">>> loading ckpt from '{}'".format(model_path))
|
||||
ckpt = torch.load(model_path)
|
||||
hand_model.load_state_dict(ckpt['state_dict'])
|
||||
print(">>> ckpt loaded (epoch: {} | acc: {})".format(ckpt['epoch'], ckpt['acc']))
|
||||
hand_model.eval()
|
||||
|
||||
#load gaze model
|
||||
model_name = 'gaze_estimation_model_best.pt'
|
||||
model_path = os.path.join(opt.ckpt, model_name)
|
||||
print(">>> loading ckpt from '{}'".format(model_path))
|
||||
ckpt = torch.load(model_path)
|
||||
net.load_state_dict(ckpt['state_dict'])
|
||||
print(">>> ckpt loaded (epoch: {} | err: {})".format(ckpt['epoch'], ckpt['err']))
|
||||
net.eval()
|
||||
|
||||
test_dir = '/scratch/hu/pose_forecast/hot3d_hoigaze/'
|
||||
test_file = 'P0002_016222d1_kitchen_0_1527_'
|
||||
test_path = test_dir + test_file
|
||||
test_dataset = hot3d_aria_single_dataset.hot3d_aria_dataset(test_path, seq_len)
|
||||
print("Test data size: {}".format(test_dataset.dataset.shape))
|
||||
|
||||
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
predictions = []
|
||||
for i, (data) in enumerate(test_loader):
|
||||
batch_size, seq_n, dim = data.shape
|
||||
joint_number = opt.joint_number
|
||||
object_num = opt.object_num
|
||||
data = data.float().to(opt.cuda_idx)
|
||||
|
||||
ground_truth = data.clone()[:, :, :3]
|
||||
joints = data.clone()[:, :, 3:(joint_number+1)*3]
|
||||
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
|
||||
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
|
||||
input = torch.cat((joints, head_directions), dim=2)
|
||||
input = torch.cat((input, object_positions), dim=2)
|
||||
hand_prd = hand_model(input)
|
||||
hand_prd = torch.nn.functional.softmax(hand_prd, dim=2)
|
||||
input = torch.cat((input, hand_prd), dim=2)
|
||||
prediction = net(input)
|
||||
|
||||
prediction_cpu = torch.cat((ground_truth, head_directions), dim=2)
|
||||
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
|
||||
prediction_cpu = prediction_cpu.cpu().data.numpy()
|
||||
if len(predictions) == 0:
|
||||
predictions = prediction_cpu
|
||||
else:
|
||||
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
|
||||
|
||||
predictions = predictions.reshape(-1, predictions.shape[2])
|
||||
ground_truth = predictions[:, :3]
|
||||
head = predictions[:, 3:6]
|
||||
head_cos = np.sum(head*ground_truth, 1)
|
||||
head_cos = np.clip(head_cos, -1, 1)
|
||||
head_errors = np.arccos(head_cos)/np.pi * 180.0
|
||||
print('Average baseline error: {:.2f} degree'.format(np.mean(head_errors)))
|
||||
|
||||
prediction = predictions[:, -3:]
|
||||
prd_cos = np.sum(prediction*ground_truth, 1)
|
||||
prd_cos = np.clip(prd_cos, -1, 1)
|
||||
prediction_errors = np.arccos(prd_cos)/np.pi * 180.0
|
||||
print('Average prediction error: {:.2f} degree'.format(np.mean(prediction_errors)))
|
||||
|
||||
save_dir = '/scratch/hu/pose_forecast/hot3d_hoigaze_prd/'
|
||||
save_path = save_dir + test_file + "hoigaze.npy"
|
||||
np.save(save_path, prediction)
|
||||
save_path = save_dir + test_file + "gaze.npy"
|
||||
np.save(save_path, ground_truth)
|
||||
|
||||
|
||||
def acos_safe(x, eps=1e-6):
|
||||
slope = np.arccos(1-eps) / eps
|
||||
buf = torch.empty_like(x)
|
||||
good = abs(x) <= 1-eps
|
||||
bad = ~good
|
||||
sign = torch.sign(x[bad])
|
||||
buf[good] = torch.acos(x[good])
|
||||
buf[bad] = torch.acos(sign * (1 - eps)) - slope*sign*(abs(x[bad]) - 1 + eps)
|
||||
return buf
|
||||
|
||||
|
||||
def run_model(net, optimizer=None, is_train=1, data_loader=None, opt=None):
|
||||
if is_train == 1:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
predictions = []
|
||||
|
||||
prediction_error_average = 0
|
||||
baseline_error_average = 0
|
||||
criterion = torch.nn.MSELoss(reduction='none')
|
||||
|
||||
n = 0
|
||||
input_n = opt.seq_len
|
||||
|
||||
for i, (data) in enumerate(data_loader):
|
||||
batch_size, seq_n, dim = data.shape
|
||||
joint_number = opt.joint_number
|
||||
object_num = opt.object_num
|
||||
# when only one sample in this batch
|
||||
if batch_size == 1 and is_train == 1:
|
||||
continue
|
||||
n += batch_size
|
||||
data = data.float().to(opt.cuda_idx)
|
||||
|
||||
ground_truth = data.clone()[:, :, :3]
|
||||
joints = data.clone()[:, :, 3:(joint_number+1)*3]
|
||||
head_directions = data.clone()[:, :, (joint_number+1)*3:(joint_number+2)*3]
|
||||
attended_hand_prd = data.clone()[:, :, (joint_number+2+8*object_num*2)*3:(joint_number+2+8*object_num*2)*3+2]
|
||||
attended_hand_gt = data.clone()[:, :, (joint_number+2+8*object_num*2)*3+2:(joint_number+2+8*object_num*2)*3+3]
|
||||
|
||||
input = torch.cat((joints, head_directions), dim=2)
|
||||
if object_num > 0:
|
||||
object_positions = data.clone()[:, :, (joint_number+2)*3:(joint_number+2+8*object_num*2)*3]
|
||||
input = torch.cat((input, object_positions), dim=2)
|
||||
input = torch.cat((input, attended_hand_prd), dim=2)
|
||||
input = torch.cat((input, attended_hand_gt), dim=2)
|
||||
prediction = net(input, input_n=input_n)
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
# ground_truth + joints + head_directions + object_positions + attended_hand_prd + attended_hand_gt + predictions
|
||||
prediction_cpu = torch.cat((ground_truth, input), dim=2)
|
||||
prediction_cpu = torch.cat((prediction_cpu, prediction), dim=2)
|
||||
prediction_cpu = prediction_cpu.cpu().data.numpy()
|
||||
if len(predictions) == 0:
|
||||
predictions = prediction_cpu
|
||||
else:
|
||||
predictions = np.concatenate((predictions, prediction_cpu), axis=0)
|
||||
|
||||
gaze_head_cos = torch.sum(ground_truth*head_directions, dim=2, keepdim=True)
|
||||
gaze_weight = torch.where(gaze_head_cos>opt.gaze_head_cos_threshold, opt.gaze_head_loss_factor, 1.0)
|
||||
|
||||
loss = criterion(ground_truth, prediction)
|
||||
loss = torch.mean(loss*gaze_weight)
|
||||
|
||||
if is_train == 1:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Calculate prediction errors
|
||||
error = torch.mean(acos_safe(torch.sum(ground_truth*prediction, 2)))/torch.tensor(math.pi) * 180.0
|
||||
prediction_error_average += error.cpu().data.numpy() * batch_size
|
||||
|
||||
# Use head directions as the baseline
|
||||
baseline_error = torch.mean(acos_safe(torch.sum(ground_truth*head_directions, 2)))/torch.tensor(math.pi) * 180.0
|
||||
baseline_error_average += baseline_error.cpu().data.numpy() * batch_size
|
||||
|
||||
result = {}
|
||||
result["prediction_error_average"] = prediction_error_average / n
|
||||
result["baseline_error_average"] = baseline_error_average / n
|
||||
|
||||
if opt.is_eval and opt.save_predictions:
|
||||
return result, predictions
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
option = options().parse()
|
||||
if option.is_eval == False:
|
||||
main(option)
|
||||
else:
|
||||
eval(option)
|
||||
#eval_single(option)
|
26
hot3d_processing/README.md
Normal file
26
hot3d_processing/README.md
Normal file
|
@ -0,0 +1,26 @@
|
|||
## Code to process the HOT3D dataset
|
||||
|
||||
|
||||
## Usage:
|
||||
Step 1: Follow the instructions at the official repository https://github.com/facebookresearch/hot3d to prepare the environment and download the dataset. You should also enable jupyter because the processing codes are run on jupyter.
|
||||
|
||||
Step 2: Set 'dataset_path', 'dataset_processed_path', 'object_library_path', and 'mano_hand_model_path' in 'hot3d_aria_preprocessing.ipynb', put 'hot3d_aria_preprocessing.ipynb', 'hot3d_aria_scene.csv', 'hot3d_objects.csv', and 'utils' into the official repository ('hot3d/hot3d/'), and run it to process the dataset.
|
||||
|
||||
Step 3: It is optional but highly recommended to set 'data_path', 'object_library_path', and 'mano_hand_model_path' in 'hot3d_aria_visualisation.ipynb', put 'hot3d_aria_visualisation.ipynb' and 'mano_hand_pose_init' into the official repository ('hot3d/hot3d/'), and run it to visualise and get familiar with the dataset.
|
||||
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@inproceedings{hu25hoigaze,
|
||||
title={HOIGaze: Gaze Estimation During Hand-Object Interactions in Extended Reality Exploiting Eye-Hand-Head Coordination},
|
||||
author={Hu, Zhiming and Haeufle, Daniel and Schmitt, Syn and Bulling, Andreas},
|
||||
booktitle={Proceedings of the 2025 ACM Special Interest Group on Computer Graphics and Interactive Techniques},
|
||||
year={2025}}
|
||||
|
||||
@article{banerjee2024introducing,
|
||||
title={Introducing HOT3D: An Egocentric Dataset for 3D Hand and Object Tracking},
|
||||
author={Banerjee, Prithviraj and Shkodrani, Sindi and Moulon, Pierre and Hampali, Shreyas and Zhang, Fan and Fountain, Jade and Miller, Edward and Basol, Selen and Newcombe, Richard and Wang, Robert and others},
|
||||
journal={arXiv preprint arXiv:2406.09598},
|
||||
year={2024}}
|
||||
```
|
565
hot3d_processing/hot3d_aria_preprocessing.ipynb
Normal file
565
hot3d_processing/hot3d_aria_preprocessing.ipynb
Normal file
|
@ -0,0 +1,565 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bbfb8da4-14a5-44d8-b9b0-d66f032f09fb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.nice(5)\n",
|
||||
"import rerun as rr\n",
|
||||
"import numpy as np\n",
|
||||
"from math import tan\n",
|
||||
"import time\n",
|
||||
"from utils import remake_dir\n",
|
||||
"import pandas as pd\n",
|
||||
"from dataset_api import Hot3dDataProvider\n",
|
||||
"from data_loaders.loader_object_library import load_object_library\n",
|
||||
"from data_loaders.mano_layer import MANOHandModel\n",
|
||||
"from data_loaders.loader_masks import combine_mask_data, load_mask_data, MaskData\n",
|
||||
"from data_loaders.loader_hand_poses import Handedness, HandPose\n",
|
||||
"from data_loaders.loader_object_library import ObjectLibrary\n",
|
||||
"from data_loaders.headsets import Headset\n",
|
||||
"from projectaria_tools.core.stream_id import StreamId\n",
|
||||
"from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions\n",
|
||||
"from projectaria_tools.core.sophus import SE3\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"dataset_path = '/datasets/public/zhiming_datasets/hot3d/aria/'\n",
|
||||
"dataset_processed_path = '/scratch/hu/pose_forecast/hot3d_hoigaze/'\n",
|
||||
"object_library_path = '/datasets/public/zhiming_datasets/hot3d/assets/'\n",
|
||||
"mano_hand_model_path = '/datasets/public/zhiming_datasets/hot3d/mano_v1_2/models/'\n",
|
||||
"remake_dir(dataset_processed_path)\n",
|
||||
"dataset_info = pd.read_csv('hot3d_aria_scene.csv')\n",
|
||||
"valid_frame_length = 60 # 60 frames -> 2 seconds, dropout the recordings that are too short\n",
|
||||
"\n",
|
||||
"# init the object library\n",
|
||||
"if not os.path.exists(object_library_path):\n",
|
||||
" print(\"invalid library path.\")\n",
|
||||
" print(\"please do update the path to VALID values for your system.\")\n",
|
||||
" raise\n",
|
||||
"object_library = load_object_library(object_library_folderpath=object_library_path)\n",
|
||||
"\n",
|
||||
"# load the bounding box information of the objects\n",
|
||||
"object_info = pd.read_csv('hot3d_objects.csv')\n",
|
||||
"object_bbx = {}\n",
|
||||
"for i, uid in enumerate(object_info['object_uid']): \n",
|
||||
" bbx_x_min = object_info['bbx_x_min'][i]\n",
|
||||
" bbx_x_max = object_info['bbx_x_max'][i]\n",
|
||||
" bbx_y_min = object_info['bbx_y_min'][i]\n",
|
||||
" bbx_y_max = object_info['bbx_y_max'][i]\n",
|
||||
" bbx_z_min = object_info['bbx_z_min'][i]\n",
|
||||
" bbx_z_max = object_info['bbx_z_max'][i]\n",
|
||||
" bbx = [bbx_x_min, bbx_x_max, bbx_y_min, bbx_y_max, bbx_z_min, bbx_z_max]\n",
|
||||
" object_bbx[str(uid)] = bbx\n",
|
||||
" \n",
|
||||
"# init the HANDs model. If None, the UmeTrack HANDs model will be used\n",
|
||||
"mano_hand_model = None\n",
|
||||
"if mano_hand_model_path is not None:\n",
|
||||
" mano_hand_model = MANOHandModel(mano_hand_model_path)\n",
|
||||
" \n",
|
||||
"for i, seq in enumerate(dataset_info['sequence_name']):\n",
|
||||
" scene = dataset_info['scene'][i]\n",
|
||||
" print(\"\\nprocessing {}th seq: {}, scene: {}...\".format(i+1, seq, scene))\n",
|
||||
" seq_path = dataset_path + seq + '/'\n",
|
||||
" if not os.path.exists(seq_path):\n",
|
||||
" print(\"invalid input sequence path.\")\n",
|
||||
" print(\"please do update the path to VALID values for your system.\")\n",
|
||||
" raise\n",
|
||||
" save_path = dataset_processed_path + seq + '_' + scene + '_'\n",
|
||||
"\n",
|
||||
" # segment the sequence into valid and invalid parts using the masks \n",
|
||||
" mask_list = [\n",
|
||||
" \"masks/mask_object_pose_available.csv\",\n",
|
||||
" \"masks/mask_hand_pose_available.csv\", \n",
|
||||
" \"masks/mask_headset_pose_available.csv\",\n",
|
||||
" #\"masks/mask_object_visibility.csv\", \n",
|
||||
" #\"masks/mask_hand_visible.csv\",\n",
|
||||
" #\"masks/mask_good_exposure.csv\",\n",
|
||||
" \"masks/mask_qa_pass.csv\"]\n",
|
||||
" \n",
|
||||
" # load the referred masks\n",
|
||||
" mask_data_list = []\n",
|
||||
" for it in mask_list:\n",
|
||||
" if os.path.exists(os.path.join(seq_path, it)):\n",
|
||||
" ret = load_mask_data(os.path.join(seq_path, it))\n",
|
||||
" mask_data_list.append(ret)\n",
|
||||
" # combine the masks (you can choose logical \"and\"/\"or\")\n",
|
||||
" combined_masks = combine_mask_data(mask_data_list, \"and\")\n",
|
||||
" masks = []\n",
|
||||
" for value in combined_masks.data['214-1'].values():\n",
|
||||
" masks.append(value)\n",
|
||||
" print(\"valid frames: {}/{}\".format(sum(masks), len(masks)))\n",
|
||||
" \n",
|
||||
" # initialize hot3d data provider\n",
|
||||
" hot3d_data_provider = Hot3dDataProvider(\n",
|
||||
" sequence_folder=seq_path,\n",
|
||||
" object_library=object_library,\n",
|
||||
" mano_hand_model=mano_hand_model)\n",
|
||||
" #print(f\"data_provider statistics: {hot3d_data_provider.get_data_statistics()}\") \n",
|
||||
" # alias over the HAND pose data provider\n",
|
||||
" hand_data_provider = hot3d_data_provider.mano_hand_data_provider if hot3d_data_provider.mano_hand_data_provider is not None else hot3d_data_provider.umetrack_hand_data_provider\n",
|
||||
" # alias over the Object pose data provider\n",
|
||||
" object_pose_data_provider = hot3d_data_provider.object_pose_data_provider\n",
|
||||
" # alias over the HEADSET/Device pose data provider\n",
|
||||
" device_pose_provider = hot3d_data_provider.device_pose_data_provider\n",
|
||||
" # alias over the Device data provider\n",
|
||||
" device_data_provider = hot3d_data_provider.device_data_provider\n",
|
||||
" device_calibration = device_data_provider.get_device_calibration()\n",
|
||||
" transform_device_cpf = device_calibration.get_transform_device_cpf()\n",
|
||||
" # retrieve a list of timestamps for the sequence (in nanoseconds)\n",
|
||||
" timestamps = device_data_provider.get_sequence_timestamps()\n",
|
||||
" \n",
|
||||
" # segment valid data\n",
|
||||
" index = 0\n",
|
||||
" valid_frames = 0\n",
|
||||
" start_frames = []\n",
|
||||
" end_frames = []\n",
|
||||
" while(index<len(masks)):\n",
|
||||
" value = masks[index] \n",
|
||||
" if value == True:\n",
|
||||
" start = index\n",
|
||||
" start_frames.append(start)\n",
|
||||
" #print(\"start at: {}\".format(start))\n",
|
||||
" while(value == True):\n",
|
||||
" index += 1\n",
|
||||
" if index<len(masks):\n",
|
||||
" value = masks[index]\n",
|
||||
" else:\n",
|
||||
" break \n",
|
||||
" end = index-1\n",
|
||||
" end_frames.append(end)\n",
|
||||
" #print(\"end at: {}\".format(end))\n",
|
||||
" valid_frames += end - start + 1 \n",
|
||||
" else:\n",
|
||||
" index += 1\n",
|
||||
" \n",
|
||||
" segment_num = len(start_frames)\n",
|
||||
" local_time = time.asctime(time.localtime(time.time()))\n",
|
||||
" print('\\nprocessing starts at ' + local_time) \n",
|
||||
" for i in range(segment_num):\n",
|
||||
" start_frame = start_frames[i]\n",
|
||||
" end_frame = end_frames[i]\n",
|
||||
" frame_length = end_frame - start_frame + 1\n",
|
||||
" if frame_length < valid_frame_length:\n",
|
||||
" continue\n",
|
||||
" print(\"start frame: {}, end frame: {}, length: {}\".format(start_frame, end_frame, frame_length))\n",
|
||||
" \n",
|
||||
" timestamps_data = np.zeros((frame_length, 1))\n",
|
||||
" head_data = np.zeros((frame_length, 10)) # head_direction (3) + head_translation (3) + head_rotation (4, quat_xyzw)\n",
|
||||
" gaze_data = np.zeros((frame_length, 6)) # gaze_direction (3) + gaze_center_in_world (3) \n",
|
||||
" hand_data = np.zeros((frame_length, 44)) # left_hand (22) + right_hand (22), hand = wrist_pose (7, translation (3) + rotation (4)) + joint_angles (15) \n",
|
||||
" hand_joint_data = np.zeros((frame_length, 122)) # left_hand (20*3) + right_hand (20*3) + attended_hand_gt + attended_hand_baseline (closest_hand) \n",
|
||||
" hand_joint_initial_data = np.zeros((frame_length, 122)) # left_hand (20*3) + right_hand (20*3) + attended_hand_gt + closest_hand\n",
|
||||
" object_data = np.zeros((frame_length, 48)) # object_data (8) * 6 objects (at most 6 objects), object_data = object_uid (1) + object_pose (7, translation (3) + rotation (4)) \n",
|
||||
" object_bbx_data = np.zeros((frame_length, 144)) # bounding box information: 6 objects (at most 6 objects) * 8 vertexes * 3\n",
|
||||
" object_bbx_left_hand_data = np.zeros((frame_length, 144)) # bounding box information of the objects ranked using distances to the left hand\n",
|
||||
" object_bbx_right_hand_data = np.zeros((frame_length, 144)) # bounding box information of the objects ranked using distances to the right hand\n",
|
||||
" object_bbx_left_hand_initial_data = np.zeros((frame_length, 144)) # bounding box information of the objects ranked using distances to the left hand\n",
|
||||
" object_bbx_right_hand_initial_data = np.zeros((frame_length, 144)) # bounding box information of the objects ranked using distances to the right hand\n",
|
||||
" \n",
|
||||
" # extract the valid frames\n",
|
||||
" for frame in range(start_frame, end_frame+1):\n",
|
||||
" timestamp_ns = timestamps[frame]\n",
|
||||
" timestamps_data[frame-start_frame] = timestamp_ns\n",
|
||||
" \n",
|
||||
" # extract head data\n",
|
||||
" headset_pose3d_with_dt = device_pose_provider.get_pose_at_timestamp(\n",
|
||||
" timestamp_ns=timestamp_ns,\n",
|
||||
" time_query_options=TimeQueryOptions.CLOSEST,\n",
|
||||
" time_domain=TimeDomain.TIME_CODE) \n",
|
||||
" headset_pose3d = headset_pose3d_with_dt.pose3d\n",
|
||||
" T_world_device = headset_pose3d.T_world_device\n",
|
||||
" # use cpf pose as head pose, see https://facebookresearch.github.io/projectaria_tools/docs/data_formats/coordinate_convention/3d_coordinate_frame_convention\n",
|
||||
" T_world_cpf = T_world_device @ transform_device_cpf \n",
|
||||
" head_translation = T_world_cpf.translation()[0]\n",
|
||||
" head_center_in_cpf = np.array([0, 0, 1.0], dtype = np.float64)\n",
|
||||
" head_center_in_world = T_world_cpf @ head_center_in_cpf\n",
|
||||
" head_center_in_world = head_center_in_world.reshape(3, )\n",
|
||||
" head_direction = head_center_in_world - head_translation\n",
|
||||
" head_direction = np.array([x / np.linalg.norm(head_direction) for x in head_direction]) \n",
|
||||
" head_rotation = np.roll(T_world_cpf.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w\n",
|
||||
" head_data[frame-start_frame, 0:3] = head_direction\n",
|
||||
" head_data[frame-start_frame, 3:6] = head_translation\n",
|
||||
" head_data[frame-start_frame, 6:10] = head_rotation\n",
|
||||
" \n",
|
||||
" # extract eye gaze data\n",
|
||||
" aria_eye_gaze_data = device_data_provider.get_eye_gaze(timestamp_ns) \n",
|
||||
" yaw = aria_eye_gaze_data.yaw\n",
|
||||
" pitch = aria_eye_gaze_data.pitch\n",
|
||||
" depth = aria_eye_gaze_data.depth\n",
|
||||
" if depth == 0:\n",
|
||||
" depth = 1\n",
|
||||
" gaze_center_in_cpf = np.array([tan(yaw), tan(pitch), 1.0], dtype = np.float64)*depth\n",
|
||||
" gaze_center_in_world = T_world_cpf @ gaze_center_in_cpf\n",
|
||||
" gaze_center_in_world = gaze_center_in_world.reshape(3, )\n",
|
||||
" gaze_direction = gaze_center_in_world - head_translation\n",
|
||||
" gaze_direction = np.array([x / np.linalg.norm(gaze_direction) for x in gaze_direction])\n",
|
||||
" # in rare cases, yaw, pitch is nan\n",
|
||||
" if np.isnan(np.sum(gaze_direction)):\n",
|
||||
" # use previous frame as an alternative\n",
|
||||
" gaze_direction = gaze_data[frame-start_frame-1, 0:3]\n",
|
||||
" gaze_center_in_world = gaze_data[frame-start_frame-1, 3:6]\n",
|
||||
" gaze_data[frame-start_frame, 0:3] = gaze_direction\n",
|
||||
" gaze_data[frame-start_frame, 3:6] = gaze_center_in_world\n",
|
||||
" \n",
|
||||
" # extract hand data\n",
|
||||
" hand_poses_with_dt = hand_data_provider.get_pose_at_timestamp(\n",
|
||||
" timestamp_ns=timestamp_ns,\n",
|
||||
" time_query_options=TimeQueryOptions.CLOSEST,\n",
|
||||
" time_domain=TimeDomain.TIME_CODE) \n",
|
||||
" hand_pose_collection = hand_poses_with_dt.pose3d_collection\n",
|
||||
" left_hand = hand_pose_collection.poses[Handedness.Left]\n",
|
||||
" left_hand_translation = left_hand.wrist_pose.translation()[0]\n",
|
||||
" left_hand_rotation = np.roll(left_hand.wrist_pose.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w\n",
|
||||
" left_hand_joint_angles = left_hand.joint_angles\n",
|
||||
" left_hand_joints = hand_data_provider.get_hand_landmarks(left_hand).numpy().reshape(-1)\n",
|
||||
" left_hand_initial = HandPose(Handedness.Left, left_hand.wrist_pose, np.zeros(15))\n",
|
||||
" left_hand_initial_joints = hand_data_provider.get_hand_landmarks(left_hand_initial).numpy().reshape(-1) \n",
|
||||
" right_hand = hand_pose_collection.poses[Handedness.Right]\n",
|
||||
" right_hand_translation = right_hand.wrist_pose.translation()[0]\n",
|
||||
" right_hand_rotation = np.roll(right_hand.wrist_pose.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w\n",
|
||||
" right_hand_joint_angles = right_hand.joint_angles\n",
|
||||
" right_hand_joints = hand_data_provider.get_hand_landmarks(right_hand).numpy().reshape(-1)\n",
|
||||
" right_hand_initial = HandPose(Handedness.Right, right_hand.wrist_pose, np.zeros(15))\n",
|
||||
" right_hand_initial_joints = hand_data_provider.get_hand_landmarks(right_hand_initial).numpy().reshape(-1)\n",
|
||||
" \n",
|
||||
" left_hand_direction = np.mean(left_hand_joints.reshape((20, 3)), axis=0) - head_translation\n",
|
||||
" left_hand_direction = np.array([x / np.linalg.norm(left_hand_direction) for x in left_hand_direction]) \n",
|
||||
" left_hand_distance_to_gaze = np.arccos(np.sum(gaze_direction*left_hand_direction))\n",
|
||||
" right_hand_direction = np.mean(right_hand_joints.reshape((20, 3)), axis=0) - head_translation\n",
|
||||
" right_hand_direction = np.array([x / np.linalg.norm(right_hand_direction) for x in right_hand_direction]) \n",
|
||||
" right_hand_distance_to_gaze = np.arccos(np.sum(gaze_direction*right_hand_direction))\n",
|
||||
" if left_hand_distance_to_gaze < right_hand_distance_to_gaze:\n",
|
||||
" hand_joint_data[frame-start_frame, 120:121] = 0\n",
|
||||
" else:\n",
|
||||
" hand_joint_data[frame-start_frame, 120:121] = 1\n",
|
||||
"\n",
|
||||
" left_hand_initial_direction = np.mean(left_hand_initial_joints.reshape((20, 3)), axis=0) - head_translation\n",
|
||||
" left_hand_initial_direction = np.array([x / np.linalg.norm(left_hand_initial_direction) for x in left_hand_initial_direction]) \n",
|
||||
" left_hand_initial_distance_to_gaze = np.arccos(np.sum(gaze_direction*left_hand_initial_direction))\n",
|
||||
" right_hand_initial_direction = np.mean(right_hand_initial_joints.reshape((20, 3)), axis=0) - head_translation\n",
|
||||
" right_hand_initial_direction = np.array([x / np.linalg.norm(right_hand_initial_direction) for x in right_hand_initial_direction]) \n",
|
||||
" right_hand_initial_distance_to_gaze = np.arccos(np.sum(gaze_direction*right_hand_initial_direction))\n",
|
||||
" if left_hand_initial_distance_to_gaze < right_hand_initial_distance_to_gaze:\n",
|
||||
" hand_joint_initial_data[frame-start_frame, 120:121] = 0\n",
|
||||
" else:\n",
|
||||
" hand_joint_initial_data[frame-start_frame, 120:121] = 1\n",
|
||||
" \n",
|
||||
" hand_data[frame-start_frame, 0:3] = left_hand_translation\n",
|
||||
" hand_data[frame-start_frame, 3:7] = left_hand_rotation\n",
|
||||
" hand_data[frame-start_frame, 7:22] = left_hand_joint_angles\n",
|
||||
" hand_data[frame-start_frame, 22:25] = right_hand_translation\n",
|
||||
" hand_data[frame-start_frame, 25:29] = right_hand_rotation\n",
|
||||
" hand_data[frame-start_frame, 29:44] = right_hand_joint_angles\n",
|
||||
" hand_joint_data[frame-start_frame, 0:60] = left_hand_joints\n",
|
||||
" hand_joint_data[frame-start_frame, 60:120] = right_hand_joints\n",
|
||||
" hand_joint_initial_data[frame-start_frame, 0:60] = left_hand_initial_joints\n",
|
||||
" hand_joint_initial_data[frame-start_frame, 60:120] = right_hand_initial_joints\n",
|
||||
" \n",
|
||||
" # extract object data\n",
|
||||
" object_poses_with_dt = object_pose_data_provider.get_pose_at_timestamp(\n",
|
||||
" timestamp_ns=timestamp_ns,\n",
|
||||
" time_query_options=TimeQueryOptions.CLOSEST,\n",
|
||||
" time_domain=TimeDomain.TIME_CODE)\n",
|
||||
" objects_pose3d = object_poses_with_dt.pose3d_collection.poses\n",
|
||||
" object_num = len(objects_pose3d)\n",
|
||||
" objects_distance_to_left_hand = {}\n",
|
||||
" objects_distance_to_right_hand = {} \n",
|
||||
" objects_distance_to_left_hand_initial = {}\n",
|
||||
" objects_distance_to_right_hand_initial = {}\n",
|
||||
" objects_pose3d_dict = {}\n",
|
||||
" item = 0\n",
|
||||
" for (object_uid, object_pose3d) in objects_pose3d.items():\n",
|
||||
" object_translation = object_pose3d.T_world_object.translation()[0] \n",
|
||||
" object_distance_to_left_hand = np.mean(np.linalg.norm(left_hand_joints.reshape((20, 3))-object_translation, axis=1))\n",
|
||||
" object_distance_to_right_hand = np.mean(np.linalg.norm(right_hand_joints.reshape((20, 3))-object_translation, axis=1))\n",
|
||||
" object_distance_to_left_hand_initial = np.mean(np.linalg.norm(left_hand_initial_joints.reshape((20, 3))-object_translation, axis=1))\n",
|
||||
" object_distance_to_right_hand_initial = np.mean(np.linalg.norm(right_hand_initial_joints.reshape((20, 3))-object_translation, axis=1)) \n",
|
||||
" objects_distance_to_left_hand[object_uid] = object_distance_to_left_hand \n",
|
||||
" objects_distance_to_right_hand[object_uid] = object_distance_to_right_hand\n",
|
||||
" objects_distance_to_left_hand_initial[object_uid] = object_distance_to_left_hand_initial\n",
|
||||
" objects_distance_to_right_hand_initial[object_uid] = object_distance_to_right_hand_initial\n",
|
||||
" objects_pose3d_dict[object_uid] = object_pose3d.T_world_object \n",
|
||||
" item += 1\n",
|
||||
" \n",
|
||||
" objects_distance_to_left_hand_sorted = sorted(objects_distance_to_left_hand.items(), key = lambda kv:(kv[1], kv[0]))\n",
|
||||
" objects_distance_to_right_hand_sorted = sorted(objects_distance_to_right_hand.items(), key = lambda kv:(kv[1], kv[0])) \n",
|
||||
" left_object_closest_uid = objects_distance_to_left_hand_sorted[0][0]\n",
|
||||
" left_object_closest_distance = objects_distance_to_left_hand_sorted[0][1]\n",
|
||||
" right_object_closest_uid = objects_distance_to_right_hand_sorted[0][0]\n",
|
||||
" right_object_closest_distance = objects_distance_to_right_hand_sorted[0][1]\n",
|
||||
" if left_object_closest_distance < right_object_closest_distance:\n",
|
||||
" hand_joint_data[frame-start_frame, -1] = 0\n",
|
||||
" else:\n",
|
||||
" hand_joint_data[frame-start_frame, -1] = 1\n",
|
||||
"\n",
|
||||
" objects_distance_to_left_hand_initial_sorted = sorted(objects_distance_to_left_hand_initial.items(), key = lambda kv:(kv[1], kv[0]))\n",
|
||||
" objects_distance_to_right_hand_initial_sorted = sorted(objects_distance_to_right_hand_initial.items(), key = lambda kv:(kv[1], kv[0]))\n",
|
||||
" left_initial_object_closest_uid = objects_distance_to_left_hand_initial_sorted[0][0]\n",
|
||||
" left_initial_object_closest_distance = objects_distance_to_left_hand_initial_sorted[0][1]\n",
|
||||
" right_initial_object_closest_uid = objects_distance_to_right_hand_initial_sorted[0][0]\n",
|
||||
" right_initial_object_closest_distance = objects_distance_to_right_hand_initial_sorted[0][1]\n",
|
||||
" if left_initial_object_closest_distance < right_initial_object_closest_distance:\n",
|
||||
" hand_joint_initial_data[frame-start_frame, -1] = 0\n",
|
||||
" else:\n",
|
||||
" hand_joint_initial_data[frame-start_frame, -1] = 1\n",
|
||||
" \n",
|
||||
" item = 0\n",
|
||||
" for object_uid in objects_pose3d_dict:\n",
|
||||
" object_pose3d = objects_pose3d_dict[object_uid]\n",
|
||||
" object_translation = object_pose3d.translation()[0]\n",
|
||||
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
|
||||
" object_data[frame-start_frame, item*8:item*8+1] = object_uid\n",
|
||||
" object_data[frame-start_frame, item*8+1:item*8+4] = object_translation\n",
|
||||
" object_data[frame-start_frame, item*8+4:item*8+8] = object_rotation\n",
|
||||
" bbx = object_bbx[object_uid]\n",
|
||||
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
|
||||
" x_min = bbx[0]\n",
|
||||
" x_max = bbx[1]\n",
|
||||
" y_min = bbx[2]\n",
|
||||
" y_max = bbx[3]\n",
|
||||
" z_min = bbx[4]\n",
|
||||
" z_max = bbx[5]\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
|
||||
" object_bbx_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
|
||||
" item += 1\n",
|
||||
" \n",
|
||||
" for item in range(len(objects_distance_to_left_hand_sorted)):\n",
|
||||
" object_uid = objects_distance_to_left_hand_sorted[item][0]\n",
|
||||
" object_pose3d = objects_pose3d_dict[object_uid] \n",
|
||||
" object_translation = object_pose3d.translation()[0]\n",
|
||||
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
|
||||
" bbx = object_bbx[object_uid]\n",
|
||||
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
|
||||
" x_min = bbx[0]\n",
|
||||
" x_max = bbx[1]\n",
|
||||
" y_min = bbx[2]\n",
|
||||
" y_max = bbx[3]\n",
|
||||
" z_min = bbx[4]\n",
|
||||
" z_max = bbx[5]\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
|
||||
" object_bbx_left_hand_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
|
||||
"\n",
|
||||
" for item in range(len(objects_distance_to_right_hand_sorted)):\n",
|
||||
" object_uid = objects_distance_to_right_hand_sorted[item][0]\n",
|
||||
" object_pose3d = objects_pose3d_dict[object_uid] \n",
|
||||
" object_translation = object_pose3d.translation()[0]\n",
|
||||
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
|
||||
" bbx = object_bbx[object_uid]\n",
|
||||
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
|
||||
" x_min = bbx[0]\n",
|
||||
" x_max = bbx[1]\n",
|
||||
" y_min = bbx[2]\n",
|
||||
" y_max = bbx[3]\n",
|
||||
" z_min = bbx[4]\n",
|
||||
" z_max = bbx[5]\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
|
||||
" object_bbx_right_hand_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
|
||||
"\n",
|
||||
" for item in range(len(objects_distance_to_left_hand_initial_sorted)):\n",
|
||||
" object_uid = objects_distance_to_left_hand_initial_sorted[item][0]\n",
|
||||
" object_pose3d = objects_pose3d_dict[object_uid] \n",
|
||||
" object_translation = object_pose3d.translation()[0]\n",
|
||||
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
|
||||
" bbx = object_bbx[object_uid]\n",
|
||||
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
|
||||
" x_min = bbx[0]\n",
|
||||
" x_max = bbx[1]\n",
|
||||
" y_min = bbx[2]\n",
|
||||
" y_max = bbx[3]\n",
|
||||
" z_min = bbx[4]\n",
|
||||
" z_max = bbx[5]\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
|
||||
" object_bbx_left_hand_initial_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_left_hand_initial_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
|
||||
"\n",
|
||||
" for item in range(len(objects_distance_to_right_hand_initial_sorted)):\n",
|
||||
" object_uid = objects_distance_to_right_hand_initial_sorted[item][0]\n",
|
||||
" object_pose3d = objects_pose3d_dict[object_uid] \n",
|
||||
" object_translation = object_pose3d.translation()[0]\n",
|
||||
" object_rotation = np.roll(object_pose3d.rotation().to_quat()[0], -1) # change from w,x,y,z to x,y,z,w \n",
|
||||
" bbx = object_bbx[object_uid]\n",
|
||||
" #print(\"uid: {}, bbx: {}\".format(object_uid, bbx))\n",
|
||||
" x_min = bbx[0]\n",
|
||||
" x_max = bbx[1]\n",
|
||||
" y_min = bbx[2]\n",
|
||||
" y_max = bbx[3]\n",
|
||||
" z_min = bbx[4]\n",
|
||||
" z_max = bbx[5]\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, ) \n",
|
||||
" object_bbx_right_hand_initial_data[frame-start_frame, item*24:item*24+3] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+3:item*24+6] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+6:item*24+9] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_min, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+9:item*24+12] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+12:item*24+15] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_max], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+15:item*24+18] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_max, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+18:item*24+21] = bbx_vertex\n",
|
||||
" bbx_vertex = np.array([x_min, y_max, z_min], dtype = np.float64) \n",
|
||||
" bbx_vertex = (object_pose3d @ bbx_vertex).reshape(3, )\n",
|
||||
" object_bbx_right_hand_initial_data[frame-start_frame, item*24+21:item*24+24] = bbx_vertex \n",
|
||||
" \n",
|
||||
" # save the data\n",
|
||||
" timestamps_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_timestamps.npy'\n",
|
||||
" head_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_head.npy'\n",
|
||||
" gaze_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_gaze.npy'\n",
|
||||
" hand_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_hand.npy' \n",
|
||||
" hand_joint_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_handjoints.npy'\n",
|
||||
" hand_joint_initial_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_inithandjoints.npy'\n",
|
||||
" object_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_objects.npy'\n",
|
||||
" object_bbx_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_bbx.npy'\n",
|
||||
" object_bbx_left_hand_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_bbxleft.npy'\n",
|
||||
" object_bbx_right_hand_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_bbxright.npy'\n",
|
||||
" object_bbx_left_hand_initial_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_initbbxleft.npy'\n",
|
||||
" object_bbx_right_hand_initial_path = save_path + str(start_frame) + \"_\" + str(end_frame) + '_object_initbbxright.npy'\n",
|
||||
" \n",
|
||||
" np.save(timestamps_path, timestamps_data)\n",
|
||||
" np.save(head_path, head_data)\n",
|
||||
" np.save(gaze_path, gaze_data)\n",
|
||||
" np.save(hand_path, hand_data)\n",
|
||||
" np.save(hand_joint_path, hand_joint_data)\n",
|
||||
" np.save(hand_joint_initial_path, hand_joint_initial_data)\n",
|
||||
" np.save(object_path, object_data)\n",
|
||||
" np.save(object_bbx_path, object_bbx_data)\n",
|
||||
" np.save(object_bbx_left_hand_path, object_bbx_left_hand_data)\n",
|
||||
" np.save(object_bbx_right_hand_path, object_bbx_right_hand_data)\n",
|
||||
" np.save(object_bbx_left_hand_initial_path, object_bbx_left_hand_initial_data)\n",
|
||||
" np.save(object_bbx_right_hand_initial_path, object_bbx_right_hand_initial_data)\n",
|
||||
" \n",
|
||||
" local_time = time.asctime(time.localtime(time.time()))\n",
|
||||
" print('\\nprocessing ends at ' + local_time)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec30e81d-8812-4665-be58-00a7e0aa1915",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
137
hot3d_processing/hot3d_aria_scene.csv
Normal file
137
hot3d_processing/hot3d_aria_scene.csv
Normal file
|
@ -0,0 +1,137 @@
|
|||
sequence_name,scene
|
||||
P0001_10a27bf7,room
|
||||
P0001_15c4300c,kitchen
|
||||
P0001_23fa0ee8,office
|
||||
P0001_4bf4e21a,room
|
||||
P0001_550ea2ac,room
|
||||
P0001_624f2ba9,kitchen
|
||||
P0001_83b755aa,room
|
||||
P0001_8d136980,room
|
||||
P0001_95eabc0a,kitchen
|
||||
P0001_9b6feab7,office
|
||||
P0001_9c030609,room
|
||||
P0001_a68492d5,kitchen
|
||||
P0001_a9d6c83d,room
|
||||
P0001_b2bcbe28,room
|
||||
P0001_f6cc0cc8,office
|
||||
P0001_f71fc9b1,room
|
||||
P0002_016222d1,kitchen
|
||||
P0002_2ea9af5b,kitchen
|
||||
P0002_59a84a3a,kitchen
|
||||
P0002_65085bfc,kitchen
|
||||
P0002_c68438ce,kitchen
|
||||
P0003_1fa0bb65,room
|
||||
P0003_5766eae8,kitchen
|
||||
P0003_743d3087,office
|
||||
P0003_c701bd11,office
|
||||
P0003_cbd17f20,kitchen
|
||||
P0003_e3a74169,kitchen
|
||||
P0003_ebdc6ff7,kitchen
|
||||
P0003_f4730606,kitchen
|
||||
P0009_02511c2f,room
|
||||
P0009_0e02bf39,kitchen
|
||||
P0009_1b21bb01,kitchen
|
||||
P0009_247b5403,kitchen
|
||||
P0009_24af5475,room
|
||||
P0009_37e30628,office
|
||||
P0009_5bb7cc2a,room
|
||||
P0009_5f2583e0,room
|
||||
P0009_88a58b3b,room
|
||||
P0009_8c0caf6c,kitchen
|
||||
P0009_8c95667c,room
|
||||
P0009_911a4ba5,room
|
||||
P0009_927973d7,office
|
||||
P0009_9e03121a,kitchen
|
||||
P0009_b1e71d7f,room
|
||||
P0009_b7d8c222,kitchen
|
||||
P0009_cf323827,room
|
||||
P0009_e27ad19f,room
|
||||
P0009_e71e2f24,kitchen
|
||||
P0009_ea91844d,office
|
||||
P0010_0ecbf39f,kitchen
|
||||
P0010_1573a837,office
|
||||
P0010_160e551c,room
|
||||
P0010_1c9fe708,kitchen
|
||||
P0010_41c4c626,kitchen
|
||||
P0010_573b2de6,office
|
||||
P0010_6d15aa57,room
|
||||
P0010_8ff3e5c4,room
|
||||
P0010_924e574e,kitchen
|
||||
P0010_99385ad8,room
|
||||
P0010_b152143e,office
|
||||
P0010_e481fd15,room
|
||||
P0010_fd1891a8,room
|
||||
P0011_0c2d00ed,kitchen
|
||||
P0011_11475e24,kitchen
|
||||
P0011_2255f410,kitchen
|
||||
P0011_30a5b61d,room
|
||||
P0011_451a7734,room
|
||||
P0011_47878e48,office
|
||||
P0011_4dcdcfec,office
|
||||
P0011_72efb935,kitchen
|
||||
P0011_76ea6d47,kitchen
|
||||
P0011_7a8886e7,room
|
||||
P0011_7c7e8d86,room
|
||||
P0011_a497658f,office
|
||||
P0011_ad5f7dee,room
|
||||
P0011_b0c895c1,kitchen
|
||||
P0011_ccc678c7,room
|
||||
P0011_cd22f5e0,kitchen
|
||||
P0011_cee8fe4f,room
|
||||
P0011_d0f907a0,office
|
||||
P0011_dbf1ebb7,room
|
||||
P0011_e5cb6a80,kitchen
|
||||
P0011_ff7cfc3a,kitchen
|
||||
P0012_0323d770,room
|
||||
P0012_0a21e4c2,room
|
||||
P0012_119de519,room
|
||||
P0012_130a66e1,office
|
||||
P0012_1c45bbd9,office
|
||||
P0012_44c5f677,kitchen
|
||||
P0012_476bae57,kitchen
|
||||
P0012_6e4f7815,office
|
||||
P0012_73e66984,room
|
||||
P0012_915e71c6,office
|
||||
P0012_af3dab9a,room
|
||||
P0012_b8fc4c1b,kitchen
|
||||
P0012_c06d939b,kitchen
|
||||
P0012_c4353c31,room
|
||||
P0012_ca1f6626,room
|
||||
P0012_d6272ce1,office
|
||||
P0012_d85e10f6,kitchen
|
||||
P0012_db543fe8,kitchen
|
||||
P0012_e846d3cc,kitchen
|
||||
P0012_e97d31b6,kitchen
|
||||
P0012_f1a33781,room
|
||||
P0012_f7e3880b,room
|
||||
P0014_0de48e2c,room
|
||||
P0014_1fdca00e,office
|
||||
P0014_24cb3bf0,kitchen
|
||||
P0014_65c6a968,office
|
||||
P0014_6db96fd0,kitchen
|
||||
P0014_8254f925,kitchen
|
||||
P0014_84ea2dcc,kitchen
|
||||
P0014_9a25ec6a,office
|
||||
P0014_9b7a0725,room
|
||||
P0014_e40eec5d,kitchen
|
||||
P0014_f7ba43e0,room
|
||||
P0015_179e1b84,kitchen
|
||||
P0015_1eb4f17f,kitchen
|
||||
P0015_367cc58d,room
|
||||
P0015_3a9bb2ae,room
|
||||
P0015_3c7f5241,room
|
||||
P0015_3f46732d,room
|
||||
P0015_42b8b389,room
|
||||
P0015_4e5d1c14,room
|
||||
P0015_53358a64,office
|
||||
P0015_60573a3b,kitchen
|
||||
P0015_745920f0,kitchen
|
||||
P0015_7e95628d,room
|
||||
P0015_7fc6548d,kitchen
|
||||
P0015_b0c5102b,room
|
||||
P0015_c3e1f590,kitchen
|
||||
P0015_cc5dfd13,room
|
||||
P0015_cc739faa,room
|
||||
P0015_dbe00981,room
|
||||
P0015_e3f96b3f,office
|
||||
P0015_e7458eb3,office
|
|
296
hot3d_processing/hot3d_aria_visualisation.ipynb
Normal file
296
hot3d_processing/hot3d_aria_visualisation.ipynb
Normal file
|
@ -0,0 +1,296 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bbfb8da4-14a5-44d8-b9b0-d66f032f09fb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.nice(5)\n",
|
||||
"import rerun as rr\n",
|
||||
"import numpy as np\n",
|
||||
"from math import tan\n",
|
||||
"import time\n",
|
||||
"from utils import remake_dir\n",
|
||||
"import pandas as pd\n",
|
||||
"from data_loaders.ManoHandDataProvider import MANOHandDataProvider\n",
|
||||
"from data_loaders.loader_object_library import load_object_library\n",
|
||||
"from data_loaders.mano_layer import MANOHandModel\n",
|
||||
"from data_loaders.loader_hand_poses import Handedness, HandPose\n",
|
||||
"from data_loaders.hand_common import LANDMARK_CONNECTIVITY\n",
|
||||
"from data_loaders.loader_object_library import ObjectLibrary\n",
|
||||
"from data_loaders.headsets import Headset\n",
|
||||
"from projectaria_tools.core.stream_id import StreamId\n",
|
||||
"from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions\n",
|
||||
"from projectaria_tools.core.sophus import SE3\n",
|
||||
"from projectaria_tools.utils.rerun_helpers import ToTransform3D\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"data_path = '/scratch/hu/pose_forecast/hot3d_hoigaze/P0001_10a27bf7_room_721_890_'\n",
|
||||
"timestamps_path = data_path + 'timestamps.npy'\n",
|
||||
"head_path = data_path + 'head.npy'\n",
|
||||
"gaze_path = data_path + 'gaze.npy'\n",
|
||||
"hand_path = data_path + 'hand.npy'\n",
|
||||
"hand_joint_path = data_path + 'handjoints.npy'\n",
|
||||
"object_path = data_path + 'objects.npy'\n",
|
||||
"object_bbx_path = data_path + 'object_bbx.npy'\n",
|
||||
"object_library_path = '/datasets/public/zhiming_datasets/hot3d/assets/'\n",
|
||||
"mano_hand_model_path = '/datasets/public/zhiming_datasets/hot3d/mano_v1_2/models/'\n",
|
||||
"\n",
|
||||
"show_bbx = False\n",
|
||||
"show_hand_mesh = True\n",
|
||||
"# init the object library\n",
|
||||
"if not os.path.exists(object_library_path):\n",
|
||||
" print(\"invalid object library path.\")\n",
|
||||
" print(\"please follow the instructions at https://github.com/facebookresearch/hot3d to Download the HOT3D Assets Dataset\") \n",
|
||||
" raise \n",
|
||||
"object_library = load_object_library(object_library_folderpath=object_library_path)\n",
|
||||
"\n",
|
||||
"# init the HANDs model\n",
|
||||
"if not os.path.exists(mano_hand_model_path):\n",
|
||||
" print(\"invalid mano hand model path.\")\n",
|
||||
" print(\"please follow the instructions at https://github.com/facebookresearch/hot3d to Download the MANO files\")\n",
|
||||
" raise\n",
|
||||
"mano_hand_model = MANOHandModel(mano_hand_model_path)\n",
|
||||
"\n",
|
||||
"timestamps_data = np.load(timestamps_path)\n",
|
||||
"head_data = np.load(head_path) # head_direction (3) + head_translation (3) + head_rotation (4, quat_xyzw)\n",
|
||||
"gaze_data = np.load(gaze_path) # gaze_direction (3) + gaze_center_in_world (3)\n",
|
||||
"hand_data = np.load(hand_path) # left_hand (22) + right_hand (22), hand = wrist_pose (7, translation (3) + rotation (4)) + joint_angles (15)\n",
|
||||
"hand_joint_data = np.load(hand_joint_path) # left_hand (20*3) + right_hand (20*3)\n",
|
||||
"object_data = np.load(object_path) # object_data (8) * 6 objects (at most 6 objects), object_data = object_uid (1) + object_pose (7, translation (3) + rotation (4))\n",
|
||||
"object_bbx_data = np.load(object_bbx_path) # bounding box information: 6 objects (at most 6 objects) * 8 vertexes * 3\n",
|
||||
"frame_length = len(head_data)\n",
|
||||
"\n",
|
||||
"# alias over the HAND pose data provider\n",
|
||||
"hand_data_provider = MANOHandDataProvider('./mano_hand_pose_init/mano_hand_pose_trajectory.jsonl', mano_hand_model)\n",
|
||||
"# keep track of what 3D assets have been loaded/unloaded so we will load them only when needed\n",
|
||||
"object_cache_status = {}\n",
|
||||
"\n",
|
||||
"# Init a rerun context\n",
|
||||
"rr.init(\"hot3d-aria\")\n",
|
||||
"rec = rr.memory_recording()\n",
|
||||
"\n",
|
||||
"def log_pose(\n",
|
||||
" pose: SE3,\n",
|
||||
" label: str,\n",
|
||||
" static=False\n",
|
||||
") -> None:\n",
|
||||
" rr.log(label, ToTransform3D(pose, False), static=static)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in range(frame_length):\n",
|
||||
" timestamp_ns = timestamps_data[i, 0]\n",
|
||||
" rr.set_time_nanos(\"synchronization_time\", int(timestamp_ns))\n",
|
||||
" rr.set_time_sequence(\"timestamp\", int(timestamp_ns))\n",
|
||||
" \n",
|
||||
" head_direction = head_data[i, 0:3]\n",
|
||||
" head_translation = head_data[i, 3:6]\n",
|
||||
" head_rotation = head_data[i, 6:10] \n",
|
||||
" gaze_direction = gaze_data[i, 0:3]\n",
|
||||
" gaze_center_in_world = gaze_data[i, 3:6]\n",
|
||||
" left_hand_translation = hand_data[i, 0:3]\n",
|
||||
" left_hand_rotation = hand_data[i, 3:7]\n",
|
||||
" left_hand_joint_angles = hand_data[i, 7:22]\n",
|
||||
" left_hand_joints = hand_joint_data[i, 0:60].reshape((20, 3))\n",
|
||||
" right_hand_translation = hand_data[i, 22:25]\n",
|
||||
" right_hand_rotation = hand_data[i, 25:29]\n",
|
||||
" right_hand_joint_angles = hand_data[i, 29:44]\n",
|
||||
" right_hand_joints = hand_joint_data[i, 60:120].reshape((20, 3))\n",
|
||||
" left_hand_wrist_pose = SE3.from_quat_and_translation(left_hand_rotation[-1], left_hand_rotation[:-1], left_hand_translation)\n",
|
||||
" left_hand_pose = HandPose(Handedness.Left, left_hand_wrist_pose, left_hand_joint_angles)\n",
|
||||
" right_hand_wrist_pose = SE3.from_quat_and_translation(right_hand_rotation[-1], right_hand_rotation[:-1], right_hand_translation)\n",
|
||||
" right_hand_pose = HandPose(Handedness.Right, right_hand_wrist_pose, right_hand_joint_angles)\n",
|
||||
" \n",
|
||||
" # use cpf pose as head pose, see https://facebookresearch.github.io/projectaria_tools/docs/data_formats/coordinate_convention/3d_coordinate_frame_convention\n",
|
||||
" T_world_cpf = SE3.from_quat_and_translation(head_rotation[-1], head_rotation[:-1], head_translation)\n",
|
||||
" log_pose(pose=T_world_cpf, label=\"world/head_pose\")\n",
|
||||
" #rr.log(\n",
|
||||
" #\"world/head_direction\",\n",
|
||||
" #rr.Points3D([head_translation], radii=[0.003]),\n",
|
||||
" #rr.Arrows3D(vectors=[head_direction*0.4], colors=[[0, 0.8, 0.8, 0.5]]))\n",
|
||||
" #log_pose(pose=left_hand_wrist_pose, label=\"world/left_hand_pose\")\n",
|
||||
" #log_pose(pose=right_hand_wrist_pose, label=\"world/right_hand_pose\")\n",
|
||||
" \n",
|
||||
" rr.log(\n",
|
||||
" \"world/gaze_direction\",\n",
|
||||
" rr.Points3D([head_translation], radii=[0.003]),\n",
|
||||
" rr.Arrows3D(vectors=[gaze_direction*0.4], colors=[[0, 0.8, 0.2, 0.5]]))\n",
|
||||
" #print(\"frame: {}, gaze: {}\".format(i+1119, gaze_direction))\n",
|
||||
" \n",
|
||||
" # plot hands as a triangular mesh representation\n",
|
||||
" if show_hand_mesh:\n",
|
||||
" left_hand_mesh_vertices = hand_data_provider.get_hand_mesh_vertices(left_hand_pose)\n",
|
||||
" left_hand_triangles, left_hand_vertex_normals = hand_data_provider.get_hand_mesh_faces_and_normals(left_hand_pose) \n",
|
||||
" rr.log(\n",
|
||||
" f\"world/left_hand/mesh_faces\",\n",
|
||||
" rr.Mesh3D(\n",
|
||||
" vertex_positions=left_hand_mesh_vertices,\n",
|
||||
" vertex_normals=left_hand_vertex_normals,\n",
|
||||
" triangle_indices=left_hand_triangles))\n",
|
||||
" right_hand_mesh_vertices = hand_data_provider.get_hand_mesh_vertices(right_hand_pose)\n",
|
||||
" right_hand_triangles, right_hand_vertex_normals = hand_data_provider.get_hand_mesh_faces_and_normals(right_hand_pose)\n",
|
||||
" rr.log(\n",
|
||||
" f\"world/right_hand/mesh_faces\",\n",
|
||||
" rr.Mesh3D(\n",
|
||||
" vertex_positions=right_hand_mesh_vertices,\n",
|
||||
" vertex_normals=right_hand_vertex_normals,\n",
|
||||
" triangle_indices=right_hand_triangles))\n",
|
||||
" else:\n",
|
||||
" #left_hand_translation = np.array([0, 0, 0])\n",
|
||||
" #left_hand_rotation = np.array([0, 0, 0, 1])\n",
|
||||
" #left_hand_joint_angles = np.zeros(15)\n",
|
||||
" #left_hand_wrist_pose = SE3.from_quat_and_translation(left_hand_rotation[-1], left_hand_rotation[:-1], left_hand_translation)\n",
|
||||
" #log_pose(pose=left_hand_wrist_pose, label=\"world/left_hand_pose\")\n",
|
||||
" #left_hand_pose = HandPose(Handedness.Left, left_hand_wrist_pose, left_hand_joint_angles)\n",
|
||||
" #left_hand_joints = hand_data_provider.get_hand_landmarks(left_hand_pose)\n",
|
||||
" #left_hand_wrist = left_hand_joints[5, :].clone()\n",
|
||||
" #joint_number = left_hand_joints.shape[0]\n",
|
||||
" #for index in range(joint_number):\n",
|
||||
" # left_hand_joints[index, :] -= left_hand_wrist\n",
|
||||
" #for index in range(joint_number): \n",
|
||||
" # tmp = left_hand_joints[index, :].clone()\n",
|
||||
" # left_hand_joints[index, 1] = -tmp[2]\n",
|
||||
" # left_hand_joints[index, 2] = tmp[1]\n",
|
||||
" #for index in range(joint_number): \n",
|
||||
" # print(left_hand_joints[index])\n",
|
||||
" \n",
|
||||
" #right_hand_translation = np.array([0, 0, 0])\n",
|
||||
" #right_hand_rotation = np.array([0, 0, 0, 1])\n",
|
||||
" #right_hand_joint_angles = np.zeros(15)\n",
|
||||
" #right_hand_wrist_pose = SE3.from_quat_and_translation(right_hand_rotation[-1], right_hand_rotation[:-1], right_hand_translation)\n",
|
||||
" #log_pose(pose=right_hand_wrist_pose, label=\"world/right_hand_pose\")\n",
|
||||
" #right_hand_pose = HandPose(Handedness.Right, right_hand_wrist_pose, right_hand_joint_angles)\n",
|
||||
" #right_hand_joints = hand_data_provider.get_hand_landmarks(right_hand_pose)\n",
|
||||
" #right_hand_wrist = right_hand_joints[5, :].clone()\n",
|
||||
" #joint_number = right_hand_joints.shape[0]\n",
|
||||
" #for index in range(joint_number):\n",
|
||||
" # right_hand_joints[index, :] -= right_hand_wrist\n",
|
||||
" #for index in range(joint_number): \n",
|
||||
" # tmp = right_hand_joints[index, :].clone()\n",
|
||||
" # right_hand_joints[index, 1] = -tmp[2]\n",
|
||||
" # right_hand_joints[index, 2] = tmp[1]\n",
|
||||
" #for index in range(joint_number): \n",
|
||||
" # print(right_hand_joints[index])\n",
|
||||
" \n",
|
||||
" left_hand_skeleton = [connections\n",
|
||||
" for connectivity in LANDMARK_CONNECTIVITY\n",
|
||||
" for connections in [[left_hand_joints[it].tolist() for it in connectivity]]] \n",
|
||||
" rr.log(\n",
|
||||
" f\"world/left_hand_skeleton\",\n",
|
||||
" rr.LineStrips3D(left_hand_skeleton, radii=0.002),\n",
|
||||
" )\n",
|
||||
" right_hand_skeleton = [connections\n",
|
||||
" for connectivity in LANDMARK_CONNECTIVITY\n",
|
||||
" for connections in [[right_hand_joints[it].tolist() for it in connectivity]]] \n",
|
||||
" rr.log(\n",
|
||||
" f\"world/right_hand_skeleton\",\n",
|
||||
" rr.LineStrips3D(right_hand_skeleton, radii=0.002),\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # load objects\n",
|
||||
" object_num_max = 6\n",
|
||||
" logging_status = {} \n",
|
||||
" for item in range(object_num_max):\n",
|
||||
" object_uid = str(int(object_data[i, item*8]))\n",
|
||||
" if object_uid == '0':\n",
|
||||
" break\n",
|
||||
" logging_status[object_uid] = False\n",
|
||||
" object_num = len(logging_status)\n",
|
||||
" for item in range(object_num):\n",
|
||||
" object_uid = str(int(object_data[i, item*8]))\n",
|
||||
" object_name = object_library.object_id_to_name_dict[object_uid]\n",
|
||||
" object_cad_asset_filepath = ObjectLibrary.get_cad_asset_path(\n",
|
||||
" object_library_folderpath=object_library.asset_folder_name,\n",
|
||||
" object_id=object_uid)\n",
|
||||
" object_translation = object_data[i, item*8+1:item*8+4]\n",
|
||||
" object_rotation = object_data[i, item*8+4:item*8+8]\n",
|
||||
" object_pose = SE3.from_quat_and_translation(object_rotation[-1], object_rotation[:-1], object_translation) \n",
|
||||
" log_pose(pose=object_pose, label=f\"world/objects/{object_name}\")\n",
|
||||
" logging_status[object_uid] = True # mark object has been seen (enable to know which object has been logged or not)\n",
|
||||
" \n",
|
||||
" if show_bbx:\n",
|
||||
" bbx_vertex_0 = object_bbx_data[i, item*24:item*24+3]\n",
|
||||
" bbx_vertex_1 = object_bbx_data[i, item*24+3:item*24+6]\n",
|
||||
" bbx_vertex_2 = object_bbx_data[i, item*24+6:item*24+9]\n",
|
||||
" bbx_vertex_3 = object_bbx_data[i, item*24+9:item*24+12]\n",
|
||||
" bbx_vertex_4 = object_bbx_data[i, item*24+12:item*24+15]\n",
|
||||
" bbx_vertex_5 = object_bbx_data[i, item*24+15:item*24+18]\n",
|
||||
" bbx_vertex_6 = object_bbx_data[i, item*24+18:item*24+21]\n",
|
||||
" bbx_vertex_7 = object_bbx_data[i, item*24+21:item*24+24] \n",
|
||||
" points = [\n",
|
||||
" bbx_vertex_0,\n",
|
||||
" bbx_vertex_1,\n",
|
||||
" bbx_vertex_2,\n",
|
||||
" bbx_vertex_3,\n",
|
||||
" bbx_vertex_0,\n",
|
||||
" bbx_vertex_7,\n",
|
||||
" bbx_vertex_6,\n",
|
||||
" bbx_vertex_5,\n",
|
||||
" bbx_vertex_4,\n",
|
||||
" bbx_vertex_7,\n",
|
||||
" bbx_vertex_6,\n",
|
||||
" bbx_vertex_1,\n",
|
||||
" bbx_vertex_2,\n",
|
||||
" bbx_vertex_5,\n",
|
||||
" bbx_vertex_4,\n",
|
||||
" bbx_vertex_3]\n",
|
||||
" rr.log(f\"world/objects_bbx/{object_name}\", rr.LineStrips3D([points]))\n",
|
||||
" \n",
|
||||
" # link the corresponding 3D object to the pose\n",
|
||||
" if object_uid not in object_cache_status.keys():\n",
|
||||
" object_cache_status[object_uid] = True \n",
|
||||
" rr.log(\n",
|
||||
" f\"world/objects/{object_name}\",\n",
|
||||
" rr.Asset3D(path=object_cad_asset_filepath)) \n",
|
||||
" # if some objects are not visible, we clear the entity\n",
|
||||
" for object_uid, displayed in logging_status.items():\n",
|
||||
" if not displayed:\n",
|
||||
" object_name = object_library.object_id_to_name_dict[object_uid]\n",
|
||||
" rr.log(\n",
|
||||
" f\"world/objects/{object_name}\",\n",
|
||||
" rr.Clear.recursive())\n",
|
||||
" if show_bbx:\n",
|
||||
" rr.log(\n",
|
||||
" f\"world/objects_bbx/{object_name}\",\n",
|
||||
" rr.Clear.recursive()) \n",
|
||||
" if object_uid in object_cache_status.keys():\n",
|
||||
" del object_cache_status[object_uid] # we will log the mesh again\n",
|
||||
" \n",
|
||||
"# show the rerun window\n",
|
||||
"rr.notebook_show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7811d038-6150-46b3-ac62-d4875191e0f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
34
hot3d_processing/hot3d_objects.csv
Normal file
34
hot3d_processing/hot3d_objects.csv
Normal file
|
@ -0,0 +1,34 @@
|
|||
object_uid,object_name,bbx_x_min,bbx_x_max,bbx_y_min,bbx_y_max,bbx_z_min,bbx_z_max
|
||||
106434519822892,bottle_bbq,-0.023,0.021,0,0.145,-0.02,0.04
|
||||
106957734975303,puzzle_toy,-0.039,0.019,0,0.058,-0.031,0.027
|
||||
111305855142671,aria_small,-0.073,0.073,-0.03,0.015,-0.141,0.023
|
||||
117658302265452,mug_patterned,-0.052,0.045,0,0.084,-0.07,0.06
|
||||
125863066770940,carton_milk,-0.039,0.033,0,0.192,-0.04,0.04
|
||||
143541090750632,holder_gray,-0.035,0.051,0,0.124,-0.04,0.03
|
||||
163340972252026,vase,-0.084,0.052,0,0.168,-0.052,0.048
|
||||
171735388712249,can_parmesan,-0.031,0.035,0,0.102,-0.04,0.03
|
||||
183136364331389,can_tomato_sauce,-0.034,0.036,0,0.082,-0.036,0.03
|
||||
194930206998778,bowl,-0.139,0.141,0,0.13,-0.14,0.14
|
||||
195041665639898,birdhouse_toy,-0.102,0.078,0,0.233,-0.09,0.091
|
||||
204462113746498,carton_oj,-0.032,0.043,0,0.192,-0.033,0.041
|
||||
208243983021975,dino_toy,-0.147,0.092,0,0.135,-0.08,0.06
|
||||
223371871635142,mug_white,-0.068,0.049,0,0.091,-0.04,0.05
|
||||
225397651484143,spoon_wooden,-0.034,0.034,0,0.018,-0.155,0.155
|
||||
228358276546933,coffee_pot,-0.066,0.08,0,0.151,-0.043,0.042
|
||||
232501673989606,holder_black,-0.022,0.04,0,0.111,-0.046,0.045
|
||||
238686662724712,whiteboard_marker,-0.01,0.009,0,0.121,-0.009,0.011
|
||||
248247003992116,plate_bamboo,-0.163,0.138,0,0.047,-0.17,0.132
|
||||
249541253457812,dvd_remote,-0.018,0.017,-0.082,0.082,0.012,0.038
|
||||
253405647833885,food_waffles,-0.048,0.022,0,0.021,-0.062,0.068
|
||||
258906041248094,whiteboard_eraser,-0.043,0.024,0,0.037,-0.074,0.074
|
||||
261746112525368,bottle_mustard,-0.03,0.018,0,0.149,-0.032,0.033
|
||||
265826671143948,mouse,-0.036,0.034,-0.021,0.021,-0.056,0.054
|
||||
270231216246839,spatula_red,-0.154,0.149,0,0.042,-0.114,0.097
|
||||
27078911029651,dumbbell_5lb,-0.042,0.042,-0.09,0.09,-0.037,0.037
|
||||
37787722328019,keyboard,-0.215,0.215,-0.018,0.002,-0.063,0.066
|
||||
4111539686391,flask,-0.051,0.069,0,0.301,-0.104,0.056
|
||||
5462893327580,cellphone,-0.037,0.036,-0.006,0.005,-0.077,0.077
|
||||
70709727230291,bottle_ranch,-0.023,0.021,0,0.147,-0.034,0.03
|
||||
79582884925181,potato_masher,-0.06,0.03,0,0.283,-0.047,0.043
|
||||
96945373046044,food_vegetables,-0.045,0.023,0,0.02,-0.049,0.048
|
||||
98604936546412,can_soup,-0.033,0.038,0,0.082,-0.035,0.031
|
|
7200
hot3d_processing/mano_hand_pose_init/mano_hand_pose_trajectory.jsonl
Normal file
7200
hot3d_processing/mano_hand_pose_init/mano_hand_pose_trajectory.jsonl
Normal file
File diff suppressed because it is too large
Load diff
4
hot3d_processing/utils/__init__.py
Normal file
4
hot3d_processing/utils/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
__all__ = ['file_systems']
|
||||
|
||||
from .file_systems import remake_dir, make_dir
|
||||
|
50
hot3d_processing/utils/file_systems.py
Normal file
50
hot3d_processing/utils/file_systems.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
|
||||
# remove a directory
|
||||
def remove_dir(dirName):
|
||||
if os.path.exists(dirName):
|
||||
shutil.rmtree(dirName)
|
||||
else:
|
||||
print("Invalid directory path!")
|
||||
|
||||
|
||||
# remake a directory
|
||||
def remake_dir(dirName):
|
||||
if os.path.exists(dirName):
|
||||
shutil.rmtree(dirName)
|
||||
os.makedirs(dirName)
|
||||
else:
|
||||
os.makedirs(dirName)
|
||||
|
||||
|
||||
# calculate the number of lines in a file
|
||||
def file_lines(fileName):
|
||||
if os.path.exists(fileName):
|
||||
with open(fileName, 'r') as fr:
|
||||
return len(fr.readlines())
|
||||
else:
|
||||
print("Invalid file path!")
|
||||
return 0
|
||||
|
||||
|
||||
# make a directory if it does not exist.
|
||||
def make_dir(dirName):
|
||||
if os.path.exists(dirName):
|
||||
print("Directory "+ dirName + " already exists.")
|
||||
else:
|
||||
os.makedirs(dirName)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dirName = "test"
|
||||
RemakeDir(dirName)
|
||||
time.sleep(3)
|
||||
MakeDir(dirName)
|
||||
RemoveDir(dirName)
|
||||
time.sleep(3)
|
||||
MakeDir(dirName)
|
||||
#print(FileLines('233.txt'))
|
98
model/attended_hand_recognition.py
Normal file
98
model/attended_hand_recognition.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
from torch import nn
|
||||
import torch
|
||||
from model import graph_convolution_network
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class attended_hand_recognition(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.body_joint_number = opt.body_joint_number
|
||||
self.hand_joint_number = opt.hand_joint_number
|
||||
self.joint_number = self.body_joint_number + self.hand_joint_number
|
||||
self.input_n = opt.seq_len
|
||||
gcn_latent_features = opt.gcn_latent_features
|
||||
residual_gcns_num = opt.residual_gcns_num
|
||||
gcn_dropout = opt.gcn_dropout
|
||||
head_cnn_channels = opt.head_cnn_channels
|
||||
recognition_cnn_channels = opt.recognition_cnn_channels
|
||||
|
||||
# 1D CNN for extracting features from head directions
|
||||
in_channels_head = 3
|
||||
cnn_kernel_size = 3
|
||||
cnn_padding = (cnn_kernel_size -1)//2
|
||||
out_channels_1_head = head_cnn_channels
|
||||
out_channels_2_head = head_cnn_channels
|
||||
out_channels_head = head_cnn_channels
|
||||
|
||||
self.head_cnn = nn.Sequential(
|
||||
nn.Conv1d(in_channels = in_channels_head, out_channels=out_channels_1_head, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_1_head, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_1_head, out_channels=out_channels_2_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_2_head, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_2_head, out_channels=out_channels_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
# GCN for extracting features from body and left hand joints
|
||||
self.left_hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
|
||||
latent_features=gcn_latent_features,
|
||||
node_n=self.joint_number,
|
||||
seq_len=self.input_n,
|
||||
p_dropout=gcn_dropout,
|
||||
residual_gcns_num=residual_gcns_num)
|
||||
|
||||
# GCN for extracting features from body and right hand joints
|
||||
self.right_hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
|
||||
latent_features=gcn_latent_features,
|
||||
node_n=self.joint_number,
|
||||
seq_len=self.input_n,
|
||||
p_dropout=gcn_dropout,
|
||||
residual_gcns_num=residual_gcns_num)
|
||||
|
||||
# 1D CNN for recognising attended hand (left or right)
|
||||
in_channels_recognition = self.joint_number*gcn_latent_features*2 + out_channels_head
|
||||
cnn_kernel_size = 3
|
||||
cnn_padding = (cnn_kernel_size -1)//2
|
||||
out_channels_1_recognition = recognition_cnn_channels
|
||||
out_channels_recognition = 2
|
||||
|
||||
self.recognition_cnn = nn.Sequential(
|
||||
nn.Conv1d(in_channels = in_channels_recognition, out_channels=out_channels_1_recognition, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_1_recognition, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_1_recognition, out_channels=out_channels_recognition, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, src, input_n=15):
|
||||
|
||||
bs, seq_len, features = src.shape
|
||||
body_joints = src.clone()[:, :, :self.body_joint_number*3]
|
||||
left_hand_joints = src.clone()[:, :, self.body_joint_number*3:(self.body_joint_number+self.hand_joint_number)*3]
|
||||
right_hand_joints = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number)*3:(self.body_joint_number+self.hand_joint_number*2)*3]
|
||||
head_direction = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2)*3:(self.body_joint_number+self.hand_joint_number*2+1)*3]
|
||||
|
||||
left_hand_joints = torch.cat((left_hand_joints, body_joints), dim=2)
|
||||
left_hand_joints = left_hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
|
||||
left_hand_features = self.left_hand_gcn(left_hand_joints)
|
||||
left_hand_features = left_hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
|
||||
|
||||
right_hand_joints = torch.cat((right_hand_joints, body_joints), dim=2)
|
||||
right_hand_joints = right_hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
|
||||
right_hand_features = self.right_hand_gcn(right_hand_joints)
|
||||
right_hand_features = right_hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
|
||||
|
||||
head_direction = head_direction.permute(0,2,1)
|
||||
head_features = self.head_cnn(head_direction)
|
||||
|
||||
# fuse head and hand features
|
||||
features = torch.cat((left_hand_features, right_hand_features), dim=1)
|
||||
features = torch.cat((features, head_features), dim=1)
|
||||
# recognise attended hand from fused features
|
||||
prediction = self.recognition_cnn(features).permute(0, 2, 1)
|
||||
|
||||
return prediction
|
140
model/gaze_estimation.py
Normal file
140
model/gaze_estimation.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
from torch import nn
|
||||
import torch
|
||||
from model import graph_convolution_network, transformer
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class gaze_estimation(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.body_joint_number = opt.body_joint_number
|
||||
self.hand_joint_number = opt.hand_joint_number
|
||||
self.input_n = opt.seq_len
|
||||
self.object_num = opt.object_num
|
||||
gcn_latent_features = opt.gcn_latent_features
|
||||
residual_gcns_num = opt.residual_gcns_num
|
||||
gcn_dropout = opt.gcn_dropout
|
||||
head_cnn_channels = opt.head_cnn_channels
|
||||
gaze_cnn_channels = opt.gaze_cnn_channels
|
||||
self.use_self_att = opt.use_self_att
|
||||
self_att_head_num = opt.self_att_head_num
|
||||
self_att_dropout = opt.self_att_dropout
|
||||
self.use_cross_att = opt.use_cross_att
|
||||
cross_att_head_num = opt.cross_att_head_num
|
||||
cross_att_dropout = opt.cross_att_dropout
|
||||
self.use_attended_hand = opt.use_attended_hand
|
||||
self.use_attended_hand_gt = opt.use_attended_hand_gt
|
||||
if self.use_attended_hand:
|
||||
self.joint_number = self.body_joint_number + self.hand_joint_number + self.object_num
|
||||
else:
|
||||
self.joint_number = self.body_joint_number + self.hand_joint_number*2 + self.object_num*2
|
||||
|
||||
# 1D CNN for extracting features from head directions
|
||||
in_channels_head = 3
|
||||
cnn_kernel_size = 3
|
||||
cnn_padding = (cnn_kernel_size -1)//2
|
||||
out_channels_1_head = head_cnn_channels
|
||||
out_channels_2_head = head_cnn_channels
|
||||
out_channels_head = head_cnn_channels
|
||||
|
||||
self.head_cnn = nn.Sequential(
|
||||
nn.Conv1d(in_channels = in_channels_head, out_channels=out_channels_1_head, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_1_head, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_1_head, out_channels=out_channels_2_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_2_head, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_2_head, out_channels=out_channels_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
# GCN for extracting features from hand joints, body joints, and scene objects
|
||||
self.hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
|
||||
latent_features=gcn_latent_features,
|
||||
node_n=self.joint_number,
|
||||
seq_len=self.input_n,
|
||||
p_dropout=gcn_dropout,
|
||||
residual_gcns_num=residual_gcns_num)
|
||||
|
||||
if self.use_self_att:
|
||||
self.head_self_att = transformer.temporal_self_attention(out_channels_head, self_att_head_num, self_att_dropout)
|
||||
self.hand_self_att = transformer.temporal_self_attention(self.joint_number*gcn_latent_features, self_att_head_num, self_att_dropout)
|
||||
|
||||
if self.use_cross_att:
|
||||
self.head_hand_cross_att = transformer.temporal_cross_attention(out_channels_head, self.joint_number*gcn_latent_features, cross_att_head_num, cross_att_dropout)
|
||||
self.hand_head_cross_att = transformer.temporal_cross_attention(self.joint_number*gcn_latent_features, out_channels_head, cross_att_head_num, cross_att_dropout)
|
||||
|
||||
# 1D CNN for estimating eye gaze
|
||||
in_channels_gaze = self.joint_number*gcn_latent_features + out_channels_head
|
||||
cnn_kernel_size = 3
|
||||
cnn_padding = (cnn_kernel_size -1)//2
|
||||
out_channels_1_gaze = gaze_cnn_channels
|
||||
out_channels_gaze = 3
|
||||
|
||||
self.gaze_cnn = nn.Sequential(
|
||||
nn.Conv1d(in_channels = in_channels_gaze, out_channels=out_channels_1_gaze, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_1_gaze, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_1_gaze, out_channels=out_channels_gaze, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
|
||||
def forward(self, src, input_n=15):
|
||||
|
||||
bs, seq_len, features = src.shape
|
||||
body_joints = src.clone()[:, :, :self.body_joint_number*3]
|
||||
left_hand_joints = src.clone()[:, :, self.body_joint_number*3:(self.body_joint_number+self.hand_joint_number)*3]
|
||||
right_hand_joints = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number)*3:(self.body_joint_number+self.hand_joint_number*2)*3]
|
||||
head_direction = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2)*3:(self.body_joint_number+self.hand_joint_number*2+1)*3]
|
||||
if self.object_num > 0:
|
||||
left_object_position = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num)*3]
|
||||
left_object_position = torch.mean(left_object_position.reshape(bs, seq_len, self.object_num, 8, 3), dim=3).reshape(bs, seq_len, self.object_num*3)
|
||||
right_object_position = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3]
|
||||
right_object_position = torch.mean(right_object_position.reshape(bs, seq_len, self.object_num, 8, 3), dim=3).reshape(bs, seq_len, self.object_num*3)
|
||||
|
||||
attended_hand_prd = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+2]
|
||||
left_hand_weights = torch.round(attended_hand_prd[:, :, 0:1])
|
||||
right_hand_weights = torch.round(attended_hand_prd[:, :, 1:2])
|
||||
if self.use_attended_hand_gt:
|
||||
attended_hand_gt = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+2:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+3]
|
||||
left_hand_weights = 1-attended_hand_gt
|
||||
right_hand_weights = attended_hand_gt
|
||||
|
||||
if self.use_attended_hand:
|
||||
hand_joints = left_hand_joints*left_hand_weights + right_hand_joints*right_hand_weights
|
||||
else:
|
||||
hand_joints = torch.cat((left_hand_joints, right_hand_joints), dim=2)
|
||||
hand_joints = torch.cat((hand_joints, body_joints), dim=2)
|
||||
if self.object_num > 0:
|
||||
if self.use_attended_hand:
|
||||
object_position = left_object_position*left_hand_weights + right_object_position*right_hand_weights
|
||||
else:
|
||||
object_position = torch.cat((left_object_position, right_object_position), dim=2)
|
||||
hand_joints = torch.cat((hand_joints, object_position), dim=2)
|
||||
|
||||
hand_joints = hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
|
||||
hand_features = self.hand_gcn(hand_joints)
|
||||
hand_features = hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
|
||||
|
||||
head_direction = head_direction.permute(0,2,1)
|
||||
head_features = self.head_cnn(head_direction)
|
||||
|
||||
if self.use_self_att:
|
||||
head_features = self.head_self_att(head_features.permute(0,2,1)).permute(0,2,1)
|
||||
hand_features = self.hand_self_att(hand_features.permute(0,2,1)).permute(0,2,1)
|
||||
|
||||
if self.use_cross_att:
|
||||
head_features_copy = head_features.clone()
|
||||
head_features = self.head_hand_cross_att(head_features.permute(0,2,1), hand_features.permute(0,2,1)).permute(0,2,1)
|
||||
hand_features = self.hand_head_cross_att(hand_features.permute(0,2,1), head_features_copy.permute(0,2,1)).permute(0,2,1)
|
||||
|
||||
# fuse head and hand features
|
||||
features = torch.cat((hand_features, head_features), dim=1)
|
||||
# estimate eye gaze
|
||||
prediction = self.gaze_cnn(features).permute(0, 2, 1)
|
||||
# normalize to unit vectors
|
||||
prediction = F.normalize(prediction, dim=2)
|
||||
|
||||
return prediction
|
79
model/graph_convolution_network.py
Normal file
79
model/graph_convolution_network.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
import math
|
||||
|
||||
|
||||
class graph_convolution(nn.Module):
|
||||
def __init__(self, in_features, out_features, node_n = 21, seq_len = 40, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
|
||||
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
|
||||
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.FloatTensor(seq_len))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
|
||||
self.feature_weights.data.uniform_(-stdv, stdv)
|
||||
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
|
||||
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, input):
|
||||
y = torch.matmul(input, self.temporal_graph_weights)
|
||||
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
|
||||
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
|
||||
|
||||
if self.bias is not None:
|
||||
return (y + self.bias)
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
class residual_graph_convolution(nn.Module):
|
||||
def __init__(self, features, node_n=21, seq_len = 40, bias=True, p_dropout=0.3):
|
||||
super().__init__()
|
||||
|
||||
self.gcn = graph_convolution(features, features, node_n=node_n, seq_len=seq_len, bias=bias)
|
||||
self.ln = nn.LayerNorm([features, node_n, seq_len], elementwise_affine=True)
|
||||
self.act_f = nn.Tanh()
|
||||
self.dropout = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
y = self.gcn(x)
|
||||
y = self.ln(y)
|
||||
y = self.act_f(y)
|
||||
y = self.dropout(y)
|
||||
|
||||
return y + x
|
||||
|
||||
|
||||
class graph_convolution_network(nn.Module):
|
||||
def __init__(self, in_features, latent_features, node_n=21, seq_len=40, p_dropout=0.3, residual_gcns_num=1):
|
||||
super().__init__()
|
||||
self.residual_gcns_num = residual_gcns_num
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.start_gcn = graph_convolution(in_features=in_features, out_features=latent_features, node_n=node_n, seq_len=seq_len)
|
||||
|
||||
self.residual_gcns = []
|
||||
for i in range(residual_gcns_num):
|
||||
self.residual_gcns.append(residual_graph_convolution(features=latent_features, node_n=node_n, seq_len=seq_len*2, p_dropout=p_dropout))
|
||||
self.residual_gcns = nn.ModuleList(self.residual_gcns)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.start_gcn(x)
|
||||
|
||||
y = torch.cat((y, y), dim=3)
|
||||
for i in range(self.residual_gcns_num):
|
||||
y = self.residual_gcns[i](y)
|
||||
y = y[:, :, :, :self.seq_len]
|
||||
|
||||
return y
|
138
model/transformer.py
Normal file
138
model/transformer.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import layer_norm, nn
|
||||
import math
|
||||
|
||||
|
||||
class temporal_self_attention(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, num_head, dropout):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.norm = nn.LayerNorm(latent_dim)
|
||||
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.key = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.value = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, T, D
|
||||
"""
|
||||
B, T, D = x.shape
|
||||
H = self.num_head
|
||||
# B, T, 1, D
|
||||
query = self.query(self.norm(x)).unsqueeze(2)
|
||||
# B, 1, T, D
|
||||
key = self.key(self.norm(x)).unsqueeze(1)
|
||||
query = query.view(B, T, H, -1)
|
||||
key = key.view(B, T, H, -1)
|
||||
# B, T, T, H
|
||||
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
|
||||
weight = self.dropout(F.softmax(attention, dim=2))
|
||||
value = self.value(self.norm(x)).view(B, T, H, -1)
|
||||
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
|
||||
y = x + y
|
||||
return y
|
||||
|
||||
|
||||
class spatial_self_attention(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, num_head, dropout):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.norm = nn.LayerNorm(latent_dim)
|
||||
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.key = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.value = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, S, D
|
||||
"""
|
||||
B, S, D = x.shape
|
||||
H = self.num_head
|
||||
# B, S, 1, D
|
||||
query = self.query(self.norm(x)).unsqueeze(2)
|
||||
# B, 1, S, D
|
||||
key = self.key(self.norm(x)).unsqueeze(1)
|
||||
query = query.view(B, S, H, -1)
|
||||
key = key.view(B, S, H, -1)
|
||||
# B, S, S, H
|
||||
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
|
||||
weight = self.dropout(F.softmax(attention, dim=2))
|
||||
value = self.value(self.norm(x)).view(B, S, H, -1)
|
||||
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, S, D)
|
||||
y = x + y
|
||||
return y
|
||||
|
||||
|
||||
class temporal_cross_attention(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, mod_dim, num_head, dropout):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.norm = nn.LayerNorm(latent_dim)
|
||||
self.mod_norm = nn.LayerNorm(mod_dim)
|
||||
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.key = nn.Linear(mod_dim, latent_dim, bias=False)
|
||||
self.value = nn.Linear(mod_dim, latent_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, xf):
|
||||
"""
|
||||
x: B, T, D
|
||||
xf: B, N, L
|
||||
"""
|
||||
B, T, D = x.shape
|
||||
N = xf.shape[1]
|
||||
H = self.num_head
|
||||
# B, T, 1, D
|
||||
query = self.query(self.norm(x)).unsqueeze(2)
|
||||
# B, 1, N, D
|
||||
key = self.key(self.mod_norm(xf)).unsqueeze(1)
|
||||
query = query.view(B, T, H, -1)
|
||||
key = key.view(B, N, H, -1)
|
||||
# B, T, N, H
|
||||
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
|
||||
weight = self.dropout(F.softmax(attention, dim=2))
|
||||
value = self.value(self.mod_norm(xf)).view(B, N, H, -1)
|
||||
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
|
||||
y = x + y
|
||||
return y
|
||||
|
||||
|
||||
class spatial_cross_attention(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, mod_dim, num_head, dropout):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.norm = nn.LayerNorm(latent_dim)
|
||||
self.mod_norm = nn.LayerNorm(mod_dim)
|
||||
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.key = nn.Linear(mod_dim, latent_dim, bias=False)
|
||||
self.value = nn.Linear(mod_dim, latent_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, xf):
|
||||
"""
|
||||
x: B, S, D
|
||||
xf: B, N, L
|
||||
"""
|
||||
B, S, D = x.shape
|
||||
N = xf.shape[1]
|
||||
H = self.num_head
|
||||
# B, S, 1, D
|
||||
query = self.query(self.norm(x)).unsqueeze(2)
|
||||
# B, 1, N, D
|
||||
key = self.key(self.mod_norm(xf)).unsqueeze(1)
|
||||
query = query.view(B, S, H, -1)
|
||||
key = key.view(B, N, H, -1)
|
||||
# B, S, N, H
|
||||
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
|
||||
weight = self.dropout(F.softmax(attention, dim=2))
|
||||
value = self.value(self.mod_norm(xf)).view(B, N, H, -1)
|
||||
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, S, D)
|
||||
y = x + y
|
||||
return y
|
7
train_adt.sh
Normal file
7
train_adt.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
python attended_hand_recognition_adt.py --data_dir /scratch/hu/pose_forecast/adt_hoigaze/ --ckpt ./checkpoints/adt/ --cuda_idx cuda:6 --seq_len 15 --sample_rate 2 --residual_gcns_num 2 --gamma 0.95 --learning_rate 0.005 --epoch 60 --object_num 1 --hand_joint_number 1;
|
||||
|
||||
python attended_hand_recognition_adt.py --data_dir /scratch/hu/pose_forecast/adt_hoigaze/ --ckpt ./checkpoints/adt/ --cuda_idx cuda:6 --seq_len 15 --sample_rate 2 --residual_gcns_num 2 --gamma 0.95 --learning_rate 0.005 --epoch 60 --object_num 1 --hand_joint_number 1 --is_eval --save_predictions;
|
||||
|
||||
python gaze_estimation_adt.py --data_dir /scratch/hu/pose_forecast/adt_hoigaze/ --ckpt ./checkpoints/adt/ --cuda_idx cuda:6 --seq_len 15 --residual_gcns_num 4 --gamma 0.8 --learning_rate 0.005 --epoch 80 --object_num 1 --hand_joint_number 1;
|
||||
|
||||
python gaze_estimation_adt.py --data_dir /scratch/hu/pose_forecast/adt_hoigaze/ --ckpt ./checkpoints/adt/ --cuda_idx cuda:6 --seq_len 15 --residual_gcns_num 4 --gamma 0.8 --learning_rate 0.005 --epoch 80 --object_num 1 --hand_joint_number 1 --is_eval;
|
7
train_hot3d_scene1.sh
Normal file
7
train_hot3d_scene1.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --actions 'room' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0;
|
||||
|
||||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --actions 'room' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0 --is_eval --save_predictions;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --actions 'room' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --actions 'room' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;
|
7
train_hot3d_scene2.sh
Normal file
7
train_hot3d_scene2.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --actions 'kitchen' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0;
|
||||
|
||||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --actions 'kitchen' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0 --is_eval --save_predictions;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --actions 'kitchen' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --actions 'kitchen' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;
|
7
train_hot3d_scene3.sh
Normal file
7
train_hot3d_scene3.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --actions 'office' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0;
|
||||
|
||||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --actions 'office' --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.0 --is_eval --save_predictions;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --actions 'office' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --actions 'office' --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;
|
7
train_hot3d_user1.sh
Normal file
7
train_hot3d_user1.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --test_user_id 1 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05;
|
||||
|
||||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --test_user_id 1 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05 --is_eval --save_predictions;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --test_user_id 1 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:1 --test_user_id 1 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;
|
7
train_hot3d_user2.sh
Normal file
7
train_hot3d_user2.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --test_user_id 2 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05;
|
||||
|
||||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --test_user_id 2 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05 --is_eval --save_predictions;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --test_user_id 2 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:3 --test_user_id 2 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;
|
7
train_hot3d_user3.sh
Normal file
7
train_hot3d_user3.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --test_user_id 3 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05;
|
||||
|
||||
python attended_hand_recognition_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --test_user_id 3 --seq_len 15 --sample_rate 8 --gcn_dropout 0.3 --residual_gcns_num 2 --gamma 0.95 --epoch 60 --object_num 1 --weight_decay 0.05 --is_eval --save_predictions;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --test_user_id 3 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1;
|
||||
|
||||
python gaze_estimation_hot3d.py --data_dir /scratch/hu/pose_forecast/hot3d_hoigaze/ --ckpt ./checkpoints/hot3d/ --cuda_idx cuda:2 --test_user_id 3 --seq_len 15 --residual_gcns_num 4 --gamma 0.95 --learning_rate 0.005 --epoch 80 --object_num 1 --is_eval;
|
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from utils import *
|
171
utils/adt_dataset.py
Normal file
171
utils/adt_dataset.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class adt_dataset(Dataset):
|
||||
def __init__(self, data_dir, seq_len, actions = 'all', train_flag = 1, object_num=1, hand_joint_number=1, sample_rate=1):
|
||||
actions = self.define_actions(actions)
|
||||
self.sample_rate = sample_rate
|
||||
if train_flag == 1:
|
||||
data_dir = data_dir + 'train/'
|
||||
if train_flag == 0:
|
||||
data_dir = data_dir + 'test/'
|
||||
|
||||
self.dataset = self.load_data(data_dir, seq_len, actions, object_num, hand_joint_number)
|
||||
|
||||
def define_actions(self, action):
|
||||
"""
|
||||
Define the list of actions we are using.
|
||||
|
||||
Args
|
||||
action: String with the passed action. Could be "all"
|
||||
Returns
|
||||
actions: List of strings of actions
|
||||
Raises
|
||||
ValueError if the action is not included.
|
||||
"""
|
||||
|
||||
actions = ['work', 'decoration', 'meal']
|
||||
if action in actions:
|
||||
return [action]
|
||||
|
||||
if action == "all":
|
||||
return actions
|
||||
raise( ValueError, "Unrecognised action: %d" % action )
|
||||
|
||||
def load_data(self, data_dir, seq_len, actions, object_num, hand_joint_number):
|
||||
action_number = len(actions)
|
||||
dataset = []
|
||||
file_names = sorted(os.listdir(data_dir))
|
||||
gaze_file_names = {}
|
||||
hand_file_names = {}
|
||||
hand_joint_file_names = {}
|
||||
head_file_names = {}
|
||||
object_left_file_names = {}
|
||||
object_right_file_names = {}
|
||||
for action_idx in np.arange(action_number):
|
||||
gaze_file_names[actions[ action_idx ]] = []
|
||||
hand_file_names[actions[ action_idx ]] = []
|
||||
hand_joint_file_names[actions[ action_idx ]] = []
|
||||
head_file_names[actions[ action_idx ]] = []
|
||||
object_left_file_names[actions[ action_idx ]] = []
|
||||
object_right_file_names[actions[ action_idx ]] = []
|
||||
|
||||
for name in file_names:
|
||||
name_split = name.split('_')
|
||||
action = name_split[2]
|
||||
if action in actions:
|
||||
data_type = name_split[-1][:-4]
|
||||
if(data_type == 'gaze'):
|
||||
gaze_file_names[action].append(name)
|
||||
if(data_type == 'hand'):
|
||||
hand_file_names[action].append(name)
|
||||
if(data_type == 'handjoints'):
|
||||
hand_joint_file_names[action].append(name)
|
||||
if(data_type == 'head'):
|
||||
head_file_names[action].append(name)
|
||||
if(data_type == 'bbxleft'):
|
||||
object_left_file_names[action].append(name)
|
||||
if(data_type == 'bbxright'):
|
||||
object_right_file_names[action].append(name)
|
||||
|
||||
for action_idx in np.arange(action_number):
|
||||
action = actions[ action_idx ]
|
||||
segments_number = len(gaze_file_names[action])
|
||||
print("Reading action {}, segments number {}".format(action, segments_number))
|
||||
for i in range(segments_number):
|
||||
gaze_data_path = data_dir + gaze_file_names[action][i]
|
||||
gaze_data = np.load(gaze_data_path)
|
||||
gaze_direction = gaze_data[:, :3]
|
||||
num_frames = gaze_data.shape[0]
|
||||
if num_frames < seq_len:
|
||||
continue
|
||||
hand_data_path = data_dir + hand_file_names[action][i]
|
||||
hand_translation = np.load(hand_data_path)
|
||||
hand_joint_data_path = data_dir + hand_joint_file_names[action][i]
|
||||
hand_joint_data_all = np.load(hand_joint_data_path)
|
||||
hand_joint_number_default = 15
|
||||
hand_joint_data = hand_joint_data_all[:, :hand_joint_number_default*6]
|
||||
left_hand_center = np.mean(hand_joint_data[:, :hand_joint_number_default*3].reshape(hand_joint_data.shape[0], hand_joint_number_default, 3), axis=1)
|
||||
right_hand_center = np.mean(hand_joint_data[:, hand_joint_number_default*3:].reshape(hand_joint_data.shape[0], hand_joint_number_default, 3), axis=1)
|
||||
if hand_joint_number == 1:
|
||||
hand_joint_data = np.concatenate((left_hand_center, right_hand_center), axis=1)
|
||||
|
||||
attended_hand_gt = hand_joint_data_all[:, hand_joint_number_default*6:hand_joint_number_default*6+1]
|
||||
attended_hand_baseline = hand_joint_data_all[:, hand_joint_number_default*6+1:hand_joint_number_default*6+2]
|
||||
|
||||
head_data_path = data_dir + head_file_names[action][i]
|
||||
head_data = np.load(head_data_path)
|
||||
head_direction = head_data[:, :3]
|
||||
head_translation = head_data[:, 3:]
|
||||
|
||||
object_left_data_path = data_dir + object_left_file_names[action][i]
|
||||
object_left_data = np.load(object_left_data_path)
|
||||
object_left_data = object_left_data.reshape(object_left_data.shape[0], -1)
|
||||
object_right_data_path = data_dir + object_right_file_names[action][i]
|
||||
object_right_data = np.load(object_right_data_path)
|
||||
object_right_data = object_right_data.reshape(object_right_data.shape[0], -1)
|
||||
|
||||
object_left_bbx = []
|
||||
object_right_bbx = []
|
||||
for item in range(object_num):
|
||||
left_bbx = object_left_data[:, item*24:item*24+24]
|
||||
right_bbx = object_right_data[:, item*24:item*24+24]
|
||||
if len(object_left_bbx) == 0:
|
||||
object_left_bbx = left_bbx
|
||||
object_right_bbx = right_bbx
|
||||
else:
|
||||
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
|
||||
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
|
||||
|
||||
#object_left_positions = np.mean(object_left_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
|
||||
#object_right_positions = np.mean(object_right_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
|
||||
|
||||
data = gaze_direction
|
||||
data = np.concatenate((data, hand_translation), axis=1)
|
||||
data = np.concatenate((data, head_translation), axis=1)
|
||||
data = np.concatenate((data, hand_joint_data), axis=1)
|
||||
data = np.concatenate((data, head_direction), axis=1)
|
||||
if object_num > 0:
|
||||
data = np.concatenate((data, object_left_bbx), axis=1)
|
||||
data = np.concatenate((data, object_right_bbx), axis=1)
|
||||
data = np.concatenate((data, attended_hand_gt), axis=1)
|
||||
data = np.concatenate((data, attended_hand_baseline), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
#print(fs_sel)
|
||||
seq_sel = data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::self.sample_rate, :, :]
|
||||
#print(seq_sel.shape)
|
||||
if len(dataset) == 0:
|
||||
dataset = seq_sel
|
||||
else:
|
||||
dataset = np.concatenate((dataset, seq_sel), axis=0)
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
return np.shape(self.dataset)[0]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_dir = "/scratch/hu/pose_forecast/adt_hoigaze/"
|
||||
seq_len = 15
|
||||
actions = 'all'
|
||||
sample_rate = 1
|
||||
train_flag = 1
|
||||
object_num = 1
|
||||
hand_joint_number = 1
|
||||
train_dataset = adt_dataset(data_dir, seq_len, actions, train_flag, object_num, hand_joint_number, sample_rate)
|
||||
print("Training data size: {}".format(train_dataset.dataset.shape))
|
||||
|
||||
hand_joint_dominance = train_dataset[:, :, -2:-1].flatten()
|
||||
print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))
|
||||
|
137
utils/hot3d_aria_dataset.py
Normal file
137
utils/hot3d_aria_dataset.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class hot3d_aria_dataset(Dataset):
|
||||
def __init__(self, data_dir, subjects, seq_len, actions = 'all', object_num=1, sample_rate=1):
|
||||
if actions == 'all':
|
||||
actions = ['room', 'kitchen', 'office']
|
||||
self.sample_rate = sample_rate
|
||||
self.dataset = self.load_data(data_dir, subjects, seq_len, actions, object_num)
|
||||
|
||||
def load_data(self, data_dir, subjects, seq_len, actions, object_num):
|
||||
dataset = []
|
||||
file_names = sorted(os.listdir(data_dir))
|
||||
gaze_file_names = []
|
||||
hand_file_names = []
|
||||
hand_joint_file_names = []
|
||||
head_file_names = []
|
||||
object_left_file_names = []
|
||||
object_right_file_names = []
|
||||
for name in file_names:
|
||||
name_split = name.split('_')
|
||||
subject = name_split[0]
|
||||
action = name_split[2]
|
||||
if subject in subjects and action in actions:
|
||||
data_type = name_split[-1][:-4]
|
||||
if(data_type == 'gaze'):
|
||||
gaze_file_names.append(name)
|
||||
if(data_type == 'hand'):
|
||||
hand_file_names.append(name)
|
||||
if(data_type == 'handjoints'):
|
||||
hand_joint_file_names.append(name)
|
||||
if(data_type == 'head'):
|
||||
head_file_names.append(name)
|
||||
if(data_type == 'bbxleft'):
|
||||
object_left_file_names.append(name)
|
||||
if(data_type == 'bbxright'):
|
||||
object_right_file_names.append(name)
|
||||
|
||||
segments_number = len(hand_file_names)
|
||||
# print("segments number {}".format(segments_number))
|
||||
for i in range(segments_number):
|
||||
gaze_data_path = data_dir + gaze_file_names[i]
|
||||
gaze_data = np.load(gaze_data_path)
|
||||
num_frames = gaze_data.shape[0]
|
||||
if num_frames < seq_len:
|
||||
continue
|
||||
hand_data_path = data_dir + hand_file_names[i]
|
||||
hand_data = np.load(hand_data_path)
|
||||
hand_joint_data_path = data_dir + hand_joint_file_names[i]
|
||||
hand_joint_data_all = np.load(hand_joint_data_path)
|
||||
hand_joint_data = hand_joint_data_all[:, :120]
|
||||
attended_hand_gt = hand_joint_data_all[:, 120:121]
|
||||
attended_hand_baseline = hand_joint_data_all[:, 121:122]
|
||||
|
||||
head_data_path = data_dir + head_file_names[i]
|
||||
head_data = np.load(head_data_path)
|
||||
object_left_data_path = data_dir + object_left_file_names[i]
|
||||
object_left_data = np.load(object_left_data_path)
|
||||
object_right_data_path = data_dir + object_right_file_names[i]
|
||||
object_right_data = np.load(object_right_data_path)
|
||||
|
||||
left_hand_translation = hand_data[:, 0:3]
|
||||
right_hand_translation = hand_data[:, 22:25]
|
||||
head_direction = head_data[:, 0:3]
|
||||
head_translation = head_data[:, 3:6]
|
||||
gaze_direction = gaze_data[:, 0:3]
|
||||
object_left_bbx = []
|
||||
object_right_bbx = []
|
||||
for item in range(object_num):
|
||||
left_bbx = object_left_data[:, item*24:item*24+24]
|
||||
right_bbx = object_right_data[:, item*24:item*24+24]
|
||||
if len(object_left_bbx) == 0:
|
||||
object_left_bbx = left_bbx
|
||||
object_right_bbx = right_bbx
|
||||
else:
|
||||
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
|
||||
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
|
||||
|
||||
#object_left_positions = np.mean(object_left_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
|
||||
#object_right_positions = np.mean(object_right_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
|
||||
|
||||
data = gaze_direction
|
||||
data = np.concatenate((data, left_hand_translation), axis=1)
|
||||
data = np.concatenate((data, right_hand_translation), axis=1)
|
||||
data = np.concatenate((data, head_translation), axis=1)
|
||||
data = np.concatenate((data, hand_joint_data), axis=1)
|
||||
data = np.concatenate((data, head_direction), axis=1)
|
||||
if object_num > 0:
|
||||
data = np.concatenate((data, object_left_bbx), axis=1)
|
||||
data = np.concatenate((data, object_right_bbx), axis=1)
|
||||
data = np.concatenate((data, attended_hand_gt), axis=1)
|
||||
data = np.concatenate((data, attended_hand_baseline), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
seq_sel = data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::self.sample_rate, :, :]
|
||||
if len(dataset) == 0:
|
||||
dataset = seq_sel
|
||||
else:
|
||||
dataset = np.concatenate((dataset, seq_sel), axis=0)
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
return np.shape(self.dataset)[0]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_dir = "/scratch/hu/pose_forecast/hot3d_hoigaze/"
|
||||
seq_len = 15
|
||||
actions = 'all'
|
||||
all_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
object_num = 1
|
||||
sample_rate = 10
|
||||
|
||||
train_dataset = hot3d_aria_dataset(data_dir, train_subjects, seq_len, actions, object_num, sample_rate)
|
||||
print("Training data size: {}".format(train_dataset.dataset.shape))
|
||||
|
||||
hand_joint_dominance = train_dataset[:, :, -2:-1].flatten()
|
||||
print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))
|
||||
|
||||
#test_subjects = ['P0001', 'P0002', 'P0003']
|
||||
#sample_rate = 8
|
||||
#test_dataset = hot3d_aria_dataset(data_dir, test_subjects, seq_len, actions, #object_num, sample_rate)
|
||||
# print("Test data size: {}".format(test_dataset.dataset.shape))
|
||||
|
||||
#hand_joint_dominance = test_dataset[:, :, -2:-1].flatten()
|
||||
#print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))
|
91
utils/hot3d_aria_single_dataset.py
Normal file
91
utils/hot3d_aria_single_dataset.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class hot3d_aria_dataset(Dataset):
|
||||
def __init__(self, data_path, seq_len, object_num=1):
|
||||
self.dataset = self.load_data(data_path, seq_len, object_num)
|
||||
|
||||
def load_data(self, data_path, seq_len, object_num):
|
||||
dataset = []
|
||||
gaze_file_name = data_path + 'gaze.npy'
|
||||
hand_file_name = data_path + 'hand.npy'
|
||||
hand_joint_file_name = data_path + 'handjoints.npy'
|
||||
head_file_name = data_path + 'head.npy'
|
||||
object_left_file_name = data_path + 'object_bbxleft.npy'
|
||||
object_right_file_name = data_path + 'object_bbxright.npy'
|
||||
|
||||
gaze_data_path = gaze_file_name
|
||||
gaze_data = np.load(gaze_data_path)
|
||||
num_frames = gaze_data.shape[0]
|
||||
hand_data_path = hand_file_name
|
||||
hand_data = np.load(hand_data_path)
|
||||
hand_joint_data_path = hand_joint_file_name
|
||||
hand_joint_data_all = np.load(hand_joint_data_path)
|
||||
hand_joint_data = hand_joint_data_all[:, :120]
|
||||
attended_hand_gt = hand_joint_data_all[:, 120:121]
|
||||
attended_hand_baseline = hand_joint_data_all[:, 121:122]
|
||||
|
||||
head_data_path = head_file_name
|
||||
head_data = np.load(head_data_path)
|
||||
object_left_data_path = object_left_file_name
|
||||
object_left_data = np.load(object_left_data_path)
|
||||
object_right_data_path = object_right_file_name
|
||||
object_right_data = np.load(object_right_data_path)
|
||||
|
||||
left_hand_translation = hand_data[:, 0:3]
|
||||
right_hand_translation = hand_data[:, 22:25]
|
||||
head_direction = head_data[:, 0:3]
|
||||
head_translation = head_data[:, 3:6]
|
||||
gaze_direction = gaze_data[:, 0:3]
|
||||
object_left_bbx = []
|
||||
object_right_bbx = []
|
||||
for item in range(object_num):
|
||||
left_bbx = object_left_data[:, item*24:item*24+24]
|
||||
right_bbx = object_right_data[:, item*24:item*24+24]
|
||||
if len(object_left_bbx) == 0:
|
||||
object_left_bbx = left_bbx
|
||||
object_right_bbx = right_bbx
|
||||
else:
|
||||
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
|
||||
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
|
||||
|
||||
data = gaze_direction
|
||||
data = np.concatenate((data, left_hand_translation), axis=1)
|
||||
data = np.concatenate((data, right_hand_translation), axis=1)
|
||||
data = np.concatenate((data, head_translation), axis=1)
|
||||
data = np.concatenate((data, hand_joint_data), axis=1)
|
||||
data = np.concatenate((data, head_direction), axis=1)
|
||||
if object_num > 0:
|
||||
data = np.concatenate((data, object_left_bbx), axis=1)
|
||||
data = np.concatenate((data, object_right_bbx), axis=1)
|
||||
data = np.concatenate((data, attended_hand_gt), axis=1)
|
||||
data = np.concatenate((data, attended_hand_baseline), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
seq_sel = data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::seq_len, :, :]
|
||||
if len(dataset) == 0:
|
||||
dataset = seq_sel
|
||||
else:
|
||||
dataset = np.concatenate((dataset, seq_sel), axis=0)
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
return np.shape(self.dataset)[0]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = '/scratch/hu/pose_forecast/hot3d_hoigaze/P0001_10a27bf7_room_721_890_'
|
||||
seq_len = 15
|
||||
object_num = 1
|
||||
train_dataset = hot3d_aria_dataset(data_path, seq_len, object_num)
|
||||
print("Training data size: {}".format(train_dataset.dataset.shape))
|
28
utils/log.py
Normal file
28
utils/log.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import json
|
||||
import os
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_csv_log(opt, head, value, is_create=False, file_name='results'):
|
||||
if len(value.shape) < 2:
|
||||
value = np.expand_dims(value, axis=0)
|
||||
df = pd.DataFrame(value)
|
||||
file_path = opt.ckpt + '/{}.csv'.format(file_name)
|
||||
print(file_path)
|
||||
if not os.path.exists(file_path) or is_create:
|
||||
df.to_csv(file_path, header=head, index=False)
|
||||
else:
|
||||
with open(file_path, 'a') as f:
|
||||
df.to_csv(f, header=False, index=False)
|
||||
|
||||
|
||||
def save_ckpt(state, opt=None, file_name = 'model.pt'):
|
||||
file_path = os.path.join(opt.ckpt, file_name)
|
||||
torch.save(state, file_path)
|
||||
|
||||
|
||||
def save_options(opt):
|
||||
with open(opt.ckpt + '/options.json', 'w') as f:
|
||||
f.write(json.dumps(vars(opt), sort_keys=False, indent=4))
|
74
utils/opt.py
Normal file
74
utils/opt.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import os
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
class options:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.opt = None
|
||||
|
||||
def _initial(self):
|
||||
# ===============================================================
|
||||
# General options
|
||||
# ===============================================================
|
||||
self.parser.add_argument('--cuda_idx', type=str, default='cuda:0', help='cuda idx')
|
||||
self.parser.add_argument('--data_dir', type=str,
|
||||
default='./dataset/',
|
||||
help='path to dataset')
|
||||
self.parser.add_argument('--is_eval', dest='is_eval', action='store_true',
|
||||
help='whether to evaluate existing models or not')
|
||||
self.parser.add_argument('--ckpt', type=str, default='./checkpoints/', help='path to save checkpoints')
|
||||
self.parser.add_argument('--test_user_id', type=int, default=1, help='id of the test participants')
|
||||
self.parser.add_argument('--actions', type=str, default='all', help='actions to use')
|
||||
self.parser.add_argument('--sample_rate', type=int, default=2, help='sample the data')
|
||||
self.parser.add_argument('--save_predictions', dest='save_predictions', action='store_true',
|
||||
help='whether to save the prediction results or not')
|
||||
# ===============================================================
|
||||
# Model options
|
||||
# ===============================================================
|
||||
self.parser.add_argument('--body_joint_number', type=int, default=3, help='number of body joints to use')
|
||||
self.parser.add_argument('--hand_joint_number', type=int, default=20, help='number of hand joints to use')
|
||||
self.parser.add_argument('--head_cnn_channels', type=int, default=32, help='number of channels used in the head_CNN')
|
||||
self.parser.add_argument('--gcn_latent_features', type=int, default=8, help='number of latent features used in the gcn')
|
||||
self.parser.add_argument('--residual_gcns_num', type=int, default=4, help='number of residual gcns to use')
|
||||
self.parser.add_argument('--gcn_dropout', type=float, default=0.3, help='drop out probability in the gcn')
|
||||
self.parser.add_argument('--gaze_cnn_channels', type=int, default=64, help='number of channels used in the gaze_CNN')
|
||||
self.parser.add_argument('--recognition_cnn_channels', type=int, default=64, help='number of channels used in the recognition_CNN')
|
||||
self.parser.add_argument('--object_num', type=int, default=1, help='number of scene objects for gaze estimation')
|
||||
self.parser.add_argument('--use_self_att', type=int, default=1, help='use self attention or not')
|
||||
self.parser.add_argument('--self_att_head_num', type=int, default=1, help='number of heads used in self attention')
|
||||
self.parser.add_argument('--self_att_dropout', type=float, default=0.1, help='drop out probability in self attention')
|
||||
self.parser.add_argument('--use_cross_att', type=int, default=1, help='use cross attention or not')
|
||||
self.parser.add_argument('--cross_att_head_num', type=int, default=1, help='number of heads used in cross attention')
|
||||
self.parser.add_argument('--cross_att_dropout', type=float, default=0.1, help='drop out probability in cross attention')
|
||||
self.parser.add_argument('--use_attended_hand', type=int, default=1, help='use attended hand or use both hands')
|
||||
self.parser.add_argument('--use_attended_hand_gt', type=int, default=0, help='use attended hand ground truth or not')
|
||||
# ===============================================================
|
||||
# Running options
|
||||
# ===============================================================
|
||||
self.parser.add_argument('--seq_len', type=int, default=15, help='the length of the used sequence')
|
||||
self.parser.add_argument('--learning_rate', type=float, default=0.005)
|
||||
self.parser.add_argument('--gaze_head_loss_factor', type=float, default=4.0)
|
||||
self.parser.add_argument('--gaze_head_cos_threshold', type=float, default=0.8)
|
||||
self.parser.add_argument('--weight_decay', type=float, default=0.0)
|
||||
self.parser.add_argument('--gamma', type=float, default=0.95, help='decay learning rate by gamma')
|
||||
self.parser.add_argument('--epoch', type=int, default=50)
|
||||
self.parser.add_argument('--batch_size', type=int, default=32)
|
||||
self.parser.add_argument('--validation_epoch', type=int, default=10, help='interval of epoches to test')
|
||||
self.parser.add_argument('--test_batch_size', type=int, default=32)
|
||||
|
||||
def _print(self):
|
||||
print("\n==================Options=================")
|
||||
pprint(vars(self.opt), indent=4)
|
||||
print("==========================================\n")
|
||||
|
||||
def parse(self, make_dir=True):
|
||||
self._initial()
|
||||
self.opt = self.parser.parse_args()
|
||||
ckpt = self.opt.ckpt
|
||||
if make_dir==True:
|
||||
if not os.path.isdir(ckpt):
|
||||
os.makedirs(ckpt)
|
||||
self._print()
|
||||
return self.opt
|
15
utils/seed_torch.py
Normal file
15
utils/seed_torch.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def seed_torch(seed=0):
|
||||
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
Loading…
Add table
Add a link
Reference in a new issue