first commit
This commit is contained in:
commit
83b04e2133
109 changed files with 12081 additions and 0 deletions
29
README.md
Normal file
29
README.md
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Inferring Human Intentions from Predicted Action Probabilities
|
||||
|
||||
*Lei Shi, Paul Bürkner, Andreas Bulling*
|
||||
|
||||
*University of Stuttgart, Stuttgart, Germany*
|
||||
|
||||
Accepted by [Workshop on Theory of Mind in Human-AI Interaction at CHI 2024](https://theoryofmindinhaichi2024.wordpress.com/)
|
||||
|
||||
## Requirements
|
||||
The code is test in Ubuntu 20.04.
|
||||
|
||||
```
|
||||
pytorch 1.11.0
|
||||
matplotlib 3.3.2
|
||||
pickle 4.0
|
||||
pandas 1.4.3
|
||||
R 4.2.1
|
||||
RStan 2.26.3
|
||||
```
|
||||
To install R, [see here](https://cran.r-project.org/bin/linux/ubuntu/fullREADME.html)
|
||||
|
||||
To install RStan, [see here](https://mc-stan.org/users/interfaces/rstan.html)
|
||||
|
||||
## Experiments
|
||||
|
||||
To train and evaluate the method on Watch-And-Help dataset, see [here](watch_and_help/README.md)
|
||||
|
||||
To train and evaluate the method on Keyboard and Mouse Interaction dataset, see [here](keyboard_and_mouse/README.md)
|
||||
|
69
keyboard_and_mouse/README.MD
Normal file
69
keyboard_and_mouse/README.MD
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Keyboard And Mouse Interactive Dataset
|
||||
|
||||
|
||||
# Neural Network
|
||||
|
||||
## Requirements
|
||||
The code is test in Ubuntu 20.04.
|
||||
|
||||
pytorch 1.11.0
|
||||
matplotlib 3.3.2
|
||||
pickle 4.0
|
||||
pandas 1.4.3
|
||||
|
||||
## Train
|
||||
|
||||
Set training parameters in train.sh
|
||||
|
||||
Run `sh train.sh` to train the model
|
||||
|
||||
|
||||
## Test
|
||||
|
||||
Run `sh test.sh` to run test on trained model
|
||||
|
||||
Predictions are saved under `prediction/task$i$/`
|
||||
|
||||
|
||||
# Bayesian Inference
|
||||
|
||||
## Requirements
|
||||
R 4.2.1
|
||||
RStan [](https://mc-stan.org/users/interfaces/rstan.html)
|
||||
|
||||
|
||||
Run `sh sampler_user.sh` to split prediction to 10% to 90%
|
||||
|
||||
Run `Rscript stan/strategy_inference_test.R` to get results of intention prediction for all users
|
||||
Run `sh stan/plot_user.sh` to plot the bar chart for user intention prediction results of all action sequences
|
||||
|
||||
Run `Rscript stan/strategy_inference_test_full_length.R` to get results of intention prediction (0% to 100%) for all users
|
||||
Run `sh stan/plot_user_length_10_steps.sh` to plot the bar chart for user intention prediction results (0% to 100%) of all action sequences
|
||||
|
||||
Run `sh sampler_single_act.sh` to get the predictions for each individual action sequence.
|
||||
Run `Rscript stan/strategy_inference_test_all_individual_act.R` to get all action sequences (0% to 100%) of all users for intention prediction
|
||||
Run `sh plot_user_all_individual.sh` to plot the bar chart for user intention prediction results of all action sequences
|
||||
Run `sh plot_user_length_10_steps_all_individual.sh` to plot the user intention prediction results (0% to 100%) of all action sequences
|
||||
|
||||
|
||||
|
||||
|
||||
Set training and test parameters in train.sh and test.sh
|
||||
|
||||
Run sh train.sh to train the model.
|
||||
|
||||
Run sh test.sh to run test on trained model.
|
||||
Predictions are saved under prediction/task$i$/
|
||||
|
||||
Run sh sampler_user.sh to split prediction to 10% to 90%
|
||||
|
||||
Run stan/strategy_inference_test.R to get results of intention prediction for all users
|
||||
Run stan/plot_user.py to plot the bar chart for user intention prediction results of all action sequences
|
||||
|
||||
Run stan/strategy_inference_test_full_length.R to get results of intention prediction (0% to 100%) for all users
|
||||
Run stan/plot_user_length_10_users.py to plot the bar chart for user intention prediction results (0% to 100%) of all action sequences
|
||||
|
||||
|
||||
Run stan/strategy_inference_test_all_individual_act.R to get all action sequences (0% to 100%) of all users for intention prediction
|
||||
Run stan/plot_user_all_individual.py to plot the bar chart for user intention prediction results of all action sequences
|
||||
Run stan/plot_user_length_10_steps_all_individual.py to plot the user intention prediction results (0% to 100%) of all action sequences
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,852 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f5cb2ecf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2021-09-28 16:10:28.497166: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import os, pdb\n",
|
||||
"from sklearn.model_selection import GridSearchCV \n",
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"from sklearn.metrics import accuracy_score\n",
|
||||
"from tensorflow import keras\n",
|
||||
"from keras.preprocessing.sequence import pad_sequences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "459ad77b",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"study_data_path = \"../IntentData/\"\n",
|
||||
"data = pd.read_pickle(study_data_path + \"/Preprocessing_data/clean_data.pkl\")\n",
|
||||
"Task_IDs = np.arange(7).tolist()\n",
|
||||
"StartIndexOffset = 0 #if set to 5 ignore first 5 elements\n",
|
||||
"EndIndexOffset = 0 #if set to 5 ignore last 5 elements"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "4ab8c0cc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array(['Cmd', 'Toolbar'], dtype=object)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data.Rule.unique()\n",
|
||||
"data.columns\n",
|
||||
"data.Type.unique()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "05550387",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# grouping by part is needed to have one ruleset for the whole part\n",
|
||||
"# Participant [1,16]\n",
|
||||
"# Repeat for 5 times [1,5]\n",
|
||||
"# ???????? [0,6]\n",
|
||||
"g = data.groupby([\"PID\", \"Part\", \"TaskID\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "7819da48",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"param_grid = {'n_estimators': [10,50,100], \n",
|
||||
" 'max_depth': [10,20,30]}\n",
|
||||
"\n",
|
||||
"grid = GridSearchCV(RandomForestClassifier(), param_grid, refit = True, verbose = 0, return_train_score=True) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "e64a6920",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def createTrainTest(test_IDs, task_IDs, start_index_offset, end_index_offset, shapes=False):\n",
|
||||
" assert isinstance(test_IDs, list)\n",
|
||||
" assert isinstance(task_IDs, list)\n",
|
||||
" # Fill data arrays\n",
|
||||
" y_train = []\n",
|
||||
" x_train = []\n",
|
||||
" y_test = []\n",
|
||||
" x_test = []\n",
|
||||
" for current in g.groups.keys():\n",
|
||||
" c = g.get_group(current)\n",
|
||||
" if (c.TaskID.isin(task_IDs).all()):\n",
|
||||
" new_rule = c.Rule.unique()[0]\n",
|
||||
" if end_index_offset == 0:\n",
|
||||
" new_data = c.Event.values[start_index_offset:]\n",
|
||||
" else:\n",
|
||||
" new_data = c.Event.values[start_index_offset:-end_index_offset]\n",
|
||||
" if (c.PID.isin(test_IDs).all()):\n",
|
||||
" y_test.append(new_rule)\n",
|
||||
" x_test.append(new_data)\n",
|
||||
" else:\n",
|
||||
" y_train.append(new_rule)\n",
|
||||
" x_train.append(new_data)\n",
|
||||
" x_train = np.array(x_train)\n",
|
||||
" y_train = np.array(y_train)\n",
|
||||
" x_test = np.array(x_test)\n",
|
||||
" y_test = np.array(y_test)\n",
|
||||
" print('x_train\\n',x_train)\n",
|
||||
" print('y_train\\n',y_train)\n",
|
||||
" print('x_test\\n',x_test)\n",
|
||||
" print('y_test\\n',y_test)\n",
|
||||
" pdb.set_trace()\n",
|
||||
" if (shapes):\n",
|
||||
" print(x_train.shape)\n",
|
||||
" print(y_train.shape)\n",
|
||||
" print(x_test.shape)\n",
|
||||
" print(y_test.shape)\n",
|
||||
" print(np.unique(y_test))\n",
|
||||
" print(np.unique(y_train))\n",
|
||||
" return (x_train, y_train, x_test, y_test)\n",
|
||||
"\n",
|
||||
"def runSVMS(train_test, maxlen=None, plots=False, last_elements=False):\n",
|
||||
" x_train, y_train, x_test, y_test = train_test\n",
|
||||
" # Get maxlen to pad and pad\n",
|
||||
" if (maxlen==None):\n",
|
||||
" maxlen = 0\n",
|
||||
" for d in np.concatenate((x_train,x_test)):\n",
|
||||
" if len(d) > maxlen:\n",
|
||||
" maxlen = len(d)\n",
|
||||
" \n",
|
||||
" truncating_elements = \"post\"\n",
|
||||
" if last_elements:\n",
|
||||
" truncating_elements = \"pre\"\n",
|
||||
"\n",
|
||||
" x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen, dtype='int32', padding='post', truncating=truncating_elements, value=0)\n",
|
||||
" x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen, dtype='int32', padding='post', truncating=truncating_elements, value=0)\n",
|
||||
"\n",
|
||||
" # fitting the model for grid search \n",
|
||||
" grid.fit(x_train, y_train) \n",
|
||||
"\n",
|
||||
" # print how our model looks after hyper-parameter tuning\n",
|
||||
" if (plots==True):\n",
|
||||
" print(grid.best_estimator_) \n",
|
||||
"\n",
|
||||
" # Predict with best SVM\n",
|
||||
" pred = grid.predict(x_test)\n",
|
||||
"\n",
|
||||
" return accuracy_score(pred, y_test), pred, y_test "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "50dac7db",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/tmp/ipykernel_97850/1264473745.py:23: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
||||
" x_train = np.array(x_train)\n",
|
||||
"/tmp/ipykernel_97850/1264473745.py:25: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
||||
" x_test = np.array(x_test)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"x_train\n",
|
||||
" [array([2, 7, 7, 7, 7, 7, 7, 2, 6, 6, 6, 2, 2, 2])\n",
|
||||
" array([4, 1, 4, 1, 1, 1, 1, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([5, 7, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 3, 3, 3, 3, 3, 6, 6, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([5, 3, 5, 3, 3, 5, 3, 5, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 6, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 3, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 6, 6, 2, 2, 2, 7, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 7, 4, 1, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 5, 7, 7, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 3, 6, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 3, 4, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 1, 4, 7, 4, 1, 4, 1, 7, 7, 7, 7, 7, 4, 4])\n",
|
||||
" array([5, 7, 7, 1, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 3, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([5, 3, 3, 4, 3, 5, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 6, 6, 6, 2, 2, 7, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 1, 4, 7, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([5, 7, 7, 1, 5, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 5, 2, 2, 6, 3, 3, 6, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([5, 3, 3, 4, 3, 5, 3, 5, 3, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 1, 1, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 1, 4, 7, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 6, 3, 6, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3])\n",
|
||||
" array([3, 5, 3, 4, 3, 5, 3, 5, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 2, 6, 2, 2, 6, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 2, 2, 2, 2, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 7, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 5, 3, 4, 3, 3, 5, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([5, 7, 7, 1, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 1, 4, 7, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([3, 2, 2, 4, 2, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 5, 3, 4, 3, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 5, 3, 3, 3, 6, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 5, 5, 7, 1, 7, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 7, 1, 4, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 5, 3, 4, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 3, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([5, 7, 7, 1, 5, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 2, 7, 2, 2, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 7, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 5, 3, 3, 6, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 1, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 7, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 2, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 5, 3, 4, 4, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 3, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 1, 7, 7, 5, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 5, 3, 4, 4, 4, 3, 5, 3, 5, 5, 3, 3])\n",
|
||||
" array([6, 3, 5, 5, 5, 5, 5, 5, 6, 3, 6, 3, 3, 3, 3])\n",
|
||||
" array([2, 2, 3, 2, 2, 2, 3, 3, 2, 2, 2, 3, 2, 2, 2, 3])\n",
|
||||
" array([1, 7, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 1, 1, 1, 6, 2, 1, 2, 6, 1, 2, 1])\n",
|
||||
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 7, 7, 1, 1, 1, 1, 1, 1, 5, 7, 5, 7])\n",
|
||||
" array([3, 3, 5, 3, 3, 5, 3, 3, 3, 3, 5])\n",
|
||||
" array([2, 3, 3, 5, 3, 3, 3, 3, 2]) array([2, 3, 2, 2, 2, 2, 2, 3])\n",
|
||||
" array([1, 7, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 2, 6, 7, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 5, 7, 7, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 3, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||
" array([6, 3, 3, 5, 6, 3, 3, 6, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 3, 2, 2, 2, 3, 2, 3, 2]) array([1, 7, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 2, 7, 7, 7, 2, 2, 7, 6, 7, 7, 2])\n",
|
||||
" array([5, 7, 7, 1, 1, 7, 5, 7, 1, 1, 7, 1, 7, 1])\n",
|
||||
" array([3, 5, 3, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 6, 3, 3, 6, 3, 6, 5, 5, 5, 5, 5, 6, 6, 6])\n",
|
||||
" array([2, 2, 3, 2, 2, 3, 2, 2, 3]) array([1, 7, 7, 1, 7, 7, 7, 7, 1])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 2, 7, 6, 6, 2, 2, 2, 2, 7, 7, 4, 7, 7, 7])\n",
|
||||
" array([7, 7, 5, 1, 7, 5, 7, 5, 7, 7, 1, 1, 1, 1, 1, 7])\n",
|
||||
" array([3, 5, 3, 5, 3, 3, 5, 3, 3])\n",
|
||||
" array([3, 6, 3, 5, 3, 3, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 3, 2, 2, 2, 2, 2, 3]) array([1, 7, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 1, 1, 1, 1, 1, 1, 1, 2, 6, 6, 2, 2, 6, 2])\n",
|
||||
" array([6, 2, 7, 6, 6, 6, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 7, 5, 7, 5, 7, 5])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([4, 1, 4, 4, 4, 4, 4, 7, 4, 1, 4, 1, 4, 1, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 5, 7, 5, 7, 7, 7])\n",
|
||||
" array([6, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2, 6])\n",
|
||||
" array([4, 5, 3, 3, 3, 3, 3, 3, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 6, 2, 7, 7, 7, 7, 7, 7, 6, 2, 2, 2])\n",
|
||||
" array([6, 3, 3, 5, 5, 5, 5, 5, 5, 3, 6, 3, 3, 6, 3])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([4, 1, 4, 7, 4, 4, 4, 1, 4, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 5, 7, 7, 1, 7, 7, 7, 7, 5, 5])\n",
|
||||
" array([6, 2, 2, 1, 1, 2, 1, 2, 6, 1, 2, 1, 2, 6, 1])\n",
|
||||
" array([5, 4, 3, 4, 5, 3, 4, 3, 4, 3, 4, 5, 3, 4])\n",
|
||||
" array([6, 2, 7, 7, 6, 7, 2, 6, 7, 7, 6, 7])\n",
|
||||
" array([6, 3, 3, 5, 5, 1, 1, 3, 5, 6, 3, 5, 6, 3, 5, 6, 3])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([4, 1, 4, 7, 7, 4, 1, 7, 4, 7, 4, 1, 4, 1, 7, 7])\n",
|
||||
" array([7, 4, 5, 7, 1, 1, 7, 1, 7, 5, 5, 7, 2, 7, 7, 7])\n",
|
||||
" array([6, 2, 2, 1, 2, 1, 2, 6, 6, 2, 6, 2, 1, 1, 1])\n",
|
||||
" array([5, 3, 3, 4, 3, 4, 4, 3, 4, 5, 3, 4, 4, 3])\n",
|
||||
" array([6, 2, 7, 7, 2, 7, 7, 6, 2, 2, 7, 7])\n",
|
||||
" array([1, 3, 1, 6, 3, 3, 5, 3, 5, 5, 6, 3, 5, 3, 5, 5, 3])\n",
|
||||
" array([2, 3, 4, 3, 4, 2, 3, 2, 3, 4, 4, 2, 3, 3, 4, 3, 3, 4, 3, 2, 3, 2])\n",
|
||||
" array([4, 1, 4, 7, 7, 4, 1, 4, 7, 7, 4, 7, 4, 1, 7])\n",
|
||||
" array([7, 5, 7, 1, 1, 7, 5, 1, 7, 7, 7, 1, 1, 1, 7, 5])\n",
|
||||
" array([6, 2, 2, 1, 1, 1, 1, 1, 1, 2, 6, 2, 2, 2])\n",
|
||||
" array([5, 3, 3, 4, 3, 5, 3, 5, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 3, 3, 5, 3, 6, 6, 3, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 3, 2, 4, 2, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 4, 1, 4, 1])\n",
|
||||
" array([7, 5, 7, 1, 1, 7, 5, 1, 7, 7, 5, 1, 1, 1, 7])\n",
|
||||
" array([6, 2, 2, 1, 1, 2, 2, 1, 1, 2, 1, 2, 6, 1])\n",
|
||||
" array([5, 3, 3, 4, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 3, 3, 5, 5, 3, 6, 3, 5, 5, 6, 3, 5, 5, 3, 5, 5])\n",
|
||||
" array([6, 2, 7, 7, 7, 7, 7, 7, 6, 6, 2, 2])\n",
|
||||
" array([5, 3, 4, 3, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([5, 7, 1, 7, 1, 3, 1, 1, 3, 1, 1, 1, 1, 7, 7, 7, 7, 5, 5])\n",
|
||||
" array([1, 7, 4, 4, 1, 1, 7, 7, 7, 7, 7, 4, 4, 4, 1, 4])\n",
|
||||
" array([6, 2, 1, 2, 2, 2, 2, 2, 6, 6])\n",
|
||||
" array([6, 3, 5, 5, 6, 6, 6, 6, 6, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 3, 2, 4, 2, 2, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([3, 5, 4, 3, 3, 3, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 7, 7, 7, 5, 5, 5, 5])\n",
|
||||
" array([1, 4, 7, 4, 7, 7, 7, 7, 7, 1, 1])\n",
|
||||
" array([6, 2, 2, 1, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 3, 3, 5, 3, 6, 6, 3, 3, 3, 6])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 7, 2, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([3, 5, 3, 4, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([7, 5, 7, 1, 7, 5, 7, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([1, 4, 4, 7, 4, 4, 4, 4, 7, 7, 7, 7, 7, 1, 1, 1])\n",
|
||||
" array([2, 6, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 6, 6])\n",
|
||||
" array([3, 6, 3, 5, 3, 3, 3, 3, 3, 5, 5, 5, 5, 5, 6])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 2, 6, 2, 7, 6, 2, 2, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([5, 3, 3, 4, 3, 3, 5, 5, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([7, 5, 1, 7, 2, 2, 3, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([1, 4, 4, 7, 1, 4, 7, 7, 7, 7, 7, 4, 1, 1, 4, 4])\n",
|
||||
" array([6, 2, 2, 1, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 5, 3, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([5, 3, 3, 4, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 5, 5, 7, 4, 7, 7, 7])\n",
|
||||
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 1, 4, 4])\n",
|
||||
" array([6, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 6, 6])\n",
|
||||
" array([6, 3, 3, 5, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 5, 4, 3, 3, 5, 5, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 3, 3, 2, 3, 3, 2, 3, 3])\n",
|
||||
" array([2, 6, 2, 2, 2, 6, 2, 6, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([1, 4, 4, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 1, 7, 1, 7, 5, 1, 7, 1, 7, 1, 7, 5, 1])\n",
|
||||
" array([2, 2, 2, 2, 7, 7, 7, 7, 7, 7, 2, 2, 7, 2, 6, 2, 6, 6, 2, 2, 6])\n",
|
||||
" array([5, 3, 5, 3, 3, 3, 5, 3, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 2, 4, 3, 3, 2, 4, 2, 4, 4, 2, 2, 3, 4, 4, 2, 3, 4])\n",
|
||||
" array([2, 3, 3, 3, 4, 2, 3, 4, 3, 4, 2, 3, 4])\n",
|
||||
" array([6, 2, 1, 2, 1, 2, 6, 1, 2, 6, 1, 2, 1, 2, 6, 1])\n",
|
||||
" array([4, 1, 7, 4, 1, 1, 7, 1, 4, 7, 4, 7, 4, 7, 1, 4, 7])\n",
|
||||
" array([7, 5, 1, 7, 1, 7, 7, 1, 7, 5, 1, 7, 5, 1, 7, 1])\n",
|
||||
" array([6, 7, 2, 7, 2, 7, 6, 7, 2, 2, 2, 2, 7, 6])\n",
|
||||
" array([3, 5, 4, 3, 4, 4, 3, 5, 4, 3, 5, 4, 3, 4, 3, 4])\n",
|
||||
" array([2, 3, 4, 2, 4, 2, 3, 4, 2, 4, 2, 4, 2, 3, 4])\n",
|
||||
" array([6, 3, 3, 4, 3, 4, 6, 3, 4, 3, 4, 4, 3])\n",
|
||||
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 1, 2, 6, 1, 2, 1, 2])\n",
|
||||
" array([1, 4, 7, 4, 7, 1, 4, 7, 1, 4, 7, 4, 7, 7, 1, 4])\n",
|
||||
" array([7, 4, 5, 1, 7, 7, 1, 7, 7, 1, 7, 5, 1, 7, 1, 7, 1])\n",
|
||||
" array([6, 7, 2, 7, 6, 7, 2, 7, 2, 7, 6, 7])\n",
|
||||
" array([3, 5, 3, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 3, 2, 2, 2, 2, 3, 4, 4, 4, 4, 2, 4, 4])\n",
|
||||
" array([2, 3, 3, 2, 3, 2, 3, 2, 3, 3, 4, 4, 2, 6, 2, 6, 2, 6, 2, 6])\n",
|
||||
" array([2, 6, 1, 2, 1, 2, 1, 2, 2, 6, 1, 1, 1, 1, 2, 1])\n",
|
||||
" array([4, 1, 7, 1, 7, 1, 7, 1, 4, 7, 1, 4, 7, 1, 7, 7])\n",
|
||||
" array([7, 5, 1, 7, 1, 7, 5, 1, 7, 1, 7, 5, 1, 7, 5, 1])\n",
|
||||
" array([6, 7, 2, 7, 6, 7, 6, 7, 6, 7, 2, 7])\n",
|
||||
" array([3, 5, 3, 3, 3, 5, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 3, 2, 2, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 3, 3, 5, 5, 5, 5, 3, 5, 3, 5, 3, 6, 5, 3, 6, 5])\n",
|
||||
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 1, 2, 2, 2, 1, 2, 1])\n",
|
||||
" array([1, 4, 7, 4, 7, 1, 4, 7, 4, 7, 1, 4, 7, 4, 7])\n",
|
||||
" array([7, 5, 1, 7, 1, 7, 5, 1, 7, 1, 7, 1, 7, 5])\n",
|
||||
" array([6, 6, 7, 2, 7, 6, 7, 6, 7, 6, 7, 2, 7])\n",
|
||||
" array([2, 3, 2, 6, 3, 5, 5, 3, 6, 3, 6, 3, 6, 3, 5, 5, 5, 5])\n",
|
||||
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 1, 1, 1, 1, 1, 1, 2, 2, 6, 2, 6, 2])\n",
|
||||
" array([2, 3, 3, 4, 2, 3, 3, 3, 2, 2, 2, 4, 4, 4, 4, 4, 2, 3, 3, 2, 3, 2])\n",
|
||||
" array([4, 4, 1, 7, 4, 4, 4, 4, 1, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 3, 3, 5, 3, 3, 6, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([5, 3, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 3, 3, 5])\n",
|
||||
" array([6, 2, 2, 1, 2, 6, 2, 2, 6, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([4, 1, 4, 7, 4, 1, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 2, 7, 2, 6, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 3, 5, 2, 2, 3, 6, 3, 3, 2, 6, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([5, 3, 3, 4, 3, 5, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 2, 2, 3, 2, 4, 2, 3, 2, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([4, 1, 4, 7, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7, 1])\n",
|
||||
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 1, 7, 1, 7, 5, 1, 7, 1, 7, 1])\n",
|
||||
" array([6, 3, 3, 5, 3, 6, 5, 5, 3, 6, 5, 3, 5, 3, 5])\n",
|
||||
" array([3, 5, 3, 4, 3, 3, 5, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 6, 2, 6, 2, 2, 6, 3, 3, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 4, 4, 2, 3, 4, 2, 3, 3, 4, 2, 3, 4])\n",
|
||||
" array([3, 3, 1, 4, 4, 7, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 2, 7, 2, 2, 7, 7, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 7, 7, 5, 5, 5, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 5, 5, 3, 5, 5, 3, 6, 5, 3, 5, 3, 6, 5])\n",
|
||||
" array([5, 3, 3, 3, 4, 4, 3, 5, 4, 3, 4, 3, 5, 4])\n",
|
||||
" array([6, 2, 2, 1, 2, 2, 6, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 4, 2, 2, 3, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([1, 4, 4, 7, 1, 4, 1, 4, 1, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 2, 7, 7, 2, 7, 2, 7, 6, 7, 6, 7, 7, 6])\n",
|
||||
" array([7, 5, 7, 1, 7, 1, 1, 7, 1, 7, 5, 1, 7, 5, 1])\n",
|
||||
" array([7, 5, 1, 7, 2, 2, 5, 5, 5, 5, 5, 2, 2, 2, 2, 1, 5, 5, 5, 5, 2, 7,\n",
|
||||
" 2, 7, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 4, 4, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4, 2])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 5, 4, 3, 4, 3, 5, 4, 3, 3, 4, 3, 5, 4, 4, 3, 3, 5, 4])\n",
|
||||
" array([2, 6, 7, 2, 7, 2, 2, 2, 6, 7, 2, 6, 7, 2, 7, 2, 6, 7])\n",
|
||||
" array([2, 3, 5, 3, 3, 5, 2, 3, 5, 2, 3, 3, 5, 3, 4, 5, 2, 3, 5])\n",
|
||||
" array([1, 4, 7, 4, 7, 4, 1, 7, 4, 7, 1, 4, 7, 4, 7])\n",
|
||||
" array([5, 7, 1, 7, 1, 5, 7, 1, 7, 5, 1, 7, 1, 7, 1])\n",
|
||||
" array([2, 3, 1, 2, 1, 2, 1, 2, 1, 3, 2, 3, 1])\n",
|
||||
" array([2, 6, 1, 6, 1, 6, 1, 2, 6, 1, 2, 6, 1, 6, 1])\n",
|
||||
" array([5, 3, 4, 3, 3, 3, 4, 3, 4, 3, 4, 4, 3, 5, 4, 3, 5, 4])\n",
|
||||
" array([6, 7, 2, 7, 2, 7, 6, 7, 6, 7, 2, 7])\n",
|
||||
" array([2, 3, 5, 3, 5, 3, 5, 3, 5, 6, 3, 2, 6, 5, 5, 3])\n",
|
||||
" array([1, 4, 7, 4, 7, 4, 1, 7, 1, 4, 7, 4, 7, 4, 7])\n",
|
||||
" array([5, 7, 1, 7, 1, 5, 7, 1, 7, 5, 1, 7, 1, 1, 5, 7])\n",
|
||||
" array([3, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 3, 4])\n",
|
||||
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 6, 1, 2, 6, 1, 2, 1])\n",
|
||||
" array([3, 5, 4, 5, 4, 5, 4, 5, 4, 5, 3, 4, 4, 5])\n",
|
||||
" array([7, 2, 6, 7, 7, 2, 2, 6, 7, 2, 7, 2, 7, 6, 7])\n",
|
||||
" array([6, 3, 5, 3, 5, 6, 5, 5, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5])\n",
|
||||
" array([1, 4, 7, 4, 7, 4, 1, 7, 4, 1, 7, 4, 7, 4, 1, 7])\n",
|
||||
" array([5, 7, 1, 7, 1, 7, 1, 5, 7, 1, 5, 7, 1, 7, 1])\n",
|
||||
" array([3, 2, 4, 2, 4, 4, 2, 4, 2, 3, 4, 2, 3, 4, 4, 2, 3])\n",
|
||||
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 6, 1, 2, 1, 2, 1])\n",
|
||||
" array([5, 3, 5, 5, 4, 3, 4, 3, 4, 3, 5, 4, 3, 5, 4, 4, 3, 5])\n",
|
||||
" array([3, 6, 3, 7, 2, 7, 6, 7, 2, 7, 6, 2, 6, 7])\n",
|
||||
" array([6, 3, 5, 3, 5, 6, 3, 5, 5, 3, 5, 6, 3, 5, 5, 3, 6])\n",
|
||||
" array([1, 4, 7, 4, 7, 4, 1, 7, 4, 1, 7, 4, 7, 4, 7])\n",
|
||||
" array([5, 7, 1, 7, 1, 7, 1, 7, 7, 1, 5, 7, 1, 5, 7, 1])\n",
|
||||
" array([3, 2, 4, 2, 4, 3, 2, 4, 2, 4, 3, 2, 4, 2, 4])\n",
|
||||
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 6, 1, 2, 1, 2, 6, 1])\n",
|
||||
" array([5, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 5, 4])\n",
|
||||
" array([6, 7, 2, 7, 2, 6, 2, 6, 2, 2, 5, 7, 6, 7, 2, 7, 2, 7])\n",
|
||||
" array([6, 3, 5, 3, 5, 3, 6, 5, 5, 3, 5, 3, 6, 5, 6, 3, 5])\n",
|
||||
" array([1, 4, 7, 4, 7, 4, 7, 4, 7, 4, 7, 1, 4, 7])\n",
|
||||
" array([7, 5, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 5, 5, 3, 3, 3, 3, 3, 3, 3])\n",
|
||||
" array([7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 2, 2])\n",
|
||||
" array([2, 3, 2, 3, 2, 3, 2, 2, 2])\n",
|
||||
" array([1, 1, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 6, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 6, 3, 6, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 5, 3, 3, 3, 3, 3, 4, 4, 4, 4])\n",
|
||||
" array([6, 6, 2, 2, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([1, 4, 1, 4, 1, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 2, 6, 2, 6, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([6, 3, 6, 3, 6, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 5, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 6, 6, 6, 2, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 3, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([1, 1, 1, 4, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 2, 2, 6, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 6, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([5, 7, 5, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 3, 5, 3, 5, 3, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 6, 6, 6, 2, 2, 7, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 3, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([1, 1, 1, 4, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 2, 2, 2, 2, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 3, 3, 3, 3, 3, 6, 6, 7, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 3, 3, 5, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 6, 6, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([4, 4, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 2, 2, 2, 2, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 5, 5, 5, 5, 5, 3, 3, 3, 3, 3, 3, 6, 6, 6, 6])\n",
|
||||
" array([7, 5, 7, 1, 1, 7, 7, 7, 7, 5, 1, 1, 1, 1])\n",
|
||||
" array([2, 6, 1, 2, 1, 2, 1, 2, 6, 2, 6, 1, 1, 1, 2])\n",
|
||||
" array([6, 2, 7, 7, 2, 6, 2, 2, 7, 7, 7, 7, 2, 6])\n",
|
||||
" array([5, 3, 3, 4, 4, 3, 4, 3, 3, 4, 3, 5, 4, 3, 5, 4])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 5, 3, 3, 6, 3, 6, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([4, 1, 4, 7, 7, 4, 7, 4, 7, 1, 4, 7, 4, 6, 6, 7, 6])\n",
|
||||
" array([7, 5, 7, 1, 1, 7, 5, 1, 7, 1, 7, 5, 1, 7, 1, 5])\n",
|
||||
" array([2, 6, 2, 1, 1, 2, 1, 2, 1, 2, 6, 1, 2, 6, 1])\n",
|
||||
" array([6, 2, 7, 7, 6, 7, 2, 7, 6, 7, 2, 7])\n",
|
||||
" array([5, 3, 3, 4, 4, 3, 5, 3, 3, 5, 3, 4, 4, 4, 4])\n",
|
||||
" array([2, 3, 1, 1, 2, 4, 4, 2, 4, 2, 3, 2, 3, 4, 4, 2, 4])\n",
|
||||
" array([3, 6, 3, 5, 5, 3, 6, 5, 3, 5, 6, 3, 5, 3, 5])\n",
|
||||
" array([1, 4, 4, 7, 7, 1, 4, 5, 5, 7, 4, 7, 4, 7, 1, 4, 7])\n",
|
||||
" array([7, 5, 7, 1, 7, 5, 1, 7, 5, 1, 7, 5, 1, 7, 1, 7, 1])\n",
|
||||
" array([2, 6, 2, 1, 1, 2, 6, 2, 1, 1, 2, 1, 2, 1])\n",
|
||||
" array([6, 2, 7, 7, 6, 2, 7, 6, 7, 2, 7, 6, 7])\n",
|
||||
" array([5, 3, 3, 4, 4, 3, 4, 3, 5, 3, 5, 4, 4, 3, 4])\n",
|
||||
" array([2, 3, 2, 4, 4, 2, 3, 4, 2, 4, 2, 3, 4, 2, 3, 3, 4])\n",
|
||||
" array([6, 3, 3, 5, 5, 3, 3, 6, 3, 3, 5, 5, 5, 5])\n",
|
||||
" array([4, 1, 4, 7, 7, 4, 7, 4, 7, 4, 7, 1, 4, 7])\n",
|
||||
" array([7, 5, 7, 1, 7, 5, 1, 7, 1, 5, 7, 1, 7, 1, 1, 5, 1, 7, 1])\n",
|
||||
" array([2, 6, 2, 1, 1, 2, 1, 2, 1, 2, 6, 1, 2, 1])\n",
|
||||
" array([6, 2, 7, 7, 6, 7, 2, 7, 6, 7, 2, 7])\n",
|
||||
" array([5, 3, 3, 4, 5, 3, 4, 5, 3, 4, 3, 4, 3, 4, 3, 5, 4])\n",
|
||||
" array([2, 3, 2, 4, 4, 2, 3, 4, 2, 4, 2, 3, 4, 2, 4])\n",
|
||||
" array([3, 6, 3, 5, 5, 3, 5, 3, 6, 5, 3, 6, 5, 3, 3, 5])\n",
|
||||
" array([4, 1, 4, 7, 4, 7, 7, 4, 7, 4, 1, 1, 4, 7, 7])\n",
|
||||
" array([7, 5, 7, 1, 1, 7, 1, 7, 5, 5, 1, 7, 1, 7, 5, 1])\n",
|
||||
" array([2, 6, 2, 1, 1, 1, 1, 1, 1, 2, 6, 2, 2, 2, 6])\n",
|
||||
" array([6, 2, 7, 7, 6, 7, 6, 7, 2, 7, 2, 7])\n",
|
||||
" array([3, 5, 3, 4, 4, 5, 3, 4, 3, 4, 5, 3, 4, 3, 4])\n",
|
||||
" array([2, 3, 2, 4, 4, 2, 4, 2, 4, 2, 3, 4, 2, 4])\n",
|
||||
" array([6, 3, 3, 5, 5, 3, 6, 3, 5, 5, 3, 6, 5, 3, 6, 5])\n",
|
||||
" array([1, 4, 4, 7, 7, 4, 1, 7, 4, 4, 4, 7, 7, 7, 1])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 3, 2, 2])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 1, 7, 7, 7, 5, 7, 5])\n",
|
||||
" array([2, 6, 1, 2, 1, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 3, 3, 3, 3, 5, 5, 3, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 3, 6, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 7, 7, 7, 7, 7, 7, 1, 4, 4, 4])\n",
|
||||
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 3, 2, 3, 2, 2])\n",
|
||||
" array([1, 7, 5, 7, 7, 5, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([4, 3, 5, 3, 4, 3, 5, 5, 3, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 3, 3, 5, 3, 3, 3, 6, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 4, 7, 3, 7, 3, 1, 1, 4, 4, 1, 4, 4, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 2, 2, 2, 3, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([7, 5, 7, 1, 7, 5, 7, 7, 5, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 3, 5, 4, 3, 5, 3, 5, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([5, 6, 3, 3, 3, 3, 3, 3, 5, 5, 5, 5, 5, 6, 6, 6])\n",
|
||||
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 1, 4, 7, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 3, 2, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([1, 7, 5, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([4, 3, 3, 5, 3, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 5, 3, 3, 3, 3, 6, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([6, 2, 7, 6, 6, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 4, 1, 7, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 2, 3, 2, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([1, 7, 5, 7, 1, 1, 1, 1, 1, 7, 7, 7, 7, 5])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 3, 4, 3, 5, 3, 5, 5, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 3, 3, 5, 2, 3, 2, 6, 6, 3, 2, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 1, 4, 7, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7, 4])\n",
|
||||
" array([4, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7, 1, 1])\n",
|
||||
" array([2, 6, 2, 2, 2, 6, 2, 6, 2, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([4, 4, 4, 4, 4, 4, 4, 5, 2, 5, 5, 5, 5, 5, 2, 2, 2])\n",
|
||||
" array([6, 3, 3, 6, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 6, 6, 6, 2, 2, 4, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 4, 4, 4, 4, 4, 1, 1, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([6, 2, 2, 6, 2, 2, 6, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 4, 5, 5, 3, 3, 3, 5, 5, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 3, 3, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 5, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 2, 2, 3, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 6, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 2, 6, 2, 2, 6, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 3, 3, 3, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 3, 3, 3, 6, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 7, 5, 7, 7, 5, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 3, 3, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 2, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 1, 4, 4, 4, 1, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 2, 2, 2, 6, 2, 6, 1, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([5, 3, 3, 3, 5, 3, 3, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 3, 3, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 7, 2, 7, 2, 2, 6, 6, 6, 7, 6, 7, 7, 7])\n",
|
||||
" array([1, 4, 4, 4, 4, 1, 1, 1, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 2, 2, 6, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 5, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 6, 3, 3, 6, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 7, 5, 7, 5, 7, 7, 5, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 6, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([7, 6, 2, 6, 2, 2, 2, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 4, 1, 7, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1, 5, 7, 7, 7])\n",
|
||||
" array([4, 3, 3, 5, 3, 4, 4, 4, 3, 3, 3, 5, 5, 4, 4, 4])\n",
|
||||
" array([6, 3, 3, 5, 2, 3, 2, 3, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 2, 6, 1, 1, 1, 1, 1, 1, 2, 6, 2, 6, 2, 2])\n",
|
||||
" array([2, 2, 3, 4, 2, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 7, 6, 2, 6, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 1, 4, 4, 4, 7, 4, 1, 4, 1, 4, 4, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 7, 5, 7, 7, 7, 5, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 4, 5, 3, 3, 5, 4, 4, 4, 4, 4, 3, 3, 3, 5])\n",
|
||||
" array([3, 3, 6, 5, 3, 3, 6, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([1, 2, 2, 6, 2, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 4, 7, 1, 4, 1, 4, 4, 1, 1, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 7, 7, 5, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1, 5])\n",
|
||||
" array([3, 3, 5, 4, 3, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([5, 6, 3, 3, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 1, 2, 6, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([2, 4, 2, 3, 2, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 7, 2, 2, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 4, 1, 1, 7, 4, 1, 4, 1, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([7, 7, 5, 1, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 3, 6, 5, 3, 6, 3, 6, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 2, 6, 1, 2, 2, 2, 2, 1, 2, 2, 6, 1, 1, 1, 1])\n",
|
||||
" array([2, 4, 2, 3, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([2, 6, 7, 2, 6, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([4, 4, 1, 7, 4, 4, 1, 4, 1, 4, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 7, 7, 5, 7, 5, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([4, 5, 3, 3, 3, 5, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([5, 6, 3, 3, 3, 6, 6, 3, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 1, 2, 6, 2, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([1, 2, 3, 1, 2, 2, 2, 3, 2, 3, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 2, 2, 6, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 1, 1, 1, 7, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 3, 3, 3, 3, 3, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([5, 2, 2, 6, 3, 3, 3, 3, 6, 6, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 6, 2, 2, 2, 2, 2, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([7, 5, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1, 5, 5])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 7, 7, 7, 7, 7, 6, 6, 2, 2])\n",
|
||||
" array([1, 1, 1, 7, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([5, 3, 3, 3, 3, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4, 4, 4, 5, 5, 4, 5, 5,\n",
|
||||
" 4])\n",
|
||||
" array([3, 6, 3, 5, 5, 5, 5, 5, 5, 6, 3, 3, 6, 3, 3])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 6, 2, 6, 2])\n",
|
||||
" array([1, 7, 7, 1, 7, 1, 7, 1, 7, 1, 7, 1, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 2, 2, 6, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 7, 1, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5])\n",
|
||||
" array([3, 6, 3, 3, 3, 6, 3, 6, 3, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([6, 2, 2, 1, 2, 2, 2, 2, 6, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([7, 1, 1, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 5, 5, 5])\n",
|
||||
" array([2, 3, 2, 2, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 6, 2, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 7, 1, 1, 1, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 3, 3, 3, 3, 3, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 6, 3, 3, 6, 3, 3, 6, 3, 6, 3, 6, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 6, 2, 2, 2, 2, 6, 2, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([7, 1, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 5, 5])\n",
|
||||
" array([2, 3, 2, 2, 3, 2, 2, 3, 2, 3, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 3, 6, 2, 2, 2, 6, 6, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([1, 7, 7, 7, 7, 7, 7, 1, 1, 4, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 3, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 5, 4])\n",
|
||||
" array([3, 6, 3, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||
" array([2, 6, 2, 2, 6, 2, 2, 2, 2, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([7, 7, 1, 7, 7, 7, 7, 1, 1, 1, 1, 1, 5, 5])\n",
|
||||
" array([5, 3, 6, 3, 5, 3, 5, 3, 5, 6, 3, 5, 3, 6, 5])\n",
|
||||
" array([2, 1, 2, 6, 1, 2, 1, 2, 6, 1, 2, 6, 2, 2, 2, 6, 6, 6, 2, 6, 1, 2,\n",
|
||||
" 1])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 5, 7, 5, 7, 7, 5])\n",
|
||||
" array([2, 2, 3, 3, 3, 2, 2, 3, 4, 4, 4, 4, 4, 2, 2])\n",
|
||||
" array([3, 3, 5, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 5, 5])\n",
|
||||
" array([1, 4, 7, 5, 5, 7, 4, 1, 4, 4, 1, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 2, 6, 7, 7, 7, 7, 7, 7, 6, 6, 6, 2, 2, 6])\n",
|
||||
" array([2, 3, 6, 2, 6, 5, 5, 5, 5, 5, 5, 6, 3, 3, 6, 3, 6, 6])\n",
|
||||
" array([2, 6, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 6, 6])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 5, 7, 5, 7, 7])\n",
|
||||
" array([2, 2, 3, 4, 4, 4, 4, 4, 4, 2, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3])\n",
|
||||
" array([5, 3, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 5, 3, 3])\n",
|
||||
" array([1, 4, 4, 7, 7, 4, 1, 7, 4, 7, 4, 1, 7, 4, 7])\n",
|
||||
" array([6, 2, 7, 7, 7, 7, 7, 7, 6, 2, 6, 2])\n",
|
||||
" array([2, 3, 2, 5, 5, 5, 5, 5, 5, 2, 3, 2, 3, 2, 2])\n",
|
||||
" array([2, 2, 6, 1, 2, 2, 2, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 7, 7, 7, 5])\n",
|
||||
" array([2, 3, 4, 2, 4, 4, 4, 4, 4, 2, 3, 2, 2, 2])\n",
|
||||
" array([3, 5, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 4, 1, 4, 1, 1, 1, 4])\n",
|
||||
" array([6, 2, 7, 7, 7, 7, 7, 7, 2, 2, 6, 6])\n",
|
||||
" array([6, 3, 3, 3, 5, 3, 5, 5, 5, 5, 5, 3, 6, 3, 3, 3])\n",
|
||||
" array([6, 2, 2, 1, 1, 1, 1, 1, 1, 2, 6, 2, 2, 2])\n",
|
||||
" array([7, 5, 7, 1, 7, 7, 1, 1, 1, 1, 1, 7, 7, 7, 5, 7, 5])\n",
|
||||
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 3, 2, 2, 2, 2])\n",
|
||||
" array([5, 3, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 1, 1, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 7, 7, 7, 7, 7, 2, 6, 6, 2])\n",
|
||||
" array([6, 3, 3, 5, 5, 5, 5, 5, 5, 3, 3, 3, 3, 6])\n",
|
||||
" array([6, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 6, 6])\n",
|
||||
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 5, 7, 5, 7, 7, 5])\n",
|
||||
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 3, 2, 3, 2, 2])\n",
|
||||
" array([5, 3, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 1, 1, 4, 4])\n",
|
||||
" array([6, 2, 7, 7, 7, 7, 7, 7, 2, 2, 6, 6])]\n",
|
||||
"y_train\n",
|
||||
" [5. 3. 6. 1. 0. 4. 2. 5. 3. 6. 1. 0. 4. 2. 5. 3. 6. 1. 0. 4. 2. 5. 3. 6.\n",
|
||||
" 1. 0. 4. 2. 5. 3. 6. 1. 0. 4. 2. 4. 5. 3. 2. 0. 1. 6. 4. 5. 3. 2. 0. 1.\n",
|
||||
" 6. 4. 5. 3. 2. 0. 1. 6. 4. 5. 3. 2. 0. 1. 6. 4. 5. 3. 2. 0. 1. 6. 0. 1.\n",
|
||||
" 2. 3. 4. 5. 6. 0. 1. 2. 3. 4. 5. 6. 0. 1. 2. 3. 4. 5. 6. 0. 1. 2. 3. 4.\n",
|
||||
" 5. 6. 0. 1. 2. 3. 4. 5. 6. 2. 3. 6. 4. 0. 5. 1. 2. 3. 6. 4. 0. 5. 1. 2.\n",
|
||||
" 3. 6. 4. 0. 5. 1. 2. 3. 6. 4. 0. 5. 1. 2. 3. 6. 4. 0. 5. 1. 5. 0. 6. 3.\n",
|
||||
" 4. 1. 2. 5. 0. 6. 3. 4. 1. 2. 5. 0. 6. 3. 4. 1. 2. 5. 0. 6. 3. 4. 1. 2.\n",
|
||||
" 5. 0. 6. 3. 4. 1. 2. 0. 2. 1. 4. 3. 6. 5. 0. 2. 1. 4. 3. 6. 5. 0. 2. 1.\n",
|
||||
" 4. 3. 6. 5. 0. 2. 1. 4. 3. 6. 5. 0. 2. 1. 4. 3. 6. 5. 1. 0. 4. 2. 3. 5.\n",
|
||||
" 6. 1. 0. 4. 2. 3. 5. 6. 1. 0. 4. 2. 3. 5. 6. 1. 0. 4. 2. 3. 5. 6. 1. 0.\n",
|
||||
" 4. 2. 3. 5. 6. 6. 2. 4. 0. 5. 1. 3. 6. 2. 4. 0. 5. 1. 3. 6. 2. 4. 0. 5.\n",
|
||||
" 1. 3. 6. 2. 4. 0. 5. 1. 3. 6. 2. 4. 0. 5. 1. 3. 6. 0. 5. 2. 3. 4. 1. 6.\n",
|
||||
" 0. 5. 2. 3. 4. 1. 6. 0. 5. 2. 3. 4. 1. 6. 0. 5. 2. 3. 4. 1. 6. 0. 5. 2.\n",
|
||||
" 3. 4. 1. 6. 4. 5. 0. 2. 1. 3. 6. 4. 5. 0. 2. 1. 3. 6. 4. 5. 0. 2. 1. 3.\n",
|
||||
" 6. 4. 5. 0. 2. 1. 3. 6. 4. 5. 0. 2. 1. 3. 2. 6. 4. 0. 1. 5. 3. 2. 6. 4.\n",
|
||||
" 0. 1. 5. 3. 2. 6. 4. 0. 1. 5. 3. 2. 6. 4. 0. 1. 5. 3. 2. 6. 4. 0. 1. 5.\n",
|
||||
" 3. 3. 4. 0. 1. 6. 2. 5. 3. 4. 0. 1. 6. 2. 5. 3. 4. 0. 1. 6. 2. 5. 3. 4.\n",
|
||||
" 0. 1. 6. 2. 5. 3. 4. 0. 1. 6. 2. 5. 2. 5. 3. 6. 0. 1. 4. 2. 5. 3. 6. 0.\n",
|
||||
" 1. 4. 2. 5. 3. 6. 0. 1. 4. 2. 5. 3. 6. 0. 1. 4. 2. 5. 3. 6. 0. 1. 4. 2.\n",
|
||||
" 5. 3. 0. 1. 4. 6. 2. 5. 3. 0. 1. 4. 6. 2. 5. 3. 0. 1. 4. 6. 2. 5. 3. 0.\n",
|
||||
" 1. 4. 6. 2. 5. 3. 0. 1. 4. 6. 1. 4. 6. 2. 0. 3. 5. 1. 4. 6. 2. 0. 3. 5.\n",
|
||||
" 1. 4. 6. 2. 0. 3. 5. 1. 4. 6. 2. 0. 3. 5. 1. 4. 6. 2. 0. 3. 5.]\n",
|
||||
"x_test\n",
|
||||
" [array([4, 1, 1, 4, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 4, 4])\n",
|
||||
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 2, 2, 3, 2, 3])\n",
|
||||
" array([5, 3, 3, 4, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5])\n",
|
||||
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 3, 6, 5, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 7, 5, 1, 7, 7, 5, 7, 7, 1, 1, 1, 1, 1, 5, 5])\n",
|
||||
" array([1, 4, 7, 4, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 2, 4, 3, 2, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 5, 3, 4, 4, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 2, 2, 7, 7, 7, 6, 6, 7, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 2, 6, 1, 1, 2, 2, 2, 2, 6, 6, 6, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 5, 2, 2, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([1, 5, 7, 7, 7, 5, 7, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([7, 4, 4, 7, 4, 4, 4, 4, 7, 1, 7, 7, 7, 7, 1, 1, 7])\n",
|
||||
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||
" array([3, 5, 3, 4, 3, 3, 3, 3, 5, 5, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 7, 7, 7, 7, 7, 2, 2, 6, 6])\n",
|
||||
" array([2, 2, 1, 6, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 5, 3, 3, 3, 3, 6, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1, 1])\n",
|
||||
" array([4, 1, 4, 4, 4, 4, 4, 1, 1, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 2, 2, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([5, 3, 3, 4, 5, 3, 3, 3, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 6, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 6, 6])\n",
|
||||
" array([3, 6, 3, 5, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([5, 7, 7, 1, 7, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1])\n",
|
||||
" array([1, 4, 4, 7, 7, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7])\n",
|
||||
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 3, 2, 2, 2])\n",
|
||||
" array([4, 3, 5, 3, 3, 5, 3, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||
" array([6, 2, 7, 6, 6, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||
" array([2, 2, 6, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||
" array([3, 6, 3, 5, 3, 6, 3, 6, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||
" array([7, 5, 7, 1, 7, 5, 1, 1, 1, 1, 1, 7, 7, 7, 5])]\n",
|
||||
"y_test\n",
|
||||
" [3. 2. 0. 5. 4. 1. 6. 3. 2. 0. 5. 4. 1. 6. 3. 2. 0. 5. 4. 1. 6. 3. 2. 0.\n",
|
||||
" 5. 4. 1. 6. 3. 2. 0. 5. 4. 1. 6.]\n",
|
||||
"> \u001b[0;32m/tmp/ipykernel_97850/1264473745.py\u001b[0m(32)\u001b[0;36mcreateTrainTest\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 30 \u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'y_test\\n'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 31 \u001b[0;31m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m---> 32 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mshapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 33 \u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 34 \u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ipdb> x_train.shape\n",
|
||||
"(525,)\n",
|
||||
"ipdb> x_test.shape\n",
|
||||
"(35,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"accuracies_full = dict()\n",
|
||||
"accuracies_small = dict()\n",
|
||||
"accuracies_last = dict()\n",
|
||||
"\n",
|
||||
"for current_PID in sorted(data.PID.unique()):\n",
|
||||
" accuracies_full[current_PID], pred_label, test_label = runSVMS(createTrainTest([current_PID], Task_IDs, StartIndexOffset, EndIndexOffset, shapes=True))\n",
|
||||
" # Only the first 5\n",
|
||||
" accuracies_small[current_PID], pred_label, test_label = runSVMS(createTrainTest([current_PID], Task_IDs, StartIndexOffset, EndIndexOffset, shapes=True), 5)\n",
|
||||
" # Only the last 5\n",
|
||||
" accuracies_last[current_PID], pred_label, test_label = runSVMS(createTrainTest([current_PID], Task_IDs, StartIndexOffset, EndIndexOffset, shapes=True), 5, last_elements=True)\n",
|
||||
" #pdb.set_trace()\n",
|
||||
"print(accuracies_full)\n",
|
||||
"print(accuracies_small)\n",
|
||||
"print(accuracies_last)\n",
|
||||
"print(\"mean full\", np.array(list(accuracies_full.values())).mean())\n",
|
||||
"print(\"mean small\", np.array(list(accuracies_small.values())).mean())\n",
|
||||
"print(\"mean last\", np.array(list(accuracies_last.values())).mean())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fdd4c915",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"len(g.groups.keys())\n",
|
||||
"g.groups.keys()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -0,0 +1,687 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "3aed8aec",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2021-09-27 15:31:30.518074: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import datetime\n",
|
||||
"import time,pdb\n",
|
||||
"import json\n",
|
||||
"import random\n",
|
||||
"import statistics\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import tensorflow as tf\n",
|
||||
"from tensorflow import keras\n",
|
||||
"from sklearn import svm\n",
|
||||
"from sklearn.model_selection import GridSearchCV \n",
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"from sklearn.metrics import accuracy_score\n",
|
||||
"from tensorflow.keras.layers import *\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from tensorflow.keras.models import Sequential\n",
|
||||
"from tensorflow.keras.optimizers import *\n",
|
||||
"from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, Callback\n",
|
||||
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
||||
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||||
"from sklearn.metrics import mean_squared_error\n",
|
||||
"from sklearn.metrics import accuracy_score\n",
|
||||
"import tqdm\n",
|
||||
"from multiprocessing import Pool\n",
|
||||
"import os\n",
|
||||
"from tensorflow.compat.v1.keras.layers import Bidirectional, CuDNNLSTM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "817f7108",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"available PIDs [ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16.]\n",
|
||||
"available TaskIDs [0. 1. 2. 3. 4. 5. 6.]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>Timestamp</th>\n",
|
||||
" <th>Event</th>\n",
|
||||
" <th>TaskID</th>\n",
|
||||
" <th>Part</th>\n",
|
||||
" <th>PID</th>\n",
|
||||
" <th>TextRule</th>\n",
|
||||
" <th>Rule</th>\n",
|
||||
" <th>Type</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>1.575388e+12</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>Cmd</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>1.575388e+12</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>Toolbar</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>1.575388e+12</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>Cmd</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>1.575388e+12</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>Cmd</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>1.575388e+12</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>Cmd</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8376</th>\n",
|
||||
" <td>1.603898e+12</td>\n",
|
||||
" <td>7</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>Toolbar</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8377</th>\n",
|
||||
" <td>1.603898e+12</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>Cmd</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8378</th>\n",
|
||||
" <td>1.603898e+12</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>Cmd</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8379</th>\n",
|
||||
" <td>1.603898e+12</td>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>Toolbar</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8380</th>\n",
|
||||
" <td>1.603898e+12</td>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>Toolbar</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>8381 rows × 8 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" Timestamp Event TaskID Part PID \\\n",
|
||||
"0 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"1 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||
"2 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||
"3 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"4 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"... ... ... ... ... ... \n",
|
||||
"8376 1.603898e+12 7 6.0 5.0 16.0 \n",
|
||||
"8377 1.603898e+12 2 6.0 5.0 16.0 \n",
|
||||
"8378 1.603898e+12 2 6.0 5.0 16.0 \n",
|
||||
"8379 1.603898e+12 6 6.0 5.0 16.0 \n",
|
||||
"8380 1.603898e+12 6 6.0 5.0 16.0 \n",
|
||||
"\n",
|
||||
" TextRule Rule Type \n",
|
||||
"0 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"1 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"2 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"3 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"4 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"... ... ... ... \n",
|
||||
"8376 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Toolbar \n",
|
||||
"8377 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Cmd \n",
|
||||
"8378 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Cmd \n",
|
||||
"8379 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Toolbar \n",
|
||||
"8380 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Toolbar \n",
|
||||
"\n",
|
||||
"[8381 rows x 8 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"study_data_path = \"../IntentData/\"\n",
|
||||
"data = pd.read_pickle(study_data_path + \"/Preprocessing_data/clean_data.pkl\")\n",
|
||||
"#val_data = pd.read_pickle(study_data_path + \"/Preprocessing_data/clean_data_condition2.pkl\")\n",
|
||||
"\n",
|
||||
"print(\"available PIDs\", data.PID.unique())\n",
|
||||
"\n",
|
||||
"print(\"available TaskIDs\", data.TaskID.unique())\n",
|
||||
"\n",
|
||||
"data.Event.unique()\n",
|
||||
"data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ab778228",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"count 560.000000\n",
|
||||
"mean 14.966071\n",
|
||||
"std 2.195440\n",
|
||||
"min 8.000000\n",
|
||||
"25% 14.000000\n",
|
||||
"50% 15.000000\n",
|
||||
"75% 16.000000\n",
|
||||
"max 28.000000\n",
|
||||
"Name: Event, dtype: float64"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data.groupby([\"PID\", \"Part\", \"TaskID\"])[\"Event\"].count().describe()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "32550f71",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"Task_IDs = list(range(0,7))\n",
|
||||
"\n",
|
||||
"# grouping by part is needed to have one ruleset for the whole part\n",
|
||||
"g = data.groupby([\"PID\", \"Part\", \"TaskID\"])\n",
|
||||
"df_all = []"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "f6fecc2f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def createTrainTestalaSven(test_IDs, task_IDs, window_size, stride, shapes=False, val_IDs=None):\n",
|
||||
" if not isinstance(test_IDs, list):\n",
|
||||
" raise ValueError(\"Test_IDs are not a list\")\n",
|
||||
" if not isinstance(task_IDs, list):\n",
|
||||
" raise ValueError(\"Task_IDs are not a list\")\n",
|
||||
" # Fill data arrays\n",
|
||||
" all_elem = []\n",
|
||||
" for current in g.groups.keys():\n",
|
||||
" c = g.get_group(current)\n",
|
||||
" if (c.TaskID.isin(task_IDs).all()):\n",
|
||||
" \n",
|
||||
" new_data = c.Event.values\n",
|
||||
" stepper = 0\n",
|
||||
" while stepper <= (len(new_data)-window_size-1):\n",
|
||||
" tmp = new_data[stepper:stepper + window_size]\n",
|
||||
" x = tmp[:-1]\n",
|
||||
" y = tmp[-1]\n",
|
||||
" stepper += stride\n",
|
||||
" \n",
|
||||
" if (c.PID.isin(test_IDs).all()):\n",
|
||||
" all_elem.append([\"Test\", x, y])\n",
|
||||
" elif (c.PID.isin(val_IDs).all()):\n",
|
||||
" all_elem.append([\"Val\", x, y])\n",
|
||||
" else:\n",
|
||||
" all_elem.append([\"Train\", x, y])\n",
|
||||
" df_tmp = pd.DataFrame(all_elem, columns =[\"Split\", \"X\", \"Y\"])\n",
|
||||
" turbo = []\n",
|
||||
" for s in df_tmp.Split.unique():\n",
|
||||
" dfX = df_tmp[df_tmp.Split == s]\n",
|
||||
" max_amount = dfX.groupby([\"Y\"]).count().max().X\n",
|
||||
" for y in dfX.Y.unique():\n",
|
||||
" df_turbotmp = dfX[dfX.Y == y]\n",
|
||||
" turbo.append(df_turbotmp)\n",
|
||||
" turbo.append(df_turbotmp.sample(max_amount-len(df_turbotmp), replace=True))\n",
|
||||
" # if len(df_turbotmp) < max_amount:\n",
|
||||
"\n",
|
||||
" df_tmp = pd.concat(turbo)\n",
|
||||
" x_train, y_train = df_tmp[df_tmp.Split == \"Train\"].X.values, df_tmp[df_tmp.Split == \"Train\"].Y.values\n",
|
||||
" x_test, y_test = df_tmp[df_tmp.Split == \"Test\"].X.values, df_tmp[df_tmp.Split == \"Test\"].Y.values\n",
|
||||
" x_val, y_val = df_tmp[df_tmp.Split == \"Val\"].X.values, df_tmp[df_tmp.Split == \"Val\"].Y.values\n",
|
||||
" \n",
|
||||
" x_train = np.expand_dims(np.stack(x_train), axis=2)\n",
|
||||
" y_train = np.array(y_train)\n",
|
||||
" x_test = np.expand_dims(np.stack(x_test), axis=2)\n",
|
||||
" y_test = np.array(y_test)\n",
|
||||
" if len(x_val) > 0:\n",
|
||||
" x_val = np.expand_dims(np.stack(x_val), axis=2)\n",
|
||||
" y_val = np.array(y_val)\n",
|
||||
" return(x_train, y_train, x_test, y_test, x_val, y_val)\n",
|
||||
" return(x_train, y_train, x_test, y_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "b8f92bc1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def createTrainTest(test_IDs, task_IDs, window_size, stride, shapes=False, val_IDs=None):\n",
|
||||
" if not isinstance(test_IDs, list):\n",
|
||||
" raise ValueError(\"Test_IDs are not a list\")\n",
|
||||
" if not isinstance(task_IDs, list):\n",
|
||||
" raise ValueError(\"Task_IDs are not a list\")\n",
|
||||
" # Fill data arrays\n",
|
||||
" y_train = []\n",
|
||||
" x_train = []\n",
|
||||
" y_test = []\n",
|
||||
" x_test = []\n",
|
||||
" x_val = []\n",
|
||||
" y_val = []\n",
|
||||
" \n",
|
||||
" for current in g.groups.keys():\n",
|
||||
" c = g.get_group(current)\n",
|
||||
" if (c.TaskID.isin(task_IDs).all()):\n",
|
||||
" \n",
|
||||
" new_data = c.Event.values\n",
|
||||
" stepper = 0\n",
|
||||
" while stepper <= (len(new_data)-window_size-1):\n",
|
||||
" tmp = new_data[stepper:stepper + window_size]\n",
|
||||
" pdb.set_trace()\n",
|
||||
" x = tmp[:-1]\n",
|
||||
" y = tmp[-1]\n",
|
||||
" stepper += stride\n",
|
||||
" if (c.PID.isin(test_IDs).all()):\n",
|
||||
" if y == 6:\n",
|
||||
" y_test.append(y)\n",
|
||||
" x_test.append(x)\n",
|
||||
" y_test.append(y)\n",
|
||||
" x_test.append(x)\n",
|
||||
" elif (c.PID.isin(val_IDs).all()):\n",
|
||||
" if y == 6:\n",
|
||||
" y_val.append(y)\n",
|
||||
" x_val.append(x)\n",
|
||||
" y_val.append(y)\n",
|
||||
" x_val.append(x)\n",
|
||||
" else:\n",
|
||||
" if y == 6:\n",
|
||||
" y_train.append(y)\n",
|
||||
" x_train.append(x)\n",
|
||||
" y_train.append(y)\n",
|
||||
" x_train.append(x)\n",
|
||||
" x_train = np.array(x_train)\n",
|
||||
" y_train = np.array(y_train)\n",
|
||||
" x_test = np.array(x_test)\n",
|
||||
" y_test = np.array(y_test)\n",
|
||||
" x_val = np.array(x_val)\n",
|
||||
" y_val = np.array(y_val)\n",
|
||||
" pdb.set_trace()\n",
|
||||
" if (shapes):\n",
|
||||
" print(x_train.shape)\n",
|
||||
" print(y_train.shape)\n",
|
||||
" print(x_test.shape)\n",
|
||||
" print(y_test.shape)\n",
|
||||
" print(x_val.shape)\n",
|
||||
" print(y_val.shape)\n",
|
||||
" print(np.unique(y_test))\n",
|
||||
" print(np.unique(y_train))\n",
|
||||
" if len(x_val) > 0:\n",
|
||||
" return(x_train, y_train, x_test, y_test, x_val, y_val)\n",
|
||||
" return (x_train, y_train, x_test, y_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "e56fbc58",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"maxlen = 1000\n",
|
||||
"lens = []\n",
|
||||
"for current in g.groups.keys():\n",
|
||||
" c = g.get_group(current)\n",
|
||||
" lens.append(len(c.Event.values))\n",
|
||||
" maxlen = min(maxlen, len(c.Event.values))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "c02cbdae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Number of trees in random forest\n",
|
||||
"n_estimators = np.arange(5,100, 5)\n",
|
||||
"# Number of features to consider at every split\n",
|
||||
"max_features = ['sqrt']\n",
|
||||
"# Maximum number of levels in tree\n",
|
||||
"max_depth = np.arange(5,100, 5)\n",
|
||||
"# Minimum number of samples required to split a node\n",
|
||||
"min_samples_split = np.arange(2,10, 1)\n",
|
||||
"# Minimum number of samples required at each leaf node\n",
|
||||
"min_samples_leaf = np.arange(2,5, 1)\n",
|
||||
"# Method of selecting samples for training each tree\n",
|
||||
"bootstrap = [True, False]\n",
|
||||
"\n",
|
||||
"# Create the random grid\n",
|
||||
"param_grid = {'n_estimators': n_estimators,\n",
|
||||
" 'max_features': max_features,\n",
|
||||
" 'max_depth': max_depth,\n",
|
||||
" 'min_samples_split': min_samples_split,\n",
|
||||
" 'min_samples_leaf': min_samples_leaf,\n",
|
||||
" 'bootstrap': bootstrap}\n",
|
||||
"\n",
|
||||
"grid = GridSearchCV(RandomForestClassifier(), param_grid, refit = True, verbose = 0, return_train_score=True) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "c2bcfe7f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def doTrainSlideWindowNoPad(currentPid):\n",
|
||||
" print(f\"doTrain: {currentPid}\")\n",
|
||||
" dfs = []\n",
|
||||
" for window_size in range(8, 15): \n",
|
||||
" (x_train, y_train, x_test, y_test) = createTrainTest([currentPid], Task_IDs, window_size, 1, False, [200])\n",
|
||||
" print(f\"doTrain: created TrainTestsplit\")\n",
|
||||
"\n",
|
||||
" # print(\"window_size\", 5, \"PID\", currentPid, \"samples\", x_train.shape[0], \"generated_samples\", \"samples\", x_train_window.shape[0])\n",
|
||||
"\n",
|
||||
" grid.fit(x_train, y_train)\n",
|
||||
" print(\"fitted\")\n",
|
||||
" # y_pred = grid.predict(x_test)\n",
|
||||
"\n",
|
||||
" df_params = pd.DataFrame(grid.cv_results_[\"params\"])\n",
|
||||
" df_params[\"Mean_test\"] = grid.cv_results_[\"mean_test_score\"]\n",
|
||||
" df_params[\"Mean_train\"] = grid.cv_results_[\"mean_train_score\"]\n",
|
||||
" df_params[\"STD_test\"] = grid.cv_results_[\"std_test_score\"]\n",
|
||||
" df_params[\"STD_train\"] = grid.cv_results_[\"std_train_score\"]\n",
|
||||
" df_params['Window_Size'] = window_size\n",
|
||||
" df_params['PID'] = currentPid\n",
|
||||
" # df_params[\"Accuracy\"] = accuracy_score(y_pred, y_test)\n",
|
||||
" dfs.append(df_params)\n",
|
||||
"\n",
|
||||
" return pd.concat(dfs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "9e3d86f1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"doTrain: 1\n",
|
||||
"> \u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m(23)\u001b[0;36mcreateTrainTest\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 21 \u001b[0;31m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 22 \u001b[0;31m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m---> 23 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 24 \u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 25 \u001b[0;31m \u001b[0mstepper\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> tmp\n",
|
||||
"array([4, 1, 1, 4, 4, 7, 7, 7])\n",
|
||||
"ipdb> new_data\n",
|
||||
"array([4, 1, 1, 4, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 4, 4])\n",
|
||||
"ipdb> current\n",
|
||||
"(1.0, 1.0, 0.0)\n",
|
||||
"ipdb> print(c)\n",
|
||||
" Timestamp Event TaskID Part PID \\\n",
|
||||
"0 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"1 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||
"2 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||
"3 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"4 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"5 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||
"6 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||
"7 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||
"8 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||
"9 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||
"10 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||
"11 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"12 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||
"13 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"14 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"15 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||
"\n",
|
||||
" TextRule Rule Type \n",
|
||||
"0 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"1 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"2 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"3 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"4 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"5 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"6 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"7 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"8 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"9 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"10 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"11 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"12 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||
"13 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"14 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"15 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||
"ipdb> print(c.TextRule)\n",
|
||||
"0 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"1 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"2 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"3 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"4 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"5 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"6 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"7 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"8 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"9 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"10 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"11 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"12 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"13 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"14 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"15 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||
"Name: TextRule, dtype: object\n",
|
||||
"ipdb> print(c.Event)\n",
|
||||
"0 4\n",
|
||||
"1 1\n",
|
||||
"2 1\n",
|
||||
"3 4\n",
|
||||
"4 4\n",
|
||||
"5 7\n",
|
||||
"6 7\n",
|
||||
"7 7\n",
|
||||
"8 7\n",
|
||||
"9 7\n",
|
||||
"10 7\n",
|
||||
"11 4\n",
|
||||
"12 1\n",
|
||||
"13 4\n",
|
||||
"14 4\n",
|
||||
"15 4\n",
|
||||
"Name: Event, dtype: int64\n",
|
||||
"ipdb> val\n",
|
||||
"*** NameError: name 'val' is not defined\n",
|
||||
"ipdb> val_IDs\n",
|
||||
"[200]\n",
|
||||
"--KeyboardInterrupt--\n",
|
||||
"\n",
|
||||
"KeyboardInterrupt: Interrupted by user\n",
|
||||
"> \u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m(22)\u001b[0;36mcreateTrainTest\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 20 \u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mstepper\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 21 \u001b[0;31m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m---> 22 \u001b[0;31m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 23 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 24 \u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"--KeyboardInterrupt--\n",
|
||||
"\n",
|
||||
"KeyboardInterrupt: Interrupted by user\n",
|
||||
"> \u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m(23)\u001b[0;36mcreateTrainTest\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 21 \u001b[0;31m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 22 \u001b[0;31m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m---> 23 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 24 \u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 25 \u001b[0;31m \u001b[0mstepper\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> q\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "BdbQuit",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mBdbQuit\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/tmp/ipykernel_90176/1128965594.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdoTrainSlideWindowNoPad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;32m/tmp/ipykernel_90176/2629087375.py\u001b[0m in \u001b[0;36mdoTrainSlideWindowNoPad\u001b[0;34m(currentPid)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mdfs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mwindow_size\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m15\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreateTrainTest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcurrentPid\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTask_IDs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"doTrain: created TrainTestsplit\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m in \u001b[0;36mcreateTrainTest\u001b[0;34m(test_IDs, task_IDs, window_size, stride, shapes, val_IDs)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mstepper\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m in \u001b[0;36mcreateTrainTest\u001b[0;34m(test_IDs, task_IDs, window_size, stride, shapes, val_IDs)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mstepper\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniconda3/envs/intentPrediction/lib/python3.9/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;31m# None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'line'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'call'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniconda3/envs/intentPrediction/lib/python3.9/bdb.py\u001b[0m in \u001b[0;36mdispatch_line\u001b[0;34m(self, frame)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstop_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbreak_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 113\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 114\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dispatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mBdbQuit\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"doTrainSlideWindowNoPad(1)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_6.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_6.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_6.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_6.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_6.pkl
Normal file
Binary file not shown.
167
keyboard_and_mouse/networks.py
Normal file
167
keyboard_and_mouse/networks.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class fc_block(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, norm, activation_fn):
|
||||
super(fc_block, self).__init__()
|
||||
block = nn.Sequential()
|
||||
block.add_module('linear', nn.Linear(in_channels, out_channels))
|
||||
if norm:
|
||||
block.add_module('batchnorm', nn.BatchNorm1d(out_channels))
|
||||
if activation_fn is not None:
|
||||
block.add_module('activation', activation_fn())
|
||||
|
||||
self.block = block
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
class ActionDemoEncoder(nn.Module):
|
||||
def __init__(self, args, pooling):
|
||||
super(ActionDemoEncoder, self).__init__()
|
||||
hidden_size = args.demo_hidden
|
||||
self.hidden_size = hidden_size
|
||||
self.bs = args.batch_size
|
||||
|
||||
len_action_predicates = 35 # max_action_len
|
||||
self.action_embed = nn.Embedding(len_action_predicates, hidden_size)
|
||||
|
||||
feat2hidden = nn.Sequential()
|
||||
feat2hidden.add_module(
|
||||
'fc_block1', fc_block(hidden_size, hidden_size, False, nn.ReLU))
|
||||
self.feat2hidden = feat2hidden
|
||||
|
||||
self.pooling = pooling
|
||||
|
||||
if 'lstm' in self.pooling:
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, batch_data):
|
||||
batch_data = batch_data.view(-1,1)
|
||||
stacked_demo_feat = self.action_embed(batch_data)
|
||||
stacked_demo_feat = self.feat2hidden(stacked_demo_feat)
|
||||
batch_demo_feat = []
|
||||
start = 0
|
||||
|
||||
for length in range(0,batch_data.shape[0]):
|
||||
if length == 0:
|
||||
feat = stacked_demo_feat[0:1, :]
|
||||
else:
|
||||
feat = stacked_demo_feat[(length-1):length, :]
|
||||
if len(feat.size()) == 3:
|
||||
feat = feat.unsqueeze(0)
|
||||
|
||||
if self.pooling == 'max':
|
||||
feat = torch.max(feat, 0)[0]
|
||||
elif self.pooling == 'avg':
|
||||
feat = torch.mean(feat, 0)
|
||||
elif self.pooling == 'lstmavg':
|
||||
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
|
||||
lstm_out = lstm_out.view(len(feat), -1)
|
||||
feat = torch.mean(lstm_out, 0)
|
||||
elif self.pooling == 'lstmlast':
|
||||
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
|
||||
lstm_out = lstm_out.view(len(feat), -1)
|
||||
feat = lstm_out[-1]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
batch_demo_feat.append(feat)
|
||||
|
||||
demo_emb = torch.stack(batch_demo_feat, 0)
|
||||
demo_emb = demo_emb.view(self.bs, 35, -1)
|
||||
return demo_emb
|
||||
|
||||
class PredicateClassifier(nn.Module):
|
||||
|
||||
def __init__(self, args,):
|
||||
super(PredicateClassifier, self).__init__()
|
||||
hidden_size = args.demo_hidden
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
classifier = nn.Sequential()
|
||||
classifier.add_module('fc_block1', fc_block(hidden_size*35, hidden_size, False, nn.Tanh))
|
||||
classifier.add_module('dropout', nn.Dropout(args.dropout))
|
||||
classifier.add_module('fc_block2', fc_block(hidden_size, 7, False, None)) # 7 is all possible actions
|
||||
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, input_emb):
|
||||
input_emb = input_emb.view(-1, self.hidden_size*35)
|
||||
return self.classifier(input_emb)
|
||||
|
||||
|
||||
class ActionDemo2Predicate(nn.Module):
|
||||
def __init__(self, args, **kwargs):
|
||||
super(ActionDemo2Predicate, self).__init__()
|
||||
|
||||
print('------------------------------------------------------------------------------------------')
|
||||
print('ActionDemo2Predicate')
|
||||
print('------------------------------------------------------------------------------------------')
|
||||
|
||||
model_type = args.model_type
|
||||
print('model_type', model_type)
|
||||
|
||||
if model_type.lower() == 'max':
|
||||
demo_encoder = ActionDemoEncoder(args, 'max')
|
||||
elif model_type.lower() == 'avg':
|
||||
demo_encoder = ActionDemoEncoder(args, 'avg')
|
||||
elif model_type.lower() == 'lstmavg':
|
||||
demo_encoder = ActionDemoEncoder(args, 'lstmavg')
|
||||
elif model_type.lower() == 'bilstmavg':
|
||||
demo_encoder = ActionDemoEncoder(args, 'bilstmavg')
|
||||
elif model_type.lower() == 'lstmlast':
|
||||
demo_encoder = ActionDemoEncoder(args, 'lstmlast')
|
||||
elif model_type.lower() == 'bilstmlast':
|
||||
demo_encoder = ActionDemoEncoder(args, 'bilstmlast')
|
||||
else:
|
||||
raise ValueError
|
||||
demo_encoder = torch.nn.DataParallel(demo_encoder)
|
||||
|
||||
predicate_decoder = PredicateClassifier(args)
|
||||
|
||||
# for quick save and load
|
||||
all_modules = nn.Sequential()
|
||||
all_modules.add_module('demo_encoder', demo_encoder)
|
||||
all_modules.add_module('predicate_decoder', predicate_decoder)
|
||||
|
||||
self.demo_encoder = demo_encoder
|
||||
self.predicate_decoder = predicate_decoder
|
||||
self.all_modules = all_modules
|
||||
self.to_cuda_fn = None
|
||||
|
||||
def set_to_cuda_fn(self, to_cuda_fn):
|
||||
self.to_cuda_fn = to_cuda_fn
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
'''
|
||||
Note: The order of the `data` won't change in this function
|
||||
'''
|
||||
if self.to_cuda_fn:
|
||||
data = self.to_cuda_fn(data)
|
||||
|
||||
batch_demo_emb = self.demo_encoder(data)
|
||||
pred = self.predicate_decoder(batch_demo_emb)
|
||||
return pred
|
||||
|
||||
def write_summary(self, writer, info, postfix):
|
||||
model_name = 'Demo2Predicate-{}/'.format(postfix)
|
||||
for k in self.summary_keys:
|
||||
if k in info.keys():
|
||||
writer.scalar_summary(model_name + k, info[k])
|
||||
|
||||
def save(self, path, verbose=False):
|
||||
if verbose:
|
||||
print(colored('[*] Save model at {}'.format(path), 'magenta'))
|
||||
torch.save(self.all_modules.state_dict(), path)
|
||||
|
||||
def load(self, path, verbose=False):
|
||||
if verbose:
|
||||
print(colored('[*] Load model at {}'.format(path), 'magenta'))
|
||||
self.all_modules.load_state_dict(
|
||||
torch.load(
|
||||
path,
|
||||
map_location=lambda storage,
|
||||
loc: storage))
|
||||
|
207
keyboard_and_mouse/process_data.py
Normal file
207
keyboard_and_mouse/process_data.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
import pickle
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
def view_clean_data():
|
||||
with open('dataset/clean_data.pkl', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
print(type(data), len(data))
|
||||
print(data.keys())
|
||||
print('length of data:',len(data))
|
||||
print('event', data['Event'], 'length of event', len(data['Event']))
|
||||
print('rule', data['Rule'], 'length of event', len(data['Rule']))
|
||||
|
||||
print('rule unique', data.Rule.unique())
|
||||
print('task id unique', data.TaskID.unique())
|
||||
print('pid unique', data.PID.unique())
|
||||
print('event unique', data.Event.unique())
|
||||
|
||||
def split_org_data():
|
||||
# generate train, test data by split user, aggregate action sequence for next action prediction
|
||||
# orignial action seq: a = [a_0 ... a_n]
|
||||
# new action seq: for a: a0 = [a_0], a1 = [a_0, a_1] ...
|
||||
|
||||
# split original data into train and test based on user
|
||||
with open('dataset/clean_data.pkl', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print('original data keys', data.keys())
|
||||
print('len of original data', len(data))
|
||||
print('rule unique', data.Rule.unique())
|
||||
print('event unique', data.Event.unique())
|
||||
|
||||
data_train = data[data['PID']<=11]
|
||||
data_test = data[data['PID']>11]
|
||||
print('train set len', len(data_train))
|
||||
print('test set len', len(data_test))
|
||||
|
||||
# split data by task
|
||||
train_data_intent = []
|
||||
test_data_intent = []
|
||||
for i in range(7):
|
||||
# 7 different rules, each as an intention
|
||||
train_data_intent.append(data_train[data_train['Rule']==i])
|
||||
test_data_intent.append(data_test[data_test['Rule']==i])
|
||||
|
||||
# generate train set
|
||||
max_len = 0 # max len is 35
|
||||
for i in range(7): # 7 tasks/rules
|
||||
train_data = [] # [task]
|
||||
train_label = []
|
||||
for u in range(1,12):
|
||||
user_data = train_data_intent[i][train_data_intent[i]['PID']==u]
|
||||
for j in range(1,6): # 5 parts == 5 trials
|
||||
part_data = user_data[user_data['Part']==j]
|
||||
for l in range(1,len(part_data['Event'])-1):
|
||||
print(part_data['Event'][:l].tolist())
|
||||
train_data.append(part_data['Event'][:l].tolist())
|
||||
train_label.append(part_data['Event'].iat[l+1])
|
||||
if len(part_data['Event'])>max_len:
|
||||
max_len = len(part_data['Event'])
|
||||
|
||||
for k in range(len(train_data)):
|
||||
while len(train_data[k])<35:
|
||||
train_data[k].append(0) # padding with 0
|
||||
|
||||
print('x_len', len(train_data), type(train_data[0]), len(train_data[0]))
|
||||
print('y_len', len(train_label), type(train_label[0]))
|
||||
|
||||
Path("dataset/strategy_dataset").mkdir(parents=True, exist_ok=True)
|
||||
with open('dataset/strategy_dataset/train_label_'+str(i)+'.pkl', 'wb') as f:
|
||||
pickle.dump(train_label, f)
|
||||
with open('dataset/strategy_dataset/train_data_'+str(i)+'.pkl', 'wb') as f:
|
||||
pickle.dump(train_data, f)
|
||||
print('max_len', max_len)
|
||||
|
||||
# generate test set
|
||||
max_len = 0 # max len is 33, total max is 35
|
||||
for i in range(7): # 7 tasks/rules
|
||||
test_data = [] # [task][user]
|
||||
test_label = []
|
||||
test_action_id = []
|
||||
for u in range(12,17):
|
||||
user_data = test_data_intent[i][test_data_intent[i]['PID']==u]
|
||||
test_data_user = []
|
||||
test_label_user = []
|
||||
test_action_id_user = []
|
||||
for j in range(1,6): # 5 parts == 5 trials
|
||||
part_data = user_data[user_data['Part']==j]
|
||||
|
||||
for l in range(1,len(part_data['Event'])-1):
|
||||
test_data_user.append(part_data['Event'][:l].tolist())
|
||||
test_label_user.append(part_data['Event'].iat[l+1])
|
||||
test_action_id_user.append(part_data['Part'].iat[l])
|
||||
|
||||
if len(part_data['Event'])>max_len:
|
||||
max_len = len(part_data['Event'])
|
||||
|
||||
for k in range(len(test_data_user)):
|
||||
while len(test_data_user[k])<35:
|
||||
test_data_user[k].append(0) # padding with 0
|
||||
|
||||
test_data.append(test_data_user)
|
||||
test_label.append(test_label_user)
|
||||
test_action_id.append(test_action_id_user)
|
||||
|
||||
|
||||
print('x_len', len(test_data), type(test_data[0]), len(test_data[0]))
|
||||
print('y_len', len(test_label), type(test_label[0]))
|
||||
with open('dataset/strategy_dataset/test_label_'+str(i)+'.pkl', 'wb') as f:
|
||||
pickle.dump(test_label, f)
|
||||
with open('dataset/strategy_dataset/test_data_'+str(i)+'.pkl', 'wb') as f:
|
||||
pickle.dump(test_data, f)
|
||||
with open('dataset/strategy_dataset/test_action_id_'+str(i)+'.pkl', 'wb') as f:
|
||||
pickle.dump(test_action_id, f)
|
||||
print('max_len', max_len)
|
||||
|
||||
def calc_gt_prob():
|
||||
# train set unique label
|
||||
for i in range(7):
|
||||
with open('dataset/strategy_dataset/train_label_'+str(i)+'.pkl', 'rb') as f:
|
||||
y = pickle.load(f)
|
||||
y = np.array(y)
|
||||
print('task ', i)
|
||||
print('unique train label', np.unique(y))
|
||||
|
||||
def plot_gt_dist():
|
||||
|
||||
full_data = []
|
||||
for i in range(7):
|
||||
with open('dataset/strategy_dataset/' + 'test' + '_label_' + str(i) + '.pkl', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
#print(len(data))
|
||||
full_data.append(data)
|
||||
|
||||
fig, axs = plt.subplots(7)
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(16)
|
||||
act_name = ["Italic", "Bold", "Underline", "Indent", "Align", "FontSize", "FontFamily"]
|
||||
x = np.arange(7)
|
||||
|
||||
width = 0.1
|
||||
for i in range(7):
|
||||
for u in range(len(data)): # 5 users
|
||||
values, counts = np.unique(full_data[i][u], return_counts=True)
|
||||
counts_vis = [0]*7
|
||||
for j in range(len(values)):
|
||||
counts_vis[values[j]-1] = counts[j]
|
||||
print('task', i, 'actions', values, 'num', counts)
|
||||
|
||||
axs[i].set_title('Intention '+str(i))
|
||||
axs[i].set_xlabel('action id')
|
||||
axs[i].set_ylabel('num of actions')
|
||||
axs[i].bar(x+u*width, counts_vis, width=0.1, label='user '+str(u))
|
||||
axs[i].set_xticks(np.arange(len(x)))
|
||||
axs[i].set_xticklabels(act_name)
|
||||
axs[i].set_ylim([0,80])
|
||||
|
||||
axs[0].legend(loc='upper right', ncol=1)
|
||||
plt.tight_layout()
|
||||
plt.savefig('dataset/'+'test'+'_gt_dist.png')
|
||||
plt.show()
|
||||
|
||||
def plot_act():
|
||||
full_data = []
|
||||
for i in range(7):
|
||||
with open('dataset/strategy_dataset/' + 'test' + '_label_' + str(i) + '.pkl', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
full_data.append(data)
|
||||
|
||||
width = 0.1
|
||||
for i in range(7):
|
||||
fig, axs = plt.subplots(5)
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(16)
|
||||
act_name = ["Italic", "Bold", "Underline", "Indent", "Align", "FontSize", "FontFamily"]
|
||||
for u in range(len(full_data[i])): # 5 users
|
||||
x = np.arange(len(full_data[i][u]))
|
||||
axs[u].set_xlabel('action id')
|
||||
axs[u].set_ylabel('num of actions')
|
||||
axs[u].plot(x, full_data[i][u])
|
||||
|
||||
axs[0].legend(loc='upper right', ncol=1)
|
||||
plt.tight_layout()
|
||||
#plt.savefig('test'+'_act.png')
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("func", help="select what function to run. view_clean_data, split_org_data, calc_gt_prob, plot_gt_dist, plot_act", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.func == 'view_clean_data':
|
||||
view_clean_data() # view original keyboad and mouse interaction dataset
|
||||
if args.func == 'split_org_data':
|
||||
split_org_data() # split the original keyboad and mouse interaction dataset. User 1-11 for training, rest for testing
|
||||
if args.func == 'calc_gt_prob':
|
||||
calc_gt_prob() # see unique label in train set
|
||||
if args.func == 'plot_gt_dist':
|
||||
plot_gt_dist() # plot the label distribution of test set
|
||||
if args.func == 'plot_act':
|
||||
plot_act() # plot the label of test set
|
||||
|
||||
|
||||
|
56
keyboard_and_mouse/sampler_single_act.py
Normal file
56
keyboard_and_mouse/sampler_single_act.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import csv
|
||||
import pandas
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
def sample_single_act(pred_path, save_path, j):
|
||||
data = pandas.read_csv(pred_path).values
|
||||
total_data = []
|
||||
|
||||
for u in range(1,6):
|
||||
act_data = data[data[:,1]==u]
|
||||
final_save_path = save_path + "/rate_" + str(j) + "_act_" + str(int(u)) + "_pred.csv"
|
||||
head = []
|
||||
for r in range(7):
|
||||
head.append('act'+str(r+1))
|
||||
head.append('task_name')
|
||||
head.append('gt')
|
||||
head.insert(0,'action_id')
|
||||
pandas.DataFrame(act_data[:,1:]).to_csv(final_save_path, header=head)
|
||||
|
||||
|
||||
def main():
|
||||
# parsing parameters
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=64, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
task = np.arange(7)
|
||||
user_num = 5
|
||||
bs = args.batch_size
|
||||
lr = args.lr # 1e-4
|
||||
hs = args.hidden_size #128
|
||||
model_type = args.model_type #'lstmlast'
|
||||
|
||||
rate = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
|
||||
|
||||
for i in task:
|
||||
for j in rate:
|
||||
for l in range(user_num):
|
||||
pred_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_rate_" + str(j) + "_pred.csv"
|
||||
if j == 100:
|
||||
pred_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_pred.csv"
|
||||
save_path = "prediction/single_act/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l)
|
||||
Path(save_path).mkdir(parents=True, exist_ok=True)
|
||||
data = sample_single_act(pred_path, save_path, j)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# split the prediction by action sequence id, from 10% to 90%
|
||||
main()
|
||||
|
5
keyboard_and_mouse/sampler_single_act.sh
Normal file
5
keyboard_and_mouse/sampler_single_act.sh
Normal file
|
@ -0,0 +1,5 @@
|
|||
python3 sampler_single_act.py \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--model_type lstmlast \
|
||||
--hidden_size 128
|
68
keyboard_and_mouse/sampler_user.py
Normal file
68
keyboard_and_mouse/sampler_user.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import csv
|
||||
import pandas
|
||||
import argparse
|
||||
|
||||
def sample_predciton(path, rate):
|
||||
data = pandas.read_csv(path).values
|
||||
task_list = [0, 1, 2, 3, 4, 5, 6]
|
||||
|
||||
start = 0
|
||||
stop = 0
|
||||
num_unique = np.unique(data[:,1])
|
||||
|
||||
samples = []
|
||||
for j in task_list:
|
||||
for i in num_unique:
|
||||
inx = np.where((data[:,1] == i) & (data[:,-2] == j))
|
||||
samples.append(data[inx])
|
||||
|
||||
for i in range(len(samples)):
|
||||
n = int(len(samples[i])*(100-rate)/100)
|
||||
if n == 0:
|
||||
n = 1
|
||||
samples[i] = samples[i][:-n]
|
||||
if len(samples[i]) == 0:
|
||||
print('len of after sampling',len(samples[i]))
|
||||
|
||||
return np.vstack(samples)
|
||||
|
||||
def main():
|
||||
# parsing parameters
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
task = np.arange(7)
|
||||
user_num = 5
|
||||
bs = args.batch_size
|
||||
lr = args.lr # 1e-4
|
||||
hs = args.hidden_size #128
|
||||
model_type = args.model_type #'lstmlast'
|
||||
|
||||
rate = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||
|
||||
for i in task:
|
||||
for j in rate:
|
||||
for l in range(user_num):
|
||||
pred_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_pred.csv"
|
||||
save_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_rate_" + str(j) + "_pred.csv"
|
||||
data = sample_predciton(pred_path, j)
|
||||
|
||||
head = []
|
||||
for r in range(7):
|
||||
head.append('act'+str(r+1))
|
||||
head.append('task_name')
|
||||
head.append('gt')
|
||||
head.insert(0,'action_id')
|
||||
pandas.DataFrame(data[:,1:]).to_csv(save_path, header=head)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# split the prediction by length, from 10% to 90%
|
||||
main()
|
||||
|
5
keyboard_and_mouse/sampler_user.sh
Normal file
5
keyboard_and_mouse/sampler_user.sh
Normal file
|
@ -0,0 +1,5 @@
|
|||
python3 sampler_user.py \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--model_type lstmlast \
|
||||
--hidden_size 128
|
88
keyboard_and_mouse/stan/plot_user.py
Normal file
88
keyboard_and_mouse/stan/plot_user.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||
|
||||
args = parser.parse_args()
|
||||
plot_type = 'bar' # line bar
|
||||
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||
|
||||
# read data
|
||||
user_data_list = []
|
||||
for i in range(args.user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
user_data_list.append(model_data_list)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14)
|
||||
fig.set_figwidth(25)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(user_data_list)):
|
||||
y.append(user_data_list[i][j+ax*7][0])
|
||||
y_low.append(user_data_list[i][j+ax*7][2])
|
||||
y_high.append(user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
if plot_type == 'line':
|
||||
axs[ax].plot(range(args.user), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(args.user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
if plot_type == 'bar':
|
||||
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||
axs[ax].bar(np.arange(args.user)+width[i], y_total[i], width=0.08, yerr=yerror, label=legend[i], color=color[i])
|
||||
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||
axs[ax].set_ylabel('prob', fontsize=22)
|
||||
for k,x in enumerate(np.arange(args.user)+width[i]):
|
||||
y = y_total[i][k] + yerror[1][k]
|
||||
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16)
|
||||
|
||||
axs[0].text(-0.1, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 22) # all: -0.3,0.5 3rows: -0.5,0.5
|
||||
axs[ax].text(-0.1, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 22, color=color[ax])
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=16)
|
||||
|
||||
plt.xticks(range(args.user),('1', '2', '3', '4', '5'))
|
||||
plt.xlabel('user', fontsize= 22)
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
plt.ylim([0, 1])
|
||||
Path("figure").mkdir(parents=True, exist_ok=True)
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_line.png", bbox_inches='tight')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_bar.png", bbox_inches='tight')
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
8
keyboard_and_mouse/stan/plot_user.sh
Normal file
8
keyboard_and_mouse/stan/plot_user.sh
Normal file
|
@ -0,0 +1,8 @@
|
|||
python3 plot_user.py \
|
||||
--model_type lstmlast_ \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--hidden_size 128 \
|
||||
--N 1 \
|
||||
--user 5
|
||||
|
99
keyboard_and_mouse/stan/plot_user_all_individual.py
Normal file
99
keyboard_and_mouse/stan/plot_user_all_individual.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||
|
||||
args = parser.parse_args()
|
||||
plot_type = 'bar' # line bar
|
||||
act_series = 5
|
||||
|
||||
# read data
|
||||
plot_list = []
|
||||
for act in range(1,act_series+1):
|
||||
user_data_list = []
|
||||
for i in range(args.user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__100" + "_act_" + str(act) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
print(model_data_list.shape)
|
||||
user_data_list.append(model_data_list)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14)
|
||||
fig.set_figwidth(25)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(user_data_list)):
|
||||
y.append(user_data_list[i][j+ax*7][0])
|
||||
y_low.append(user_data_list[i][j+ax*7][2])
|
||||
y_high.append(user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
if plot_type == 'line':
|
||||
axs[ax].plot(range(args.user), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(args.user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
if plot_type == 'bar':
|
||||
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||
axs[ax].bar(np.arange(args.user)+width[i], y_total[i], width=0.08, yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])], label=legend[i], color=color[i])
|
||||
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||
axs[ax].set_ylabel('prob', fontsize=36) # was 22,
|
||||
for k,x in enumerate(np.arange(args.user)+width[i]):
|
||||
y = y_total[i][k] + yerror[1][k]
|
||||
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16) #was 16
|
||||
|
||||
axs[0].text(-0.17, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 46) # was -0.1 0.9 25
|
||||
axs[ax].text(-0.17, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 46, color=color[ax]) # was 25
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=42) # was 18
|
||||
for tick in axs[ax].xaxis.get_major_ticks():
|
||||
tick.set_pad(20)
|
||||
|
||||
plt.xticks(range(args.user),('1', '2', '3', '4', '5'))
|
||||
plt.xlabel('user', fontsize= 42) # was 22
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.ylim([0, 1])
|
||||
plt.tight_layout()
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_line_all_individual.png", bbox_inches='tight')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_bar_all_individual.png", bbox_inches='tight')
|
||||
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_line_all_individual.eps", bbox_inches='tight', format='eps')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_bar_all_individual.eps", bbox_inches='tight', format='eps')
|
||||
#plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
8
keyboard_and_mouse/stan/plot_user_all_individual.sh
Normal file
8
keyboard_and_mouse/stan/plot_user_all_individual.sh
Normal file
|
@ -0,0 +1,8 @@
|
|||
python3 plot_user_all_individual.py \
|
||||
--model_type lstmlast_ \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--hidden_size 128 \
|
||||
--N 1 \
|
||||
--user 5
|
||||
|
93
keyboard_and_mouse/stan/plot_user_all_individual_chiw.py
Normal file
93
keyboard_and_mouse/stan/plot_user_all_individual_chiw.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
model_type = "lstmlast_"
|
||||
batch_size = 8
|
||||
lr = 1e-4
|
||||
hidden_size = 128
|
||||
N = 1
|
||||
user = 5
|
||||
plot_type = 'bar' # line bar
|
||||
act_series = 5
|
||||
|
||||
# read data
|
||||
plot_list = []
|
||||
for act in range(1,act_series+1):
|
||||
user_data_list = []
|
||||
for i in range(user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(N) + "/" + model_type + "bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_result_user" + str(i) + "_rate__100" + "_act_" + str(act) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
print(model_data_list.shape)
|
||||
user_data_list.append(model_data_list)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14)
|
||||
fig.set_figwidth(25)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(user_data_list)):
|
||||
y.append(user_data_list[i][j+ax*7][0])
|
||||
y_low.append(user_data_list[i][j+ax*7][2])
|
||||
y_high.append(user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print(legend[ax])
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
if plot_type == 'line':
|
||||
axs[ax].plot(range(user), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
if plot_type == 'bar':
|
||||
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||
axs[ax].bar(np.arange(user)+width[i], y_total[i], width=0.08, yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])], label=legend[i], color=color[i])
|
||||
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||
axs[ax].set_ylabel('prob', fontsize=26) # was 22,
|
||||
axs[ax].set_title(legend[ax], color=color[ax], fontsize=26)
|
||||
for k,x in enumerate(np.arange(user)+width[i]):
|
||||
y = y_total[i][k] + yerror[1][k]
|
||||
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16) #was 16
|
||||
|
||||
#axs[0].text(-0.17, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 46) # was -0.1 0.9 25
|
||||
#axs[ax].text(-0.17, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 46, color=color[ax]) # was 25
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=18) # was 18
|
||||
for tick in axs[ax].xaxis.get_major_ticks():
|
||||
tick.set_pad(20)
|
||||
|
||||
plt.xticks(range(user),('1', '2', '3', '4', '5'))
|
||||
plt.xlabel('user', fontsize= 26) # was 22
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.ylim([0, 1.2])
|
||||
plt.tight_layout()
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_line_all_individual_chiw.png", bbox_inches='tight')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_bar_all_individual_chiw.png", bbox_inches='tight')
|
||||
#plt.show()
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_line_all_individual_chiw.eps", bbox_inches='tight', format='eps')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_bar_all_individual_chiw.eps", bbox_inches='tight', format='eps')
|
||||
|
||||
|
||||
|
88
keyboard_and_mouse/stan/plot_user_length_10_steps.py
Normal file
88
keyboard_and_mouse/stan/plot_user_length_10_steps.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||
|
||||
args = parser.parse_args()
|
||||
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||
|
||||
rate_user_data_list = []
|
||||
for r in range(0,101,10): # rate = range(0,101,10)
|
||||
# read data
|
||||
user_data_list = []
|
||||
for i in range(args.user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__" + str(r) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
if i == 4:
|
||||
print(model_data_list.shape, model_data_list)
|
||||
user_data_list.append(model_data_list)
|
||||
model_data_list_total = np.stack(user_data_list)
|
||||
print(model_data_list_total.shape)
|
||||
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||
print(mean_user_data.shape)
|
||||
rate_user_data_list.append(mean_user_data)
|
||||
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
||||
fig.set_figwidth(20)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(rate_user_data_list)):
|
||||
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
axs[ax].set_xticks(range(0,101,10))
|
||||
axs[ax].set_ylabel('prob', fontsize=20)
|
||||
|
||||
|
||||
axs[0].text(-0.125, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 20)
|
||||
axs[ax].text(-0.125, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 20, color=color[ax])
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=16)
|
||||
|
||||
|
||||
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 20)
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.xlim([0, 101])
|
||||
plt.ylim([0, 1])
|
||||
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_rate_full.png", bbox_inches='tight')
|
||||
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
8
keyboard_and_mouse/stan/plot_user_length_10_steps.sh
Normal file
8
keyboard_and_mouse/stan/plot_user_length_10_steps.sh
Normal file
|
@ -0,0 +1,8 @@
|
|||
python3 plot_user_length_10_steps.py \
|
||||
--model_type lstmlast_ \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--hidden_size 128 \
|
||||
--N 1 \
|
||||
--user 5
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||
|
||||
args = parser.parse_args()
|
||||
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||
act_series = 5
|
||||
|
||||
for act in range(1,act_series+1):
|
||||
rate_user_data_list = []
|
||||
for r in range(0,101,10):
|
||||
# read data
|
||||
user_data_list = []
|
||||
for i in range(args.user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__" + str(r) + "_act_" + str(act) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
user_data_list.append(model_data_list)
|
||||
model_data_list_total = np.stack(user_data_list)
|
||||
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||
rate_user_data_list.append(mean_user_data)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14) # was 10
|
||||
fig.set_figwidth(20)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(rate_user_data_list)):
|
||||
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
axs[ax].set_xticks(range(0,101,10))
|
||||
axs[ax].set_ylabel('prob', fontsize=26) # was 20
|
||||
|
||||
axs[0].text(-0.15, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) # was -0.125 20
|
||||
axs[ax].text(-0.15, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) # -0.125 20
|
||||
axs[ax].tick_params(axis='y', which='major', labelsize=24) # was 16
|
||||
axs[ax].tick_params(axis='x', which='major', labelsize=24) # was 16
|
||||
for tick in axs[ax].xaxis.get_major_ticks():
|
||||
tick.set_pad(20)
|
||||
|
||||
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 36) # was 20
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.xlim([0, 101])
|
||||
plt.ylim([0, 1])
|
||||
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_rate_ful_all_individuall.png", bbox_inches='tight')
|
||||
|
||||
#plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
python3 plot_user_length_10_steps_all_individual.py \
|
||||
--model_type lstmlast_ \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--hidden_size 128 \
|
||||
--N 1 \
|
||||
--user 5
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
model_type = "lstmlast_"
|
||||
batch_size = 8
|
||||
lr = 1e-4
|
||||
hidden_size = 128
|
||||
N = 1
|
||||
user = 5
|
||||
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||
act_series = 5
|
||||
|
||||
for act in range(1,act_series+1):
|
||||
rate_user_data_list = []
|
||||
for r in range(0,101,10):
|
||||
# read data
|
||||
print(r)
|
||||
user_data_list = []
|
||||
for i in range(user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(N) + "/" + model_type + "bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_result_user" + str(i) + "_rate__" + str(r) + "_act_" + str(act) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
user_data_list.append(model_data_list)
|
||||
model_data_list_total = np.stack(user_data_list)
|
||||
print(model_data_list_total.shape)
|
||||
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||
print(mean_user_data.shape)
|
||||
rate_user_data_list.append(mean_user_data)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14) # was 10
|
||||
fig.set_figwidth(20)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(rate_user_data_list)):
|
||||
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print(legend[ax])
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
axs[ax].set_xticks(range(0,101,10))
|
||||
axs[ax].set_ylabel('prob', fontsize=26) # was 20
|
||||
axs[ax].set_title(legend[ax], color=color[ax], fontsize=26)
|
||||
|
||||
|
||||
#axs[0].text(-0.15, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) # was -0.125 20
|
||||
#axs[ax].text(-0.15, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) # -0.125 20
|
||||
axs[ax].tick_params(axis='y', which='major', labelsize=18) # was 16
|
||||
axs[ax].tick_params(axis='x', which='major', labelsize=18) # was 16
|
||||
for tick in axs[ax].xaxis.get_major_ticks():
|
||||
tick.set_pad(20)
|
||||
|
||||
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 26) # was 20
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.xlim([0, 101])
|
||||
plt.ylim([0, 1.1])
|
||||
plt.tight_layout()
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_rate_ful_all_individuall_chiw.png", bbox_inches='tight')
|
||||
|
||||
#plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
BIN
keyboard_and_mouse/stan/strategy_inference_model
Executable file
BIN
keyboard_and_mouse/stan/strategy_inference_model
Executable file
Binary file not shown.
26
keyboard_and_mouse/stan/strategy_inference_model.stan
Executable file
26
keyboard_and_mouse/stan/strategy_inference_model.stan
Executable file
|
@ -0,0 +1,26 @@
|
|||
data {
|
||||
int<lower=1> I; // number of question options (22)
|
||||
int<lower=0> N; // number of questions being asked by the user
|
||||
int<lower=1> K; // number of strategies
|
||||
// observed "true" questions of the user
|
||||
int q[N];
|
||||
// array of predicted probabilities of questions given strategies
|
||||
// coming from the forward neural network
|
||||
matrix[I, K] P_q_S[N];
|
||||
}
|
||||
parameters {
|
||||
// probabiliy vector of the strategies being applied by the user
|
||||
// to be inferred by the model here
|
||||
simplex[K] P_S;
|
||||
}
|
||||
model {
|
||||
for (n in 1:N) {
|
||||
// marginal probability vector of the questions being asked
|
||||
vector[I] theta = P_q_S[n] * P_S;
|
||||
// categorical likelihood
|
||||
target += categorical_lpmf(q[n] | theta);
|
||||
}
|
||||
// priors
|
||||
target += dirichlet_lpdf(P_S | rep_vector(1.0, K));
|
||||
}
|
||||
|
157
keyboard_and_mouse/stan/strategy_inference_test.R
Normal file
157
keyboard_and_mouse/stan/strategy_inference_test.R
Normal file
|
@ -0,0 +1,157 @@
|
|||
library(tidyverse)
|
||||
library(cmdstanr)
|
||||
library(dplyr)
|
||||
|
||||
|
||||
model_type <- "lstmlast"
|
||||
batch_size <- "8"
|
||||
lr <- "0.0001"
|
||||
hidden_size <- "128"
|
||||
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||
print(model_type)
|
||||
set.seed(9736734)
|
||||
|
||||
user_num <- 5
|
||||
user <-c(0:(user_num-1))
|
||||
strategies <- c(0:6) # 7 tasks
|
||||
print(strategies)
|
||||
print(length(strategies))
|
||||
N <- 1
|
||||
|
||||
# read data from csv
|
||||
sel <- vector("list", length(strategies))
|
||||
for (u in seq_along(user)){
|
||||
dat <- vector("list", length(strategies))
|
||||
print(paste0('user: ', u))
|
||||
for (i in seq_along(strategies)) {
|
||||
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
|
||||
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||
}
|
||||
|
||||
# reset N after inference
|
||||
N = 1
|
||||
|
||||
# select one action series from one intention
|
||||
if (user[[u]] == 0){
|
||||
sel[[1]]<-dat[[1]] %>%
|
||||
group_by(task_name) %>%
|
||||
sample_n(N)
|
||||
sel[[1]] <- data.frame(sel[[1]])
|
||||
}
|
||||
|
||||
# filter data from the selected action series, N series per intention
|
||||
for (i in seq_along(strategies)) {
|
||||
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
|
||||
}
|
||||
row.names(dat) <- NULL
|
||||
|
||||
# create save path
|
||||
dir.create(file.path("result"), showWarnings = FALSE)
|
||||
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
|
||||
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], ".csv")
|
||||
|
||||
dat <- do.call(rbind, dat) %>%
|
||||
mutate(index = as.numeric(as.factor(id))) %>%
|
||||
rename(true_strategy = task_name) %>%
|
||||
mutate(
|
||||
true_strategy = factor(
|
||||
true_strategy, levels = 0:6,
|
||||
labels = strategies
|
||||
),
|
||||
q_type = case_when(
|
||||
gt %in% c(3,4,5) ~ 0,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||
gt %in% c(1,2,3,4) ~ 2,
|
||||
gt %in% c(1,4,5,6,7) ~ 3,
|
||||
gt %in% c(1,2,3,6,7) ~ 4,
|
||||
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||
)
|
||||
)
|
||||
|
||||
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
|
||||
N <- nrow(dat_obs)
|
||||
print(c("N: ", N))
|
||||
q <- dat_obs$gt
|
||||
true_strategy <- dat_obs$true_strategy
|
||||
|
||||
K <- length(unique(dat$assumed_strategy))
|
||||
print(c("K: ", K))
|
||||
I <- 7
|
||||
|
||||
P_q_S <- array(dim = c(N, I, K))
|
||||
for (n in 1:N) {
|
||||
#print(n)
|
||||
P_q_S[n, , ] <- dat %>%
|
||||
filter(index == n) %>%
|
||||
select(matches("^act[[:digit:]]+$")) %>%
|
||||
as.matrix() %>%
|
||||
t()
|
||||
for (k in 1:K) {
|
||||
# normalize probabilities
|
||||
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||
}
|
||||
}
|
||||
print(c('dim P_q_S',dim(P_q_S)))
|
||||
|
||||
mod <- cmdstan_model("strategy_inference_model.stan")
|
||||
|
||||
sub <- which(true_strategy == 0) # "0"
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_0$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 1)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_1$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 2)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_2$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 3)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_3$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 4)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_4$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 5)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_5$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 6)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_6$summary(NULL, c("mean","sd")))
|
||||
|
||||
# save csv
|
||||
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||
write.csv(df,file=save_path,quote=FALSE)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
library(tidyverse)
|
||||
library(cmdstanr)
|
||||
library(dplyr)
|
||||
|
||||
# using every action sequence from each user
|
||||
model_type <- "lstmlast"
|
||||
batch_size <- "8"
|
||||
lr <- "0.0001"
|
||||
hidden_size <- "128"
|
||||
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
||||
|
||||
user_num <- 5
|
||||
user <-c(0:(user_num-1))
|
||||
strategies <- c(0:6) # 7 tasks
|
||||
print('strategies')
|
||||
print(strategies)
|
||||
print('strategies length')
|
||||
print(length(strategies))
|
||||
N <- 1
|
||||
unique_act_id <- c(1:5)
|
||||
print('unique_act_id')
|
||||
print(unique_act_id)
|
||||
set.seed(9746234)
|
||||
|
||||
for (act_id in seq_along(unique_act_id)){
|
||||
for (u in seq_along(user)){
|
||||
print('user')
|
||||
print(u)
|
||||
for (rate in rates) {
|
||||
N <- 1
|
||||
dat <- vector("list", length(strategies))
|
||||
for (i in seq_along(strategies)) {
|
||||
if (rate=="_0"){
|
||||
# read data from csv
|
||||
dat[[i]] <- read.csv(paste0("../prediction/single_act/task", strategies[[i]], "/", model_type, "/user", user[[u]], "/rate_10", "_act_", unique_act_id[act_id], "_pred", ".csv"))
|
||||
} else{
|
||||
dat[[i]] <- read.csv(paste0("../prediction/single_act/task", strategies[[i]], "/", model_type, "/user", user[[u]], "/rate", rate, "_act_", unique_act_id[act_id], "_pred", ".csv"))
|
||||
}
|
||||
# strategy assumed for prediction
|
||||
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||
}
|
||||
|
||||
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, "_act_", unique_act_id[act_id], ".csv")
|
||||
|
||||
dat_act <- do.call(rbind, dat) %>%
|
||||
mutate(index = as.numeric(as.factor(id))) %>%
|
||||
rename(true_strategy = task_name) %>%
|
||||
mutate(
|
||||
true_strategy = factor(
|
||||
true_strategy, levels = 0:6,
|
||||
labels = strategies
|
||||
),
|
||||
q_type = case_when(
|
||||
gt %in% c(3,4,5) ~ 0,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||
gt %in% c(1,2,3,4) ~ 2,
|
||||
gt %in% c(1,4,5,6,7) ~ 3,
|
||||
gt %in% c(1,2,3,6,7) ~ 4,
|
||||
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||
)
|
||||
)
|
||||
|
||||
dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]])
|
||||
N <- nrow(dat_obs)
|
||||
print(c("N: ", N))
|
||||
print(c("dim dat_act: ", dim(dat_act)))
|
||||
q <- dat_obs$gt
|
||||
true_strategy <- dat_obs$true_strategy
|
||||
|
||||
K <- length(unique(dat_act$assumed_strategy))
|
||||
I <- 7
|
||||
|
||||
P_q_S <- array(dim = c(N, I, K))
|
||||
for (n in 1:N) {
|
||||
print(n)
|
||||
P_q_S[n, , ] <- dat_act %>%
|
||||
filter(index == n) %>%
|
||||
select(matches("^act[[:digit:]]+$")) %>%
|
||||
as.matrix() %>%
|
||||
t()
|
||||
for (k in 1:K) {
|
||||
# normalize probabilities
|
||||
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||
}
|
||||
}
|
||||
|
||||
print(c("dim(P_q_S)", dim(P_q_S)))
|
||||
# read stan model
|
||||
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 0) # "0"
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_0$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 1)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_1$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 2)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_2$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 3)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_3$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 4)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_4$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 5)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_5$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 6)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_6$summary(NULL, c("mean","sd")))
|
||||
|
||||
# save csv
|
||||
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||
write.csv(df,file=save_path,quote=FALSE)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
238
keyboard_and_mouse/stan/strategy_inference_test_full_length.R
Normal file
238
keyboard_and_mouse/stan/strategy_inference_test_full_length.R
Normal file
|
@ -0,0 +1,238 @@
|
|||
library(tidyverse)
|
||||
library(cmdstanr)
|
||||
library(dplyr)
|
||||
|
||||
# index order of the strategies assumed throughout
|
||||
model_type <- "lstmlast"
|
||||
batch_size <- "8"
|
||||
lr <- "0.0001"
|
||||
hidden_size <- "128"
|
||||
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
||||
|
||||
user_num <- 5
|
||||
user <-c(0:(user_num-1))
|
||||
strategies <- c(0:6) # 7 tasks
|
||||
print(strategies)
|
||||
print(length(strategies))
|
||||
N <- 1
|
||||
|
||||
set.seed(9736754)
|
||||
|
||||
#read data from csv
|
||||
sel <- vector("list", length(strategies))
|
||||
for (u in seq_along(user)){
|
||||
print('user')
|
||||
print(u)
|
||||
for (rate in rates) {
|
||||
dat <- vector("list", length(strategies))
|
||||
for (i in seq_along(strategies)) {
|
||||
if (rate=="_0"){
|
||||
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate_10", "_pred", ".csv"))
|
||||
} else if (rate=="_100"){
|
||||
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
|
||||
} else{
|
||||
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate", rate, "_pred", ".csv"))
|
||||
}
|
||||
# strategy assumed for prediction
|
||||
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||
dat[[i]]$index <- dat[[i]]$action_id
|
||||
dat[[i]]$id <- dat[[i]][,1]
|
||||
}
|
||||
|
||||
# reset N after inference
|
||||
N <- 1
|
||||
|
||||
# select all action series and infer every one
|
||||
if (rate == "_0"){
|
||||
sel[[1]]<-dat[[1]] %>%
|
||||
group_by(task_name) %>%
|
||||
sample_n(N)
|
||||
sel[[1]] <- data.frame(sel[[1]])
|
||||
unique_act_id <- unique(sel[[1]]$action_id)
|
||||
}
|
||||
print(sel[[1]]$action_id)
|
||||
print(sel[[1]]$task_name)
|
||||
print(dat[[1]]$task_name)
|
||||
|
||||
|
||||
for (i in seq_along(strategies)) {
|
||||
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
|
||||
}
|
||||
row.names(dat) <- NULL
|
||||
print(c('action id', dat[[1]]$action_id))
|
||||
print(c('action id', dat[[2]]$action_id))
|
||||
print(c('action id', dat[[3]]$action_id))
|
||||
|
||||
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
|
||||
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, ".csv")
|
||||
|
||||
dat_act <- do.call(rbind, dat) %>%
|
||||
mutate(index = as.numeric(as.factor(id))) %>%
|
||||
rename(true_strategy = task_name) %>%
|
||||
mutate(
|
||||
true_strategy = factor(
|
||||
true_strategy, levels = 0:6,
|
||||
labels = strategies
|
||||
),
|
||||
q_type = case_when(
|
||||
gt %in% c(3,4,5) ~ 0,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||
gt %in% c(1,2,3,4) ~ 2,
|
||||
gt %in% c(1,4,5,6,7) ~ 3,
|
||||
gt %in% c(1,2,3,6,7) ~ 4,
|
||||
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||
)
|
||||
)
|
||||
|
||||
dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]]) # put_fridge, was num
|
||||
N <- nrow(dat_obs)
|
||||
print(c("N: ", N))
|
||||
print(c("dim dat_act: ", dim(dat_act)))
|
||||
|
||||
q <- dat_obs$gt
|
||||
true_strategy <- dat_obs$true_strategy
|
||||
|
||||
K <- length(unique(dat_act$assumed_strategy))
|
||||
I <- 7
|
||||
|
||||
P_q_S <- array(dim = c(N, I, K))
|
||||
for (n in 1:N) {
|
||||
print(n)
|
||||
P_q_S[n, , ] <- dat_act %>%
|
||||
filter(index == n) %>%
|
||||
select(matches("^act[[:digit:]]+$")) %>%
|
||||
as.matrix() %>%
|
||||
t()
|
||||
for (k in 1:K) {
|
||||
# normalize probabilities
|
||||
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||
}
|
||||
}
|
||||
print(c("dim(P_q_S)", dim(P_q_S)))
|
||||
|
||||
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 0) # "0"
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_0$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 1)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_1$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 2)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_2$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 3)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_3$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 4)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_4$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 5)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_5$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 6)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_6$summary(NULL, c("mean","sd")))
|
||||
|
||||
# save csv
|
||||
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||
write.csv(df,file=save_path,quote=FALSE)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
10
keyboard_and_mouse/temp.py
Normal file
10
keyboard_and_mouse/temp.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
import torch
|
||||
import matplotlib as plt
|
||||
import pickle
|
||||
print(pickle.format_version)
|
||||
import pandas
|
||||
|
||||
print(torch.__version__)
|
||||
print('matplotlib: {}'.format(plt.__version__))
|
||||
print(pandas.__version__)
|
||||
|
158
keyboard_and_mouse/test.py
Normal file
158
keyboard_and_mouse/test.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
import pickle
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import shutil
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
from networks import ActionDemo2Predicate
|
||||
from pathlib import Path
|
||||
from termcolor import colored
|
||||
import pandas as pd
|
||||
|
||||
|
||||
print('torch version: ',torch.__version__)
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(DEVICE)
|
||||
torch.manual_seed(256)
|
||||
|
||||
class test_dataset(Dataset):
|
||||
def __init__(self, x, label, action_id):
|
||||
self.x = x
|
||||
self.idx = action_id
|
||||
self.labels = label
|
||||
|
||||
def __getitem__(self, index):
|
||||
x = self.x[index]
|
||||
label = self.labels[index]
|
||||
action_idx = self.idx[index]
|
||||
return x, label, action_idx
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
def test_model(model, test_dataloader, DEVICE):
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
test_acc = []
|
||||
logits = []
|
||||
labels = []
|
||||
action_ids = []
|
||||
for iter, (x, label, action_id) in enumerate(test_dataloader):
|
||||
with torch.no_grad():
|
||||
x = torch.tensor(x).to(DEVICE)
|
||||
label = torch.tensor(label).to(DEVICE)
|
||||
logps = model(x)
|
||||
logps = F.softmax(logps, 1)
|
||||
logits.append(logps.cpu().numpy())
|
||||
labels.append(label.cpu().numpy())
|
||||
action_ids.append(action_id)
|
||||
|
||||
argmax_Y = torch.max(logps, 1)[1].view(-1, 1)
|
||||
test_acc.append((label.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(label.float().view(-1, 1)) * 100)
|
||||
|
||||
test_acc = np.mean(np.array(test_acc))
|
||||
print('test acc {:.4f}'.format(test_acc))
|
||||
logits = np.concatenate(logits, axis=0)
|
||||
labels = np.concatenate(labels, axis=0)
|
||||
action_ids = np.concatenate(action_ids, axis=0)
|
||||
return logits, labels, action_ids
|
||||
|
||||
def main():
|
||||
# parsing parameters
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--resume', type=bool, default=False, help='resume training')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-1, help='learning rate')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size')
|
||||
parser.add_argument('--epochs', type=int, default=100, help='training epoch')
|
||||
parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path')
|
||||
parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer')
|
||||
parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden')
|
||||
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
|
||||
parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path')
|
||||
|
||||
args = parser.parse_args()
|
||||
path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||||
|
||||
# read models
|
||||
models = []
|
||||
for i in range(7): # 7 tasks
|
||||
net = ActionDemo2Predicate(args)
|
||||
model_path = path + '/task' + str(i) + '_checkpoint.ckpt' # _checkpoint
|
||||
net.load(model_path)
|
||||
models.append(net)
|
||||
|
||||
for u in range(5):
|
||||
task_pred = []
|
||||
task_target = []
|
||||
task_act = []
|
||||
task_task_name = []
|
||||
for i in range(7): # 7 tasks
|
||||
test_loader = []
|
||||
# # read dataset test data
|
||||
with open(args.dataset_path + 'test_data_' + str(i) + '.pkl', 'rb') as f:
|
||||
data_x = pickle.load(f)
|
||||
with open(args.dataset_path + 'test_label_' + str(i) + '.pkl', 'rb') as f:
|
||||
data_y = pickle.load(f)
|
||||
with open(args.dataset_path + 'test_action_id_' + str(i) + '.pkl', 'rb') as f:
|
||||
act_idx = pickle.load(f)
|
||||
|
||||
x = data_x[u]
|
||||
y = data_y[u]
|
||||
act = act_idx[u]
|
||||
test_set = test_dataset(np.array(x), np.array(y)-1, np.array(act))
|
||||
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True)
|
||||
|
||||
preds = []
|
||||
targets = []
|
||||
actions = []
|
||||
task_names = []
|
||||
for j in range(7): # logits from all models
|
||||
pred, target, action = test_model(models[j], test_loader, DEVICE)
|
||||
preds.append(pred)
|
||||
targets.append(target)
|
||||
actions.append(action)
|
||||
task_names.append(np.full(target.shape, i)) #assumed intention
|
||||
|
||||
task_pred.append(preds)
|
||||
task_target.append(targets)
|
||||
task_act.append(actions)
|
||||
task_task_name.append(task_names)
|
||||
|
||||
for i in range(7):
|
||||
preds = []
|
||||
targets = []
|
||||
actions = []
|
||||
task_names = []
|
||||
for j in range(7):
|
||||
preds.append(task_pred[j][i])
|
||||
targets.append(task_target[j][i]+1) # gt value add one
|
||||
actions.append(task_act[j][i])
|
||||
task_names.append(task_task_name[j][i])
|
||||
|
||||
preds = np.concatenate(preds, axis=0)
|
||||
targets = np.concatenate(targets, axis=0)
|
||||
actions = np.concatenate(actions, axis=0)
|
||||
task_names = np.concatenate(task_names, axis=0)
|
||||
write_data = np.concatenate((np.reshape(actions, (-1, 1)), preds, np.reshape(task_names, (-1, 1)), np.reshape(targets, (-1, 1))), axis=1)
|
||||
|
||||
output_path = 'prediction/' + 'task' +str(i) + '/' + args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||||
Path(output_path).mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_path + '/user' + str(u) + '_pred.csv'
|
||||
print(write_data.shape)
|
||||
|
||||
head = []
|
||||
for j in range(7):
|
||||
head.append('act'+str(j+1))
|
||||
head.append('task_name')
|
||||
head.append('gt')
|
||||
head.insert(0,'action_id')
|
||||
pd.DataFrame(write_data).to_csv(output_path, header=head)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
12
keyboard_and_mouse/test.sh
Normal file
12
keyboard_and_mouse/test.sh
Normal file
|
@ -0,0 +1,12 @@
|
|||
python3 test.py \
|
||||
--resume False \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--model_type lstmlast \
|
||||
--epochs 100 \
|
||||
--demo_hidden 128 \
|
||||
--hidden_size 128 \
|
||||
--dropout 0.5 \
|
||||
--dataset_path dataset/strategy_dataset/ \
|
||||
--checkpoint checkpoints/ \
|
||||
--weight_decay 1e-4
|
145
keyboard_and_mouse/train.py
Normal file
145
keyboard_and_mouse/train.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import pickle
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import shutil
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
from networks import ActionDemo2Predicate
|
||||
|
||||
print('torch version: ',torch.__version__)
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(DEVICE)
|
||||
torch.manual_seed(256)
|
||||
|
||||
class train_dataset(Dataset):
|
||||
def __init__(self, x, label):
|
||||
self.x = x
|
||||
self.labels = label
|
||||
|
||||
def __getitem__(self, index):
|
||||
x = self.x[index]
|
||||
label = self.labels[index]
|
||||
return x, label #, img_idx
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
class test_dataset(Dataset):
|
||||
def __init__(self, x, label):
|
||||
self.x = x
|
||||
self.labels = label
|
||||
|
||||
def __getitem__(self, index):
|
||||
x = self.x[index]
|
||||
label = self.labels[index]
|
||||
return x, label #, img_idx
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
def train_model(model, train_dataloader, criterion, optimizer, num_epochs, DEVICE, path, resume):
|
||||
running_loss = 0
|
||||
train_losses = 10
|
||||
is_best_acc = False
|
||||
is_best_train_loss = False
|
||||
|
||||
best_train_acc = 0
|
||||
best_train_loss = 10
|
||||
|
||||
start_epoch = 0
|
||||
accuracy = 0
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
epoch_losses = []
|
||||
train_acc = []
|
||||
epoch_loss = 0
|
||||
for iter, (x, labels) in enumerate(train_dataloader):
|
||||
x = torch.tensor(x).to(DEVICE)
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
optimizer.zero_grad()
|
||||
logps = model(x)
|
||||
loss = criterion(logps, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.detach().item()
|
||||
argmax_Y = torch.max(logps, 1)[1].view(-1, 1)
|
||||
train_acc.append((labels.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(labels.float().view(-1, 1)) * 100)
|
||||
epoch_loss /= (iter + 1)
|
||||
epoch_losses.append(epoch_loss)
|
||||
train_acc = np.mean(np.array(train_acc))
|
||||
print('Epoch {}, train loss {:.4f}, train acc {:.4f}'.format(epoch, epoch_loss, train_acc))
|
||||
|
||||
is_best_acc = train_acc > best_train_acc
|
||||
best_train_acc = max(train_acc, best_train_acc)
|
||||
|
||||
is_best_train_loss = best_train_loss < epoch_loss
|
||||
best_train_loss = min(epoch_loss, best_train_loss)
|
||||
|
||||
if is_best_acc:
|
||||
model.save(path + '_model_best.ckpt')
|
||||
model.save(path + '_checkpoint.ckpt')
|
||||
#scheduler.step()
|
||||
|
||||
def save_checkpoint(state, is_best, path, filename='_checkpoint.pth.tar'):
|
||||
torch.save(state, path + filename)
|
||||
if is_best:
|
||||
shutil.copyfile(path + filename, path +'_model_best.pth.tar')
|
||||
|
||||
def main():
|
||||
# parsing parameters
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--resume', type=bool, default=False, help='resume training')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-1, help='learning rate')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size')
|
||||
parser.add_argument('--epochs', type=int, default=100, help='training epoch')
|
||||
parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path')
|
||||
parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer')
|
||||
parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden')
|
||||
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
|
||||
parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path')
|
||||
|
||||
args = parser.parse_args()
|
||||
# create checkpoints path
|
||||
from pathlib import Path
|
||||
path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
print('total epochs for training: ', args.epochs)
|
||||
|
||||
# read dataset
|
||||
train_loader = []
|
||||
test_loader = []
|
||||
loss_funcs = []
|
||||
optimizers = []
|
||||
models = []
|
||||
parameters = []
|
||||
for i in range(7): # 7 tasks
|
||||
# train data
|
||||
with open(args.dataset_path + 'train_data_' + str(i) + '.pkl', 'rb') as f:
|
||||
data_x = pickle.load(f)
|
||||
with open(args.dataset_path + 'train_label_' + str(i) + '.pkl', 'rb') as f:
|
||||
data_y = pickle.load(f)
|
||||
train_set = train_dataset(np.array(data_x), np.array(data_y)-1)
|
||||
train_loader.append(DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4))
|
||||
print('task', str(i), 'train data size: ', len(train_set))
|
||||
|
||||
net = ActionDemo2Predicate(args)
|
||||
models.append(net)
|
||||
parameter = net.parameters()
|
||||
loss_funcs.append(nn.CrossEntropyLoss())
|
||||
optimizers.append(optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay))
|
||||
|
||||
for i in range(7):
|
||||
path_save = path + '/task' + str(i)
|
||||
print('checkpoint save path: ', path_save)
|
||||
train_model(models[i], train_loader[i], loss_funcs[i], optimizers[i], args.epochs, DEVICE, path_save, args.resume)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
12
keyboard_and_mouse/train.sh
Normal file
12
keyboard_and_mouse/train.sh
Normal file
|
@ -0,0 +1,12 @@
|
|||
python3 train.py \
|
||||
--resume False \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--model_type lstmlast \
|
||||
--epochs 100 \
|
||||
--demo_hidden 128 \
|
||||
--hidden_size 128 \
|
||||
--dropout 0.5 \
|
||||
--dataset_path dataset/strategy_dataset/ \
|
||||
--checkpoint checkpoints/ \
|
||||
--weight_decay 1e-4
|
46
watch_and_help/README.md
Normal file
46
watch_and_help/README.md
Normal file
|
@ -0,0 +1,46 @@
|
|||
# Watch And Help Dataset
|
||||
|
||||
Codes to reproduce results on WAH dataset[^1]
|
||||
|
||||
[^1]: Modified based on WAH train and test codes (https://github.com/xavierpuigf/watch_and_help)[https://github.com/xavierpuigf/watch_and_help].
|
||||
|
||||
## Data
|
||||
|
||||
Extact `dataset/watch_data.zip`
|
||||
|
||||
|
||||
## Neural Network
|
||||
|
||||
Run `sh scripts/train_watch_strategy_full.sh` to train the model
|
||||
|
||||
To test model, either use trained model or extract checkpoints `checkpoints/train_strategy_full/lstmlast.zip`
|
||||
|
||||
Run `sh scripts/test_watch_strategy_full.sh` to test the model
|
||||
|
||||
|
||||
|
||||
## Prediction Split
|
||||
|
||||
Create artificial users and sample predictions from 10% to 90%
|
||||
|
||||
```
|
||||
cd stan
|
||||
sh split_user.sh
|
||||
sh sampler_user.sh
|
||||
```
|
||||
|
||||
|
||||
## Bayesian Inference
|
||||
|
||||
|
||||
Run inference to get results of user intention prediction and action length (0% to 100%) for all users
|
||||
|
||||
```
|
||||
Rscript strategy_inference_test.R
|
||||
```
|
||||
|
||||
Plot intention prediction results and 10% to 100% of actions results
|
||||
|
||||
```
|
||||
sh plot_user_length.sh
|
||||
sh plot_user_length_10_steps.sh
|
BIN
watch_and_help/checkpoints/train_strategy_full/lstmlast.zip
Normal file
BIN
watch_and_help/checkpoints/train_strategy_full/lstmlast.zip
Normal file
Binary file not shown.
BIN
watch_and_help/dataset/watch_data.zip
Normal file
BIN
watch_and_help/dataset/watch_data.zip
Normal file
Binary file not shown.
13
watch_and_help/scripts/test_watch_strategy_full.sh
Normal file
13
watch_and_help/scripts/test_watch_strategy_full.sh
Normal file
|
@ -0,0 +1,13 @@
|
|||
python3 watch_strategy_full/predicate-train-strategy.py \
|
||||
--testset test_task \
|
||||
--gpu_id 0 \
|
||||
--batch_size 32 \
|
||||
--demo_hidden 512 \
|
||||
--model_type lstmlast \
|
||||
--dropout 0 \
|
||||
--inputtype actioninput \
|
||||
--inference 2 \
|
||||
--single 1 \
|
||||
--resume '' \
|
||||
--loss_type ce \
|
||||
--checkpoint checkpoints/train_strategy_full/lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob
|
13
watch_and_help/scripts/train_watch_strategy_full.sh
Normal file
13
watch_and_help/scripts/train_watch_strategy_full.sh
Normal file
|
@ -0,0 +1,13 @@
|
|||
python3 watch_strategy_full/predicate-train-strategy.py \
|
||||
--gpu_id 0 \
|
||||
--model_lr_rate 3e-4 \
|
||||
--batch_size 32 \
|
||||
--demo_hidden 512 \
|
||||
--model_type lstmlast \
|
||||
--inputtype actioninput \
|
||||
--dropout 0 \
|
||||
--single 1 \
|
||||
--resume '' \
|
||||
--checkpoint checkpoints/train_strategy_full/lstmlast \
|
||||
--train_iters 2000 \
|
||||
--loss_type ce\
|
132
watch_and_help/stan/plot_user_length.py
Normal file
132
watch_and_help/stan/plot_user_length.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
def main(args):
|
||||
if args.task_type == 'new_test_task':
|
||||
user = 9
|
||||
N = 1
|
||||
if args.task_type == 'test_task':
|
||||
user = 92
|
||||
N = 1
|
||||
rate = 100
|
||||
|
||||
widths = [-0.1, 0, 0.1]
|
||||
user_table = [6, 13, 15, 19, 20, 23, 27, 30, 33, 44, 46, 49, 50, 51, 52, 53, 54, 56, 65, 71, 84]
|
||||
|
||||
# read data
|
||||
model_data_list = []
|
||||
user_list = []
|
||||
if not args.plot_user_list:
|
||||
for i in range(user):
|
||||
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(rate) + "_user" + str(i) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data)
|
||||
if args.task_type == 'test_task':
|
||||
user_list.append(np.transpose(data[:,[0]]))
|
||||
else:
|
||||
for i in range(user):
|
||||
for t in user_table:
|
||||
if t == i+1:
|
||||
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(rate) + "_user" + str(i) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data)
|
||||
user_list.append(np.transpose(data[:,[0]]))
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato']
|
||||
legend = ['put fridge', 'put\n dishwasher', 'read book']
|
||||
fig, axs = plt.subplots(3, sharex=True, sharey=True)
|
||||
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
||||
fig.set_figwidth(20)
|
||||
|
||||
for ax in range(3):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(3):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(model_data_list)):
|
||||
y.append(model_data_list[i][j+ax*3][0])
|
||||
y_low.append(model_data_list[i][j+ax*3][2])
|
||||
y_high.append(model_data_list[i][j+ax*3][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(3):
|
||||
if args.plot_type == 'line':
|
||||
axs[ax].plot(range(user), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
if args.plot_type == 'bar':
|
||||
if args.task_type == 'new_test_task':
|
||||
widths = [-0.25, 0, 0.25]
|
||||
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
|
||||
axs[ax].bar(np.arange(user)+widths[i],y_total[i], width=0.2, yerr=yerror, color=color[i], label=legend[i])
|
||||
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
|
||||
plt.xticks(range(user), range(1,user+1))
|
||||
axs[ax].set_ylabel('prob', fontsize= 36) # was 22
|
||||
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
|
||||
plt.xlabel('user', fontsize= 40) # was 22
|
||||
for k, x in enumerate(np.arange(user)+widths[i]):
|
||||
y = y_total[i][k] + yerror[1][k]
|
||||
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-15, 3), fontsize=14)
|
||||
|
||||
|
||||
if args.task_type == 'test_task':
|
||||
if not args.plot_user_list:
|
||||
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
|
||||
axs[ax].errorbar(np.arange(user)+widths[i],y_total[i], yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])],markerfacecolor=color[i], ecolor=color[i], markeredgecolor=color[i], label=legend[i],fmt='.k')
|
||||
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
|
||||
plt.xticks(range(user)[::5], range(1,user+1)[::5])
|
||||
axs[ax].set_ylabel('prob', fontsize= 36)
|
||||
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
|
||||
plt.xlabel('user', fontsize= 40)
|
||||
else:
|
||||
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
|
||||
axs[ax].errorbar(np.arange(len(model_data_list))+widths[i],y_total[i], yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])],markerfacecolor=color[i], ecolor=color[i], markeredgecolor=color[i], label=legend[i],fmt='.k')
|
||||
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
|
||||
plt.xticks(range(len(model_data_list)), user_table)
|
||||
axs[ax].set_ylabel('prob', fontsize= 36)
|
||||
#axs[ax].set_yticks(range(0.0,1.0, 0.25))
|
||||
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
|
||||
plt.xlabel('user', fontsize= 40)
|
||||
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=30)
|
||||
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.ylim([0, 1.08])
|
||||
plt.tight_layout()
|
||||
pathlib.Path("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.task_type == 'test_task':
|
||||
if not args.plot_user_list:
|
||||
plt.savefig("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_1.png", bbox_inches='tight')
|
||||
else:
|
||||
plt.savefig("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_1_user_analysis.png", bbox_inches='tight')
|
||||
if args.task_type == 'new_test_task':
|
||||
plt.savefig("result/"+args.ask_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_2.png", bbox_inches='tight')
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--loss_type', type=str, default='ce')
|
||||
parser.add_argument('--model_type', type=str, default="lstmlast" )
|
||||
parser.add_argument('--plot_type', type=str, default='bar') # bar or line
|
||||
parser.add_argument('--task_type', type=str, default='test_task')
|
||||
parser.add_argument('--plot_user_list', action='store_true') # plot user_table or not
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
5
watch_and_help/stan/plot_user_length.sh
Normal file
5
watch_and_help/stan/plot_user_length.sh
Normal file
|
@ -0,0 +1,5 @@
|
|||
python3 plot_user_length.py \
|
||||
--loss_type ce \
|
||||
--model_type lstmlast \
|
||||
--plot_type bar \
|
||||
--task_type test_task
|
88
watch_and_help/stan/plot_user_length_10_steps.py
Normal file
88
watch_and_help/stan/plot_user_length_10_steps.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--loss_type', type=str, default='ce')
|
||||
parser.add_argument('--model_type', type=str, default="lstmlast" )
|
||||
parser.add_argument('--task_type', type=str, default='test_task')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.task_type == 'new_test_task':
|
||||
user = 9
|
||||
N = 1
|
||||
if args.task_type == 'test_task':
|
||||
user = 92
|
||||
N = 1
|
||||
|
||||
#rate = range(0,101,10)
|
||||
rate_user_data_list = []
|
||||
for r in range(0,101,10): # rate = range(0,101,10)
|
||||
# read data
|
||||
print(r)
|
||||
model_data_list = []
|
||||
for i in range(user):
|
||||
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(r) + "_user" + str(i) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data)
|
||||
#print(type(data))
|
||||
model_data_list_total = np.stack(model_data_list)
|
||||
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||
rate_user_data_list.append(mean_user_data)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato']
|
||||
legend = ['put fridge', 'put\n dishwasher', 'read book']
|
||||
fig, axs = plt.subplots(3, sharex=True, sharey=True)
|
||||
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
||||
fig.set_figwidth(20)
|
||||
axs[0].text(-0.145, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 25) # all: -0.3,0.5 3rows: -0.5,0.5
|
||||
|
||||
for ax in range(3):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(3):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(rate_user_data_list)):
|
||||
y.append(rate_user_data_list[i][j+ax*3][0])
|
||||
y_low.append(rate_user_data_list[i][j+ax*3][2])
|
||||
y_high.append(rate_user_data_list[i][j+ax*3][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(3):
|
||||
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
axs[ax].set_xticks(range(0,101,10))
|
||||
axs[ax].set_ylabel('probability', fontsize=22)
|
||||
|
||||
axs[ax].text(-0.145, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 25, color=color[ax])
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=18)
|
||||
|
||||
plt.xlabel('Percentage of observed actions in one action sequence', fontsize= 22)
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.xlim([0, 101])
|
||||
plt.ylim([0, 1])
|
||||
pathlib.Path("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/").mkdir(parents=True, exist_ok=True)
|
||||
if args.task_type == 'test_task':
|
||||
plt.savefig("result/"+args.task_type+"/user"+str(user)+ "/"+args.loss_type+"/figure/N"+ str(N) + "_"+args.model_type+"_rate_full_test_set_1.png", bbox_inches='tight')
|
||||
if args.task_type == 'new_test_task':
|
||||
plt.savefig("result/"+args.task_type+"/user"+str(user)+ "/"+args.loss_type+"/figure/N"+ str(N) + "_"+args.model_type+"_rate_full_test_set_2.png", bbox_inches='tight')
|
||||
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
4
watch_and_help/stan/plot_user_length_10_steps.sh
Normal file
4
watch_and_help/stan/plot_user_length_10_steps.sh
Normal file
|
@ -0,0 +1,4 @@
|
|||
python3 plot_user_length_10_steps.py \
|
||||
--loss_type ce \
|
||||
--model_type lstmlast \
|
||||
--task_type test_task
|
64
watch_and_help/stan/sampler_user.py
Normal file
64
watch_and_help/stan/sampler_user.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import csv
|
||||
import pandas
|
||||
import argparse
|
||||
|
||||
def sample_predciton(path, rate):
|
||||
data = pandas.read_csv(path).values
|
||||
task_list = [0, 1, 2]
|
||||
start = 0
|
||||
stop = 0
|
||||
num_unique = np.unique(data[:,1])
|
||||
#print('unique number', num_unique)
|
||||
|
||||
samples = []
|
||||
for j in task_list:
|
||||
for i in num_unique:
|
||||
inx = np.where((data[:,1] == i) & (data[:,-2] == j))
|
||||
samples.append(data[inx])
|
||||
|
||||
for i in range(len(samples)):
|
||||
n = int(len(samples[i])*(100-rate)/100)
|
||||
samples[i] = samples[i][:-n]
|
||||
|
||||
return np.vstack(samples)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--LOSS', type=str, default='ce')
|
||||
parser.add_argument('--MODEL_TYPE', type=str, default="lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob" )
|
||||
parser.add_argument('--EPOCHS', type=int, default=50)
|
||||
parser.add_argument('--TASK', type=str, default='test_task')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
||||
|
||||
task = ['put_fridge', 'put_dishwasher', 'read_book']
|
||||
sets = [args.TASK]
|
||||
rate = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||
|
||||
for i in task:
|
||||
for j in rate:
|
||||
for k in sets:
|
||||
if k == 'test_task':
|
||||
user_num = 92
|
||||
if k == 'new_test_task':
|
||||
user_num = 9
|
||||
|
||||
for l in range(user_num):
|
||||
pred_path = "prediction/" + k + "/" + "user" + str(user_num) + "/ce/" + i + "/" + "loss_weight_" + args.MODEL_TYPE + "_prediction_" + i + "_user" + str(l) + ".csv"
|
||||
save_path = "prediction/" + k + "/" + "user" + str(user_num) + "/ce/" + i + "/" + "loss_weight_" + args.MODEL_TYPE + "_prediction_" + i + "_user" + str(l) + "_rate_" + str(j) + ".csv"
|
||||
data = sample_predciton(pred_path, j)
|
||||
|
||||
head = []
|
||||
for r in range(79):
|
||||
head.append('act'+str(r+1))
|
||||
head.append('task_name')
|
||||
head.append('gt')
|
||||
head.insert(0,'action_id')
|
||||
pandas.DataFrame(data[:,1:]).to_csv(save_path, header=head)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
5
watch_and_help/stan/sampler_user.sh
Normal file
5
watch_and_help/stan/sampler_user.sh
Normal file
|
@ -0,0 +1,5 @@
|
|||
python3 sampler_user.py \
|
||||
--TASK test_task \
|
||||
--LOSS ce \
|
||||
--MODEL_TYPE lstmlast \
|
||||
--EPOCHS 50
|
76
watch_and_help/stan/save_act_series.R
Normal file
76
watch_and_help/stan/save_act_series.R
Normal file
|
@ -0,0 +1,76 @@
|
|||
library(tidyverse)
|
||||
library(cmdstanr)
|
||||
library(dplyr)
|
||||
|
||||
strategies <- c("put_fridge", "put_dishwasher", "read_book")
|
||||
model_type <- "lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob"
|
||||
rate <- "_0"
|
||||
task_type <- "new_test_task" # new_test_task test_task
|
||||
loss_type <- "ce"
|
||||
set.seed(9746234)
|
||||
if (task_type=="test_task"){
|
||||
user_num <- 92
|
||||
user <-c(0:(user_num-1))
|
||||
N <- 1
|
||||
}
|
||||
if (task_type=="new_test_task"){
|
||||
user_num <- 9
|
||||
user <-c(0:(user_num-1))
|
||||
N <- 1
|
||||
}
|
||||
total_user_act1 <- vector("list", length(user_num))
|
||||
total_user_act2 <- vector("list", length(user_num))
|
||||
|
||||
sel <- vector("list", length(strategies))
|
||||
act_series <- vector("list", user_num)
|
||||
for (u in seq_along(user)){
|
||||
print('user')
|
||||
print(u)
|
||||
dat <- vector("list", length(strategies))
|
||||
for (i in seq_along(strategies)) {
|
||||
if (rate=="_0"){
|
||||
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate_", "90", ".csv"))
|
||||
} else if (rate=="_100"){
|
||||
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], ".csv"))
|
||||
} else{
|
||||
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate", rate, ".csv"))
|
||||
}
|
||||
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||
}
|
||||
|
||||
N <- 1
|
||||
# select all action series and infer every one
|
||||
sel[[1]]<-dat[[1]] %>%
|
||||
group_by(task_name) %>%
|
||||
filter(task_name==1)
|
||||
sel[[1]] <- data.frame(sel[[1]])
|
||||
unique_act_id_t1 <- unique(sel[[1]]$action_id)
|
||||
write.csv(unique_act_id_t1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "user_",u, "_put_dishwasher", ".csv"))
|
||||
total_user_act1[[u]] <- unique_act_id_t1
|
||||
|
||||
sel[[1]]<-dat[[1]] %>%
|
||||
group_by(task_name) %>%
|
||||
filter(task_name==2)
|
||||
sel[[1]] <- data.frame(sel[[1]])
|
||||
unique_act_id_t1 <- unique(sel[[1]]$action_id)
|
||||
write.csv(unique_act_id_t1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "user_",u, "_read_book", ".csv"))
|
||||
total_user_act2[[u]] <- unique_act_id_t1
|
||||
}
|
||||
|
||||
write.csv(total_user_act1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "_put_dishwasher_total", ".csv"))
|
||||
write.csv(total_user_act2, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "read_book_total", ".csv"))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
87
watch_and_help/stan/split_user.py
Normal file
87
watch_and_help/stan/split_user.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
import numpy as np
|
||||
import pathlib
|
||||
import argparse
|
||||
|
||||
np.random.seed(seed=100)
|
||||
|
||||
def sample_user(data, num_users, split_inx):
|
||||
np.random.seed(seed=100)
|
||||
num_unique3 = np.unique(data[:,1])
|
||||
num_unique2 = num_unique3[0:split_inx[1]]
|
||||
num_unique = num_unique3[0:split_inx[0]]
|
||||
|
||||
user_list1 = [np.random.choice(num_unique, int(len(num_unique)/num_users), replace=False) for i in range(num_users)]
|
||||
user_list2 = [np.random.choice(num_unique2, int(len(num_unique2)/num_users), replace=False) for i in range(num_users)]
|
||||
user_list3 = [np.random.choice(num_unique3, int(len(num_unique3)/num_users), replace=False) for i in range(num_users)]
|
||||
|
||||
user_data = []
|
||||
|
||||
for i in range(num_users): # len(user_list)
|
||||
user_idx1 = [int(item) for item in user_list1[i]]
|
||||
user_idx2 = [int(item) for item in user_list2[i]]
|
||||
user_idx3 = [int(item) for item in user_list3[i]]
|
||||
|
||||
data_list = []
|
||||
for j in range(len(user_idx1)):
|
||||
inx = np.where((data[:,1] == user_idx1[j]) & (data[:,-2]==0))
|
||||
data_list.append(data[inx])
|
||||
|
||||
for j in range(len(user_idx2)):
|
||||
inx = np.where((data[:,1] == user_idx2[j]) & (data[:,-2]==1))
|
||||
data_list.append(data[inx])
|
||||
|
||||
for j in range(len(user_idx3)):
|
||||
inx = np.where((data[:,1] == user_idx3[j]) & (data[:,-2]==2))
|
||||
data_list.append(data[inx])
|
||||
|
||||
user_data.append(np.vstack(data_list))
|
||||
|
||||
return user_data
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--LOSS', type=str, default='ce')
|
||||
parser.add_argument('--MODEL_TYPE', type=str, default="lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob" )
|
||||
parser.add_argument('--EPOCHS', type=int, default=50)
|
||||
parser.add_argument('--TASK', type=str, default='test_task')
|
||||
args = parser.parse_args()
|
||||
|
||||
pref = ['put_fridge', 'put_dishwasher', 'read_book']
|
||||
|
||||
if args.TASK == 'new_test_task':
|
||||
NUM_USER = 9 # 9 for 1 user 1 action
|
||||
SPLIT_INX = [NUM_USER, 45]
|
||||
if args.TASK == 'test_task':
|
||||
NUM_USER = 92
|
||||
SPLIT_INX = [NUM_USER, 229]
|
||||
|
||||
head = []
|
||||
for j in range(79):
|
||||
head.append('act'+str(j+1))
|
||||
head.append('task_name')
|
||||
head.append('gt')
|
||||
head.insert(0,'action_id')
|
||||
head.insert(0,'')
|
||||
|
||||
for i in pref:
|
||||
path = "prediction/"+args.TASK+"/" + args.MODEL_TYPE + "/model_" + i + "_strategy_put_fridge" +".csv"
|
||||
data = np.genfromtxt(path, skip_header=1, delimiter=',')
|
||||
data_task_name = np.genfromtxt(path, skip_header=1, delimiter=',', usecols=-2, dtype=None)
|
||||
data_task_name[data_task_name==b'put_fridge'] = 0
|
||||
data_task_name[data_task_name==b'put_dishwasher'] = 1
|
||||
data_task_name[data_task_name==b'read_book'] = 2
|
||||
data[:,-2] = data_task_name.astype(np.float)
|
||||
print("data length: ", len(data))
|
||||
users_data = sample_user(data, NUM_USER, SPLIT_INX)
|
||||
|
||||
length = 0
|
||||
pathlib.Path("prediction/"+args.TASK+"/user" + str(NUM_USER) + "/" + args.LOSS + "/" + i).mkdir(parents=True, exist_ok=True)
|
||||
for j in range(len(users_data)):
|
||||
save_path = "prediction/"+args.TASK+"/user" + str(NUM_USER) + "/" + args.LOSS + "/" + i +"/loss_weight_"+ args.MODEL_TYPE + "_prediction_"+ i + "_user"+str(j)+".csv"
|
||||
length = length + len(users_data[j])
|
||||
np.savetxt(save_path, users_data[j], delimiter=',', header=','.join(head))
|
||||
print("user data length: ", length)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
5
watch_and_help/stan/split_user.sh
Normal file
5
watch_and_help/stan/split_user.sh
Normal file
|
@ -0,0 +1,5 @@
|
|||
python3 split_user.py \
|
||||
--TASK test_task \
|
||||
--LOSS ce \
|
||||
--MODEL_TYPE lstmlast \
|
||||
--EPOCHS 50
|
BIN
watch_and_help/stan/strategy_inference_model
Executable file
BIN
watch_and_help/stan/strategy_inference_model
Executable file
Binary file not shown.
26
watch_and_help/stan/strategy_inference_model.stan
Normal file
26
watch_and_help/stan/strategy_inference_model.stan
Normal file
|
@ -0,0 +1,26 @@
|
|||
data {
|
||||
int<lower=1> I; // number of question options (22)
|
||||
int<lower=0> N; // number of questions being asked by the user
|
||||
int<lower=1> K; // number of strategies
|
||||
// observed "true" questions of the user
|
||||
int q[N];
|
||||
// array of predicted probabilities of questions given strategies
|
||||
// coming from the forward neural network
|
||||
matrix[I, K] P_q_S[N];
|
||||
}
|
||||
parameters {
|
||||
// probabiliy vector of the strategies being applied by the user
|
||||
// to be inferred by the model here
|
||||
simplex[K] P_S;
|
||||
}
|
||||
model {
|
||||
for (n in 1:N) {
|
||||
// marginal probability vector of the questions being asked
|
||||
vector[I] theta = P_q_S[n] * P_S;
|
||||
// categorical likelihood
|
||||
target += categorical_lpmf(q[n] | theta);
|
||||
}
|
||||
// priors
|
||||
target += dirichlet_lpdf(P_S | rep_vector(1.0, K));
|
||||
}
|
||||
|
190
watch_and_help/stan/strategy_inference_test.R
Normal file
190
watch_and_help/stan/strategy_inference_test.R
Normal file
|
@ -0,0 +1,190 @@
|
|||
library(tidyverse)
|
||||
library(cmdstanr)
|
||||
library(dplyr)
|
||||
|
||||
# index order of the strategies assumed throughout
|
||||
strategies <- c("put_fridge", "put_dishwasher", "read_book")
|
||||
model_type <- "lstmlast"
|
||||
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
||||
task_type <- "test_task" # new_test_task test_task
|
||||
loss_type <- "ce"
|
||||
set.seed(9746234)
|
||||
if (task_type=="test_task"){
|
||||
user_num <- 92
|
||||
user <-c(38:(user_num-1))
|
||||
N <- 1
|
||||
}
|
||||
if (task_type=="new_test_task"){
|
||||
user_num <- 9
|
||||
user <-c(0:(user_num-1))
|
||||
N <- 1
|
||||
}
|
||||
|
||||
# read data from csv
|
||||
sel <- vector("list", length(strategies))
|
||||
act_series <- vector("list", user_num)
|
||||
for (u in seq_along(user)){
|
||||
for (rate in rates) {
|
||||
dat <- vector("list", length(strategies))
|
||||
for (i in seq_along(strategies)) {
|
||||
if (rate=="_0"){
|
||||
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate_", "10", ".csv")) # _60
|
||||
} else if (rate=="_100"){
|
||||
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], ".csv")) # _60
|
||||
} else{
|
||||
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate", rate, ".csv")) # _60
|
||||
}
|
||||
# strategy assumed for prediction
|
||||
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||
}
|
||||
|
||||
# reset N after inference
|
||||
if (task_type=="test_task"){
|
||||
N <- 1
|
||||
}
|
||||
if (task_type=="new_test_task"){
|
||||
N <- 1
|
||||
}
|
||||
|
||||
# select one action series from one intention
|
||||
if (rate == "_0"){
|
||||
sel[[1]]<-dat[[1]] %>%
|
||||
group_by(task_name) %>%
|
||||
sample_n(N)
|
||||
|
||||
sel[[1]] <- data.frame(sel[[1]])
|
||||
act_series[[u]] <- sel[[1]]$action_id
|
||||
#print(typeof(sel[[1]]))
|
||||
#print(typeof(dat[[1]]))
|
||||
#print(sel[[1]]$action_id[2])
|
||||
}
|
||||
|
||||
print(c('unique action id', sel[[1]]$action_id))
|
||||
|
||||
# filter data from the selected action series, N series per intention
|
||||
dat[[1]]<-subset(dat[[1]], dat[[1]]$action_id == sel[[1]]$action_id[1] | dat[[1]]$action_id == sel[[1]]$action_id[2] | dat[[1]]$action_id == sel[[1]]$action_id[3])
|
||||
dat[[2]]<-subset(dat[[2]], dat[[2]]$action_id == sel[[1]]$action_id[1] | dat[[2]]$action_id == sel[[1]]$action_id[2] | dat[[2]]$action_id == sel[[1]]$action_id[3])
|
||||
dat[[3]]<-subset(dat[[3]], dat[[3]]$action_id == sel[[1]]$action_id[1] | dat[[3]]$action_id == sel[[1]]$action_id[2] | dat[[3]]$action_id == sel[[1]]$action_id[3])
|
||||
row.names(dat) <- NULL
|
||||
print(c('task name 1', dat[[1]]$task_name))
|
||||
print(c('task name 2', dat[[2]]$task_name))
|
||||
print(c('task name 3', dat[[3]]$task_name))
|
||||
print(c('action id 1', dat[[1]]$action_id))
|
||||
print(c('action id 2', dat[[2]]$action_id))
|
||||
print(c('action id 3', dat[[3]]$action_id))
|
||||
|
||||
# create save path
|
||||
dir.create(file.path(paste0("result/", task_type, "/user", user_num, "/", loss_type, "/N", N)), showWarnings = FALSE, recursive = TRUE)
|
||||
dir.create(file.path("temp"), showWarnings = FALSE)
|
||||
save_path <- paste0("result/", task_type, "/user", user_num, "/", loss_type, "/N", N, "/", model_type, "_N", N, "_", "result", rate,"_user", user[[u]], ".csv")
|
||||
|
||||
if(task_type=="test_task"){
|
||||
dat <- do.call(rbind, dat) %>%
|
||||
mutate(index = as.numeric(as.factor(id))) %>%
|
||||
rename(true_strategy = task_name) %>%
|
||||
mutate(
|
||||
true_strategy = factor(
|
||||
#true_strategy, levels = 0:3,
|
||||
true_strategy, levels = 0:2,
|
||||
labels = strategies
|
||||
),
|
||||
q_type = case_when(
|
||||
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 19, 20, 22, 23, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43, 44, 58, 59, 64, 65, 68, 69, 70, 71, 72, 73, 74) ~ "put_fridge",
|
||||
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 25, 29,30, 31, 32, 33, 34, 37, 38, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57) ~ "put_dishwasher",
|
||||
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45) ~ "read_book",
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if(task_type=="new_test_task"){
|
||||
dat <- do.call(rbind, dat) %>%
|
||||
mutate(index = as.numeric(as.factor(id))) %>%
|
||||
rename(true_strategy = task_name) %>%
|
||||
mutate(
|
||||
true_strategy = factor(
|
||||
true_strategy, levels = 0:2,
|
||||
labels = strategies
|
||||
),
|
||||
q_type = case_when(
|
||||
# new_test_set
|
||||
gt %in% c(1, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 19, 20, 22, 23, 25, 29, 30, 31, 32, 33, 34, 35, 40, 42, 43, 44, 46, 47, 52, 53, 55, 56, 58, 59, 60, 64, 65, 68, 69, 70, 71, 72, 73, 74, 75, 77, 78) ~ "put_fridge",
|
||||
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74) ~ "put_dishwasher",
|
||||
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 60, 75, 76, 77, 78) ~ "read_book",
|
||||
)
|
||||
)
|
||||
}
|
||||
#print(nrow(dat))
|
||||
#print(dat)
|
||||
|
||||
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
|
||||
N <- nrow(dat_obs)
|
||||
print(c("N: ", N))
|
||||
q <- dat_obs$gt
|
||||
true_strategy <- dat_obs$true_strategy
|
||||
|
||||
K <- length(unique(dat$assumed_strategy))
|
||||
I <- 79
|
||||
|
||||
P_q_S <- array(dim = c(N, I, K))
|
||||
for (n in 1:N) {
|
||||
P_q_S[n, , ] <- dat %>%
|
||||
filter(index == n) %>%
|
||||
select(matches("^act[[:digit:]]+$")) %>%
|
||||
as.matrix() %>%
|
||||
t()
|
||||
for (k in 1:K) {
|
||||
# normalize probabilities
|
||||
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||
}
|
||||
}
|
||||
|
||||
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == "put_fridge")
|
||||
}
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_put_fridge <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_put_fridge$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == "put_dishwasher")
|
||||
}
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_put_dishwasher <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_put_dishwasher$summary(NULL, c("mean","sd")))
|
||||
|
||||
# read_book strategy (should favor index 3)
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == "read_book")
|
||||
}
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_read_book <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_read_book$summary(NULL, c("mean","sd")))
|
||||
|
||||
# save csv
|
||||
df <-rbind(fit_put_fridge$summary(), fit_put_dishwasher$summary(), fit_read_book$summary())
|
||||
write.csv(df,file=save_path,quote=FALSE)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Binary file not shown.
203
watch_and_help/watch_strategy_full/helper.py
Normal file
203
watch_and_help/watch_strategy_full/helper.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.modules.rnn import RNNCellBase
|
||||
|
||||
def to_cpu(list_of_tensor):
|
||||
if isinstance(list_of_tensor[0], list):
|
||||
list_list_of_tensor = list_of_tensor
|
||||
list_of_tensor = [to_cpu(list_of_tensor)
|
||||
for list_of_tensor in list_list_of_tensor]
|
||||
else:
|
||||
list_of_tensor = [tensor.cpu() for tensor in list_of_tensor]
|
||||
return list_of_tensor
|
||||
|
||||
|
||||
def average_over_list(l):
|
||||
return sum(l) / len(l)
|
||||
|
||||
def _LayerNormGRUCell(input, hidden, w_ih, w_hh, ln, b_ih=None, b_hh=None):
|
||||
gi = F.linear(input, w_ih, b_ih)
|
||||
gh = F.linear(hidden, w_hh, b_hh)
|
||||
i_r, i_i, i_n = gi.chunk(3, 1)
|
||||
h_r, h_i, h_n = gh.chunk(3, 1)
|
||||
|
||||
# use layernorm here
|
||||
resetgate = torch.sigmoid(ln['resetgate'](i_r + h_r))
|
||||
inputgate = torch.sigmoid(ln['inputgate'](i_i + h_i))
|
||||
newgate = torch.tanh(ln['newgate'](i_n + resetgate * h_n))
|
||||
hy = newgate + inputgate * (hidden - newgate)
|
||||
return hy
|
||||
|
||||
class CombinedEmbedding(nn.Module):
|
||||
def __init__(self, pretrained_embedding, embedding):
|
||||
super(CombinedEmbedding, self).__init__()
|
||||
self.pretrained_embedding = pretrained_embedding
|
||||
self.embedding = embedding
|
||||
self.pivot = pretrained_embedding.num_embeddings
|
||||
|
||||
def forward(self, input):
|
||||
outputs = []
|
||||
mask = input < self.pivot
|
||||
outputs.append(self.pretrained_embedding(torch.clamp(input, 0, self.pivot-1)) * mask.unsqueeze(1).float())
|
||||
mask = input >= self.pivot
|
||||
outputs.append(self.embedding(torch.clamp(input, self.pivot) - self.pivot) * mask.unsqueeze(1).float())
|
||||
return sum(outputs)
|
||||
|
||||
|
||||
class writer_helper(object):
|
||||
def __init__(self, writer):
|
||||
self.writer = writer
|
||||
self.all_steps = {}
|
||||
|
||||
def get_step(self, tag):
|
||||
if tag not in self.all_steps.keys():
|
||||
self.all_steps.update({tag: 0})
|
||||
|
||||
step = self.all_steps[tag]
|
||||
self.all_steps[tag] += 1
|
||||
return step
|
||||
|
||||
def scalar_summary(self, tag, value, step=None):
|
||||
if step is None:
|
||||
step = self.get_step(tag)
|
||||
self.writer.add_scalar(tag, value, step)
|
||||
|
||||
def text_summary(self, tag, value, step=None):
|
||||
if step is None:
|
||||
step = self.get_step(tag)
|
||||
self.writer.add_text(tag, value, step)
|
||||
|
||||
|
||||
class Constant():
|
||||
def __init__(self, v):
|
||||
self.v = v
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
|
||||
class LinearStep():
|
||||
def __init__(self, max, min, steps):
|
||||
self.steps = float(steps)
|
||||
self.max = max
|
||||
self.min = min
|
||||
self.cur_step = 0
|
||||
self.v = self.max
|
||||
|
||||
def update(self):
|
||||
v = max(self.max - (self.max - self.min) *
|
||||
self.cur_step / self.steps, self.min)
|
||||
self.cur_step += 1
|
||||
self.v = v
|
||||
|
||||
|
||||
class fc_block(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, norm, activation_fn):
|
||||
super(fc_block, self).__init__()
|
||||
|
||||
block = nn.Sequential()
|
||||
block.add_module('linear', nn.Linear(in_channels, out_channels))
|
||||
if norm:
|
||||
block.add_module('batchnorm', nn.BatchNorm1d(out_channels))
|
||||
if activation_fn is not None:
|
||||
block.add_module('activation', activation_fn())
|
||||
|
||||
self.block = block
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class conv_block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
norm,
|
||||
activation_fn):
|
||||
super(conv_block, self).__init__()
|
||||
|
||||
block = nn.Sequential()
|
||||
block.add_module(
|
||||
'conv',
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride))
|
||||
if norm:
|
||||
block.add_module('batchnorm', nn.BatchNorm2d(out_channels))
|
||||
if activation_fn is not None:
|
||||
block.add_module('activation', activation_fn())
|
||||
|
||||
self.block = block
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
def get_conv_output_shape(shape, block):
|
||||
B = 1
|
||||
input = torch.rand(B, *shape)
|
||||
output = block(input)
|
||||
n_size = output.data.view(B, -1).size(1)
|
||||
return n_size
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
def BHWC_to_BCHW(tensor):
|
||||
tensor = torch.transpose(tensor, 1, 3) # BCWH
|
||||
tensor = torch.transpose(tensor, 2, 3) # BCHW
|
||||
return tensor
|
||||
|
||||
|
||||
def LCS(X, Y):
|
||||
|
||||
# find the length of the strings
|
||||
m = len(X)
|
||||
n = len(Y)
|
||||
|
||||
# declaring the array for storing the dp values
|
||||
L = [[None] * (n + 1) for i in range(m + 1)]
|
||||
longest_L = [[[]] * (n + 1) for i in range(m + 1)]
|
||||
longest = 0
|
||||
lcs_set = []
|
||||
|
||||
for i in range(m + 1):
|
||||
for j in range(n + 1):
|
||||
if i == 0 or j == 0:
|
||||
L[i][j] = 0
|
||||
longest_L[i][j] = []
|
||||
elif X[i - 1] == Y[j - 1]:
|
||||
L[i][j] = L[i - 1][j - 1] + 1
|
||||
longest_L[i][j] = longest_L[i - 1][j - 1] + [X[i - 1]]
|
||||
if L[i][j] > longest:
|
||||
lcs_set = []
|
||||
lcs_set.append(longest_L[i][j])
|
||||
longest = L[i][j]
|
||||
elif L[i][j] == longest and longest != 0:
|
||||
lcs_set.append(longest_L[i][j])
|
||||
else:
|
||||
if L[i - 1][j] > L[i][j - 1]:
|
||||
L[i][j] = L[i - 1][j]
|
||||
longest_L[i][j] = longest_L[i - 1][j]
|
||||
else:
|
||||
L[i][j] = L[i][j - 1]
|
||||
longest_L[i][j] = longest_L[i][j - 1]
|
||||
|
||||
if len(lcs_set) > 0:
|
||||
return lcs_set[0]
|
||||
else:
|
||||
return lcs_set
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue