Added notebooks and models
This commit is contained in:
parent
f4ce991790
commit
0fb6efadd5
BIN
python/ModelSnapshots/CNN-33767.h5
Normal file
BIN
python/ModelSnapshots/CNN-33767.h5
Normal file
Binary file not shown.
BIN
python/ModelSnapshots/LSTM-v1-01605.h5
Normal file
BIN
python/ModelSnapshots/LSTM-v1-01605.h5
Normal file
Binary file not shown.
BIN
python/ModelSnapshots/LSTM-v2-00398.h5
Normal file
BIN
python/ModelSnapshots/LSTM-v2-00398.h5
Normal file
Binary file not shown.
BIN
python/Models/CNN.pb
Normal file
BIN
python/Models/CNN.pb
Normal file
Binary file not shown.
BIN
python/Models/LSTM.pb
Normal file
BIN
python/Models/LSTM.pb
Normal file
Binary file not shown.
340
python/Step_01_UserData.ipynb
Normal file
340
python/Step_01_UserData.ipynb
Normal file
|
@ -0,0 +1,340 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"from scipy.odr import *\n",
|
||||
"from scipy.stats import *\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from multiprocessing import Pool"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def cast_to_int(row):\n",
|
||||
" try:\n",
|
||||
" return np.array([a if float(a) >= 0 else 0 for a in row[2:-1]], dtype=np.uint8)\n",
|
||||
" except Exception as e:\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
"def load_csv(file):\n",
|
||||
" temp_df = pd.read_csv(file, header=None, names = [\"UserID\", \"Age\", \"Gender\"], delimiter=\";\")\n",
|
||||
" return temp_df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 298 ms, sys: 443 ms, total: 741 ms\n",
|
||||
"Wall time: 937 ms\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"pool = Pool(os.cpu_count() - 2)\n",
|
||||
"data_files = [\"DataStudyCollection/%s\" % file for file in os.listdir(\"DataStudyCollection\") if file.endswith(\".csv\") and \"userData\" in file]\n",
|
||||
"df_lst = pool.map(load_csv, data_files)\n",
|
||||
"dfAll = pd.concat(df_lst)\n",
|
||||
"dfAll = dfAll.sort_values(\"UserID\")\n",
|
||||
"dfAll = dfAll.reset_index(drop=True)\n",
|
||||
"pool.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"24.166666666666668"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dfAll.Age.mean()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1.4245742398014511"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dfAll.Age.std()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"21"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dfAll.Age.min()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"26"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dfAll.Age.max()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>UserID</th>\n",
|
||||
" <th>Age</th>\n",
|
||||
" <th>Gender</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>23</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>24</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>26</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>23</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>7</td>\n",
|
||||
" <td>21</td>\n",
|
||||
" <td>female</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>8</td>\n",
|
||||
" <td>24</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>24</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>24</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10</th>\n",
|
||||
" <td>11</td>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>female</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>11</th>\n",
|
||||
" <td>12</td>\n",
|
||||
" <td>26</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>12</th>\n",
|
||||
" <td>13</td>\n",
|
||||
" <td>22</td>\n",
|
||||
" <td>female</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>13</th>\n",
|
||||
" <td>14</td>\n",
|
||||
" <td>24</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>14</th>\n",
|
||||
" <td>15</td>\n",
|
||||
" <td>24</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>15</th>\n",
|
||||
" <td>16</td>\n",
|
||||
" <td>26</td>\n",
|
||||
" <td>female</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>16</th>\n",
|
||||
" <td>17</td>\n",
|
||||
" <td>26</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>17</th>\n",
|
||||
" <td>18</td>\n",
|
||||
" <td>23</td>\n",
|
||||
" <td>male</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" UserID Age Gender\n",
|
||||
"0 1 23 male\n",
|
||||
"1 2 24 male\n",
|
||||
"2 3 25 male\n",
|
||||
"3 4 25 male\n",
|
||||
"4 5 26 male\n",
|
||||
"5 6 23 male\n",
|
||||
"6 7 21 female\n",
|
||||
"7 8 24 male\n",
|
||||
"8 9 24 male\n",
|
||||
"9 10 24 male\n",
|
||||
"10 11 25 female\n",
|
||||
"11 12 26 male\n",
|
||||
"12 13 22 female\n",
|
||||
"13 14 24 male\n",
|
||||
"14 15 24 male\n",
|
||||
"15 16 26 female\n",
|
||||
"16 17 26 male\n",
|
||||
"17 18 23 male"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dfAll"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
338
python/Step_02_ReadData.ipynb
Normal file
338
python/Step_02_ReadData.ipynb
Normal file
|
@ -0,0 +1,338 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## This notebook creates one dataframe from all participants data\n",
|
||||
"## It also removes 1% of the data as this is corrupted"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"from scipy.odr import *\n",
|
||||
"from scipy.stats import *\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import ast\n",
|
||||
"from multiprocessing import Pool, cpu_count\n",
|
||||
"\n",
|
||||
"import scipy\n",
|
||||
"\n",
|
||||
"from IPython import display\n",
|
||||
"from matplotlib.patches import Rectangle\n",
|
||||
"\n",
|
||||
"from sklearn.metrics import mean_squared_error\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"import scipy.stats as st\n",
|
||||
"from sklearn.metrics import r2_score\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from matplotlib import cm\n",
|
||||
"from mpl_toolkits.mplot3d import axes3d\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import copy\n",
|
||||
"\n",
|
||||
"from sklearn.model_selection import LeaveOneOut, LeavePOut\n",
|
||||
"\n",
|
||||
"from multiprocessing import Pool"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def cast_to_int(row):\n",
|
||||
" try:\n",
|
||||
" return np.array([a if float(a) >= 0 else 0 for a in row[2:-1]], dtype=np.uint8)\n",
|
||||
" except Exception as e:\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
"def load_csv(file):\n",
|
||||
" temp_df = pd.read_csv(file, delimiter=\";\")\n",
|
||||
" temp_df.Image = temp_df.Image.str.split(',')\n",
|
||||
" temp_df.Image = temp_df.Image.apply(cast_to_int)\n",
|
||||
" return temp_df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['DataStudyCollection/17_studyData.csv', 'DataStudyCollection/2_studyData.csv', 'DataStudyCollection/12_studyData.csv', 'DataStudyCollection/15_studyData.csv', 'DataStudyCollection/5_studyData.csv', 'DataStudyCollection/1_studyData.csv', 'DataStudyCollection/14_studyData.csv', 'DataStudyCollection/10_studyData.csv', 'DataStudyCollection/13_studyData.csv', 'DataStudyCollection/18_studyData.csv', 'DataStudyCollection/6_studyData.csv', 'DataStudyCollection/16_studyData.csv', 'DataStudyCollection/3_studyData.csv', 'DataStudyCollection/7_studyData.csv', 'DataStudyCollection/8_studyData.csv', 'DataStudyCollection/9_studyData.csv', 'DataStudyCollection/11_studyData.csv', 'DataStudyCollection/4_studyData.csv']\n",
|
||||
"CPU times: user 1.86 s, sys: 1.03 s, total: 2.89 s\n",
|
||||
"Wall time: 17.3 s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"pool = Pool(cpu_count() - 2)\n",
|
||||
"data_files = [\"DataStudyCollection/%s\" % file for file in os.listdir(\"DataStudyCollection\") if file.endswith(\".csv\") and \"studyData\" in file]\n",
|
||||
"print(data_files)\n",
|
||||
"df_lst = pool.map(load_csv, data_files)\n",
|
||||
"dfAll = pd.concat(df_lst)\n",
|
||||
"pool.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1010014"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df = dfAll[dfAll.Image.notnull()]\n",
|
||||
"len(df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loaded 1013841 values\n",
|
||||
"removed 3827 values (thats 0.377%)\n",
|
||||
"new df has size 1010014\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"loaded %s values\" % len(dfAll))\n",
|
||||
"print(\"removed %s values (thats %s%%)\" % (len(dfAll) - len(df), round((len(dfAll) - len(df)) / len(dfAll) * 100, 3)))\n",
|
||||
"print(\"new df has size %s\" % len(df))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = df.reset_index(drop=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>userID</th>\n",
|
||||
" <th>Timestamp</th>\n",
|
||||
" <th>Current_Task</th>\n",
|
||||
" <th>Task_amount</th>\n",
|
||||
" <th>TaskID</th>\n",
|
||||
" <th>VersionID</th>\n",
|
||||
" <th>RepetitionID</th>\n",
|
||||
" <th>Actual_Data</th>\n",
|
||||
" <th>Is_Pause</th>\n",
|
||||
" <th>Image</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>17</td>\n",
|
||||
" <td>1547138602677</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>34</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>17</td>\n",
|
||||
" <td>1547138602697</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>34</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>17</td>\n",
|
||||
" <td>1547138602796</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>34</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>17</td>\n",
|
||||
" <td>1547138602817</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>34</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>17</td>\n",
|
||||
" <td>1547138602863</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>34</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" userID Timestamp Current_Task Task_amount TaskID VersionID \\\n",
|
||||
"0 17 1547138602677 0 34 0 0 \n",
|
||||
"1 17 1547138602697 0 34 0 0 \n",
|
||||
"2 17 1547138602796 0 34 0 0 \n",
|
||||
"3 17 1547138602817 0 34 0 0 \n",
|
||||
"4 17 1547138602863 0 34 0 0 \n",
|
||||
"\n",
|
||||
" RepetitionID Actual_Data Is_Pause \\\n",
|
||||
"0 0 False False \n",
|
||||
"1 0 False False \n",
|
||||
"2 0 False False \n",
|
||||
"3 0 False False \n",
|
||||
"4 0 False False \n",
|
||||
"\n",
|
||||
" Image \n",
|
||||
"0 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... \n",
|
||||
"1 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... \n",
|
||||
"2 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... \n",
|
||||
"3 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... \n",
|
||||
"4 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... "
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.to_pickle(\"DataStudyCollection/AllData.pkl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sorted(df.userID.unique())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
966
python/Step_05_CNN_PreprocessData.ipynb
Normal file
966
python/Step_05_CNN_PreprocessData.ipynb
Normal file
File diff suppressed because one or more lines are too long
890
python/Step_06_CNN_Baseline.ipynb
Normal file
890
python/Step_06_CNN_Baseline.ipynb
Normal file
|
@ -0,0 +1,890 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## USE for Multi GPU Systems\n",
|
||||
"#import os\n",
|
||||
"#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"from scipy.odr import *\n",
|
||||
"from scipy.stats import *\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import ast\n",
|
||||
"from multiprocessing import Pool\n",
|
||||
"\n",
|
||||
"import scipy\n",
|
||||
"\n",
|
||||
"from IPython import display\n",
|
||||
"from matplotlib.patches import Rectangle\n",
|
||||
"\n",
|
||||
"from sklearn.metrics import mean_squared_error\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"import scipy.stats as st\n",
|
||||
"from sklearn.metrics import r2_score\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from matplotlib import cm\n",
|
||||
"from mpl_toolkits.mplot3d import axes3d\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from matplotlib.patches import Ellipse\n",
|
||||
"\n",
|
||||
"import copy\n",
|
||||
"\n",
|
||||
"from sklearn.model_selection import LeaveOneOut, LeavePOut\n",
|
||||
"\n",
|
||||
"from multiprocessing import Pool\n",
|
||||
"import cv2\n",
|
||||
"\n",
|
||||
"import sklearn\n",
|
||||
"import random\n",
|
||||
"from sklearn import neighbors\n",
|
||||
"from sklearn import svm\n",
|
||||
"from sklearn import tree\n",
|
||||
"from sklearn import ensemble\n",
|
||||
"from sklearn.model_selection import GridSearchCV\n",
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import pandas as pd\n",
|
||||
"import math\n",
|
||||
"\n",
|
||||
"# Importing matplotlib to plot images.\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"# Importing SK-learn to calculate precision and recall\n",
|
||||
"import sklearn\n",
|
||||
"from sklearn import metrics\n",
|
||||
"from sklearn.model_selection import train_test_split, cross_val_score, LeaveOneGroupOut\n",
|
||||
"from sklearn.utils import shuffle\n",
|
||||
"from sklearn.model_selection import GridSearchCV\n",
|
||||
"from sklearn.metrics.pairwise import euclidean_distances\n",
|
||||
"from sklearn.metrics import confusion_matrix\n",
|
||||
"from sklearn.metrics import accuracy_score\n",
|
||||
"\n",
|
||||
"import pickle as pkl\n",
|
||||
"import h5py\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"import os.path\n",
|
||||
"import sys\n",
|
||||
"import datetime\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"import skimage\n",
|
||||
"\n",
|
||||
"target_names = [\"Knuckle\", \"Finger\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from skimage import measure\n",
|
||||
"from skimage.measure import find_contours, approximate_polygon, \\\n",
|
||||
" subdivide_polygon, EllipseModel, LineModelND"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def getEllipseParams(img):\n",
|
||||
" points = np.argwhere(img > 40)\n",
|
||||
" \n",
|
||||
" contours = skimage.measure.find_contours(img, 40)\n",
|
||||
" points_to_approx = []\n",
|
||||
" highest_val = 0\n",
|
||||
" for n, contour in enumerate(contours):\n",
|
||||
" if (len(contour) > highest_val):\n",
|
||||
" points_to_approx = contour\n",
|
||||
" highest_val = len(contour) \n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" contour = np.fliplr(points_to_approx)\n",
|
||||
" except Exception as inst:\n",
|
||||
" return [-1, -1, -1, -1, -1]\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" ellipse = skimage.measure.fit.EllipseModel()\n",
|
||||
" ellipse.estimate(contour)\n",
|
||||
" try:\n",
|
||||
" xc, yc, a, b, theta = ellipse.params \n",
|
||||
" except Exception as int:\n",
|
||||
" return [-1, -1, -1, -1, -1]\n",
|
||||
" \n",
|
||||
" return [xc, yc, a, b, theta]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[ 1 2 9 6 4 14 17 16 12 3 10 18 5] [13 8 11 15 7]\n",
|
||||
"13 : 5\n",
|
||||
"0.7222222222222222 : 0.2777777777777778\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# the data, split between train and test sets\n",
|
||||
"df = pd.read_pickle(\"DataStudyCollection/df_statistics.pkl\")\n",
|
||||
"\n",
|
||||
"lst = df.userID.unique()\n",
|
||||
"np.random.seed(42)\n",
|
||||
"np.random.shuffle(lst)\n",
|
||||
"test_ids = lst[-5:]\n",
|
||||
"train_ids = lst[:-5]\n",
|
||||
"\n",
|
||||
"df[\"Set\"] = \"Test\"\n",
|
||||
"df.loc[df.userID.isin(train_ids), \"Set\"] = \"Train\"\n",
|
||||
"print(train_ids, test_ids)\n",
|
||||
"print(len(train_ids), \":\", len(test_ids))\n",
|
||||
"print(len(train_ids) / len(lst), \":\", len(test_ids)/ len(lst))\n",
|
||||
"\n",
|
||||
"#df_train = df[df.userID.isin(train_ids)]\n",
|
||||
"#df_test = df[df.userID.isin(test_ids) & (df.Version == \"Normal\")]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<matplotlib.patches.Ellipse at 0x7ff60430b668>"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAJ4AAAD8CAYAAACGuR0qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADetJREFUeJzt3XuQVPWZxvHvO8AwMAy3gCNBRbCQDV4yuoSoYArvqNkga0K8ZXHXCmrFXbWyu2VlqzS7+YdaNWx243pJQImV6LoaSjZFokhijHeZaBBULuE+AgPKZYIyzDDv/tFnKsNlunu6e/rt6Xk+VVPdfW79UvVwus85/TuvuTsixVYRXYD0TgqehFDwJISCJyEUPAmh4EkIBU9CKHgSQsGTEH2L+WaV1t+rqO58AbO061ufDP9PDrWlnZ3xGo2u4uStid273H1kpuXyCp6ZTQd+APQBfuzuc9MtX0U1X6y4uPPtVVamfb+KmkFp5/uf9qefnyFY3tycdr5k9oI/vSmb5XL+qDWzPsADwOXAROBaM5uY6/akd8nnO95kYJ27r3f3g8CTwIzClCXlLp/gjQa2dHi9NZl2GDObY2bLzWx5C/ook5RuP6p190fcfZK7T+pH/+5+O+kh8gleA3Bih9cnJNNEMsoneG8B481srJlVAtcAiwtTlpS7nE+nuHurmd0GPEfqdMoCd1+VdiWztKdMKgYOTL96htMp1jf9P6d1R2Pa+VI8eZ3Hc/clwJIC1SK9iC6ZSQgFT0IoeBKiqD8SKKR+lX353NljOHnCKCoqjLY2Z/PKzbxfv4HmT1uiy5MMelzwaoYO5Ou3XsSV151L1cCjj5BbDrbyxtKVPPEfv2T9Kp1WLFU9KnjnXnIa3773GqprBgDwx/caWP2HzTR/2kJlVT8mnHEC404bzdQr65h6ZR2/fuZNHvjOU3zSdCC4cjlSaQVv5PBOZ13x9S/yrXtm0Keigpc2bOTe3/2O9xp3HrbMwHpj5OBqbpz2l8w67/NcePVkTvnSeG6b/ywbd+7msw/uTfv2bQcU0GLpEQcXZ08Zzz/860z6VFQw7+VX+Ntnfn5U6Nrt3Lefexe/xNX3Pc77DY2MGTmM+bd+lVHDaopctaRT8sEbMqyaf5w7C4B5r7zKD19/I6v1Nu/aw+wf/g9vrtvCcUMG8dA3Z1IzLM2vn6WoSj541956IcNH1rDirfX8d5aha/fpwVbuePT/WLttF+NqP8Mt/35DN1UpXVXSwauuqWL6174AwIPfW0xbDmMimg40c/ujiznQ0srF10+lbpp+JF0KSjp4508/g6oBlbz96lo2rNme83a2fLSXh5e+DqC9Xoko6eCdOXkcAK++kP5HL9l47MV6Ptq+h7Gnn8iZ5/9F3tuT/JR08M6YNBaAFW9tyHtbrYfaWDL/1wB85ZZL8t6e5Keo5/EMsDRjZ33An69EVPbvy8hRQzl4sJVNDXvwAZWsveHBtNuvbz6Ydv4Pvnoj3/iXv2by5XX0H11LS3PrYfPbtupKR7GU7B6vZnDq6kTT3k8yjofN1kc79vLHVQ30r6rkc2efXJBtSm5KNniDaqoA2N9U2JFpq95aD8CEujEF3a50TckG75NPUh+bAwcVdmTahxtTVzxGjBpa0O1K15Rs8PZ8nLodxdACX23Ys7Mptd0R6cdvSPcq2eC1HGxl7+799O3Xh5HHDynYdv+071MABiXfISVGyQYPYM17HwJwWt1JBdtmdRK4/fqpVKiSDt47b6YOBL4wZXzBtlkzNDWEsmnPJwXbpnRdUc/jOelvFVbRcPi419eefIVv3nkZUy+cyEN3/oSz/+3WtNs/OCT9/fVOOrCasafWArBt/XZcv78LU9J7vA83NFL/m1VUDazk0uvOK8g2zzg3tfdc+dragmxPclPSwQNYPP9FAGb9/WUMqa7Ka1ujxx3HSaeO4tP9B1j7h80FqE5yVfLBe+P5d1nxyhqGjqjhjplfymtbV825EIDfLlpOa8uhQpQnOSr54AH85z/9lIPNLcw873S+ck5uv6ebcPwILrt+Cm1tbSx6eFmBK5Su6hHB27puB4/c/TQAd19/KeefPrZL6w/o15f7rrmCyv79+OXjL7Np9bbuKFO6oEcED+AXj/6W+c+9Sd8+Fcy7eQbXXXBWVusNGVjFj2+6mlNqP8Om1dt45O7/7eZKJRulNbwxg/969mUM+LvLJvPPX7uAS846lYeXvMbrHxx9oGAGl585gTsvm8LoYUP4cPc+vnfjQ7rLQImwYnboHmzDPV27gYpBGe5/V9kPgClfPovb77+BwcNTy2/ftIt3X1vD9i0f4w4jPzuMuvMnMGrMCADWvbuFe/7mIXatTn8kq3G1+XvBn65390mZlsu3z8VGoAk4BLRm84aF8Mov3ubtF9/nr26axsxbLub4MSM4PglZR7s+3M3j9y/hhafeoC1D8xUprrz2eEnwJrn7rmyWL9Qe77B1Koxxp5/AaeeMZ8iIwZgZ+3bv54P6Dax5ZxNtbX/+97Xt0Z0EultR9niloK3NWbdiC+tWbMH6HR1MKU35HtU68LyZ1ZvZnEIUJL1Dvnu8qe7eYGbHAUvN7AN3f6njAkkg5wBUkf7m2tJ75LXHc/eG5LERWESqzdSRy6jBihwlnyZ61WZW0/4cuBRYWajCpLzl81FbCyxKxsn2BX7m7r/KuFaao+i2pqY8ypGeJJ8GK+uBzxewFulFesy1WikvCp6EUPAkhIInIRQ8CaHgSQgFT0IoeBJCwZMQCp6EUPAkhIInIRQ8CaHgSQgFT0IoeBJCwZMQCp6EUPAkhIInIRQ8CaHgSQgFT0IoeBJCwZMQCp6EUPAkhIInIRQ8CaHgSQgFT0JkDJ6ZLTCzRjNb2WHacDNbamZrk8dh3VumlJts9niPAdOPmHYXsMzdxwPLktciWcsYvOQu7h8fMXkGsDB5vhC4qsB1SZnL9Tterbu3997cTup+yCJZy/vgwlM9qTq9o7aZzTGz5Wa2vIXmfN9OykSuwdthZqMAksfGzhZUnws5llyDtxiYnTyfDTxbmHKkt8jmdMoTwGvABDPbamY3AXOBS8xsLXBx8lokaxn7XLj7tZ3MuqjAtUgvoisXEkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQuTaYOW7ZtZgZu8kf1d0b5lSbnJtsAIwz93rkr8lhS1Lyl2uDVZE8pLPd7zbzGxF8lHcaS8z9bmQY8k1eA8CpwB1wDbg/s4WVJ8LOZacgufuO9z9kLu3AT8CJhe2LCl3OQWvvatPYiawsrNlRY4lY5+LpMHKNGCEmW0F7gGmmVkdqR5mG4Gbu7FGKUO5NliZ3w21SC+iKxcSQsGTEAqehFDwJISCJyEUPAmh4EkIBU9CKHgSQsGTEAqehFDwJISCJyEUPAmh4EkIBU9CKHgSQsGTEAqehFDwJISCJyEUPAmh4EkIBU9CKHgSQsGTEAqehFDwJISCJyEUPAmh4EmIbPpcnGhmvzGz98xslZndnkwfbmZLzWxt8tjpDbhFjpTNHq8V+La7TwTOAb5lZhOBu4Bl7j4eWJa8FslKNn0utrn775PnTcD7wGhgBrAwWWwhcFV3FSnlJ+OtaDsys5OBs4A3gFp335bM2g7UdrLOHGAOQBUDc61TykzWBxdmNgh4BrjD3fd1nOfuTupG3EdRnws5lqyCZ2b9SIXup+7+82Tyjva2A8ljY/eUKOUom6NaI3WX9/fd/fsdZi0GZifPZwPPFr48KVfZfMebAnwDeNfM3kmmfQeYCzxlZjcBm4BZ3VOilKNs+ly8DFgnsy8qbDnSW+jKhYRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EiKfPhffNbMGM3sn+bui+8uVcpHNHUHb+1z83sxqgHozW5rMm+fu93VfeVKusrkj6DZgW/K8ycza+1yI5KxL3/GO6HMBcJuZrTCzBWopJV2RT5+LB4FTgDpSe8T7O1lvjpktN7PlLTQXoGQpBzn3uXD3He5+yN3bgB8Bk4+1rhqsyLHk3OeivblKYiawsvDlSbnKp8/FtWZWR6qV1Ebg5m6pUMpSPn0ulhS+HOktdOVCQih4EkLBkxAKnoRQ8CSEgichFDwJYe5evDcz20mqqXK7EcCuohXQdaVeH5RejWPcfWSmhYoavKPe3Gy5u08KKyCDUq8PekaNx6KPWgmh4EmI6OA9Evz+mZR6fdAzajxK6Hc86b2i93jSS4UEz8ymm9lqM1tnZndF1JCJmW00s3eToZvLS6CeBWbWaGYrO0wbbmZLzWxt8thjxr0UPXhm1gd4ALgcmEjqB6UTi11Hli5w97oSOV3xGDD9iGl3AcvcfTywLHndI0Ts8SYD69x9vbsfBJ4EZgTU0aO4+0vAx0dMngEsTJ4vBK4qalF5iAjeaGBLh9dbKc1xug48b2b1ZjYnuphO1CbjngG2A7WRxXRFNmMuequp7t5gZscBS83sg2SvU5Lc3c2sx5yiiNjjNQAndnh9QjKtpLh7Q/LYCCyik+GbwXa0j/ZLHhuD68laRPDeAsab2VgzqwSuARYH1NEpM6tO7hODmVUDl1KawzcXA7OT57OBZwNr6ZKif9S6e6uZ3QY8B/QBFrj7qmLXkUEtsCg1pJi+wM/c/VeRBZnZE8A0YISZbQXuAeYCT5nZTaR+9TMrrsKu0ZULCaErFxJCwZMQCp6EUPAkhIInIRQ8CaHgSQgFT0L8Pyp94TKMVNfuAAAAAElFTkSuQmCC\n",
|
||||
"text/plain": [
|
||||
"<Figure size 432x288 with 1 Axes>"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"needs_background": "light"
|
||||
},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"fig, ax = plt.subplots(1)\n",
|
||||
"img = df.iloc[0].Blobs\n",
|
||||
"xc, yc, a, b, theta = getEllipseParams(img)\n",
|
||||
"ax.imshow(img)\n",
|
||||
"e = Ellipse(xy=[xc,yc], width=a*2, height=b*2, angle=math.degrees(theta), fill=False, lw=2, edgecolor='w')\n",
|
||||
"ax.add_artist(e)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"lst = df.Blobs.apply(lambda x: getEllipseParams(x))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"lst2 = np.vstack(lst.values)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(618012, 5)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"lst2.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df[\"XC\"] = lst2[:,0]\n",
|
||||
"df[\"YC\"] = lst2[:,1]\n",
|
||||
"df[\"EllipseW\"] = lst2[:,2]\n",
|
||||
"df[\"EllipseH\"] = lst2[:,3]\n",
|
||||
"df[\"EllipseTheta\"] = lst2[:,4]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df[\"Area\"] = df[\"EllipseW\"] * df[\"EllipseH\"] * np.pi\n",
|
||||
"df[\"AvgCapa\"] = df.Blobs.apply(lambda x: np.mean(x))\n",
|
||||
"df[\"SumCapa\"] = df.Blobs.apply(lambda x: np.sum(x))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[8, 11, 6, 7, 16, 15, 14, 10, 9, 2, 3, 13, 17, 5, 12, 1, 4]"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"lst = list(range(1, df.userID.max()))\n",
|
||||
"SEED = 42#448\n",
|
||||
"random.seed(SEED)\n",
|
||||
"random.shuffle(lst)\n",
|
||||
"lst"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dfY = df[df.Set == \"Train\"].copy(deep=True)\n",
|
||||
"dfT = df[(df.Set == \"Test\") & (df.Version == \"Normal\")].copy(deep=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"minmax = min(len(dfY[dfY.Input == \"Finger\"]), len(dfY[dfY.Input == \"Knuckle\"]))\n",
|
||||
"dfX = dfY[dfY.Input == \"Finger\"].sample(minmax)\n",
|
||||
"dfZ = dfY[dfY.Input == \"Knuckle\"].sample(minmax)\n",
|
||||
"dfY = pd.concat([dfX,dfZ])\n",
|
||||
"\n",
|
||||
"minmax = min(len(dfT[dfT.Input == \"Finger\"]), len(dfT[dfT.Input == \"Knuckle\"]))\n",
|
||||
"dfX = dfT[dfT.Input == \"Finger\"].sample(minmax)\n",
|
||||
"dfZ = dfT[dfT.Input == \"Knuckle\"].sample(minmax)\n",
|
||||
"dfT = pd.concat([dfX,dfZ])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>userID</th>\n",
|
||||
" <th>Timestamp</th>\n",
|
||||
" <th>Current_Task</th>\n",
|
||||
" <th>Task_amount</th>\n",
|
||||
" <th>TaskID</th>\n",
|
||||
" <th>VersionID</th>\n",
|
||||
" <th>RepetitionID</th>\n",
|
||||
" <th>Actual_Data</th>\n",
|
||||
" <th>Is_Pause</th>\n",
|
||||
" <th>Image</th>\n",
|
||||
" <th>...</th>\n",
|
||||
" <th>InputMethod</th>\n",
|
||||
" <th>Set</th>\n",
|
||||
" <th>XC</th>\n",
|
||||
" <th>YC</th>\n",
|
||||
" <th>EllipseW</th>\n",
|
||||
" <th>EllipseH</th>\n",
|
||||
" <th>EllipseTheta</th>\n",
|
||||
" <th>Area</th>\n",
|
||||
" <th>AvgCapa</th>\n",
|
||||
" <th>SumCapa</th>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>Input</th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" <th></th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>Finger</th>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>Knuckle</th>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" <td>9421</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>2 rows × 31 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" userID Timestamp Current_Task Task_amount TaskID VersionID \\\n",
|
||||
"Input \n",
|
||||
"Finger 9421 9421 9421 9421 9421 9421 \n",
|
||||
"Knuckle 9421 9421 9421 9421 9421 9421 \n",
|
||||
"\n",
|
||||
" RepetitionID Actual_Data Is_Pause Image ... InputMethod Set \\\n",
|
||||
"Input ... \n",
|
||||
"Finger 9421 9421 9421 9421 ... 9421 9421 \n",
|
||||
"Knuckle 9421 9421 9421 9421 ... 9421 9421 \n",
|
||||
"\n",
|
||||
" XC YC EllipseW EllipseH EllipseTheta Area AvgCapa SumCapa \n",
|
||||
"Input \n",
|
||||
"Finger 9421 9421 9421 9421 9421 9421 9421 9421 \n",
|
||||
"Knuckle 9421 9421 9421 9421 9421 9421 9421 9421 \n",
|
||||
"\n",
|
||||
"[2 rows x 31 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dfT.groupby(\"Input\").count()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# FEATURE SET: sum of capacitance, avg of capacitance, ellipse area, ellipse width, height and theta."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"features = [\"SumCapa\", \"AvgCapa\", \"Area\", \"EllipseW\", \"EllipseH\", \"EllipseTheta\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ZeroR"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dfT[\"InputMethodPred\"] = 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[ 0 9421]\n",
|
||||
" [ 0 9421]]\n",
|
||||
"Accuray: 0.50\n",
|
||||
"Recall: 0.50\n",
|
||||
"Precision: 0.50\n",
|
||||
"F1-Score: 0.33\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" Knuckle 0.00 0.00 0.00 9421\n",
|
||||
" Finger 0.50 1.00 0.67 9421\n",
|
||||
"\n",
|
||||
" micro avg 0.50 0.50 0.50 18842\n",
|
||||
" macro avg 0.25 0.50 0.33 18842\n",
|
||||
"weighted avg 0.25 0.50 0.33 18842\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.6/dist-packages/sklearn/metrics/classification.py:1143: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.\n",
|
||||
" 'precision', 'predicted', average, warn_for)\n",
|
||||
"/usr/local/lib/python3.6/dist-packages/sklearn/metrics/classification.py:1143: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n",
|
||||
" 'precision', 'predicted', average, warn_for)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
|
||||
"print(\"Accuray: %.2f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"Recall: %.2f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(\"Precision: %.2f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(\"F1-Score: %.2f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# DecisionTreeClassifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fitting 5 folds for each of 240 candidates, totalling 1200 fits\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Parallel(n_jobs=30)]: Using backend LokyBackend with 30 concurrent workers.\n",
|
||||
"[Parallel(n_jobs=30)]: Done 140 tasks | elapsed: 10.4s\n",
|
||||
"[Parallel(n_jobs=30)]: Done 390 tasks | elapsed: 31.4s\n",
|
||||
"[Parallel(n_jobs=30)]: Done 740 tasks | elapsed: 1.3min\n",
|
||||
"[Parallel(n_jobs=30)]: Done 1200 out of 1200 | elapsed: 2.4min finished\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'max_depth': 22, 'min_samples_split': 2} 0.8120637794585754\n",
|
||||
"[[7409 2012]\n",
|
||||
" [3096 6325]]\n",
|
||||
"Accuray: 0.73\n",
|
||||
"Recall: 0.73\n",
|
||||
"Precision: 0.67\n",
|
||||
"F1-Score: 0.73\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" Knuckle 0.71 0.79 0.74 9421\n",
|
||||
" Finger 0.76 0.67 0.71 9421\n",
|
||||
"\n",
|
||||
" micro avg 0.73 0.73 0.73 18842\n",
|
||||
" macro avg 0.73 0.73 0.73 18842\n",
|
||||
"weighted avg 0.73 0.73 0.73 18842\n",
|
||||
"\n",
|
||||
"CPU times: user 7.26 s, sys: 3.38 s, total: 10.6 s\n",
|
||||
"Wall time: 2min 29s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"param_grid = {'max_depth': range(2,32,1),\n",
|
||||
" 'min_samples_split':range(2,10,1)}\n",
|
||||
"#TODO: Create Baseline for different ML stuff\n",
|
||||
"clf = GridSearchCV(tree.DecisionTreeClassifier(), \n",
|
||||
" param_grid,\n",
|
||||
" cv=5 , n_jobs=os.cpu_count()-2, verbose=1)\n",
|
||||
"clf.fit(dfY[features].values, dfY.InputMethod.values)\n",
|
||||
"print(clf.best_params_, clf.best_score_)\n",
|
||||
"dfT[\"InputMethodPred\"] = clf.predict(dfT[features].values) \n",
|
||||
"\n",
|
||||
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
|
||||
"print(\"Accuray: %.3f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"Recall: %.3f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(\"Precision: %.3f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(\"F1-Score: %.3f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# RandomForestClassifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fitting 5 folds for each of 180 candidates, totalling 900 fits\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Parallel(n_jobs=94)]: Using backend LokyBackend with 94 concurrent workers.\n",
|
||||
"[Parallel(n_jobs=94)]: Done 12 tasks | elapsed: 1.2min\n",
|
||||
"[Parallel(n_jobs=94)]: Done 262 tasks | elapsed: 4.0min\n",
|
||||
"[Parallel(n_jobs=94)]: Done 612 tasks | elapsed: 9.2min\n",
|
||||
"[Parallel(n_jobs=94)]: Done 900 out of 900 | elapsed: 12.8min finished\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'max_depth': 60, 'n_estimators': 63} 0.8669582104371696\n",
|
||||
"[[8175 1246]\n",
|
||||
" [2765 6656]]\n",
|
||||
"Accuray: 0.79\n",
|
||||
"Recall: 0.71\n",
|
||||
"Precision: 0.74\n",
|
||||
"F1-Score: 0.77\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" Knuckle 0.75 0.87 0.80 9421\n",
|
||||
" Finger 0.84 0.71 0.77 9421\n",
|
||||
"\n",
|
||||
" micro avg 0.79 0.79 0.79 18842\n",
|
||||
" macro avg 0.79 0.79 0.79 18842\n",
|
||||
"weighted avg 0.79 0.79 0.79 18842\n",
|
||||
"\n",
|
||||
"CPU times: user 42.1 s, sys: 834 ms, total: 42.9 s\n",
|
||||
"Wall time: 13min 28s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"param_grid = {'n_estimators': range(55,64,1),\n",
|
||||
" 'max_depth': range(50,70,1)}\n",
|
||||
"#TODO: Create Baseline for different ML stuff\n",
|
||||
"clf = GridSearchCV(ensemble.RandomForestClassifier(), \n",
|
||||
" param_grid,\n",
|
||||
" cv=5 , n_jobs=os.cpu_count()-2, verbose=1)\n",
|
||||
"clf.fit(dfY[features].values, dfY.InputMethod.values)\n",
|
||||
"print(clf.best_params_, clf.best_score_)\n",
|
||||
"dfT[\"InputMethodPred\"] = clf.predict(dfT[features].values) \n",
|
||||
"\n",
|
||||
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
|
||||
"print(\"Accuray: %.2f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"Recall: %.2f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"Precision: %.2f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"F1-Score: %.2f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# kNN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fitting 5 folds for each of 62 candidates, totalling 310 fits\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Parallel(n_jobs=94)]: Using backend LokyBackend with 94 concurrent workers.\n",
|
||||
"[Parallel(n_jobs=94)]: Done 12 tasks | elapsed: 17.7s\n",
|
||||
"[Parallel(n_jobs=94)]: Done 310 out of 310 | elapsed: 1.5min finished\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'n_neighbors': 2} 0.800546827088748\n",
|
||||
"[[8187 1234]\n",
|
||||
" [4318 5103]]\n",
|
||||
"Accuray: 0.71\n",
|
||||
"Recall: 0.54\n",
|
||||
"Precision: 0.67\n",
|
||||
"F1-Score: 0.65\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" Knuckle 0.65 0.87 0.75 9421\n",
|
||||
" Finger 0.81 0.54 0.65 9421\n",
|
||||
"\n",
|
||||
" micro avg 0.71 0.71 0.71 18842\n",
|
||||
" macro avg 0.73 0.71 0.70 18842\n",
|
||||
"weighted avg 0.73 0.71 0.70 18842\n",
|
||||
"\n",
|
||||
"CPU times: user 1.74 s, sys: 300 ms, total: 2.04 s\n",
|
||||
"Wall time: 1min 30s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"param_grid = {'n_neighbors': range(2,64,1),\n",
|
||||
" #weights': ['uniform', 'distance']\n",
|
||||
" }\n",
|
||||
"#TODO: Create Baseline for different ML stuff\n",
|
||||
"clf = GridSearchCV(neighbors.KNeighborsClassifier(),\n",
|
||||
" param_grid,\n",
|
||||
" cv=5 , n_jobs=os.cpu_count()-2, verbose=1)\n",
|
||||
"clf.fit(dfY[features].values, dfY.InputMethod.values)\n",
|
||||
"print(clf.best_params_, clf.best_score_)\n",
|
||||
"dfT[\"InputMethodPred\"] = clf.predict(dfT[features].values) \n",
|
||||
"\n",
|
||||
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
|
||||
"print(\"Accuray: %.2f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"Recall: %.2f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(\"Precision: %.2f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(\"F1-Score: %.2f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
|
||||
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SVM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fitting 5 folds for each of 9 candidates, totalling 45 fits\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Parallel(n_jobs=94)]: Using backend LokyBackend with 94 concurrent workers.\n",
|
||||
"[Parallel(n_jobs=94)]: Done 42 out of 45 | elapsed: 1056.5min remaining: 75.5min\n",
|
||||
"[Parallel(n_jobs=94)]: Done 45 out of 45 | elapsed: 1080.5min finished\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'C': 10.0, 'gamma': 10.0} 0.8256943024851795\n",
|
||||
"CPU times: user 2h 42min 9s, sys: 23.6 s, total: 2h 42min 33s\n",
|
||||
"Wall time: 20h 43min 1s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"C_range = np.logspace(1, 3,3)\n",
|
||||
"gamma_range = np.logspace(-1, 1, 3)\n",
|
||||
"param_grid = dict(gamma=gamma_range, C=C_range)\n",
|
||||
"clf = GridSearchCV(sklearn.svm.SVC(), \n",
|
||||
" param_grid,\n",
|
||||
" cv=5 , n_jobs=os.cpu_count()-2, verbose=1)\n",
|
||||
"clf.fit(dfY[features].values, dfY.InputMethod.values)\n",
|
||||
"print(clf.best_params_, clf.best_score_)\n",
|
||||
"\n",
|
||||
"dfT[\"InputMethodPred\"] = clf.predict(dfT[features].values)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'C': 10.0, 'gamma': 10.0} 0.8256943024851795\n",
|
||||
"[[7106 2315]\n",
|
||||
" [2944 6477]]\n",
|
||||
"Accuray: 0.72\n",
|
||||
"Recall: 0.69\n",
|
||||
"Precision: 0.66\n",
|
||||
"F1-Score: 0.71\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" Knuckle 0.71 0.75 0.73 9421\n",
|
||||
" Finger 0.74 0.69 0.71 9421\n",
|
||||
"\n",
|
||||
" micro avg 0.72 0.72 0.72 18842\n",
|
||||
" macro avg 0.72 0.72 0.72 18842\n",
|
||||
"weighted avg 0.72 0.72 0.72 18842\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(clf.best_params_, clf.best_score_)\n",
|
||||
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
|
||||
"print(\"Accuray: %.2f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"Recall: %.2f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"Precision: %.2f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(\"F1-Score: %.2f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
|
||||
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
13240
python/Step_07_CNN.ipynb
Normal file
13240
python/Step_07_CNN.ipynb
Normal file
File diff suppressed because one or more lines are too long
430
python/Step_08_CNN-Report.ipynb
Normal file
430
python/Step_08_CNN-Report.ipynb
Normal file
File diff suppressed because one or more lines are too long
152
python/Step_09_LSTM_ReadData.ipynb
Normal file
152
python/Step_09_LSTM_ReadData.ipynb
Normal file
|
@ -0,0 +1,152 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Filtering the data for the LSTM: removes all the rows, where we used the revert button, when the participant performed a wrong gesture\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"from scipy.odr import *\n",
|
||||
"from scipy.stats import *\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import ast\n",
|
||||
"from multiprocessing import Pool, cpu_count\n",
|
||||
"\n",
|
||||
"import scipy\n",
|
||||
"\n",
|
||||
"from IPython import display\n",
|
||||
"from matplotlib.patches import Rectangle\n",
|
||||
"\n",
|
||||
"from sklearn.metrics import mean_squared_error\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"import scipy.stats as st\n",
|
||||
"from sklearn.metrics import r2_score\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from matplotlib import cm\n",
|
||||
"from mpl_toolkits.mplot3d import axes3d\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import copy\n",
|
||||
"\n",
|
||||
"from sklearn.model_selection import LeaveOneOut, LeavePOut\n",
|
||||
"\n",
|
||||
"from multiprocessing import Pool\n",
|
||||
"import cv2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dfAll = pd.read_pickle(\"DataStudyCollection/AllData.pkl\")\n",
|
||||
"dfAll.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df_actual = dfAll[(dfAll.Actual_Data == True) & (dfAll.Is_Pause == False)]\n",
|
||||
"df_actual.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"all: %s, actual data: %s\" % (len(dfAll), len(df_actual)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# filter out all gestures, where the revert button was pressed during the study and the gestrue was repeated\n",
|
||||
"def is_max(df):\n",
|
||||
" df_temp = df.copy(deep=True)\n",
|
||||
" max_version = df_temp.RepetitionID.max()\n",
|
||||
" df_temp[\"IsMax\"] = np.where(df_temp.RepetitionID == max_version, True, False)\n",
|
||||
" df_temp[\"MaxRepetition\"] = [max_version] * len(df_temp)\n",
|
||||
" return df_temp\n",
|
||||
"\n",
|
||||
"df_filtered = df_actual.copy(deep=True)\n",
|
||||
"df_grp = df_filtered.groupby([df_filtered.userID, df_filtered.TaskID, df_filtered.VersionID])\n",
|
||||
"pool = Pool(cpu_count() - 1)\n",
|
||||
"result_lst = pool.map(is_max, [grp for name, grp in df_grp])\n",
|
||||
"df_filtered = pd.concat(result_lst)\n",
|
||||
"df_filtered = df_filtered[df_filtered.IsMax == True]\n",
|
||||
"pool.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df_filtered.to_pickle(\"DataStudyCollection/df_lstm.pkl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"actual: %s, filtered data: %s\" % (len(df_actual), len(df_filtered)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
1286
python/Step_10_LSTM_Preprocessing.ipynb
Normal file
1286
python/Step_10_LSTM_Preprocessing.ipynb
Normal file
File diff suppressed because one or more lines are too long
2550
python/Step_11_LSTM.ipynb
Normal file
2550
python/Step_11_LSTM.ipynb
Normal file
File diff suppressed because it is too large
Load diff
10107
python/Step_12_LSTM-WarmStart.ipynb
Normal file
10107
python/Step_12_LSTM-WarmStart.ipynb
Normal file
File diff suppressed because one or more lines are too long
1052
python/Step_13_LSTM-Report.ipynb
Normal file
1052
python/Step_13_LSTM-Report.ipynb
Normal file
File diff suppressed because it is too large
Load diff
427
python/Step_15_CNN-Ensemble-Report.ipynb
Normal file
427
python/Step_15_CNN-Ensemble-Report.ipynb
Normal file
File diff suppressed because one or more lines are too long
494
python/Step_20_Paper_Figure.ipynb
Normal file
494
python/Step_20_Paper_Figure.ipynb
Normal file
|
@ -0,0 +1,494 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.read_pickle(\"DataStudyCollection/df_statistics.pkl\")"
|
||||
]
|
||||