knuckletouch/python/Step_08_CNN-Report.ipynb

431 lines
17 KiB
Plaintext
Raw Permalink Normal View History

2019-08-07 23:57:12 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"/usr/local/lib/python3.6/dist-packages/requests/__init__.py:91: RequestsDependencyWarning: urllib3 (1.25.2) or chardet (3.0.4) doesn't match a supported version!\n",
" RequestsDependencyWarning)\n"
]
}
],
"source": [
"import keras\n",
"from keras.models import load_model\n",
"from keras import utils\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import math\n",
"\n",
"import tensorflow as tf\n",
"\n",
"# Importing matplotlib to plot images.\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"%matplotlib inline\n",
"\n",
"# Importing SK-learn to calculate precision and recall\n",
"import sklearn\n",
"from sklearn import metrics\n",
"from sklearn.model_selection import train_test_split, cross_val_score, LeaveOneGroupOut\n",
"from sklearn.utils import shuffle \n",
"\n",
"# Used for graph export\n",
"from tensorflow.python.framework import graph_util\n",
"from tensorflow.python.framework import graph_io\n",
"from keras import backend as K\n",
"from keras import regularizers\n",
"\n",
"import pickle as pkl\n",
"import h5py\n",
"\n",
"from pathlib import Path\n",
"import os.path\n",
"import sys\n",
"import datetime\n",
"import time\n",
"\n",
"target_names = [\"Knuckle\", \"Finger\"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1 2 9 6 4 14 17 16 12 3 10 18 5] [13 8 11 15 7]\n",
"13 : 5\n",
"0.7222222222222222 : 0.2777777777777778\n",
"503886\n"
]
}
],
"source": [
"# the data, split between train and test sets\n",
"df = pd.read_pickle(\"DataStudyCollection/df_blobs_area.pkl\")\n",
"\n",
"lst = df.userID.unique()\n",
"np.random.seed(42)\n",
"np.random.shuffle(lst)\n",
"test_ids = lst[-5:]\n",
"train_ids = lst[:-5]\n",
"print(train_ids, test_ids)\n",
"print(len(train_ids), \":\", len(test_ids))\n",
"print(len(train_ids) / len(lst), \":\", len(test_ids)/ len(lst))\n",
"\n",
"df = df[df.userID.isin(train_ids) | df.userID.isin(test_ids) & (df.Version == \"Normal\")]\n",
"print(len(df))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"x = np.vstack(df.Blobs)\n",
"x = x.reshape(-1, 27, 15, 1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# convert class vectors to binary class matrices (one-hot notation)\n",
"num_classes = 2\n",
"y = utils.to_categorical(df.InputMethod, num_classes)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Label for image 1 is: [1. 0.]')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALEAAAEICAYAAAAQmxXMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAD1tJREFUeJzt3X2wXHV9x/H35yaBQIiFAGYgPIRmUp2UGeIMgm3BQhEE+hD8oymM0GhpY1XG2vpQdNoGqVXGKVUZLaNAIEWBQShD1BQIGRjaKRUCA5oUNBASSBoSIASChIck3/5xflc2l717d++evXu/8HnN3Nlz9jx9d+9nf3vO2bP7U0RgltlAvwsw65ZDbOk5xJaeQ2zpOcSWnkNs6fUsxJLulvTndS+rytWSnpd0X3dVgqQjJL0kaUK36xov6npMkq6R9JqkdTWV1sm2f6M8hl0j5WjEEEtaJ+kD9ZXXtROAU4HDIuK4blcWEU9GxH4Rsav70npH0smS7pL0wkihqvkxfS0iZjbUMV/Sf0t6WdLdna5M0l9LelrSi5IWS9q72XwR8YuI2A/4z5HWmXF34khgXUT8stMFJU3sQT1j5ZfAYuBzfa5jK/AN4JJOF5T0QeBC4BSq/+OvA1/qtqBRh1jSAZJ+JOmZ8tb+I0mHDZltlqT7yqvuVknTGpZ/X3lFb5P0sKST2tjm+cCVwG+Vt5ovlfv/QtJjkrZKWirp0IZlQtInJa0B1jRZ58wyz8QyfrekL5faXpL0Q0kHSvp+eRz3S5rZsPw3JT1Vpj0g6cSGaftIWlKen0ckfV7Shobph0q6uTyHT0j61HCPPSLui4hrgbVtPE9DH9NHJK2VtL1s58MjraNFHXdGxI3A/41i8QXAVRGxOiKeB/4R+MhoaxnUTUs8AFxN9Yo6AtgBfGvIPH8K/BlwCLATuAxA0gzgx8CXgWnAZ4GbJR3caoMRcRXwl8C95e1ykaTfA74KzC/bWQ/cMGTRs4DjgTltPrazgfOAGcAs4N7yWKcBjwCLGua9H5hbpl0H/EDS5DJtETCTqsU5FTh3cCFJA8APgYfLdk4BPl1aq9pImkL1vJ8REVOB3wYeKtOOKI3IEXVus4XfpHq8gx4Gpks6sJuVjjrEEfFcRNwcES9HxHbgn4DfHTLbtRGxqrz1/z0wvxxsnAssi4hlEbE7IpYDK4EzR1HKh4HFEfFgRLwKfIGqpZ7ZMM9XI2JrROxoc51XR8TjEfEC8B/A46UF2gn8AHjP4IwR8b3yXOyMiEuBvYF3lcnzga9ExPMRsYHyIi7eCxwcERdHxGsRsRa4guoFVLfdwNGS9omITRGxutT+ZETsHxFP9mCbzewHvNAwPjg8tZuVdrM7sa+k70haL+lF4B5g/yFHxE81DK8HJgEHUbXef1xagW2StlEdsB0yilIOLesGICJeAp6jat2a1dGOzQ3DO5qM7zc4IumzZVfhhfI4fo3qMQ7W1rjtxuEjgUOHPAdfBKZ3WGtLpQH5E6p3sE2Sfizp3XVuowMvAe9oGB8c3t7NSrvZnfgMVYtzfES8A3h/uV8N8xzeMHwE8DrwLNU/89rSCgz+TYmIjg8WqPbNjhwcKW+fBwIbG+bpyaV6Zf/381Qt7gERsT9V6zL4HGwCGo8TGp+Pp4AnhjwHUyNiNO9GLUXE7RFxKlUj8ShVi98Pq4FjGsaPATZHxHPdrLTdEE+SNLnhbyLVW8AOYFs5YFvUZLlzJc2RtC9wMXBTOe3zPeAPJX1Q0oSyzpOaHBi243rgo5LmltM1XwF+EhHrRrGuTk2l2td/Bpgo6R/Ys6W5EfhCOQieAVzQMO0+YLukvy0HgBMkHS3pvc02JGmg7GtPqkY1WdJeIxUoabqkeeXF/SpVa7h7NA+2rG9CqWMiMFDqmNTm4v8GnF8ysT/wd8A1o61lULshXkYV2MG/i6hOs+xD1bL+D3Bbk+WuLUU+DUwGPgUQEU8B86jePp+hapU+10E9vxIRd1Ltb99M1fLNojf7lc3cTvW4f0G1S/MKe+4yXAxsAJ4A7gRuogoS5cX8B1QHhU9QPY9XUu2ONPN+qud+GW8cSN/RRo0DwN9QvWNtpTpu+Tjs8aFIJwd255VtXw6cWIZ/1bKX9Z3YbMGIuA34GnAX8CTVc7aoYdnVozlzIl8UP3YkfRw4OyKGHgCPS5KuAM6hesufNcbbnk115mcv4BMRcc2w8zrEvSPpEKrTa/cCs6lOK34rIr7R18LeYjJ/gpXBXsB3gKOAbVTnr/+1rxW9BbkltvQyXjthtocx353YS3vHZKYMP4M0/DRAE0a6urD1O0vsHNcXq70lbOf5ZyOi5SUEdaolxJJOB74JTACubPWhxWSmcPyE04Zf16TWJQ3sP9wZqGJX65Du2rqt9fK7HfJu3Rk3rR95rvp0vTtRPmb+NnAG1QU250hq90Ibs67VsU98HPBYRKyNiNeojsDn1bBes7bUEeIZ7Pkp1Qb2vPgGSQslrZS08vXqAyuz2ozJ2YmI+G5EHBsRx06i6bdRzEatjhBvZM+rsw5jzyvIzHqqjhDfD8yWdFS5qupsYGkN6zVrS9en2CJip6QLqK7omkD1LYvVrRca/krAgb1H2N2Y2uIcM8D2Eb4/6lNobzm1nCeOiGVUlwiajTl/7GzpOcSWnkNs6TnElp5DbOk5xJbemF9PrIEBBvbZZ/gZZrT+7ZDHzz2o5fRJ21tfj3z4ZS+1nL775ZdbTrfxxy2xpecQW3oOsaXnEFt6DrGl5xBbeg6xpTfufsYqRvhdiZ9/9PKu1v/7N/xRy+m7143Vj6ZbXdwSW3oOsaXnEFt6DrGl5xBbeg6xpecQW3rj7jzxwPMvtpx+9GWfaDn95UNa92717h0jdo1sybgltvQcYkvPIbb0HGJLzyG29BxiS88htvTG/Dxx7N7N7leG77dDr7zScvkjb9jQ3fZ3tF6/5VNXP3brgO3ALmBnRBxbx3rN2lFnS3xyRDxb4/rM2uJ9YkuvrhAHcIekByQtrGmdZm2pa3fihIjYKOmdwHJJj0bEPYMTS7AXAkxm35o2aVappSWOiI3ldgtwC1VXuY3T3Rmj9UwdHZRPkTR1cBg4DVjV7XrN2lXH7sR04BZJg+u7LiJua7lEi77kdj23tfXWRppubzt1dMa4FjimhlrMRsWn2Cw9h9jSc4gtPYfY0nOILT2H2NJziC09h9jSc4gtPYfY0nOILT2H2NJziC09h9jSc4gtPYfY0nOILT2H2NJziC09h9jSc4gtPYfY0nOILT2H2NJziC09h9jSc4gtPYfY0nOILT2H2NJziC09h9jSazvEkhZL2iJpVcN90yQtl7Sm3B7QmzLNhtdJS3wNcPqQ+y4EVkTEbGBFGTcbU22HuHTpNbTDjHnAkjK8BDirprrM2tZtnx3TI2JTGX6aqhOaN3E/dtZLtR3YRURQ9SzabJr7sbOe6TbEmyUdAlBut3Rfkllnug3xUmBBGV4A3Nrl+sw61skptuuBe4F3Sdog6XzgEuBUSWuAD5RxszHV9oFdRJwzzKRTaqrFbFT8iZ2l5xBbeg6xpecQW3oOsaXnEFt6DrGl5xBbeg6xpecQW3oOsaXnEFt6DrGl5xBbeg6xpecQW3oOsaXnEFt6DrGl5xBbeg6xpecQW3oOsaXnEFt6DrGl5xBbeg6xpecQW3oOsaXnEFt6DrGl120/dhdJ2ijpofJ3Zm/KNBtet/3YAXw9IuaWv2X1lGXWvm77sTPruzr2iS+Q9NOyu9G0W1xJCyWtlLTydV6tYZNmb+g2xJcDs4C5wCbg0mYzuR8766WuQhwRmyNiV0TsBq4AjqunLLP2dRXiwY4Yiw8Bq4ab16xX2u4CrPRjdxJwkKQNwCLgJElzqbrDXQd8rAc1mrXUbT92V9VYi9mo+BM7S88htvQcYkvPIbb0HGJLzyG29BxiS88htvQcYkvPIbb0HGJLzyG29BxiS88htvQcYkvPIbb0HGJLzyG29BxiS88htvQcYkvPIbb0HGJLzyG29BxiS88htvQcYkvPIbb0HGJLzyG29BxiS6+TfuwOl3S
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"i = 1\n",
"plt.imshow(x[i].reshape(27, 15)) #np.sqrt(784) = 28\n",
"plt.title(\"Label for image %i is: %s\" % (i, y[i]))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# If GPU is not available: \n",
"# GPU_USE = '/cpu:0'\n",
"#config = tf.ConfigProto(device_count = {\"GPU\": 1})\n",
"\n",
"\n",
"# If GPU is available: \n",
"config = tf.ConfigProto()\n",
"config.log_device_placement = True\n",
"config.allow_soft_placement = True\n",
"config.gpu_options.allow_growth=True\n",
"config.gpu_options.allocator_type = 'BFC'\n",
"\n",
"# Limit the maximum memory used\n",
"config.gpu_options.per_process_gpu_memory_fraction = 0.4\n",
"\n",
"# set session config\n",
"tf.keras.backend.set_session(tf.Session(config=config))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"loadpath = \"./ModelSnapshots/CNN-33767.h5\"\n",
"model = load_model(loadpath)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min 28s, sys: 9.52 s, total: 1min 37s\n",
"Wall time: 1min\n"
]
}
],
"source": [
"%%time\n",
"lst = []\n",
"batch = 100\n",
"for i in range(0, len(x), batch):\n",
" _x = x[i: i+batch]\n",
" lst.extend(model.predict(_x))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"df[\"InputMethodPred\"] = lst\n",
"df.InputMethodPred = df.InputMethodPred.apply(lambda x: np.argmax(x))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"df_train = df[df.userID.isin(train_ids)]\n",
"df_test = df[df.userID.isin(test_ids) & (df.Version == \"Normal\")]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[124207 12765]\n",
" [ 6596 322276]]\n",
"[[0.90680577 0.09319423]\n",
" [0.02005644 0.97994356]]\n",
"Accuray: 0.958\n",
"Recall: 0.943\n",
"Precision: 0.957\n",
" precision recall f1-score support\n",
"\n",
" Knuckle 0.95 0.91 0.93 136972\n",
" Finger 0.96 0.98 0.97 328872\n",
"\n",
" micro avg 0.96 0.96 0.96 465844\n",
" macro avg 0.96 0.94 0.95 465844\n",
"weighted avg 0.96 0.96 0.96 465844\n",
"\n"
]
}
],
"source": [
"print(sklearn.metrics.confusion_matrix(df_train.InputMethod.values, df_train.InputMethodPred.values, labels=[0, 1]))\n",
"cm = sklearn.metrics.confusion_matrix(df_train.InputMethod.values, df_train.InputMethodPred.values, labels=[0, 1])\n",
"cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
"print(cm)\n",
"print(\"Accuray: %.3f\" % sklearn.metrics.accuracy_score(df_train.InputMethod.values, df_train.InputMethodPred.values))\n",
"print(\"Recall: %.3f\" % metrics.recall_score(df_train.InputMethod.values, df_train.InputMethodPred.values, average=\"macro\"))\n",
"print(\"Precision: %.3f\" % metrics.average_precision_score(df_train.InputMethod.values, df_train.InputMethodPred.values, average=\"macro\"))\n",
"#print(\"F1-Score: %.3f\" % metrics.f1_score(df_train.InputMethod.values, df_train.InputMethodPred.values, average=\"macro\"))\n",
"print(sklearn.metrics.classification_report(df_train.InputMethod.values, df_train.InputMethodPred.values, target_names=target_names))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 8384 1037]\n",
" [ 1028 27593]]\n",
"[[0.88992676 0.11007324]\n",
" [0.03591768 0.96408232]]\n",
"Accuray: 0.946\n",
"Recall: 0.927\n",
"Precision: 0.956\n",
" precision recall f1-score support\n",
"\n",
" Knuckle 0.89 0.89 0.89 9421\n",
" Finger 0.96 0.96 0.96 28621\n",
"\n",
" micro avg 0.95 0.95 0.95 38042\n",
" macro avg 0.93 0.93 0.93 38042\n",
"weighted avg 0.95 0.95 0.95 38042\n",
"\n"
]
}
],
"source": [
"print(sklearn.metrics.confusion_matrix(df_test.InputMethod.values, df_test.InputMethodPred.values, labels=[0, 1]))\n",
"cm = sklearn.metrics.confusion_matrix(df_test.InputMethod.values, df_test.InputMethodPred.values, labels=[0, 1])\n",
"cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
"print(cm)\n",
"print(\"Accuray: %.3f\" % sklearn.metrics.accuracy_score(df_test.InputMethod.values, df_test.InputMethodPred.values))\n",
"print(\"Recall: %.3f\" % metrics.recall_score(df_test.InputMethod.values, df_test.InputMethodPred.values, average=\"macro\"))\n",
"print(\"Precision: %.3f\" % metrics.average_precision_score(df_test.InputMethod.values, df_test.InputMethodPred.values, average=\"macro\"))\n",
"#print(\"F1-Score: %.3f\" % metrics.f1_score(df_test.InputMethod.values, df_test.InputMethodPred.values, average=\"macro\"))\n",
"print(sklearn.metrics.classification_report(df_test.InputMethod.values, df_test.InputMethodPred.values, target_names=target_names))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Export"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"output nodes names are: ['output_node0']\n"
]
}
],
"source": [
"output_node_prefix = \"output_node\"\n",
"num_output = 1\n",
"pred = [None]*num_output\n",
"pred_node_names = [None]*num_output\n",
"for i in range(num_output):\n",
" pred_node_names[i] = output_node_prefix+str(i)\n",
" pred[i] = tf.identity(model.outputs[i], name=pred_node_names[i])\n",
"print('output nodes names are: ', pred_node_names)\n",
"output_node_prefix = pred_node_names[0]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Tensor 'conv2d_1_input:0' shape=(?, 27, 15, 1) dtype=float32>]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.inputs"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"sess = K.get_session()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"output_path = \"./Models/\"\n",
"output_file = \"CNN.pb\""
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Froze 30 variables.\n",
"INFO:tensorflow:Converted 30 variables to const ops.\n",
"Saved the freezed graph at: ./Models/CNN.pb\n"
]
}
],
"source": [
"from tensorflow.python.framework import graph_util\n",
"from tensorflow.python.framework import graph_io\n",
"constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)\n",
"\n",
"graph_io.write_graph(constant_graph, output_path, output_file, as_text=False)\n",
"\n",
"print('Saved the freezed graph at: ', (output_path + output_file))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}