HOIGaze/hot3d_processing/hot3d_aria_visualisation.ipynb
2025-04-30 14:15:00 +02:00

15 KiB

In [ ]:
import os
os.nice(5)
import rerun as rr
import numpy as np
from math import tan
import time
from utils import remake_dir
import pandas as pd
from data_loaders.ManoHandDataProvider import MANOHandDataProvider
from data_loaders.loader_object_library import load_object_library
from data_loaders.mano_layer import MANOHandModel
from data_loaders.loader_hand_poses import Handedness, HandPose
from data_loaders.hand_common import LANDMARK_CONNECTIVITY
from data_loaders.loader_object_library import ObjectLibrary
from data_loaders.headsets import Headset
from projectaria_tools.core.stream_id import StreamId
from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions
from projectaria_tools.core.sophus import SE3
from projectaria_tools.utils.rerun_helpers import ToTransform3D


data_path = '/scratch/hu/pose_forecast/hot3d_hoigaze/P0001_10a27bf7_room_721_890_'
timestamps_path = data_path + 'timestamps.npy'
head_path = data_path + 'head.npy'
gaze_path = data_path + 'gaze.npy'
hand_path = data_path + 'hand.npy'
hand_joint_path = data_path + 'handjoints.npy'
object_path = data_path + 'objects.npy'
object_bbx_path = data_path + 'object_bbx.npy'
object_library_path = '/datasets/public/zhiming_datasets/hot3d/assets/'
mano_hand_model_path = '/datasets/public/zhiming_datasets/hot3d/mano_v1_2/models/'

show_bbx = False
show_hand_mesh = True
# init the object library
if not os.path.exists(object_library_path):
    print("invalid object library path.")
    print("please follow the instructions at https://github.com/facebookresearch/hot3d to Download the HOT3D Assets Dataset")        
    raise    
object_library = load_object_library(object_library_folderpath=object_library_path)

# init the HANDs model
if not os.path.exists(mano_hand_model_path):
    print("invalid mano hand model path.")
    print("please follow the instructions at https://github.com/facebookresearch/hot3d to Download the MANO files")
    raise
mano_hand_model = MANOHandModel(mano_hand_model_path)

timestamps_data = np.load(timestamps_path)
head_data = np.load(head_path) # head_direction (3) + head_translation (3) + head_rotation (4, quat_xyzw)
gaze_data = np.load(gaze_path) # gaze_direction (3) + gaze_center_in_world (3)
hand_data = np.load(hand_path) # left_hand (22) + right_hand (22), hand = wrist_pose (7, translation (3) + rotation (4)) + joint_angles (15)
hand_joint_data = np.load(hand_joint_path) # left_hand (20*3) + right_hand (20*3)
object_data = np.load(object_path) # object_data (8) * 6 objects (at most 6 objects), object_data = object_uid (1) + object_pose (7, translation (3) + rotation (4))
object_bbx_data = np.load(object_bbx_path) # bounding box information: 6 objects (at most 6 objects) * 8 vertexes * 3
frame_length = len(head_data)

# alias over the HAND pose data provider
hand_data_provider = MANOHandDataProvider('./mano_hand_pose_init/mano_hand_pose_trajectory.jsonl', mano_hand_model)
# keep track of what 3D assets have been loaded/unloaded so we will load them only when needed
object_cache_status = {}

# Init a rerun context
rr.init("hot3d-aria")
rec = rr.memory_recording()

def log_pose(
    pose: SE3,
    label: str,
    static=False
) -> None:
    rr.log(label, ToTransform3D(pose, False), static=static)


