ActionDiffusion_WACV2025/inference_no_mask.sh
2024-12-02 15:42:58 +01:00

71 lines
1.6 KiB
Bash

data="NIV"
horizon=4
attn="WithAttention"
checkpoint_diff="save_max/ablation_no_mask/$data/T$horizon""_act_epoch0125.pth.tar"
if [ $data = "crosstask" ]
then
action_dim=105
class_dim=18
act_emb="dataset/crosstask/act_lang_emb.pkl"
diffusion_step=200
train_step=200
json_train="dataset/crosstask/crosstask_release/train_split_T$horizon.json"
json_val="dataset/crosstask/crosstask_release/crosstask_mlp_T$horizon.json"
epoch=120
lr=5e-4
fi
if [ $data = "coin" ]
then
action_dim=778
class_dim=180
act_emb="dataset/coin/steps_info.pickle"
diffusion_step=200
train_step=200
json_train="dataset/coin/train_split_T$horizon.json"
json_val="dataset/coin/coin_mlp_T$horizon.json"
epoch=800
lr=1e-5
fi
if [ $data = "NIV" ]
then
action_dim=48
class_dim=5
act_emb="dataset/NIV/niv_act_embeddings.pickle"
diffusion_step=50
train_step=50
json_train="dataset/NIV/train_split_T$horizon.json"
json_val="dataset/NIV/NIV_mlp_T$horizon.json"
epoch=130
if [ $horizon -eq 3 ]
then
lr=1e-4
fi
if [ $horizon -eq 4 ]
then
lr=3e-4
fi
fi
python3 inference_no_mask.py \
--multiprocessing-distributed \
--num_thread_reader=8 \
--cudnn_benchmark=1 \
--pin_memory \
--checkpoint_dir=whl \
--resume \
--batch_size=256 \
--batch_size_val=256 \
--evaluate > ${data}_T${horizon}_no_mask_output.txt \
--checkpoint_diff ${checkpoint_diff} \
--dataset ${data} \
--horizon ${horizon} \
--attn ${attn} \
--act_emb_path ${act_emb} \
--action_dim ${action_dim} \
--class_dim ${class_dim} \
--n_diffusion_steps ${diffusion_step} \
--n_train_steps ${train_step} \
--json_path_train ${json_train} \
--json_path_val ${json_val} \
--epochs ${epoch} \
--lr ${lr} \
--use_cls_mask True