HOIGaze/utils/opt.py
2025-04-30 14:15:00 +02:00

74 lines
No EOL
5.2 KiB
Python

import os
import argparse
from pprint import pprint
class options:
def __init__(self):
self.parser = argparse.ArgumentParser()
self.opt = None
def _initial(self):
# ===============================================================
# General options
# ===============================================================
self.parser.add_argument('--cuda_idx', type=str, default='cuda:0', help='cuda idx')
self.parser.add_argument('--data_dir', type=str,
default='./dataset/',
help='path to dataset')
self.parser.add_argument('--is_eval', dest='is_eval', action='store_true',
help='whether to evaluate existing models or not')
self.parser.add_argument('--ckpt', type=str, default='./checkpoints/', help='path to save checkpoints')
self.parser.add_argument('--test_user_id', type=int, default=1, help='id of the test participants')
self.parser.add_argument('--actions', type=str, default='all', help='actions to use')
self.parser.add_argument('--sample_rate', type=int, default=2, help='sample the data')
self.parser.add_argument('--save_predictions', dest='save_predictions', action='store_true',
help='whether to save the prediction results or not')
# ===============================================================
# Model options
# ===============================================================
self.parser.add_argument('--body_joint_number', type=int, default=3, help='number of body joints to use')
self.parser.add_argument('--hand_joint_number', type=int, default=20, help='number of hand joints to use')
self.parser.add_argument('--head_cnn_channels', type=int, default=32, help='number of channels used in the head_CNN')
self.parser.add_argument('--gcn_latent_features', type=int, default=8, help='number of latent features used in the gcn')
self.parser.add_argument('--residual_gcns_num', type=int, default=4, help='number of residual gcns to use')
self.parser.add_argument('--gcn_dropout', type=float, default=0.3, help='drop out probability in the gcn')
self.parser.add_argument('--gaze_cnn_channels', type=int, default=64, help='number of channels used in the gaze_CNN')
self.parser.add_argument('--recognition_cnn_channels', type=int, default=64, help='number of channels used in the recognition_CNN')
self.parser.add_argument('--object_num', type=int, default=1, help='number of scene objects for gaze estimation')
self.parser.add_argument('--use_self_att', type=int, default=1, help='use self attention or not')
self.parser.add_argument('--self_att_head_num', type=int, default=1, help='number of heads used in self attention')
self.parser.add_argument('--self_att_dropout', type=float, default=0.1, help='drop out probability in self attention')
self.parser.add_argument('--use_cross_att', type=int, default=1, help='use cross attention or not')
self.parser.add_argument('--cross_att_head_num', type=int, default=1, help='number of heads used in cross attention')
self.parser.add_argument('--cross_att_dropout', type=float, default=0.1, help='drop out probability in cross attention')
self.parser.add_argument('--use_attended_hand', type=int, default=1, help='use attended hand or use both hands')
self.parser.add_argument('--use_attended_hand_gt', type=int, default=0, help='use attended hand ground truth or not')
# ===============================================================
# Running options
# ===============================================================
self.parser.add_argument('--seq_len', type=int, default=15, help='the length of the used sequence')
self.parser.add_argument('--learning_rate', type=float, default=0.005)
self.parser.add_argument('--gaze_head_loss_factor', type=float, default=4.0)
self.parser.add_argument('--gaze_head_cos_threshold', type=float, default=0.8)
self.parser.add_argument('--weight_decay', type=float, default=0.0)
self.parser.add_argument('--gamma', type=float, default=0.95, help='decay learning rate by gamma')
self.parser.add_argument('--epoch', type=int, default=50)
self.parser.add_argument('--batch_size', type=int, default=32)
self.parser.add_argument('--validation_epoch', type=int, default=10, help='interval of epoches to test')
self.parser.add_argument('--test_batch_size', type=int, default=32)
def _print(self):
print("\n==================Options=================")
pprint(vars(self.opt), indent=4)
print("==========================================\n")
def parse(self, make_dir=True):
self._initial()
self.opt = self.parser.parse_args()
ckpt = self.opt.ckpt
if make_dir==True:
if not os.path.isdir(ckpt):
os.makedirs(ckpt)
self._print()
return self.opt