Mouse2Vec/main.py
2024-05-07 17:01:12 +02:00

77 lines
No EOL
3.2 KiB
Python

import numpy as np
import random, pdb, os, copy
import argparse
import pandas as pd
import pickle as pkl
import torch
from utils import MakeDir, set_seed, str2bool
from CRT import Model, self_supervised_learning, finetuning
from base_models import FTDataSet
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ssl", type=str2bool, default=False, help='Self-supervised learning (pretrain the autoencoder)')
parser.add_argument("--sl", type=str2bool, default=True, help='Supervised learning (downstream tasks)')
parser.add_argument("--load", type=str2bool, default=True)
parser.add_argument("--patch_len", type=int, default=5)
parser.add_argument("--dim", type=int, default=128)
parser.add_argument("--depth", type=int, default=6)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--Headdropout", type=float, default=0.1)
parser.add_argument("--beta2", type=float, default=1)
parser.add_argument("--AElearningRate", type=float, default=1e-4)
parser.add_argument("--HeadlearningRate", type=float, default=1e-3)
parser.add_argument("--AEbatchSize", type=int, default=512)
parser.add_argument("--HeadbatchSize", type=int, default=32)
parser.add_argument("--AEepochs", type=int, default=100)
parser.add_argument("--Headepochs", type=int, default=50)
parser.add_argument("--AEwin", type=int, default=5)
parser.add_argument("--Headwin", type=int, default=5)
parser.add_argument("--slid", type=int, default=1)
parser.add_argument("--timeWinFreq", type=int, default=20)
parser.add_argument("--pretrainDataset", type=str, default='Buffalo_EMAKI')
parser.add_argument("--testDataset", type=str)
parser.add_argument("--stage", type=int, choices=[0,1], default=0, help='0: Train the classifier only; 1: Finetune the whole model')
parser.add_argument("--seed", type=int, default=0)
opt = parser.parse_args()
opt.in_dim = 2 # x,y
opt.seq_len = opt.AEwin * opt.timeWinFreq *2
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(opt.seed)
aemodelfile = 'pretrained_model.pkl'
if opt.ssl:
'''
Load dataset for pretraining the Autoencoder
taskdata: on-screen coordinates (x,y), their magnitude and phase
'''
taskdata, taskclicks = loadPretrainDataset(opt)
model = Model(opt).to(opt.device)
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
if opt.load:
model = torch.load(aemodelfile, map_location=opt.device)
else:
model = Model(opt).to(opt.device)
pdb.set_trace()
if opt.sl:
'''
Load dataset for downstream tasks
_X: on-screen coordinates (x,y), their magnitude and phase
_y: labels
_clicks: clicks
'''
train_X, train_y, train_clicks, valid_X, valid_y, valid_clicks = loadDownstreamDataset(opt)
TrainSet = FTDataSet(train_X, train_y, train_clicks, multi_label=True) # Binary or multi-class
ValidSet = FTDataSet(valid_X, valid_y, valid_clicks, multi_label=True)
finetuning(model, TrainSet, ValidSet, opt, aemodelfile.split('.pkl')[0]+'downstream.pkl', multi_label=True)