{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "## USE for Multi GPU Systems\n", "#import os\n", "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import math\n", "\n", "import tensorflow as tf\n", "\n", "%matplotlib inline\n", "\n", "# Importing SK-learn to calculate precision and recall\n", "import sklearn\n", "from sklearn import metrics\n", "\n", "target_names = [\"Knuckle\", \"Finger\"]" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 1 2 9 6 4 14 17 16 12 3 10 18 5] [13 8 11 15 7]\n" ] } ], "source": [ "df = pd.read_pickle(\"DataStudyCollection/df_statistics.pkl\")\n", "\n", "lst = df.userID.unique()\n", "np.random.seed(42)\n", "np.random.shuffle(lst)\n", "test_ids = lst[-5:]\n", "train_ids = lst[:-5]\n", "print(train_ids, test_ids)\n", "df[\"GestureId\"] = df.TaskID % 17\n", "\n", "#df_train = dfAll[dfAll.userID.isin(train_ids)]\n", "#df_test = dfAll[dfAll.userID.isin(test_ids)]\n", "\n", "x = np.concatenate(df.Blobs.values).reshape(-1,27,15,1)\n", "x = x / 255.0\n", "\n", "# convert class vectors to binary class matrices (one-hot notation)\n", "num_classes = len(df.TaskID.unique())\n", "y = tf.keras.utils.to_categorical(df.TaskID, num_classes)\n", "\n", "labels = sorted(df.TaskID.unique())" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "x = np.stack(df.Blobs)\n", "x = x.reshape(-1, 27, 15, 1)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# convert class vectors to binary class matrices (one-hot notation)\n", "num_classes = 2\n", "y = tf.keras.utils.to_categorical(df.InputMethod, num_classes)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Label for image 1 is: [1. 0.]')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAALEAAAEICAYAAAAQmxXMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAD1tJREFUeJzt3X2wXHV9x/H35yaBQIiFAGYgPIRmUp2UGeIMgm3BQhEE+hD8oymM0GhpY1XG2vpQdNoGqVXGKVUZLaNAIEWBQShD1BQIGRjaKRUCA5oUNBASSBoSIASChIck3/5xflc2l717d++evXu/8HnN3Nlz9jx9d+9nf3vO2bP7U0RgltlAvwsw65ZDbOk5xJaeQ2zpOcSWnkNs6fUsxJLulvTndS+rytWSnpd0X3dVgqQjJL0kaUK36xov6npMkq6R9JqkdTWV1sm2f6M8hl0j5WjEEEtaJ+kD9ZXXtROAU4HDIuK4blcWEU9GxH4Rsav70npH0smS7pL0wkihqvkxfS0iZjbUMV/Sf0t6WdLdna5M0l9LelrSi5IWS9q72XwR8YuI2A/4z5HWmXF34khgXUT8stMFJU3sQT1j5ZfAYuBzfa5jK/AN4JJOF5T0QeBC4BSq/+OvA1/qtqBRh1jSAZJ+JOmZ8tb+I0mHDZltlqT7yqvuVknTGpZ/X3lFb5P0sKST2tjm+cCVwG+Vt5ovlfv/QtJjkrZKWirp0IZlQtInJa0B1jRZ58wyz8QyfrekL5faXpL0Q0kHSvp+eRz3S5rZsPw3JT1Vpj0g6cSGaftIWlKen0ckfV7Shobph0q6uTyHT0j61HCPPSLui4hrgbVtPE9DH9NHJK2VtL1s58MjraNFHXdGxI3A/41i8QXAVRGxOiKeB/4R+MhoaxnUTUs8AFxN9Yo6AtgBfGvIPH8K/BlwCLATuAxA0gzgx8CXgWnAZ4GbJR3caoMRcRXwl8C95e1ykaTfA74KzC/bWQ/cMGTRs4DjgTltPrazgfOAGcAs4N7yWKcBjwCLGua9H5hbpl0H/EDS5DJtETCTqsU5FTh3cCFJA8APgYfLdk4BPl1aq9pImkL1vJ8REVOB3wYeKtOOKI3IEXVus4XfpHq8gx4Gpks6sJuVjjrEEfFcRNwcES9HxHbgn4DfHTLbtRGxqrz1/z0wvxxsnAssi4hlEbE7IpYDK4EzR1HKh4HFEfFgRLwKfIGqpZ7ZMM9XI2JrROxoc51XR8TjEfEC8B/A46UF2gn8AHjP4IwR8b3yXOyMiEuBvYF3lcnzga9ExPMRsYHyIi7eCxwcERdHxGsRsRa4guoFVLfdwNGS9omITRGxutT+ZETsHxFP9mCbzewHvNAwPjg8tZuVdrM7sa+k70haL+lF4B5g/yFHxE81DK8HJgEHUbXef1xagW2StlEdsB0yilIOLesGICJeAp6jat2a1dGOzQ3DO5qM7zc4IumzZVfhhfI4fo3qMQ7W1rjtxuEjgUOHPAdfBKZ3WGtLpQH5E6p3sE2Sfizp3XVuowMvAe9oGB8c3t7NSrvZnfgMVYtzfES8A3h/uV8N8xzeMHwE8DrwLNU/89rSCgz+TYmIjg8WqPbNjhwcKW+fBwIbG+bpyaV6Zf/381Qt7gERsT9V6zL4HGwCGo8TGp+Pp4AnhjwHUyNiNO9GLUXE7RFxKlUj8ShVi98Pq4FjGsaPATZHxHPdrLTdEE+SNLnhbyLVW8AOYFs5YFvUZLlzJc2RtC9wMXBTOe3zPeAPJX1Q0oSyzpOaHBi243rgo5LmltM1XwF+EhHrRrGuTk2l2td/Bpgo6R/Ys6W5EfhCOQieAVzQMO0+YLukvy0HgBMkHS3pvc02JGmg7GtPqkY1WdJeIxUoabqkeeXF/SpVa7h7NA+2rG9CqWMiMFDqmNTm4v8GnF8ysT/wd8A1o61lULshXkYV2MG/i6hOs+xD1bL+D3Bbk+WuLUU+DUwGPgUQEU8B86jePp+hapU+10E9vxIRd1Ltb99M1fLNojf7lc3cTvW4f0G1S/MKe+4yXAxsAJ4A7gRuogoS5cX8B1QHhU9QPY9XUu2ONPN+qud+GW8cSN/RRo0DwN9QvWNtpTpu+Tjs8aFIJwd255VtXw6cWIZ/1bKX9Z3YbMGIuA34GnAX8CTVc7aoYdnVozlzIl8UP3YkfRw4OyKGHgCPS5KuAM6hesufNcbbnk115mcv4BMRcc2w8zrEvSPpEKrTa/cCs6lOK34rIr7R18LeYjJ/gpXBXsB3gKOAbVTnr/+1rxW9BbkltvQyXjthtocx353YS3vHZKYMP4M0/DRAE0a6urD1O0vsHNcXq70lbOf5ZyOi5SUEdaolxJJOB74JTACubPWhxWSmcPyE04Zf16TWJQ3sP9wZqGJX65Du2rqt9fK7HfJu3Rk3rR95rvp0vTtRPmb+NnAG1QU250hq90Ibs67VsU98HPBYRKyNiNeojsDn1bBes7bUEeIZ7Pkp1Qb2vPgGSQslrZS08vXqAyuz2ozJ2YmI+G5EHBsRx06i6bdRzEatjhBvZM+rsw5jzyvIzHqqjhDfD8yWdFS5qupsYGkN6zVrS9en2CJip6QLqK7omkD1LYvVrRca/krAgb1H2N2Y2uIcM8D2Eb4/6lNobzm1nCeOiGVUlwiajTl/7GzpOcSWnkNs6TnElp5DbOk5xJbemF9PrIEBBvbZZ/gZZrT+7ZDHzz2o5fRJ21tfj3z4ZS+1nL775ZdbTrfxxy2xpecQW3oOsaXnEFt6DrGl5xBbeg6xpTfufsYqRvhdiZ9/9PKu1v/7N/xRy+m7143Vj6ZbXdwSW3oOsaXnEFt6DrGl5xBbeg6xpecQW3rj7jzxwPMvtpx+9GWfaDn95UNa92717h0jdo1sybgltvQcYkvPIbb0HGJLzyG29BxiS88htvTG/Dxx7N7N7leG77dDr7zScvkjb9jQ3fZ3tF6/5VNXP3brgO3ALmBnRBxbx3rN2lFnS3xyRDxb4/rM2uJ9YkuvrhAHcIekByQtrGmdZm2pa3fihIjYKOmdwHJJj0bEPYMTS7AXAkxm35o2aVappSWOiI3ldgtwC1VXuY3T3Rmj9UwdHZRPkTR1cBg4DVjV7XrN2lXH7sR04BZJg+u7LiJua7lEi77kdj23tfXWRppubzt1dMa4FjimhlrMRsWn2Cw9h9jSc4gtPYfY0nOILT2H2NJziC09h9jSc4gtPYfY0nOILT2H2NJziC09h9jSc4gtPYfY0nOILT2H2NJziC09h9jSc4gtPYfY0nOILT2H2NJziC09h9jSc4gtPYfY0nOILT2H2NJziC09h9jSazvEkhZL2iJpVcN90yQtl7Sm3B7QmzLNhtdJS3wNcPqQ+y4EVkTEbGBFGTcbU22HuHTpNbTDjHnAkjK8BDirprrM2tZtnx3TI2JTGX6aqhOaN3E/dtZLtR3YRURQ9SzabJr7sbOe6TbEmyUdAlBut3Rfkllnug3xUmBBGV4A3Nrl+sw61skptuuBe4F3Sdog6XzgEuBUSWuAD5RxszHV9oFdRJwzzKRTaqrFbFT8iZ2l5xBbeg6xpecQW3oOsaXnEFt6DrGl5xBbeg6xpecQW3oOsaXnEFt6DrGl5xBbeg6xpecQW3oOsaXnEFt6DrGl5xBbeg6xpecQW3oOsaXnEFt6DrGl5xBbeg6xpecQW3oOsaXnEFt6DrGl120/dhdJ2ijpofJ3Zm/KNBtet/3YAXw9IuaWv2X1lGXWvm77sTPruzr2iS+Q9NOyu9G0W1xJCyWtlLTydV6tYZNmb+g2xJcDs4C5wCbg0mYzuR8766WuQhwRmyNiV0TsBq4AjqunLLP2dRXiwY4Yiw8Bq4ab16xX2u4CrPRjdxJwkKQNwCLgJElzqbrDXQd8rAc1mrXUbT92V9VYi9mo+BM7S88htvQcYkvPIbb0HGJLzyG29BxiS88htvQcYkvPIbb0HGJLzyG29BxiS88htvQcYkvPIbb0HGJLzyG29BxiS88htvQcYkvPIbb0HGJLzyG29BxiS88htvQcYkvPIbb0HGJLzyG29BxiS6+TfuwOl3SXpP+VtFrSX5X7p0laLmlNuW3a+YxZr3TSEu8EPhMRc4D3AZ+UNAe4EFgREbOBFWXcbMx00o/dpoh4sAxvBx4BZgDzgCVltiXAWXUXadZK290dNJI0E3gP8BNgekRsKpOeBqY3mX8hsBBgMvuOZpNmw+r4wE7SfsDNwKcj4sXGaRERVJ3QMOR+92NnPdNRiCVNogrw9yPi38vdmwe7Aiu3W+ot0ay1Ts5OiKq3pEci4l8aJi0FFpThBcCt9ZVnNrJO9ol/BzgP+Jmkh8p9XwQuAW6UdD6wHphfb4lmrXXSj91/ARpm8in1lGPWOX9iZ+k5xJaeQ2zpOcSWnkNs6TnElp5DbOk5xJaeQ2zpOcSWnkNs6TnElp5DbOk5xJaeQ2zpOcSWnkNs6TnElp5DbOk5xJaeQ2zpOcSWnkNs6TnElp5DbOk5xJaeQ2zpOcSWnkNs6TnElp5DbOk5xJZeWyFu0RHjRZI2Snqo/J3Z23LN3qzdX4of7IjxQUlTgQckLS/Tvh4R/9yb8sxG1laISz91m8rwdkmDHTGa9d1o+rGbyRsdMQJcIOmnkhYP16+zpIWSVkpa+TqvjrpYs2Y67cduaEeMlwOzgLlULfWlzZZzZ4zWS530Y/emjhgjYnNE7IqI3cAVwHG9KdNseO2enWjaEeNgT6LFh4BV9ZZnNrJ2z04M1xHjOZLmUvXnvA74WO0Vmo2g3bMTw3XEuKzecsw650/sLD2H2NJziC09h9jSc4gtPYfY0lNEjO0GpWeA9UPuPgh4dkwL6Yzr68yREXHwWG1szEPctAhpZUQc2+86huP6xjfvTlh6DrGlN15C/N1+FzAC1zeOjYt9YrNujJeW2GzUHGJLr68hlnS6pJ9LekzShf2spRlJ6yT9rPwcwcp+1wNQvsu4RdKqhvumSVouaU25bfpdx7eqvoVY0gTg28AZwByqC+zn9KueFk6OiLnj6DzsNcDpQ+67EFgREbOBFWX8baOfLfFxwGMRsTYiXgNuAOb1sZ4UIuIeYOuQu+cBS8rwEuCsMS2qz/oZ4hnAUw3jGxh/v2URwB2SHpC0sN/FtDC9/DYIwNPA9H4WM9ba/Y7d29UJEbFR0juB5ZIeLS3huBURIeltdd60ny3xRuDwhvHDyn3jRkRsLLdbgFsYvz9JsHnwm+fldkuf6xlT/Qzx/cBsSUdJ2gs4G1jax3r2IGlK+d05JE0BTmP8/iTBUmBBGV4A3NrHWsZc33YnImKnpAuA24EJwOKIWN2vepqYDtxS/eQGE4HrIuK2/pYEkq4HTgIOkrQBWARcAtwo6Xyqy1zn96/CseePnS09f2Jn6TnElp5DbOk5xJaeQ2zpOcSWnkNs6f0/kZhtd/D2o3sAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "i = 1\n", "plt.imshow(x[i].reshape(27, 15)) #np.sqrt(784) = 28\n", "plt.title(\"Label for image %i is: %s\" % (i, y[i]))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# If GPU is not available: \n", "# GPU_USE = '/cpu:0'\n", "#config = tf.ConfigProto(device_count = {\"GPU\": 1})\n", "\n", "\n", "# If GPU is available: \n", "config = tf.ConfigProto()\n", "config.log_device_placement = True\n", "config.allow_soft_placement = True\n", "config.gpu_options.allow_growth=True\n", "config.gpu_options.allocator_type = 'BFC'\n", "\n", "# Limit the maximum memory used\n", "config.gpu_options.per_process_gpu_memory_fraction = 0.4\n", "\n", "# set session config\n", "tf.keras.backend.set_session(tf.Session(config=config))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "loadpath = \"./ModelSnapshots/CNN-33767.h5\"\n", "model = tf.keras.models.load_model(loadpath)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1min 45s, sys: 12 s, total: 1min 57s\n", "Wall time: 1min 12s\n" ] } ], "source": [ "%%time\n", "lst = []\n", "batch = 100\n", "for i in range(0, len(x), batch):\n", " _x = x[i: i+batch]\n", " lst.extend(model.predict(_x))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "df[\"InputMethodPred\"] = lst\n", "df.InputMethodPred = df.InputMethodPred.apply(lambda x: np.argmax(x))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "df = df.groupby([\"userID\", \"TaskID\", \"VersionID\"])[[\"InputMethodPred\", \"InputMethod\"]].agg(lambda x: x.tolist()).reset_index()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from collections import Counter" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "df.InputMethod = df.InputMethod.apply(lambda x: Counter(x).most_common()[0][0])\n", "df.InputMethodPred = df.InputMethodPred.apply(lambda x: Counter(x).most_common()[0][0])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIDTaskIDVersionID
InputMethodInputMethodPred
00442544254425
1124124124
10838383
1593559355935
\n", "
" ], "text/plain": [ " userID TaskID VersionID\n", "InputMethod InputMethodPred \n", "0 0 4425 4425 4425\n", " 1 124 124 124\n", "1 0 83 83 83\n", " 1 5935 5935 5935" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.groupby([\"InputMethod\", \"InputMethodPred\"]).count()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "df_train = df[df.userID.isin(train_ids)]\n", "df_test = df[df.userID.isin(test_ids)]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[3260 82]\n", " [ 47 4274]]\n", "[[0.97546379 0.02453621]\n", " [0.01087711 0.98912289]]\n", "Accuray: 0.983\n", "Recall: 0.989\n", "Precision: 0.977\n", "F1-Score: 0.985\n", " precision recall f1-score support\n", "\n", " Knuckle 0.99 0.98 0.98 3342\n", " Finger 0.98 0.99 0.99 4321\n", "\n", " micro avg 0.98 0.98 0.98 7663\n", " macro avg 0.98 0.98 0.98 7663\n", "weighted avg 0.98 0.98 0.98 7663\n", "\n" ] } ], "source": [ "print(sklearn.metrics.confusion_matrix(df_train.InputMethod.values, df_train.InputMethodPred.values, labels=[0, 1]))\n", "cm = sklearn.metrics.confusion_matrix(df_train.InputMethod.values, df_train.InputMethodPred.values, labels=[0, 1], )\n", "cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", "print(cm)\n", "print(\"Accuray: %.3f\" % sklearn.metrics.accuracy_score(df_train.InputMethod.values, df_train.InputMethodPred.values))\n", "print(\"Recall: %.3f\" % metrics.recall_score(df_train.InputMethod.values, df_train.InputMethodPred.values))\n", "print(\"Precision: %.3f\" % metrics.average_precision_score(df_train.InputMethod.values, df_train.InputMethodPred.values))\n", "print(\"F1-Score: %.3f\" % metrics.f1_score(df_train.InputMethod.values, df_train.InputMethodPred.values))\n", "print(sklearn.metrics.classification_report(df_train.InputMethod.values, df_train.InputMethodPred.values, target_names=target_names))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1165 42]\n", " [ 36 1661]]\n", "[[0.96520298 0.03479702]\n", " [0.02121391 0.97878609]]\n", "Accuray: 0.973\n", "Recall: 0.979\n", "Precision: 0.967\n", "F1-Score: 0.977\n", " precision recall f1-score support\n", "\n", " Knuckle 0.97 0.97 0.97 1207\n", " Finger 0.98 0.98 0.98 1697\n", "\n", " micro avg 0.97 0.97 0.97 2904\n", " macro avg 0.97 0.97 0.97 2904\n", "weighted avg 0.97 0.97 0.97 2904\n", "\n" ] } ], "source": [ "print(sklearn.metrics.confusion_matrix(df_test.InputMethod.values, df_test.InputMethodPred.values, labels=[0, 1]))\n", "cm = sklearn.metrics.confusion_matrix(df_test.InputMethod.values, df_test.InputMethodPred.values, labels=[0, 1], )\n", "cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", "print(cm)\n", "print(\"Accuray: %.3f\" % sklearn.metrics.accuracy_score(df_test.InputMethod.values, df_test.InputMethodPred.values))\n", "print(\"Recall: %.3f\" % metrics.recall_score(df_test.InputMethod.values, df_test.InputMethodPred.values))\n", "print(\"Precision: %.3f\" % metrics.average_precision_score(df_test.InputMethod.values, df_test.InputMethodPred.values))\n", "print(\"F1-Score: %.3f\" % metrics.f1_score(df_test.InputMethod.values, df_test.InputMethodPred.values))\n", "print(sklearn.metrics.classification_report(df_test.InputMethod.values, df_test.InputMethodPred.values, target_names=target_names))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }