EMAKI/BPE.py

127 lines
6.3 KiB
Python

import warnings
warnings.filterwarnings("ignore")
import copy
import pandas as pd
import numpy as np
import time, pdb, os, random
import pickle as pkl
import seaborn as sn
import itertools
from itertools import product
from argparse import ArgumentParser
from ast import literal_eval
import torch, gc
import multiprocessing as mp
import torch.multiprocessing as torchmp
from utils_Analysis import setup_seed, flattenTuple, generateBPE, generateVocabdict, generateWordsVocab, fit_transformer, poolsegmentTokenise, sort_by_key_len
if __name__=='__main__':
parser = ArgumentParser(description='BPE on datasets')
parser.add_argument('--seed', dest='seed', type=int, required=True)
parser.add_argument('-m','--modality', dest='modality', choices=['mouse', 'keyboard', 'both'], type=str, required=True)
parser.add_argument('--IDTfile', type=str, required=False)
parser.add_argument('--keyfile', type=str, required=False)
parser.add_argument('--lr', type=float, required=True)
parser.add_argument('--n_layers', type=int, required=True)
parser.add_argument('--d_model', type=int, required=True)
parser.add_argument('--batch_size', type=int, required=True)
parser.add_argument('--dropout', type=float, required=True)
parser.add_argument('--nhead', type=int, required=True)
parser.add_argument('--optimizer_name', type=str, required=True)
parser.add_argument('--epochs', type=int, required=True)
parser.add_argument('--iteration', type=int, required=True)
args = parser.parse_args()
setup_seed(args.seed)
args.folds = 5
device = torch.device('cuda')
WINLENS_key = [10, 50, 100]
WINLENS_mouse = [20, 100, 200]
WINLENS_both = [15, 75, 150]
pool = mp.Pool(mp.cpu_count())
if args.modality=='mouse':
WINLENS = WINLENS_mouse
with open(args.IDTfile, 'rb') as f:
data = pkl.load(f)
elif args.modality=='keyboard':
WINLENS = WINLENS_key
with open(args.keyfile, 'rb') as f:
data = pkl.load(f)
elif args.modality=='both':
WINLENS = WINLENS_both
with open(args.IDTfile, 'rb') as f:
data = pkl.load(f)
with open(args.keyfile, 'rb') as f:
keydata = pkl.load(f)
data = pd.concat((data, keydata))
data.sort_index(inplace=True)
users = data.user.unique()
random.shuffle(users)
userInAFold = int(len(users)/args.folds)
for fold in range(args.folds):
print('----- Fold %d -----'%(fold))
testuser = users[fold*userInAFold:(fold+1)*userInAFold]
trainuser = set(users)-set(testuser)
trset = data[data['user'].isin(trainuser)]
teset = data[data['user'].isin(testuser)]
vocabdict = generateVocabdict(trset)
words, vocab = generateWordsVocab(trset, vocabdict)
vocabdict['unknown'] = tuple([len(vocabdict)])
vocabdict['padding'] = tuple([len(vocabdict)])
nLabels = data.task.unique().shape[0]
generateBPE(args.iteration, vocab, vocabdict, words, pool)
with open('%d.pkl'%(args.iteration-1), 'rb') as f:
vocab, _, _, _ = pkl.load(f)
flatvocab = set()
for x in vocab:
flatvocab = flatvocab.union([(tuple(flattenTuple(x)))])
BPEvocabdict = dict((x, idx) for idx, x in enumerate(flatvocab))
assert len(flatvocab)==len(BPEvocabdict)
rankedvocab = sort_by_key_len(BPEvocabdict)
rankedvocab = rankedvocab + [{tuple([vocabdict['unknown']]):len(rankedvocab)},{tuple([vocabdict['padding']]):len(rankedvocab)+1}]
stackeddata = pool.starmap(poolsegmentTokenise, [(gkey, gdata, win_len, vocabdict, 'train', rankedvocab) for (gkey,gdata),win_len in product(trset.groupby(['user', 'task', 'session']), WINLENS)])
minlabel = np.inf
stackedtrdata, stackedtrlabel = dict(zip(WINLENS,[[] for i in range(len(WINLENS))])), dict(zip(WINLENS,[[] for i in range(len(WINLENS))]))
for segments, labels, unknowntoken, paddingtoken in stackeddata:
if segments.shape[0]==0:
continue
assert vocabdict['padding'][0]==paddingtoken
assert vocabdict['unknown'][0]==unknowntoken
if len(stackedtrdata[segments.shape[1]])==0:
stackedtrdata[segments.shape[1]] = segments
else:
stackedtrdata[segments.shape[1]] = np.concatenate((stackedtrdata[segments.shape[1]], segments), axis=0)
stackedtrlabel[segments.shape[1]] = np.array(list(stackedtrlabel[segments.shape[1]]) + labels)
if np.min(labels)<minlabel:
minlabel = np.min(labels)
stackeddata = pool.starmap(poolsegmentTokenise, [(gkey, gdata, win_len, vocabdict, 'test', rankedvocab) for (gkey,gdata),win_len in product(teset.groupby(['user', 'task', 'session']), WINLENS)])
stackedtedata, stackedtelabel = dict(zip(WINLENS,[[] for i in range(len(WINLENS))])), dict(zip(WINLENS,[[] for i in range(len(WINLENS))]))
for segments, labels, unknowntoken, paddingtoken in stackeddata:
if segments.shape[0]==0:
continue
assert vocabdict['padding'][0]==paddingtoken
assert vocabdict['unknown'][0]==unknowntoken
if len(stackedtedata[segments.shape[1]])==0:
stackedtedata[segments.shape[1]] = segments
else:
stackedtedata[segments.shape[1]] = np.concatenate((stackedtedata[segments.shape[1]], segments), axis=0)
stackedtelabel[segments.shape[1]] = np.array(list(stackedtelabel[segments.shape[1]]) + labels)
if np.min(labels)<minlabel:
minlabel = np.min(labels)
for key, _ in stackedtrlabel.items():
stackedtrlabel[key] = (stackedtrlabel[key]-minlabel).astype(int)
assert stackedtrlabel[key].shape[0]==stackedtrdata[key].shape[0]
stackedtelabel[key] = (stackedtelabel[key]-minlabel).astype(int)
assert stackedtelabel[key].shape[0]==stackedtedata[key].shape[0]
for win_len, trdata in stackedtrdata.items():
trlabel, tedata, telabel = stackedtrlabel[win_len], stackedtedata[win_len], stackedtelabel[win_len]
assert tedata.shape[1]==trdata.shape[1]
fit_transformer(trdata, trlabel, tedata, telabel, args, device, paddingtoken, nLabels, len(rankedvocab))