init
This commit is contained in:
commit
b429f9807d
7 changed files with 1186 additions and 0 deletions
77
main.py
Normal file
77
main.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import numpy as np
|
||||
import random, pdb, os, copy
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import pickle as pkl
|
||||
import torch
|
||||
|
||||
from utils import MakeDir, set_seed, str2bool
|
||||
from CRT import Model, self_supervised_learning, finetuning
|
||||
from base_models import FTDataSet
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ssl", type=str2bool, default=False, help='Self-supervised learning (pretrain the autoencoder)')
|
||||
parser.add_argument("--sl", type=str2bool, default=True, help='Supervised learning (downstream tasks)')
|
||||
parser.add_argument("--load", type=str2bool, default=True)
|
||||
|
||||
parser.add_argument("--patch_len", type=int, default=5)
|
||||
parser.add_argument("--dim", type=int, default=128)
|
||||
parser.add_argument("--depth", type=int, default=6)
|
||||
parser.add_argument("--dropout", type=float, default=0.1)
|
||||
parser.add_argument("--Headdropout", type=float, default=0.1)
|
||||
parser.add_argument("--beta2", type=float, default=1)
|
||||
|
||||
parser.add_argument("--AElearningRate", type=float, default=1e-4)
|
||||
parser.add_argument("--HeadlearningRate", type=float, default=1e-3)
|
||||
parser.add_argument("--AEbatchSize", type=int, default=512)
|
||||
parser.add_argument("--HeadbatchSize", type=int, default=32)
|
||||
parser.add_argument("--AEepochs", type=int, default=100)
|
||||
parser.add_argument("--Headepochs", type=int, default=50)
|
||||
|
||||
parser.add_argument("--AEwin", type=int, default=5)
|
||||
parser.add_argument("--Headwin", type=int, default=5)
|
||||
parser.add_argument("--slid", type=int, default=1)
|
||||
parser.add_argument("--timeWinFreq", type=int, default=20)
|
||||
|
||||
parser.add_argument("--pretrainDataset", type=str, default='Buffalo_EMAKI')
|
||||
parser.add_argument("--testDataset", type=str)
|
||||
|
||||
parser.add_argument("--stage", type=int, choices=[0,1], default=0, help='0: Train the classifier only; 1: Finetune the whole model')
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
opt = parser.parse_args()
|
||||
opt.in_dim = 2 # x,y
|
||||
opt.seq_len = opt.AEwin * opt.timeWinFreq *2
|
||||
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
set_seed(opt.seed)
|
||||
aemodelfile = 'pretrained_model.pkl'
|
||||
|
||||
if opt.ssl:
|
||||
'''
|
||||
Load dataset for pretraining the Autoencoder
|
||||
taskdata: on-screen coordinates (x,y), their magnitude and phase
|
||||
'''
|
||||
taskdata, taskclicks = loadPretrainDataset(opt)
|
||||
model = Model(opt).to(opt.device)
|
||||
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
|
||||
|
||||
if opt.load:
|
||||
model = torch.load(aemodelfile, map_location=opt.device)
|
||||
else:
|
||||
model = Model(opt).to(opt.device)
|
||||
pdb.set_trace()
|
||||
|
||||
if opt.sl:
|
||||
'''
|
||||
Load dataset for downstream tasks
|
||||
_X: on-screen coordinates (x,y), their magnitude and phase
|
||||
_y: labels
|
||||
_clicks: clicks
|
||||
'''
|
||||
train_X, train_y, train_clicks, valid_X, valid_y, valid_clicks = loadDownstreamDataset(opt)
|
||||
|
||||
TrainSet = FTDataSet(train_X, train_y, train_clicks, multi_label=True) # Binary or multi-class
|
||||
ValidSet = FTDataSet(valid_X, valid_y, valid_clicks, multi_label=True)
|
||||
|
||||
finetuning(model, TrainSet, ValidSet, opt, aemodelfile.split('.pkl')[0]+'downstream.pkl', multi_label=True)
|
Loading…
Add table
Add a link
Reference in a new issue