44 lines
1.5 KiB
Python
44 lines
1.5 KiB
Python
|
from model.temporal_act import TemporalUnetNoAttn, TemporalUnet
|
||
|
from utils.args import get_args
|
||
|
|
||
|
def count_parameters(model):
|
||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
|
|
||
|
args = get_args()
|
||
|
|
||
|
args.dataset='crosstask'
|
||
|
|
||
|
if args.dataset=='NIV':
|
||
|
args.action_dim = 48
|
||
|
args.observation_dim = 1536
|
||
|
args.class_dim = 5
|
||
|
|
||
|
if args.dataset=='crosstask':
|
||
|
args.action_dim = 105
|
||
|
args.observation_dim = 1536
|
||
|
args.class_dim = 18
|
||
|
|
||
|
if args.dataset=='coin':
|
||
|
args.action_dim = 778
|
||
|
args.observation_dim = 1536
|
||
|
args.class_dim = 180
|
||
|
|
||
|
|
||
|
if args.dataset=='NIV':
|
||
|
|
||
|
temporal_model_no_attn = TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), )
|
||
|
temporal_model = TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), )
|
||
|
else:
|
||
|
|
||
|
temporal_model_no_attn = TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), )
|
||
|
temporal_model = TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), )
|
||
|
|
||
|
print('no attention para:', count_parameters(temporal_model_no_attn)/(1000*1000))
|
||
|
print('attention para:', count_parameters(temporal_model)/(1000*1000))
|
||
|
|
||
|
'''nb_params = 0
|
||
|
for name, param in temporal_model_no_attn.named_parameters():
|
||
|
print("parameter {} contains {} elements".format(name, param.nelement()))
|
||
|
nb_params += param.nelement()
|
||
|
print('no attention para:', nb_params / 1e6)'''
|