77 lines
3.2 KiB
Python
77 lines
3.2 KiB
Python
|
import numpy as np
|
||
|
import random, pdb, os, copy
|
||
|
import argparse
|
||
|
import pandas as pd
|
||
|
import pickle as pkl
|
||
|
import torch
|
||
|
|
||
|
from utils import MakeDir, set_seed, str2bool
|
||
|
from CRT import Model, self_supervised_learning, finetuning
|
||
|
from base_models import FTDataSet
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--ssl", type=str2bool, default=False, help='Self-supervised learning (pretrain the autoencoder)')
|
||
|
parser.add_argument("--sl", type=str2bool, default=True, help='Supervised learning (downstream tasks)')
|
||
|
parser.add_argument("--load", type=str2bool, default=True)
|
||
|
|
||
|
parser.add_argument("--patch_len", type=int, default=5)
|
||
|
parser.add_argument("--dim", type=int, default=128)
|
||
|
parser.add_argument("--depth", type=int, default=6)
|
||
|
parser.add_argument("--dropout", type=float, default=0.1)
|
||
|
parser.add_argument("--Headdropout", type=float, default=0.1)
|
||
|
parser.add_argument("--beta2", type=float, default=1)
|
||
|
|
||
|
parser.add_argument("--AElearningRate", type=float, default=1e-4)
|
||
|
parser.add_argument("--HeadlearningRate", type=float, default=1e-3)
|
||
|
parser.add_argument("--AEbatchSize", type=int, default=512)
|
||
|
parser.add_argument("--HeadbatchSize", type=int, default=32)
|
||
|
parser.add_argument("--AEepochs", type=int, default=100)
|
||
|
parser.add_argument("--Headepochs", type=int, default=50)
|
||
|
|
||
|
parser.add_argument("--AEwin", type=int, default=5)
|
||
|
parser.add_argument("--Headwin", type=int, default=5)
|
||
|
parser.add_argument("--slid", type=int, default=1)
|
||
|
parser.add_argument("--timeWinFreq", type=int, default=20)
|
||
|
|
||
|
parser.add_argument("--pretrainDataset", type=str, default='Buffalo_EMAKI')
|
||
|
parser.add_argument("--testDataset", type=str)
|
||
|
|
||
|
parser.add_argument("--stage", type=int, choices=[0,1], default=0, help='0: Train the classifier only; 1: Finetune the whole model')
|
||
|
parser.add_argument("--seed", type=int, default=0)
|
||
|
opt = parser.parse_args()
|
||
|
opt.in_dim = 2 # x,y
|
||
|
opt.seq_len = opt.AEwin * opt.timeWinFreq *2
|
||
|
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
|
||
|
set_seed(opt.seed)
|
||
|
aemodelfile = 'pretrained_model.pkl'
|
||
|
|
||
|
if opt.ssl:
|
||
|
'''
|
||
|
Load dataset for pretraining the Autoencoder
|
||
|
taskdata: on-screen coordinates (x,y), their magnitude and phase
|
||
|
'''
|
||
|
taskdata, taskclicks = loadPretrainDataset(opt)
|
||
|
model = Model(opt).to(opt.device)
|
||
|
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
|
||
|
|
||
|
if opt.load:
|
||
|
model = torch.load(aemodelfile, map_location=opt.device)
|
||
|
else:
|
||
|
model = Model(opt).to(opt.device)
|
||
|
pdb.set_trace()
|
||
|
|
||
|
if opt.sl:
|
||
|
'''
|
||
|
Load dataset for downstream tasks
|
||
|
_X: on-screen coordinates (x,y), their magnitude and phase
|
||
|
_y: labels
|
||
|
_clicks: clicks
|
||
|
'''
|
||
|
train_X, train_y, train_clicks, valid_X, valid_y, valid_clicks = loadDownstreamDataset(opt)
|
||
|
|
||
|
TrainSet = FTDataSet(train_X, train_y, train_clicks, multi_label=True) # Binary or multi-class
|
||
|
ValidSet = FTDataSet(valid_X, valid_y, valid_clicks, multi_label=True)
|
||
|
|
||
|
finetuning(model, TrainSet, ValidSet, opt, aemodelfile.split('.pkl')[0]+'downstream.pkl', multi_label=True)
|