first commit
This commit is contained in:
commit
83b04e2133
109 changed files with 12081 additions and 0 deletions
29
README.md
Normal file
29
README.md
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Inferring Human Intentions from Predicted Action Probabilities
|
||||||
|
|
||||||
|
*Lei Shi, Paul Bürkner, Andreas Bulling*
|
||||||
|
|
||||||
|
*University of Stuttgart, Stuttgart, Germany*
|
||||||
|
|
||||||
|
Accepted by [Workshop on Theory of Mind in Human-AI Interaction at CHI 2024](https://theoryofmindinhaichi2024.wordpress.com/)
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
The code is test in Ubuntu 20.04.
|
||||||
|
|
||||||
|
```
|
||||||
|
pytorch 1.11.0
|
||||||
|
matplotlib 3.3.2
|
||||||
|
pickle 4.0
|
||||||
|
pandas 1.4.3
|
||||||
|
R 4.2.1
|
||||||
|
RStan 2.26.3
|
||||||
|
```
|
||||||
|
To install R, [see here](https://cran.r-project.org/bin/linux/ubuntu/fullREADME.html)
|
||||||
|
|
||||||
|
To install RStan, [see here](https://mc-stan.org/users/interfaces/rstan.html)
|
||||||
|
|
||||||
|
## Experiments
|
||||||
|
|
||||||
|
To train and evaluate the method on Watch-And-Help dataset, see [here](watch_and_help/README.md)
|
||||||
|
|
||||||
|
To train and evaluate the method on Keyboard and Mouse Interaction dataset, see [here](keyboard_and_mouse/README.md)
|
||||||
|
|
69
keyboard_and_mouse/README.MD
Normal file
69
keyboard_and_mouse/README.MD
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
# Keyboard And Mouse Interactive Dataset
|
||||||
|
|
||||||
|
|
||||||
|
# Neural Network
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
The code is test in Ubuntu 20.04.
|
||||||
|
|
||||||
|
pytorch 1.11.0
|
||||||
|
matplotlib 3.3.2
|
||||||
|
pickle 4.0
|
||||||
|
pandas 1.4.3
|
||||||
|
|
||||||
|
## Train
|
||||||
|
|
||||||
|
Set training parameters in train.sh
|
||||||
|
|
||||||
|
Run `sh train.sh` to train the model
|
||||||
|
|
||||||
|
|
||||||
|
## Test
|
||||||
|
|
||||||
|
Run `sh test.sh` to run test on trained model
|
||||||
|
|
||||||
|
Predictions are saved under `prediction/task$i$/`
|
||||||
|
|
||||||
|
|
||||||
|
# Bayesian Inference
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
R 4.2.1
|
||||||
|
RStan [](https://mc-stan.org/users/interfaces/rstan.html)
|
||||||
|
|
||||||
|
|
||||||
|
Run `sh sampler_user.sh` to split prediction to 10% to 90%
|
||||||
|
|
||||||
|
Run `Rscript stan/strategy_inference_test.R` to get results of intention prediction for all users
|
||||||
|
Run `sh stan/plot_user.sh` to plot the bar chart for user intention prediction results of all action sequences
|
||||||
|
|
||||||
|
Run `Rscript stan/strategy_inference_test_full_length.R` to get results of intention prediction (0% to 100%) for all users
|
||||||
|
Run `sh stan/plot_user_length_10_steps.sh` to plot the bar chart for user intention prediction results (0% to 100%) of all action sequences
|
||||||
|
|
||||||
|
Run `sh sampler_single_act.sh` to get the predictions for each individual action sequence.
|
||||||
|
Run `Rscript stan/strategy_inference_test_all_individual_act.R` to get all action sequences (0% to 100%) of all users for intention prediction
|
||||||
|
Run `sh plot_user_all_individual.sh` to plot the bar chart for user intention prediction results of all action sequences
|
||||||
|
Run `sh plot_user_length_10_steps_all_individual.sh` to plot the user intention prediction results (0% to 100%) of all action sequences
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Set training and test parameters in train.sh and test.sh
|
||||||
|
|
||||||
|
Run sh train.sh to train the model.
|
||||||
|
|
||||||
|
Run sh test.sh to run test on trained model.
|
||||||
|
Predictions are saved under prediction/task$i$/
|
||||||
|
|
||||||
|
Run sh sampler_user.sh to split prediction to 10% to 90%
|
||||||
|
|
||||||
|
Run stan/strategy_inference_test.R to get results of intention prediction for all users
|
||||||
|
Run stan/plot_user.py to plot the bar chart for user intention prediction results of all action sequences
|
||||||
|
|
||||||
|
Run stan/strategy_inference_test_full_length.R to get results of intention prediction (0% to 100%) for all users
|
||||||
|
Run stan/plot_user_length_10_users.py to plot the bar chart for user intention prediction results (0% to 100%) of all action sequences
|
||||||
|
|
||||||
|
|
||||||
|
Run stan/strategy_inference_test_all_individual_act.R to get all action sequences (0% to 100%) of all users for intention prediction
|
||||||
|
Run stan/plot_user_all_individual.py to plot the bar chart for user intention prediction results of all action sequences
|
||||||
|
Run stan/plot_user_length_10_steps_all_individual.py to plot the user intention prediction results (0% to 100%) of all action sequences
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,852 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "f5cb2ecf",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"2021-09-28 16:10:28.497166: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import os, pdb\n",
|
||||||
|
"from sklearn.model_selection import GridSearchCV \n",
|
||||||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||||
|
"from sklearn.metrics import accuracy_score\n",
|
||||||
|
"from tensorflow import keras\n",
|
||||||
|
"from keras.preprocessing.sequence import pad_sequences"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "459ad77b",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"study_data_path = \"../IntentData/\"\n",
|
||||||
|
"data = pd.read_pickle(study_data_path + \"/Preprocessing_data/clean_data.pkl\")\n",
|
||||||
|
"Task_IDs = np.arange(7).tolist()\n",
|
||||||
|
"StartIndexOffset = 0 #if set to 5 ignore first 5 elements\n",
|
||||||
|
"EndIndexOffset = 0 #if set to 5 ignore last 5 elements"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "4ab8c0cc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array(['Cmd', 'Toolbar'], dtype=object)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"data.Rule.unique()\n",
|
||||||
|
"data.columns\n",
|
||||||
|
"data.Type.unique()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "05550387",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# grouping by part is needed to have one ruleset for the whole part\n",
|
||||||
|
"# Participant [1,16]\n",
|
||||||
|
"# Repeat for 5 times [1,5]\n",
|
||||||
|
"# ???????? [0,6]\n",
|
||||||
|
"g = data.groupby([\"PID\", \"Part\", \"TaskID\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "7819da48",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"param_grid = {'n_estimators': [10,50,100], \n",
|
||||||
|
" 'max_depth': [10,20,30]}\n",
|
||||||
|
"\n",
|
||||||
|
"grid = GridSearchCV(RandomForestClassifier(), param_grid, refit = True, verbose = 0, return_train_score=True) "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "e64a6920",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def createTrainTest(test_IDs, task_IDs, start_index_offset, end_index_offset, shapes=False):\n",
|
||||||
|
" assert isinstance(test_IDs, list)\n",
|
||||||
|
" assert isinstance(task_IDs, list)\n",
|
||||||
|
" # Fill data arrays\n",
|
||||||
|
" y_train = []\n",
|
||||||
|
" x_train = []\n",
|
||||||
|
" y_test = []\n",
|
||||||
|
" x_test = []\n",
|
||||||
|
" for current in g.groups.keys():\n",
|
||||||
|
" c = g.get_group(current)\n",
|
||||||
|
" if (c.TaskID.isin(task_IDs).all()):\n",
|
||||||
|
" new_rule = c.Rule.unique()[0]\n",
|
||||||
|
" if end_index_offset == 0:\n",
|
||||||
|
" new_data = c.Event.values[start_index_offset:]\n",
|
||||||
|
" else:\n",
|
||||||
|
" new_data = c.Event.values[start_index_offset:-end_index_offset]\n",
|
||||||
|
" if (c.PID.isin(test_IDs).all()):\n",
|
||||||
|
" y_test.append(new_rule)\n",
|
||||||
|
" x_test.append(new_data)\n",
|
||||||
|
" else:\n",
|
||||||
|
" y_train.append(new_rule)\n",
|
||||||
|
" x_train.append(new_data)\n",
|
||||||
|
" x_train = np.array(x_train)\n",
|
||||||
|
" y_train = np.array(y_train)\n",
|
||||||
|
" x_test = np.array(x_test)\n",
|
||||||
|
" y_test = np.array(y_test)\n",
|
||||||
|
" print('x_train\\n',x_train)\n",
|
||||||
|
" print('y_train\\n',y_train)\n",
|
||||||
|
" print('x_test\\n',x_test)\n",
|
||||||
|
" print('y_test\\n',y_test)\n",
|
||||||
|
" pdb.set_trace()\n",
|
||||||
|
" if (shapes):\n",
|
||||||
|
" print(x_train.shape)\n",
|
||||||
|
" print(y_train.shape)\n",
|
||||||
|
" print(x_test.shape)\n",
|
||||||
|
" print(y_test.shape)\n",
|
||||||
|
" print(np.unique(y_test))\n",
|
||||||
|
" print(np.unique(y_train))\n",
|
||||||
|
" return (x_train, y_train, x_test, y_test)\n",
|
||||||
|
"\n",
|
||||||
|
"def runSVMS(train_test, maxlen=None, plots=False, last_elements=False):\n",
|
||||||
|
" x_train, y_train, x_test, y_test = train_test\n",
|
||||||
|
" # Get maxlen to pad and pad\n",
|
||||||
|
" if (maxlen==None):\n",
|
||||||
|
" maxlen = 0\n",
|
||||||
|
" for d in np.concatenate((x_train,x_test)):\n",
|
||||||
|
" if len(d) > maxlen:\n",
|
||||||
|
" maxlen = len(d)\n",
|
||||||
|
" \n",
|
||||||
|
" truncating_elements = \"post\"\n",
|
||||||
|
" if last_elements:\n",
|
||||||
|
" truncating_elements = \"pre\"\n",
|
||||||
|
"\n",
|
||||||
|
" x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen, dtype='int32', padding='post', truncating=truncating_elements, value=0)\n",
|
||||||
|
" x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen, dtype='int32', padding='post', truncating=truncating_elements, value=0)\n",
|
||||||
|
"\n",
|
||||||
|
" # fitting the model for grid search \n",
|
||||||
|
" grid.fit(x_train, y_train) \n",
|
||||||
|
"\n",
|
||||||
|
" # print how our model looks after hyper-parameter tuning\n",
|
||||||
|
" if (plots==True):\n",
|
||||||
|
" print(grid.best_estimator_) \n",
|
||||||
|
"\n",
|
||||||
|
" # Predict with best SVM\n",
|
||||||
|
" pred = grid.predict(x_test)\n",
|
||||||
|
"\n",
|
||||||
|
" return accuracy_score(pred, y_test), pred, y_test "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "50dac7db",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/tmp/ipykernel_97850/1264473745.py:23: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
||||||
|
" x_train = np.array(x_train)\n",
|
||||||
|
"/tmp/ipykernel_97850/1264473745.py:25: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
||||||
|
" x_test = np.array(x_test)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"x_train\n",
|
||||||
|
" [array([2, 7, 7, 7, 7, 7, 7, 2, 6, 6, 6, 2, 2, 2])\n",
|
||||||
|
" array([4, 1, 4, 1, 1, 1, 1, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([5, 7, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 3, 3, 3, 3, 3, 6, 6, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([5, 3, 5, 3, 3, 5, 3, 5, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 6, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 3, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 6, 6, 2, 2, 2, 7, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 7, 4, 1, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 5, 7, 7, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 3, 6, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 3, 4, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 1, 4, 1, 7, 7, 7, 7, 7, 4, 4])\n",
|
||||||
|
" array([5, 7, 7, 1, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 3, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 5, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 6, 2, 2, 7, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([5, 7, 7, 1, 5, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 5, 2, 2, 6, 3, 3, 6, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 5, 3, 5, 3, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 6, 3, 6, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 5, 3, 5, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 2, 6, 2, 2, 6, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 2, 2, 2, 2, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 7, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 3, 5, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([5, 7, 7, 1, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 1, 4, 7, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([3, 2, 2, 4, 2, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 3, 3, 6, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 5, 5, 7, 1, 7, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 7, 1, 4, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 3, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([5, 7, 7, 1, 5, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 7, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 3, 6, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 7, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 5, 3, 4, 4, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 3, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 7, 5, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 5, 3, 4, 4, 4, 3, 5, 3, 5, 5, 3, 3])\n",
|
||||||
|
" array([6, 3, 5, 5, 5, 5, 5, 5, 6, 3, 6, 3, 3, 3, 3])\n",
|
||||||
|
" array([2, 2, 3, 2, 2, 2, 3, 3, 2, 2, 2, 3, 2, 2, 2, 3])\n",
|
||||||
|
" array([1, 7, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 1, 6, 2, 1, 2, 6, 1, 2, 1])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 7, 1, 1, 1, 1, 1, 1, 5, 7, 5, 7])\n",
|
||||||
|
" array([3, 3, 5, 3, 3, 5, 3, 3, 3, 3, 5])\n",
|
||||||
|
" array([2, 3, 3, 5, 3, 3, 3, 3, 2]) array([2, 3, 2, 2, 2, 2, 2, 3])\n",
|
||||||
|
" array([1, 7, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 2, 6, 7, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 5, 7, 7, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 3, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||||
|
" array([6, 3, 3, 5, 6, 3, 3, 6, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 2, 2, 3, 2, 3, 2]) array([1, 7, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 2, 7, 7, 7, 2, 2, 7, 6, 7, 7, 2])\n",
|
||||||
|
" array([5, 7, 7, 1, 1, 7, 5, 7, 1, 1, 7, 1, 7, 1])\n",
|
||||||
|
" array([3, 5, 3, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 6, 3, 3, 6, 3, 6, 5, 5, 5, 5, 5, 6, 6, 6])\n",
|
||||||
|
" array([2, 2, 3, 2, 2, 3, 2, 2, 3]) array([1, 7, 7, 1, 7, 7, 7, 7, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 2, 2, 2, 2, 7, 7, 4, 7, 7, 7])\n",
|
||||||
|
" array([7, 7, 5, 1, 7, 5, 7, 5, 7, 7, 1, 1, 1, 1, 1, 7])\n",
|
||||||
|
" array([3, 5, 3, 5, 3, 3, 5, 3, 3])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 3, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 2, 2, 2, 2, 3]) array([1, 7, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 1, 1, 1, 1, 1, 2, 6, 6, 2, 2, 6, 2])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 6, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 7, 5, 7, 5, 7, 5])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([4, 1, 4, 4, 4, 4, 4, 7, 4, 1, 4, 1, 4, 1, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 5, 7, 5, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2, 6])\n",
|
||||||
|
" array([4, 5, 3, 3, 3, 3, 3, 3, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 6, 2, 7, 7, 7, 7, 7, 7, 6, 2, 2, 2])\n",
|
||||||
|
" array([6, 3, 3, 5, 5, 5, 5, 5, 5, 3, 6, 3, 3, 6, 3])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 4, 4, 1, 4, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 5, 7, 7, 1, 7, 7, 7, 7, 5, 5])\n",
|
||||||
|
" array([6, 2, 2, 1, 1, 2, 1, 2, 6, 1, 2, 1, 2, 6, 1])\n",
|
||||||
|
" array([5, 4, 3, 4, 5, 3, 4, 3, 4, 3, 4, 5, 3, 4])\n",
|
||||||
|
" array([6, 2, 7, 7, 6, 7, 2, 6, 7, 7, 6, 7])\n",
|
||||||
|
" array([6, 3, 3, 5, 5, 1, 1, 3, 5, 6, 3, 5, 6, 3, 5, 6, 3])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 4, 1, 7, 4, 7, 4, 1, 4, 1, 7, 7])\n",
|
||||||
|
" array([7, 4, 5, 7, 1, 1, 7, 1, 7, 5, 5, 7, 2, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 2, 1, 2, 1, 2, 6, 6, 2, 6, 2, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 4, 4, 3, 4, 5, 3, 4, 4, 3])\n",
|
||||||
|
" array([6, 2, 7, 7, 2, 7, 7, 6, 2, 2, 7, 7])\n",
|
||||||
|
" array([1, 3, 1, 6, 3, 3, 5, 3, 5, 5, 6, 3, 5, 3, 5, 5, 3])\n",
|
||||||
|
" array([2, 3, 4, 3, 4, 2, 3, 2, 3, 4, 4, 2, 3, 3, 4, 3, 3, 4, 3, 2, 3, 2])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 4, 1, 4, 7, 7, 4, 7, 4, 1, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 7, 5, 1, 7, 7, 7, 1, 1, 1, 7, 5])\n",
|
||||||
|
" array([6, 2, 2, 1, 1, 1, 1, 1, 1, 2, 6, 2, 2, 2])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 5, 3, 5, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 3, 3, 5, 3, 6, 6, 3, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 4, 1, 4, 1])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 7, 5, 1, 7, 7, 5, 1, 1, 1, 7])\n",
|
||||||
|
" array([6, 2, 2, 1, 1, 2, 2, 1, 1, 2, 1, 2, 6, 1])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 3, 3, 5, 5, 3, 6, 3, 5, 5, 6, 3, 5, 5, 3, 5, 5])\n",
|
||||||
|
" array([6, 2, 7, 7, 7, 7, 7, 7, 6, 6, 2, 2])\n",
|
||||||
|
" array([5, 3, 4, 3, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([5, 7, 1, 7, 1, 3, 1, 1, 3, 1, 1, 1, 1, 7, 7, 7, 7, 5, 5])\n",
|
||||||
|
" array([1, 7, 4, 4, 1, 1, 7, 7, 7, 7, 7, 4, 4, 4, 1, 4])\n",
|
||||||
|
" array([6, 2, 1, 2, 2, 2, 2, 2, 6, 6])\n",
|
||||||
|
" array([6, 3, 5, 5, 6, 6, 6, 6, 6, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 2, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([3, 5, 4, 3, 3, 3, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 7, 7, 7, 5, 5, 5, 5])\n",
|
||||||
|
" array([1, 4, 7, 4, 7, 7, 7, 7, 7, 1, 1])\n",
|
||||||
|
" array([6, 2, 2, 1, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 3, 3, 5, 3, 6, 6, 3, 3, 3, 6])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 7, 2, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 5, 7, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([1, 4, 4, 7, 4, 4, 4, 4, 7, 7, 7, 7, 7, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 6, 6])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 3, 3, 3, 3, 5, 5, 5, 5, 5, 6])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 2, 6, 2, 7, 6, 2, 2, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 3, 5, 5, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([7, 5, 1, 7, 2, 2, 3, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([1, 4, 4, 7, 1, 4, 7, 7, 7, 7, 7, 4, 1, 1, 4, 4])\n",
|
||||||
|
" array([6, 2, 2, 1, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 5, 3, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 5, 5, 7, 4, 7, 7, 7])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 1, 4, 4])\n",
|
||||||
|
" array([6, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 6, 6])\n",
|
||||||
|
" array([6, 3, 3, 5, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 5, 4, 3, 3, 5, 5, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 3, 3, 2, 3, 3, 2, 3, 3])\n",
|
||||||
|
" array([2, 6, 2, 2, 2, 6, 2, 6, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([1, 4, 4, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 1, 7, 1, 7, 5, 1, 7, 1, 7, 1, 7, 5, 1])\n",
|
||||||
|
" array([2, 2, 2, 2, 7, 7, 7, 7, 7, 7, 2, 2, 7, 2, 6, 2, 6, 6, 2, 2, 6])\n",
|
||||||
|
" array([5, 3, 5, 3, 3, 3, 5, 3, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 2, 4, 3, 3, 2, 4, 2, 4, 4, 2, 2, 3, 4, 4, 2, 3, 4])\n",
|
||||||
|
" array([2, 3, 3, 3, 4, 2, 3, 4, 3, 4, 2, 3, 4])\n",
|
||||||
|
" array([6, 2, 1, 2, 1, 2, 6, 1, 2, 6, 1, 2, 1, 2, 6, 1])\n",
|
||||||
|
" array([4, 1, 7, 4, 1, 1, 7, 1, 4, 7, 4, 7, 4, 7, 1, 4, 7])\n",
|
||||||
|
" array([7, 5, 1, 7, 1, 7, 7, 1, 7, 5, 1, 7, 5, 1, 7, 1])\n",
|
||||||
|
" array([6, 7, 2, 7, 2, 7, 6, 7, 2, 2, 2, 2, 7, 6])\n",
|
||||||
|
" array([3, 5, 4, 3, 4, 4, 3, 5, 4, 3, 5, 4, 3, 4, 3, 4])\n",
|
||||||
|
" array([2, 3, 4, 2, 4, 2, 3, 4, 2, 4, 2, 4, 2, 3, 4])\n",
|
||||||
|
" array([6, 3, 3, 4, 3, 4, 6, 3, 4, 3, 4, 4, 3])\n",
|
||||||
|
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 1, 2, 6, 1, 2, 1, 2])\n",
|
||||||
|
" array([1, 4, 7, 4, 7, 1, 4, 7, 1, 4, 7, 4, 7, 7, 1, 4])\n",
|
||||||
|
" array([7, 4, 5, 1, 7, 7, 1, 7, 7, 1, 7, 5, 1, 7, 1, 7, 1])\n",
|
||||||
|
" array([6, 7, 2, 7, 6, 7, 2, 7, 2, 7, 6, 7])\n",
|
||||||
|
" array([3, 5, 3, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 3, 2, 2, 2, 2, 3, 4, 4, 4, 4, 2, 4, 4])\n",
|
||||||
|
" array([2, 3, 3, 2, 3, 2, 3, 2, 3, 3, 4, 4, 2, 6, 2, 6, 2, 6, 2, 6])\n",
|
||||||
|
" array([2, 6, 1, 2, 1, 2, 1, 2, 2, 6, 1, 1, 1, 1, 2, 1])\n",
|
||||||
|
" array([4, 1, 7, 1, 7, 1, 7, 1, 4, 7, 1, 4, 7, 1, 7, 7])\n",
|
||||||
|
" array([7, 5, 1, 7, 1, 7, 5, 1, 7, 1, 7, 5, 1, 7, 5, 1])\n",
|
||||||
|
" array([6, 7, 2, 7, 6, 7, 6, 7, 6, 7, 2, 7])\n",
|
||||||
|
" array([3, 5, 3, 3, 3, 5, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 3, 2, 2, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 3, 3, 5, 5, 5, 5, 3, 5, 3, 5, 3, 6, 5, 3, 6, 5])\n",
|
||||||
|
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 1, 2, 2, 2, 1, 2, 1])\n",
|
||||||
|
" array([1, 4, 7, 4, 7, 1, 4, 7, 4, 7, 1, 4, 7, 4, 7])\n",
|
||||||
|
" array([7, 5, 1, 7, 1, 7, 5, 1, 7, 1, 7, 1, 7, 5])\n",
|
||||||
|
" array([6, 6, 7, 2, 7, 6, 7, 6, 7, 6, 7, 2, 7])\n",
|
||||||
|
" array([2, 3, 2, 6, 3, 5, 5, 3, 6, 3, 6, 3, 6, 3, 5, 5, 5, 5])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 1, 1, 1, 1, 2, 2, 6, 2, 6, 2])\n",
|
||||||
|
" array([2, 3, 3, 4, 2, 3, 3, 3, 2, 2, 2, 4, 4, 4, 4, 4, 2, 3, 3, 2, 3, 2])\n",
|
||||||
|
" array([4, 4, 1, 7, 4, 4, 4, 4, 1, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 3, 3, 5, 3, 3, 6, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([5, 3, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 3, 3, 5])\n",
|
||||||
|
" array([6, 2, 2, 1, 2, 6, 2, 2, 6, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 1, 1, 4, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 7, 2, 6, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 3, 5, 2, 2, 3, 6, 3, 3, 2, 6, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 5, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 2, 2, 3, 2, 4, 2, 3, 2, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7, 1])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 7, 1, 7, 5, 1, 7, 1, 7, 1])\n",
|
||||||
|
" array([6, 3, 3, 5, 3, 6, 5, 5, 3, 6, 5, 3, 5, 3, 5])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 3, 5, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 6, 2, 6, 2, 2, 6, 3, 3, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 4, 4, 2, 3, 4, 2, 3, 3, 4, 2, 3, 4])\n",
|
||||||
|
" array([3, 3, 1, 4, 4, 7, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 7, 7, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 7, 5, 5, 5, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 5, 5, 3, 5, 5, 3, 6, 5, 3, 5, 3, 6, 5])\n",
|
||||||
|
" array([5, 3, 3, 3, 4, 4, 3, 5, 4, 3, 4, 3, 5, 4])\n",
|
||||||
|
" array([6, 2, 2, 1, 2, 2, 6, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 2, 3, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([1, 4, 4, 7, 1, 4, 1, 4, 1, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 7, 7, 2, 7, 2, 7, 6, 7, 6, 7, 7, 6])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 1, 1, 7, 1, 7, 5, 1, 7, 5, 1])\n",
|
||||||
|
" array([7, 5, 1, 7, 2, 2, 5, 5, 5, 5, 5, 2, 2, 2, 2, 1, 5, 5, 5, 5, 2, 7,\n",
|
||||||
|
" 2, 7, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4, 2])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 5, 4, 3, 4, 3, 5, 4, 3, 3, 4, 3, 5, 4, 4, 3, 3, 5, 4])\n",
|
||||||
|
" array([2, 6, 7, 2, 7, 2, 2, 2, 6, 7, 2, 6, 7, 2, 7, 2, 6, 7])\n",
|
||||||
|
" array([2, 3, 5, 3, 3, 5, 2, 3, 5, 2, 3, 3, 5, 3, 4, 5, 2, 3, 5])\n",
|
||||||
|
" array([1, 4, 7, 4, 7, 4, 1, 7, 4, 7, 1, 4, 7, 4, 7])\n",
|
||||||
|
" array([5, 7, 1, 7, 1, 5, 7, 1, 7, 5, 1, 7, 1, 7, 1])\n",
|
||||||
|
" array([2, 3, 1, 2, 1, 2, 1, 2, 1, 3, 2, 3, 1])\n",
|
||||||
|
" array([2, 6, 1, 6, 1, 6, 1, 2, 6, 1, 2, 6, 1, 6, 1])\n",
|
||||||
|
" array([5, 3, 4, 3, 3, 3, 4, 3, 4, 3, 4, 4, 3, 5, 4, 3, 5, 4])\n",
|
||||||
|
" array([6, 7, 2, 7, 2, 7, 6, 7, 6, 7, 2, 7])\n",
|
||||||
|
" array([2, 3, 5, 3, 5, 3, 5, 3, 5, 6, 3, 2, 6, 5, 5, 3])\n",
|
||||||
|
" array([1, 4, 7, 4, 7, 4, 1, 7, 1, 4, 7, 4, 7, 4, 7])\n",
|
||||||
|
" array([5, 7, 1, 7, 1, 5, 7, 1, 7, 5, 1, 7, 1, 1, 5, 7])\n",
|
||||||
|
" array([3, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 3, 4])\n",
|
||||||
|
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 6, 1, 2, 6, 1, 2, 1])\n",
|
||||||
|
" array([3, 5, 4, 5, 4, 5, 4, 5, 4, 5, 3, 4, 4, 5])\n",
|
||||||
|
" array([7, 2, 6, 7, 7, 2, 2, 6, 7, 2, 7, 2, 7, 6, 7])\n",
|
||||||
|
" array([6, 3, 5, 3, 5, 6, 5, 5, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5])\n",
|
||||||
|
" array([1, 4, 7, 4, 7, 4, 1, 7, 4, 1, 7, 4, 7, 4, 1, 7])\n",
|
||||||
|
" array([5, 7, 1, 7, 1, 7, 1, 5, 7, 1, 5, 7, 1, 7, 1])\n",
|
||||||
|
" array([3, 2, 4, 2, 4, 4, 2, 4, 2, 3, 4, 2, 3, 4, 4, 2, 3])\n",
|
||||||
|
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 6, 1, 2, 1, 2, 1])\n",
|
||||||
|
" array([5, 3, 5, 5, 4, 3, 4, 3, 4, 3, 5, 4, 3, 5, 4, 4, 3, 5])\n",
|
||||||
|
" array([3, 6, 3, 7, 2, 7, 6, 7, 2, 7, 6, 2, 6, 7])\n",
|
||||||
|
" array([6, 3, 5, 3, 5, 6, 3, 5, 5, 3, 5, 6, 3, 5, 5, 3, 6])\n",
|
||||||
|
" array([1, 4, 7, 4, 7, 4, 1, 7, 4, 1, 7, 4, 7, 4, 7])\n",
|
||||||
|
" array([5, 7, 1, 7, 1, 7, 1, 7, 7, 1, 5, 7, 1, 5, 7, 1])\n",
|
||||||
|
" array([3, 2, 4, 2, 4, 3, 2, 4, 2, 4, 3, 2, 4, 2, 4])\n",
|
||||||
|
" array([2, 6, 1, 2, 1, 2, 6, 1, 2, 6, 1, 2, 1, 2, 6, 1])\n",
|
||||||
|
" array([5, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 5, 4])\n",
|
||||||
|
" array([6, 7, 2, 7, 2, 6, 2, 6, 2, 2, 5, 7, 6, 7, 2, 7, 2, 7])\n",
|
||||||
|
" array([6, 3, 5, 3, 5, 3, 6, 5, 5, 3, 5, 3, 6, 5, 6, 3, 5])\n",
|
||||||
|
" array([1, 4, 7, 4, 7, 4, 7, 4, 7, 4, 7, 1, 4, 7])\n",
|
||||||
|
" array([7, 5, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 5, 5, 3, 3, 3, 3, 3, 3, 3])\n",
|
||||||
|
" array([7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 2, 2])\n",
|
||||||
|
" array([2, 3, 2, 3, 2, 3, 2, 2, 2])\n",
|
||||||
|
" array([1, 1, 1, 4, 1, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 6, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 6, 3, 6, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 5, 3, 3, 3, 3, 3, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 6, 2, 2, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([1, 4, 1, 4, 1, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 6, 2, 6, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([6, 3, 6, 3, 6, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 5, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 6, 6, 6, 2, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 3, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([1, 1, 1, 4, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 2, 6, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 6, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([5, 7, 5, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 3, 5, 3, 5, 3, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 6, 6, 6, 2, 2, 7, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 3, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([1, 1, 1, 4, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 2, 2, 2, 2, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 3, 3, 3, 3, 3, 6, 6, 7, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 5, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 3, 3, 5, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 6, 6, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([4, 4, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 2, 2, 2, 2, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 5, 5, 5, 5, 5, 3, 3, 3, 3, 3, 3, 6, 6, 6, 6])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 7, 7, 7, 7, 5, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 1, 2, 1, 2, 1, 2, 6, 2, 6, 1, 1, 1, 2])\n",
|
||||||
|
" array([6, 2, 7, 7, 2, 6, 2, 2, 7, 7, 7, 7, 2, 6])\n",
|
||||||
|
" array([5, 3, 3, 4, 4, 3, 4, 3, 3, 4, 3, 5, 4, 3, 5, 4])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 3, 6, 3, 6, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 4, 7, 4, 7, 1, 4, 7, 4, 6, 6, 7, 6])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 7, 5, 1, 7, 1, 7, 5, 1, 7, 1, 5])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 2, 1, 2, 1, 2, 6, 1, 2, 6, 1])\n",
|
||||||
|
" array([6, 2, 7, 7, 6, 7, 2, 7, 6, 7, 2, 7])\n",
|
||||||
|
" array([5, 3, 3, 4, 4, 3, 5, 3, 3, 5, 3, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 3, 1, 1, 2, 4, 4, 2, 4, 2, 3, 2, 3, 4, 4, 2, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 5, 3, 6, 5, 3, 5, 6, 3, 5, 3, 5])\n",
|
||||||
|
" array([1, 4, 4, 7, 7, 1, 4, 5, 5, 7, 4, 7, 4, 7, 1, 4, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 5, 1, 7, 5, 1, 7, 5, 1, 7, 1, 7, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 2, 6, 2, 1, 1, 2, 1, 2, 1])\n",
|
||||||
|
" array([6, 2, 7, 7, 6, 2, 7, 6, 7, 2, 7, 6, 7])\n",
|
||||||
|
" array([5, 3, 3, 4, 4, 3, 4, 3, 5, 3, 5, 4, 4, 3, 4])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 2, 3, 4, 2, 4, 2, 3, 4, 2, 3, 3, 4])\n",
|
||||||
|
" array([6, 3, 3, 5, 5, 3, 3, 6, 3, 3, 5, 5, 5, 5])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 4, 7, 4, 7, 4, 7, 1, 4, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 5, 1, 7, 1, 5, 7, 1, 7, 1, 1, 5, 1, 7, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 2, 1, 2, 1, 2, 6, 1, 2, 1])\n",
|
||||||
|
" array([6, 2, 7, 7, 6, 7, 2, 7, 6, 7, 2, 7])\n",
|
||||||
|
" array([5, 3, 3, 4, 5, 3, 4, 5, 3, 4, 3, 4, 3, 4, 3, 5, 4])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 2, 3, 4, 2, 4, 2, 3, 4, 2, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 5, 3, 5, 3, 6, 5, 3, 6, 5, 3, 3, 5])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 7, 7, 4, 7, 4, 1, 1, 4, 7, 7])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 7, 1, 7, 5, 5, 1, 7, 1, 7, 5, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 1, 1, 1, 1, 2, 6, 2, 2, 2, 6])\n",
|
||||||
|
" array([6, 2, 7, 7, 6, 7, 6, 7, 2, 7, 2, 7])\n",
|
||||||
|
" array([3, 5, 3, 4, 4, 5, 3, 4, 3, 4, 5, 3, 4, 3, 4])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 2, 4, 2, 4, 2, 3, 4, 2, 4])\n",
|
||||||
|
" array([6, 3, 3, 5, 5, 3, 6, 3, 5, 5, 3, 6, 5, 3, 6, 5])\n",
|
||||||
|
" array([1, 4, 4, 7, 7, 4, 1, 7, 4, 4, 4, 7, 7, 7, 1])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 3, 2, 2])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 1, 7, 7, 7, 5, 7, 5])\n",
|
||||||
|
" array([2, 6, 1, 2, 1, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 3, 3, 3, 3, 5, 5, 3, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 3, 6, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 7, 7, 7, 7, 7, 7, 1, 4, 4, 4])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 3, 2, 3, 2, 2])\n",
|
||||||
|
" array([1, 7, 5, 7, 7, 5, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([4, 3, 5, 3, 4, 3, 5, 5, 3, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 3, 3, 5, 3, 3, 3, 6, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 2, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 4, 7, 3, 7, 3, 1, 1, 4, 4, 1, 4, 4, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 2, 2, 3, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 5, 7, 7, 5, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 6, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 3, 5, 4, 3, 5, 3, 5, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([5, 6, 3, 3, 3, 3, 3, 3, 5, 5, 5, 5, 5, 6, 6, 6])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 3, 2, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([1, 7, 5, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 6, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([4, 3, 3, 5, 3, 3, 5, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 3, 3, 3, 6, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 4, 1, 7, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 2, 3, 2, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([1, 7, 5, 7, 1, 1, 1, 1, 1, 7, 7, 7, 7, 5])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 5, 3, 5, 5, 3, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 3, 3, 5, 2, 3, 2, 6, 6, 3, 2, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 1, 4, 7, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7, 4])\n",
|
||||||
|
" array([4, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7, 1, 1])\n",
|
||||||
|
" array([2, 6, 2, 2, 2, 6, 2, 6, 2, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([4, 4, 4, 4, 4, 4, 4, 5, 2, 5, 5, 5, 5, 5, 2, 2, 2])\n",
|
||||||
|
" array([6, 3, 3, 6, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 6, 6, 6, 2, 2, 4, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 4, 4, 4, 4, 4, 1, 1, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([6, 2, 2, 6, 2, 2, 6, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 4, 5, 5, 3, 3, 3, 5, 5, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 3, 3, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 5, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 2, 2, 3, 2, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 6, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 2, 6, 2, 2, 6, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 3, 3, 3, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 3, 3, 3, 6, 3, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 7, 5, 7, 7, 5, 7, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 3, 3, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 2, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 1, 4, 4, 4, 1, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 2, 2, 2, 6, 2, 6, 1, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([5, 3, 3, 3, 5, 3, 3, 3, 5, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 3, 3, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 7, 7, 7, 5, 7, 5, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 7, 2, 7, 2, 2, 6, 6, 6, 7, 6, 7, 7, 7])\n",
|
||||||
|
" array([1, 4, 4, 4, 4, 1, 1, 1, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 2, 2, 6, 2, 2, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 5, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 6, 3, 3, 6, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 7, 5, 7, 5, 7, 7, 5, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 6, 6, 2, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([7, 6, 2, 6, 2, 2, 2, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 4, 1, 7, 4, 4, 4, 4, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1, 5, 7, 7, 7])\n",
|
||||||
|
" array([4, 3, 3, 5, 3, 4, 4, 4, 3, 3, 3, 5, 5, 4, 4, 4])\n",
|
||||||
|
" array([6, 3, 3, 5, 2, 3, 2, 3, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 2, 6, 1, 1, 1, 1, 1, 1, 2, 6, 2, 6, 2, 2])\n",
|
||||||
|
" array([2, 2, 3, 4, 2, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 7, 6, 2, 6, 2, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 1, 4, 4, 4, 7, 4, 1, 4, 1, 4, 4, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 7, 5, 7, 7, 7, 5, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 4, 5, 3, 3, 5, 4, 4, 4, 4, 4, 3, 3, 3, 5])\n",
|
||||||
|
" array([3, 3, 6, 5, 3, 3, 6, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([1, 2, 2, 6, 2, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 3, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 7, 2, 2, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 4, 7, 1, 4, 1, 4, 4, 1, 1, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 7, 7, 5, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1, 5])\n",
|
||||||
|
" array([3, 3, 5, 4, 3, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([5, 6, 3, 3, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 1, 2, 6, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 4, 2, 3, 2, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 7, 2, 2, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 4, 1, 1, 7, 4, 1, 4, 1, 4, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([7, 7, 5, 1, 7, 5, 7, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 5, 3, 3, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 3, 6, 5, 3, 6, 3, 6, 3, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 2, 6, 1, 2, 2, 2, 2, 1, 2, 2, 6, 1, 1, 1, 1])\n",
|
||||||
|
" array([2, 4, 2, 3, 2, 3, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([2, 6, 7, 2, 6, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([4, 4, 1, 7, 4, 4, 1, 4, 1, 4, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 7, 7, 5, 7, 5, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([4, 5, 3, 3, 3, 5, 3, 3, 5, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([5, 6, 3, 3, 3, 6, 6, 3, 3, 6, 3, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 1, 2, 6, 2, 2, 6, 2, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([1, 2, 3, 1, 2, 2, 2, 3, 2, 3, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 2, 2, 6, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 1, 1, 1, 7, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 3, 3, 3, 3, 3, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([5, 2, 2, 6, 3, 3, 3, 3, 6, 6, 3, 3, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 6, 2, 2, 2, 2, 2, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([7, 5, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 7, 7, 7, 7, 7, 6, 6, 2, 2])\n",
|
||||||
|
" array([1, 1, 1, 7, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([5, 3, 3, 3, 3, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4, 4, 4, 5, 5, 4, 5, 5,\n",
|
||||||
|
" 4])\n",
|
||||||
|
" array([3, 6, 3, 5, 5, 5, 5, 5, 5, 6, 3, 3, 6, 3, 3])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 6, 2, 6, 2])\n",
|
||||||
|
" array([1, 7, 7, 1, 7, 1, 7, 1, 7, 1, 7, 1, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 2, 2, 6, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 7, 1, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5])\n",
|
||||||
|
" array([3, 6, 3, 3, 3, 6, 3, 6, 3, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([6, 2, 2, 1, 2, 2, 2, 2, 6, 2, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([7, 1, 1, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 5, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 2, 2, 2, 3, 2, 3, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 6, 2, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 7, 1, 1, 1, 7, 7, 7, 7, 7, 4, 4, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 3, 3, 3, 3, 3, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 6, 3, 3, 6, 3, 3, 6, 3, 6, 3, 6, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 6, 2, 2, 2, 2, 6, 2, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([7, 1, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 5, 5])\n",
|
||||||
|
" array([2, 3, 2, 2, 3, 2, 2, 3, 2, 3, 2, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 3, 6, 2, 2, 2, 6, 6, 6, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([1, 7, 7, 7, 7, 7, 7, 1, 1, 4, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 3, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 5, 4])\n",
|
||||||
|
" array([3, 6, 3, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([2, 6, 2, 2, 6, 2, 2, 2, 2, 6, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([7, 7, 1, 7, 7, 7, 7, 1, 1, 1, 1, 1, 5, 5])\n",
|
||||||
|
" array([5, 3, 6, 3, 5, 3, 5, 3, 5, 6, 3, 5, 3, 6, 5])\n",
|
||||||
|
" array([2, 1, 2, 6, 1, 2, 1, 2, 6, 1, 2, 6, 2, 2, 2, 6, 6, 6, 2, 6, 1, 2,\n",
|
||||||
|
" 1])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 5, 7, 5, 7, 7, 5])\n",
|
||||||
|
" array([2, 2, 3, 3, 3, 2, 2, 3, 4, 4, 4, 4, 4, 2, 2])\n",
|
||||||
|
" array([3, 3, 5, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 5, 5])\n",
|
||||||
|
" array([1, 4, 7, 5, 5, 7, 4, 1, 4, 4, 1, 4, 4, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 2, 6, 7, 7, 7, 7, 7, 7, 6, 6, 6, 2, 2, 6])\n",
|
||||||
|
" array([2, 3, 6, 2, 6, 5, 5, 5, 5, 5, 5, 6, 3, 3, 6, 3, 6, 6])\n",
|
||||||
|
" array([2, 6, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 6, 6])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 5, 7, 5, 7, 7])\n",
|
||||||
|
" array([2, 2, 3, 4, 4, 4, 4, 4, 4, 2, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3])\n",
|
||||||
|
" array([5, 3, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 5, 3, 3])\n",
|
||||||
|
" array([1, 4, 4, 7, 7, 4, 1, 7, 4, 7, 4, 1, 7, 4, 7])\n",
|
||||||
|
" array([6, 2, 7, 7, 7, 7, 7, 7, 6, 2, 6, 2])\n",
|
||||||
|
" array([2, 3, 2, 5, 5, 5, 5, 5, 5, 2, 3, 2, 3, 2, 2])\n",
|
||||||
|
" array([2, 2, 6, 1, 2, 2, 2, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 7, 7, 7, 5])\n",
|
||||||
|
" array([2, 3, 4, 2, 4, 4, 4, 4, 4, 2, 3, 2, 2, 2])\n",
|
||||||
|
" array([3, 5, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 4, 1, 4, 1, 1, 1, 4])\n",
|
||||||
|
" array([6, 2, 7, 7, 7, 7, 7, 7, 2, 2, 6, 6])\n",
|
||||||
|
" array([6, 3, 3, 3, 5, 3, 5, 5, 5, 5, 5, 3, 6, 3, 3, 3])\n",
|
||||||
|
" array([6, 2, 2, 1, 1, 1, 1, 1, 1, 2, 6, 2, 2, 2])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 7, 1, 1, 1, 1, 1, 7, 7, 7, 5, 7, 5])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 3, 2, 2, 2, 2])\n",
|
||||||
|
" array([5, 3, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 1, 1, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 7, 7, 7, 7, 7, 2, 6, 6, 2])\n",
|
||||||
|
" array([6, 3, 3, 5, 5, 5, 5, 5, 5, 3, 3, 3, 3, 6])\n",
|
||||||
|
" array([6, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 6, 6])\n",
|
||||||
|
" array([7, 5, 7, 1, 1, 1, 1, 1, 1, 7, 5, 7, 5, 7, 7, 5])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 3, 2, 3, 2, 2])\n",
|
||||||
|
" array([5, 3, 3, 4, 4, 4, 4, 4, 4, 3, 5, 3, 5, 3, 5, 3])\n",
|
||||||
|
" array([4, 1, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 1, 1, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 7, 7, 7, 7, 7, 2, 2, 6, 6])]\n",
|
||||||
|
"y_train\n",
|
||||||
|
" [5. 3. 6. 1. 0. 4. 2. 5. 3. 6. 1. 0. 4. 2. 5. 3. 6. 1. 0. 4. 2. 5. 3. 6.\n",
|
||||||
|
" 1. 0. 4. 2. 5. 3. 6. 1. 0. 4. 2. 4. 5. 3. 2. 0. 1. 6. 4. 5. 3. 2. 0. 1.\n",
|
||||||
|
" 6. 4. 5. 3. 2. 0. 1. 6. 4. 5. 3. 2. 0. 1. 6. 4. 5. 3. 2. 0. 1. 6. 0. 1.\n",
|
||||||
|
" 2. 3. 4. 5. 6. 0. 1. 2. 3. 4. 5. 6. 0. 1. 2. 3. 4. 5. 6. 0. 1. 2. 3. 4.\n",
|
||||||
|
" 5. 6. 0. 1. 2. 3. 4. 5. 6. 2. 3. 6. 4. 0. 5. 1. 2. 3. 6. 4. 0. 5. 1. 2.\n",
|
||||||
|
" 3. 6. 4. 0. 5. 1. 2. 3. 6. 4. 0. 5. 1. 2. 3. 6. 4. 0. 5. 1. 5. 0. 6. 3.\n",
|
||||||
|
" 4. 1. 2. 5. 0. 6. 3. 4. 1. 2. 5. 0. 6. 3. 4. 1. 2. 5. 0. 6. 3. 4. 1. 2.\n",
|
||||||
|
" 5. 0. 6. 3. 4. 1. 2. 0. 2. 1. 4. 3. 6. 5. 0. 2. 1. 4. 3. 6. 5. 0. 2. 1.\n",
|
||||||
|
" 4. 3. 6. 5. 0. 2. 1. 4. 3. 6. 5. 0. 2. 1. 4. 3. 6. 5. 1. 0. 4. 2. 3. 5.\n",
|
||||||
|
" 6. 1. 0. 4. 2. 3. 5. 6. 1. 0. 4. 2. 3. 5. 6. 1. 0. 4. 2. 3. 5. 6. 1. 0.\n",
|
||||||
|
" 4. 2. 3. 5. 6. 6. 2. 4. 0. 5. 1. 3. 6. 2. 4. 0. 5. 1. 3. 6. 2. 4. 0. 5.\n",
|
||||||
|
" 1. 3. 6. 2. 4. 0. 5. 1. 3. 6. 2. 4. 0. 5. 1. 3. 6. 0. 5. 2. 3. 4. 1. 6.\n",
|
||||||
|
" 0. 5. 2. 3. 4. 1. 6. 0. 5. 2. 3. 4. 1. 6. 0. 5. 2. 3. 4. 1. 6. 0. 5. 2.\n",
|
||||||
|
" 3. 4. 1. 6. 4. 5. 0. 2. 1. 3. 6. 4. 5. 0. 2. 1. 3. 6. 4. 5. 0. 2. 1. 3.\n",
|
||||||
|
" 6. 4. 5. 0. 2. 1. 3. 6. 4. 5. 0. 2. 1. 3. 2. 6. 4. 0. 1. 5. 3. 2. 6. 4.\n",
|
||||||
|
" 0. 1. 5. 3. 2. 6. 4. 0. 1. 5. 3. 2. 6. 4. 0. 1. 5. 3. 2. 6. 4. 0. 1. 5.\n",
|
||||||
|
" 3. 3. 4. 0. 1. 6. 2. 5. 3. 4. 0. 1. 6. 2. 5. 3. 4. 0. 1. 6. 2. 5. 3. 4.\n",
|
||||||
|
" 0. 1. 6. 2. 5. 3. 4. 0. 1. 6. 2. 5. 2. 5. 3. 6. 0. 1. 4. 2. 5. 3. 6. 0.\n",
|
||||||
|
" 1. 4. 2. 5. 3. 6. 0. 1. 4. 2. 5. 3. 6. 0. 1. 4. 2. 5. 3. 6. 0. 1. 4. 2.\n",
|
||||||
|
" 5. 3. 0. 1. 4. 6. 2. 5. 3. 0. 1. 4. 6. 2. 5. 3. 0. 1. 4. 6. 2. 5. 3. 0.\n",
|
||||||
|
" 1. 4. 6. 2. 5. 3. 0. 1. 4. 6. 1. 4. 6. 2. 0. 3. 5. 1. 4. 6. 2. 0. 3. 5.\n",
|
||||||
|
" 1. 4. 6. 2. 0. 3. 5. 1. 4. 6. 2. 0. 3. 5. 1. 4. 6. 2. 0. 3. 5.]\n",
|
||||||
|
"x_test\n",
|
||||||
|
" [array([4, 1, 1, 4, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 4, 4])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 2, 2, 3, 2, 3])\n",
|
||||||
|
" array([5, 3, 3, 4, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 3, 6, 5, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 7, 5, 1, 7, 7, 5, 7, 7, 1, 1, 1, 1, 1, 5, 5])\n",
|
||||||
|
" array([1, 4, 7, 4, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 2, 4, 3, 2, 2, 3, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 5, 3, 4, 4, 3, 3, 3, 3, 5, 5, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 2, 2, 7, 7, 7, 6, 6, 7, 6, 6, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 2, 6, 1, 1, 2, 2, 2, 2, 6, 6, 6, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 5, 2, 2, 3, 6, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([1, 5, 7, 7, 7, 5, 7, 5, 7, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([7, 4, 4, 7, 4, 4, 4, 4, 7, 1, 7, 7, 7, 7, 1, 1, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 3, 2, 2, 2, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([3, 5, 3, 4, 3, 3, 3, 3, 5, 5, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 7, 7, 7, 7, 7, 2, 2, 6, 6])\n",
|
||||||
|
" array([2, 2, 1, 6, 2, 2, 2, 2, 6, 6, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 3, 3, 3, 6, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([4, 1, 4, 4, 4, 4, 4, 1, 1, 1, 7, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 2, 2, 2, 2, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([5, 3, 3, 4, 5, 3, 3, 3, 3, 5, 3, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 2, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 6, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 6, 6])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 3, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([5, 7, 7, 1, 7, 7, 5, 7, 5, 7, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([1, 4, 4, 7, 7, 4, 4, 4, 4, 1, 1, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 3, 2, 4, 4, 4, 4, 4, 4, 2, 3, 2, 2, 2])\n",
|
||||||
|
" array([4, 3, 5, 3, 3, 5, 3, 3, 3, 5, 4, 4, 4, 4, 4])\n",
|
||||||
|
" array([6, 2, 7, 6, 6, 6, 2, 7, 7, 7, 7, 7])\n",
|
||||||
|
" array([2, 2, 6, 1, 2, 6, 2, 6, 2, 2, 1, 1, 1, 1, 1])\n",
|
||||||
|
" array([3, 6, 3, 5, 3, 6, 3, 6, 3, 3, 6, 5, 5, 5, 5, 5])\n",
|
||||||
|
" array([7, 5, 7, 1, 7, 5, 1, 1, 1, 1, 1, 7, 7, 7, 5])]\n",
|
||||||
|
"y_test\n",
|
||||||
|
" [3. 2. 0. 5. 4. 1. 6. 3. 2. 0. 5. 4. 1. 6. 3. 2. 0. 5. 4. 1. 6. 3. 2. 0.\n",
|
||||||
|
" 5. 4. 1. 6. 3. 2. 0. 5. 4. 1. 6.]\n",
|
||||||
|
"> \u001b[0;32m/tmp/ipykernel_97850/1264473745.py\u001b[0m(32)\u001b[0;36mcreateTrainTest\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"\u001b[0;32m 30 \u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'y_test\\n'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 31 \u001b[0;31m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m---> 32 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mshapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 33 \u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 34 \u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"ipdb> x_train.shape\n",
|
||||||
|
"(525,)\n",
|
||||||
|
"ipdb> x_test.shape\n",
|
||||||
|
"(35,)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"accuracies_full = dict()\n",
|
||||||
|
"accuracies_small = dict()\n",
|
||||||
|
"accuracies_last = dict()\n",
|
||||||
|
"\n",
|
||||||
|
"for current_PID in sorted(data.PID.unique()):\n",
|
||||||
|
" accuracies_full[current_PID], pred_label, test_label = runSVMS(createTrainTest([current_PID], Task_IDs, StartIndexOffset, EndIndexOffset, shapes=True))\n",
|
||||||
|
" # Only the first 5\n",
|
||||||
|
" accuracies_small[current_PID], pred_label, test_label = runSVMS(createTrainTest([current_PID], Task_IDs, StartIndexOffset, EndIndexOffset, shapes=True), 5)\n",
|
||||||
|
" # Only the last 5\n",
|
||||||
|
" accuracies_last[current_PID], pred_label, test_label = runSVMS(createTrainTest([current_PID], Task_IDs, StartIndexOffset, EndIndexOffset, shapes=True), 5, last_elements=True)\n",
|
||||||
|
" #pdb.set_trace()\n",
|
||||||
|
"print(accuracies_full)\n",
|
||||||
|
"print(accuracies_small)\n",
|
||||||
|
"print(accuracies_last)\n",
|
||||||
|
"print(\"mean full\", np.array(list(accuracies_full.values())).mean())\n",
|
||||||
|
"print(\"mean small\", np.array(list(accuracies_small.values())).mean())\n",
|
||||||
|
"print(\"mean last\", np.array(list(accuracies_last.values())).mean())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fdd4c915",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"len(g.groups.keys())\n",
|
||||||
|
"g.groups.keys()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
|
@ -0,0 +1,687 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "3aed8aec",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"2021-09-27 15:31:30.518074: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import datetime\n",
|
||||||
|
"import time,pdb\n",
|
||||||
|
"import json\n",
|
||||||
|
"import random\n",
|
||||||
|
"import statistics\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"\n",
|
||||||
|
"import tensorflow as tf\n",
|
||||||
|
"from tensorflow import keras\n",
|
||||||
|
"from sklearn import svm\n",
|
||||||
|
"from sklearn.model_selection import GridSearchCV \n",
|
||||||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||||
|
"from sklearn.metrics import accuracy_score\n",
|
||||||
|
"from tensorflow.keras.layers import *\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from tensorflow.keras.models import Sequential\n",
|
||||||
|
"from tensorflow.keras.optimizers import *\n",
|
||||||
|
"from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, Callback\n",
|
||||||
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
||||||
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||||||
|
"from sklearn.metrics import mean_squared_error\n",
|
||||||
|
"from sklearn.metrics import accuracy_score\n",
|
||||||
|
"import tqdm\n",
|
||||||
|
"from multiprocessing import Pool\n",
|
||||||
|
"import os\n",
|
||||||
|
"from tensorflow.compat.v1.keras.layers import Bidirectional, CuDNNLSTM"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"id": "817f7108",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"available PIDs [ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16.]\n",
|
||||||
|
"available TaskIDs [0. 1. 2. 3. 4. 5. 6.]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>Timestamp</th>\n",
|
||||||
|
" <th>Event</th>\n",
|
||||||
|
" <th>TaskID</th>\n",
|
||||||
|
" <th>Part</th>\n",
|
||||||
|
" <th>PID</th>\n",
|
||||||
|
" <th>TextRule</th>\n",
|
||||||
|
" <th>Rule</th>\n",
|
||||||
|
" <th>Type</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>1.575388e+12</td>\n",
|
||||||
|
" <td>4</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||||
|
" <td>3.0</td>\n",
|
||||||
|
" <td>Cmd</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>1.575388e+12</td>\n",
|
||||||
|
" <td>1</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||||
|
" <td>3.0</td>\n",
|
||||||
|
" <td>Toolbar</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>1.575388e+12</td>\n",
|
||||||
|
" <td>1</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||||
|
" <td>3.0</td>\n",
|
||||||
|
" <td>Cmd</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>1.575388e+12</td>\n",
|
||||||
|
" <td>4</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||||
|
" <td>3.0</td>\n",
|
||||||
|
" <td>Cmd</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>1.575388e+12</td>\n",
|
||||||
|
" <td>4</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>{'Title': ['1', 'Indent', 'and', 'Italic'], 'S...</td>\n",
|
||||||
|
" <td>3.0</td>\n",
|
||||||
|
" <td>Cmd</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>...</th>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>8376</th>\n",
|
||||||
|
" <td>1.603898e+12</td>\n",
|
||||||
|
" <td>7</td>\n",
|
||||||
|
" <td>6.0</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>16.0</td>\n",
|
||||||
|
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>Toolbar</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>8377</th>\n",
|
||||||
|
" <td>1.603898e+12</td>\n",
|
||||||
|
" <td>2</td>\n",
|
||||||
|
" <td>6.0</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>16.0</td>\n",
|
||||||
|
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>Cmd</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>8378</th>\n",
|
||||||
|
" <td>1.603898e+12</td>\n",
|
||||||
|
" <td>2</td>\n",
|
||||||
|
" <td>6.0</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>16.0</td>\n",
|
||||||
|
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>Cmd</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>8379</th>\n",
|
||||||
|
" <td>1.603898e+12</td>\n",
|
||||||
|
" <td>6</td>\n",
|
||||||
|
" <td>6.0</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>16.0</td>\n",
|
||||||
|
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>Toolbar</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>8380</th>\n",
|
||||||
|
" <td>1.603898e+12</td>\n",
|
||||||
|
" <td>6</td>\n",
|
||||||
|
" <td>6.0</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>16.0</td>\n",
|
||||||
|
" <td>{'Title': ['Size', 'Big'], 'Subtitle': ['Bold'...</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>Toolbar</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"<p>8381 rows × 8 columns</p>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" Timestamp Event TaskID Part PID \\\n",
|
||||||
|
"0 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"1 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||||
|
"2 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||||
|
"3 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"4 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"... ... ... ... ... ... \n",
|
||||||
|
"8376 1.603898e+12 7 6.0 5.0 16.0 \n",
|
||||||
|
"8377 1.603898e+12 2 6.0 5.0 16.0 \n",
|
||||||
|
"8378 1.603898e+12 2 6.0 5.0 16.0 \n",
|
||||||
|
"8379 1.603898e+12 6 6.0 5.0 16.0 \n",
|
||||||
|
"8380 1.603898e+12 6 6.0 5.0 16.0 \n",
|
||||||
|
"\n",
|
||||||
|
" TextRule Rule Type \n",
|
||||||
|
"0 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"1 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"2 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"3 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"4 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"... ... ... ... \n",
|
||||||
|
"8376 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Toolbar \n",
|
||||||
|
"8377 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Cmd \n",
|
||||||
|
"8378 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Cmd \n",
|
||||||
|
"8379 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Toolbar \n",
|
||||||
|
"8380 {'Title': ['Size', 'Big'], 'Subtitle': ['Bold'... 5.0 Toolbar \n",
|
||||||
|
"\n",
|
||||||
|
"[8381 rows x 8 columns]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 18,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"study_data_path = \"../IntentData/\"\n",
|
||||||
|
"data = pd.read_pickle(study_data_path + \"/Preprocessing_data/clean_data.pkl\")\n",
|
||||||
|
"#val_data = pd.read_pickle(study_data_path + \"/Preprocessing_data/clean_data_condition2.pkl\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"available PIDs\", data.PID.unique())\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"available TaskIDs\", data.TaskID.unique())\n",
|
||||||
|
"\n",
|
||||||
|
"data.Event.unique()\n",
|
||||||
|
"data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "ab778228",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"count 560.000000\n",
|
||||||
|
"mean 14.966071\n",
|
||||||
|
"std 2.195440\n",
|
||||||
|
"min 8.000000\n",
|
||||||
|
"25% 14.000000\n",
|
||||||
|
"50% 15.000000\n",
|
||||||
|
"75% 16.000000\n",
|
||||||
|
"max 28.000000\n",
|
||||||
|
"Name: Event, dtype: float64"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"data.groupby([\"PID\", \"Part\", \"TaskID\"])[\"Event\"].count().describe()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "32550f71",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"Task_IDs = list(range(0,7))\n",
|
||||||
|
"\n",
|
||||||
|
"# grouping by part is needed to have one ruleset for the whole part\n",
|
||||||
|
"g = data.groupby([\"PID\", \"Part\", \"TaskID\"])\n",
|
||||||
|
"df_all = []"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "f6fecc2f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def createTrainTestalaSven(test_IDs, task_IDs, window_size, stride, shapes=False, val_IDs=None):\n",
|
||||||
|
" if not isinstance(test_IDs, list):\n",
|
||||||
|
" raise ValueError(\"Test_IDs are not a list\")\n",
|
||||||
|
" if not isinstance(task_IDs, list):\n",
|
||||||
|
" raise ValueError(\"Task_IDs are not a list\")\n",
|
||||||
|
" # Fill data arrays\n",
|
||||||
|
" all_elem = []\n",
|
||||||
|
" for current in g.groups.keys():\n",
|
||||||
|
" c = g.get_group(current)\n",
|
||||||
|
" if (c.TaskID.isin(task_IDs).all()):\n",
|
||||||
|
" \n",
|
||||||
|
" new_data = c.Event.values\n",
|
||||||
|
" stepper = 0\n",
|
||||||
|
" while stepper <= (len(new_data)-window_size-1):\n",
|
||||||
|
" tmp = new_data[stepper:stepper + window_size]\n",
|
||||||
|
" x = tmp[:-1]\n",
|
||||||
|
" y = tmp[-1]\n",
|
||||||
|
" stepper += stride\n",
|
||||||
|
" \n",
|
||||||
|
" if (c.PID.isin(test_IDs).all()):\n",
|
||||||
|
" all_elem.append([\"Test\", x, y])\n",
|
||||||
|
" elif (c.PID.isin(val_IDs).all()):\n",
|
||||||
|
" all_elem.append([\"Val\", x, y])\n",
|
||||||
|
" else:\n",
|
||||||
|
" all_elem.append([\"Train\", x, y])\n",
|
||||||
|
" df_tmp = pd.DataFrame(all_elem, columns =[\"Split\", \"X\", \"Y\"])\n",
|
||||||
|
" turbo = []\n",
|
||||||
|
" for s in df_tmp.Split.unique():\n",
|
||||||
|
" dfX = df_tmp[df_tmp.Split == s]\n",
|
||||||
|
" max_amount = dfX.groupby([\"Y\"]).count().max().X\n",
|
||||||
|
" for y in dfX.Y.unique():\n",
|
||||||
|
" df_turbotmp = dfX[dfX.Y == y]\n",
|
||||||
|
" turbo.append(df_turbotmp)\n",
|
||||||
|
" turbo.append(df_turbotmp.sample(max_amount-len(df_turbotmp), replace=True))\n",
|
||||||
|
" # if len(df_turbotmp) < max_amount:\n",
|
||||||
|
"\n",
|
||||||
|
" df_tmp = pd.concat(turbo)\n",
|
||||||
|
" x_train, y_train = df_tmp[df_tmp.Split == \"Train\"].X.values, df_tmp[df_tmp.Split == \"Train\"].Y.values\n",
|
||||||
|
" x_test, y_test = df_tmp[df_tmp.Split == \"Test\"].X.values, df_tmp[df_tmp.Split == \"Test\"].Y.values\n",
|
||||||
|
" x_val, y_val = df_tmp[df_tmp.Split == \"Val\"].X.values, df_tmp[df_tmp.Split == \"Val\"].Y.values\n",
|
||||||
|
" \n",
|
||||||
|
" x_train = np.expand_dims(np.stack(x_train), axis=2)\n",
|
||||||
|
" y_train = np.array(y_train)\n",
|
||||||
|
" x_test = np.expand_dims(np.stack(x_test), axis=2)\n",
|
||||||
|
" y_test = np.array(y_test)\n",
|
||||||
|
" if len(x_val) > 0:\n",
|
||||||
|
" x_val = np.expand_dims(np.stack(x_val), axis=2)\n",
|
||||||
|
" y_val = np.array(y_val)\n",
|
||||||
|
" return(x_train, y_train, x_test, y_test, x_val, y_val)\n",
|
||||||
|
" return(x_train, y_train, x_test, y_test)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "b8f92bc1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def createTrainTest(test_IDs, task_IDs, window_size, stride, shapes=False, val_IDs=None):\n",
|
||||||
|
" if not isinstance(test_IDs, list):\n",
|
||||||
|
" raise ValueError(\"Test_IDs are not a list\")\n",
|
||||||
|
" if not isinstance(task_IDs, list):\n",
|
||||||
|
" raise ValueError(\"Task_IDs are not a list\")\n",
|
||||||
|
" # Fill data arrays\n",
|
||||||
|
" y_train = []\n",
|
||||||
|
" x_train = []\n",
|
||||||
|
" y_test = []\n",
|
||||||
|
" x_test = []\n",
|
||||||
|
" x_val = []\n",
|
||||||
|
" y_val = []\n",
|
||||||
|
" \n",
|
||||||
|
" for current in g.groups.keys():\n",
|
||||||
|
" c = g.get_group(current)\n",
|
||||||
|
" if (c.TaskID.isin(task_IDs).all()):\n",
|
||||||
|
" \n",
|
||||||
|
" new_data = c.Event.values\n",
|
||||||
|
" stepper = 0\n",
|
||||||
|
" while stepper <= (len(new_data)-window_size-1):\n",
|
||||||
|
" tmp = new_data[stepper:stepper + window_size]\n",
|
||||||
|
" pdb.set_trace()\n",
|
||||||
|
" x = tmp[:-1]\n",
|
||||||
|
" y = tmp[-1]\n",
|
||||||
|
" stepper += stride\n",
|
||||||
|
" if (c.PID.isin(test_IDs).all()):\n",
|
||||||
|
" if y == 6:\n",
|
||||||
|
" y_test.append(y)\n",
|
||||||
|
" x_test.append(x)\n",
|
||||||
|
" y_test.append(y)\n",
|
||||||
|
" x_test.append(x)\n",
|
||||||
|
" elif (c.PID.isin(val_IDs).all()):\n",
|
||||||
|
" if y == 6:\n",
|
||||||
|
" y_val.append(y)\n",
|
||||||
|
" x_val.append(x)\n",
|
||||||
|
" y_val.append(y)\n",
|
||||||
|
" x_val.append(x)\n",
|
||||||
|
" else:\n",
|
||||||
|
" if y == 6:\n",
|
||||||
|
" y_train.append(y)\n",
|
||||||
|
" x_train.append(x)\n",
|
||||||
|
" y_train.append(y)\n",
|
||||||
|
" x_train.append(x)\n",
|
||||||
|
" x_train = np.array(x_train)\n",
|
||||||
|
" y_train = np.array(y_train)\n",
|
||||||
|
" x_test = np.array(x_test)\n",
|
||||||
|
" y_test = np.array(y_test)\n",
|
||||||
|
" x_val = np.array(x_val)\n",
|
||||||
|
" y_val = np.array(y_val)\n",
|
||||||
|
" pdb.set_trace()\n",
|
||||||
|
" if (shapes):\n",
|
||||||
|
" print(x_train.shape)\n",
|
||||||
|
" print(y_train.shape)\n",
|
||||||
|
" print(x_test.shape)\n",
|
||||||
|
" print(y_test.shape)\n",
|
||||||
|
" print(x_val.shape)\n",
|
||||||
|
" print(y_val.shape)\n",
|
||||||
|
" print(np.unique(y_test))\n",
|
||||||
|
" print(np.unique(y_train))\n",
|
||||||
|
" if len(x_val) > 0:\n",
|
||||||
|
" return(x_train, y_train, x_test, y_test, x_val, y_val)\n",
|
||||||
|
" return (x_train, y_train, x_test, y_test)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "e56fbc58",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"maxlen = 1000\n",
|
||||||
|
"lens = []\n",
|
||||||
|
"for current in g.groups.keys():\n",
|
||||||
|
" c = g.get_group(current)\n",
|
||||||
|
" lens.append(len(c.Event.values))\n",
|
||||||
|
" maxlen = min(maxlen, len(c.Event.values))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"id": "c02cbdae",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Number of trees in random forest\n",
|
||||||
|
"n_estimators = np.arange(5,100, 5)\n",
|
||||||
|
"# Number of features to consider at every split\n",
|
||||||
|
"max_features = ['sqrt']\n",
|
||||||
|
"# Maximum number of levels in tree\n",
|
||||||
|
"max_depth = np.arange(5,100, 5)\n",
|
||||||
|
"# Minimum number of samples required to split a node\n",
|
||||||
|
"min_samples_split = np.arange(2,10, 1)\n",
|
||||||
|
"# Minimum number of samples required at each leaf node\n",
|
||||||
|
"min_samples_leaf = np.arange(2,5, 1)\n",
|
||||||
|
"# Method of selecting samples for training each tree\n",
|
||||||
|
"bootstrap = [True, False]\n",
|
||||||
|
"\n",
|
||||||
|
"# Create the random grid\n",
|
||||||
|
"param_grid = {'n_estimators': n_estimators,\n",
|
||||||
|
" 'max_features': max_features,\n",
|
||||||
|
" 'max_depth': max_depth,\n",
|
||||||
|
" 'min_samples_split': min_samples_split,\n",
|
||||||
|
" 'min_samples_leaf': min_samples_leaf,\n",
|
||||||
|
" 'bootstrap': bootstrap}\n",
|
||||||
|
"\n",
|
||||||
|
"grid = GridSearchCV(RandomForestClassifier(), param_grid, refit = True, verbose = 0, return_train_score=True) "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "c2bcfe7f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def doTrainSlideWindowNoPad(currentPid):\n",
|
||||||
|
" print(f\"doTrain: {currentPid}\")\n",
|
||||||
|
" dfs = []\n",
|
||||||
|
" for window_size in range(8, 15): \n",
|
||||||
|
" (x_train, y_train, x_test, y_test) = createTrainTest([currentPid], Task_IDs, window_size, 1, False, [200])\n",
|
||||||
|
" print(f\"doTrain: created TrainTestsplit\")\n",
|
||||||
|
"\n",
|
||||||
|
" # print(\"window_size\", 5, \"PID\", currentPid, \"samples\", x_train.shape[0], \"generated_samples\", \"samples\", x_train_window.shape[0])\n",
|
||||||
|
"\n",
|
||||||
|
" grid.fit(x_train, y_train)\n",
|
||||||
|
" print(\"fitted\")\n",
|
||||||
|
" # y_pred = grid.predict(x_test)\n",
|
||||||
|
"\n",
|
||||||
|
" df_params = pd.DataFrame(grid.cv_results_[\"params\"])\n",
|
||||||
|
" df_params[\"Mean_test\"] = grid.cv_results_[\"mean_test_score\"]\n",
|
||||||
|
" df_params[\"Mean_train\"] = grid.cv_results_[\"mean_train_score\"]\n",
|
||||||
|
" df_params[\"STD_test\"] = grid.cv_results_[\"std_test_score\"]\n",
|
||||||
|
" df_params[\"STD_train\"] = grid.cv_results_[\"std_train_score\"]\n",
|
||||||
|
" df_params['Window_Size'] = window_size\n",
|
||||||
|
" df_params['PID'] = currentPid\n",
|
||||||
|
" # df_params[\"Accuracy\"] = accuracy_score(y_pred, y_test)\n",
|
||||||
|
" dfs.append(df_params)\n",
|
||||||
|
"\n",
|
||||||
|
" return pd.concat(dfs)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "9e3d86f1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"doTrain: 1\n",
|
||||||
|
"> \u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m(23)\u001b[0;36mcreateTrainTest\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"\u001b[0;32m 21 \u001b[0;31m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 22 \u001b[0;31m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m---> 23 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 24 \u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 25 \u001b[0;31m \u001b[0mstepper\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
|
"ipdb> tmp\n",
|
||||||
|
"array([4, 1, 1, 4, 4, 7, 7, 7])\n",
|
||||||
|
"ipdb> new_data\n",
|
||||||
|
"array([4, 1, 1, 4, 4, 7, 7, 7, 7, 7, 7, 4, 1, 4, 4, 4])\n",
|
||||||
|
"ipdb> current\n",
|
||||||
|
"(1.0, 1.0, 0.0)\n",
|
||||||
|
"ipdb> print(c)\n",
|
||||||
|
" Timestamp Event TaskID Part PID \\\n",
|
||||||
|
"0 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"1 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||||
|
"2 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||||
|
"3 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"4 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"5 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||||
|
"6 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||||
|
"7 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||||
|
"8 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||||
|
"9 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||||
|
"10 1.575388e+12 7 0.0 1.0 1.0 \n",
|
||||||
|
"11 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"12 1.575388e+12 1 0.0 1.0 1.0 \n",
|
||||||
|
"13 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"14 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"15 1.575388e+12 4 0.0 1.0 1.0 \n",
|
||||||
|
"\n",
|
||||||
|
" TextRule Rule Type \n",
|
||||||
|
"0 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"1 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"2 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"3 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"4 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"5 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"6 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"7 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"8 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"9 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"10 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"11 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"12 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Toolbar \n",
|
||||||
|
"13 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"14 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"15 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S... 3.0 Cmd \n",
|
||||||
|
"ipdb> print(c.TextRule)\n",
|
||||||
|
"0 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"1 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"2 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"3 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"4 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"5 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"6 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"7 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"8 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"9 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"10 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"11 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"12 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"13 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"14 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"15 {'Title': ['1', 'Indent', 'and', 'Italic'], 'S...\n",
|
||||||
|
"Name: TextRule, dtype: object\n",
|
||||||
|
"ipdb> print(c.Event)\n",
|
||||||
|
"0 4\n",
|
||||||
|
"1 1\n",
|
||||||
|
"2 1\n",
|
||||||
|
"3 4\n",
|
||||||
|
"4 4\n",
|
||||||
|
"5 7\n",
|
||||||
|
"6 7\n",
|
||||||
|
"7 7\n",
|
||||||
|
"8 7\n",
|
||||||
|
"9 7\n",
|
||||||
|
"10 7\n",
|
||||||
|
"11 4\n",
|
||||||
|
"12 1\n",
|
||||||
|
"13 4\n",
|
||||||
|
"14 4\n",
|
||||||
|
"15 4\n",
|
||||||
|
"Name: Event, dtype: int64\n",
|
||||||
|
"ipdb> val\n",
|
||||||
|
"*** NameError: name 'val' is not defined\n",
|
||||||
|
"ipdb> val_IDs\n",
|
||||||
|
"[200]\n",
|
||||||
|
"--KeyboardInterrupt--\n",
|
||||||
|
"\n",
|
||||||
|
"KeyboardInterrupt: Interrupted by user\n",
|
||||||
|
"> \u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m(22)\u001b[0;36mcreateTrainTest\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"\u001b[0;32m 20 \u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mstepper\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 21 \u001b[0;31m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m---> 22 \u001b[0;31m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 23 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 24 \u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
|
"--KeyboardInterrupt--\n",
|
||||||
|
"\n",
|
||||||
|
"KeyboardInterrupt: Interrupted by user\n",
|
||||||
|
"> \u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m(23)\u001b[0;36mcreateTrainTest\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"\u001b[0;32m 21 \u001b[0;31m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 22 \u001b[0;31m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m---> 23 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 24 \u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\u001b[0;32m 25 \u001b[0;31m \u001b[0mstepper\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
|
"ipdb> q\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "BdbQuit",
|
||||||
|
"evalue": "",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mBdbQuit\u001b[0m Traceback (most recent call last)",
|
||||||
|
"\u001b[0;32m/tmp/ipykernel_90176/1128965594.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdoTrainSlideWindowNoPad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||||
|
"\u001b[0;32m/tmp/ipykernel_90176/2629087375.py\u001b[0m in \u001b[0;36mdoTrainSlideWindowNoPad\u001b[0;34m(currentPid)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mdfs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mwindow_size\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m15\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreateTrainTest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcurrentPid\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTask_IDs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"doTrain: created TrainTestsplit\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m in \u001b[0;36mcreateTrainTest\u001b[0;34m(test_IDs, task_IDs, window_size, stride, shapes, val_IDs)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mstepper\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/tmp/ipykernel_90176/2602038955.py\u001b[0m in \u001b[0;36mcreateTrainTest\u001b[0;34m(test_IDs, task_IDs, window_size, stride, shapes, val_IDs)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstepper\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstepper\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mstepper\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m~/miniconda3/envs/intentPrediction/lib/python3.9/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;31m# None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'line'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'call'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m~/miniconda3/envs/intentPrediction/lib/python3.9/bdb.py\u001b[0m in \u001b[0;36mdispatch_line\u001b[0;34m(self, frame)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstop_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbreak_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 113\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 114\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dispatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;31mBdbQuit\u001b[0m: "
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"doTrainSlideWindowNoPad(1)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_action_id_6.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_data_6.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/test_label_6.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_data_6.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_0.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_0.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_1.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_1.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_2.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_2.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_3.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_3.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_4.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_4.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_5.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_5.pkl
Normal file
Binary file not shown.
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_6.pkl
Normal file
BIN
keyboard_and_mouse/dataset/strategy_dataset/train_label_6.pkl
Normal file
Binary file not shown.
167
keyboard_and_mouse/networks.py
Normal file
167
keyboard_and_mouse/networks.py
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class fc_block(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, norm, activation_fn):
|
||||||
|
super(fc_block, self).__init__()
|
||||||
|
block = nn.Sequential()
|
||||||
|
block.add_module('linear', nn.Linear(in_channels, out_channels))
|
||||||
|
if norm:
|
||||||
|
block.add_module('batchnorm', nn.BatchNorm1d(out_channels))
|
||||||
|
if activation_fn is not None:
|
||||||
|
block.add_module('activation', activation_fn())
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
class ActionDemoEncoder(nn.Module):
|
||||||
|
def __init__(self, args, pooling):
|
||||||
|
super(ActionDemoEncoder, self).__init__()
|
||||||
|
hidden_size = args.demo_hidden
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bs = args.batch_size
|
||||||
|
|
||||||
|
len_action_predicates = 35 # max_action_len
|
||||||
|
self.action_embed = nn.Embedding(len_action_predicates, hidden_size)
|
||||||
|
|
||||||
|
feat2hidden = nn.Sequential()
|
||||||
|
feat2hidden.add_module(
|
||||||
|
'fc_block1', fc_block(hidden_size, hidden_size, False, nn.ReLU))
|
||||||
|
self.feat2hidden = feat2hidden
|
||||||
|
|
||||||
|
self.pooling = pooling
|
||||||
|
|
||||||
|
if 'lstm' in self.pooling:
|
||||||
|
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
||||||
|
|
||||||
|
def forward(self, batch_data):
|
||||||
|
batch_data = batch_data.view(-1,1)
|
||||||
|
stacked_demo_feat = self.action_embed(batch_data)
|
||||||
|
stacked_demo_feat = self.feat2hidden(stacked_demo_feat)
|
||||||
|
batch_demo_feat = []
|
||||||
|
start = 0
|
||||||
|
|
||||||
|
for length in range(0,batch_data.shape[0]):
|
||||||
|
if length == 0:
|
||||||
|
feat = stacked_demo_feat[0:1, :]
|
||||||
|
else:
|
||||||
|
feat = stacked_demo_feat[(length-1):length, :]
|
||||||
|
if len(feat.size()) == 3:
|
||||||
|
feat = feat.unsqueeze(0)
|
||||||
|
|
||||||
|
if self.pooling == 'max':
|
||||||
|
feat = torch.max(feat, 0)[0]
|
||||||
|
elif self.pooling == 'avg':
|
||||||
|
feat = torch.mean(feat, 0)
|
||||||
|
elif self.pooling == 'lstmavg':
|
||||||
|
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
|
||||||
|
lstm_out = lstm_out.view(len(feat), -1)
|
||||||
|
feat = torch.mean(lstm_out, 0)
|
||||||
|
elif self.pooling == 'lstmlast':
|
||||||
|
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
|
||||||
|
lstm_out = lstm_out.view(len(feat), -1)
|
||||||
|
feat = lstm_out[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
batch_demo_feat.append(feat)
|
||||||
|
|
||||||
|
demo_emb = torch.stack(batch_demo_feat, 0)
|
||||||
|
demo_emb = demo_emb.view(self.bs, 35, -1)
|
||||||
|
return demo_emb
|
||||||
|
|
||||||
|
class PredicateClassifier(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args,):
|
||||||
|
super(PredicateClassifier, self).__init__()
|
||||||
|
hidden_size = args.demo_hidden
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
classifier = nn.Sequential()
|
||||||
|
classifier.add_module('fc_block1', fc_block(hidden_size*35, hidden_size, False, nn.Tanh))
|
||||||
|
classifier.add_module('dropout', nn.Dropout(args.dropout))
|
||||||
|
classifier.add_module('fc_block2', fc_block(hidden_size, 7, False, None)) # 7 is all possible actions
|
||||||
|
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
def forward(self, input_emb):
|
||||||
|
input_emb = input_emb.view(-1, self.hidden_size*35)
|
||||||
|
return self.classifier(input_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionDemo2Predicate(nn.Module):
|
||||||
|
def __init__(self, args, **kwargs):
|
||||||
|
super(ActionDemo2Predicate, self).__init__()
|
||||||
|
|
||||||
|
print('------------------------------------------------------------------------------------------')
|
||||||
|
print('ActionDemo2Predicate')
|
||||||
|
print('------------------------------------------------------------------------------------------')
|
||||||
|
|
||||||
|
model_type = args.model_type
|
||||||
|
print('model_type', model_type)
|
||||||
|
|
||||||
|
if model_type.lower() == 'max':
|
||||||
|
demo_encoder = ActionDemoEncoder(args, 'max')
|
||||||
|
elif model_type.lower() == 'avg':
|
||||||
|
demo_encoder = ActionDemoEncoder(args, 'avg')
|
||||||
|
elif model_type.lower() == 'lstmavg':
|
||||||
|
demo_encoder = ActionDemoEncoder(args, 'lstmavg')
|
||||||
|
elif model_type.lower() == 'bilstmavg':
|
||||||
|
demo_encoder = ActionDemoEncoder(args, 'bilstmavg')
|
||||||
|
elif model_type.lower() == 'lstmlast':
|
||||||
|
demo_encoder = ActionDemoEncoder(args, 'lstmlast')
|
||||||
|
elif model_type.lower() == 'bilstmlast':
|
||||||
|
demo_encoder = ActionDemoEncoder(args, 'bilstmlast')
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
demo_encoder = torch.nn.DataParallel(demo_encoder)
|
||||||
|
|
||||||
|
predicate_decoder = PredicateClassifier(args)
|
||||||
|
|
||||||
|
# for quick save and load
|
||||||
|
all_modules = nn.Sequential()
|
||||||
|
all_modules.add_module('demo_encoder', demo_encoder)
|
||||||
|
all_modules.add_module('predicate_decoder', predicate_decoder)
|
||||||
|
|
||||||
|
self.demo_encoder = demo_encoder
|
||||||
|
self.predicate_decoder = predicate_decoder
|
||||||
|
self.all_modules = all_modules
|
||||||
|
self.to_cuda_fn = None
|
||||||
|
|
||||||
|
def set_to_cuda_fn(self, to_cuda_fn):
|
||||||
|
self.to_cuda_fn = to_cuda_fn
|
||||||
|
|
||||||
|
def forward(self, data, **kwargs):
|
||||||
|
'''
|
||||||
|
Note: The order of the `data` won't change in this function
|
||||||
|
'''
|
||||||
|
if self.to_cuda_fn:
|
||||||
|
data = self.to_cuda_fn(data)
|
||||||
|
|
||||||
|
batch_demo_emb = self.demo_encoder(data)
|
||||||
|
pred = self.predicate_decoder(batch_demo_emb)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def write_summary(self, writer, info, postfix):
|
||||||
|
model_name = 'Demo2Predicate-{}/'.format(postfix)
|
||||||
|
for k in self.summary_keys:
|
||||||
|
if k in info.keys():
|
||||||
|
writer.scalar_summary(model_name + k, info[k])
|
||||||
|
|
||||||
|
def save(self, path, verbose=False):
|
||||||
|
if verbose:
|
||||||
|
print(colored('[*] Save model at {}'.format(path), 'magenta'))
|
||||||
|
torch.save(self.all_modules.state_dict(), path)
|
||||||
|
|
||||||
|
def load(self, path, verbose=False):
|
||||||
|
if verbose:
|
||||||
|
print(colored('[*] Load model at {}'.format(path), 'magenta'))
|
||||||
|
self.all_modules.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
path,
|
||||||
|
map_location=lambda storage,
|
||||||
|
loc: storage))
|
||||||
|
|
207
keyboard_and_mouse/process_data.py
Normal file
207
keyboard_and_mouse/process_data.py
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
import pickle
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def view_clean_data():
|
||||||
|
with open('dataset/clean_data.pkl', 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
print(type(data), len(data))
|
||||||
|
print(data.keys())
|
||||||
|
print('length of data:',len(data))
|
||||||
|
print('event', data['Event'], 'length of event', len(data['Event']))
|
||||||
|
print('rule', data['Rule'], 'length of event', len(data['Rule']))
|
||||||
|
|
||||||
|
print('rule unique', data.Rule.unique())
|
||||||
|
print('task id unique', data.TaskID.unique())
|
||||||
|
print('pid unique', data.PID.unique())
|
||||||
|
print('event unique', data.Event.unique())
|
||||||
|
|
||||||
|
def split_org_data():
|
||||||
|
# generate train, test data by split user, aggregate action sequence for next action prediction
|
||||||
|
# orignial action seq: a = [a_0 ... a_n]
|
||||||
|
# new action seq: for a: a0 = [a_0], a1 = [a_0, a_1] ...
|
||||||
|
|
||||||
|
# split original data into train and test based on user
|
||||||
|
with open('dataset/clean_data.pkl', 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
|
||||||
|
print('original data keys', data.keys())
|
||||||
|
print('len of original data', len(data))
|
||||||
|
print('rule unique', data.Rule.unique())
|
||||||
|
print('event unique', data.Event.unique())
|
||||||
|
|
||||||
|
data_train = data[data['PID']<=11]
|
||||||
|
data_test = data[data['PID']>11]
|
||||||
|
print('train set len', len(data_train))
|
||||||
|
print('test set len', len(data_test))
|
||||||
|
|
||||||
|
# split data by task
|
||||||
|
train_data_intent = []
|
||||||
|
test_data_intent = []
|
||||||
|
for i in range(7):
|
||||||
|
# 7 different rules, each as an intention
|
||||||
|
train_data_intent.append(data_train[data_train['Rule']==i])
|
||||||
|
test_data_intent.append(data_test[data_test['Rule']==i])
|
||||||
|
|
||||||
|
# generate train set
|
||||||
|
max_len = 0 # max len is 35
|
||||||
|
for i in range(7): # 7 tasks/rules
|
||||||
|
train_data = [] # [task]
|
||||||
|
train_label = []
|
||||||
|
for u in range(1,12):
|
||||||
|
user_data = train_data_intent[i][train_data_intent[i]['PID']==u]
|
||||||
|
for j in range(1,6): # 5 parts == 5 trials
|
||||||
|
part_data = user_data[user_data['Part']==j]
|
||||||
|
for l in range(1,len(part_data['Event'])-1):
|
||||||
|
print(part_data['Event'][:l].tolist())
|
||||||
|
train_data.append(part_data['Event'][:l].tolist())
|
||||||
|
train_label.append(part_data['Event'].iat[l+1])
|
||||||
|
if len(part_data['Event'])>max_len:
|
||||||
|
max_len = len(part_data['Event'])
|
||||||
|
|
||||||
|
for k in range(len(train_data)):
|
||||||
|
while len(train_data[k])<35:
|
||||||
|
train_data[k].append(0) # padding with 0
|
||||||
|
|
||||||
|
print('x_len', len(train_data), type(train_data[0]), len(train_data[0]))
|
||||||
|
print('y_len', len(train_label), type(train_label[0]))
|
||||||
|
|
||||||
|
Path("dataset/strategy_dataset").mkdir(parents=True, exist_ok=True)
|
||||||
|
with open('dataset/strategy_dataset/train_label_'+str(i)+'.pkl', 'wb') as f:
|
||||||
|
pickle.dump(train_label, f)
|
||||||
|
with open('dataset/strategy_dataset/train_data_'+str(i)+'.pkl', 'wb') as f:
|
||||||
|
pickle.dump(train_data, f)
|
||||||
|
print('max_len', max_len)
|
||||||
|
|
||||||
|
# generate test set
|
||||||
|
max_len = 0 # max len is 33, total max is 35
|
||||||
|
for i in range(7): # 7 tasks/rules
|
||||||
|
test_data = [] # [task][user]
|
||||||
|
test_label = []
|
||||||
|
test_action_id = []
|
||||||
|
for u in range(12,17):
|
||||||
|
user_data = test_data_intent[i][test_data_intent[i]['PID']==u]
|
||||||
|
test_data_user = []
|
||||||
|
test_label_user = []
|
||||||
|
test_action_id_user = []
|
||||||
|
for j in range(1,6): # 5 parts == 5 trials
|
||||||
|
part_data = user_data[user_data['Part']==j]
|
||||||
|
|
||||||
|
for l in range(1,len(part_data['Event'])-1):
|
||||||
|
test_data_user.append(part_data['Event'][:l].tolist())
|
||||||
|
test_label_user.append(part_data['Event'].iat[l+1])
|
||||||
|
test_action_id_user.append(part_data['Part'].iat[l])
|
||||||
|
|
||||||
|
if len(part_data['Event'])>max_len:
|
||||||
|
max_len = len(part_data['Event'])
|
||||||
|
|
||||||
|
for k in range(len(test_data_user)):
|
||||||
|
while len(test_data_user[k])<35:
|
||||||
|
test_data_user[k].append(0) # padding with 0
|
||||||
|
|
||||||
|
test_data.append(test_data_user)
|
||||||
|
test_label.append(test_label_user)
|
||||||
|
test_action_id.append(test_action_id_user)
|
||||||
|
|
||||||
|
|
||||||
|
print('x_len', len(test_data), type(test_data[0]), len(test_data[0]))
|
||||||
|
print('y_len', len(test_label), type(test_label[0]))
|
||||||
|
with open('dataset/strategy_dataset/test_label_'+str(i)+'.pkl', 'wb') as f:
|
||||||
|
pickle.dump(test_label, f)
|
||||||
|
with open('dataset/strategy_dataset/test_data_'+str(i)+'.pkl', 'wb') as f:
|
||||||
|
pickle.dump(test_data, f)
|
||||||
|
with open('dataset/strategy_dataset/test_action_id_'+str(i)+'.pkl', 'wb') as f:
|
||||||
|
pickle.dump(test_action_id, f)
|
||||||
|
print('max_len', max_len)
|
||||||
|
|
||||||
|
def calc_gt_prob():
|
||||||
|
# train set unique label
|
||||||
|
for i in range(7):
|
||||||
|
with open('dataset/strategy_dataset/train_label_'+str(i)+'.pkl', 'rb') as f:
|
||||||
|
y = pickle.load(f)
|
||||||
|
y = np.array(y)
|
||||||
|
print('task ', i)
|
||||||
|
print('unique train label', np.unique(y))
|
||||||
|
|
||||||
|
def plot_gt_dist():
|
||||||
|
|
||||||
|
full_data = []
|
||||||
|
for i in range(7):
|
||||||
|
with open('dataset/strategy_dataset/' + 'test' + '_label_' + str(i) + '.pkl', 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
#print(len(data))
|
||||||
|
full_data.append(data)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(7)
|
||||||
|
fig.set_figheight(10)
|
||||||
|
fig.set_figwidth(16)
|
||||||
|
act_name = ["Italic", "Bold", "Underline", "Indent", "Align", "FontSize", "FontFamily"]
|
||||||
|
x = np.arange(7)
|
||||||
|
|
||||||
|
width = 0.1
|
||||||
|
for i in range(7):
|
||||||
|
for u in range(len(data)): # 5 users
|
||||||
|
values, counts = np.unique(full_data[i][u], return_counts=True)
|
||||||
|
counts_vis = [0]*7
|
||||||
|
for j in range(len(values)):
|
||||||
|
counts_vis[values[j]-1] = counts[j]
|
||||||
|
print('task', i, 'actions', values, 'num', counts)
|
||||||
|
|
||||||
|
axs[i].set_title('Intention '+str(i))
|
||||||
|
axs[i].set_xlabel('action id')
|
||||||
|
axs[i].set_ylabel('num of actions')
|
||||||
|
axs[i].bar(x+u*width, counts_vis, width=0.1, label='user '+str(u))
|
||||||
|
axs[i].set_xticks(np.arange(len(x)))
|
||||||
|
axs[i].set_xticklabels(act_name)
|
||||||
|
axs[i].set_ylim([0,80])
|
||||||
|
|
||||||
|
axs[0].legend(loc='upper right', ncol=1)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('dataset/'+'test'+'_gt_dist.png')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def plot_act():
|
||||||
|
full_data = []
|
||||||
|
for i in range(7):
|
||||||
|
with open('dataset/strategy_dataset/' + 'test' + '_label_' + str(i) + '.pkl', 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
full_data.append(data)
|
||||||
|
|
||||||
|
width = 0.1
|
||||||
|
for i in range(7):
|
||||||
|
fig, axs = plt.subplots(5)
|
||||||
|
fig.set_figheight(10)
|
||||||
|
fig.set_figwidth(16)
|
||||||
|
act_name = ["Italic", "Bold", "Underline", "Indent", "Align", "FontSize", "FontFamily"]
|
||||||
|
for u in range(len(full_data[i])): # 5 users
|
||||||
|
x = np.arange(len(full_data[i][u]))
|
||||||
|
axs[u].set_xlabel('action id')
|
||||||
|
axs[u].set_ylabel('num of actions')
|
||||||
|
axs[u].plot(x, full_data[i][u])
|
||||||
|
|
||||||
|
axs[0].legend(loc='upper right', ncol=1)
|
||||||
|
plt.tight_layout()
|
||||||
|
#plt.savefig('test'+'_act.png')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("func", help="select what function to run. view_clean_data, split_org_data, calc_gt_prob, plot_gt_dist, plot_act", type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.func == 'view_clean_data':
|
||||||
|
view_clean_data() # view original keyboad and mouse interaction dataset
|
||||||
|
if args.func == 'split_org_data':
|
||||||
|
split_org_data() # split the original keyboad and mouse interaction dataset. User 1-11 for training, rest for testing
|
||||||
|
if args.func == 'calc_gt_prob':
|
||||||
|
calc_gt_prob() # see unique label in train set
|
||||||
|
if args.func == 'plot_gt_dist':
|
||||||
|
plot_gt_dist() # plot the label distribution of test set
|
||||||
|
if args.func == 'plot_act':
|
||||||
|
plot_act() # plot the label of test set
|
||||||
|
|
||||||
|
|
||||||
|
|
56
keyboard_and_mouse/sampler_single_act.py
Normal file
56
keyboard_and_mouse/sampler_single_act.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import csv
|
||||||
|
import pandas
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def sample_single_act(pred_path, save_path, j):
|
||||||
|
data = pandas.read_csv(pred_path).values
|
||||||
|
total_data = []
|
||||||
|
|
||||||
|
for u in range(1,6):
|
||||||
|
act_data = data[data[:,1]==u]
|
||||||
|
final_save_path = save_path + "/rate_" + str(j) + "_act_" + str(int(u)) + "_pred.csv"
|
||||||
|
head = []
|
||||||
|
for r in range(7):
|
||||||
|
head.append('act'+str(r+1))
|
||||||
|
head.append('task_name')
|
||||||
|
head.append('gt')
|
||||||
|
head.insert(0,'action_id')
|
||||||
|
pandas.DataFrame(act_data[:,1:]).to_csv(final_save_path, header=head)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parsing parameters
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=64, help='hidden_size')
|
||||||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
task = np.arange(7)
|
||||||
|
user_num = 5
|
||||||
|
bs = args.batch_size
|
||||||
|
lr = args.lr # 1e-4
|
||||||
|
hs = args.hidden_size #128
|
||||||
|
model_type = args.model_type #'lstmlast'
|
||||||
|
|
||||||
|
rate = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
|
||||||
|
|
||||||
|
for i in task:
|
||||||
|
for j in rate:
|
||||||
|
for l in range(user_num):
|
||||||
|
pred_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_rate_" + str(j) + "_pred.csv"
|
||||||
|
if j == 100:
|
||||||
|
pred_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_pred.csv"
|
||||||
|
save_path = "prediction/single_act/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l)
|
||||||
|
Path(save_path).mkdir(parents=True, exist_ok=True)
|
||||||
|
data = sample_single_act(pred_path, save_path, j)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# split the prediction by action sequence id, from 10% to 90%
|
||||||
|
main()
|
||||||
|
|
5
keyboard_and_mouse/sampler_single_act.sh
Normal file
5
keyboard_and_mouse/sampler_single_act.sh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
python3 sampler_single_act.py \
|
||||||
|
--batch_size 8 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--model_type lstmlast \
|
||||||
|
--hidden_size 128
|
68
keyboard_and_mouse/sampler_user.py
Normal file
68
keyboard_and_mouse/sampler_user.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import csv
|
||||||
|
import pandas
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def sample_predciton(path, rate):
|
||||||
|
data = pandas.read_csv(path).values
|
||||||
|
task_list = [0, 1, 2, 3, 4, 5, 6]
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
stop = 0
|
||||||
|
num_unique = np.unique(data[:,1])
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for j in task_list:
|
||||||
|
for i in num_unique:
|
||||||
|
inx = np.where((data[:,1] == i) & (data[:,-2] == j))
|
||||||
|
samples.append(data[inx])
|
||||||
|
|
||||||
|
for i in range(len(samples)):
|
||||||
|
n = int(len(samples[i])*(100-rate)/100)
|
||||||
|
if n == 0:
|
||||||
|
n = 1
|
||||||
|
samples[i] = samples[i][:-n]
|
||||||
|
if len(samples[i]) == 0:
|
||||||
|
print('len of after sampling',len(samples[i]))
|
||||||
|
|
||||||
|
return np.vstack(samples)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parsing parameters
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
task = np.arange(7)
|
||||||
|
user_num = 5
|
||||||
|
bs = args.batch_size
|
||||||
|
lr = args.lr # 1e-4
|
||||||
|
hs = args.hidden_size #128
|
||||||
|
model_type = args.model_type #'lstmlast'
|
||||||
|
|
||||||
|
rate = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||||
|
|
||||||
|
for i in task:
|
||||||
|
for j in rate:
|
||||||
|
for l in range(user_num):
|
||||||
|
pred_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_pred.csv"
|
||||||
|
save_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_rate_" + str(j) + "_pred.csv"
|
||||||
|
data = sample_predciton(pred_path, j)
|
||||||
|
|
||||||
|
head = []
|
||||||
|
for r in range(7):
|
||||||
|
head.append('act'+str(r+1))
|
||||||
|
head.append('task_name')
|
||||||
|
head.append('gt')
|
||||||
|
head.insert(0,'action_id')
|
||||||
|
pandas.DataFrame(data[:,1:]).to_csv(save_path, header=head)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# split the prediction by length, from 10% to 90%
|
||||||
|
main()
|
||||||
|
|
5
keyboard_and_mouse/sampler_user.sh
Normal file
5
keyboard_and_mouse/sampler_user.sh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
python3 sampler_user.py \
|
||||||
|
--batch_size 8 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--model_type lstmlast \
|
||||||
|
--hidden_size 128
|
88
keyboard_and_mouse/stan/plot_user.py
Normal file
88
keyboard_and_mouse/stan/plot_user.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||||
|
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||||
|
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
plot_type = 'bar' # line bar
|
||||||
|
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
# read data
|
||||||
|
user_data_list = []
|
||||||
|
for i in range(args.user):
|
||||||
|
model_data_list = []
|
||||||
|
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
for j in range(7):
|
||||||
|
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data_temp)
|
||||||
|
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||||
|
user_data_list.append(model_data_list)
|
||||||
|
|
||||||
|
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||||
|
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||||
|
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||||
|
fig.set_figheight(14)
|
||||||
|
fig.set_figwidth(25)
|
||||||
|
|
||||||
|
for ax in range(7):
|
||||||
|
y_total = []
|
||||||
|
y_low_total = []
|
||||||
|
y_high_total = []
|
||||||
|
for j in range(7):
|
||||||
|
y= []
|
||||||
|
y_low = []
|
||||||
|
y_high = []
|
||||||
|
for i in range(len(user_data_list)):
|
||||||
|
y.append(user_data_list[i][j+ax*7][0])
|
||||||
|
y_low.append(user_data_list[i][j+ax*7][2])
|
||||||
|
y_high.append(user_data_list[i][j+ax*7][3])
|
||||||
|
y_total.append(y)
|
||||||
|
y_low_total.append(y_low)
|
||||||
|
y_high_total.append(y_high)
|
||||||
|
print()
|
||||||
|
print("user mean of mean prob: ", np.mean(y))
|
||||||
|
print("user mean of sd prob: ", np.std(y))
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
if plot_type == 'line':
|
||||||
|
axs[ax].plot(range(args.user), y_total[i], color=color[i], label=legend[i])
|
||||||
|
axs[ax].fill_between(range(args.user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||||
|
if plot_type == 'bar':
|
||||||
|
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||||
|
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||||
|
axs[ax].bar(np.arange(args.user)+width[i], y_total[i], width=0.08, yerr=yerror, label=legend[i], color=color[i])
|
||||||
|
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||||
|
axs[ax].set_ylabel('prob', fontsize=22)
|
||||||
|
for k,x in enumerate(np.arange(args.user)+width[i]):
|
||||||
|
y = y_total[i][k] + yerror[1][k]
|
||||||
|
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16)
|
||||||
|
|
||||||
|
axs[0].text(-0.1, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 22) # all: -0.3,0.5 3rows: -0.5,0.5
|
||||||
|
axs[ax].text(-0.1, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 22, color=color[ax])
|
||||||
|
axs[ax].tick_params(axis='both', which='major', labelsize=16)
|
||||||
|
|
||||||
|
plt.xticks(range(args.user),('1', '2', '3', '4', '5'))
|
||||||
|
plt.xlabel('user', fontsize= 22)
|
||||||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||||||
|
plt.ylim([0, 1])
|
||||||
|
Path("figure").mkdir(parents=True, exist_ok=True)
|
||||||
|
if plot_type == 'line':
|
||||||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_line.png", bbox_inches='tight')
|
||||||
|
if plot_type == 'bar':
|
||||||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_bar.png", bbox_inches='tight')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
8
keyboard_and_mouse/stan/plot_user.sh
Normal file
8
keyboard_and_mouse/stan/plot_user.sh
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
python3 plot_user.py \
|
||||||
|
--model_type lstmlast_ \
|
||||||
|
--batch_size 8 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--hidden_size 128 \
|
||||||
|
--N 1 \
|
||||||
|
--user 5
|
||||||
|
|
99
keyboard_and_mouse/stan/plot_user_all_individual.py
Normal file
99
keyboard_and_mouse/stan/plot_user_all_individual.py
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||||
|
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||||
|
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
plot_type = 'bar' # line bar
|
||||||
|
act_series = 5
|
||||||
|
|
||||||
|
# read data
|
||||||
|
plot_list = []
|
||||||
|
for act in range(1,act_series+1):
|
||||||
|
user_data_list = []
|
||||||
|
for i in range(args.user):
|
||||||
|
model_data_list = []
|
||||||
|
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__100" + "_act_" + str(act) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
for j in range(7):
|
||||||
|
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data_temp)
|
||||||
|
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||||
|
print(model_data_list.shape)
|
||||||
|
user_data_list.append(model_data_list)
|
||||||
|
|
||||||
|
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||||
|
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||||
|
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||||
|
fig.set_figheight(14)
|
||||||
|
fig.set_figwidth(25)
|
||||||
|
|
||||||
|
for ax in range(7):
|
||||||
|
y_total = []
|
||||||
|
y_low_total = []
|
||||||
|
y_high_total = []
|
||||||
|
for j in range(7):
|
||||||
|
y= []
|
||||||
|
y_low = []
|
||||||
|
y_high = []
|
||||||
|
for i in range(len(user_data_list)):
|
||||||
|
y.append(user_data_list[i][j+ax*7][0])
|
||||||
|
y_low.append(user_data_list[i][j+ax*7][2])
|
||||||
|
y_high.append(user_data_list[i][j+ax*7][3])
|
||||||
|
y_total.append(y)
|
||||||
|
y_low_total.append(y_low)
|
||||||
|
y_high_total.append(y_high)
|
||||||
|
print()
|
||||||
|
print("user mean of mean prob: ", np.mean(y))
|
||||||
|
print("user mean of sd prob: ", np.std(y))
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
if plot_type == 'line':
|
||||||
|
axs[ax].plot(range(args.user), y_total[i], color=color[i], label=legend[i])
|
||||||
|
axs[ax].fill_between(range(args.user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||||
|
if plot_type == 'bar':
|
||||||
|
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||||
|
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||||
|
axs[ax].bar(np.arange(args.user)+width[i], y_total[i], width=0.08, yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])], label=legend[i], color=color[i])
|
||||||
|
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||||
|
axs[ax].set_ylabel('prob', fontsize=36) # was 22,
|
||||||
|
for k,x in enumerate(np.arange(args.user)+width[i]):
|
||||||
|
y = y_total[i][k] + yerror[1][k]
|
||||||
|
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16) #was 16
|
||||||
|
|
||||||
|
axs[0].text(-0.17, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 46) # was -0.1 0.9 25
|
||||||
|
axs[ax].text(-0.17, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 46, color=color[ax]) # was 25
|
||||||
|
axs[ax].tick_params(axis='both', which='major', labelsize=42) # was 18
|
||||||
|
for tick in axs[ax].xaxis.get_major_ticks():
|
||||||
|
tick.set_pad(20)
|
||||||
|
|
||||||
|
plt.xticks(range(args.user),('1', '2', '3', '4', '5'))
|
||||||
|
plt.xlabel('user', fontsize= 42) # was 22
|
||||||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||||||
|
|
||||||
|
plt.ylim([0, 1])
|
||||||
|
plt.tight_layout()
|
||||||
|
if plot_type == 'line':
|
||||||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_line_all_individual.png", bbox_inches='tight')
|
||||||
|
if plot_type == 'bar':
|
||||||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_bar_all_individual.png", bbox_inches='tight')
|
||||||
|
|
||||||
|
if plot_type == 'line':
|
||||||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_line_all_individual.eps", bbox_inches='tight', format='eps')
|
||||||
|
if plot_type == 'bar':
|
||||||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_bar_all_individual.eps", bbox_inches='tight', format='eps')
|
||||||
|
#plt.show()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
8
keyboard_and_mouse/stan/plot_user_all_individual.sh
Normal file
8
keyboard_and_mouse/stan/plot_user_all_individual.sh
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
python3 plot_user_all_individual.py \
|
||||||
|
--model_type lstmlast_ \
|
||||||
|
--batch_size 8 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--hidden_size 128 \
|
||||||
|
--N 1 \
|
||||||
|
--user 5
|
||||||
|
|
93
keyboard_and_mouse/stan/plot_user_all_individual_chiw.py
Normal file
93
keyboard_and_mouse/stan/plot_user_all_individual_chiw.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
model_type = "lstmlast_"
|
||||||
|
batch_size = 8
|
||||||
|
lr = 1e-4
|
||||||
|
hidden_size = 128
|
||||||
|
N = 1
|
||||||
|
user = 5
|
||||||
|
plot_type = 'bar' # line bar
|
||||||
|
act_series = 5
|
||||||
|
|
||||||
|
# read data
|
||||||
|
plot_list = []
|
||||||
|
for act in range(1,act_series+1):
|
||||||
|
user_data_list = []
|
||||||
|
for i in range(user):
|
||||||
|
model_data_list = []
|
||||||
|
path = "result/"+"N"+ str(N) + "/" + model_type + "bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_result_user" + str(i) + "_rate__100" + "_act_" + str(act) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
for j in range(7):
|
||||||
|
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data_temp)
|
||||||
|
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||||
|
print(model_data_list.shape)
|
||||||
|
user_data_list.append(model_data_list)
|
||||||
|
|
||||||
|
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||||
|
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||||
|
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||||
|
fig.set_figheight(14)
|
||||||
|
fig.set_figwidth(25)
|
||||||
|
|
||||||
|
for ax in range(7):
|
||||||
|
y_total = []
|
||||||
|
y_low_total = []
|
||||||
|
y_high_total = []
|
||||||
|
for j in range(7):
|
||||||
|
y= []
|
||||||
|
y_low = []
|
||||||
|
y_high = []
|
||||||
|
for i in range(len(user_data_list)):
|
||||||
|
y.append(user_data_list[i][j+ax*7][0])
|
||||||
|
y_low.append(user_data_list[i][j+ax*7][2])
|
||||||
|
y_high.append(user_data_list[i][j+ax*7][3])
|
||||||
|
y_total.append(y)
|
||||||
|
y_low_total.append(y_low)
|
||||||
|
y_high_total.append(y_high)
|
||||||
|
print()
|
||||||
|
print(legend[ax])
|
||||||
|
print("user mean of mean prob: ", np.mean(y))
|
||||||
|
print("user mean of sd prob: ", np.std(y))
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
if plot_type == 'line':
|
||||||
|
axs[ax].plot(range(user), y_total[i], color=color[i], label=legend[i])
|
||||||
|
axs[ax].fill_between(range(user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||||
|
if plot_type == 'bar':
|
||||||
|
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||||
|
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||||
|
axs[ax].bar(np.arange(user)+width[i], y_total[i], width=0.08, yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])], label=legend[i], color=color[i])
|
||||||
|
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||||
|
axs[ax].set_ylabel('prob', fontsize=26) # was 22,
|
||||||
|
axs[ax].set_title(legend[ax], color=color[ax], fontsize=26)
|
||||||
|
for k,x in enumerate(np.arange(user)+width[i]):
|
||||||
|
y = y_total[i][k] + yerror[1][k]
|
||||||
|
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16) #was 16
|
||||||
|
|
||||||
|
#axs[0].text(-0.17, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 46) # was -0.1 0.9 25
|
||||||
|
#axs[ax].text(-0.17, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 46, color=color[ax]) # was 25
|
||||||
|
axs[ax].tick_params(axis='both', which='major', labelsize=18) # was 18
|
||||||
|
for tick in axs[ax].xaxis.get_major_ticks():
|
||||||
|
tick.set_pad(20)
|
||||||
|
|
||||||
|
plt.xticks(range(user),('1', '2', '3', '4', '5'))
|
||||||
|
plt.xlabel('user', fontsize= 26) # was 22
|
||||||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||||||
|
|
||||||
|
plt.ylim([0, 1.2])
|
||||||
|
plt.tight_layout()
|
||||||
|
if plot_type == 'line':
|
||||||
|
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_line_all_individual_chiw.png", bbox_inches='tight')
|
||||||
|
if plot_type == 'bar':
|
||||||
|
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_bar_all_individual_chiw.png", bbox_inches='tight')
|
||||||
|
#plt.show()
|
||||||
|
if plot_type == 'line':
|
||||||
|
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_line_all_individual_chiw.eps", bbox_inches='tight', format='eps')
|
||||||
|
if plot_type == 'bar':
|
||||||
|
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_bar_all_individual_chiw.eps", bbox_inches='tight', format='eps')
|
||||||
|
|
||||||
|
|
||||||
|
|
88
keyboard_and_mouse/stan/plot_user_length_10_steps.py
Normal file
88
keyboard_and_mouse/stan/plot_user_length_10_steps.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||||
|
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||||
|
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
rate_user_data_list = []
|
||||||
|
for r in range(0,101,10): # rate = range(0,101,10)
|
||||||
|
# read data
|
||||||
|
user_data_list = []
|
||||||
|
for i in range(args.user):
|
||||||
|
model_data_list = []
|
||||||
|
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__" + str(r) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
for j in range(7):
|
||||||
|
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data_temp)
|
||||||
|
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||||
|
if i == 4:
|
||||||
|
print(model_data_list.shape, model_data_list)
|
||||||
|
user_data_list.append(model_data_list)
|
||||||
|
model_data_list_total = np.stack(user_data_list)
|
||||||
|
print(model_data_list_total.shape)
|
||||||
|
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||||
|
print(mean_user_data.shape)
|
||||||
|
rate_user_data_list.append(mean_user_data)
|
||||||
|
|
||||||
|
|
||||||
|
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||||
|
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||||
|
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||||
|
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
||||||
|
fig.set_figwidth(20)
|
||||||
|
|
||||||
|
for ax in range(7):
|
||||||
|
y_total = []
|
||||||
|
y_low_total = []
|
||||||
|
y_high_total = []
|
||||||
|
for j in range(7):
|
||||||
|
y= []
|
||||||
|
y_low = []
|
||||||
|
y_high = []
|
||||||
|
for i in range(len(rate_user_data_list)):
|
||||||
|
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||||
|
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||||
|
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||||
|
y_total.append(y)
|
||||||
|
y_low_total.append(y_low)
|
||||||
|
y_high_total.append(y_high)
|
||||||
|
print()
|
||||||
|
print("user mean of mean prob: ", np.mean(y))
|
||||||
|
print("user mean of sd prob: ", np.std(y))
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||||
|
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||||
|
axs[ax].set_xticks(range(0,101,10))
|
||||||
|
axs[ax].set_ylabel('prob', fontsize=20)
|
||||||
|
|
||||||
|
|
||||||
|
axs[0].text(-0.125, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 20)
|
||||||
|
axs[ax].text(-0.125, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 20, color=color[ax])
|
||||||
|
axs[ax].tick_params(axis='both', which='major', labelsize=16)
|
||||||
|
|
||||||
|
|
||||||
|
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 20)
|
||||||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||||||
|
|
||||||
|
plt.xlim([0, 101])
|
||||||
|
plt.ylim([0, 1])
|
||||||
|
|
||||||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_rate_full.png", bbox_inches='tight')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
8
keyboard_and_mouse/stan/plot_user_length_10_steps.sh
Normal file
8
keyboard_and_mouse/stan/plot_user_length_10_steps.sh
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
python3 plot_user_length_10_steps.py \
|
||||||
|
--model_type lstmlast_ \
|
||||||
|
--batch_size 8 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--hidden_size 128 \
|
||||||
|
--N 1 \
|
||||||
|
--user 5
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||||
|
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||||
|
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||||
|
act_series = 5
|
||||||
|
|
||||||
|
for act in range(1,act_series+1):
|
||||||
|
rate_user_data_list = []
|
||||||
|
for r in range(0,101,10):
|
||||||
|
# read data
|
||||||
|
user_data_list = []
|
||||||
|
for i in range(args.user):
|
||||||
|
model_data_list = []
|
||||||
|
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__" + str(r) + "_act_" + str(act) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
for j in range(7):
|
||||||
|
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data_temp)
|
||||||
|
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||||
|
user_data_list.append(model_data_list)
|
||||||
|
model_data_list_total = np.stack(user_data_list)
|
||||||
|
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||||
|
rate_user_data_list.append(mean_user_data)
|
||||||
|
|
||||||
|
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||||
|
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||||
|
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||||
|
fig.set_figheight(14) # was 10
|
||||||
|
fig.set_figwidth(20)
|
||||||
|
|
||||||
|
for ax in range(7):
|
||||||
|
y_total = []
|
||||||
|
y_low_total = []
|
||||||
|
y_high_total = []
|
||||||
|
for j in range(7):
|
||||||
|
y= []
|
||||||
|
y_low = []
|
||||||
|
y_high = []
|
||||||
|
for i in range(len(rate_user_data_list)):
|
||||||
|
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||||
|
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||||
|
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||||
|
y_total.append(y)
|
||||||
|
y_low_total.append(y_low)
|
||||||
|
y_high_total.append(y_high)
|
||||||
|
print()
|
||||||
|
print("user mean of mean prob: ", np.mean(y))
|
||||||
|
print("user mean of sd prob: ", np.std(y))
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||||
|
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||||
|
axs[ax].set_xticks(range(0,101,10))
|
||||||
|
axs[ax].set_ylabel('prob', fontsize=26) # was 20
|
||||||
|
|
||||||
|
axs[0].text(-0.15, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) # was -0.125 20
|
||||||
|
axs[ax].text(-0.15, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) # -0.125 20
|
||||||
|
axs[ax].tick_params(axis='y', which='major', labelsize=24) # was 16
|
||||||
|
axs[ax].tick_params(axis='x', which='major', labelsize=24) # was 16
|
||||||
|
for tick in axs[ax].xaxis.get_major_ticks():
|
||||||
|
tick.set_pad(20)
|
||||||
|
|
||||||
|
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 36) # was 20
|
||||||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||||||
|
|
||||||
|
plt.xlim([0, 101])
|
||||||
|
plt.ylim([0, 1])
|
||||||
|
|
||||||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_rate_ful_all_individuall.png", bbox_inches='tight')
|
||||||
|
|
||||||
|
#plt.show()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
python3 plot_user_length_10_steps_all_individual.py \
|
||||||
|
--model_type lstmlast_ \
|
||||||
|
--batch_size 8 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--hidden_size 128 \
|
||||||
|
--N 1 \
|
||||||
|
--user 5
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
model_type = "lstmlast_"
|
||||||
|
batch_size = 8
|
||||||
|
lr = 1e-4
|
||||||
|
hidden_size = 128
|
||||||
|
N = 1
|
||||||
|
user = 5
|
||||||
|
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||||
|
act_series = 5
|
||||||
|
|
||||||
|
for act in range(1,act_series+1):
|
||||||
|
rate_user_data_list = []
|
||||||
|
for r in range(0,101,10):
|
||||||
|
# read data
|
||||||
|
print(r)
|
||||||
|
user_data_list = []
|
||||||
|
for i in range(user):
|
||||||
|
model_data_list = []
|
||||||
|
path = "result/"+"N"+ str(N) + "/" + model_type + "bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_result_user" + str(i) + "_rate__" + str(r) + "_act_" + str(act) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
for j in range(7):
|
||||||
|
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data_temp)
|
||||||
|
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||||
|
user_data_list.append(model_data_list)
|
||||||
|
model_data_list_total = np.stack(user_data_list)
|
||||||
|
print(model_data_list_total.shape)
|
||||||
|
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||||
|
print(mean_user_data.shape)
|
||||||
|
rate_user_data_list.append(mean_user_data)
|
||||||
|
|
||||||
|
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||||
|
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||||
|
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||||
|
fig.set_figheight(14) # was 10
|
||||||
|
fig.set_figwidth(20)
|
||||||
|
|
||||||
|
for ax in range(7):
|
||||||
|
y_total = []
|
||||||
|
y_low_total = []
|
||||||
|
y_high_total = []
|
||||||
|
for j in range(7):
|
||||||
|
y= []
|
||||||
|
y_low = []
|
||||||
|
y_high = []
|
||||||
|
for i in range(len(rate_user_data_list)):
|
||||||
|
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||||
|
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||||
|
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||||
|
y_total.append(y)
|
||||||
|
y_low_total.append(y_low)
|
||||||
|
y_high_total.append(y_high)
|
||||||
|
print()
|
||||||
|
print(legend[ax])
|
||||||
|
print("user mean of mean prob: ", np.mean(y))
|
||||||
|
print("user mean of sd prob: ", np.std(y))
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||||
|
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||||
|
axs[ax].set_xticks(range(0,101,10))
|
||||||
|
axs[ax].set_ylabel('prob', fontsize=26) # was 20
|
||||||
|
axs[ax].set_title(legend[ax], color=color[ax], fontsize=26)
|
||||||
|
|
||||||
|
|
||||||
|
#axs[0].text(-0.15, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) # was -0.125 20
|
||||||
|
#axs[ax].text(-0.15, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) # -0.125 20
|
||||||
|
axs[ax].tick_params(axis='y', which='major', labelsize=18) # was 16
|
||||||
|
axs[ax].tick_params(axis='x', which='major', labelsize=18) # was 16
|
||||||
|
for tick in axs[ax].xaxis.get_major_ticks():
|
||||||
|
tick.set_pad(20)
|
||||||
|
|
||||||
|
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 26) # was 20
|
||||||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||||||
|
|
||||||
|
plt.xlim([0, 101])
|
||||||
|
plt.ylim([0, 1.1])
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_rate_ful_all_individuall_chiw.png", bbox_inches='tight')
|
||||||
|
|
||||||
|
#plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
keyboard_and_mouse/stan/strategy_inference_model
Executable file
BIN
keyboard_and_mouse/stan/strategy_inference_model
Executable file
Binary file not shown.
26
keyboard_and_mouse/stan/strategy_inference_model.stan
Executable file
26
keyboard_and_mouse/stan/strategy_inference_model.stan
Executable file
|
@ -0,0 +1,26 @@
|
||||||
|
data {
|
||||||
|
int<lower=1> I; // number of question options (22)
|
||||||
|
int<lower=0> N; // number of questions being asked by the user
|
||||||
|
int<lower=1> K; // number of strategies
|
||||||
|
// observed "true" questions of the user
|
||||||
|
int q[N];
|
||||||
|
// array of predicted probabilities of questions given strategies
|
||||||
|
// coming from the forward neural network
|
||||||
|
matrix[I, K] P_q_S[N];
|
||||||
|
}
|
||||||
|
parameters {
|
||||||
|
// probabiliy vector of the strategies being applied by the user
|
||||||
|
// to be inferred by the model here
|
||||||
|
simplex[K] P_S;
|
||||||
|
}
|
||||||
|
model {
|
||||||
|
for (n in 1:N) {
|
||||||
|
// marginal probability vector of the questions being asked
|
||||||
|
vector[I] theta = P_q_S[n] * P_S;
|
||||||
|
// categorical likelihood
|
||||||
|
target += categorical_lpmf(q[n] | theta);
|
||||||
|
}
|
||||||
|
// priors
|
||||||
|
target += dirichlet_lpdf(P_S | rep_vector(1.0, K));
|
||||||
|
}
|
||||||
|
|
157
keyboard_and_mouse/stan/strategy_inference_test.R
Normal file
157
keyboard_and_mouse/stan/strategy_inference_test.R
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
library(tidyverse)
|
||||||
|
library(cmdstanr)
|
||||||
|
library(dplyr)
|
||||||
|
|
||||||
|
|
||||||
|
model_type <- "lstmlast"
|
||||||
|
batch_size <- "8"
|
||||||
|
lr <- "0.0001"
|
||||||
|
hidden_size <- "128"
|
||||||
|
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||||
|
print(model_type)
|
||||||
|
set.seed(9736734)
|
||||||
|
|
||||||
|
user_num <- 5
|
||||||
|
user <-c(0:(user_num-1))
|
||||||
|
strategies <- c(0:6) # 7 tasks
|
||||||
|
print(strategies)
|
||||||
|
print(length(strategies))
|
||||||
|
N <- 1
|
||||||
|
|
||||||
|
# read data from csv
|
||||||
|
sel <- vector("list", length(strategies))
|
||||||
|
for (u in seq_along(user)){
|
||||||
|
dat <- vector("list", length(strategies))
|
||||||
|
print(paste0('user: ', u))
|
||||||
|
for (i in seq_along(strategies)) {
|
||||||
|
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
|
||||||
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||||
|
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||||
|
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||||
|
}
|
||||||
|
|
||||||
|
# reset N after inference
|
||||||
|
N = 1
|
||||||
|
|
||||||
|
# select one action series from one intention
|
||||||
|
if (user[[u]] == 0){
|
||||||
|
sel[[1]]<-dat[[1]] %>%
|
||||||
|
group_by(task_name) %>%
|
||||||
|
sample_n(N)
|
||||||
|
sel[[1]] <- data.frame(sel[[1]])
|
||||||
|
}
|
||||||
|
|
||||||
|
# filter data from the selected action series, N series per intention
|
||||||
|
for (i in seq_along(strategies)) {
|
||||||
|
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
|
||||||
|
}
|
||||||
|
row.names(dat) <- NULL
|
||||||
|
|
||||||
|
# create save path
|
||||||
|
dir.create(file.path("result"), showWarnings = FALSE)
|
||||||
|
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
|
||||||
|
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], ".csv")
|
||||||
|
|
||||||
|
dat <- do.call(rbind, dat) %>%
|
||||||
|
mutate(index = as.numeric(as.factor(id))) %>%
|
||||||
|
rename(true_strategy = task_name) %>%
|
||||||
|
mutate(
|
||||||
|
true_strategy = factor(
|
||||||
|
true_strategy, levels = 0:6,
|
||||||
|
labels = strategies
|
||||||
|
),
|
||||||
|
q_type = case_when(
|
||||||
|
gt %in% c(3,4,5) ~ 0,
|
||||||
|
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||||
|
gt %in% c(1,2,3,4) ~ 2,
|
||||||
|
gt %in% c(1,4,5,6,7) ~ 3,
|
||||||
|
gt %in% c(1,2,3,6,7) ~ 4,
|
||||||
|
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||||
|
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
|
||||||
|
N <- nrow(dat_obs)
|
||||||
|
print(c("N: ", N))
|
||||||
|
q <- dat_obs$gt
|
||||||
|
true_strategy <- dat_obs$true_strategy
|
||||||
|
|
||||||
|
K <- length(unique(dat$assumed_strategy))
|
||||||
|
print(c("K: ", K))
|
||||||
|
I <- 7
|
||||||
|
|
||||||
|
P_q_S <- array(dim = c(N, I, K))
|
||||||
|
for (n in 1:N) {
|
||||||
|
#print(n)
|
||||||
|
P_q_S[n, , ] <- dat %>%
|
||||||
|
filter(index == n) %>%
|
||||||
|
select(matches("^act[[:digit:]]+$")) %>%
|
||||||
|
as.matrix() %>%
|
||||||
|
t()
|
||||||
|
for (k in 1:K) {
|
||||||
|
# normalize probabilities
|
||||||
|
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
print(c('dim P_q_S',dim(P_q_S)))
|
||||||
|
|
||||||
|
mod <- cmdstan_model("strategy_inference_model.stan")
|
||||||
|
|
||||||
|
sub <- which(true_strategy == 0) # "0"
|
||||||
|
print(c('sub', sub))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_0$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
sub <- which(true_strategy == 1)
|
||||||
|
print(c('sub', sub))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_1$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
sub <- which(true_strategy == 2)
|
||||||
|
print(c('sub', sub))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_2$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
sub <- which(true_strategy == 3)
|
||||||
|
print(c('sub', sub))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_3$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
sub <- which(true_strategy == 4)
|
||||||
|
print(c('sub', sub))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_4$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
sub <- which(true_strategy == 5)
|
||||||
|
print(c('sub', sub))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_5$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
sub <- which(true_strategy == 6)
|
||||||
|
print(c('sub', sub))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_6$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
# save csv
|
||||||
|
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||||
|
write.csv(df,file=save_path,quote=FALSE)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,239 @@
|
||||||
|
library(tidyverse)
|
||||||
|
library(cmdstanr)
|
||||||
|
library(dplyr)
|
||||||
|
|
||||||
|
# using every action sequence from each user
|
||||||
|
model_type <- "lstmlast"
|
||||||
|
batch_size <- "8"
|
||||||
|
lr <- "0.0001"
|
||||||
|
hidden_size <- "128"
|
||||||
|
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||||
|
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
||||||
|
|
||||||
|
user_num <- 5
|
||||||
|
user <-c(0:(user_num-1))
|
||||||
|
strategies <- c(0:6) # 7 tasks
|
||||||
|
print('strategies')
|
||||||
|
print(strategies)
|
||||||
|
print('strategies length')
|
||||||
|
print(length(strategies))
|
||||||
|
N <- 1
|
||||||
|
unique_act_id <- c(1:5)
|
||||||
|
print('unique_act_id')
|
||||||
|
print(unique_act_id)
|
||||||
|
set.seed(9746234)
|
||||||
|
|
||||||
|
for (act_id in seq_along(unique_act_id)){
|
||||||
|
for (u in seq_along(user)){
|
||||||
|
print('user')
|
||||||
|
print(u)
|
||||||
|
for (rate in rates) {
|
||||||
|
N <- 1
|
||||||
|
dat <- vector("list", length(strategies))
|
||||||
|
for (i in seq_along(strategies)) {
|
||||||
|
if (rate=="_0"){
|
||||||
|
# read data from csv
|
||||||
|
dat[[i]] <- read.csv(paste0("../prediction/single_act/task", strategies[[i]], "/", model_type, "/user", user[[u]], "/rate_10", "_act_", unique_act_id[act_id], "_pred", ".csv"))
|
||||||
|
} else{
|
||||||
|
dat[[i]] <- read.csv(paste0("../prediction/single_act/task", strategies[[i]], "/", model_type, "/user", user[[u]], "/rate", rate, "_act_", unique_act_id[act_id], "_pred", ".csv"))
|
||||||
|
}
|
||||||
|
# strategy assumed for prediction
|
||||||
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||||
|
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||||
|
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||||
|
}
|
||||||
|
|
||||||
|
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, "_act_", unique_act_id[act_id], ".csv")
|
||||||
|
|
||||||
|
dat_act <- do.call(rbind, dat) %>%
|
||||||
|
mutate(index = as.numeric(as.factor(id))) %>%
|
||||||
|
rename(true_strategy = task_name) %>%
|
||||||
|
mutate(
|
||||||
|
true_strategy = factor(
|
||||||
|
true_strategy, levels = 0:6,
|
||||||
|
labels = strategies
|
||||||
|
),
|
||||||
|
q_type = case_when(
|
||||||
|
gt %in% c(3,4,5) ~ 0,
|
||||||
|
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||||
|
gt %in% c(1,2,3,4) ~ 2,
|
||||||
|
gt %in% c(1,4,5,6,7) ~ 3,
|
||||||
|
gt %in% c(1,2,3,6,7) ~ 4,
|
||||||
|
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||||
|
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]])
|
||||||
|
N <- nrow(dat_obs)
|
||||||
|
print(c("N: ", N))
|
||||||
|
print(c("dim dat_act: ", dim(dat_act)))
|
||||||
|
q <- dat_obs$gt
|
||||||
|
true_strategy <- dat_obs$true_strategy
|
||||||
|
|
||||||
|
K <- length(unique(dat_act$assumed_strategy))
|
||||||
|
I <- 7
|
||||||
|
|
||||||
|
P_q_S <- array(dim = c(N, I, K))
|
||||||
|
for (n in 1:N) {
|
||||||
|
print(n)
|
||||||
|
P_q_S[n, , ] <- dat_act %>%
|
||||||
|
filter(index == n) %>%
|
||||||
|
select(matches("^act[[:digit:]]+$")) %>%
|
||||||
|
as.matrix() %>%
|
||||||
|
t()
|
||||||
|
for (k in 1:K) {
|
||||||
|
# normalize probabilities
|
||||||
|
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
print(c("dim(P_q_S)", dim(P_q_S)))
|
||||||
|
# read stan model
|
||||||
|
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 0) # "0"
|
||||||
|
}
|
||||||
|
#print(sub)
|
||||||
|
#print(length(sub))
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||||
|
print(fit_0$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 1)
|
||||||
|
}
|
||||||
|
#print(sub)
|
||||||
|
#print(length(sub))
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||||
|
print(fit_1$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 2)
|
||||||
|
}
|
||||||
|
#print(sub)
|
||||||
|
#print(length(sub))
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||||
|
print(fit_2$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 3)
|
||||||
|
}
|
||||||
|
#print(sub)
|
||||||
|
#print(length(sub))
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||||
|
print(fit_3$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 4)
|
||||||
|
}
|
||||||
|
#print(sub)
|
||||||
|
#print(length(sub))
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||||
|
print(fit_4$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 5)
|
||||||
|
}
|
||||||
|
#print(sub)
|
||||||
|
#print(length(sub))
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||||
|
print(fit_5$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 6)
|
||||||
|
}
|
||||||
|
#print(sub)
|
||||||
|
#print(length(sub))
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||||
|
print(fit_6$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
# save csv
|
||||||
|
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||||
|
write.csv(df,file=save_path,quote=FALSE)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
238
keyboard_and_mouse/stan/strategy_inference_test_full_length.R
Normal file
238
keyboard_and_mouse/stan/strategy_inference_test_full_length.R
Normal file
|
@ -0,0 +1,238 @@
|
||||||
|
library(tidyverse)
|
||||||
|
library(cmdstanr)
|
||||||
|
library(dplyr)
|
||||||
|
|
||||||
|
# index order of the strategies assumed throughout
|
||||||
|
model_type <- "lstmlast"
|
||||||
|
batch_size <- "8"
|
||||||
|
lr <- "0.0001"
|
||||||
|
hidden_size <- "128"
|
||||||
|
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||||
|
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
||||||
|
|
||||||
|
user_num <- 5
|
||||||
|
user <-c(0:(user_num-1))
|
||||||
|
strategies <- c(0:6) # 7 tasks
|
||||||
|
print(strategies)
|
||||||
|
print(length(strategies))
|
||||||
|
N <- 1
|
||||||
|
|
||||||
|
set.seed(9736754)
|
||||||
|
|
||||||
|
#read data from csv
|
||||||
|
sel <- vector("list", length(strategies))
|
||||||
|
for (u in seq_along(user)){
|
||||||
|
print('user')
|
||||||
|
print(u)
|
||||||
|
for (rate in rates) {
|
||||||
|
dat <- vector("list", length(strategies))
|
||||||
|
for (i in seq_along(strategies)) {
|
||||||
|
if (rate=="_0"){
|
||||||
|
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate_10", "_pred", ".csv"))
|
||||||
|
} else if (rate=="_100"){
|
||||||
|
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
|
||||||
|
} else{
|
||||||
|
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate", rate, "_pred", ".csv"))
|
||||||
|
}
|
||||||
|
# strategy assumed for prediction
|
||||||
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||||
|
dat[[i]]$index <- dat[[i]]$action_id
|
||||||
|
dat[[i]]$id <- dat[[i]][,1]
|
||||||
|
}
|
||||||
|
|
||||||
|
# reset N after inference
|
||||||
|
N <- 1
|
||||||
|
|
||||||
|
# select all action series and infer every one
|
||||||
|
if (rate == "_0"){
|
||||||
|
sel[[1]]<-dat[[1]] %>%
|
||||||
|
group_by(task_name) %>%
|
||||||
|
sample_n(N)
|
||||||
|
sel[[1]] <- data.frame(sel[[1]])
|
||||||
|
unique_act_id <- unique(sel[[1]]$action_id)
|
||||||
|
}
|
||||||
|
print(sel[[1]]$action_id)
|
||||||
|
print(sel[[1]]$task_name)
|
||||||
|
print(dat[[1]]$task_name)
|
||||||
|
|
||||||
|
|
||||||
|
for (i in seq_along(strategies)) {
|
||||||
|
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
|
||||||
|
}
|
||||||
|
row.names(dat) <- NULL
|
||||||
|
print(c('action id', dat[[1]]$action_id))
|
||||||
|
print(c('action id', dat[[2]]$action_id))
|
||||||
|
print(c('action id', dat[[3]]$action_id))
|
||||||
|
|
||||||
|
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
|
||||||
|
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, ".csv")
|
||||||
|
|
||||||
|
dat_act <- do.call(rbind, dat) %>%
|
||||||
|
mutate(index = as.numeric(as.factor(id))) %>%
|
||||||
|
rename(true_strategy = task_name) %>%
|
||||||
|
mutate(
|
||||||
|
true_strategy = factor(
|
||||||
|
true_strategy, levels = 0:6,
|
||||||
|
labels = strategies
|
||||||
|
),
|
||||||
|
q_type = case_when(
|
||||||
|
gt %in% c(3,4,5) ~ 0,
|
||||||
|
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||||
|
gt %in% c(1,2,3,4) ~ 2,
|
||||||
|
gt %in% c(1,4,5,6,7) ~ 3,
|
||||||
|
gt %in% c(1,2,3,6,7) ~ 4,
|
||||||
|
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||||
|
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]]) # put_fridge, was num
|
||||||
|
N <- nrow(dat_obs)
|
||||||
|
print(c("N: ", N))
|
||||||
|
print(c("dim dat_act: ", dim(dat_act)))
|
||||||
|
|
||||||
|
q <- dat_obs$gt
|
||||||
|
true_strategy <- dat_obs$true_strategy
|
||||||
|
|
||||||
|
K <- length(unique(dat_act$assumed_strategy))
|
||||||
|
I <- 7
|
||||||
|
|
||||||
|
P_q_S <- array(dim = c(N, I, K))
|
||||||
|
for (n in 1:N) {
|
||||||
|
print(n)
|
||||||
|
P_q_S[n, , ] <- dat_act %>%
|
||||||
|
filter(index == n) %>%
|
||||||
|
select(matches("^act[[:digit:]]+$")) %>%
|
||||||
|
as.matrix() %>%
|
||||||
|
t()
|
||||||
|
for (k in 1:K) {
|
||||||
|
# normalize probabilities
|
||||||
|
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
print(c("dim(P_q_S)", dim(P_q_S)))
|
||||||
|
|
||||||
|
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 0) # "0"
|
||||||
|
}
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_0$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 1)
|
||||||
|
}
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_1$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 2)
|
||||||
|
}
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_2$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 3)
|
||||||
|
}
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_3$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 4)
|
||||||
|
}
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_4$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 5)
|
||||||
|
}
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_5$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == 6)
|
||||||
|
}
|
||||||
|
if (length(sub) == 1){
|
||||||
|
temp <- P_q_S[sub, , ]
|
||||||
|
dim(temp) <- c(1, dim(temp))
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||||
|
} else{
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
}
|
||||||
|
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_6$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
# save csv
|
||||||
|
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||||
|
write.csv(df,file=save_path,quote=FALSE)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
10
keyboard_and_mouse/temp.py
Normal file
10
keyboard_and_mouse/temp.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
import torch
|
||||||
|
import matplotlib as plt
|
||||||
|
import pickle
|
||||||
|
print(pickle.format_version)
|
||||||
|
import pandas
|
||||||
|
|
||||||
|
print(torch.__version__)
|
||||||
|
print('matplotlib: {}'.format(plt.__version__))
|
||||||
|
print(pandas.__version__)
|
||||||
|
|
158
keyboard_and_mouse/test.py
Normal file
158
keyboard_and_mouse/test.py
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import shutil
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
from networks import ActionDemo2Predicate
|
||||||
|
from pathlib import Path
|
||||||
|
from termcolor import colored
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
print('torch version: ',torch.__version__)
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(DEVICE)
|
||||||
|
torch.manual_seed(256)
|
||||||
|
|
||||||
|
class test_dataset(Dataset):
|
||||||
|
def __init__(self, x, label, action_id):
|
||||||
|
self.x = x
|
||||||
|
self.idx = action_id
|
||||||
|
self.labels = label
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
x = self.x[index]
|
||||||
|
label = self.labels[index]
|
||||||
|
action_idx = self.idx[index]
|
||||||
|
return x, label, action_idx
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
def test_model(model, test_dataloader, DEVICE):
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
test_acc = []
|
||||||
|
logits = []
|
||||||
|
labels = []
|
||||||
|
action_ids = []
|
||||||
|
for iter, (x, label, action_id) in enumerate(test_dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
x = torch.tensor(x).to(DEVICE)
|
||||||
|
label = torch.tensor(label).to(DEVICE)
|
||||||
|
logps = model(x)
|
||||||
|
logps = F.softmax(logps, 1)
|
||||||
|
logits.append(logps.cpu().numpy())
|
||||||
|
labels.append(label.cpu().numpy())
|
||||||
|
action_ids.append(action_id)
|
||||||
|
|
||||||
|
argmax_Y = torch.max(logps, 1)[1].view(-1, 1)
|
||||||
|
test_acc.append((label.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(label.float().view(-1, 1)) * 100)
|
||||||
|
|
||||||
|
test_acc = np.mean(np.array(test_acc))
|
||||||
|
print('test acc {:.4f}'.format(test_acc))
|
||||||
|
logits = np.concatenate(logits, axis=0)
|
||||||
|
labels = np.concatenate(labels, axis=0)
|
||||||
|
action_ids = np.concatenate(action_ids, axis=0)
|
||||||
|
return logits, labels, action_ids
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parsing parameters
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--resume', type=bool, default=False, help='resume training')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-1, help='learning rate')
|
||||||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size')
|
||||||
|
parser.add_argument('--epochs', type=int, default=100, help='training epoch')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path')
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer')
|
||||||
|
parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden')
|
||||||
|
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
|
||||||
|
parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||||||
|
|
||||||
|
# read models
|
||||||
|
models = []
|
||||||
|
for i in range(7): # 7 tasks
|
||||||
|
net = ActionDemo2Predicate(args)
|
||||||
|
model_path = path + '/task' + str(i) + '_checkpoint.ckpt' # _checkpoint
|
||||||
|
net.load(model_path)
|
||||||
|
models.append(net)
|
||||||
|
|
||||||
|
for u in range(5):
|
||||||
|
task_pred = []
|
||||||
|
task_target = []
|
||||||
|
task_act = []
|
||||||
|
task_task_name = []
|
||||||
|
for i in range(7): # 7 tasks
|
||||||
|
test_loader = []
|
||||||
|
# # read dataset test data
|
||||||
|
with open(args.dataset_path + 'test_data_' + str(i) + '.pkl', 'rb') as f:
|
||||||
|
data_x = pickle.load(f)
|
||||||
|
with open(args.dataset_path + 'test_label_' + str(i) + '.pkl', 'rb') as f:
|
||||||
|
data_y = pickle.load(f)
|
||||||
|
with open(args.dataset_path + 'test_action_id_' + str(i) + '.pkl', 'rb') as f:
|
||||||
|
act_idx = pickle.load(f)
|
||||||
|
|
||||||
|
x = data_x[u]
|
||||||
|
y = data_y[u]
|
||||||
|
act = act_idx[u]
|
||||||
|
test_set = test_dataset(np.array(x), np.array(y)-1, np.array(act))
|
||||||
|
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True)
|
||||||
|
|
||||||
|
preds = []
|
||||||
|
targets = []
|
||||||
|
actions = []
|
||||||
|
task_names = []
|
||||||
|
for j in range(7): # logits from all models
|
||||||
|
pred, target, action = test_model(models[j], test_loader, DEVICE)
|
||||||
|
preds.append(pred)
|
||||||
|
targets.append(target)
|
||||||
|
actions.append(action)
|
||||||
|
task_names.append(np.full(target.shape, i)) #assumed intention
|
||||||
|
|
||||||
|
task_pred.append(preds)
|
||||||
|
task_target.append(targets)
|
||||||
|
task_act.append(actions)
|
||||||
|
task_task_name.append(task_names)
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
preds = []
|
||||||
|
targets = []
|
||||||
|
actions = []
|
||||||
|
task_names = []
|
||||||
|
for j in range(7):
|
||||||
|
preds.append(task_pred[j][i])
|
||||||
|
targets.append(task_target[j][i]+1) # gt value add one
|
||||||
|
actions.append(task_act[j][i])
|
||||||
|
task_names.append(task_task_name[j][i])
|
||||||
|
|
||||||
|
preds = np.concatenate(preds, axis=0)
|
||||||
|
targets = np.concatenate(targets, axis=0)
|
||||||
|
actions = np.concatenate(actions, axis=0)
|
||||||
|
task_names = np.concatenate(task_names, axis=0)
|
||||||
|
write_data = np.concatenate((np.reshape(actions, (-1, 1)), preds, np.reshape(task_names, (-1, 1)), np.reshape(targets, (-1, 1))), axis=1)
|
||||||
|
|
||||||
|
output_path = 'prediction/' + 'task' +str(i) + '/' + args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||||||
|
Path(output_path).mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = output_path + '/user' + str(u) + '_pred.csv'
|
||||||
|
print(write_data.shape)
|
||||||
|
|
||||||
|
head = []
|
||||||
|
for j in range(7):
|
||||||
|
head.append('act'+str(j+1))
|
||||||
|
head.append('task_name')
|
||||||
|
head.append('gt')
|
||||||
|
head.insert(0,'action_id')
|
||||||
|
pd.DataFrame(write_data).to_csv(output_path, header=head)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
12
keyboard_and_mouse/test.sh
Normal file
12
keyboard_and_mouse/test.sh
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
python3 test.py \
|
||||||
|
--resume False \
|
||||||
|
--batch_size 8 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--model_type lstmlast \
|
||||||
|
--epochs 100 \
|
||||||
|
--demo_hidden 128 \
|
||||||
|
--hidden_size 128 \
|
||||||
|
--dropout 0.5 \
|
||||||
|
--dataset_path dataset/strategy_dataset/ \
|
||||||
|
--checkpoint checkpoints/ \
|
||||||
|
--weight_decay 1e-4
|
145
keyboard_and_mouse/train.py
Normal file
145
keyboard_and_mouse/train.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import shutil
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
from networks import ActionDemo2Predicate
|
||||||
|
|
||||||
|
print('torch version: ',torch.__version__)
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(DEVICE)
|
||||||
|
torch.manual_seed(256)
|
||||||
|
|
||||||
|
class train_dataset(Dataset):
|
||||||
|
def __init__(self, x, label):
|
||||||
|
self.x = x
|
||||||
|
self.labels = label
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
x = self.x[index]
|
||||||
|
label = self.labels[index]
|
||||||
|
return x, label #, img_idx
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
class test_dataset(Dataset):
|
||||||
|
def __init__(self, x, label):
|
||||||
|
self.x = x
|
||||||
|
self.labels = label
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
x = self.x[index]
|
||||||
|
label = self.labels[index]
|
||||||
|
return x, label #, img_idx
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
def train_model(model, train_dataloader, criterion, optimizer, num_epochs, DEVICE, path, resume):
|
||||||
|
running_loss = 0
|
||||||
|
train_losses = 10
|
||||||
|
is_best_acc = False
|
||||||
|
is_best_train_loss = False
|
||||||
|
|
||||||
|
best_train_acc = 0
|
||||||
|
best_train_loss = 10
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
accuracy = 0
|
||||||
|
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.train()
|
||||||
|
for epoch in range(start_epoch, num_epochs):
|
||||||
|
epoch_losses = []
|
||||||
|
train_acc = []
|
||||||
|
epoch_loss = 0
|
||||||
|
for iter, (x, labels) in enumerate(train_dataloader):
|
||||||
|
x = torch.tensor(x).to(DEVICE)
|
||||||
|
labels = torch.tensor(labels).to(DEVICE)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
logps = model(x)
|
||||||
|
loss = criterion(logps, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
epoch_loss += loss.detach().item()
|
||||||
|
argmax_Y = torch.max(logps, 1)[1].view(-1, 1)
|
||||||
|
train_acc.append((labels.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(labels.float().view(-1, 1)) * 100)
|
||||||
|
epoch_loss /= (iter + 1)
|
||||||
|
epoch_losses.append(epoch_loss)
|
||||||
|
train_acc = np.mean(np.array(train_acc))
|
||||||
|
print('Epoch {}, train loss {:.4f}, train acc {:.4f}'.format(epoch, epoch_loss, train_acc))
|
||||||
|
|
||||||
|
is_best_acc = train_acc > best_train_acc
|
||||||
|
best_train_acc = max(train_acc, best_train_acc)
|
||||||
|
|
||||||
|
is_best_train_loss = best_train_loss < epoch_loss
|
||||||
|
best_train_loss = min(epoch_loss, best_train_loss)
|
||||||
|
|
||||||
|
if is_best_acc:
|
||||||
|
model.save(path + '_model_best.ckpt')
|
||||||
|
model.save(path + '_checkpoint.ckpt')
|
||||||
|
#scheduler.step()
|
||||||
|
|
||||||
|
def save_checkpoint(state, is_best, path, filename='_checkpoint.pth.tar'):
|
||||||
|
torch.save(state, path + filename)
|
||||||
|
if is_best:
|
||||||
|
shutil.copyfile(path + filename, path +'_model_best.pth.tar')
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parsing parameters
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--resume', type=bool, default=False, help='resume training')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-1, help='learning rate')
|
||||||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size')
|
||||||
|
parser.add_argument('--epochs', type=int, default=100, help='training epoch')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path')
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer')
|
||||||
|
parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden')
|
||||||
|
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
|
||||||
|
parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
# create checkpoints path
|
||||||
|
from pathlib import Path
|
||||||
|
path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||||||
|
Path(path).mkdir(parents=True, exist_ok=True)
|
||||||
|
print('total epochs for training: ', args.epochs)
|
||||||
|
|
||||||
|
# read dataset
|
||||||
|
train_loader = []
|
||||||
|
test_loader = []
|
||||||
|
loss_funcs = []
|
||||||
|
optimizers = []
|
||||||
|
models = []
|
||||||
|
parameters = []
|
||||||
|
for i in range(7): # 7 tasks
|
||||||
|
# train data
|
||||||
|
with open(args.dataset_path + 'train_data_' + str(i) + '.pkl', 'rb') as f:
|
||||||
|
data_x = pickle.load(f)
|
||||||
|
with open(args.dataset_path + 'train_label_' + str(i) + '.pkl', 'rb') as f:
|
||||||
|
data_y = pickle.load(f)
|
||||||
|
train_set = train_dataset(np.array(data_x), np.array(data_y)-1)
|
||||||
|
train_loader.append(DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4))
|
||||||
|
print('task', str(i), 'train data size: ', len(train_set))
|
||||||
|
|
||||||
|
net = ActionDemo2Predicate(args)
|
||||||
|
models.append(net)
|
||||||
|
parameter = net.parameters()
|
||||||
|
loss_funcs.append(nn.CrossEntropyLoss())
|
||||||
|
optimizers.append(optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay))
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
path_save = path + '/task' + str(i)
|
||||||
|
print('checkpoint save path: ', path_save)
|
||||||
|
train_model(models[i], train_loader[i], loss_funcs[i], optimizers[i], args.epochs, DEVICE, path_save, args.resume)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
12
keyboard_and_mouse/train.sh
Normal file
12
keyboard_and_mouse/train.sh
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
python3 train.py \
|
||||||
|
--resume False \
|
||||||
|
--batch_size 8 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--model_type lstmlast \
|
||||||
|
--epochs 100 \
|
||||||
|
--demo_hidden 128 \
|
||||||
|
--hidden_size 128 \
|
||||||
|
--dropout 0.5 \
|
||||||
|
--dataset_path dataset/strategy_dataset/ \
|
||||||
|
--checkpoint checkpoints/ \
|
||||||
|
--weight_decay 1e-4
|
46
watch_and_help/README.md
Normal file
46
watch_and_help/README.md
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# Watch And Help Dataset
|
||||||
|
|
||||||
|
Codes to reproduce results on WAH dataset[^1]
|
||||||
|
|
||||||
|
[^1]: Modified based on WAH train and test codes (https://github.com/xavierpuigf/watch_and_help)[https://github.com/xavierpuigf/watch_and_help].
|
||||||
|
|
||||||
|
## Data
|
||||||
|
|
||||||
|
Extact `dataset/watch_data.zip`
|
||||||
|
|
||||||
|
|
||||||
|
## Neural Network
|
||||||
|
|
||||||
|
Run `sh scripts/train_watch_strategy_full.sh` to train the model
|
||||||
|
|
||||||
|
To test model, either use trained model or extract checkpoints `checkpoints/train_strategy_full/lstmlast.zip`
|
||||||
|
|
||||||
|
Run `sh scripts/test_watch_strategy_full.sh` to test the model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Prediction Split
|
||||||
|
|
||||||
|
Create artificial users and sample predictions from 10% to 90%
|
||||||
|
|
||||||
|
```
|
||||||
|
cd stan
|
||||||
|
sh split_user.sh
|
||||||
|
sh sampler_user.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Bayesian Inference
|
||||||
|
|
||||||
|
|
||||||
|
Run inference to get results of user intention prediction and action length (0% to 100%) for all users
|
||||||
|
|
||||||
|
```
|
||||||
|
Rscript strategy_inference_test.R
|
||||||
|
```
|
||||||
|
|
||||||
|
Plot intention prediction results and 10% to 100% of actions results
|
||||||
|
|
||||||
|
```
|
||||||
|
sh plot_user_length.sh
|
||||||
|
sh plot_user_length_10_steps.sh
|
BIN
watch_and_help/checkpoints/train_strategy_full/lstmlast.zip
Normal file
BIN
watch_and_help/checkpoints/train_strategy_full/lstmlast.zip
Normal file
Binary file not shown.
BIN
watch_and_help/dataset/watch_data.zip
Normal file
BIN
watch_and_help/dataset/watch_data.zip
Normal file
Binary file not shown.
13
watch_and_help/scripts/test_watch_strategy_full.sh
Normal file
13
watch_and_help/scripts/test_watch_strategy_full.sh
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
python3 watch_strategy_full/predicate-train-strategy.py \
|
||||||
|
--testset test_task \
|
||||||
|
--gpu_id 0 \
|
||||||
|
--batch_size 32 \
|
||||||
|
--demo_hidden 512 \
|
||||||
|
--model_type lstmlast \
|
||||||
|
--dropout 0 \
|
||||||
|
--inputtype actioninput \
|
||||||
|
--inference 2 \
|
||||||
|
--single 1 \
|
||||||
|
--resume '' \
|
||||||
|
--loss_type ce \
|
||||||
|
--checkpoint checkpoints/train_strategy_full/lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob
|
13
watch_and_help/scripts/train_watch_strategy_full.sh
Normal file
13
watch_and_help/scripts/train_watch_strategy_full.sh
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
python3 watch_strategy_full/predicate-train-strategy.py \
|
||||||
|
--gpu_id 0 \
|
||||||
|
--model_lr_rate 3e-4 \
|
||||||
|
--batch_size 32 \
|
||||||
|
--demo_hidden 512 \
|
||||||
|
--model_type lstmlast \
|
||||||
|
--inputtype actioninput \
|
||||||
|
--dropout 0 \
|
||||||
|
--single 1 \
|
||||||
|
--resume '' \
|
||||||
|
--checkpoint checkpoints/train_strategy_full/lstmlast \
|
||||||
|
--train_iters 2000 \
|
||||||
|
--loss_type ce\
|
132
watch_and_help/stan/plot_user_length.py
Normal file
132
watch_and_help/stan/plot_user_length.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.task_type == 'new_test_task':
|
||||||
|
user = 9
|
||||||
|
N = 1
|
||||||
|
if args.task_type == 'test_task':
|
||||||
|
user = 92
|
||||||
|
N = 1
|
||||||
|
rate = 100
|
||||||
|
|
||||||
|
widths = [-0.1, 0, 0.1]
|
||||||
|
user_table = [6, 13, 15, 19, 20, 23, 27, 30, 33, 44, 46, 49, 50, 51, 52, 53, 54, 56, 65, 71, 84]
|
||||||
|
|
||||||
|
# read data
|
||||||
|
model_data_list = []
|
||||||
|
user_list = []
|
||||||
|
if not args.plot_user_list:
|
||||||
|
for i in range(user):
|
||||||
|
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(rate) + "_user" + str(i) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data)
|
||||||
|
if args.task_type == 'test_task':
|
||||||
|
user_list.append(np.transpose(data[:,[0]]))
|
||||||
|
else:
|
||||||
|
for i in range(user):
|
||||||
|
for t in user_table:
|
||||||
|
if t == i+1:
|
||||||
|
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(rate) + "_user" + str(i) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data)
|
||||||
|
user_list.append(np.transpose(data[:,[0]]))
|
||||||
|
|
||||||
|
color = ['royalblue', 'lightgreen', 'tomato']
|
||||||
|
legend = ['put fridge', 'put\n dishwasher', 'read book']
|
||||||
|
fig, axs = plt.subplots(3, sharex=True, sharey=True)
|
||||||
|
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
||||||
|
fig.set_figwidth(20)
|
||||||
|
|
||||||
|
for ax in range(3):
|
||||||
|
y_total = []
|
||||||
|
y_low_total = []
|
||||||
|
y_high_total = []
|
||||||
|
for j in range(3):
|
||||||
|
y= []
|
||||||
|
y_low = []
|
||||||
|
y_high = []
|
||||||
|
for i in range(len(model_data_list)):
|
||||||
|
y.append(model_data_list[i][j+ax*3][0])
|
||||||
|
y_low.append(model_data_list[i][j+ax*3][2])
|
||||||
|
y_high.append(model_data_list[i][j+ax*3][3])
|
||||||
|
y_total.append(y)
|
||||||
|
y_low_total.append(y_low)
|
||||||
|
y_high_total.append(y_high)
|
||||||
|
print()
|
||||||
|
print("user mean of mean prob: ", np.mean(y))
|
||||||
|
print("user mean of sd prob: ", np.std(y))
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
if args.plot_type == 'line':
|
||||||
|
axs[ax].plot(range(user), y_total[i], color=color[i], label=legend[i])
|
||||||
|
axs[ax].fill_between(range(user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||||
|
if args.plot_type == 'bar':
|
||||||
|
if args.task_type == 'new_test_task':
|
||||||
|
widths = [-0.25, 0, 0.25]
|
||||||
|
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||||
|
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
|
||||||
|
axs[ax].bar(np.arange(user)+widths[i],y_total[i], width=0.2, yerr=yerror, color=color[i], label=legend[i])
|
||||||
|
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
|
||||||
|
plt.xticks(range(user), range(1,user+1))
|
||||||
|
axs[ax].set_ylabel('prob', fontsize= 36) # was 22
|
||||||
|
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
|
||||||
|
plt.xlabel('user', fontsize= 40) # was 22
|
||||||
|
for k, x in enumerate(np.arange(user)+widths[i]):
|
||||||
|
y = y_total[i][k] + yerror[1][k]
|
||||||
|
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-15, 3), fontsize=14)
|
||||||
|
|
||||||
|
|
||||||
|
if args.task_type == 'test_task':
|
||||||
|
if not args.plot_user_list:
|
||||||
|
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
|
||||||
|
axs[ax].errorbar(np.arange(user)+widths[i],y_total[i], yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])],markerfacecolor=color[i], ecolor=color[i], markeredgecolor=color[i], label=legend[i],fmt='.k')
|
||||||
|
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
|
||||||
|
plt.xticks(range(user)[::5], range(1,user+1)[::5])
|
||||||
|
axs[ax].set_ylabel('prob', fontsize= 36)
|
||||||
|
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
|
||||||
|
plt.xlabel('user', fontsize= 40)
|
||||||
|
else:
|
||||||
|
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
|
||||||
|
axs[ax].errorbar(np.arange(len(model_data_list))+widths[i],y_total[i], yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])],markerfacecolor=color[i], ecolor=color[i], markeredgecolor=color[i], label=legend[i],fmt='.k')
|
||||||
|
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
|
||||||
|
plt.xticks(range(len(model_data_list)), user_table)
|
||||||
|
axs[ax].set_ylabel('prob', fontsize= 36)
|
||||||
|
#axs[ax].set_yticks(range(0.0,1.0, 0.25))
|
||||||
|
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
|
||||||
|
plt.xlabel('user', fontsize= 40)
|
||||||
|
|
||||||
|
axs[ax].tick_params(axis='both', which='major', labelsize=30)
|
||||||
|
|
||||||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||||||
|
|
||||||
|
plt.ylim([0, 1.08])
|
||||||
|
plt.tight_layout()
|
||||||
|
pathlib.Path("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if args.task_type == 'test_task':
|
||||||
|
if not args.plot_user_list:
|
||||||
|
plt.savefig("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_1.png", bbox_inches='tight')
|
||||||
|
else:
|
||||||
|
plt.savefig("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_1_user_analysis.png", bbox_inches='tight')
|
||||||
|
if args.task_type == 'new_test_task':
|
||||||
|
plt.savefig("result/"+args.ask_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_2.png", bbox_inches='tight')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--loss_type', type=str, default='ce')
|
||||||
|
parser.add_argument('--model_type', type=str, default="lstmlast" )
|
||||||
|
parser.add_argument('--plot_type', type=str, default='bar') # bar or line
|
||||||
|
parser.add_argument('--task_type', type=str, default='test_task')
|
||||||
|
parser.add_argument('--plot_user_list', action='store_true') # plot user_table or not
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|
5
watch_and_help/stan/plot_user_length.sh
Normal file
5
watch_and_help/stan/plot_user_length.sh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
python3 plot_user_length.py \
|
||||||
|
--loss_type ce \
|
||||||
|
--model_type lstmlast \
|
||||||
|
--plot_type bar \
|
||||||
|
--task_type test_task
|
88
watch_and_help/stan/plot_user_length_10_steps.py
Normal file
88
watch_and_help/stan/plot_user_length_10_steps.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--loss_type', type=str, default='ce')
|
||||||
|
parser.add_argument('--model_type', type=str, default="lstmlast" )
|
||||||
|
parser.add_argument('--task_type', type=str, default='test_task')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.task_type == 'new_test_task':
|
||||||
|
user = 9
|
||||||
|
N = 1
|
||||||
|
if args.task_type == 'test_task':
|
||||||
|
user = 92
|
||||||
|
N = 1
|
||||||
|
|
||||||
|
#rate = range(0,101,10)
|
||||||
|
rate_user_data_list = []
|
||||||
|
for r in range(0,101,10): # rate = range(0,101,10)
|
||||||
|
# read data
|
||||||
|
print(r)
|
||||||
|
model_data_list = []
|
||||||
|
for i in range(user):
|
||||||
|
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(r) + "_user" + str(i) +".csv"
|
||||||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||||
|
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
|
||||||
|
model_data_list.append(data)
|
||||||
|
#print(type(data))
|
||||||
|
model_data_list_total = np.stack(model_data_list)
|
||||||
|
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||||
|
rate_user_data_list.append(mean_user_data)
|
||||||
|
|
||||||
|
color = ['royalblue', 'lightgreen', 'tomato']
|
||||||
|
legend = ['put fridge', 'put\n dishwasher', 'read book']
|
||||||
|
fig, axs = plt.subplots(3, sharex=True, sharey=True)
|
||||||
|
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
||||||
|
fig.set_figwidth(20)
|
||||||
|
axs[0].text(-0.145, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 25) # all: -0.3,0.5 3rows: -0.5,0.5
|
||||||
|
|
||||||
|
for ax in range(3):
|
||||||
|
y_total = []
|
||||||
|
y_low_total = []
|
||||||
|
y_high_total = []
|
||||||
|
for j in range(3):
|
||||||
|
y= []
|
||||||
|
y_low = []
|
||||||
|
y_high = []
|
||||||
|
for i in range(len(rate_user_data_list)):
|
||||||
|
y.append(rate_user_data_list[i][j+ax*3][0])
|
||||||
|
y_low.append(rate_user_data_list[i][j+ax*3][2])
|
||||||
|
y_high.append(rate_user_data_list[i][j+ax*3][3])
|
||||||
|
y_total.append(y)
|
||||||
|
y_low_total.append(y_low)
|
||||||
|
y_high_total.append(y_high)
|
||||||
|
print()
|
||||||
|
print("user mean of mean prob: ", np.mean(y))
|
||||||
|
print("user mean of sd prob: ", np.std(y))
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||||
|
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||||
|
axs[ax].set_xticks(range(0,101,10))
|
||||||
|
axs[ax].set_ylabel('probability', fontsize=22)
|
||||||
|
|
||||||
|
axs[ax].text(-0.145, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 25, color=color[ax])
|
||||||
|
axs[ax].tick_params(axis='both', which='major', labelsize=18)
|
||||||
|
|
||||||
|
plt.xlabel('Percentage of observed actions in one action sequence', fontsize= 22)
|
||||||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||||||
|
|
||||||
|
plt.xlim([0, 101])
|
||||||
|
plt.ylim([0, 1])
|
||||||
|
pathlib.Path("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/").mkdir(parents=True, exist_ok=True)
|
||||||
|
if args.task_type == 'test_task':
|
||||||
|
plt.savefig("result/"+args.task_type+"/user"+str(user)+ "/"+args.loss_type+"/figure/N"+ str(N) + "_"+args.model_type+"_rate_full_test_set_1.png", bbox_inches='tight')
|
||||||
|
if args.task_type == 'new_test_task':
|
||||||
|
plt.savefig("result/"+args.task_type+"/user"+str(user)+ "/"+args.loss_type+"/figure/N"+ str(N) + "_"+args.model_type+"_rate_full_test_set_2.png", bbox_inches='tight')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
4
watch_and_help/stan/plot_user_length_10_steps.sh
Normal file
4
watch_and_help/stan/plot_user_length_10_steps.sh
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
python3 plot_user_length_10_steps.py \
|
||||||
|
--loss_type ce \
|
||||||
|
--model_type lstmlast \
|
||||||
|
--task_type test_task
|
64
watch_and_help/stan/sampler_user.py
Normal file
64
watch_and_help/stan/sampler_user.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy import genfromtxt
|
||||||
|
import csv
|
||||||
|
import pandas
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def sample_predciton(path, rate):
|
||||||
|
data = pandas.read_csv(path).values
|
||||||
|
task_list = [0, 1, 2]
|
||||||
|
start = 0
|
||||||
|
stop = 0
|
||||||
|
num_unique = np.unique(data[:,1])
|
||||||
|
#print('unique number', num_unique)
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for j in task_list:
|
||||||
|
for i in num_unique:
|
||||||
|
inx = np.where((data[:,1] == i) & (data[:,-2] == j))
|
||||||
|
samples.append(data[inx])
|
||||||
|
|
||||||
|
for i in range(len(samples)):
|
||||||
|
n = int(len(samples[i])*(100-rate)/100)
|
||||||
|
samples[i] = samples[i][:-n]
|
||||||
|
|
||||||
|
return np.vstack(samples)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--LOSS', type=str, default='ce')
|
||||||
|
parser.add_argument('--MODEL_TYPE', type=str, default="lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob" )
|
||||||
|
parser.add_argument('--EPOCHS', type=int, default=50)
|
||||||
|
parser.add_argument('--TASK', type=str, default='test_task')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
task = ['put_fridge', 'put_dishwasher', 'read_book']
|
||||||
|
sets = [args.TASK]
|
||||||
|
rate = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||||
|
|
||||||
|
for i in task:
|
||||||
|
for j in rate:
|
||||||
|
for k in sets:
|
||||||
|
if k == 'test_task':
|
||||||
|
user_num = 92
|
||||||
|
if k == 'new_test_task':
|
||||||
|
user_num = 9
|
||||||
|
|
||||||
|
for l in range(user_num):
|
||||||
|
pred_path = "prediction/" + k + "/" + "user" + str(user_num) + "/ce/" + i + "/" + "loss_weight_" + args.MODEL_TYPE + "_prediction_" + i + "_user" + str(l) + ".csv"
|
||||||
|
save_path = "prediction/" + k + "/" + "user" + str(user_num) + "/ce/" + i + "/" + "loss_weight_" + args.MODEL_TYPE + "_prediction_" + i + "_user" + str(l) + "_rate_" + str(j) + ".csv"
|
||||||
|
data = sample_predciton(pred_path, j)
|
||||||
|
|
||||||
|
head = []
|
||||||
|
for r in range(79):
|
||||||
|
head.append('act'+str(r+1))
|
||||||
|
head.append('task_name')
|
||||||
|
head.append('gt')
|
||||||
|
head.insert(0,'action_id')
|
||||||
|
pandas.DataFrame(data[:,1:]).to_csv(save_path, header=head)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
5
watch_and_help/stan/sampler_user.sh
Normal file
5
watch_and_help/stan/sampler_user.sh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
python3 sampler_user.py \
|
||||||
|
--TASK test_task \
|
||||||
|
--LOSS ce \
|
||||||
|
--MODEL_TYPE lstmlast \
|
||||||
|
--EPOCHS 50
|
76
watch_and_help/stan/save_act_series.R
Normal file
76
watch_and_help/stan/save_act_series.R
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
library(tidyverse)
|
||||||
|
library(cmdstanr)
|
||||||
|
library(dplyr)
|
||||||
|
|
||||||
|
strategies <- c("put_fridge", "put_dishwasher", "read_book")
|
||||||
|
model_type <- "lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob"
|
||||||
|
rate <- "_0"
|
||||||
|
task_type <- "new_test_task" # new_test_task test_task
|
||||||
|
loss_type <- "ce"
|
||||||
|
set.seed(9746234)
|
||||||
|
if (task_type=="test_task"){
|
||||||
|
user_num <- 92
|
||||||
|
user <-c(0:(user_num-1))
|
||||||
|
N <- 1
|
||||||
|
}
|
||||||
|
if (task_type=="new_test_task"){
|
||||||
|
user_num <- 9
|
||||||
|
user <-c(0:(user_num-1))
|
||||||
|
N <- 1
|
||||||
|
}
|
||||||
|
total_user_act1 <- vector("list", length(user_num))
|
||||||
|
total_user_act2 <- vector("list", length(user_num))
|
||||||
|
|
||||||
|
sel <- vector("list", length(strategies))
|
||||||
|
act_series <- vector("list", user_num)
|
||||||
|
for (u in seq_along(user)){
|
||||||
|
print('user')
|
||||||
|
print(u)
|
||||||
|
dat <- vector("list", length(strategies))
|
||||||
|
for (i in seq_along(strategies)) {
|
||||||
|
if (rate=="_0"){
|
||||||
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate_", "90", ".csv"))
|
||||||
|
} else if (rate=="_100"){
|
||||||
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], ".csv"))
|
||||||
|
} else{
|
||||||
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate", rate, ".csv"))
|
||||||
|
}
|
||||||
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||||
|
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||||
|
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||||
|
}
|
||||||
|
|
||||||
|
N <- 1
|
||||||
|
# select all action series and infer every one
|
||||||
|
sel[[1]]<-dat[[1]] %>%
|
||||||
|
group_by(task_name) %>%
|
||||||
|
filter(task_name==1)
|
||||||
|
sel[[1]] <- data.frame(sel[[1]])
|
||||||
|
unique_act_id_t1 <- unique(sel[[1]]$action_id)
|
||||||
|
write.csv(unique_act_id_t1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "user_",u, "_put_dishwasher", ".csv"))
|
||||||
|
total_user_act1[[u]] <- unique_act_id_t1
|
||||||
|
|
||||||
|
sel[[1]]<-dat[[1]] %>%
|
||||||
|
group_by(task_name) %>%
|
||||||
|
filter(task_name==2)
|
||||||
|
sel[[1]] <- data.frame(sel[[1]])
|
||||||
|
unique_act_id_t1 <- unique(sel[[1]]$action_id)
|
||||||
|
write.csv(unique_act_id_t1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "user_",u, "_read_book", ".csv"))
|
||||||
|
total_user_act2[[u]] <- unique_act_id_t1
|
||||||
|
}
|
||||||
|
|
||||||
|
write.csv(total_user_act1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "_put_dishwasher_total", ".csv"))
|
||||||
|
write.csv(total_user_act2, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "read_book_total", ".csv"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
87
watch_and_help/stan/split_user.py
Normal file
87
watch_and_help/stan/split_user.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
import numpy as np
|
||||||
|
import pathlib
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
np.random.seed(seed=100)
|
||||||
|
|
||||||
|
def sample_user(data, num_users, split_inx):
|
||||||
|
np.random.seed(seed=100)
|
||||||
|
num_unique3 = np.unique(data[:,1])
|
||||||
|
num_unique2 = num_unique3[0:split_inx[1]]
|
||||||
|
num_unique = num_unique3[0:split_inx[0]]
|
||||||
|
|
||||||
|
user_list1 = [np.random.choice(num_unique, int(len(num_unique)/num_users), replace=False) for i in range(num_users)]
|
||||||
|
user_list2 = [np.random.choice(num_unique2, int(len(num_unique2)/num_users), replace=False) for i in range(num_users)]
|
||||||
|
user_list3 = [np.random.choice(num_unique3, int(len(num_unique3)/num_users), replace=False) for i in range(num_users)]
|
||||||
|
|
||||||
|
user_data = []
|
||||||
|
|
||||||
|
for i in range(num_users): # len(user_list)
|
||||||
|
user_idx1 = [int(item) for item in user_list1[i]]
|
||||||
|
user_idx2 = [int(item) for item in user_list2[i]]
|
||||||
|
user_idx3 = [int(item) for item in user_list3[i]]
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for j in range(len(user_idx1)):
|
||||||
|
inx = np.where((data[:,1] == user_idx1[j]) & (data[:,-2]==0))
|
||||||
|
data_list.append(data[inx])
|
||||||
|
|
||||||
|
for j in range(len(user_idx2)):
|
||||||
|
inx = np.where((data[:,1] == user_idx2[j]) & (data[:,-2]==1))
|
||||||
|
data_list.append(data[inx])
|
||||||
|
|
||||||
|
for j in range(len(user_idx3)):
|
||||||
|
inx = np.where((data[:,1] == user_idx3[j]) & (data[:,-2]==2))
|
||||||
|
data_list.append(data[inx])
|
||||||
|
|
||||||
|
user_data.append(np.vstack(data_list))
|
||||||
|
|
||||||
|
return user_data
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--LOSS', type=str, default='ce')
|
||||||
|
parser.add_argument('--MODEL_TYPE', type=str, default="lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob" )
|
||||||
|
parser.add_argument('--EPOCHS', type=int, default=50)
|
||||||
|
parser.add_argument('--TASK', type=str, default='test_task')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
pref = ['put_fridge', 'put_dishwasher', 'read_book']
|
||||||
|
|
||||||
|
if args.TASK == 'new_test_task':
|
||||||
|
NUM_USER = 9 # 9 for 1 user 1 action
|
||||||
|
SPLIT_INX = [NUM_USER, 45]
|
||||||
|
if args.TASK == 'test_task':
|
||||||
|
NUM_USER = 92
|
||||||
|
SPLIT_INX = [NUM_USER, 229]
|
||||||
|
|
||||||
|
head = []
|
||||||
|
for j in range(79):
|
||||||
|
head.append('act'+str(j+1))
|
||||||
|
head.append('task_name')
|
||||||
|
head.append('gt')
|
||||||
|
head.insert(0,'action_id')
|
||||||
|
head.insert(0,'')
|
||||||
|
|
||||||
|
for i in pref:
|
||||||
|
path = "prediction/"+args.TASK+"/" + args.MODEL_TYPE + "/model_" + i + "_strategy_put_fridge" +".csv"
|
||||||
|
data = np.genfromtxt(path, skip_header=1, delimiter=',')
|
||||||
|
data_task_name = np.genfromtxt(path, skip_header=1, delimiter=',', usecols=-2, dtype=None)
|
||||||
|
data_task_name[data_task_name==b'put_fridge'] = 0
|
||||||
|
data_task_name[data_task_name==b'put_dishwasher'] = 1
|
||||||
|
data_task_name[data_task_name==b'read_book'] = 2
|
||||||
|
data[:,-2] = data_task_name.astype(np.float)
|
||||||
|
print("data length: ", len(data))
|
||||||
|
users_data = sample_user(data, NUM_USER, SPLIT_INX)
|
||||||
|
|
||||||
|
length = 0
|
||||||
|
pathlib.Path("prediction/"+args.TASK+"/user" + str(NUM_USER) + "/" + args.LOSS + "/" + i).mkdir(parents=True, exist_ok=True)
|
||||||
|
for j in range(len(users_data)):
|
||||||
|
save_path = "prediction/"+args.TASK+"/user" + str(NUM_USER) + "/" + args.LOSS + "/" + i +"/loss_weight_"+ args.MODEL_TYPE + "_prediction_"+ i + "_user"+str(j)+".csv"
|
||||||
|
length = length + len(users_data[j])
|
||||||
|
np.savetxt(save_path, users_data[j], delimiter=',', header=','.join(head))
|
||||||
|
print("user data length: ", length)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
5
watch_and_help/stan/split_user.sh
Normal file
5
watch_and_help/stan/split_user.sh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
python3 split_user.py \
|
||||||
|
--TASK test_task \
|
||||||
|
--LOSS ce \
|
||||||
|
--MODEL_TYPE lstmlast \
|
||||||
|
--EPOCHS 50
|
BIN
watch_and_help/stan/strategy_inference_model
Executable file
BIN
watch_and_help/stan/strategy_inference_model
Executable file
Binary file not shown.
26
watch_and_help/stan/strategy_inference_model.stan
Normal file
26
watch_and_help/stan/strategy_inference_model.stan
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
data {
|
||||||
|
int<lower=1> I; // number of question options (22)
|
||||||
|
int<lower=0> N; // number of questions being asked by the user
|
||||||
|
int<lower=1> K; // number of strategies
|
||||||
|
// observed "true" questions of the user
|
||||||
|
int q[N];
|
||||||
|
// array of predicted probabilities of questions given strategies
|
||||||
|
// coming from the forward neural network
|
||||||
|
matrix[I, K] P_q_S[N];
|
||||||
|
}
|
||||||
|
parameters {
|
||||||
|
// probabiliy vector of the strategies being applied by the user
|
||||||
|
// to be inferred by the model here
|
||||||
|
simplex[K] P_S;
|
||||||
|
}
|
||||||
|
model {
|
||||||
|
for (n in 1:N) {
|
||||||
|
// marginal probability vector of the questions being asked
|
||||||
|
vector[I] theta = P_q_S[n] * P_S;
|
||||||
|
// categorical likelihood
|
||||||
|
target += categorical_lpmf(q[n] | theta);
|
||||||
|
}
|
||||||
|
// priors
|
||||||
|
target += dirichlet_lpdf(P_S | rep_vector(1.0, K));
|
||||||
|
}
|
||||||
|
|
190
watch_and_help/stan/strategy_inference_test.R
Normal file
190
watch_and_help/stan/strategy_inference_test.R
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
library(tidyverse)
|
||||||
|
library(cmdstanr)
|
||||||
|
library(dplyr)
|
||||||
|
|
||||||
|
# index order of the strategies assumed throughout
|
||||||
|
strategies <- c("put_fridge", "put_dishwasher", "read_book")
|
||||||
|
model_type <- "lstmlast"
|
||||||
|
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
||||||
|
task_type <- "test_task" # new_test_task test_task
|
||||||
|
loss_type <- "ce"
|
||||||
|
set.seed(9746234)
|
||||||
|
if (task_type=="test_task"){
|
||||||
|
user_num <- 92
|
||||||
|
user <-c(38:(user_num-1))
|
||||||
|
N <- 1
|
||||||
|
}
|
||||||
|
if (task_type=="new_test_task"){
|
||||||
|
user_num <- 9
|
||||||
|
user <-c(0:(user_num-1))
|
||||||
|
N <- 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# read data from csv
|
||||||
|
sel <- vector("list", length(strategies))
|
||||||
|
act_series <- vector("list", user_num)
|
||||||
|
for (u in seq_along(user)){
|
||||||
|
for (rate in rates) {
|
||||||
|
dat <- vector("list", length(strategies))
|
||||||
|
for (i in seq_along(strategies)) {
|
||||||
|
if (rate=="_0"){
|
||||||
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate_", "10", ".csv")) # _60
|
||||||
|
} else if (rate=="_100"){
|
||||||
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], ".csv")) # _60
|
||||||
|
} else{
|
||||||
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate", rate, ".csv")) # _60
|
||||||
|
}
|
||||||
|
# strategy assumed for prediction
|
||||||
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||||
|
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||||
|
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||||
|
}
|
||||||
|
|
||||||
|
# reset N after inference
|
||||||
|
if (task_type=="test_task"){
|
||||||
|
N <- 1
|
||||||
|
}
|
||||||
|
if (task_type=="new_test_task"){
|
||||||
|
N <- 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# select one action series from one intention
|
||||||
|
if (rate == "_0"){
|
||||||
|
sel[[1]]<-dat[[1]] %>%
|
||||||
|
group_by(task_name) %>%
|
||||||
|
sample_n(N)
|
||||||
|
|
||||||
|
sel[[1]] <- data.frame(sel[[1]])
|
||||||
|
act_series[[u]] <- sel[[1]]$action_id
|
||||||
|
#print(typeof(sel[[1]]))
|
||||||
|
#print(typeof(dat[[1]]))
|
||||||
|
#print(sel[[1]]$action_id[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
print(c('unique action id', sel[[1]]$action_id))
|
||||||
|
|
||||||
|
# filter data from the selected action series, N series per intention
|
||||||
|
dat[[1]]<-subset(dat[[1]], dat[[1]]$action_id == sel[[1]]$action_id[1] | dat[[1]]$action_id == sel[[1]]$action_id[2] | dat[[1]]$action_id == sel[[1]]$action_id[3])
|
||||||
|
dat[[2]]<-subset(dat[[2]], dat[[2]]$action_id == sel[[1]]$action_id[1] | dat[[2]]$action_id == sel[[1]]$action_id[2] | dat[[2]]$action_id == sel[[1]]$action_id[3])
|
||||||
|
dat[[3]]<-subset(dat[[3]], dat[[3]]$action_id == sel[[1]]$action_id[1] | dat[[3]]$action_id == sel[[1]]$action_id[2] | dat[[3]]$action_id == sel[[1]]$action_id[3])
|
||||||
|
row.names(dat) <- NULL
|
||||||
|
print(c('task name 1', dat[[1]]$task_name))
|
||||||
|
print(c('task name 2', dat[[2]]$task_name))
|
||||||
|
print(c('task name 3', dat[[3]]$task_name))
|
||||||
|
print(c('action id 1', dat[[1]]$action_id))
|
||||||
|
print(c('action id 2', dat[[2]]$action_id))
|
||||||
|
print(c('action id 3', dat[[3]]$action_id))
|
||||||
|
|
||||||
|
# create save path
|
||||||
|
dir.create(file.path(paste0("result/", task_type, "/user", user_num, "/", loss_type, "/N", N)), showWarnings = FALSE, recursive = TRUE)
|
||||||
|
dir.create(file.path("temp"), showWarnings = FALSE)
|
||||||
|
save_path <- paste0("result/", task_type, "/user", user_num, "/", loss_type, "/N", N, "/", model_type, "_N", N, "_", "result", rate,"_user", user[[u]], ".csv")
|
||||||
|
|
||||||
|
if(task_type=="test_task"){
|
||||||
|
dat <- do.call(rbind, dat) %>%
|
||||||
|
mutate(index = as.numeric(as.factor(id))) %>%
|
||||||
|
rename(true_strategy = task_name) %>%
|
||||||
|
mutate(
|
||||||
|
true_strategy = factor(
|
||||||
|
#true_strategy, levels = 0:3,
|
||||||
|
true_strategy, levels = 0:2,
|
||||||
|
labels = strategies
|
||||||
|
),
|
||||||
|
q_type = case_when(
|
||||||
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 19, 20, 22, 23, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43, 44, 58, 59, 64, 65, 68, 69, 70, 71, 72, 73, 74) ~ "put_fridge",
|
||||||
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 25, 29,30, 31, 32, 33, 34, 37, 38, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57) ~ "put_dishwasher",
|
||||||
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45) ~ "read_book",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if(task_type=="new_test_task"){
|
||||||
|
dat <- do.call(rbind, dat) %>%
|
||||||
|
mutate(index = as.numeric(as.factor(id))) %>%
|
||||||
|
rename(true_strategy = task_name) %>%
|
||||||
|
mutate(
|
||||||
|
true_strategy = factor(
|
||||||
|
true_strategy, levels = 0:2,
|
||||||
|
labels = strategies
|
||||||
|
),
|
||||||
|
q_type = case_when(
|
||||||
|
# new_test_set
|
||||||
|
gt %in% c(1, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 19, 20, 22, 23, 25, 29, 30, 31, 32, 33, 34, 35, 40, 42, 43, 44, 46, 47, 52, 53, 55, 56, 58, 59, 60, 64, 65, 68, 69, 70, 71, 72, 73, 74, 75, 77, 78) ~ "put_fridge",
|
||||||
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74) ~ "put_dishwasher",
|
||||||
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 60, 75, 76, 77, 78) ~ "read_book",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#print(nrow(dat))
|
||||||
|
#print(dat)
|
||||||
|
|
||||||
|
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
|
||||||
|
N <- nrow(dat_obs)
|
||||||
|
print(c("N: ", N))
|
||||||
|
q <- dat_obs$gt
|
||||||
|
true_strategy <- dat_obs$true_strategy
|
||||||
|
|
||||||
|
K <- length(unique(dat$assumed_strategy))
|
||||||
|
I <- 79
|
||||||
|
|
||||||
|
P_q_S <- array(dim = c(N, I, K))
|
||||||
|
for (n in 1:N) {
|
||||||
|
P_q_S[n, , ] <- dat %>%
|
||||||
|
filter(index == n) %>%
|
||||||
|
select(matches("^act[[:digit:]]+$")) %>%
|
||||||
|
as.matrix() %>%
|
||||||
|
t()
|
||||||
|
for (k in 1:K) {
|
||||||
|
# normalize probabilities
|
||||||
|
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == "put_fridge")
|
||||||
|
}
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_put_fridge <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_put_fridge$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == "put_dishwasher")
|
||||||
|
}
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_put_dishwasher <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_put_dishwasher$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
# read_book strategy (should favor index 3)
|
||||||
|
if (rate=="_0"){
|
||||||
|
sub <- integer(0)
|
||||||
|
} else {
|
||||||
|
sub <- which(true_strategy == "read_book")
|
||||||
|
}
|
||||||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||||
|
fit_read_book <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||||
|
print(fit_read_book$summary(NULL, c("mean","sd")))
|
||||||
|
|
||||||
|
# save csv
|
||||||
|
df <-rbind(fit_put_fridge$summary(), fit_put_dishwasher$summary(), fit_read_book$summary())
|
||||||
|
write.csv(df,file=save_path,quote=FALSE)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Binary file not shown.
203
watch_and_help/watch_strategy_full/helper.py
Normal file
203
watch_and_help/watch_strategy_full/helper.py
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from torch.nn.modules.rnn import RNNCellBase
|
||||||
|
|
||||||
|
def to_cpu(list_of_tensor):
|
||||||
|
if isinstance(list_of_tensor[0], list):
|
||||||
|
list_list_of_tensor = list_of_tensor
|
||||||
|
list_of_tensor = [to_cpu(list_of_tensor)
|
||||||
|
for list_of_tensor in list_list_of_tensor]
|
||||||
|
else:
|
||||||
|
list_of_tensor = [tensor.cpu() for tensor in list_of_tensor]
|
||||||
|
return list_of_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def average_over_list(l):
|
||||||
|
return sum(l) / len(l)
|
||||||
|
|
||||||
|
def _LayerNormGRUCell(input, hidden, w_ih, w_hh, ln, b_ih=None, b_hh=None):
|
||||||
|
gi = F.linear(input, w_ih, b_ih)
|
||||||
|
gh = F.linear(hidden, w_hh, b_hh)
|
||||||
|
i_r, i_i, i_n = gi.chunk(3, 1)
|
||||||
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
||||||
|
|
||||||
|
# use layernorm here
|
||||||
|
resetgate = torch.sigmoid(ln['resetgate'](i_r + h_r))
|
||||||
|
inputgate = torch.sigmoid(ln['inputgate'](i_i + h_i))
|
||||||
|
newgate = torch.tanh(ln['newgate'](i_n + resetgate * h_n))
|
||||||
|
hy = newgate + inputgate * (hidden - newgate)
|
||||||
|
return hy
|
||||||
|
|
||||||
|
class CombinedEmbedding(nn.Module):
|
||||||
|
def __init__(self, pretrained_embedding, embedding):
|
||||||
|
super(CombinedEmbedding, self).__init__()
|
||||||
|
self.pretrained_embedding = pretrained_embedding
|
||||||
|
self.embedding = embedding
|
||||||
|
self.pivot = pretrained_embedding.num_embeddings
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
outputs = []
|
||||||
|
mask = input < self.pivot
|
||||||
|
outputs.append(self.pretrained_embedding(torch.clamp(input, 0, self.pivot-1)) * mask.unsqueeze(1).float())
|
||||||
|
mask = input >= self.pivot
|
||||||
|
outputs.append(self.embedding(torch.clamp(input, self.pivot) - self.pivot) * mask.unsqueeze(1).float())
|
||||||
|
return sum(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class writer_helper(object):
|
||||||
|
def __init__(self, writer):
|
||||||
|
self.writer = writer
|
||||||
|
self.all_steps = {}
|
||||||
|
|
||||||
|
def get_step(self, tag):
|
||||||
|
if tag not in self.all_steps.keys():
|
||||||
|
self.all_steps.update({tag: 0})
|
||||||
|
|
||||||
|
step = self.all_steps[tag]
|
||||||
|
self.all_steps[tag] += 1
|
||||||
|
return step
|
||||||
|
|
||||||
|
def scalar_summary(self, tag, value, step=None):
|
||||||
|
if step is None:
|
||||||
|
step = self.get_step(tag)
|
||||||
|
self.writer.add_scalar(tag, value, step)
|
||||||
|
|
||||||
|
def text_summary(self, tag, value, step=None):
|
||||||
|
if step is None:
|
||||||
|
step = self.get_step(tag)
|
||||||
|
self.writer.add_text(tag, value, step)
|
||||||
|
|
||||||
|
|
||||||
|
class Constant():
|
||||||
|
def __init__(self, v):
|
||||||
|
self.v = v
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LinearStep():
|
||||||
|
def __init__(self, max, min, steps):
|
||||||
|
self.steps = float(steps)
|
||||||
|
self.max = max
|
||||||
|
self.min = min
|
||||||
|
self.cur_step = 0
|
||||||
|
self.v = self.max
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
v = max(self.max - (self.max - self.min) *
|
||||||
|
self.cur_step / self.steps, self.min)
|
||||||
|
self.cur_step += 1
|
||||||
|
self.v = v
|
||||||
|
|
||||||
|
|
||||||
|
class fc_block(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, norm, activation_fn):
|
||||||
|
super(fc_block, self).__init__()
|
||||||
|
|
||||||
|
block = nn.Sequential()
|
||||||
|
block.add_module('linear', nn.Linear(in_channels, out_channels))
|
||||||
|
if norm:
|
||||||
|
block.add_module('batchnorm', nn.BatchNorm1d(out_channels))
|
||||||
|
if activation_fn is not None:
|
||||||
|
block.add_module('activation', activation_fn())
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class conv_block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
norm,
|
||||||
|
activation_fn):
|
||||||
|
super(conv_block, self).__init__()
|
||||||
|
|
||||||
|
block = nn.Sequential()
|
||||||
|
block.add_module(
|
||||||
|
'conv',
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride))
|
||||||
|
if norm:
|
||||||
|
block.add_module('batchnorm', nn.BatchNorm2d(out_channels))
|
||||||
|
if activation_fn is not None:
|
||||||
|
block.add_module('activation', activation_fn())
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_conv_output_shape(shape, block):
|
||||||
|
B = 1
|
||||||
|
input = torch.rand(B, *shape)
|
||||||
|
output = block(input)
|
||||||
|
n_size = output.data.view(B, -1).size(1)
|
||||||
|
return n_size
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
def forward(self, input):
|
||||||
|
return input.view(input.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
|
def BHWC_to_BCHW(tensor):
|
||||||
|
tensor = torch.transpose(tensor, 1, 3) # BCWH
|
||||||
|
tensor = torch.transpose(tensor, 2, 3) # BCHW
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def LCS(X, Y):
|
||||||
|
|
||||||
|
# find the length of the strings
|
||||||
|
m = len(X)
|
||||||
|
n = len(Y)
|
||||||
|
|
||||||
|
# declaring the array for storing the dp values
|
||||||
|
L = [[None] * (n + 1) for i in range(m + 1)]
|
||||||
|
longest_L = [[[]] * (n + 1) for i in range(m + 1)]
|
||||||
|
longest = 0
|
||||||
|
lcs_set = []
|
||||||
|
|
||||||
|
for i in range(m + 1):
|
||||||
|
for j in range(n + 1):
|
||||||
|
if i == 0 or j == 0:
|
||||||
|
L[i][j] = 0
|
||||||
|
longest_L[i][j] = []
|
||||||
|
elif X[i - 1] == Y[j - 1]:
|
||||||
|
L[i][j] = L[i - 1][j - 1] + 1
|
||||||
|
longest_L[i][j] = longest_L[i - 1][j - 1] + [X[i - 1]]
|
||||||
|
if L[i][j] > longest:
|
||||||
|
lcs_set = []
|
||||||
|
lcs_set.append(longest_L[i][j])
|
||||||
|
longest = L[i][j]
|
||||||
|
elif L[i][j] == longest and longest != 0:
|
||||||
|
lcs_set.append(longest_L[i][j])
|
||||||
|
else:
|
||||||
|
if L[i - 1][j] > L[i][j - 1]:
|
||||||
|
L[i][j] = L[i - 1][j]
|
||||||
|
longest_L[i][j] = longest_L[i - 1][j]
|
||||||
|
else:
|
||||||
|
L[i][j] = L[i][j - 1]
|
||||||
|
longest_L[i][j] = longest_L[i][j - 1]
|
||||||
|
|
||||||
|
if len(lcs_set) > 0:
|
||||||
|
return lcs_set[0]
|
||||||
|
else:
|
||||||
|
return lcs_set
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue