Dataset and Code added
This commit is contained in:
parent
a89f7125d5
commit
6d937a54e6
4 changed files with 493 additions and 12 deletions
127
BPE.py
Normal file
127
BPE.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
import copy
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import time, pdb, os, random
|
||||||
|
import pickle as pkl
|
||||||
|
import seaborn as sn
|
||||||
|
import itertools
|
||||||
|
from itertools import product
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from ast import literal_eval
|
||||||
|
import torch, gc
|
||||||
|
import multiprocessing as mp
|
||||||
|
import torch.multiprocessing as torchmp
|
||||||
|
from utils_Analysis import setup_seed, flattenTuple, generateBPE, generateVocabdict, generateWordsVocab, fit_transformer, poolsegmentTokenise, sort_by_key_len
|
||||||
|
|
||||||
|
if __name__=='__main__':
|
||||||
|
parser = ArgumentParser(description='BPE on datasets')
|
||||||
|
parser.add_argument('--seed', dest='seed', type=int, required=True)
|
||||||
|
parser.add_argument('-m','--modality', dest='modality', choices=['mouse', 'keyboard', 'both'], type=str, required=True)
|
||||||
|
parser.add_argument('--IDTfile', type=str, required=False)
|
||||||
|
parser.add_argument('--keyfile', type=str, required=False)
|
||||||
|
parser.add_argument('--lr', type=float, required=True)
|
||||||
|
parser.add_argument('--n_layers', type=int, required=True)
|
||||||
|
parser.add_argument('--d_model', type=int, required=True)
|
||||||
|
parser.add_argument('--batch_size', type=int, required=True)
|
||||||
|
parser.add_argument('--dropout', type=float, required=True)
|
||||||
|
parser.add_argument('--nhead', type=int, required=True)
|
||||||
|
parser.add_argument('--optimizer_name', type=str, required=True)
|
||||||
|
parser.add_argument('--epochs', type=int, required=True)
|
||||||
|
parser.add_argument('--iteration', type=int, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
setup_seed(args.seed)
|
||||||
|
args.folds = 5
|
||||||
|
device = torch.device('cuda')
|
||||||
|
WINLENS_key = [10, 50, 100]
|
||||||
|
WINLENS_mouse = [20, 100, 200]
|
||||||
|
WINLENS_both = [15, 75, 150]
|
||||||
|
pool = mp.Pool(mp.cpu_count())
|
||||||
|
|
||||||
|
if args.modality=='mouse':
|
||||||
|
WINLENS = WINLENS_mouse
|
||||||
|
with open(args.IDTfile, 'rb') as f:
|
||||||
|
data = pkl.load(f)
|
||||||
|
elif args.modality=='keyboard':
|
||||||
|
WINLENS = WINLENS_key
|
||||||
|
with open(args.keyfile, 'rb') as f:
|
||||||
|
data = pkl.load(f)
|
||||||
|
elif args.modality=='both':
|
||||||
|
WINLENS = WINLENS_both
|
||||||
|
with open(args.IDTfile, 'rb') as f:
|
||||||
|
data = pkl.load(f)
|
||||||
|
with open(args.keyfile, 'rb') as f:
|
||||||
|
keydata = pkl.load(f)
|
||||||
|
data = pd.concat((data, keydata))
|
||||||
|
data.sort_index(inplace=True)
|
||||||
|
|
||||||
|
users = data.user.unique()
|
||||||
|
random.shuffle(users)
|
||||||
|
userInAFold = int(len(users)/args.folds)
|
||||||
|
|
||||||
|
for fold in range(args.folds):
|
||||||
|
print('----- Fold %d -----'%(fold))
|
||||||
|
testuser = users[fold*userInAFold:(fold+1)*userInAFold]
|
||||||
|
trainuser = set(users)-set(testuser)
|
||||||
|
trset = data[data['user'].isin(trainuser)]
|
||||||
|
teset = data[data['user'].isin(testuser)]
|
||||||
|
|
||||||
|
vocabdict = generateVocabdict(trset)
|
||||||
|
words, vocab = generateWordsVocab(trset, vocabdict)
|
||||||
|
vocabdict['unknown'] = tuple([len(vocabdict)])
|
||||||
|
vocabdict['padding'] = tuple([len(vocabdict)])
|
||||||
|
nLabels = data.task.unique().shape[0]
|
||||||
|
|
||||||
|
generateBPE(args.iteration, vocab, vocabdict, words, pool)
|
||||||
|
|
||||||
|
with open('%d.pkl'%(args.iteration-1), 'rb') as f:
|
||||||
|
vocab, _, _, _ = pkl.load(f)
|
||||||
|
flatvocab = set()
|
||||||
|
for x in vocab:
|
||||||
|
flatvocab = flatvocab.union([(tuple(flattenTuple(x)))])
|
||||||
|
BPEvocabdict = dict((x, idx) for idx, x in enumerate(flatvocab))
|
||||||
|
assert len(flatvocab)==len(BPEvocabdict)
|
||||||
|
rankedvocab = sort_by_key_len(BPEvocabdict)
|
||||||
|
rankedvocab = rankedvocab + [{tuple([vocabdict['unknown']]):len(rankedvocab)},{tuple([vocabdict['padding']]):len(rankedvocab)+1}]
|
||||||
|
|
||||||
|
stackeddata = pool.starmap(poolsegmentTokenise, [(gkey, gdata, win_len, vocabdict, 'train', rankedvocab) for (gkey,gdata),win_len in product(trset.groupby(['user', 'task', 'session']), WINLENS)])
|
||||||
|
minlabel = np.inf
|
||||||
|
stackedtrdata, stackedtrlabel = dict(zip(WINLENS,[[] for i in range(len(WINLENS))])), dict(zip(WINLENS,[[] for i in range(len(WINLENS))]))
|
||||||
|
for segments, labels, unknowntoken, paddingtoken in stackeddata:
|
||||||
|
if segments.shape[0]==0:
|
||||||
|
continue
|
||||||
|
assert vocabdict['padding'][0]==paddingtoken
|
||||||
|
assert vocabdict['unknown'][0]==unknowntoken
|
||||||
|
if len(stackedtrdata[segments.shape[1]])==0:
|
||||||
|
stackedtrdata[segments.shape[1]] = segments
|
||||||
|
else:
|
||||||
|
stackedtrdata[segments.shape[1]] = np.concatenate((stackedtrdata[segments.shape[1]], segments), axis=0)
|
||||||
|
stackedtrlabel[segments.shape[1]] = np.array(list(stackedtrlabel[segments.shape[1]]) + labels)
|
||||||
|
if np.min(labels)<minlabel:
|
||||||
|
minlabel = np.min(labels)
|
||||||
|
|
||||||
|
stackeddata = pool.starmap(poolsegmentTokenise, [(gkey, gdata, win_len, vocabdict, 'test', rankedvocab) for (gkey,gdata),win_len in product(teset.groupby(['user', 'task', 'session']), WINLENS)])
|
||||||
|
stackedtedata, stackedtelabel = dict(zip(WINLENS,[[] for i in range(len(WINLENS))])), dict(zip(WINLENS,[[] for i in range(len(WINLENS))]))
|
||||||
|
for segments, labels, unknowntoken, paddingtoken in stackeddata:
|
||||||
|
if segments.shape[0]==0:
|
||||||
|
continue
|
||||||
|
assert vocabdict['padding'][0]==paddingtoken
|
||||||
|
assert vocabdict['unknown'][0]==unknowntoken
|
||||||
|
if len(stackedtedata[segments.shape[1]])==0:
|
||||||
|
stackedtedata[segments.shape[1]] = segments
|
||||||
|
else:
|
||||||
|
stackedtedata[segments.shape[1]] = np.concatenate((stackedtedata[segments.shape[1]], segments), axis=0)
|
||||||
|
stackedtelabel[segments.shape[1]] = np.array(list(stackedtelabel[segments.shape[1]]) + labels)
|
||||||
|
if np.min(labels)<minlabel:
|
||||||
|
minlabel = np.min(labels)
|
||||||
|
for key, _ in stackedtrlabel.items():
|
||||||
|
stackedtrlabel[key] = (stackedtrlabel[key]-minlabel).astype(int)
|
||||||
|
assert stackedtrlabel[key].shape[0]==stackedtrdata[key].shape[0]
|
||||||
|
stackedtelabel[key] = (stackedtelabel[key]-minlabel).astype(int)
|
||||||
|
assert stackedtelabel[key].shape[0]==stackedtedata[key].shape[0]
|
||||||
|
|
||||||
|
for win_len, trdata in stackedtrdata.items():
|
||||||
|
trlabel, tedata, telabel = stackedtrlabel[win_len], stackedtedata[win_len], stackedtelabel[win_len]
|
||||||
|
assert tedata.shape[1]==trdata.shape[1]
|
||||||
|
fit_transformer(trdata, trlabel, tedata, telabel, args, device, paddingtoken, nLabels, len(rankedvocab))
|
BIN
EMAKI_utt.pkl
Normal file
BIN
EMAKI_utt.pkl
Normal file
Binary file not shown.
38
README.md
38
README.md
|
@ -4,12 +4,19 @@
|
||||||
|
|
||||||
Proc. IFIP TC13 Conference on Human-Computer Interaction (INTERACT), 2023, York, UK
|
Proc. IFIP TC13 Conference on Human-Computer Interaction (INTERACT), 2023, York, UK
|
||||||
|
|
||||||
|
https://link.springer.com/chapter/10.1007/978-3-031-42286-7_1
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
This repository contains the EMAKI dataset:
|
||||||
This repository contains the EMAKI dataset
|
- user: User ID 1~39
|
||||||
|
- task: Text entry and editing (3), image editing (4) and questionnaire completion (5)
|
||||||
- To download the dataset, please fill up our license agreement (TBA)
|
- trial: Questionnaire completion had four trials
|
||||||
|
- type: Action type, one of ['mousemove', 'mousedown', 'mouseup', 'keydown', 'keyup']
|
||||||
|
- timestamp: Timestamp of the current mouse or keyboard action
|
||||||
|
- X/Y: On-screen coordinates of the current mouse cursor
|
||||||
|
- value: Left (1) or right (3) if type is mousedown/up; keystroke content if type is keydown/up; nan/none otherwise
|
||||||
|
- resolutionX/Y: The resolution of user's screen when performing the online user study
|
||||||
|
|
||||||
As well as code for
|
As well as code for
|
||||||
|
|
||||||
|
@ -18,15 +25,22 @@ As well as code for
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
If you find this repository helpful, please cite our work:
|
If you find our work helpful, please cite:
|
||||||
|
|
||||||
```
|
```
|
||||||
@inproceedings{zhang2023exploring,
|
@InProceedings{zhang2023exploring,
|
||||||
title = {Exploring Natural Language Processing Methods for Interactive Behaviour Modelling},
|
author="Zhang, Guanhua
|
||||||
author = {Zhang, Guanhua and Bortoletto, Matteo and Hu, Zhiming and Shi, Lei and B{\^a}ce, Mihai and Bulling, Andreas},
|
and Bortoletto, Matteo
|
||||||
booktitle = {Proc. IFIP TC13 Conference on Human-Computer Interaction (INTERACT)},
|
and Hu, Zhiming
|
||||||
pages = {1--22},
|
and Shi, Lei
|
||||||
year = {2023},
|
and B{\^a}ce, Mihai
|
||||||
publisher = {Springer}
|
and Bulling, Andreas",
|
||||||
|
title="Exploring Natural Language Processing Methods for Interactive Behaviour Modelling",
|
||||||
|
booktitle="Human-Computer Interaction -- INTERACT 2023",
|
||||||
|
year="2023",
|
||||||
|
publisher="Springer Nature Switzerland",
|
||||||
|
address="Cham",
|
||||||
|
pages="3--26",
|
||||||
|
isbn="978-3-031-42286-7"
|
||||||
}
|
}
|
||||||
```
|
```
|
340
utils_Analysis.py
Normal file
340
utils_Analysis.py
Normal file
|
@ -0,0 +1,340 @@
|
||||||
|
from json import decoder
|
||||||
|
from operator import mod
|
||||||
|
import numpy as np
|
||||||
|
import time,pdb,os, random, math, copy
|
||||||
|
import pandas as pd
|
||||||
|
import pickle as pkl
|
||||||
|
from sklearn.metrics import f1_score
|
||||||
|
|
||||||
|
import matplotlib.lines as mlines
|
||||||
|
from scipy.stats import norm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
from torchinfo import summary
|
||||||
|
from torch.utils.data import DataLoader,Dataset
|
||||||
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||||
|
|
||||||
|
|
||||||
|
def find_replace(seq, token, word):
|
||||||
|
st = 0
|
||||||
|
while st<len(word):
|
||||||
|
if seq==tuple(word[st:st+len(seq)]):
|
||||||
|
word = word[:st] + [token] + word[st+len(seq):]
|
||||||
|
st+=1
|
||||||
|
return word
|
||||||
|
|
||||||
|
def generateVocabdict(data):
|
||||||
|
vocab = list(data.event.unique())
|
||||||
|
vocabdict = dict((x, tuple([idx])) for idx, x in enumerate(vocab))
|
||||||
|
vocabdict['OVER'] = tuple([-1])
|
||||||
|
return vocabdict
|
||||||
|
|
||||||
|
def generateWordsVocab(data, vocabdict):
|
||||||
|
word, words = [], []
|
||||||
|
for gkey, gdata in data.groupby(['user', 'task', 'trial']):
|
||||||
|
events = list(gdata.event.values) + ['OVER']
|
||||||
|
word = [vocabdict[x] for x in list(events)]
|
||||||
|
words.append(word)
|
||||||
|
vocab = set([])
|
||||||
|
for word in words:
|
||||||
|
vocab = vocab.union(set(word))
|
||||||
|
vocab = set(x for x in list(vocab))
|
||||||
|
return words, vocab
|
||||||
|
|
||||||
|
def generateBPE(iters, vocab, vocabdict, words, pool, savefileprefix='./', savefreq=50):
|
||||||
|
for i in range(iters):
|
||||||
|
stackeddata = pool.starmap(getPair, [[x] for x in words])
|
||||||
|
if len(stackeddata)==0:
|
||||||
|
break
|
||||||
|
stackedsubwords = pool.starmap(uniqueSubwords, stackeddata)
|
||||||
|
subwords = set([])
|
||||||
|
for idx, subword in enumerate(stackedsubwords):
|
||||||
|
subwords = subwords.union(subword)
|
||||||
|
|
||||||
|
paircounter = dict((x, 0) for x in subwords)
|
||||||
|
for pairs, word in stackeddata:
|
||||||
|
for _, pair in enumerate(pairs):
|
||||||
|
subword = []
|
||||||
|
for x in list(pair):
|
||||||
|
if len(x)==1:
|
||||||
|
subword.append(tuple(x))
|
||||||
|
else:
|
||||||
|
subword.append(list(x))
|
||||||
|
subword = tuple(map(tuple,subword))
|
||||||
|
paircounter[subword] += 1
|
||||||
|
if len(paircounter)==0:
|
||||||
|
break
|
||||||
|
if max(paircounter.values())==1:
|
||||||
|
break
|
||||||
|
targetpair = max(paircounter, key=paircounter.get)
|
||||||
|
prelen = len(vocab)
|
||||||
|
vocab.add(targetpair)
|
||||||
|
|
||||||
|
for idx, word in enumerate(words):
|
||||||
|
stidx = []
|
||||||
|
for st in range(len(word)-len(targetpair)+1):
|
||||||
|
if equalTuple(targetpair, tuple(word[st:st+len(targetpair)])):
|
||||||
|
stidx.append(st)
|
||||||
|
pre = 0
|
||||||
|
updated = []
|
||||||
|
for st in stidx:
|
||||||
|
if pre<st:
|
||||||
|
updated = updated + word[pre:st] + [targetpair]
|
||||||
|
pre = st+len(targetpair)
|
||||||
|
words[idx] = updated + word[pre:]
|
||||||
|
|
||||||
|
if (i+1)%savefreq==0:
|
||||||
|
savefilename = savefileprefix+'%d.pkl'%(i)
|
||||||
|
print('saving to =>', savefilename)
|
||||||
|
with open(savefilename, 'wb') as f:
|
||||||
|
pkl.dump([vocab, vocabdict, words, paircounter],f)
|
||||||
|
return savefilename
|
||||||
|
|
||||||
|
def sort_by_key_len(dict):
|
||||||
|
dict_len= {key: len(key) for key in dict.keys()}
|
||||||
|
import operator
|
||||||
|
sorted_key_list = sorted(dict_len.items(), key=operator.itemgetter(1), reverse=True)
|
||||||
|
sorted_dict = [{item[0]: dict[item [0]]} for item in sorted_key_list]
|
||||||
|
return sorted_dict
|
||||||
|
|
||||||
|
def getPair(word, N=2):
|
||||||
|
word = np.array(word)
|
||||||
|
slid = 1
|
||||||
|
sub_windows = (
|
||||||
|
np.expand_dims(np.arange(N), 0) +
|
||||||
|
np.expand_dims(np.arange(0, word.shape[0]-N+1, slid), 0).T
|
||||||
|
).astype(int)
|
||||||
|
return word[sub_windows], word
|
||||||
|
|
||||||
|
def flattenTuple(x):
|
||||||
|
flatten = []
|
||||||
|
for oo in list(x):
|
||||||
|
if type(oo)==int:
|
||||||
|
flatten.append(tuple([oo]))
|
||||||
|
else:
|
||||||
|
if len(oo)==1:
|
||||||
|
flatten.append(oo)
|
||||||
|
else:
|
||||||
|
flatten = flatten + flattenTuple(oo)
|
||||||
|
return flatten
|
||||||
|
|
||||||
|
def equalTuple(x, y):
|
||||||
|
if set(x)==set(y):
|
||||||
|
if flattenTuple(x)==flattenTuple(y):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def findTuple(target, candidates):
|
||||||
|
for x in candidates:
|
||||||
|
if equalTuple(target, x):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def uniqueSubwords(pairs, ignore):
|
||||||
|
subwords = set([])
|
||||||
|
for pair in pairs:
|
||||||
|
subword = []
|
||||||
|
for x in list(pair):
|
||||||
|
if len(x)==1:
|
||||||
|
subword.append(tuple(x))
|
||||||
|
else:
|
||||||
|
subword.append(list(x))
|
||||||
|
subword = tuple(map(tuple,subword))
|
||||||
|
subwords.add(subword)
|
||||||
|
return subwords
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
def poolsegmentTokenise(gkey, gdata, win_len, vocabdict, mark, BPEvocab=None):
|
||||||
|
unknowntoken = vocabdict['unknown']
|
||||||
|
paddingtoken = vocabdict['padding']
|
||||||
|
slid = win_len
|
||||||
|
windowTokendict, windowBPEdict = {}, {}
|
||||||
|
events = np.array(list(gdata.event.values) + ['OVER'])
|
||||||
|
sub_windows = (
|
||||||
|
np.expand_dims(np.arange(win_len), 0) +
|
||||||
|
np.expand_dims(np.arange(0, len(gdata)-win_len+1, slid), 0).T
|
||||||
|
).astype(int)
|
||||||
|
windowOri = events[sub_windows].tolist()
|
||||||
|
if len(sub_windows)==0:
|
||||||
|
lastidx = 0
|
||||||
|
else:
|
||||||
|
lastidx = sub_windows[-1][-1]
|
||||||
|
if lastidx<len(gdata)-1:
|
||||||
|
windowOri.append(events[lastidx+1:].tolist())
|
||||||
|
|
||||||
|
windowToken = []
|
||||||
|
for windowO in windowOri:
|
||||||
|
word = []
|
||||||
|
for x in list(windowO):
|
||||||
|
if x in vocabdict.keys():
|
||||||
|
word.append(vocabdict[x])
|
||||||
|
else:
|
||||||
|
assert mark!='train'
|
||||||
|
word.append(unknowntoken)
|
||||||
|
windowToken.append(word)
|
||||||
|
windowTokendict[gkey] = windowToken
|
||||||
|
if BPEvocab is None:
|
||||||
|
BPEvocab = [{tuple([tuple([x-1])]):x} for x in range(len(vocabdict))]
|
||||||
|
|
||||||
|
windowBPE = []
|
||||||
|
for word in windowToken:
|
||||||
|
prelen = len(word)
|
||||||
|
for vocab in BPEvocab:
|
||||||
|
seq = list(vocab.keys())[0]
|
||||||
|
token = list(vocab.values())[0]
|
||||||
|
word = find_replace(seq, token, word)
|
||||||
|
assert len(word)<=prelen
|
||||||
|
windowBPE.append(word)
|
||||||
|
windowBPEdict[gkey] = windowBPE
|
||||||
|
for key,values in windowBPEdict.items():
|
||||||
|
oo = []
|
||||||
|
for word in values:
|
||||||
|
newword = []
|
||||||
|
for x in word:
|
||||||
|
if type(x)==tuple:
|
||||||
|
newword.append(x[0])
|
||||||
|
else:
|
||||||
|
newword.append(x)
|
||||||
|
oo.append(newword)
|
||||||
|
windowBPEdict[key] = oo
|
||||||
|
segments, labels, nextaction = np.array([]), [], []
|
||||||
|
for gkey, groupWindowBPE in windowBPEdict.items():
|
||||||
|
for idx, windowBPE in enumerate(groupWindowBPE):
|
||||||
|
windowBPE = windowBPE + [paddingtoken[0] for x in range(win_len-len(windowBPE))]
|
||||||
|
assert len(windowBPE)==win_len
|
||||||
|
if idx<len(groupWindowBPE)-1:
|
||||||
|
nextaction.append(windowTokendict[gkey][idx+1][0][0])
|
||||||
|
else:
|
||||||
|
nextaction.append(vocabdict['OVER'][0])
|
||||||
|
labels.append(gkey[1])
|
||||||
|
if segments.shape[0]==0:
|
||||||
|
segments = np.array([windowBPE])
|
||||||
|
else:
|
||||||
|
segments = np.concatenate((segments, np.array([windowBPE])), axis=0)
|
||||||
|
return segments, labels, unknowntoken[0], paddingtoken[0]
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, d_model: int, dropout: float, max_len: int = 5000):
|
||||||
|
super().__init__()
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
position = torch.arange(max_len).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||||
|
pe = torch.zeros(1, max_len, d_model)
|
||||||
|
pe[0, :, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[0, :, 1::2] = torch.cos(position * div_term)
|
||||||
|
self.register_buffer('pe', pe)
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = x + self.pe[:,:x.shape[1]]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisedTransformerv2(nn.Module):
|
||||||
|
def __init__(self, **params):
|
||||||
|
super().__init__()
|
||||||
|
self.params = params
|
||||||
|
self.pos_encoder = PositionalEncoding(d_model=params['d_model'], dropout=params['dropout'])
|
||||||
|
if 'ntokens' in params.keys():
|
||||||
|
self.embedding = nn.Embedding(num_embeddings=params['ntokens'], embedding_dim=params['d_model'])
|
||||||
|
self.transformer_encoder = TransformerEncoder(TransformerEncoderLayer(d_model=params['d_model'], nhead=params['nhead'],
|
||||||
|
dim_feedforward=params['d_model']*4,
|
||||||
|
dropout=params['dropout'], activation='relu', batch_first=True),
|
||||||
|
params['n_layers'])
|
||||||
|
self.linear = nn.Linear(params['d_model']*params['win_len'], params['nlabels'])
|
||||||
|
|
||||||
|
def forward(self, encoder_input, paddingmask):
|
||||||
|
encoder_embed = self.embedding(encoder_input) * math.sqrt(self.params['d_model'])
|
||||||
|
encoder_pos = self.pos_encoder(encoder_embed)
|
||||||
|
encoder_output = self.transformer_encoder(src=encoder_pos, src_key_padding_mask=paddingmask)
|
||||||
|
output = encoder_output.view(encoder_output.shape[0],-1)
|
||||||
|
final_output = self.linear(output)
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
def evaluate(model, criterion, val_loader, device):
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
losses, f1s = [], []
|
||||||
|
with torch.no_grad():
|
||||||
|
for batchdata, batchlabel, batchmask in val_loader:
|
||||||
|
predictions = model(batchdata.to(device), batchmask.to(device))
|
||||||
|
loss = criterion(predictions, batchlabel.reshape(-1).to(device).long())
|
||||||
|
if np.isnan(loss.item()):
|
||||||
|
raise "Loss NaN!"
|
||||||
|
losses.append(loss.item())
|
||||||
|
pred_label = np.argmax(predictions.detach().cpu().numpy(), axis=1)
|
||||||
|
f1 = f1_score(batchlabel.numpy(), pred_label, average='macro')
|
||||||
|
f1s.append(f1)
|
||||||
|
return np.mean(losses), np.mean(f1s)
|
||||||
|
|
||||||
|
class DatasetPadding(Dataset):
|
||||||
|
def __init__(self, data, paddingtoken=None, label=None):
|
||||||
|
self.data = data
|
||||||
|
self.label = label
|
||||||
|
if paddingtoken is not None:
|
||||||
|
self.mask = data==paddingtoken
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if self.label is None:
|
||||||
|
return self.data[idx], self.mask[idx]
|
||||||
|
return self.data[idx], self.label[idx], self.mask[idx]
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def fit_transformer(traindata, trainlabel, testdata, testlabel, args, device, paddingtoken, nLabels, savemodel='model', nTokens=None):
|
||||||
|
if os.path.exists(savemodel):
|
||||||
|
return
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
traindata = traindata.astype(int)
|
||||||
|
|
||||||
|
params = {'nlabels': nLabels, 'batch_size':args.batch_size, 'd_model':args.d_model,
|
||||||
|
'optimizer_name': args, 'nhead':args.nhead, 'dropout':args.dropout,
|
||||||
|
'win_len': traindata.shape[1], 'lr':args.lr, 'n_layers':args.n_layers, 'ntokens': nTokens}
|
||||||
|
trainset = DatasetPadding(data=traindata, paddingtoken=paddingtoken, label=trainlabel)
|
||||||
|
testset = DatasetPadding(data=testdata, paddingtoken=paddingtoken, label=testlabel)
|
||||||
|
trainloader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=0)
|
||||||
|
testloader = DataLoader(testset, batch_size=params['batch_size'], shuffle=True, num_workers=0)
|
||||||
|
|
||||||
|
model = SupervisedTransformerv2(**params).to(device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
optimizer = getattr(torch.optim, params['optimizer_name'])(model.parameters(),
|
||||||
|
lr=params['lr'], betas=(0.9,0.999), weight_decay=0.01)
|
||||||
|
if len(trainloader)>=20:
|
||||||
|
LOG = int(len(trainloader)/20)
|
||||||
|
else:
|
||||||
|
LOG = 1
|
||||||
|
trloss, valoss, trf1, vaf1 = [], [], [], []
|
||||||
|
evaloss, evaf1 = 0,0
|
||||||
|
for epoch in range(1, args.epochs+1):
|
||||||
|
for batch, (batchdata, batchlabel, batchmask) in enumerate(trainloader):
|
||||||
|
predictions = model(batchdata.to(device), batchmask.to(device))
|
||||||
|
loss = criterion(predictions, batchlabel.reshape(-1).to(device).long())
|
||||||
|
if np.isnan(loss.item()):
|
||||||
|
raise "Loss NaN!"
|
||||||
|
loss.requires_grad_(True)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
evaloss+=loss.item()
|
||||||
|
pred_label = np.argmax(predictions.detach().cpu().numpy(), axis=1)
|
||||||
|
f1 = f1_score(batchlabel.numpy(), pred_label, average='macro')
|
||||||
|
evaf1 += f1
|
||||||
|
if batch%(LOG)==0 or batch==len(trainloader)-1:
|
||||||
|
cur_valoss, cur_vaf1 = evaluate(model, criterion, testloader, device)
|
||||||
|
model.train()
|
||||||
|
trloss.append(evaloss/LOG)
|
||||||
|
valoss.append(cur_valoss)
|
||||||
|
trf1.append(evaf1/LOG)
|
||||||
|
vaf1.append(cur_vaf1)
|
||||||
|
evaloss, evaf1 = 0,0
|
||||||
|
print('Epoch [{}/{}], Batch [{}/{}], Train Loss: {:.4f}, Train F1: {:.4f}, Val Loss: {:.4f}, Val F1: {:.4f}'
|
||||||
|
.format(epoch, args.epochs, batch, len(trainloader), trloss[-1], trf1[-1], valoss[-1], vaf1[-1]))
|
||||||
|
torch.save([model.cpu(), [trloss, valoss, trf1, vaf1]], savemodel+'%d.pkl'%(epoch))
|
Loading…
Reference in a new issue