from experiment import * def autoenc_base(): conf = TrainConfig() conf.batch_size = 32 conf.beatgans_gen_type = GenerativeType.ddim conf.beta_scheduler = 'linear' conf.data_name = 'ffhq' conf.diffusion_type = 'beatgans' conf.eval_ema_every_samples = 200_000 conf.eval_every_samples = 200_000 conf.fp16 = True conf.lr = 1e-4 conf.model_name = ModelName.beatgans_autoenc conf.net_attn = (16, ) conf.net_beatgans_attn_head = 1 conf.net_beatgans_embed_channels = 128 conf.net_beatgans_resnet_two_cond = True conf.net_ch_mult = (1, 2, 4, 8) conf.net_ch = 64 conf.net_enc_channel_mult = (1, 2, 4, 8, 8) conf.net_enc_pool = 'adaptivenonzero' conf.sample_size = 32 conf.T_eval = 20 conf.T = 1000 conf.make_model_conf() return conf def mouse_autoenc(mode): num_users = {'Clarkson': 75} conf = autoenc_base() conf.scale_up_gpus(1) conf.img_size = 256 conf.net_ch = 128 conf.net_ch_mult = (1, 1, 2, 2, 4, 4) conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) conf.eval_every_samples = 10_000_000 conf.eval_ema_every_samples = 10_000_000 conf.total_samples = 200_000_000 conf.batch_size = 512 conf.name = 'mouse_autoenc' conf.pretrainDataset = 'Clarkson' conf.data_name = 'Clarkson' conf.path = f'../mousedata/{conf.pretrainDataset}/' conf.AEwin = 8 conf.slid = 1 conf.timeWinFreq = 20 conf.num_users = num_users[conf.pretrainDataset] conf.mode = mode conf.make_model_conf() return conf