ActionDiffusion_WACV2025/print_para.py
2024-12-02 15:42:58 +01:00

43 lines
1.5 KiB
Python

from model.temporal_act import TemporalUnetNoAttn, TemporalUnet
from utils.args import get_args
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
args = get_args()
args.dataset='crosstask'
if args.dataset=='NIV':
args.action_dim = 48
args.observation_dim = 1536
args.class_dim = 5
if args.dataset=='crosstask':
args.action_dim = 105
args.observation_dim = 1536
args.class_dim = 18
if args.dataset=='coin':
args.action_dim = 778
args.observation_dim = 1536
args.class_dim = 180
if args.dataset=='NIV':
temporal_model_no_attn = TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), )
temporal_model = TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), )
else:
temporal_model_no_attn = TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), )
temporal_model = TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), )
print('no attention para:', count_parameters(temporal_model_no_attn)/(1000*1000))
print('attention para:', count_parameters(temporal_model)/(1000*1000))
'''nb_params = 0
for name, param in temporal_model_no_attn.named_parameters():
print("parameter {} contains {} elements".format(name, param.nelement()))
nb_params += param.nelement()
print('no attention para:', nb_params / 1e6)'''