data="NIV" horizon=3 attn="WithAttention" mask_type="multi_add" if [ $attn = WithAttention ] then if [ $data = "NIV" ] && [ $horizon = 4 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/$data/Ours-MultiAdd-Attention(4-16-4)-256-1248-3e-4-scaleNorm/T4_act_epoch0095_2.pth.tar" fi if [ $data = "NIV" ] && [ $horizon = 3 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/$data/Ours-MultiAdd-Attention(4-16-4)-256-1248-1e-4-scaleNorm/T3_act_epoch0080_0.pth.tar" fi if [ $data = "coin" ] && [ $horizon = 3 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/$data/Ours-MultiAdd-Attention(4-16-4)-512-124-1e-5-scaleNorm/T3_act_epoch0800_0.pth.tar" fi if [ $data = "coin" ] && [ $horizon = 4 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/$data/Ours-MultiAdd-Attention(4-16-4)-512-124-1e-5-scaleNorm/T4_act_epoch0735_0.pth.tar" fi if [ $data = "crosstask" ] && [ $horizon = 3 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/$data/Ours-MultiAdd-Attention(4-16-4)-512-124-1e-5-scaleNorm/T3_act_epoch0100_0.pth.tar" fi if [ $data = "crosstask" ] && [ $horizon = 4 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/$data/Ours-MultiAdd-Attention(4-16-4)-512-124-1e-5-scaleNorm/T4_act_epoch0090_2.pth.tar" fi if [ $data = "crosstask" ] && [ $horizon = 5 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/$data/Ours-MultiAdd-Attention(4-16-4)-512-124-1e-5-scaleNorm/T5_act_epoch0095_2.pth.tar" fi if [ $data = "crosstask" ] && [ $horizon = 6 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/$data/Ours-MultiAdd-Attention(4-16-4)-512-124-1e-5-scaleNorm/T6_act_epoch0075_2.pth.tar" fi fi if [ $attn = "NoAttention" ] then if [ $data = "NIV" ] && [ $horizon = 4 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/ablation_attn/$data/Ours-MultiAdd-NoAttention(4-16-4)-256-1248-3e-4-scaleNorm/T4_act_epoch0060_2.pth.tar" fi if [ $data = "NIV" ] && [ $horizon = 3 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/ablation_attn/$data/Ours-MultiAdd-NoAttention(4-16-4)-256-1248-1e-4-scaleNorm/T3_act_epoch0095_0.pth.tar" fi if [ $data = "coin" ] && [ $horizon = 3 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/ablation_attn/$data/Ours-MultiAdd-NoAttention(4-16-4)-512-124-1e-4-scaleNorm-max/T3_act_epoch0685_1.pth.tar" fi if [ $data = "coin" ] && [ $horizon = 4 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/ablation_attn/$data/Ours-MultiAdd-NoAttention(4-16-4)-512-124-1e-4-scaleNorm-max/T4_act_epoch0795_1.pth.tar" fi if [ $data = "crosstask" ] && [ $horizon = 3 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/ablation_attn/$data/Ours-MultiAdd-NoAttention(4-16-4)-512-124-1e-5-scaleNorm/T3_act_epoch0105_2.pth.tar" fi if [ $data = "crosstask" ] && [ $horizon = 4 ] && [ $mask_type = "multi_add" ] then checkpoint_diff="save_max/ablation_attn/$data/Ours-MultiAdd-NoAttention(4-16-4)-512-124-1e-5-scaleNorm/T4_act_epoch0100_2.pth.tar" fi fi if [ $data = "crosstask" ] then action_dim=105 class_dim=18 act_emb="dataset/crosstask/act_lang_emb.pkl" diffusion_step=200 train_step=200 json_train="dataset/crosstask/crosstask_release/train_split_T$horizon.json" json_val="dataset/crosstask/crosstask_release/crosstask_mlp_T$horizon.json" #json_val="dataset/crosstask/crosstask_release/test_split_T$horizon.json" epoch=120 lr=5e-4 fi if [ $data = "coin" ] then action_dim=778 class_dim=180 act_emb="dataset/coin/steps_info.pickle" diffusion_step=200 train_step=200 json_train="dataset/coin/train_split_T$horizon.json" json_val="dataset/coin/coin_mlp_T$horizon.json" #json_val="dataset/coin/test_split_T$horizon.json" epoch=800 lr=1e-5 fi if [ $data = "NIV" ] then action_dim=48 class_dim=5 act_emb="dataset/NIV/niv_act_embeddings.pickle" diffusion_step=50 train_step=50 json_train="dataset/NIV/train_split_T$horizon.json" json_val="dataset/NIV/NIV_mlp_T$horizon.json" #json_val="dataset/NIV/test_split_T$horizon.json" epoch=130 if [ $horizon -eq 3 ] then lr=1e-4 fi if [ $horizon -eq 4 ] then lr=3e-4 fi fi python3 inference_act_new_dist.py \ --multiprocessing-distributed \ --num_thread_reader=8 \ --cudnn_benchmark=1 \ --pin_memory \ --checkpoint_dir=whl \ --resume \ --batch_size=256 \ --batch_size_val=256 \ --evaluate > ${data}_T${horizon}_dist_mean_act_output.txt \ --checkpoint_diff ${checkpoint_diff} \ --dataset ${data} \ --horizon ${horizon} \ --mask_type ${mask_type} \ --attn ${attn} \ --act_emb_path ${act_emb} \ --action_dim ${action_dim} \ --class_dim ${class_dim} \ --n_diffusion_steps ${diffusion_step} \ --n_train_steps ${train_step} \ --json_path_train ${json_train} \ --json_path_val ${json_val} \ --epochs ${epoch} \ --lr ${lr} \ --use_cls_mask True