knuckletouch/python/Step_38_LSTM-Report.ipynb

379 lines
54 KiB
Plaintext
Raw Permalink Normal View History

2019-08-07 23:57:12 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/requests/__init__.py:91: RequestsDependencyWarning: urllib3 (1.25.2) or chardet (3.0.4) doesn't match a supported version!\n",
" RequestsDependencyWarning)\n"
]
},
{
"data": {
"text/plain": [
"'1.13.1'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"## USE for Multi GPU Systems\n",
"#import os\n",
"#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import math\n",
"\n",
"import tensorflow as tf\n",
"\n",
"%matplotlib inline\n",
"\n",
"# Importing SK-learn to calculate precision and recall\n",
"import sklearn\n",
"from sklearn import metrics\n",
"\n",
"from sklearn.utils.multiclass import unique_labels\n",
"\n",
"target_names = [\"tap\", \"twotap\", \"swipeleft\", \"swiperight\", \"swipeup\", \"swipedown\", \"twoswipeup\", \"twoswipedown\", \"circle\", \"arrowheadleft\", \"arrowheadright\", \"checkmark\", \"flashlight\", \"l\", \"lmirrored\", \"screenshot\", \"rotate\"]\n",
"\n",
"\n",
"target_names = [\"Tap\", \"Two tap\", \"Swipe left\", \"Swipe right\", \"Swipe up\", \"Swipe down\",\n",
" \"Two swipe up\", \"Two swipe down\", \"Circle\", \"Arrowhead left\", \"Arrowhead right\",\n",
" \"$\\checkmark$\", \"$\\Gamma$\", \"L\", \"L mirrored\", \"S\", \"Rotate\"]\n",
"\n",
"\n",
"tf.__version__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_pickle(\"DataStudyEvaluation/df_lstm_norm50.pkl\")\n",
"\n",
"df.TaskID = df.TaskID % 17\n",
"\n",
"x = np.concatenate(df.Blobs.values).reshape(-1,50,27,15,1)\n",
"x = x / 255.0\n",
"\n",
"# convert class vectors to binary class matrices (one-hot notation)\n",
"num_classes = len(df.TaskID.unique())\n",
"y = tf.keras.utils.to_categorical(df.TaskID, num_classes)\n",
"\n",
"labels = sorted(df.TaskID.unique())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# If GPU is not available: \n",
"# GPU_USE = '/cpu:0'\n",
"#config = tf.ConfigProto(device_count = {\"GPU\": 1})\n",
"\n",
"# If GPU is available: \n",
"config = tf.ConfigProto()\n",
"config.log_device_placement = True\n",
"config.allow_soft_placement = True\n",
"config.gpu_options.allow_growth=True\n",
"config.gpu_options.allocator_type = 'BFC'\n",
"\n",
"# Limit the maximum memory used\n",
"config.gpu_options.per_process_gpu_memory_fraction = 0.3\n",
"\n",
"# set session config\n",
"sess = tf.Session(config=config)\n",
"tf.keras.backend.set_session(sess)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.cast instead.\n"
]
}
],
"source": [
"model = tf.keras.models.load_model('./ModelSnapshots/LSTM-v2-00398.h5')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 30.4 s, sys: 3.28 s, total: 33.7 s\n",
"Wall time: 19.6 s\n"
]
}
],
"source": [
"%%time\n",
"lst = []\n",
"batch = 100\n",
"for i in range(0, len(x), batch):\n",
" _x = x[i : i+batch]\n",
" lst.extend(model.predict(_x))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"df[\"TaskIDPred\"] = lst\n",
"df.TaskIDPred = df.TaskIDPred.apply(lambda x: np.argmax(x))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"df_eval = df"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[221 2 0 0 2 1 0 0 1 0 0 0 2 0 0 0 43]\n",
" [ 3 297 0 0 0 1 1 0 0 0 0 0 0 0 0 0 13]\n",
" [ 3 1 239 3 0 0 0 0 0 1 0 0 0 0 24 0 1]\n",
" [ 2 1 1 244 0 0 0 5 1 0 0 3 5 5 3 0 4]\n",
" [ 2 0 0 2 222 0 0 0 0 0 0 1 27 0 0 0 6]\n",
" [ 5 0 0 0 0 246 0 1 1 0 0 2 4 21 0 2 3]\n",
" [ 0 3 0 0 4 0 306 1 0 0 0 0 1 0 1 0 0]\n",
" [ 0 6 0 0 1 12 2 306 0 0 0 0 6 1 0 0 0]\n",
" [ 0 0 2 0 0 0 0 0 273 9 0 10 2 1 1 0 0]\n",
" [ 1 0 4 1 0 0 0 0 11 249 2 4 0 9 0 8 1]\n",
" [ 0 0 0 6 0 0 0 0 0 2 267 1 0 0 2 14 0]\n",
" [ 1 0 0 1 4 0 0 1 19 1 0 247 4 0 1 0 0]\n",
" [ 1 0 2 3 18 1 0 0 4 7 0 0 239 7 0 8 8]\n",
" [ 0 0 0 3 1 6 0 0 5 5 0 7 2 272 0 0 0]\n",
" [ 1 0 1 0 0 6 0 0 0 0 5 3 0 3 278 3 4]\n",
" [ 0 0 8 0 0 1 0 0 6 10 5 0 3 1 21 250 1]\n",
" [ 15 0 0 0 0 0 1 0 0 0 0 0 18 0 1 0 312]]\n",
"[[0.8 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.2]\n",
" [0. 0.9 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]\n",
" [0. 0. 0.9 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.1 0. 0. ]\n",
" [0. 0. 0. 0.9 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0.9 0. 0. 0. 0. 0. 0. 0. 0.1 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0.9 0. 0. 0. 0. 0. 0. 0. 0.1 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0.9 0. 0. 0. 0. 0. 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0.9 0. 0. 0. 0. 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.9 0. 0. 0. 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.9 0. 0. 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0.1 0. 0. 0.9 0. 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0.1 0. 0. 0. 0. 0. 0. 0. 0.8 0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.9 0. 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.9 0. 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.1 0.8 0. ]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.1 0. 0. 0. 0.9]]\n",
"Accuray: 0.886\n",
"Recall: 0.885\n",
"F1-Score: 0.886\n",
" precision recall f1-score support\n",
"\n",
" Tap 0.87 0.81 0.84 272\n",
" Two tap 0.96 0.94 0.95 315\n",
" Swipe left 0.93 0.88 0.90 272\n",
" Swipe right 0.93 0.89 0.91 274\n",
" Swipe up 0.88 0.85 0.87 260\n",
" Swipe down 0.90 0.86 0.88 285\n",
" Two swipe up 0.99 0.97 0.98 316\n",
" Two swipe down 0.97 0.92 0.94 334\n",
" Circle 0.85 0.92 0.88 298\n",
" Arrowhead left 0.88 0.86 0.87 290\n",
"Arrowhead right 0.96 0.91 0.94 292\n",
" $\\checkmark$ 0.89 0.89 0.89 279\n",
" $\\Gamma$ 0.76 0.80 0.78 298\n",
" L 0.85 0.90 0.88 301\n",
" L mirrored 0.84 0.91 0.87 304\n",
" S 0.88 0.82 0.85 306\n",
" Rotate 0.79 0.90 0.84 347\n",
"\n",
" accuracy 0.89 5043\n",
" macro avg 0.89 0.88 0.89 5043\n",
" weighted avg 0.89 0.89 0.89 5043\n",
"\n"
]
}
],
"source": [
"print(sklearn.metrics.confusion_matrix(df_eval.TaskID.values, df_eval.TaskIDPred.values, labels=labels))\n",
"cm = sklearn.metrics.confusion_matrix(df_eval.TaskID.values, df_eval.TaskIDPred.values, labels=labels)\n",
"cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
"print(np.round(cm,1))\n",
"print(\"Accuray: %.3f\" % sklearn.metrics.accuracy_score(df_eval.TaskID.values, df_eval.TaskIDPred.values))\n",
"print(\"Recall: %.3f\" % metrics.recall_score(df_eval.TaskID.values, df_eval.TaskIDPred.values, average=\"macro\"))\n",
"#print(\"Precision: %.3f\" % metrics.average_precision_score(df_eval.TaskID.values, df_eval.TaskIDPred.values))\n",
"print(\"F1-Score: %.3f\" % metrics.f1_score(df_eval.TaskID.values, df_eval.TaskIDPred.values, average=\"macro\"))\n",
"print(sklearn.metrics.classification_report(df_eval.TaskID.values, df_eval.TaskIDPred.values, target_names=target_names))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Normalized confusion matrix\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x7f64c079e0b8>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAGoCAYAAAATsnHAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsnXl4VNX5xz8vCZsgmyDIoojsmxACKIsCrkVUKqiAbFLFBXFvf7a1ltalKlSpWlux1n1rabWCigiKsu+IorIJyCKyI4EEksn7+2NucIhJmMycyRzC+3me82TmLt/5npM7eXPn3jlfUVUMwzAMwzfKJNuAYRiGYRSEFSjDMAzDS6xAGYZhGF5iBcowDMPwEitQhmEYhpdYgTIMwzC8xAqUYRiG4SVWoAzDMAwvsQJlGIZheElqsg0c75SpcKKWqVTLiVbbhjWc6IgTFfe4nPPE1z4aJY8dV9FzMCfXic7mjd+ye9eOow6XFagkU6ZSLSpf/EcnWrNfuMaJjoifbzOX03L52kej5LHjKnrWb9/vRKf/xd2j2s4+4jMMwzC8xAqUYRiG4SVWoAzDMAwvsQLlITdf3Jy5j/RhzsN9+MeobpQvW4brL2jKkj9fzp5XB1Ojcvlia95w/QhOq1eb9HZt4vY39YMptG3VjFbNGzP20Ye90Crt/fNVy0dPLrXsuCqc395xE13bNOTSnh2PWP7Kc3+jd/f29OmRztj7743LG6pqrZAGnAQsC9pWYHPE83IuXiOlxuladdDLh1vzURN1/ff7tPaw17TqoJf1v3PX601/n63dfz1Z29z6X92wbZ+ePvJfR+yT1w4cyi20TZ0+Q2fPX6QtW7YqcrsDh3I1M1sLbRlZOXp6o0b65cq1unf/QW3Tpq0u+WxFkfu40nLVv6L6mMz+HWtaPnqy4yqxWl9tyTjcXvrvFJ04ZZY2btbi8LIX/v2untWth362bqd+tSVDZy3/5oh98lqrtu01mr+PdgZVBKq6U1XbqWo74O/A43nPVfVQol43JUWoUC6FlDJCxfIpfLc7k+UbdvPtjtjvoOnW/RxqVI//NvSFCxZwxhmNOb1RI8qVK8eVVw9g8qT/JV2rtPfPRy0fPbnWsuOqcDqe1Y1q1asfseyNl/7B9bfcRbny4U95Tqp5cky+8rACFSMiMklEFovIChG5LliWKiJ7ROSJYPmHInJScXS/253JU+9+yRdP/JyVf+3HDwey+fjz7xLTiRjYsmUz9es3OPy8Xr36bN68OelarvC1fz5q+ejJtZYrfO2f67Fav3YNi+fP5upLejDkiov4fNnimLXAClQ8DFPVDkBH4E4RyftXoiowW1VbAXOB3+XfUURGisgiEVmUm/XDEeuqnlCO3h0acObtb9P8lv9QqXwqV3U9PcFdMQzDiJ+cUA579+zmjckf88vfPcgdNwyN63tmVqBi5w4R+YxwEaoPnBEszwH+HTx+BeiWf0dVnaCq6aqaXqZClSPW9Whdhw3bM9i57yA5IWXSwm/p1KRm4npRTOrWrcemTRsPP9+8eRP16tVLupYrfO2fj1o+enKt5Qpf++d6rOqcUo8Lel+GiNC2fTplypRh964dMetZgYoBETkfOAc4S1XPBJYDFQrZvFj/PmzauZ/0xjWpWC4FgHNb1WHVlh+OslfJkd6xI2vWrGb9unUcOnSIf7/5Bpf0uSzpWq7wtX8+avnoybWWK3ztn+uxOu/iPsyf/SkA69auJvvQIarXiP0fbJvqKDaqArtUNVNEWhH+mC+PVOAKYCIwCJhVHOHFa3fyzoJv+eTB3uSElM837OKFj1Zzw0XNuLVPS2pXrcjshy/hw2VbuPUf86LWHTZ4EJ9+OoOdO3bQ+PQG3HvfGIZf+4viWAMgNTWVx//yFJdechGhUIhhw0fQslWrYuu41irt/fNRy0dPrrXsuCqcu24azoK5M9mzayc9OjTllrt+yxUDhnLvnTdxac+OlC1bjj/95Zm4pn8Sl/NQlWZEZAyQoarjRKQC8D+gAbCS8O3ovwHmATuAF4Hzge+Aq1V1Z2G6qSc1Uldz8X1nc/FFja99NEoeO66ix+VcfF98tsQmi3WFqo6JeJwFXJR/GxFJDdbfVnLODMMwSid2DcowDMPwEjuDcoiq5gDVku3DMAyjNGBnUIZhGIaX2BlUkmnbsAaznndzc0ONbr9yorN79lgnOq4p7RegjeRgx1X01Kla2LdpikfZlOjOjewMyjAMw/ASK1CGYRiGl1iBMgzDMLzECpRhGIbhJVagDMMwDC+xAuUxN44cwWn1a5PePra46VFXd2PRa3ex+PW7uGXAkZOq3zboHDLnj+WkqicUW9eHuOlEavnoyVctHz251PLRk89aoVCIc85O5+p+jibnTXas+rEQzQ7cCVRIhPf2aR10/8HcAtsH02borHmLtEXLVoVuE9kqdLr7cEsbMFa/WPOdVu/+a6109q90+vxV2vKKP2mFTndr4z7369S5X+uGLbu03gX3HbFfhU53exs3bTHmfmn56Mn6l1it3ftzimwP/Gms9rtygF54ce8it2vXvsPxE/leAtHsd1J4nEbCiCduunnD2ixc8S2ZB7MJhXKZufQb+vYIn4k9esdl/Papd2OaJNOXuOlEafnoyVctHz251PLRk89amzdvYuqU9xg6fERM+xdEqShQhSEivxaRm4PHT4rI1ODxhSLyYvB4sIh8LiJfiMhDBWjcAZwMzBSRacGyCUEi7goRuS9i200i8kigN19EGpVEPwtixTdb6drudGpUOYGK5ctycZfm1K9dlT7ntGLL9r18vjq2GHlf46YtxrzktXz05FLLR08+a/3mV3fyhwcfpkwZd2WlVBcoYCbQPXicBlQTkZRg2aciUh94AOgJtAe6ikifSAFVfRzYBnRX1fODxfeoajpwJnCBiLSM2GWXqrYBngEeK8hUZOT7jh3bnXQ0PyvXb+PPL33MpCev552/XMdnq7ZQrlwqvxrWiz8+MzUhr2kYxvHJlPcnU7PWybRr38GpbmkvUAuBjiJSDcgInqcRLlAzgc7AR6q6Q1WzgdcIJ+UejYEisgRYArQAIgvU68HPV4EuBe0cGfles2atGLoVHS9OWkjXYX/hghv/xp59B/jqm+85rW4NFrxyB1+/9WvqnVyVuS/dTu0aJ0at6WvctMWYl7yWj55cavnoyVet+XPnMOXdSbRtcQa/GHYNMz/5mJEjhsbkKZJSXaBU9SDhGyaGArMJF6XzgNNUdVUsmiLSBLgN6KWqbYEpHHl9ypsEyFrVKwHQoHY1Lu/RhlfeXcRpP/sDzX/+J5r//E9s3raXs4eO5/td+6LW9DVu2mLMS17LR08utXz05KvW7//4ECtWb2D5V2t57sVX6X5uTyb886WYPEVyPEwWOxO4m3CRWg2MJZx8CzAfGCciJwF7gQHAuAI09gEnAnuAKsHzH0TkFMLBhVMitr060BhIuCjGzLAhg5gZxE03adSAe383hmHFiJt+/eGh1KhaieycELePfYu9GVnx2AH8iZtOlJaPnnzV8tGTSy0fPfmslQhKXeR7ZDR78Pwi4B2gqqpmicg3wHhVfSJYPxj4P0CASar66wI07wBuBDYCFwAvEf54cAOwH5ioqq+IyCbgFaA3kAkMVNVvivKb1iFdZ81dGH/HgZO6l+7ZzA3DSC5Zh0JOdHp268zSJYuOv8j3yGj24PkHQPmI543yrX+FcFEpSvNx4PGIRUOK2PxhVb0nWr+GYRhGwZTqa1CGYRjGsUupO4NKJqpaP9keDMMwSgtWoEoRrq4dVe/z+NE3ipKd79zuTKtMGUs+NY4fdma4mAQHTqpczokOQIVyKU50on0r20d8hmEYhpdYgTIMwzC8xAqUYRiG4SVWoAzDMAwvsQJlGIZheIkVKI+JN1E3P/EkZ47+eXsWPzOURX8fwov3/IzyZVPo0a4Bc54axLy/XsP0P19Fo1OqFkvTp/4lQud40PLRk0stHz0B/OPvT3Le2e3pdXY7/vG3J7zw5bJ/h0l2Gm6Uiba/BVYAywmn5HaOcr8/Auc7eP0ewOQ
"text/plain": [
"<Figure size 432x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def plot_confusion_matrix(y_true, y_pred, classes,\n",
" normalize=False,\n",
" title=None,\n",
" cmap=plt.cm.Blues):\n",
" \"\"\"\n",
" This function prints and plots the confusion matrix.\n",
" Normalization can be applied by setting `normalize=True`.\n",
" \"\"\"\n",
" if not title:\n",
" if normalize:\n",
" title = 'Normalized confusion matrix'\n",
" else:\n",
" title = 'Confusion matrix, without normalization'\n",
"\n",
" # Compute confusion matrix\n",
" cm = sklearn.metrics.confusion_matrix(y_true, y_pred)\n",
" # Only use the labels that appear in the data\n",
" #classes = classes[unique_labels(y_true, y_pred)]\n",
" if normalize:\n",
" cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
" print(\"Normalized confusion matrix\")\n",
" else:\n",
" print('Confusion matrix, without normalization')\n",
"\n",
" #print(cm)\n",
"\n",
" fig, ax = plt.subplots(figsize=(6,6))\n",
" im = ax.imshow(cm, interpolation='nearest', cmap=cmap)\n",
" #ax.figure.colorbar(im, ax=ax)\n",
" # We want to show all ticks...\n",
" ax.set(xticks=np.arange(cm.shape[1]),\n",
" yticks=np.arange(cm.shape[0]),\n",
" # ... and label them with the respective list entries\n",
" xticklabels=classes, yticklabels=classes)\n",
" #title=title,\n",
" #ylabel='True label',\n",
" #xlabel='Predicted label')\n",
"\n",
" # Rotate the tick labels and set their alignment.\n",
" plt.setp(ax.get_xticklabels(), rotation=90, ha=\"right\",\n",
" rotation_mode=\"anchor\")\n",
"\n",
" # Loop over data dimensions and create text annotations.\n",
" fmt = '.0f' if normalize else 'd'\n",
" thresh = cm.max() / 2.\n",
" for i in range(cm.shape[0]):\n",
" for j in range(cm.shape[1]):\n",
" ax.text(j, i, format(cm[i, j]*100, fmt),\n",
" ha=\"center\", va=\"center\",\n",
" color=\"white\" if cm[i, j] > thresh else \"black\")\n",
" fig.tight_layout()\n",
" fig.savefig(\"./out/conf_matrix.pdf\", bbox_inches='tight', transparent=False, pad_inches=0)\n",
" return ax\n",
"\n",
"plot_confusion_matrix(df_eval.TaskID.values, df_eval.TaskIDPred.values, classes=target_names, normalize=True,\n",
" title='Normalized confusion matrix')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}