DisMouse/main.py

29 lines
780 B
Python
Raw Normal View History

2024-10-08 14:18:47 +02:00
import os
from templates import *
if __name__ == '__main__':
gpus = [0]
conf = mouse_autoenc('trainDiff')
betas = {
'recon':1,
'noise':1,
'user':0.01,
'nonuser':0.01,
'mi': 0.01
}
betastr = ''
for k,v in betas.items():
betastr += f'{k}{v}_'
betastr = betastr[:-1]
diffmodel = LitModel(conf, betas)
conf.logdir= f'{conf.logdir}/mouse_autoenc/{conf.pretrainDataset}/{betastr}/embDim{conf.net_beatgans_embed_channels}/win{conf.AEwin}/slid{conf.slid}/GRL/'
MakeDir(conf.logdir)
os.environ['WANDB_CACHE_DIR'] = conf.logdir
os.environ['WANDB_DATA_DIR'] = conf.logdir
os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt'
train(conf, diffmodel, gpus=gpus)