visrecall/RecallNet/notebooks/train_RecallNet.ipynb

985 lines
113 KiB
Plaintext
Raw Normal View History

2022-05-09 14:32:31 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
"/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
]
}
],
"source": [
"import numpy as np\n",
"import keras\n",
"import sys\n",
"import os, glob\n",
"from keras.models import Model\n",
"import tensorflow as tf\n",
"from keras.utils import Sequence\n",
"from keras.optimizers import Adam\n",
"from keras.callbacks import ModelCheckpoint, LearningRateScheduler\n",
"#from IPython.display import clear_output\n",
"import tqdm \n",
"import math\n",
"\n",
"sys.path.append('../src')\n",
"\n",
"from util import get_model_by_name\n",
"\n",
"from sal_imp_utilities import *\n",
"from cb import InteractivePlot\n",
"\n",
"config = tf.ConfigProto()\n",
"config.gpu_options.allow_growth = True\n",
"sess = tf.Session(config=config)\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mon May 9 14:27:53 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 470.103.01 Driver Version: 470.103.01 CUDA Version: 11.4 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 NVIDIA GeForce ... Off | 00000000:29:00.0 On | N/A |\n",
"| 30% 41C P2 38W / 184W | 1027MiB / 7959MiB | 2% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| 0 N/A N/A 1377 G /usr/lib/xorg/Xorg 95MiB |\n",
"| 0 N/A N/A 2243 G /usr/lib/xorg/Xorg 432MiB |\n",
"| 0 N/A N/A 2378 G /usr/bin/gnome-shell 72MiB |\n",
"| 0 N/A N/A 3521 G ...AAAAAAAAA= --shared-files 16MiB |\n",
"| 0 N/A N/A 5912 G ...RendererForSitePerProcess 87MiB |\n",
"| 0 N/A N/A 18975 G ...sktop/bin/mendeleydesktop 42MiB |\n",
"| 0 N/A N/A 168057 G ...178058406726361824,131072 146MiB |\n",
"| 0 N/A N/A 179935 C .../envs/tf-cuda9/bin/python 87MiB |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"%%bash\n",
"nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ENV"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"data_path = '/your/path/VisRecall/'\n",
"# 0: T, 1: FE, 2: F, 3: RV, 4: U\n",
"TYPE = 0\n",
"split = 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"TYPE_Q = ['T','FE','F','RV','U']\n",
"bp_imp = data_path + 'merged/src/'\n",
"training_set = np.load(data_path + 'training_data/%s-question/train_split%d.npy'%(TYPE_Q[TYPE], split),allow_pickle=True)\n",
"val_set = np.load(data_path + 'training_data/%s-question/val_split%d.npy'%(TYPE_Q[TYPE], split),allow_pickle=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train_filename = []\n",
"train_label = []\n",
"train_mean = []\n",
"train_type = []\n",
"\n",
"for data in training_set:\n",
" one_hot = [0,0,0,0,0,0]\n",
" one_hot[data['vistype']] = 1\n",
" train_filename.append(bp_imp+data['name'])\n",
" train_label.append(one_hot)\n",
" train_mean.append(data['norm_mean_acc_withD'])\n",
" train_type.append(data['norm_%d_withD'%(TYPE)])\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"val_filename = []\n",
"val_label = []\n",
"val_mean = []\n",
"val_type = []\n",
"for data in val_set:\n",
" one_hot = [0,0,0,0,0,0]\n",
" one_hot[data['vistype']] = 1\n",
" val_filename.append(bp_imp+data['name'])\n",
" val_label.append(one_hot)\n",
" val_mean.append(data['norm_mean_acc_withD'])\n",
" val_type.append(data['norm_%d_withD'%(TYPE)])\n",
"#val_filename"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model and training params"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# FILL THESE IN: set training parameters \n",
"ckpt_savedir = \"ckpt\"\n",
"\n",
"load_weights = False\n",
"weightspath = \"\"\n",
"\n",
"batch_size = 4\n",
"init_lr = 0.002\n",
"lr_reduce_by = .1\n",
"reduce_at_epoch = 3\n",
"n_epochs = 15\n",
"\n",
"opt = Adam(lr=init_lr) \n",
"\n",
"\n",
"model_name = \"RecallNet_xception\"\n",
"\n",
"model_inp_size = (240, 320)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"input_shape = model_inp_size + (3,)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.\n",
"\n",
"xception output shapes: (?, 30, 40, 2048)\n",
"Model: \"model_1\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_1 (InputLayer) (None, 240, 320, 3) 0 \n",
"__________________________________________________________________________________________________\n",
"block1_conv1 (Conv2D) (None, 119, 159, 32) 864 input_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1_conv1_bn (BatchNormaliza (None, 119, 159, 32) 128 block1_conv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1_conv1_act (Activation) (None, 119, 159, 32) 0 block1_conv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1_conv2 (Conv2D) (None, 117, 157, 64) 18432 block1_conv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1_conv2_bn (BatchNormaliza (None, 117, 157, 64) 256 block1_conv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1_conv2_act (Activation) (None, 117, 157, 64) 0 block1_conv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2_sepconv1 (SeparableConv2 (None, 117, 157, 128 8768 block1_conv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2_sepconv1_bn (BatchNormal (None, 117, 157, 128 512 block2_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2_sepconv2_act (Activation (None, 117, 157, 128 0 block2_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2_sepconv2 (SeparableConv2 (None, 117, 157, 128 17536 block2_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2_sepconv2_bn (BatchNormal (None, 117, 157, 128 512 block2_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_1 (Conv2D) (None, 59, 79, 128) 8192 block1_conv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2_pool (MaxPooling2D) (None, 59, 79, 128) 0 block2_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_1 (BatchNor (None, 59, 79, 128) 512 conv2d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_1 (Add) (None, 59, 79, 128) 0 block2_pool[0][0] \n",
" batch_normalization_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3_sepconv1_act (Activation (None, 59, 79, 128) 0 add_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3_sepconv1 (SeparableConv2 (None, 59, 79, 256) 33920 block3_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3_sepconv1_bn (BatchNormal (None, 59, 79, 256) 1024 block3_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3_sepconv2_act (Activation (None, 59, 79, 256) 0 block3_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3_sepconv2 (SeparableConv2 (None, 59, 79, 256) 67840 block3_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3_sepconv2_bn (BatchNormal (None, 59, 79, 256) 1024 block3_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 30, 40, 256) 32768 add_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3_pool (MaxPooling2D) (None, 30, 40, 256) 0 block3_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_2 (BatchNor (None, 30, 40, 256) 1024 conv2d_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_2 (Add) (None, 30, 40, 256) 0 block3_pool[0][0] \n",
" batch_normalization_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4_sepconv1_act (Activation (None, 30, 40, 256) 0 add_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4_sepconv1 (SeparableConv2 (None, 30, 40, 728) 188672 block4_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4_sepconv1_bn (BatchNormal (None, 30, 40, 728) 2912 block4_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4_sepconv2_act (Activation (None, 30, 40, 728) 0 block4_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4_sepconv2 (SeparableConv2 (None, 30, 40, 728) 536536 block4_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4_sepconv2_bn (BatchNormal (None, 30, 40, 728) 2912 block4_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_3 (Conv2D) (None, 30, 40, 728) 186368 add_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4_pool (MaxPooling2D) (None, 30, 40, 728) 0 block4_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_3 (BatchNor (None, 30, 40, 728) 2912 conv2d_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_3 (Add) (None, 30, 40, 728) 0 block4_pool[0][0] \n",
" batch_normalization_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv1_act (Activation (None, 30, 40, 728) 0 add_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv1 (SeparableConv2 (None, 30, 40, 728) 536536 block5_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv1_bn (BatchNormal (None, 30, 40, 728) 2912 block5_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv2_act (Activation (None, 30, 40, 728) 0 block5_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv2 (SeparableConv2 (None, 30, 40, 728) 536536 block5_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv2_bn (BatchNormal (None, 30, 40, 728) 2912 block5_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv3_act (Activation (None, 30, 40, 728) 0 block5_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv3 (SeparableConv2 (None, 30, 40, 728) 536536 block5_sepconv3_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5_sepconv3_bn (BatchNormal (None, 30, 40, 728) 2912 block5_sepconv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_4 (Add) (None, 30, 40, 728) 0 block5_sepconv3_bn[0][0] \n",
" add_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv1_act (Activation (None, 30, 40, 728) 0 add_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv1 (SeparableConv2 (None, 30, 40, 728) 536536 block6_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv1_bn (BatchNormal (None, 30, 40, 728) 2912 block6_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv2_act (Activation (None, 30, 40, 728) 0 block6_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv2 (SeparableConv2 (None, 30, 40, 728) 536536 block6_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv2_bn (BatchNormal (None, 30, 40, 728) 2912 block6_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv3_act (Activation (None, 30, 40, 728) 0 block6_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv3 (SeparableConv2 (None, 30, 40, 728) 536536 block6_sepconv3_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6_sepconv3_bn (BatchNormal (None, 30, 40, 728) 2912 block6_sepconv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_5 (Add) (None, 30, 40, 728) 0 block6_sepconv3_bn[0][0] \n",
" add_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv1_act (Activation (None, 30, 40, 728) 0 add_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv1 (SeparableConv2 (None, 30, 40, 728) 536536 block7_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv1_bn (BatchNormal (None, 30, 40, 728) 2912 block7_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv2_act (Activation (None, 30, 40, 728) 0 block7_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv2 (SeparableConv2 (None, 30, 40, 728) 536536 block7_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv2_bn (BatchNormal (None, 30, 40, 728) 2912 block7_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv3_act (Activation (None, 30, 40, 728) 0 block7_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv3 (SeparableConv2 (None, 30, 40, 728) 536536 block7_sepconv3_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7_sepconv3_bn (BatchNormal (None, 30, 40, 728) 2912 block7_sepconv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_6 (Add) (None, 30, 40, 728) 0 block7_sepconv3_bn[0][0] \n",
" add_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv1_act (Activation (None, 30, 40, 728) 0 add_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv1 (SeparableConv2 (None, 30, 40, 728) 536536 block8_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv1_bn (BatchNormal (None, 30, 40, 728) 2912 block8_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv2_act (Activation (None, 30, 40, 728) 0 block8_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv2 (SeparableConv2 (None, 30, 40, 728) 536536 block8_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv2_bn (BatchNormal (None, 30, 40, 728) 2912 block8_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv3_act (Activation (None, 30, 40, 728) 0 block8_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv3 (SeparableConv2 (None, 30, 40, 728) 536536 block8_sepconv3_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block8_sepconv3_bn (BatchNormal (None, 30, 40, 728) 2912 block8_sepconv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_7 (Add) (None, 30, 40, 728) 0 block8_sepconv3_bn[0][0] \n",
" add_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv1_act (Activation (None, 30, 40, 728) 0 add_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv1 (SeparableConv2 (None, 30, 40, 728) 536536 block9_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv1_bn (BatchNormal (None, 30, 40, 728) 2912 block9_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv2_act (Activation (None, 30, 40, 728) 0 block9_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv2 (SeparableConv2 (None, 30, 40, 728) 536536 block9_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv2_bn (BatchNormal (None, 30, 40, 728) 2912 block9_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv3_act (Activation (None, 30, 40, 728) 0 block9_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv3 (SeparableConv2 (None, 30, 40, 728) 536536 block9_sepconv3_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block9_sepconv3_bn (BatchNormal (None, 30, 40, 728) 2912 block9_sepconv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_8 (Add) (None, 30, 40, 728) 0 block9_sepconv3_bn[0][0] \n",
" add_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv1_act (Activatio (None, 30, 40, 728) 0 add_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv1 (SeparableConv (None, 30, 40, 728) 536536 block10_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv1_bn (BatchNorma (None, 30, 40, 728) 2912 block10_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv2_act (Activatio (None, 30, 40, 728) 0 block10_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv2 (SeparableConv (None, 30, 40, 728) 536536 block10_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv2_bn (BatchNorma (None, 30, 40, 728) 2912 block10_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv3_act (Activatio (None, 30, 40, 728) 0 block10_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv3 (SeparableConv (None, 30, 40, 728) 536536 block10_sepconv3_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block10_sepconv3_bn (BatchNorma (None, 30, 40, 728) 2912 block10_sepconv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_9 (Add) (None, 30, 40, 728) 0 block10_sepconv3_bn[0][0] \n",
" add_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv1_act (Activatio (None, 30, 40, 728) 0 add_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv1 (SeparableConv (None, 30, 40, 728) 536536 block11_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv1_bn (BatchNorma (None, 30, 40, 728) 2912 block11_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv2_act (Activatio (None, 30, 40, 728) 0 block11_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv2 (SeparableConv (None, 30, 40, 728) 536536 block11_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv2_bn (BatchNorma (None, 30, 40, 728) 2912 block11_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv3_act (Activatio (None, 30, 40, 728) 0 block11_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv3 (SeparableConv (None, 30, 40, 728) 536536 block11_sepconv3_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block11_sepconv3_bn (BatchNorma (None, 30, 40, 728) 2912 block11_sepconv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_10 (Add) (None, 30, 40, 728) 0 block11_sepconv3_bn[0][0] \n",
" add_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv1_act (Activatio (None, 30, 40, 728) 0 add_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv1 (SeparableConv (None, 30, 40, 728) 536536 block12_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv1_bn (BatchNorma (None, 30, 40, 728) 2912 block12_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv2_act (Activatio (None, 30, 40, 728) 0 block12_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv2 (SeparableConv (None, 30, 40, 728) 536536 block12_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv2_bn (BatchNorma (None, 30, 40, 728) 2912 block12_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv3_act (Activatio (None, 30, 40, 728) 0 block12_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv3 (SeparableConv (None, 30, 40, 728) 536536 block12_sepconv3_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block12_sepconv3_bn (BatchNorma (None, 30, 40, 728) 2912 block12_sepconv3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_11 (Add) (None, 30, 40, 728) 0 block12_sepconv3_bn[0][0] \n",
" add_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"block13_sepconv1_act (Activatio (None, 30, 40, 728) 0 add_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"block13_sepconv1 (SeparableConv (None, 30, 40, 728) 536536 block13_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block13_sepconv1_bn (BatchNorma (None, 30, 40, 728) 2912 block13_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block13_sepconv2_act (Activatio (None, 30, 40, 728) 0 block13_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block13_sepconv2 (SeparableConv (None, 30, 40, 1024) 752024 block13_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block13_sepconv2_bn (BatchNorma (None, 30, 40, 1024) 4096 block13_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_4 (Conv2D) (None, 30, 40, 1024) 745472 add_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"block13_pool (MaxPooling2D) (None, 30, 40, 1024) 0 block13_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_4 (BatchNor (None, 30, 40, 1024) 4096 conv2d_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_12 (Add) (None, 30, 40, 1024) 0 block13_pool[0][0] \n",
" batch_normalization_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"block14_sepconv1 (SeparableConv (None, 30, 40, 1536) 1582080 add_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"block14_sepconv1_bn (BatchNorma (None, 30, 40, 1536) 6144 block14_sepconv1[0][0] \n",
"__________________________________________________________________________________________________\n",
"block14_sepconv1_act (Activatio (None, 30, 40, 1536) 0 block14_sepconv1_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block14_sepconv2 (SeparableConv (None, 30, 40, 2048) 3159552 block14_sepconv1_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"block14_sepconv2_bn (BatchNorma (None, 30, 40, 2048) 8192 block14_sepconv2[0][0] \n",
"__________________________________________________________________________________________________\n",
"block14_sepconv2_act (Activatio (None, 30, 40, 2048) 0 block14_sepconv2_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_conv (Conv2D) (None, 10, 14, 256) 4718592 block14_sepconv2_act[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_BN (BatchNormalization) (None, 10, 14, 256) 1024 global_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_1 (Activation) (None, 10, 14, 256) 0 global_BN[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_1 (Dropout) (None, 10, 14, 256) 0 activation_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling2d_1 (Glo (None, 256) 0 dropout_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_dense (Dense) (None, 256) 65792 global_average_pooling2d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_2 (Dropout) (None, 256) 0 global_dense[0][0] \n",
"__________________________________________________________________________________________________\n",
"out_mean_acc (Dense) (None, 1) 257 dropout_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"out_type0_acc (Dense) (None, 1) 257 dropout_2[0][0] \n",
"==================================================================================================\n",
"Total params: 25,647,402\n",
"Trainable params: 25,592,362\n",
"Non-trainable params: 55,040\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"# get model \n",
"model_params = {\n",
" 'input_shape': input_shape,\n",
" #'n_outs': len(losses),\n",
" 'n_outs': 2\n",
"}\n",
"model_func, mode = get_model_by_name(model_name)\n",
"#assert mode == \"simple\", \"%s is a multi-duration model! Please use the multi-duration notebook to train.\" % model_name\n",
"model = model_func(**model_params)\n",
"\n",
"if load_weights: \n",
" model.load_weights(weightspath)\n",
" print(\"load\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Checkpoints will be saved with format ckpt/RecallNet_xception/RecallNet_xception_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5\n"
]
}
],
"source": [
"# set up data generation and checkpoints\n",
"if not os.path.exists(ckpt_savedir): \n",
" os.makedirs(ckpt_savedir)\n",
"\n",
"\n",
"# Generators\n",
"gen_train = RecallNet_Generator(train_filename,\n",
" train_label,\n",
" train_mean,\n",
" train_type,\n",
" batch_size = 4)\n",
"\n",
"gen_val = RecallNet_Generator(val_filename,\n",
" val_label,\n",
" val_mean,\n",
" val_type,\n",
" 1)\n",
"\n",
"# Callbacks\n",
"\n",
"# where to save checkpoints\n",
"#filepath = os.path.join(ckpt_savedir, dataset_sal + \"_\" + l_str + '_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5')\n",
"filepath = os.path.join(ckpt_savedir, model_name, model_name+'_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5')\n",
"\n",
"print(\"Checkpoints will be saved with format %s\" % filepath)\n",
"\n",
"cb_chk = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_weights_only=False, period=1)\n",
"cb_plot = InteractivePlot()\n",
"\n",
"def step_decay(epoch):\n",
" lrate = init_lr * math.pow(lr_reduce_by, math.floor((1+epoch)/reduce_at_epoch))\n",
" if epoch%reduce_at_epoch:\n",
" print('Reducing lr. New lr is:', lrate)\n",
" return lrate\n",
"cb_sched = LearningRateScheduler(step_decay)\n",
"\n",
"cbs = [cb_chk, cb_sched, cb_plot]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[array([0.74688797, 0.84647303, 0.93360996, 0.8340249 ]), array([0.7826087 , 0.95652174, 1. , 0.86956522])]\n"
]
}
],
"source": [
"#test the generator \n",
"batch_img, out= gen_train.__getitem__(1)\n",
"print(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAzIAAAI/CAYAAACs3OxHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACo1ElEQVR4nOzdaWBcB30u/OfMvmi075Itebdsx3ZiJwESsrNDAoUWaGkLXShQlvd24fYWbtvb7bbQ2/belrKUFgqUUnZStgCBLCYksZPYiffd1mppJGskzXJmzvJ+OHNGo9Es58ycmTkjPb8vbWxrPCT2aJ75b4KqqiAiIiIiImokjno/ASIiIiIiIrMYZIiIiIiIqOEwyBARERERUcNhkCEiIiIioobDIENERERERA2HQYaIiIiIiBqOq16/cWdnpzo8PFyv356IiIiIiGzumWeeCauq2pXv5+oWZIaHh3HkyJF6/fZERERERGRzgiBcKfRzbC0jIiIiIqKGwyBDREREREQNh0GGiIiIiIgaTt1mZIiIiIiIyJxUKoWxsTEkEol6PxVL+Xw+DA4Owu12G/4aBhkiIiIiogYxNjaGUCiE4eFhCIJQ76djCVVVMTs7i7GxMWzatMnw17G1jIiIiIioQSQSCXR0dKyZEAMAgiCgo6PDdJWJQYaIiIiIqIGspRCjK+d/E4MMEREREREZ1tTUVO+nAIBBhoiIiIiIGhCDDBERERERmaaqKn7/938fe/bswQ033ID//M//BABMTk7ijjvuwP79+7Fnzx48/vjjkGUZb3/72zO/9u/+7u8q/v25tYyIiIiIiEz7+te/jqNHj+LYsWMIh8O4+eabcccdd+CLX/wiXvGKV+BDH/oQZFlGLBbD0aNHMT4+juPHjwMA5ufnK/79GWSIiIiIiBrQ//qvEzg5sWDpY+7qb8Yfv263oV976NAhvPWtb4XT6URPTw/uvPNOHD58GDfffDN+7dd+DalUCq9//euxf/9+bN68GRcvXsT73vc+vOY1r8HLX/7yip8rW8uIiIiIiMgyd9xxBx577DEMDAzg7W9/Oz73uc+hra0Nx44dw1133YVPfOIT+I3f+I2Kfx9WZIiIiIiIGpDRykm1vPSlL8UnP/lJ/Oqv/irm5ubw2GOP4aMf/SiuXLmCwcFB/OZv/iZEUcSzzz6LV7/61fB4PHjjG9+IHTt24G1ve1vFvz+DDBERERERmfaGN7wBP/vZz7Bv3z4IgoCPfOQj6O3txb/927/hox/9KNxuN5qamvC5z30O4+PjeMc73gFFUQAA//t//++Kf39BVdWKH6QcBw8eVI8cOVKX35uIiIiIqBGdOnUKIyMj9X4aVZHvf5sgCM+oqnow36/njAwRERERETUcBhkiIiIiImo4DDJERERERNRwGGSIiIiIiBpIvWbcq6mc/00MMkREREREDcLn82F2dnZNhRlVVTE7Owufz2fq67h+mYiIiIioQQwODmJsbAwzMzP1fiqW8vl8GBwcNPU1DDI2Iysq3v+l5/Abt2/CjRvb6v10iIiIiMhG3G43Nm3aVO+nYQtsLbOZhXgK33l+Ek9dmqv3UyEiIiIisi0GGZsRJe3aaTL9f4mIiIiIaDUGGZsRJXnF/yUiIiIiotUYZGyGFRkiIiIiotIYZGxGTDHIEBERERGVwiBjM3pLWVJmkCEiIiIiKoRBxmb01jKRFRkiIiIiooIYZGwmU5FhkCEiIiIiKohBxmb0GRlWZIiIiIiICmOQsRluLSMiIiIiKo1BxmbYWkZEREREVBqDjM1kKjLcWkZEREREVBCDjM3wjgwRERERUWkMMjbD1jIiIiIiotIYZGwmydYyIiIiIqKSGGRsJnMQMyXX+ZkQEREREdkXg4zNcNifiIiIiKg0Bhmb0WdkeBCTiIiIiKgwBhmb4dYyIiIiIqLSGGRsJru1TFXVOj8bIiIiIiJ7YpCxGb21TFUBSWGQISIiIiLKh0HGZrJnY9heRkRERESUH4OMzegzMgAH/omIiIiICmGQsRm9tQxgRYaIiIiIqBAGGZthaxkRERERUWkMMjYjSgoEQfv/k7Jc/BcTEREREa1TDDI2I6ZkNHld2v/PigwRERERUV4MMjYjSgqafW4AbC0jIiIiIiqEQcZmRElByMeKDBERERFRMQwyNiNKMisyREREREQlMMjYiKKoSMlqpiLDIENERERElB+DjI0kZS24ZIKMzCBDRERERJQPg4yNiCk9yLC1jIiIiIioGAYZGxEl7W4MW8uIiIiIiIpjkLERfUuZXpER2VpGRERERJQXg4yN5FZkxJRcz6dDRERERGRbDDI2kkjPyDT70zMyrMgQEREREeXFIGMjy61lnJEhIiIiIiqGQcZG9NYyv9sJl0NgkCEiIiIiKoBBxkb0iozX5YDH5WCQISIiIiIqgEHGRvQ7Mh49yHBGhoiIiIgoLwYZG9Fby7wuJzxOVmSIiIiIiAphkLGR3NYykUGGiIiIiCgvBhkbyQQZtwNezsgQERERERXEIGMj+gFMr8sJj8vJigwRERERUQEMMjayamsZh/2JiIiIiPIyFGQEQXilIAhnBEE4LwjCHxT5dW8UBEEVBOGgdU9x/UhmBRmv04FkevifiIiIiIhWKhlkBEFwAvgYgFcB2AXgrYIg7Mrz60IAPgDgKauf5HohSgo8LgcEQeAdGSIiIiKiIoxUZG4BcF5V1YuqqiYBfAnAA3l+3Z8B+GsACQuf37oiSjK8Lu0/CbeWEREREREVZiTIDAAYzfrnsfSPZQiCcBOADaqqfsfC57buiJICr8sJANxaRkRERERURMXD/oIgOAD8LYDfNfBr3ykIwhFBEI7MzMxU+luvOWJKWVGR4bA/EREREVF+RoLMOIANWf88mP4xXQjAHgCPCIJwGcCLADyYb+BfVdVPqap6UFXVg11dXeU/6zVKlGR43ekg42RFhoiIiIioECNB5jCAbYIgbBIEwQPgLQAe1H9SVdWIqqqdqqoOq6o6DOBJAPerqnqkKs94DctuLeOwPxERERFRYSWDjKqqEoD3AngIwCkAX1ZV9YQgCH8qCML91X6C64kWZLJayxhkiIiIiIjychn5RaqqfhfAd3N+7I8K/Nq7Kn9a65OYytlaxhkZIiIiIqK8Kh72J+uIkgKvO721LD0jo6pqnZ8VEREREZH9MMjYiCgp8Di1/yR6oOHmMiIiIiKi1RhkbCR3axkAzskQEREREeXBIGMjuXdkAAYZIiIiIqJ8GGRsJHf9MsDWMiIiIiKifBhkbESUsraWsbWMiIiIiKggBhkbSUrK8owMW8uIiIiIiApikLEJVVXztpaJDDJERERERKswyNiEPgujt5Z5GWSIiIiIiApikLEJPbBwaxkRERERUWkMMjYhptJBJn0I08utZUREREREBTHI2IQoyQCyKjJOLdCwIkNEREREtBqDjE2wtYyIiIiIyDgGGZvItJat2lom1+05ERERERHZFYOMTWRay3hHhoiIiIioJAYZm8htLeOwPxERERFRYQwyNrEcZFa2lrEiQ0RERES0GoOMTYip3K1lPIhJRERERFQIg4xNrNpa5mRFhoiIiIioEAYZm8htLXM4BLidAmdkiIiIiIjyYJCxidytZYBWldHXMhMRERER0TIGGZtYviOz/J/E63YiKfOODBERERFRLgYZm9BbyPTWMkCryHBGhoiIiIhoNQYZm9ArMp6siozHxSBDRERERJQPg4xNiJIMt1OA0yFkfszjcnDYn4iIiIgoDwYZmxAlZUVbGcDWMiIiIiKiQhhkbEKU5BWD/oBWkeFBTCIiIiKi1RhkbEJMKQwyREREREQGMcjYhCgp8LpXtpZ5OexPRERERJQXg4xN5GstY5AhIiIiIsqPQcYmtGH/1a1l3FpGRERERLQag4xNaDMy3FpGRERERGQEg4xNiJK84hgmwIOYRERERESFMMjYRKHWMlGS6/SMiIiIiIjsi0HGJrStZTlBxulkRYaIiIiIKA8GGZvQtpblrF92c9ifiIiIiCgfBhmbyHsQ0+lASlahKGqdnhURERERkT0xyNhEUs4/I6P/HBE
"text/plain": [
"<Figure size 1008x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<keras.callbacks.callbacks.History at 0x7ff9ed6f3950>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#model.compile(optimizer=opt, loss={'out_mean_acc': 'mean_squared_error', 'out_type0_acc':'mean_squared_error','out_classif':'binary_crossentropy'}, loss_weights={'out_mean_acc': 3, 'out_type0_acc': 3, 'out_classif': 1})\n",
"model.compile(optimizer=opt, loss={'out_mean_acc': 'mean_squared_error', 'out_type0_acc':'mean_squared_error'}, loss_weights={'out_mean_acc': 1, 'out_type0_acc': 1})\n",
"\n",
"print('Ready to train')\n",
"model.fit_generator(gen_train, epochs=n_epochs, verbose=1, callbacks=cbs, validation_data=gen_val, max_queue_size=10, workers=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## validation"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load\n"
]
}
],
"source": [
"data_path = '/your/path/VisRecall/'\n",
"\n",
"if True:\n",
" W = \"./ckpt/RecallNet/U-question/Fold4/RecallNet_xception_ep10_valloss0.0313.hdf5\"\n",
" model.load_weights(W)\n",
" print('load')"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"avg, type: 0.02553353417062307 0.02617591625785007\n"
]
}
],
"source": [
"from sklearn.metrics import mean_squared_error\n",
"\n",
"mean_list = []\n",
"type_list = []\n",
"\n",
"num_t = 0.0\n",
"\n",
"for i in range(len(val_filename)):\n",
" img = preprocess_images([val_filename[i]], 240, 320)\n",
" preds = model.predict(img)\n",
" #if(np.argmax(preds[2][0]) == np.argmax(val_label[i])):\n",
" # num_t += 1\n",
" mean_acc = preds[0][0]\n",
" type_acc = preds[1][0]\n",
" mean_list.append(mean_squared_error([val_mean[i]], mean_acc))\n",
" type_list.append(mean_squared_error([val_type[i]], type_acc))\n",
" #print(val_label[i])\n",
"\n",
"#print('classificaion: ', num_t/len(val_filename))\n",
"print('avg, type: ', np.mean(mean_list),np.mean(type_list))\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load\n",
"[(0.1154690275375791, 0.07620564377135046)]\n",
"load\n",
"[(0.1154690275375791, 0.07620564377135046), (0.10878741427577963, 0.2667425108749781)]\n",
"load\n",
"[(0.1154690275375791, 0.07620564377135046), (0.10878741427577963, 0.2667425108749781), (0.5932855723937102, 0.38129302771926993)]\n",
"load\n",
"[(0.1154690275375791, 0.07620564377135046), (0.10878741427577963, 0.2667425108749781), (0.5932855723937102, 0.38129302771926993), (0.015003968232081454, 0.06507362662601399)]\n",
"load\n",
"[(0.1154690275375791, 0.07620564377135046), (0.10878741427577963, 0.2667425108749781), (0.5932855723937102, 0.38129302771926993), (0.015003968232081454, 0.06507362662601399), (0.02553353417062307, 0.02617591625785007)]\n"
]
}
],
"source": [
"# 0: T, 1: FE, 2: F, 3: RV, 4: U\n",
"TYPE = 4\n",
"final = []\n",
"for split in range(5):\n",
" TYPE_Q = ['T','FE','F','RV','U']\n",
" bp_imp = data_path + 'merged/src/'\n",
" training_set = np.load(data_path + 'training_data/%s-question/train_split%d.npy'%(TYPE_Q[TYPE], split),allow_pickle=True)\n",
" val_set = np.load(data_path + 'training_data/%s-question/val_split%d.npy'%(TYPE_Q[TYPE], split),allow_pickle=True)\n",
" train_filename = []\n",
" train_label = []\n",
" train_mean = []\n",
" train_type = []\n",
"\n",
" for data in training_set:\n",
" one_hot = [0,0,0,0,0,0]\n",
" one_hot[data['vistype']] = 1\n",
" train_filename.append(bp_imp+data['name'])\n",
" train_label.append(one_hot)\n",
" train_mean.append(data['norm_mean_acc_withD'])\n",
" train_type.append(data['norm_%d_withD'%(TYPE)])\n",
" \n",
" val_filename = []\n",
" val_label = []\n",
" val_mean = []\n",
" val_type = []\n",
" for data in val_set:\n",
" one_hot = [0,0,0,0,0,0]\n",
" one_hot[data['vistype']] = 1\n",
" val_filename.append(bp_imp+data['name'])\n",
" val_label.append(one_hot)\n",
" val_mean.append(data['norm_mean_acc_withD'])\n",
" val_type.append(data['norm_%d_withD'%(TYPE)])\n",
"\n",
" mean_list = []\n",
" type_list = []\n",
"\n",
" num_t = 0.0\n",
" if True:\n",
" W = \"./ckpt/RecallNet_xception/U-question/Fold\"+str(split)+\"/RecallNet_xception_ep10.hdf5\"\n",
" model.load_weights(W)\n",
" print('load')\n",
"\n",
" for i in range(len(val_filename)):\n",
" img = preprocess_images([val_filename[i]], 240, 320)\n",
" preds = model.predict(img)\n",
" #if(np.argmax(preds[2][0]) == np.argmax(val_label[i])):\n",
" # num_t += 1\n",
" mean_acc = preds[0][0]\n",
" type_acc = preds[1][0]\n",
" mean_list.append(mean_squared_error([val_mean[i]], mean_acc))\n",
" type_list.append(mean_squared_error([val_type[i]], type_acc))\n",
" final.append((np.mean(mean_list),np.mean(type_list)))\n",
" print(final)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0.1716159 , 0.16309815]), array([0.21483601, 0.13730111]))"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(final ,axis=0),np.std(final ,axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/bar4.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/pie4.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/bar1.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/pie2.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/bar3.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/line1.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/bar2.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/pie1.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/line2.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/line4.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/line3.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/pie3.png']\n"
]
}
],
"source": [
"#W_T = \"./ckpt/VMQA_UMSI/T-question/UMSI_weight1:1/Fold0/VMQA_UMSI_ep10_valloss0.6198.hdf5\"\n",
"W_T = \"./ckpt/RecallNet_xception/T-question/Fold4/RecallNet_xception_ep10.hdf5\"\n",
"\n",
"model.load_weights(W_T)\n",
"\n",
"image_path = \"/your/path/to/src\"\n",
"list_img_targets = glob.glob(image_path + '/*.png')\n",
"print(list_img_targets)\n",
"images = preprocess_images(list_img_targets, model_inp_size[0], model_inp_size[1])\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"preds_recallNet = model.predict(images)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(12, 1) [array([[0.6729137 ],\n",
" [0.62224036],\n",
" [0.6644782 ],\n",
" [0.63536835],\n",
" [0.6595377 ],\n",
" [0.6628296 ],\n",
" [0.600067 ],\n",
" [0.64756936],\n",
" [0.600191 ],\n",
" [0.6459451 ],\n",
" [0.6242438 ],\n",
" [0.60450727]], dtype=float32), array([[0.7803583 ],\n",
" [0.67262685],\n",
" [0.75816834],\n",
" [0.718869 ],\n",
" [0.7442306 ],\n",
" [0.76404536],\n",
" [0.8003189 ],\n",
" [0.7172563 ],\n",
" [0.7985541 ],\n",
" [0.81216127],\n",
" [0.6953638 ],\n",
" [0.6564585 ]], dtype=float32)]\n"
]
}
],
"source": [
"print(preds_recallNet[0].shape, preds_recallNet)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}