from model.temporal_act import TemporalUnetNoAttn, TemporalUnet from utils.args import get_args def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) args = get_args() args.dataset='crosstask' if args.dataset=='NIV': args.action_dim = 48 args.observation_dim = 1536 args.class_dim = 5 if args.dataset=='crosstask': args.action_dim = 105 args.observation_dim = 1536 args.class_dim = 18 if args.dataset=='coin': args.action_dim = 778 args.observation_dim = 1536 args.class_dim = 180 if args.dataset=='NIV': temporal_model_no_attn = TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), ) temporal_model = TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), ) else: temporal_model_no_attn = TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), ) temporal_model = TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), ) print('no attention para:', count_parameters(temporal_model_no_attn)/(1000*1000)) print('attention para:', count_parameters(temporal_model)/(1000*1000)) '''nb_params = 0 for name, param in temporal_model_no_attn.named_parameters(): print("parameter {} contains {} elements".format(name, param.nelement())) nb_params += param.nelement() print('no attention para:', nb_params / 1e6)'''