127 lines
6.3 KiB
Python
127 lines
6.3 KiB
Python
import warnings
|
|
warnings.filterwarnings("ignore")
|
|
import copy
|
|
import pandas as pd
|
|
import numpy as np
|
|
import time, pdb, os, random
|
|
import pickle as pkl
|
|
import seaborn as sn
|
|
import itertools
|
|
from itertools import product
|
|
from argparse import ArgumentParser
|
|
from ast import literal_eval
|
|
import torch, gc
|
|
import multiprocessing as mp
|
|
import torch.multiprocessing as torchmp
|
|
from utils_Analysis import setup_seed, flattenTuple, generateBPE, generateVocabdict, generateWordsVocab, fit_transformer, poolsegmentTokenise, sort_by_key_len
|
|
|
|
if __name__=='__main__':
|
|
parser = ArgumentParser(description='BPE on datasets')
|
|
parser.add_argument('--seed', dest='seed', type=int, required=True)
|
|
parser.add_argument('-m','--modality', dest='modality', choices=['mouse', 'keyboard', 'both'], type=str, required=True)
|
|
parser.add_argument('--IDTfile', type=str, required=False)
|
|
parser.add_argument('--keyfile', type=str, required=False)
|
|
parser.add_argument('--lr', type=float, required=True)
|
|
parser.add_argument('--n_layers', type=int, required=True)
|
|
parser.add_argument('--d_model', type=int, required=True)
|
|
parser.add_argument('--batch_size', type=int, required=True)
|
|
parser.add_argument('--dropout', type=float, required=True)
|
|
parser.add_argument('--nhead', type=int, required=True)
|
|
parser.add_argument('--optimizer_name', type=str, required=True)
|
|
parser.add_argument('--epochs', type=int, required=True)
|
|
parser.add_argument('--iteration', type=int, required=True)
|
|
args = parser.parse_args()
|
|
setup_seed(args.seed)
|
|
args.folds = 5
|
|
device = torch.device('cuda')
|
|
WINLENS_key = [10, 50, 100]
|
|
WINLENS_mouse = [20, 100, 200]
|
|
WINLENS_both = [15, 75, 150]
|
|
pool = mp.Pool(mp.cpu_count())
|
|
|
|
if args.modality=='mouse':
|
|
WINLENS = WINLENS_mouse
|
|
with open(args.IDTfile, 'rb') as f:
|
|
data = pkl.load(f)
|
|
elif args.modality=='keyboard':
|
|
WINLENS = WINLENS_key
|
|
with open(args.keyfile, 'rb') as f:
|
|
data = pkl.load(f)
|
|
elif args.modality=='both':
|
|
WINLENS = WINLENS_both
|
|
with open(args.IDTfile, 'rb') as f:
|
|
data = pkl.load(f)
|
|
with open(args.keyfile, 'rb') as f:
|
|
keydata = pkl.load(f)
|
|
data = pd.concat((data, keydata))
|
|
data.sort_index(inplace=True)
|
|
|
|
users = data.user.unique()
|
|
random.shuffle(users)
|
|
userInAFold = int(len(users)/args.folds)
|
|
|
|
for fold in range(args.folds):
|
|
print('----- Fold %d -----'%(fold))
|
|
testuser = users[fold*userInAFold:(fold+1)*userInAFold]
|
|
trainuser = set(users)-set(testuser)
|
|
trset = data[data['user'].isin(trainuser)]
|
|
teset = data[data['user'].isin(testuser)]
|
|
|
|
vocabdict = generateVocabdict(trset)
|
|
words, vocab = generateWordsVocab(trset, vocabdict)
|
|
vocabdict['unknown'] = tuple([len(vocabdict)])
|
|
vocabdict['padding'] = tuple([len(vocabdict)])
|
|
nLabels = data.task.unique().shape[0]
|
|
|
|
generateBPE(args.iteration, vocab, vocabdict, words, pool)
|
|
|
|
with open('%d.pkl'%(args.iteration-1), 'rb') as f:
|
|
vocab, _, _, _ = pkl.load(f)
|
|
flatvocab = set()
|
|
for x in vocab:
|
|
flatvocab = flatvocab.union([(tuple(flattenTuple(x)))])
|
|
BPEvocabdict = dict((x, idx) for idx, x in enumerate(flatvocab))
|
|
assert len(flatvocab)==len(BPEvocabdict)
|
|
rankedvocab = sort_by_key_len(BPEvocabdict)
|
|
rankedvocab = rankedvocab + [{tuple([vocabdict['unknown']]):len(rankedvocab)},{tuple([vocabdict['padding']]):len(rankedvocab)+1}]
|
|
|
|
stackeddata = pool.starmap(poolsegmentTokenise, [(gkey, gdata, win_len, vocabdict, 'train', rankedvocab) for (gkey,gdata),win_len in product(trset.groupby(['user', 'task', 'session']), WINLENS)])
|
|
minlabel = np.inf
|
|
stackedtrdata, stackedtrlabel = dict(zip(WINLENS,[[] for i in range(len(WINLENS))])), dict(zip(WINLENS,[[] for i in range(len(WINLENS))]))
|
|
for segments, labels, unknowntoken, paddingtoken in stackeddata:
|
|
if segments.shape[0]==0:
|
|
continue
|
|
assert vocabdict['padding'][0]==paddingtoken
|
|
assert vocabdict['unknown'][0]==unknowntoken
|
|
if len(stackedtrdata[segments.shape[1]])==0:
|
|
stackedtrdata[segments.shape[1]] = segments
|
|
else:
|
|
stackedtrdata[segments.shape[1]] = np.concatenate((stackedtrdata[segments.shape[1]], segments), axis=0)
|
|
stackedtrlabel[segments.shape[1]] = np.array(list(stackedtrlabel[segments.shape[1]]) + labels)
|
|
if np.min(labels)<minlabel:
|
|
minlabel = np.min(labels)
|
|
|
|
stackeddata = pool.starmap(poolsegmentTokenise, [(gkey, gdata, win_len, vocabdict, 'test', rankedvocab) for (gkey,gdata),win_len in product(teset.groupby(['user', 'task', 'session']), WINLENS)])
|
|
stackedtedata, stackedtelabel = dict(zip(WINLENS,[[] for i in range(len(WINLENS))])), dict(zip(WINLENS,[[] for i in range(len(WINLENS))]))
|
|
for segments, labels, unknowntoken, paddingtoken in stackeddata:
|
|
if segments.shape[0]==0:
|
|
continue
|
|
assert vocabdict['padding'][0]==paddingtoken
|
|
assert vocabdict['unknown'][0]==unknowntoken
|
|
if len(stackedtedata[segments.shape[1]])==0:
|
|
stackedtedata[segments.shape[1]] = segments
|
|
else:
|
|
stackedtedata[segments.shape[1]] = np.concatenate((stackedtedata[segments.shape[1]], segments), axis=0)
|
|
stackedtelabel[segments.shape[1]] = np.array(list(stackedtelabel[segments.shape[1]]) + labels)
|
|
if np.min(labels)<minlabel:
|
|
minlabel = np.min(labels)
|
|
for key, _ in stackedtrlabel.items():
|
|
stackedtrlabel[key] = (stackedtrlabel[key]-minlabel).astype(int)
|
|
assert stackedtrlabel[key].shape[0]==stackedtrdata[key].shape[0]
|
|
stackedtelabel[key] = (stackedtelabel[key]-minlabel).astype(int)
|
|
assert stackedtelabel[key].shape[0]==stackedtedata[key].shape[0]
|
|
|
|
for win_len, trdata in stackedtrdata.items():
|
|
trlabel, tedata, telabel = stackedtrlabel[win_len], stackedtedata[win_len], stackedtelabel[win_len]
|
|
assert tedata.shape[1]==trdata.shape[1]
|
|
fit_transformer(trdata, trlabel, tedata, telabel, args, device, paddingtoken, nLabels, len(rankedvocab)) |