Added notebooks and models

This commit is contained in:
mayersn 2019-08-07 17:57:12 -04:00
parent f4ce991790
commit 0fb6efadd5
28 changed files with 37002 additions and 0 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
python/Models/CNN.pb Normal file

Binary file not shown.

BIN
python/Models/LSTM.pb Normal file

Binary file not shown.

View File

@ -0,0 +1,340 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from scipy.odr import *\n",
"from scipy.stats import *\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import time\n",
"import matplotlib.pyplot as plt\n",
"from multiprocessing import Pool"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def cast_to_int(row):\n",
" try:\n",
" return np.array([a if float(a) >= 0 else 0 for a in row[2:-1]], dtype=np.uint8)\n",
" except Exception as e:\n",
" return None\n",
" \n",
"def load_csv(file):\n",
" temp_df = pd.read_csv(file, header=None, names = [\"UserID\", \"Age\", \"Gender\"], delimiter=\";\")\n",
" return temp_df"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 298 ms, sys: 443 ms, total: 741 ms\n",
"Wall time: 937 ms\n"
]
}
],
"source": [
"%%time\n",
"pool = Pool(os.cpu_count() - 2)\n",
"data_files = [\"DataStudyCollection/%s\" % file for file in os.listdir(\"DataStudyCollection\") if file.endswith(\".csv\") and \"userData\" in file]\n",
"df_lst = pool.map(load_csv, data_files)\n",
"dfAll = pd.concat(df_lst)\n",
"dfAll = dfAll.sort_values(\"UserID\")\n",
"dfAll = dfAll.reset_index(drop=True)\n",
"pool.close()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"24.166666666666668"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfAll.Age.mean()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.4245742398014511"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfAll.Age.std()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"21"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfAll.Age.min()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"26"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfAll.Age.max()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>UserID</th>\n",
" <th>Age</th>\n",
" <th>Gender</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>23</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>24</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>25</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>25</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>26</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>6</td>\n",
" <td>23</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>7</td>\n",
" <td>21</td>\n",
" <td>female</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>8</td>\n",
" <td>24</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>9</td>\n",
" <td>24</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>10</td>\n",
" <td>24</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>11</td>\n",
" <td>25</td>\n",
" <td>female</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>12</td>\n",
" <td>26</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>13</td>\n",
" <td>22</td>\n",
" <td>female</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>14</td>\n",
" <td>24</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>15</td>\n",
" <td>24</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>16</td>\n",
" <td>26</td>\n",
" <td>female</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>17</td>\n",
" <td>26</td>\n",
" <td>male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>18</td>\n",
" <td>23</td>\n",
" <td>male</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" UserID Age Gender\n",
"0 1 23 male\n",
"1 2 24 male\n",
"2 3 25 male\n",
"3 4 25 male\n",
"4 5 26 male\n",
"5 6 23 male\n",
"6 7 21 female\n",
"7 8 24 male\n",
"8 9 24 male\n",
"9 10 24 male\n",
"10 11 25 female\n",
"11 12 26 male\n",
"12 13 22 female\n",
"13 14 24 male\n",
"14 15 24 male\n",
"15 16 26 female\n",
"16 17 26 male\n",
"17 18 23 male"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfAll"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,338 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## This notebook creates one dataframe from all participants data\n",
"## It also removes 1% of the data as this is corrupted"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from scipy.odr import *\n",
"from scipy.stats import *\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import time\n",
"import matplotlib.pyplot as plt\n",
"import ast\n",
"from multiprocessing import Pool, cpu_count\n",
"\n",
"import scipy\n",
"\n",
"from IPython import display\n",
"from matplotlib.patches import Rectangle\n",
"\n",
"from sklearn.metrics import mean_squared_error\n",
"import json\n",
"\n",
"import scipy.stats as st\n",
"from sklearn.metrics import r2_score\n",
"\n",
"\n",
"from matplotlib import cm\n",
"from mpl_toolkits.mplot3d import axes3d\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import copy\n",
"\n",
"from sklearn.model_selection import LeaveOneOut, LeavePOut\n",
"\n",
"from multiprocessing import Pool"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def cast_to_int(row):\n",
" try:\n",
" return np.array([a if float(a) >= 0 else 0 for a in row[2:-1]], dtype=np.uint8)\n",
" except Exception as e:\n",
" return None\n",
" \n",
"def load_csv(file):\n",
" temp_df = pd.read_csv(file, delimiter=\";\")\n",
" temp_df.Image = temp_df.Image.str.split(',')\n",
" temp_df.Image = temp_df.Image.apply(cast_to_int)\n",
" return temp_df"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['DataStudyCollection/17_studyData.csv', 'DataStudyCollection/2_studyData.csv', 'DataStudyCollection/12_studyData.csv', 'DataStudyCollection/15_studyData.csv', 'DataStudyCollection/5_studyData.csv', 'DataStudyCollection/1_studyData.csv', 'DataStudyCollection/14_studyData.csv', 'DataStudyCollection/10_studyData.csv', 'DataStudyCollection/13_studyData.csv', 'DataStudyCollection/18_studyData.csv', 'DataStudyCollection/6_studyData.csv', 'DataStudyCollection/16_studyData.csv', 'DataStudyCollection/3_studyData.csv', 'DataStudyCollection/7_studyData.csv', 'DataStudyCollection/8_studyData.csv', 'DataStudyCollection/9_studyData.csv', 'DataStudyCollection/11_studyData.csv', 'DataStudyCollection/4_studyData.csv']\n",
"CPU times: user 1.86 s, sys: 1.03 s, total: 2.89 s\n",
"Wall time: 17.3 s\n"
]
}
],
"source": [
"%%time\n",
"pool = Pool(cpu_count() - 2)\n",
"data_files = [\"DataStudyCollection/%s\" % file for file in os.listdir(\"DataStudyCollection\") if file.endswith(\".csv\") and \"studyData\" in file]\n",
"print(data_files)\n",
"df_lst = pool.map(load_csv, data_files)\n",
"dfAll = pd.concat(df_lst)\n",
"pool.close()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1010014"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = dfAll[dfAll.Image.notnull()]\n",
"len(df)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded 1013841 values\n",
"removed 3827 values (thats 0.377%)\n",
"new df has size 1010014\n"
]
}
],
"source": [
"print(\"loaded %s values\" % len(dfAll))\n",
"print(\"removed %s values (thats %s%%)\" % (len(dfAll) - len(df), round((len(dfAll) - len(df)) / len(dfAll) * 100, 3)))\n",
"print(\"new df has size %s\" % len(df))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"df = df.reset_index(drop=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>userID</th>\n",
" <th>Timestamp</th>\n",
" <th>Current_Task</th>\n",
" <th>Task_amount</th>\n",
" <th>TaskID</th>\n",
" <th>VersionID</th>\n",
" <th>RepetitionID</th>\n",
" <th>Actual_Data</th>\n",
" <th>Is_Pause</th>\n",
" <th>Image</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>17</td>\n",
" <td>1547138602677</td>\n",
" <td>0</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>17</td>\n",
" <td>1547138602697</td>\n",
" <td>0</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>17</td>\n",
" <td>1547138602796</td>\n",
" <td>0</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>17</td>\n",
" <td>1547138602817</td>\n",
" <td>0</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>17</td>\n",
" <td>1547138602863</td>\n",
" <td>0</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>[1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" userID Timestamp Current_Task Task_amount TaskID VersionID \\\n",
"0 17 1547138602677 0 34 0 0 \n",
"1 17 1547138602697 0 34 0 0 \n",
"2 17 1547138602796 0 34 0 0 \n",
"3 17 1547138602817 0 34 0 0 \n",
"4 17 1547138602863 0 34 0 0 \n",
"\n",
" RepetitionID Actual_Data Is_Pause \\\n",
"0 0 False False \n",
"1 0 False False \n",
"2 0 False False \n",
"3 0 False False \n",
"4 0 False False \n",
"\n",
" Image \n",
"0 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... \n",
"1 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... \n",
"2 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... \n",
"3 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... \n",
"4 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, ... "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"df.to_pickle(\"DataStudyCollection/AllData.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sorted(df.userID.unique())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,890 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"## USE for Multi GPU Systems\n",
"#import os\n",
"#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
"\n",
"%matplotlib inline\n",
"\n",
"from scipy.odr import *\n",
"from scipy.stats import *\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import time\n",
"import matplotlib.pyplot as plt\n",
"import ast\n",
"from multiprocessing import Pool\n",
"\n",
"import scipy\n",
"\n",
"from IPython import display\n",
"from matplotlib.patches import Rectangle\n",
"\n",
"from sklearn.metrics import mean_squared_error\n",
"import json\n",
"\n",
"import scipy.stats as st\n",
"from sklearn.metrics import r2_score\n",
"\n",
"\n",
"from matplotlib import cm\n",
"from mpl_toolkits.mplot3d import axes3d\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Ellipse\n",
"\n",
"import copy\n",
"\n",
"from sklearn.model_selection import LeaveOneOut, LeavePOut\n",
"\n",
"from multiprocessing import Pool\n",
"import cv2\n",
"\n",
"import sklearn\n",
"import random\n",
"from sklearn import neighbors\n",
"from sklearn import svm\n",
"from sklearn import tree\n",
"from sklearn import ensemble\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics import classification_report\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import math\n",
"\n",
"# Importing matplotlib to plot images.\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"%matplotlib inline\n",
"\n",
"# Importing SK-learn to calculate precision and recall\n",
"import sklearn\n",
"from sklearn import metrics\n",
"from sklearn.model_selection import train_test_split, cross_val_score, LeaveOneGroupOut\n",
"from sklearn.utils import shuffle\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics.pairwise import euclidean_distances\n",
"from sklearn.metrics import confusion_matrix\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"import pickle as pkl\n",
"import h5py\n",
"\n",
"from pathlib import Path\n",
"import os.path\n",
"import sys\n",
"import datetime\n",
"import time\n",
"\n",
"import skimage\n",
"\n",
"target_names = [\"Knuckle\", \"Finger\"]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from skimage import measure\n",
"from skimage.measure import find_contours, approximate_polygon, \\\n",
" subdivide_polygon, EllipseModel, LineModelND"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def getEllipseParams(img):\n",
" points = np.argwhere(img > 40)\n",
" \n",
" contours = skimage.measure.find_contours(img, 40)\n",
" points_to_approx = []\n",
" highest_val = 0\n",
" for n, contour in enumerate(contours):\n",
" if (len(contour) > highest_val):\n",
" points_to_approx = contour\n",
" highest_val = len(contour) \n",
" \n",
" try:\n",
" contour = np.fliplr(points_to_approx)\n",
" except Exception as inst:\n",
" return [-1, -1, -1, -1, -1]\n",
" \n",
"\n",
" ellipse = skimage.measure.fit.EllipseModel()\n",
" ellipse.estimate(contour)\n",
" try:\n",
" xc, yc, a, b, theta = ellipse.params \n",
" except Exception as int:\n",
" return [-1, -1, -1, -1, -1]\n",
" \n",
" return [xc, yc, a, b, theta]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1 2 9 6 4 14 17 16 12 3 10 18 5] [13 8 11 15 7]\n",
"13 : 5\n",
"0.7222222222222222 : 0.2777777777777778\n"
]
}
],
"source": [
"# the data, split between train and test sets\n",
"df = pd.read_pickle(\"DataStudyCollection/df_statistics.pkl\")\n",
"\n",
"lst = df.userID.unique()\n",
"np.random.seed(42)\n",
"np.random.shuffle(lst)\n",
"test_ids = lst[-5:]\n",
"train_ids = lst[:-5]\n",
"\n",
"df[\"Set\"] = \"Test\"\n",
"df.loc[df.userID.isin(train_ids), \"Set\"] = \"Train\"\n",
"print(train_ids, test_ids)\n",
"print(len(train_ids), \":\", len(test_ids))\n",
"print(len(train_ids) / len(lst), \":\", len(test_ids)/ len(lst))\n",
"\n",
"#df_train = df[df.userID.isin(train_ids)]\n",
"#df_test = df[df.userID.isin(test_ids) & (df.Version == \"Normal\")]\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.patches.Ellipse at 0x7ff60430b668>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAJ4AAAD8CAYAAACGuR0qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADetJREFUeJzt3XuQVPWZxvHvO8AwMAy3gCNBRbCQDV4yuoSoYArvqNkga0K8ZXHXCmrFXbWyu2VlqzS7+YdaNWx243pJQImV6LoaSjZFokhijHeZaBBULuE+AgPKZYIyzDDv/tFnKsNlunu6e/rt6Xk+VVPdfW79UvVwus85/TuvuTsixVYRXYD0TgqehFDwJISCJyEUPAmh4EkIBU9CKHgSQsGTEH2L+WaV1t+rqO58AbO061ufDP9PDrWlnZ3xGo2u4uStid273H1kpuXyCp6ZTQd+APQBfuzuc9MtX0U1X6y4uPPtVVamfb+KmkFp5/uf9qefnyFY3tycdr5k9oI/vSmb5XL+qDWzPsADwOXAROBaM5uY6/akd8nnO95kYJ27r3f3g8CTwIzClCXlLp/gjQa2dHi9NZl2GDObY2bLzWx5C/ook5RuP6p190fcfZK7T+pH/+5+O+kh8gleA3Bih9cnJNNEMsoneG8B481srJlVAtcAiwtTlpS7nE+nuHurmd0GPEfqdMoCd1+VdiWztKdMKgYOTL96htMp1jf9P6d1R2Pa+VI8eZ3Hc/clwJIC1SK9iC6ZSQgFT0IoeBKiqD8SKKR+lX353NljOHnCKCoqjLY2Z/PKzbxfv4HmT1uiy5MMelzwaoYO5Ou3XsSV151L1cCjj5BbDrbyxtKVPPEfv2T9Kp1WLFU9KnjnXnIa3773GqprBgDwx/caWP2HzTR/2kJlVT8mnHEC404bzdQr65h6ZR2/fuZNHvjOU3zSdCC4cjlSaQVv5PBOZ13x9S/yrXtm0Keigpc2bOTe3/2O9xp3HrbMwHpj5OBqbpz2l8w67/NcePVkTvnSeG6b/ywbd+7msw/uTfv2bQcU0GLpEQcXZ08Zzz/860z6VFQw7+VX+Ntnfn5U6Nrt3Lefexe/xNX3Pc77DY2MGTmM+bd+lVHDaopctaRT8sEbMqyaf5w7C4B5r7zKD19/I6v1Nu/aw+wf/g9vrtvCcUMG8dA3Z1IzLM2vn6WoSj541956IcNH1rDirfX8d5aha/fpwVbuePT/WLttF+NqP8Mt/35DN1UpXVXSwauuqWL6174AwIPfW0xbDmMimg40c/ujiznQ0srF10+lbpp+JF0KSjp4508/g6oBlbz96lo2rNme83a2fLSXh5e+DqC9Xoko6eCdOXkcAK++kP5HL9l47MV6Ptq+h7Gnn8iZ5/9F3tuT/JR08M6YNBaAFW9tyHtbrYfaWDL/1wB85ZZL8t6e5Keo5/EMsDRjZ33An69EVPbvy8hRQzl4sJVNDXvwAZWsveHBtNuvbz6Ydv4Pvnoj3/iXv2by5XX0H11LS3PrYfPbtupKR7GU7B6vZnDq6kTT3k8yjofN1kc79vLHVQ30r6rkc2efXJBtSm5KNniDaqoA2N9U2JFpq95aD8CEujEF3a50TckG75NPUh+bAwcVdmTahxtTVzxGjBpa0O1K15Rs8PZ8nLodxdACX23Ys7Mptd0R6cdvSPcq2eC1HGxl7+799O3Xh5HHDynYdv+071MABiXfISVGyQYPYM17HwJwWt1JBdtmdRK4/fqpVKiSDt47b6YOBL4wZXzBtlkzNDWEsmnPJwXbpnRdUc/jOelvFVbRcPi419eefIVv3nkZUy+cyEN3/oSz/+3WtNs/OCT9/fVOOrCasafWArBt/XZcv78LU9J7vA83NFL/m1VUDazk0uvOK8g2zzg3tfdc+dragmxPclPSwQNYPP9FAGb9/WUMqa7Ka1ujxx3HSaeO4tP9B1j7h80FqE5yVfLBe+P5d1nxyhqGjqjhjplfymtbV825EIDfLlpOa8uhQpQnOSr54AH85z/9lIPNLcw873S+ck5uv6ebcPwILrt+Cm1tbSx6eFmBK5Su6hHB27puB4/c/TQAd19/KeefPrZL6w/o15f7rrmCyv79+OXjL7Np9bbuKFO6oEcED+AXj/6W+c+9Sd8+Fcy7eQbXXXBWVusNGVjFj2+6mlNqP8Om1dt45O7/7eZKJRulNbwxg/969mUM+LvLJvPPX7uAS846lYeXvMbrHxx9oGAGl585gTsvm8LoYUP4cPc+vnfjQ7rLQImwYnboHmzDPV27gYpBGe5/V9kPgClfPovb77+BwcNTy2/ftIt3X1vD9i0f4w4jPzuMuvMnMGrMCADWvbuFe/7mIXatTn8kq3G1+XvBn65390mZlsu3z8VGoAk4BLRm84aF8Mov3ubtF9/nr26axsxbLub4MSM4PglZR7s+3M3j9y/hhafeoC1D8xUprrz2eEnwJrn7rmyWL9Qe77B1Koxxp5/AaeeMZ8iIwZgZ+3bv54P6Dax5ZxNtbX/+97Xt0Z0EultR9niloK3NWbdiC+tWbMH6HR1MKU35HtU68LyZ1ZvZnEIUJL1Dvnu8qe7eYGbHAUvN7AN3f6njAkkg5wBUkf7m2tJ75LXHc/eG5LERWESqzdSRy6jBihwlnyZ61WZW0/4cuBRYWajCpLzl81FbCyxKxsn2BX7m7r/KuFaao+i2pqY8ypGeJJ8GK+uBzxewFulFesy1WikvCp6EUPAkhIInIRQ8CaHgSQgFT0IoeBJCwZMQCp6EUPAkhIInIRQ8CaHgSQgFT0IoeBJCwZMQCp6EUPAkhIInIRQ8CaHgSQgFT0IoeBJCwZMQCp6EUPAkhIInIRQ8CaHgSQgFT0JkDJ6ZLTCzRjNb2WHacDNbamZrk8dh3VumlJts9niPAdOPmHYXsMzdxwPLktciWcsYvOQu7h8fMXkGsDB5vhC4qsB1SZnL9Tterbu3997cTup+yCJZy/vgwlM9qTq9o7aZzTGz5Wa2vIXmfN9OykSuwdthZqMAksfGzhZUnws5llyDtxiYnTyfDTxbmHKkt8jmdMoTwGvABDPbamY3AXOBS8xsLXBx8lokaxn7XLj7tZ3MuqjAtUgvoisXEkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQuTaYOW7ZtZgZu8kf1d0b5lSbnJtsAIwz93rkr8lhS1Lyl2uDVZE8pLPd7zbzGxF8lHcaS8z9bmQY8k1eA8CpwB1wDbg/s4WVJ8LOZacgufuO9z9kLu3AT8CJhe2LCl3OQWvvatPYiawsrNlRY4lY5+LpMHKNGCEmW0F7gGmmVkdqR5mG4Gbu7FGKUO5NliZ3w21SC+iKxcSQsGTEAqehFDwJISCJyEUPAmh4EkIBU9CKHgSQsGTEAqehFDwJISCJyEUPAmh4EkIBU9CKHgSQsGTEAqehFDwJISCJyEUPAmh4EkIBU9CKHgSQsGTEAqehFDwJISCJyEUPAmh4EmIbPpcnGhmvzGz98xslZndnkwfbmZLzWxt8tjpDbhFjpTNHq8V+La7TwTOAb5lZhOBu4Bl7j4eWJa8FslKNn0utrn775PnTcD7wGhgBrAwWWwhcFV3FSnlJ+OtaDsys5OBs4A3gFp335bM2g7UdrLOHGAOQBUDc61TykzWBxdmNgh4BrjD3fd1nOfuTupG3EdRnws5lqyCZ2b9SIXup+7+82Tyjva2A8ljY/eUKOUom6NaI3WX9/fd/fsdZi0GZifPZwPPFr48KVfZfMebAnwDeNfM3kmmfQeYCzxlZjcBm4BZ3VOilKNs+ly8DFgnsy8qbDnSW+jKhYRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EkLBkxAKnoRQ8CSEgichFDwJoeBJCAVPQih4EiKfPhffNbMGM3sn+bui+8uVcpHNHUHb+1z83sxqgHozW5rMm+fu93VfeVKusrkj6DZgW/K8ycza+1yI5KxL3/GO6HMBcJuZrTCzBWopJV2RT5+LB4FTgDpSe8T7O1lvjpktN7PlLTQXoGQpBzn3uXD3He5+yN3bgB8Bk4+1rhqsyLHk3OeivblKYiawsvDlSbnKp8/FtWZWR6qV1Ebg5m6pUMpSPn0ulhS+HOktdOVCQih4EkLBkxAKnoRQ8CSEgichFDwJYe5evDcz20mqqXK7EcCuohXQdaVeH5RejWPcfWSmhYoavKPe3Gy5u08KKyCDUq8PekaNx6KPWgmh4EmI6OA9Evz+mZR6fdAzajxK6Hc86b2i93jSS4UEz8ymm9lqM1tnZndF1JCJmW00s3eToZvLS6CeBWbWaGYrO0wbbmZLzWxt8thjxr0UPXhm1gd4ALgcmEjqB6UTi11Hli5w97oSOV3xGDD9iGl3AcvcfTywLHndI0Ts8SYD69x9vbsfBJ4EZgTU0aO4+0vAx0dMngEsTJ4vBK4qalF5iAjeaGBLh9dbKc1xug48b2b1ZjYnuphO1CbjngG2A7WRxXRFNmMuequp7t5gZscBS83sg2SvU5Lc3c2sx5yiiNjjNQAndnh9QjKtpLh7Q/LYCCyik+GbwXa0j/ZLHhuD68laRPDeAsab2VgzqwSuARYH1NEpM6tO7hODmVUDl1KawzcXA7OT57OBZwNr6ZKif9S6e6uZ3QY8B/QBFrj7qmLXkUEtsCg1pJi+wM/c/VeRBZnZE8A0YISZbQXuAeYCT5nZTaR+9TMrrsKu0ZULCaErFxJCwZMQCp6EUPAkhIInIRQ8CaHgSQgFT0L8Pyp94TKMVNfuAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1)\n",
"img = df.iloc[0].Blobs\n",
"xc, yc, a, b, theta = getEllipseParams(img)\n",
"ax.imshow(img)\n",
"e = Ellipse(xy=[xc,yc], width=a*2, height=b*2, angle=math.degrees(theta), fill=False, lw=2, edgecolor='w')\n",
"ax.add_artist(e)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lst = df.Blobs.apply(lambda x: getEllipseParams(x))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"lst2 = np.vstack(lst.values)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(618012, 5)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lst2.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"df[\"XC\"] = lst2[:,0]\n",
"df[\"YC\"] = lst2[:,1]\n",
"df[\"EllipseW\"] = lst2[:,2]\n",
"df[\"EllipseH\"] = lst2[:,3]\n",
"df[\"EllipseTheta\"] = lst2[:,4]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"df[\"Area\"] = df[\"EllipseW\"] * df[\"EllipseH\"] * np.pi\n",
"df[\"AvgCapa\"] = df.Blobs.apply(lambda x: np.mean(x))\n",
"df[\"SumCapa\"] = df.Blobs.apply(lambda x: np.sum(x))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[8, 11, 6, 7, 16, 15, 14, 10, 9, 2, 3, 13, 17, 5, 12, 1, 4]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lst = list(range(1, df.userID.max()))\n",
"SEED = 42#448\n",
"random.seed(SEED)\n",
"random.shuffle(lst)\n",
"lst"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"dfY = df[df.Set == \"Train\"].copy(deep=True)\n",
"dfT = df[(df.Set == \"Test\") & (df.Version == \"Normal\")].copy(deep=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"minmax = min(len(dfY[dfY.Input == \"Finger\"]), len(dfY[dfY.Input == \"Knuckle\"]))\n",
"dfX = dfY[dfY.Input == \"Finger\"].sample(minmax)\n",
"dfZ = dfY[dfY.Input == \"Knuckle\"].sample(minmax)\n",
"dfY = pd.concat([dfX,dfZ])\n",
"\n",
"minmax = min(len(dfT[dfT.Input == \"Finger\"]), len(dfT[dfT.Input == \"Knuckle\"]))\n",
"dfX = dfT[dfT.Input == \"Finger\"].sample(minmax)\n",
"dfZ = dfT[dfT.Input == \"Knuckle\"].sample(minmax)\n",
"dfT = pd.concat([dfX,dfZ])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>userID</th>\n",
" <th>Timestamp</th>\n",
" <th>Current_Task</th>\n",
" <th>Task_amount</th>\n",
" <th>TaskID</th>\n",
" <th>VersionID</th>\n",
" <th>RepetitionID</th>\n",
" <th>Actual_Data</th>\n",
" <th>Is_Pause</th>\n",
" <th>Image</th>\n",
" <th>...</th>\n",
" <th>InputMethod</th>\n",
" <th>Set</th>\n",
" <th>XC</th>\n",
" <th>YC</th>\n",
" <th>EllipseW</th>\n",
" <th>EllipseH</th>\n",
" <th>EllipseTheta</th>\n",
" <th>Area</th>\n",
" <th>AvgCapa</th>\n",
" <th>SumCapa</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Input</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Finger</th>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>...</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Knuckle</th>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>...</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" <td>9421</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2 rows × 31 columns</p>\n",
"</div>"
],
"text/plain": [
" userID Timestamp Current_Task Task_amount TaskID VersionID \\\n",
"Input \n",
"Finger 9421 9421 9421 9421 9421 9421 \n",
"Knuckle 9421 9421 9421 9421 9421 9421 \n",
"\n",
" RepetitionID Actual_Data Is_Pause Image ... InputMethod Set \\\n",
"Input ... \n",
"Finger 9421 9421 9421 9421 ... 9421 9421 \n",
"Knuckle 9421 9421 9421 9421 ... 9421 9421 \n",
"\n",
" XC YC EllipseW EllipseH EllipseTheta Area AvgCapa SumCapa \n",
"Input \n",
"Finger 9421 9421 9421 9421 9421 9421 9421 9421 \n",
"Knuckle 9421 9421 9421 9421 9421 9421 9421 9421 \n",
"\n",
"[2 rows x 31 columns]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfT.groupby(\"Input\").count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# FEATURE SET: sum of capacitance, avg of capacitance, ellipse area, ellipse width, height and theta."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"features = [\"SumCapa\", \"AvgCapa\", \"Area\", \"EllipseW\", \"EllipseH\", \"EllipseTheta\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ZeroR"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"dfT[\"InputMethodPred\"] = 1"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0 9421]\n",
" [ 0 9421]]\n",
"Accuray: 0.50\n",
"Recall: 0.50\n",
"Precision: 0.50\n",
"F1-Score: 0.33\n",
" precision recall f1-score support\n",
"\n",
" Knuckle 0.00 0.00 0.00 9421\n",
" Finger 0.50 1.00 0.67 9421\n",
"\n",
" micro avg 0.50 0.50 0.50 18842\n",
" macro avg 0.25 0.50 0.33 18842\n",
"weighted avg 0.25 0.50 0.33 18842\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/sklearn/metrics/classification.py:1143: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.\n",
" 'precision', 'predicted', average, warn_for)\n",
"/usr/local/lib/python3.6/dist-packages/sklearn/metrics/classification.py:1143: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n",
" 'precision', 'predicted', average, warn_for)\n"
]
}
],
"source": [
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
"print(\"Accuray: %.2f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"Recall: %.2f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(\"Precision: %.2f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(\"F1-Score: %.2f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DecisionTreeClassifier"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 240 candidates, totalling 1200 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=30)]: Using backend LokyBackend with 30 concurrent workers.\n",
"[Parallel(n_jobs=30)]: Done 140 tasks | elapsed: 10.4s\n",
"[Parallel(n_jobs=30)]: Done 390 tasks | elapsed: 31.4s\n",
"[Parallel(n_jobs=30)]: Done 740 tasks | elapsed: 1.3min\n",
"[Parallel(n_jobs=30)]: Done 1200 out of 1200 | elapsed: 2.4min finished\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'max_depth': 22, 'min_samples_split': 2} 0.8120637794585754\n",
"[[7409 2012]\n",
" [3096 6325]]\n",
"Accuray: 0.73\n",
"Recall: 0.73\n",
"Precision: 0.67\n",
"F1-Score: 0.73\n",
" precision recall f1-score support\n",
"\n",
" Knuckle 0.71 0.79 0.74 9421\n",
" Finger 0.76 0.67 0.71 9421\n",
"\n",
" micro avg 0.73 0.73 0.73 18842\n",
" macro avg 0.73 0.73 0.73 18842\n",
"weighted avg 0.73 0.73 0.73 18842\n",
"\n",
"CPU times: user 7.26 s, sys: 3.38 s, total: 10.6 s\n",
"Wall time: 2min 29s\n"
]
}
],
"source": [
"%%time\n",
"param_grid = {'max_depth': range(2,32,1),\n",
" 'min_samples_split':range(2,10,1)}\n",
"#TODO: Create Baseline for different ML stuff\n",
"clf = GridSearchCV(tree.DecisionTreeClassifier(), \n",
" param_grid,\n",
" cv=5 , n_jobs=os.cpu_count()-2, verbose=1)\n",
"clf.fit(dfY[features].values, dfY.InputMethod.values)\n",
"print(clf.best_params_, clf.best_score_)\n",
"dfT[\"InputMethodPred\"] = clf.predict(dfT[features].values) \n",
"\n",
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
"print(\"Accuray: %.3f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"Recall: %.3f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(\"Precision: %.3f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(\"F1-Score: %.3f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RandomForestClassifier"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 180 candidates, totalling 900 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=94)]: Using backend LokyBackend with 94 concurrent workers.\n",
"[Parallel(n_jobs=94)]: Done 12 tasks | elapsed: 1.2min\n",
"[Parallel(n_jobs=94)]: Done 262 tasks | elapsed: 4.0min\n",
"[Parallel(n_jobs=94)]: Done 612 tasks | elapsed: 9.2min\n",
"[Parallel(n_jobs=94)]: Done 900 out of 900 | elapsed: 12.8min finished\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'max_depth': 60, 'n_estimators': 63} 0.8669582104371696\n",
"[[8175 1246]\n",
" [2765 6656]]\n",
"Accuray: 0.79\n",
"Recall: 0.71\n",
"Precision: 0.74\n",
"F1-Score: 0.77\n",
" precision recall f1-score support\n",
"\n",
" Knuckle 0.75 0.87 0.80 9421\n",
" Finger 0.84 0.71 0.77 9421\n",
"\n",
" micro avg 0.79 0.79 0.79 18842\n",
" macro avg 0.79 0.79 0.79 18842\n",
"weighted avg 0.79 0.79 0.79 18842\n",
"\n",
"CPU times: user 42.1 s, sys: 834 ms, total: 42.9 s\n",
"Wall time: 13min 28s\n"
]
}
],
"source": [
"%%time\n",
"param_grid = {'n_estimators': range(55,64,1),\n",
" 'max_depth': range(50,70,1)}\n",
"#TODO: Create Baseline for different ML stuff\n",
"clf = GridSearchCV(ensemble.RandomForestClassifier(), \n",
" param_grid,\n",
" cv=5 , n_jobs=os.cpu_count()-2, verbose=1)\n",
"clf.fit(dfY[features].values, dfY.InputMethod.values)\n",
"print(clf.best_params_, clf.best_score_)\n",
"dfT[\"InputMethodPred\"] = clf.predict(dfT[features].values) \n",
"\n",
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
"print(\"Accuray: %.2f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"Recall: %.2f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"Precision: %.2f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"F1-Score: %.2f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# kNN"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 62 candidates, totalling 310 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=94)]: Using backend LokyBackend with 94 concurrent workers.\n",
"[Parallel(n_jobs=94)]: Done 12 tasks | elapsed: 17.7s\n",
"[Parallel(n_jobs=94)]: Done 310 out of 310 | elapsed: 1.5min finished\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'n_neighbors': 2} 0.800546827088748\n",
"[[8187 1234]\n",
" [4318 5103]]\n",
"Accuray: 0.71\n",
"Recall: 0.54\n",
"Precision: 0.67\n",
"F1-Score: 0.65\n",
" precision recall f1-score support\n",
"\n",
" Knuckle 0.65 0.87 0.75 9421\n",
" Finger 0.81 0.54 0.65 9421\n",
"\n",
" micro avg 0.71 0.71 0.71 18842\n",
" macro avg 0.73 0.71 0.70 18842\n",
"weighted avg 0.73 0.71 0.70 18842\n",
"\n",
"CPU times: user 1.74 s, sys: 300 ms, total: 2.04 s\n",
"Wall time: 1min 30s\n"
]
}
],
"source": [
"%%time\n",
"param_grid = {'n_neighbors': range(2,64,1),\n",
" #weights': ['uniform', 'distance']\n",
" }\n",
"#TODO: Create Baseline for different ML stuff\n",
"clf = GridSearchCV(neighbors.KNeighborsClassifier(),\n",
" param_grid,\n",
" cv=5 , n_jobs=os.cpu_count()-2, verbose=1)\n",
"clf.fit(dfY[features].values, dfY.InputMethod.values)\n",
"print(clf.best_params_, clf.best_score_)\n",
"dfT[\"InputMethodPred\"] = clf.predict(dfT[features].values) \n",
"\n",
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
"print(\"Accuray: %.2f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"Recall: %.2f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(\"Precision: %.2f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(\"F1-Score: %.2f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values, average=\"macro\"))\n",
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SVM"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 9 candidates, totalling 45 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=94)]: Using backend LokyBackend with 94 concurrent workers.\n",
"[Parallel(n_jobs=94)]: Done 42 out of 45 | elapsed: 1056.5min remaining: 75.5min\n",
"[Parallel(n_jobs=94)]: Done 45 out of 45 | elapsed: 1080.5min finished\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'C': 10.0, 'gamma': 10.0} 0.8256943024851795\n",
"CPU times: user 2h 42min 9s, sys: 23.6 s, total: 2h 42min 33s\n",
"Wall time: 20h 43min 1s\n"
]
}
],
"source": [
"%%time\n",
"C_range = np.logspace(1, 3,3)\n",
"gamma_range = np.logspace(-1, 1, 3)\n",
"param_grid = dict(gamma=gamma_range, C=C_range)\n",
"clf = GridSearchCV(sklearn.svm.SVC(), \n",
" param_grid,\n",
" cv=5 , n_jobs=os.cpu_count()-2, verbose=1)\n",
"clf.fit(dfY[features].values, dfY.InputMethod.values)\n",
"print(clf.best_params_, clf.best_score_)\n",
"\n",
"dfT[\"InputMethodPred\"] = clf.predict(dfT[features].values)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'C': 10.0, 'gamma': 10.0} 0.8256943024851795\n",
"[[7106 2315]\n",
" [2944 6477]]\n",
"Accuray: 0.72\n",
"Recall: 0.69\n",
"Precision: 0.66\n",
"F1-Score: 0.71\n",
" precision recall f1-score support\n",
"\n",
" Knuckle 0.71 0.75 0.73 9421\n",
" Finger 0.74 0.69 0.71 9421\n",
"\n",
" micro avg 0.72 0.72 0.72 18842\n",
" macro avg 0.72 0.72 0.72 18842\n",
"weighted avg 0.72 0.72 0.72 18842\n",
"\n"
]
}
],
"source": [
"print(clf.best_params_, clf.best_score_)\n",
"print(confusion_matrix(dfT.InputMethod.values, dfT.InputMethodPred.values, labels=[0, 1]))\n",
"print(\"Accuray: %.2f\" % accuracy_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"Recall: %.2f\" % metrics.recall_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"Precision: %.2f\" % metrics.average_precision_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(\"F1-Score: %.2f\" % metrics.f1_score(dfT.InputMethod.values, dfT.InputMethodPred.values))\n",
"print(classification_report(dfT.InputMethod.values, dfT.InputMethodPred.values, target_names=target_names))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

13240
python/Step_07_CNN.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Filtering the data for the LSTM: removes all the rows, where we used the revert button, when the participant performed a wrong gesture\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from scipy.odr import *\n",
"from scipy.stats import *\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import time\n",
"import matplotlib.pyplot as plt\n",
"import ast\n",
"from multiprocessing import Pool, cpu_count\n",
"\n",
"import scipy\n",
"\n",
"from IPython import display\n",
"from matplotlib.patches import Rectangle\n",
"\n",
"from sklearn.metrics import mean_squared_error\n",
"import json\n",
"\n",
"import scipy.stats as st\n",
"from sklearn.metrics import r2_score\n",
"\n",
"\n",
"from matplotlib import cm\n",
"from mpl_toolkits.mplot3d import axes3d\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import copy\n",
"\n",
"from sklearn.model_selection import LeaveOneOut, LeavePOut\n",
"\n",
"from multiprocessing import Pool\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dfAll = pd.read_pickle(\"DataStudyCollection/AllData.pkl\")\n",
"dfAll.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_actual = dfAll[(dfAll.Actual_Data == True) & (dfAll.Is_Pause == False)]\n",
"df_actual.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"all: %s, actual data: %s\" % (len(dfAll), len(df_actual)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"# filter out all gestures, where the revert button was pressed during the study and the gestrue was repeated\n",
"def is_max(df):\n",
" df_temp = df.copy(deep=True)\n",
" max_version = df_temp.RepetitionID.max()\n",
" df_temp[\"IsMax\"] = np.where(df_temp.RepetitionID == max_version, True, False)\n",
" df_temp[\"MaxRepetition\"] = [max_version] * len(df_temp)\n",
" return df_temp\n",
"\n",
"df_filtered = df_actual.copy(deep=True)\n",
"df_grp = df_filtered.groupby([df_filtered.userID, df_filtered.TaskID, df_filtered.VersionID])\n",
"pool = Pool(cpu_count() - 1)\n",
"result_lst = pool.map(is_max, [grp for name, grp in df_grp])\n",
"df_filtered = pd.concat(result_lst)\n",
"df_filtered = df_filtered[df_filtered.IsMax == True]\n",
"pool.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_filtered.to_pickle(\"DataStudyCollection/df_lstm.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"actual: %s, filtered data: %s\" % (len(df_actual), len(df_filtered)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

2550
python/Step_11_LSTM.ipynb Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,494 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_pickle(\"DataStudyCollection/df_statistics.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAA8lJREFUeJzt2D1qVGEYhmFHJ/4VCppOrCyt7azF3hXYuAB7EdcjQgo7N+AKBMHCysJmCiMJYnKsD2M1zOFLbq5rAS9PcW4+OKtpmq4AHVdHDwD2S9QQI2qIETXEiBpi1kscfX7/1YX6pX765NHoCVvOr61GT5i5/fnb6AkzZ5vN6AkX3qfz9//9iLzUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXErJc4urpxfYmzO7v35vvoCVueHX4ZPWHm6MXT0RPmNpvRCy4tLzXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihpj1IlcPDhY5u6t3D49GT9jy+Pqt0RNmPty5OXoCe+KlhhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghZr3E0enkZImzO3v59vXoCVv+3F2NnjDz4OeP0RNm/o4ecIl5qSFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiFkvcXQ6/r3E2Z0dfvw6esK21Wr0gpmzX8ejJ7AnXmqIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTHrJY6en54ucXZ3F20PLMhLDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDzGqaptEbgD3yUkOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0x/wB+HS+RtgmxWAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAA91JREFUeJzt2DFqVGEYhlEnDsSAChZaWFiK2LkJSzegpZXrcA8uwspaa8FObLRUG9s4KCSGybW+jEUIuf76cE45xccLw8MPdzVN0yWgY2/0AOBiiRpiRA0xooYYUUPMeomjD/cf/1Of1PcOroyesGt/f/SCuePj0Qtmtj9+jp6w63Q7esHM69OXqz/97qWGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFmvcTRabtd4uy5nd69M3rCjs+Pro2eMHPz/enoCTPX3nwcPWHHdrMZPeFMvNQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhRQ4yoIUbUECNqiBE1xIgaYkQNMaKGGFFDjKghRtQQI2qIETXEiBpiRA0xooYYUUOMqCFG1BAjaogRNcSIGmLWSxxdXb68xNlzO7x3dfSEHZ+evhg9Yeb+2yejJ8xcf/fv/WeXNpvRC87ESw0xooYYUUOMqCFG1BAjaogRNcSIGmJEDTGihhhR