first commit
This commit is contained in:
parent
99ce0acafb
commit
8f6b6a34e7
73 changed files with 11656 additions and 0 deletions
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from utils import *
|
171
utils/adt_dataset.py
Normal file
171
utils/adt_dataset.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class adt_dataset(Dataset):
|
||||
def __init__(self, data_dir, seq_len, actions = 'all', train_flag = 1, object_num=1, hand_joint_number=1, sample_rate=1):
|
||||
actions = self.define_actions(actions)
|
||||
self.sample_rate = sample_rate
|
||||
if train_flag == 1:
|
||||
data_dir = data_dir + 'train/'
|
||||
if train_flag == 0:
|
||||
data_dir = data_dir + 'test/'
|
||||
|
||||
self.dataset = self.load_data(data_dir, seq_len, actions, object_num, hand_joint_number)
|
||||
|
||||
def define_actions(self, action):
|
||||
"""
|
||||
Define the list of actions we are using.
|
||||
|
||||
Args
|
||||
action: String with the passed action. Could be "all"
|
||||
Returns
|
||||
actions: List of strings of actions
|
||||
Raises
|
||||
ValueError if the action is not included.
|
||||
"""
|
||||
|
||||
actions = ['work', 'decoration', 'meal']
|
||||
if action in actions:
|
||||
return [action]
|
||||
|
||||
if action == "all":
|
||||
return actions
|
||||
raise( ValueError, "Unrecognised action: %d" % action )
|
||||
|
||||
def load_data(self, data_dir, seq_len, actions, object_num, hand_joint_number):
|
||||
action_number = len(actions)
|
||||
dataset = []
|
||||
file_names = sorted(os.listdir(data_dir))
|
||||
gaze_file_names = {}
|
||||
hand_file_names = {}
|
||||
hand_joint_file_names = {}
|
||||
head_file_names = {}
|
||||
object_left_file_names = {}
|
||||
object_right_file_names = {}
|
||||
for action_idx in np.arange(action_number):
|
||||
gaze_file_names[actions[ action_idx ]] = []
|
||||
hand_file_names[actions[ action_idx ]] = []
|
||||
hand_joint_file_names[actions[ action_idx ]] = []
|
||||
head_file_names[actions[ action_idx ]] = []
|
||||
object_left_file_names[actions[ action_idx ]] = []
|
||||
object_right_file_names[actions[ action_idx ]] = []
|
||||
|
||||
for name in file_names:
|
||||
name_split = name.split('_')
|
||||
action = name_split[2]
|
||||
if action in actions:
|
||||
data_type = name_split[-1][:-4]
|
||||
if(data_type == 'gaze'):
|
||||
gaze_file_names[action].append(name)
|
||||
if(data_type == 'hand'):
|
||||
hand_file_names[action].append(name)
|
||||
if(data_type == 'handjoints'):
|
||||
hand_joint_file_names[action].append(name)
|
||||
if(data_type == 'head'):
|
||||
head_file_names[action].append(name)
|
||||
if(data_type == 'bbxleft'):
|
||||
object_left_file_names[action].append(name)
|
||||
if(data_type == 'bbxright'):
|
||||
object_right_file_names[action].append(name)
|
||||
|
||||
for action_idx in np.arange(action_number):
|
||||
action = actions[ action_idx ]
|
||||
segments_number = len(gaze_file_names[action])
|
||||
print("Reading action {}, segments number {}".format(action, segments_number))
|
||||
for i in range(segments_number):
|
||||
gaze_data_path = data_dir + gaze_file_names[action][i]
|
||||
gaze_data = np.load(gaze_data_path)
|
||||
gaze_direction = gaze_data[:, :3]
|
||||
num_frames = gaze_data.shape[0]
|
||||
if num_frames < seq_len:
|
||||
continue
|
||||
hand_data_path = data_dir + hand_file_names[action][i]
|
||||
hand_translation = np.load(hand_data_path)
|
||||
hand_joint_data_path = data_dir + hand_joint_file_names[action][i]
|
||||
hand_joint_data_all = np.load(hand_joint_data_path)
|
||||
hand_joint_number_default = 15
|
||||
hand_joint_data = hand_joint_data_all[:, :hand_joint_number_default*6]
|
||||
left_hand_center = np.mean(hand_joint_data[:, :hand_joint_number_default*3].reshape(hand_joint_data.shape[0], hand_joint_number_default, 3), axis=1)
|
||||
right_hand_center = np.mean(hand_joint_data[:, hand_joint_number_default*3:].reshape(hand_joint_data.shape[0], hand_joint_number_default, 3), axis=1)
|
||||
if hand_joint_number == 1:
|
||||
hand_joint_data = np.concatenate((left_hand_center, right_hand_center), axis=1)
|
||||
|
||||
attended_hand_gt = hand_joint_data_all[:, hand_joint_number_default*6:hand_joint_number_default*6+1]
|
||||
attended_hand_baseline = hand_joint_data_all[:, hand_joint_number_default*6+1:hand_joint_number_default*6+2]
|
||||
|
||||
head_data_path = data_dir + head_file_names[action][i]
|
||||
head_data = np.load(head_data_path)
|
||||
head_direction = head_data[:, :3]
|
||||
head_translation = head_data[:, 3:]
|
||||
|
||||
object_left_data_path = data_dir + object_left_file_names[action][i]
|
||||
object_left_data = np.load(object_left_data_path)
|
||||
object_left_data = object_left_data.reshape(object_left_data.shape[0], -1)
|
||||
object_right_data_path = data_dir + object_right_file_names[action][i]
|
||||
object_right_data = np.load(object_right_data_path)
|
||||
object_right_data = object_right_data.reshape(object_right_data.shape[0], -1)
|
||||
|
||||
object_left_bbx = []
|
||||
object_right_bbx = []
|
||||
for item in range(object_num):
|
||||
left_bbx = object_left_data[:, item*24:item*24+24]
|
||||
right_bbx = object_right_data[:, item*24:item*24+24]
|
||||
if len(object_left_bbx) == 0:
|
||||
object_left_bbx = left_bbx
|
||||
object_right_bbx = right_bbx
|
||||
else:
|
||||
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
|
||||
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
|
||||
|
||||
#object_left_positions = np.mean(object_left_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
|
||||
#object_right_positions = np.mean(object_right_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
|
||||
|
||||
data = gaze_direction
|
||||
data = np.concatenate((data, hand_translation), axis=1)
|
||||
data = np.concatenate((data, head_translation), axis=1)
|
||||
data = np.concatenate((data, hand_joint_data), axis=1)
|
||||
data = np.concatenate((data, head_direction), axis=1)
|
||||
if object_num > 0:
|
||||
data = np.concatenate((data, object_left_bbx), axis=1)
|
||||
data = np.concatenate((data, object_right_bbx), axis=1)
|
||||
data = np.concatenate((data, attended_hand_gt), axis=1)
|
||||
data = np.concatenate((data, attended_hand_baseline), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
#print(fs_sel)
|
||||
seq_sel = data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::self.sample_rate, :, :]
|
||||
#print(seq_sel.shape)
|
||||
if len(dataset) == 0:
|
||||
dataset = seq_sel
|
||||
else:
|
||||
dataset = np.concatenate((dataset, seq_sel), axis=0)
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
return np.shape(self.dataset)[0]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_dir = "/scratch/hu/pose_forecast/adt_hoigaze/"
|
||||
seq_len = 15
|
||||
actions = 'all'
|
||||
sample_rate = 1
|
||||
train_flag = 1
|
||||
object_num = 1
|
||||
hand_joint_number = 1
|
||||
train_dataset = adt_dataset(data_dir, seq_len, actions, train_flag, object_num, hand_joint_number, sample_rate)
|
||||
print("Training data size: {}".format(train_dataset.dataset.shape))
|
||||
|
||||
hand_joint_dominance = train_dataset[:, :, -2:-1].flatten()
|
||||
print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))
|
||||
|
137
utils/hot3d_aria_dataset.py
Normal file
137
utils/hot3d_aria_dataset.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class hot3d_aria_dataset(Dataset):
|
||||
def __init__(self, data_dir, subjects, seq_len, actions = 'all', object_num=1, sample_rate=1):
|
||||
if actions == 'all':
|
||||
actions = ['room', 'kitchen', 'office']
|
||||
self.sample_rate = sample_rate
|
||||
self.dataset = self.load_data(data_dir, subjects, seq_len, actions, object_num)
|
||||
|
||||
def load_data(self, data_dir, subjects, seq_len, actions, object_num):
|
||||
dataset = []
|
||||
file_names = sorted(os.listdir(data_dir))
|
||||
gaze_file_names = []
|
||||
hand_file_names = []
|
||||
hand_joint_file_names = []
|
||||
head_file_names = []
|
||||
object_left_file_names = []
|
||||
object_right_file_names = []
|
||||
for name in file_names:
|
||||
name_split = name.split('_')
|
||||
subject = name_split[0]
|
||||
action = name_split[2]
|
||||
if subject in subjects and action in actions:
|
||||
data_type = name_split[-1][:-4]
|
||||
if(data_type == 'gaze'):
|
||||
gaze_file_names.append(name)
|
||||
if(data_type == 'hand'):
|
||||
hand_file_names.append(name)
|
||||
if(data_type == 'handjoints'):
|
||||
hand_joint_file_names.append(name)
|
||||
if(data_type == 'head'):
|
||||
head_file_names.append(name)
|
||||
if(data_type == 'bbxleft'):
|
||||
object_left_file_names.append(name)
|
||||
if(data_type == 'bbxright'):
|
||||
object_right_file_names.append(name)
|
||||
|
||||
segments_number = len(hand_file_names)
|
||||
# print("segments number {}".format(segments_number))
|
||||
for i in range(segments_number):
|
||||
gaze_data_path = data_dir + gaze_file_names[i]
|
||||
gaze_data = np.load(gaze_data_path)
|
||||
num_frames = gaze_data.shape[0]
|
||||
if num_frames < seq_len:
|
||||
continue
|
||||
hand_data_path = data_dir + hand_file_names[i]
|
||||
hand_data = np.load(hand_data_path)
|
||||
hand_joint_data_path = data_dir + hand_joint_file_names[i]
|
||||
hand_joint_data_all = np.load(hand_joint_data_path)
|
||||
hand_joint_data = hand_joint_data_all[:, :120]
|
||||
attended_hand_gt = hand_joint_data_all[:, 120:121]
|
||||
attended_hand_baseline = hand_joint_data_all[:, 121:122]
|
||||
|
||||
head_data_path = data_dir + head_file_names[i]
|
||||
head_data = np.load(head_data_path)
|
||||
object_left_data_path = data_dir + object_left_file_names[i]
|
||||
object_left_data = np.load(object_left_data_path)
|
||||
object_right_data_path = data_dir + object_right_file_names[i]
|
||||
object_right_data = np.load(object_right_data_path)
|
||||
|
||||
left_hand_translation = hand_data[:, 0:3]
|
||||
right_hand_translation = hand_data[:, 22:25]
|
||||
head_direction = head_data[:, 0:3]
|
||||
head_translation = head_data[:, 3:6]
|
||||
gaze_direction = gaze_data[:, 0:3]
|
||||
object_left_bbx = []
|
||||
object_right_bbx = []
|
||||
for item in range(object_num):
|
||||
left_bbx = object_left_data[:, item*24:item*24+24]
|
||||
right_bbx = object_right_data[:, item*24:item*24+24]
|
||||
if len(object_left_bbx) == 0:
|
||||
object_left_bbx = left_bbx
|
||||
object_right_bbx = right_bbx
|
||||
else:
|
||||
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
|
||||
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
|
||||
|
||||
#object_left_positions = np.mean(object_left_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
|
||||
#object_right_positions = np.mean(object_right_bbx.reshape(num_frames, object_num, 8, 3), axis=2).reshape(num_frames, -1)
|
||||
|
||||
data = gaze_direction
|
||||
data = np.concatenate((data, left_hand_translation), axis=1)
|
||||
data = np.concatenate((data, right_hand_translation), axis=1)
|
||||
data = np.concatenate((data, head_translation), axis=1)
|
||||
data = np.concatenate((data, hand_joint_data), axis=1)
|
||||
data = np.concatenate((data, head_direction), axis=1)
|
||||
if object_num > 0:
|
||||
data = np.concatenate((data, object_left_bbx), axis=1)
|
||||
data = np.concatenate((data, object_right_bbx), axis=1)
|
||||
data = np.concatenate((data, attended_hand_gt), axis=1)
|
||||
data = np.concatenate((data, attended_hand_baseline), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
seq_sel = data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::self.sample_rate, :, :]
|
||||
if len(dataset) == 0:
|
||||
dataset = seq_sel
|
||||
else:
|
||||
dataset = np.concatenate((dataset, seq_sel), axis=0)
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
return np.shape(self.dataset)[0]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_dir = "/scratch/hu/pose_forecast/hot3d_hoigaze/"
|
||||
seq_len = 15
|
||||
actions = 'all'
|
||||
all_subjects = ['P0001', 'P0002', 'P0003', 'P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
train_subjects = ['P0009', 'P0010', 'P0011', 'P0012', 'P0014', 'P0015']
|
||||
object_num = 1
|
||||
sample_rate = 10
|
||||
|
||||
train_dataset = hot3d_aria_dataset(data_dir, train_subjects, seq_len, actions, object_num, sample_rate)
|
||||
print("Training data size: {}".format(train_dataset.dataset.shape))
|
||||
|
||||
hand_joint_dominance = train_dataset[:, :, -2:-1].flatten()
|
||||
print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))
|
||||
|
||||
#test_subjects = ['P0001', 'P0002', 'P0003']
|
||||
#sample_rate = 8
|
||||
#test_dataset = hot3d_aria_dataset(data_dir, test_subjects, seq_len, actions, #object_num, sample_rate)
|
||||
# print("Test data size: {}".format(test_dataset.dataset.shape))
|
||||
|
||||
#hand_joint_dominance = test_dataset[:, :, -2:-1].flatten()
|
||||
#print("right hand ratio: {:.2f}".format(np.sum(hand_joint_dominance)/hand_joint_dominance.shape[0]*100))
|
91
utils/hot3d_aria_single_dataset.py
Normal file
91
utils/hot3d_aria_single_dataset.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class hot3d_aria_dataset(Dataset):
|
||||
def __init__(self, data_path, seq_len, object_num=1):
|
||||
self.dataset = self.load_data(data_path, seq_len, object_num)
|
||||
|
||||
def load_data(self, data_path, seq_len, object_num):
|
||||
dataset = []
|
||||
gaze_file_name = data_path + 'gaze.npy'
|
||||
hand_file_name = data_path + 'hand.npy'
|
||||
hand_joint_file_name = data_path + 'handjoints.npy'
|
||||
head_file_name = data_path + 'head.npy'
|
||||
object_left_file_name = data_path + 'object_bbxleft.npy'
|
||||
object_right_file_name = data_path + 'object_bbxright.npy'
|
||||
|
||||
gaze_data_path = gaze_file_name
|
||||
gaze_data = np.load(gaze_data_path)
|
||||
num_frames = gaze_data.shape[0]
|
||||
hand_data_path = hand_file_name
|
||||
hand_data = np.load(hand_data_path)
|
||||
hand_joint_data_path = hand_joint_file_name
|
||||
hand_joint_data_all = np.load(hand_joint_data_path)
|
||||
hand_joint_data = hand_joint_data_all[:, :120]
|
||||
attended_hand_gt = hand_joint_data_all[:, 120:121]
|
||||
attended_hand_baseline = hand_joint_data_all[:, 121:122]
|
||||
|
||||
head_data_path = head_file_name
|
||||
head_data = np.load(head_data_path)
|
||||
object_left_data_path = object_left_file_name
|
||||
object_left_data = np.load(object_left_data_path)
|
||||
object_right_data_path = object_right_file_name
|
||||
object_right_data = np.load(object_right_data_path)
|
||||
|
||||
left_hand_translation = hand_data[:, 0:3]
|
||||
right_hand_translation = hand_data[:, 22:25]
|
||||
head_direction = head_data[:, 0:3]
|
||||
head_translation = head_data[:, 3:6]
|
||||
gaze_direction = gaze_data[:, 0:3]
|
||||
object_left_bbx = []
|
||||
object_right_bbx = []
|
||||
for item in range(object_num):
|
||||
left_bbx = object_left_data[:, item*24:item*24+24]
|
||||
right_bbx = object_right_data[:, item*24:item*24+24]
|
||||
if len(object_left_bbx) == 0:
|
||||
object_left_bbx = left_bbx
|
||||
object_right_bbx = right_bbx
|
||||
else:
|
||||
object_left_bbx = np.concatenate((object_left_bbx, left_bbx), axis=1)
|
||||
object_right_bbx = np.concatenate((object_right_bbx, right_bbx), axis=1)
|
||||
|
||||
data = gaze_direction
|
||||
data = np.concatenate((data, left_hand_translation), axis=1)
|
||||
data = np.concatenate((data, right_hand_translation), axis=1)
|
||||
data = np.concatenate((data, head_translation), axis=1)
|
||||
data = np.concatenate((data, hand_joint_data), axis=1)
|
||||
data = np.concatenate((data, head_direction), axis=1)
|
||||
if object_num > 0:
|
||||
data = np.concatenate((data, object_left_bbx), axis=1)
|
||||
data = np.concatenate((data, object_right_bbx), axis=1)
|
||||
data = np.concatenate((data, attended_hand_gt), axis=1)
|
||||
data = np.concatenate((data, attended_hand_baseline), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
seq_sel = data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::seq_len, :, :]
|
||||
if len(dataset) == 0:
|
||||
dataset = seq_sel
|
||||
else:
|
||||
dataset = np.concatenate((dataset, seq_sel), axis=0)
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
return np.shape(self.dataset)[0]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = '/scratch/hu/pose_forecast/hot3d_hoigaze/P0001_10a27bf7_room_721_890_'
|
||||
seq_len = 15
|
||||
object_num = 1
|
||||
train_dataset = hot3d_aria_dataset(data_path, seq_len, object_num)
|
||||
print("Training data size: {}".format(train_dataset.dataset.shape))
|
28
utils/log.py
Normal file
28
utils/log.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import json
|
||||
import os
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_csv_log(opt, head, value, is_create=False, file_name='results'):
|
||||
if len(value.shape) < 2:
|
||||
value = np.expand_dims(value, axis=0)
|
||||
df = pd.DataFrame(value)
|
||||
file_path = opt.ckpt + '/{}.csv'.format(file_name)
|
||||
print(file_path)
|
||||
if not os.path.exists(file_path) or is_create:
|
||||
df.to_csv(file_path, header=head, index=False)
|
||||
else:
|
||||
with open(file_path, 'a') as f:
|
||||
df.to_csv(f, header=False, index=False)
|
||||
|
||||
|
||||
def save_ckpt(state, opt=None, file_name = 'model.pt'):
|
||||
file_path = os.path.join(opt.ckpt, file_name)
|
||||
torch.save(state, file_path)
|
||||
|
||||
|
||||
def save_options(opt):
|
||||
with open(opt.ckpt + '/options.json', 'w') as f:
|
||||
f.write(json.dumps(vars(opt), sort_keys=False, indent=4))
|
74
utils/opt.py
Normal file
74
utils/opt.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import os
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
class options:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.opt = None
|
||||
|
||||
def _initial(self):
|
||||
# ===============================================================
|
||||
# General options
|
||||
# ===============================================================
|
||||
self.parser.add_argument('--cuda_idx', type=str, default='cuda:0', help='cuda idx')
|
||||
self.parser.add_argument('--data_dir', type=str,
|
||||
default='./dataset/',
|
||||
help='path to dataset')
|
||||
self.parser.add_argument('--is_eval', dest='is_eval', action='store_true',
|
||||
help='whether to evaluate existing models or not')
|
||||
self.parser.add_argument('--ckpt', type=str, default='./checkpoints/', help='path to save checkpoints')
|
||||
self.parser.add_argument('--test_user_id', type=int, default=1, help='id of the test participants')
|
||||
self.parser.add_argument('--actions', type=str, default='all', help='actions to use')
|
||||
self.parser.add_argument('--sample_rate', type=int, default=2, help='sample the data')
|
||||
self.parser.add_argument('--save_predictions', dest='save_predictions', action='store_true',
|
||||
help='whether to save the prediction results or not')
|
||||
# ===============================================================
|
||||
# Model options
|
||||
# ===============================================================
|
||||
self.parser.add_argument('--body_joint_number', type=int, default=3, help='number of body joints to use')
|
||||
self.parser.add_argument('--hand_joint_number', type=int, default=20, help='number of hand joints to use')
|
||||
self.parser.add_argument('--head_cnn_channels', type=int, default=32, help='number of channels used in the head_CNN')
|
||||
self.parser.add_argument('--gcn_latent_features', type=int, default=8, help='number of latent features used in the gcn')
|
||||
self.parser.add_argument('--residual_gcns_num', type=int, default=4, help='number of residual gcns to use')
|
||||
self.parser.add_argument('--gcn_dropout', type=float, default=0.3, help='drop out probability in the gcn')
|
||||
self.parser.add_argument('--gaze_cnn_channels', type=int, default=64, help='number of channels used in the gaze_CNN')
|
||||
self.parser.add_argument('--recognition_cnn_channels', type=int, default=64, help='number of channels used in the recognition_CNN')
|
||||
self.parser.add_argument('--object_num', type=int, default=1, help='number of scene objects for gaze estimation')
|
||||
self.parser.add_argument('--use_self_att', type=int, default=1, help='use self attention or not')
|
||||
self.parser.add_argument('--self_att_head_num', type=int, default=1, help='number of heads used in self attention')
|
||||
self.parser.add_argument('--self_att_dropout', type=float, default=0.1, help='drop out probability in self attention')
|
||||
self.parser.add_argument('--use_cross_att', type=int, default=1, help='use cross attention or not')
|
||||
self.parser.add_argument('--cross_att_head_num', type=int, default=1, help='number of heads used in cross attention')
|
||||
self.parser.add_argument('--cross_att_dropout', type=float, default=0.1, help='drop out probability in cross attention')
|
||||
self.parser.add_argument('--use_attended_hand', type=int, default=1, help='use attended hand or use both hands')
|
||||
self.parser.add_argument('--use_attended_hand_gt', type=int, default=0, help='use attended hand ground truth or not')
|
||||
# ===============================================================
|
||||
# Running options
|
||||
# ===============================================================
|
||||
self.parser.add_argument('--seq_len', type=int, default=15, help='the length of the used sequence')
|
||||
self.parser.add_argument('--learning_rate', type=float, default=0.005)
|
||||
self.parser.add_argument('--gaze_head_loss_factor', type=float, default=4.0)
|
||||
self.parser.add_argument('--gaze_head_cos_threshold', type=float, default=0.8)
|
||||
self.parser.add_argument('--weight_decay', type=float, default=0.0)
|
||||
self.parser.add_argument('--gamma', type=float, default=0.95, help='decay learning rate by gamma')
|
||||
self.parser.add_argument('--epoch', type=int, default=50)
|
||||
self.parser.add_argument('--batch_size', type=int, default=32)
|
||||
self.parser.add_argument('--validation_epoch', type=int, default=10, help='interval of epoches to test')
|
||||
self.parser.add_argument('--test_batch_size', type=int, default=32)
|
||||
|
||||
def _print(self):
|
||||
print("\n==================Options=================")
|
||||
pprint(vars(self.opt), indent=4)
|
||||
print("==========================================\n")
|
||||
|
||||
def parse(self, make_dir=True):
|
||||
self._initial()
|
||||
self.opt = self.parser.parse_args()
|
||||
ckpt = self.opt.ckpt
|
||||
if make_dir==True:
|
||||
if not os.path.isdir(ckpt):
|
||||
os.makedirs(ckpt)
|
||||
self._print()
|
||||
return self.opt
|
15
utils/seed_torch.py
Normal file
15
utils/seed_torch.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def seed_torch(seed=0):
|
||||
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
Loading…
Add table
Add a link
Reference in a new issue