import argparse def get_args(description='whl'): parser = argparse.ArgumentParser(description=description) parser.add_argument('--act_emb_path', type=str, default='dataset/coin/steps_info.pickle', help='action embedding path') parser.add_argument('--checkpoint_mlp', type=str, default='', help='checkpoint path for task prediction model') parser.add_argument('--checkpoint_diff', type=str, default='', help='checkpoint path for diffusion model') parser.add_argument('--mask_type', type=str, default='multi_add', # single_add, multi_add help='action embedding mask type') parser.add_argument('--attn', type=str, default='attention', # single_add, multi_add help='WithAttention: unet with attn. NoAttention: unet without attention.') parser.add_argument('--infer_avg_mask', type=bool, default=False, help='if use average mask for inference') parser.add_argument('--use_cls_mask', type=bool, default=False, help='if use class label in diffusion mask') parser.add_argument('--checkpoint_root', type=str, default='checkpoint', help='checkpoint dir root') parser.add_argument('--log_root', type=str, default='log', help='log dir root') parser.add_argument('--checkpoint_dir', type=str, default='', help='checkpoint model folder') parser.add_argument('--optimizer', type=str, default='adam', help='opt algorithm') parser.add_argument('--num_thread_reader', type=int, default=40, help='') parser.add_argument('--batch_size', type=int, default=256, # 256 help='batch size') parser.add_argument('--batch_size_val', type=int, default=1024, # 1024 help='batch size eval') parser.add_argument('--pretrain_cnn_path', type=str, default='', help='') parser.add_argument('--momemtum', type=float, default=0.9, help='SGD momemtum') parser.add_argument('--log_freq', type=int, default=500, help='how many steps do we log once') parser.add_argument('--save_freq', type=int, default=1, help='how many epochs do we save once') parser.add_argument('--gradient_accumulate_every', type=int, default=1, help='accumulation_steps') parser.add_argument('--ema_decay', type=float, default=0.995, help='') parser.add_argument('--step_start_ema', type=int, default=400, help='') parser.add_argument('--update_ema_every', type=int, default=10, help='') parser.add_argument('--crop_only', type=int, default=1, help='random seed') parser.add_argument('--centercrop', type=int, default=0, help='random seed') parser.add_argument('--random_flip', type=int, default=1, help='random seed') parser.add_argument('--verbose', type=int, default=1, help='') parser.add_argument('--fps', type=int, default=1, help='') parser.add_argument('--cudnn_benchmark', type=int, default=0, help='') parser.add_argument('--horizon', type=int, default=3, help='') parser.add_argument('--dataset', type=str, default='coin', help='dataset') parser.add_argument('--action_dim', type=int, default=778, help='') parser.add_argument('--observation_dim', type=int, default=1536, help='') parser.add_argument('--class_dim', type=int, default=180, help='') parser.add_argument('--n_diffusion_steps', type=int, default=200, help='') parser.add_argument('--n_train_steps', type=int, default=200, help='training_steps_per_epoch') parser.add_argument('--root', type=str, default='', help='root path of dataset crosstask') parser.add_argument('--json_path_train', type=str, default='dataset/coin/train_split_T4.json', help='path of the generated json file for train') parser.add_argument('--json_path_val', type=str, default='dataset/coin/coin_mlp_T4.json', help='path of the generated json file for val') parser.add_argument('--epochs', default=800, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--resume', dest='resume', action='store_true', help='resume training from last checkpoint') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--pin_memory', dest='pin_memory', action='store_true', help='use pin_memory') parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=0, type=int, help='node rank for distributed training') parser.add_argument('--dist-file', default='dist-file', type=str, help='url used to set up distributed training') parser.add_argument('--dist-url', default='tcp://localhost:20000', type=str, help='url used to set up distributed training') parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') parser.add_argument('--seed', default=217, type=int, help='seed for initializing training. ') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') args = parser.parse_args() return args