Mouse2Vec/main.py

77 lines
3.5 KiB
Python
Raw Normal View History

2024-05-07 17:01:12 +02:00
import argparse
import torch
from utils import MakeDir, set_seed, str2bool
from CRT import Model, self_supervised_learning, finetuning
from base_models import FTDataSet
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ssl", type=str2bool, default=False, help='Self-supervised learning (pretrain the autoencoder)')
parser.add_argument("--sl", type=str2bool, default=True, help='Supervised learning (downstream tasks)')
parser.add_argument("--load", type=str2bool, default=True)
parser.add_argument("--patch_len", type=int, default=5)
parser.add_argument("--dim", type=int, default=128)
parser.add_argument("--depth", type=int, default=6)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--Headdropout", type=float, default=0.1)
parser.add_argument("--beta2", type=float, default=1)
parser.add_argument("--AElearningRate", type=float, default=1e-4)
parser.add_argument("--HeadlearningRate", type=float, default=1e-3)
parser.add_argument("--AEbatchSize", type=int, default=512)
parser.add_argument("--HeadbatchSize", type=int, default=32)
parser.add_argument("--AEepochs", type=int, default=100)
parser.add_argument("--Headepochs", type=int, default=50)
parser.add_argument("--AEwin", type=int, default=5)
parser.add_argument("--Headwin", type=int, default=5)
parser.add_argument("--slid", type=int, default=1)
parser.add_argument("--timeWinFreq", type=int, default=20)
parser.add_argument("--pretrainDataset", type=str, default='Buffalo_EMAKI')
parser.add_argument("--testDataset", type=str)
parser.add_argument("--stage", type=int, choices=[0,1], default=0, help='0: Train the classifier only; 1: Finetune the whole model')
parser.add_argument("--seed", type=int, default=0)
opt = parser.parse_args()
opt.in_dim = 2 # x,y
opt.seq_len = opt.AEwin * opt.timeWinFreq *2
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(opt.seed)
aemodelfile = 'pretrained_model.pkl'
if opt.ssl:
'''
Load dataset for pretraining the Autoencoder
2024-10-10 14:01:06 +02:00
taskdata: on-screen coordinates (x,y), their magnitude and phase. Shape (N, 2, 200)
- N: the number of sliding windows generated from the mouse data
- 2: X or Y
- 200: 100 (20Hz * 5s) + 50 (half of 100 data points) + 50 (half of 100 data points), see the DFT function in utils.py
taskclicks: indicate if each mouse data point is a click (1) or move (0) event. Shape (N, 100)
- 100: 200Hz * 5s
2024-05-07 17:01:12 +02:00
'''
2024-10-10 14:01:06 +02:00
taskdata, taskclicks = loadPretrainDataset(opt) # Plug the loader of your dataset
2024-05-07 17:01:12 +02:00
model = Model(opt).to(opt.device)
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
if opt.load:
model = torch.load(aemodelfile, map_location=opt.device)
else:
model = Model(opt).to(opt.device)
if opt.sl:
'''
Load dataset for downstream tasks
_X: on-screen coordinates (x,y), their magnitude and phase
_y: labels
_clicks: clicks
'''
train_X, train_y, train_clicks, valid_X, valid_y, valid_clicks = loadDownstreamDataset(opt)
TrainSet = FTDataSet(train_X, train_y, train_clicks, multi_label=True) # Binary or multi-class
ValidSet = FTDataSet(valid_X, valid_y, valid_clicks, multi_label=True)
finetuning(model, TrainSet, ValidSet, opt, aemodelfile.split('.pkl')[0]+'downstream.pkl', multi_label=True)