import argparse import torch from utils import MakeDir, set_seed, str2bool from CRT import Model, self_supervised_learning, finetuning from base_models import FTDataSet if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--ssl", type=str2bool, default=False, help='Self-supervised learning (pretrain the autoencoder)') parser.add_argument("--sl", type=str2bool, default=True, help='Supervised learning (downstream tasks)') parser.add_argument("--load", type=str2bool, default=True) parser.add_argument("--patch_len", type=int, default=5) parser.add_argument("--dim", type=int, default=128) parser.add_argument("--depth", type=int, default=6) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--Headdropout", type=float, default=0.1) parser.add_argument("--beta2", type=float, default=1) parser.add_argument("--AElearningRate", type=float, default=1e-4) parser.add_argument("--HeadlearningRate", type=float, default=1e-3) parser.add_argument("--AEbatchSize", type=int, default=512) parser.add_argument("--HeadbatchSize", type=int, default=32) parser.add_argument("--AEepochs", type=int, default=100) parser.add_argument("--Headepochs", type=int, default=50) parser.add_argument("--AEwin", type=int, default=5) parser.add_argument("--Headwin", type=int, default=5) parser.add_argument("--slid", type=int, default=1) parser.add_argument("--timeWinFreq", type=int, default=20) parser.add_argument("--pretrainDataset", type=str, default='Buffalo_EMAKI') parser.add_argument("--testDataset", type=str) parser.add_argument("--stage", type=int, choices=[0,1], default=0, help='0: Train the classifier only; 1: Finetune the whole model') parser.add_argument("--seed", type=int, default=0) opt = parser.parse_args() opt.in_dim = 2 # x,y opt.seq_len = opt.AEwin * opt.timeWinFreq *2 opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') set_seed(opt.seed) aemodelfile = 'pretrained_model.pkl' if opt.ssl: ''' Load dataset for pretraining the Autoencoder taskdata: on-screen coordinates (x,y), their magnitude and phase. Shape (N, 2, 200) - N: the number of sliding windows generated from the mouse data - 2: X or Y - 200: 100 (20Hz * 5s) + 50 (half of 100 data points) + 50 (half of 100 data points), see the DFT function in utils.py taskclicks: indicate if each mouse data point is a click (1) or move (0) event. Shape (N, 100) - 100: 200Hz * 5s ''' taskdata, taskclicks = loadPretrainDataset(opt) # Plug the loader of your dataset model = Model(opt).to(opt.device) self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile) if opt.load: model = torch.load(aemodelfile, map_location=opt.device) else: model = Model(opt).to(opt.device) if opt.sl: ''' Load dataset for downstream tasks _X: on-screen coordinates (x,y), their magnitude and phase _y: labels _clicks: clicks ''' train_X, train_y, train_clicks, valid_X, valid_y, valid_clicks = loadDownstreamDataset(opt) TrainSet = FTDataSet(train_X, train_y, train_clicks, multi_label=True) # Binary or multi-class ValidSet = FTDataSet(valid_X, valid_y, valid_clicks, multi_label=True) finetuning(model, TrainSet, ValidSet, opt, aemodelfile.split('.pkl')[0]+'downstream.pkl', multi_label=True)