77 lines
No EOL
3.5 KiB
Python
77 lines
No EOL
3.5 KiB
Python
import argparse
|
|
import torch
|
|
|
|
from utils import MakeDir, set_seed, str2bool
|
|
from CRT import Model, self_supervised_learning, finetuning
|
|
from base_models import FTDataSet
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--ssl", type=str2bool, default=False, help='Self-supervised learning (pretrain the autoencoder)')
|
|
parser.add_argument("--sl", type=str2bool, default=True, help='Supervised learning (downstream tasks)')
|
|
parser.add_argument("--load", type=str2bool, default=True)
|
|
|
|
parser.add_argument("--patch_len", type=int, default=5)
|
|
parser.add_argument("--dim", type=int, default=128)
|
|
parser.add_argument("--depth", type=int, default=6)
|
|
parser.add_argument("--dropout", type=float, default=0.1)
|
|
parser.add_argument("--Headdropout", type=float, default=0.1)
|
|
parser.add_argument("--beta2", type=float, default=1)
|
|
|
|
parser.add_argument("--AElearningRate", type=float, default=1e-4)
|
|
parser.add_argument("--HeadlearningRate", type=float, default=1e-3)
|
|
parser.add_argument("--AEbatchSize", type=int, default=512)
|
|
parser.add_argument("--HeadbatchSize", type=int, default=32)
|
|
parser.add_argument("--AEepochs", type=int, default=100)
|
|
parser.add_argument("--Headepochs", type=int, default=50)
|
|
|
|
parser.add_argument("--AEwin", type=int, default=5)
|
|
parser.add_argument("--Headwin", type=int, default=5)
|
|
parser.add_argument("--slid", type=int, default=1)
|
|
parser.add_argument("--timeWinFreq", type=int, default=20)
|
|
|
|
parser.add_argument("--pretrainDataset", type=str, default='Buffalo_EMAKI')
|
|
parser.add_argument("--testDataset", type=str)
|
|
|
|
parser.add_argument("--stage", type=int, choices=[0,1], default=0, help='0: Train the classifier only; 1: Finetune the whole model')
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
opt = parser.parse_args()
|
|
opt.in_dim = 2 # x,y
|
|
opt.seq_len = opt.AEwin * opt.timeWinFreq *2
|
|
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
set_seed(opt.seed)
|
|
aemodelfile = 'pretrained_model.pkl'
|
|
|
|
if opt.ssl:
|
|
'''
|
|
Load dataset for pretraining the Autoencoder
|
|
taskdata: on-screen coordinates (x,y), their magnitude and phase. Shape (N, 2, 200)
|
|
- N: the number of sliding windows generated from the mouse data
|
|
- 2: X or Y
|
|
- 200: 100 (20Hz * 5s) + 50 (half of 100 data points) + 50 (half of 100 data points), see the DFT function in utils.py
|
|
taskclicks: indicate if each mouse data point is a click (1) or move (0) event. Shape (N, 100)
|
|
- 100: 200Hz * 5s
|
|
'''
|
|
taskdata, taskclicks = loadPretrainDataset(opt) # Plug the loader of your dataset
|
|
model = Model(opt).to(opt.device)
|
|
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
|
|
|
|
if opt.load:
|
|
model = torch.load(aemodelfile, map_location=opt.device)
|
|
else:
|
|
model = Model(opt).to(opt.device)
|
|
|
|
if opt.sl:
|
|
'''
|
|
Load dataset for downstream tasks
|
|
_X: on-screen coordinates (x,y), their magnitude and phase
|
|
_y: labels
|
|
_clicks: clicks
|
|
'''
|
|
train_X, train_y, train_clicks, valid_X, valid_y, valid_clicks = loadDownstreamDataset(opt)
|
|
|
|
TrainSet = FTDataSet(train_X, train_y, train_clicks, multi_label=True) # Binary or multi-class
|
|
ValidSet = FTDataSet(valid_X, valid_y, valid_clicks, multi_label=True)
|
|
|
|
finetuning(model, TrainSet, ValidSet, opt, aemodelfile.split('.pkl')[0]+'downstream.pkl', multi_label=True) |