for i in range(frame_length):
    timestamp_ns = timestamps_data[i, 0]
    rr.set_time_nanos("synchronization_time", int(timestamp_ns))
    rr.set_time_sequence("timestamp", int(timestamp_ns))
    
    head_direction = head_data[i, 0:3]
    head_translation = head_data[i, 3:6]
    head_rotation = head_data[i, 6:10]    
    gaze_direction = gaze_data[i, 0:3]
    gaze_center_in_world = gaze_data[i, 3:6]
    left_hand_translation = hand_data[i, 0:3]
    left_hand_rotation = hand_data[i, 3:7]
    left_hand_joint_angles = hand_data[i, 7:22]
    left_hand_joints = hand_joint_data[i, 0:60].reshape((20, 3))
    right_hand_translation = hand_data[i, 22:25]
    right_hand_rotation = hand_data[i, 25:29]
    right_hand_joint_angles = hand_data[i, 29:44]
    right_hand_joints = hand_joint_data[i, 60:120].reshape((20, 3))
    left_hand_wrist_pose = SE3.from_quat_and_translation(left_hand_rotation[-1], left_hand_rotation[:-1], left_hand_translation)
    left_hand_pose = HandPose(Handedness.Left, left_hand_wrist_pose, left_hand_joint_angles)
    right_hand_wrist_pose = SE3.from_quat_and_translation(right_hand_rotation[-1], right_hand_rotation[:-1], right_hand_translation)
    right_hand_pose = HandPose(Handedness.Right, right_hand_wrist_pose, right_hand_joint_angles)
    
    # use cpf pose as head pose, see https://facebookresearch.github.io/projectaria_tools/docs/data_formats/coordinate_convention/3d_coordinate_frame_convention
    T_world_cpf = SE3.from_quat_and_translation(head_rotation[-1], head_rotation[:-1], head_translation)
    log_pose(pose=T_world_cpf, label="world/head_pose")
    #rr.log(
    #"world/head_direction",
    #rr.Points3D([head_translation], radii=[0.003]),
    #rr.Arrows3D(vectors=[head_direction*0.4], colors=[[0, 0.8, 0.8, 0.5]]))
    #log_pose(pose=left_hand_wrist_pose, label="world/left_hand_pose")
    #log_pose(pose=right_hand_wrist_pose, label="world/right_hand_pose")
    
    rr.log(
    "world/gaze_direction",
    rr.Points3D([head_translation], radii=[0.003]),
    rr.Arrows3D(vectors=[gaze_direction*0.4], colors=[[0, 0.8, 0.2, 0.5]]))
    #print("frame: {}, gaze: {}".format(i+1119, gaze_direction))
    
    # plot hands as a triangular mesh representation
    if show_hand_mesh:
        left_hand_mesh_vertices = hand_data_provider.get_hand_mesh_vertices(left_hand_pose)
        left_hand_triangles, left_hand_vertex_normals = hand_data_provider.get_hand_mesh_faces_and_normals(left_hand_pose)    
        rr.log(
            f"world/left_hand/mesh_faces",
            rr.Mesh3D(
                vertex_positions=left_hand_mesh_vertices,
                vertex_normals=left_hand_vertex_normals,
                triangle_indices=left_hand_triangles))
        right_hand_mesh_vertices = hand_data_provider.get_hand_mesh_vertices(right_hand_pose)
        right_hand_triangles, right_hand_vertex_normals = hand_data_provider.get_hand_mesh_faces_and_normals(right_hand_pose)
        rr.log(
            f"world/right_hand/mesh_faces",
            rr.Mesh3D(
                vertex_positions=right_hand_mesh_vertices,
                vertex_normals=right_hand_vertex_normals,
                triangle_indices=right_hand_triangles))
    else:
        #left_hand_translation = np.array([0, 0, 0])
        #left_hand_rotation = np.array([0, 0, 0, 1])
        #left_hand_joint_angles = np.zeros(15)
        #left_hand_wrist_pose = SE3.from_quat_and_translation(left_hand_rotation[-1], left_hand_rotation[:-1], left_hand_translation)
        #log_pose(pose=left_hand_wrist_pose, label="world/left_hand_pose")
        #left_hand_pose = HandPose(Handedness.Left, left_hand_wrist_pose, left_hand_joint_angles)
        #left_hand_joints = hand_data_provider.get_hand_landmarks(left_hand_pose)
        #left_hand_wrist = left_hand_joints[5, :].clone()
        #joint_number = left_hand_joints.shape[0]
        #for index in range(joint_number):
        #    left_hand_joints[index, :] -= left_hand_wrist
        #for index in range(joint_number):            
        #    tmp = left_hand_joints[index, :].clone()
        #    left_hand_joints[index, 1] = -tmp[2]
        #    left_hand_joints[index, 2] = tmp[1]
        #for index in range(joint_number):            
        #    print(left_hand_joints[index])
        
        #right_hand_translation = np.array([0, 0, 0])
        #right_hand_rotation = np.array([0, 0, 0, 1])
        #right_hand_joint_angles = np.zeros(15)
        #right_hand_wrist_pose = SE3.from_quat_and_translation(right_hand_rotation[-1], right_hand_rotation[:-1], right_hand_translation)
        #log_pose(pose=right_hand_wrist_pose, label="world/right_hand_pose")
        #right_hand_pose = HandPose(Handedness.Right, right_hand_wrist_pose, right_hand_joint_angles)
        #right_hand_joints = hand_data_provider.get_hand_landmarks(right_hand_pose)
        #right_hand_wrist = right_hand_joints[5, :].clone()
        #joint_number = right_hand_joints.shape[0]
        #for index in range(joint_number):
        #    right_hand_joints[index, :] -= right_hand_wrist
        #for index in range(joint_number):            
        #    tmp = right_hand_joints[index, :].clone()
        #    right_hand_joints[index, 1] = -tmp[2]
        #    right_hand_joints[index, 2] = tmp[1]
        #for index in range(joint_number):            
        #    print(right_hand_joints[index])
        
        left_hand_skeleton = [connections
                  for connectivity in LANDMARK_CONNECTIVITY
                  for connections in [[left_hand_joints[it].tolist() for it in connectivity]]]                
        rr.log(
            f"world/left_hand_skeleton",
            rr.LineStrips3D(left_hand_skeleton, radii=0.002),
        )
        right_hand_skeleton = [connections
                  for connectivity in LANDMARK_CONNECTIVITY
                  for connections in [[right_hand_joints[it].tolist() for it in connectivity]]]        
        rr.log(
            f"world/right_hand_skeleton",
            rr.LineStrips3D(right_hand_skeleton, radii=0.002),
        )
        
    # load objects
    object_num_max = 6
    logging_status = {}    
    for item in range(object_num_max):
        object_uid = str(int(object_data[i, item*8]))
        if object_uid == '0':
            break
        logging_status[object_uid] = False
    object_num = len(logging_status)
    for item in range(object_num):
        object_uid = str(int(object_data[i, item*8]))
        object_name = object_library.object_id_to_name_dict[object_uid]
        object_cad_asset_filepath = ObjectLibrary.get_cad_asset_path(
            object_library_folderpath=object_library.asset_folder_name,
            object_id=object_uid)
        object_translation = object_data[i, item*8+1:item*8+4]
        object_rotation = object_data[i, item*8+4:item*8+8]
        object_pose = SE3.from_quat_and_translation(object_rotation[-1], object_rotation[:-1], object_translation)        
        log_pose(pose=object_pose, label=f"world/objects/{object_name}")
        logging_status[object_uid] = True # mark object has been seen (enable to know which object has been logged or not)
        
        if show_bbx:
            bbx_vertex_0 = object_bbx_data[i, item*24:item*24+3]
            bbx_vertex_1 = object_bbx_data[i, item*24+3:item*24+6]
            bbx_vertex_2 = object_bbx_data[i, item*24+6:item*24+9]
            bbx_vertex_3 = object_bbx_data[i, item*24+9:item*24+12]
            bbx_vertex_4 = object_bbx_data[i, item*24+12:item*24+15]
            bbx_vertex_5 = object_bbx_data[i, item*24+15:item*24+18]
            bbx_vertex_6 = object_bbx_data[i, item*24+18:item*24+21]
            bbx_vertex_7 = object_bbx_data[i, item*24+21:item*24+24]    
            points = [
                bbx_vertex_0,
                bbx_vertex_1,
                bbx_vertex_2,
                bbx_vertex_3,
                bbx_vertex_0,
                bbx_vertex_7,
                bbx_vertex_6,
                bbx_vertex_5,
                bbx_vertex_4,
                bbx_vertex_7,
                bbx_vertex_6,
                bbx_vertex_1,
                bbx_vertex_2,
                bbx_vertex_5,
                bbx_vertex_4,
                bbx_vertex_3]
            rr.log(f"world/objects_bbx/{object_name}", rr.LineStrips3D([points]))
            
        # link the corresponding 3D object to the pose
        if object_uid not in object_cache_status.keys():
            object_cache_status[object_uid] = True            
            rr.log(
                f"world/objects/{object_name}",
                rr.Asset3D(path=object_cad_asset_filepath))            
    # if some objects are not visible, we clear the entity
    for object_uid, displayed in logging_status.items():
        if not displayed:
            object_name = object_library.object_id_to_name_dict[object_uid]
            rr.log(
                f"world/objects/{object_name}",
                rr.Clear.recursive())
            if show_bbx:
                rr.log(
                    f"world/objects_bbx/{object_name}",
                    rr.Clear.recursive())            
            if object_uid in object_cache_status.keys():
                del object_cache_status[object_uid]  # we will log the mesh again
                
# show the rerun window
rr.notebook_show()
In [ ]: