55 lines
No EOL
1.5 KiB
Python
55 lines
No EOL
1.5 KiB
Python
from experiment import *
|
|
|
|
def autoenc_base():
|
|
conf = TrainConfig()
|
|
conf.batch_size = 32
|
|
conf.beatgans_gen_type = GenerativeType.ddim
|
|
conf.beta_scheduler = 'linear'
|
|
conf.data_name = 'ffhq'
|
|
conf.diffusion_type = 'beatgans'
|
|
conf.eval_ema_every_samples = 200_000
|
|
conf.eval_every_samples = 200_000
|
|
conf.fp16 = True
|
|
conf.lr = 1e-4
|
|
conf.model_name = ModelName.beatgans_autoenc
|
|
conf.net_attn = (16, )
|
|
conf.net_beatgans_attn_head = 1
|
|
conf.net_beatgans_embed_channels = 128
|
|
conf.net_beatgans_resnet_two_cond = True
|
|
conf.net_ch_mult = (1, 2, 4, 8)
|
|
conf.net_ch = 64
|
|
conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
|
|
conf.net_enc_pool = 'adaptivenonzero'
|
|
conf.sample_size = 32
|
|
conf.T_eval = 20
|
|
conf.T = 1000
|
|
conf.make_model_conf()
|
|
return conf
|
|
|
|
def mouse_autoenc(mode):
|
|
num_users = {'Clarkson': 75}
|
|
|
|
conf = autoenc_base()
|
|
conf.scale_up_gpus(1)
|
|
conf.img_size = 256
|
|
conf.net_ch = 128
|
|
conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
|
|
conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
|
|
conf.eval_every_samples = 10_000_000
|
|
conf.eval_ema_every_samples = 10_000_000
|
|
conf.total_samples = 200_000_000
|
|
conf.batch_size = 512
|
|
conf.name = 'mouse_autoenc'
|
|
|
|
conf.pretrainDataset = 'Clarkson'
|
|
conf.data_name = 'Clarkson'
|
|
|
|
conf.path = f'../mousedata/{conf.pretrainDataset}/'
|
|
conf.AEwin = 8
|
|
conf.slid = 1
|
|
conf.timeWinFreq = 20
|
|
conf.num_users = num_users[conf.pretrainDataset]
|
|
|
|
conf.mode = mode
|
|
conf.make_model_conf()
|
|
return conf |