29 lines
780 B
Python
29 lines
780 B
Python
|
import os
|
||
|
from templates import *
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
gpus = [0]
|
||
|
|
||
|
conf = mouse_autoenc('trainDiff')
|
||
|
|
||
|
betas = {
|
||
|
'recon':1,
|
||
|
'noise':1,
|
||
|
'user':0.01,
|
||
|
'nonuser':0.01,
|
||
|
'mi': 0.01
|
||
|
}
|
||
|
betastr = ''
|
||
|
for k,v in betas.items():
|
||
|
betastr += f'{k}{v}_'
|
||
|
betastr = betastr[:-1]
|
||
|
|
||
|
diffmodel = LitModel(conf, betas)
|
||
|
|
||
|
conf.logdir= f'{conf.logdir}/mouse_autoenc/{conf.pretrainDataset}/{betastr}/embDim{conf.net_beatgans_embed_channels}/win{conf.AEwin}/slid{conf.slid}/GRL/'
|
||
|
MakeDir(conf.logdir)
|
||
|
os.environ['WANDB_CACHE_DIR'] = conf.logdir
|
||
|
os.environ['WANDB_DATA_DIR'] = conf.logdir
|
||
|
os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt'
|
||
|
|
||
|
train(conf, diffmodel, gpus=gpus)
|