visrecall/RecallNet/notebooks/.ipynb_checkpoints/train_RecallNet-checkpoint.ipynb
2022-05-09 14:32:31 +02:00

113 KiB

In [1]:
import numpy as np
import keras
import sys
import os, glob
from keras.models import Model
import tensorflow as tf
from keras.utils import Sequence
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
#from IPython.display import clear_output
import tqdm 
import math

sys.path.append('../src')

from util import get_model_by_name

from sal_imp_utilities import *
from cb import InteractivePlot

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

%load_ext autoreload
%autoreload 2
Using TensorFlow backend.
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
In [2]:
%%bash
nvidia-smi
Mon May  9 14:23:41 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:29:00.0  On |                  N/A |
| 30%   41C    P2    38W / 184W |    975MiB /  7959MiB |      5%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1377      G   /usr/lib/xorg/Xorg                 95MiB |
|    0   N/A  N/A      2243      G   /usr/lib/xorg/Xorg                432MiB |
|    0   N/A  N/A      2378      G   /usr/bin/gnome-shell               72MiB |
|    0   N/A  N/A      3521      G   ...AAAAAAAAA= --shared-files       16MiB |
|    0   N/A  N/A      5912      G   ...RendererForSitePerProcess      103MiB |
|    0   N/A  N/A     18975      G   ...sktop/bin/mendeleydesktop       42MiB |
|    0   N/A  N/A    168057      G   ...178058406726361824,131072       78MiB |
|    0   N/A  N/A    178813      C   .../envs/tf-cuda9/bin/python       87MiB |
+-----------------------------------------------------------------------------+
In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["CUDA_VISIBLE_DEVICES"]
Out[3]:
'0'

ENV

In [4]:
data_path = '/your/path/VisRecall/'
# 0: T, 1: FE, 2: F, 3: RV, 4: U
TYPE = 0
split = 0

Load data

In [5]:
TYPE_Q = ['T','FE','F','RV','U']
bp_imp = data_path + 'merged/src/'
training_set = np.load(data_path + 'training_data/%s-question/train_split%d.npy'%(TYPE_Q[TYPE], split),allow_pickle=True)
val_set = np.load(data_path + 'training_data/%s-question/val_split%d.npy'%(TYPE_Q[TYPE], split),allow_pickle=True)
In [6]:
train_filename = []
train_label = []
train_mean = []
train_type = []

for data in training_set:
    one_hot = [0,0,0,0,0,0]
    one_hot[data['vistype']] = 1
    train_filename.append(bp_imp+data['name'])
    train_label.append(one_hot)
    train_mean.append(data['norm_mean_acc_withD'])
    train_type.append(data['norm_%d_withD'%(TYPE)])
In [7]:
val_filename = []
val_label = []
val_mean = []
val_type = []
for data in val_set:
    one_hot = [0,0,0,0,0,0]
    one_hot[data['vistype']] = 1
    val_filename.append(bp_imp+data['name'])
    val_label.append(one_hot)
    val_mean.append(data['norm_mean_acc_withD'])
    val_type.append(data['norm_%d_withD'%(TYPE)])
#val_filename

Model and training params

In [8]:
# FILL THESE IN: set training parameters 
ckpt_savedir = "ckpt"

load_weights = False
weightspath = ""

batch_size = 4
init_lr = 0.002
lr_reduce_by = .1
reduce_at_epoch = 3
n_epochs = 15

opt = Adam(lr=init_lr) 


model_name = "RecallNet_xception"

model_inp_size = (240, 320)
In [9]:
input_shape = model_inp_size + (3,)
In [10]:
# get model 
model_params = {
    'input_shape': input_shape,
    #'n_outs': len(losses),
    'n_outs': 2
}
model_func, mode = get_model_by_name(model_name)
#assert mode == "simple", "%s is a multi-duration model! Please use the multi-duration notebook to train." % model_name
model = model_func(**model_params)

if load_weights: 
    model.load_weights(weightspath)
    print("load")
WARNING:tensorflow:From /netpool/homes/wangyo/.conda/envs/tf-cuda9/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

xception output shapes: (?, 30, 40, 2048)
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 240, 320, 3)  0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 119, 159, 32) 864         input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv1_bn (BatchNormaliza (None, 119, 159, 32) 128         block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_conv1_act (Activation)   (None, 119, 159, 32) 0           block1_conv1_bn[0][0]            
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 117, 157, 64) 18432       block1_conv1_act[0][0]           
__________________________________________________________________________________________________
block1_conv2_bn (BatchNormaliza (None, 117, 157, 64) 256         block1_conv2[0][0]               
__________________________________________________________________________________________________
block1_conv2_act (Activation)   (None, 117, 157, 64) 0           block1_conv2_bn[0][0]            
__________________________________________________________________________________________________
block2_sepconv1 (SeparableConv2 (None, 117, 157, 128 8768        block1_conv2_act[0][0]           
__________________________________________________________________________________________________
block2_sepconv1_bn (BatchNormal (None, 117, 157, 128 512         block2_sepconv1[0][0]            
__________________________________________________________________________________________________
block2_sepconv2_act (Activation (None, 117, 157, 128 0           block2_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block2_sepconv2 (SeparableConv2 (None, 117, 157, 128 17536       block2_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block2_sepconv2_bn (BatchNormal (None, 117, 157, 128 512         block2_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 59, 79, 128)  8192        block1_conv2_act[0][0]           
__________________________________________________________________________________________________
block2_pool (MaxPooling2D)      (None, 59, 79, 128)  0           block2_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 59, 79, 128)  512         conv2d_1[0][0]                   
__________________________________________________________________________________________________
add_1 (Add)                     (None, 59, 79, 128)  0           block2_pool[0][0]                
                                                                 batch_normalization_1[0][0]      
__________________________________________________________________________________________________
block3_sepconv1_act (Activation (None, 59, 79, 128)  0           add_1[0][0]                      
__________________________________________________________________________________________________
block3_sepconv1 (SeparableConv2 (None, 59, 79, 256)  33920       block3_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block3_sepconv1_bn (BatchNormal (None, 59, 79, 256)  1024        block3_sepconv1[0][0]            
__________________________________________________________________________________________________
block3_sepconv2_act (Activation (None, 59, 79, 256)  0           block3_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block3_sepconv2 (SeparableConv2 (None, 59, 79, 256)  67840       block3_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block3_sepconv2_bn (BatchNormal (None, 59, 79, 256)  1024        block3_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 30, 40, 256)  32768       add_1[0][0]                      
__________________________________________________________________________________________________
block3_pool (MaxPooling2D)      (None, 30, 40, 256)  0           block3_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 30, 40, 256)  1024        conv2d_2[0][0]                   
__________________________________________________________________________________________________
add_2 (Add)                     (None, 30, 40, 256)  0           block3_pool[0][0]                
                                                                 batch_normalization_2[0][0]      
__________________________________________________________________________________________________
block4_sepconv1_act (Activation (None, 30, 40, 256)  0           add_2[0][0]                      
__________________________________________________________________________________________________
block4_sepconv1 (SeparableConv2 (None, 30, 40, 728)  188672      block4_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block4_sepconv1_bn (BatchNormal (None, 30, 40, 728)  2912        block4_sepconv1[0][0]            
__________________________________________________________________________________________________
block4_sepconv2_act (Activation (None, 30, 40, 728)  0           block4_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block4_sepconv2 (SeparableConv2 (None, 30, 40, 728)  536536      block4_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block4_sepconv2_bn (BatchNormal (None, 30, 40, 728)  2912        block4_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 30, 40, 728)  186368      add_2[0][0]                      
__________________________________________________________________________________________________
block4_pool (MaxPooling2D)      (None, 30, 40, 728)  0           block4_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 30, 40, 728)  2912        conv2d_3[0][0]                   
__________________________________________________________________________________________________
add_3 (Add)                     (None, 30, 40, 728)  0           block4_pool[0][0]                
                                                                 batch_normalization_3[0][0]      
__________________________________________________________________________________________________
block5_sepconv1_act (Activation (None, 30, 40, 728)  0           add_3[0][0]                      
__________________________________________________________________________________________________
block5_sepconv1 (SeparableConv2 (None, 30, 40, 728)  536536      block5_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv1_bn (BatchNormal (None, 30, 40, 728)  2912        block5_sepconv1[0][0]            
__________________________________________________________________________________________________
block5_sepconv2_act (Activation (None, 30, 40, 728)  0           block5_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block5_sepconv2 (SeparableConv2 (None, 30, 40, 728)  536536      block5_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv2_bn (BatchNormal (None, 30, 40, 728)  2912        block5_sepconv2[0][0]            
__________________________________________________________________________________________________
block5_sepconv3_act (Activation (None, 30, 40, 728)  0           block5_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block5_sepconv3 (SeparableConv2 (None, 30, 40, 728)  536536      block5_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv3_bn (BatchNormal (None, 30, 40, 728)  2912        block5_sepconv3[0][0]            
__________________________________________________________________________________________________
add_4 (Add)                     (None, 30, 40, 728)  0           block5_sepconv3_bn[0][0]         
                                                                 add_3[0][0]                      
__________________________________________________________________________________________________
block6_sepconv1_act (Activation (None, 30, 40, 728)  0           add_4[0][0]                      
__________________________________________________________________________________________________
block6_sepconv1 (SeparableConv2 (None, 30, 40, 728)  536536      block6_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv1_bn (BatchNormal (None, 30, 40, 728)  2912        block6_sepconv1[0][0]            
__________________________________________________________________________________________________
block6_sepconv2_act (Activation (None, 30, 40, 728)  0           block6_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block6_sepconv2 (SeparableConv2 (None, 30, 40, 728)  536536      block6_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv2_bn (BatchNormal (None, 30, 40, 728)  2912        block6_sepconv2[0][0]            
__________________________________________________________________________________________________
block6_sepconv3_act (Activation (None, 30, 40, 728)  0           block6_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block6_sepconv3 (SeparableConv2 (None, 30, 40, 728)  536536      block6_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv3_bn (BatchNormal (None, 30, 40, 728)  2912        block6_sepconv3[0][0]            
__________________________________________________________________________________________________
add_5 (Add)                     (None, 30, 40, 728)  0           block6_sepconv3_bn[0][0]         
                                                                 add_4[0][0]                      
__________________________________________________________________________________________________
block7_sepconv1_act (Activation (None, 30, 40, 728)  0           add_5[0][0]                      
__________________________________________________________________________________________________
block7_sepconv1 (SeparableConv2 (None, 30, 40, 728)  536536      block7_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv1_bn (BatchNormal (None, 30, 40, 728)  2912        block7_sepconv1[0][0]            
__________________________________________________________________________________________________
block7_sepconv2_act (Activation (None, 30, 40, 728)  0           block7_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block7_sepconv2 (SeparableConv2 (None, 30, 40, 728)  536536      block7_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv2_bn (BatchNormal (None, 30, 40, 728)  2912        block7_sepconv2[0][0]            
__________________________________________________________________________________________________
block7_sepconv3_act (Activation (None, 30, 40, 728)  0           block7_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block7_sepconv3 (SeparableConv2 (None, 30, 40, 728)  536536      block7_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv3_bn (BatchNormal (None, 30, 40, 728)  2912        block7_sepconv3[0][0]            
__________________________________________________________________________________________________
add_6 (Add)                     (None, 30, 40, 728)  0           block7_sepconv3_bn[0][0]         
                                                                 add_5[0][0]                      
__________________________________________________________________________________________________
block8_sepconv1_act (Activation (None, 30, 40, 728)  0           add_6[0][0]                      
__________________________________________________________________________________________________
block8_sepconv1 (SeparableConv2 (None, 30, 40, 728)  536536      block8_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv1_bn (BatchNormal (None, 30, 40, 728)  2912        block8_sepconv1[0][0]            
__________________________________________________________________________________________________
block8_sepconv2_act (Activation (None, 30, 40, 728)  0           block8_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block8_sepconv2 (SeparableConv2 (None, 30, 40, 728)  536536      block8_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv2_bn (BatchNormal (None, 30, 40, 728)  2912        block8_sepconv2[0][0]            
__________________________________________________________________________________________________
block8_sepconv3_act (Activation (None, 30, 40, 728)  0           block8_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block8_sepconv3 (SeparableConv2 (None, 30, 40, 728)  536536      block8_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv3_bn (BatchNormal (None, 30, 40, 728)  2912        block8_sepconv3[0][0]            
__________________________________________________________________________________________________
add_7 (Add)                     (None, 30, 40, 728)  0           block8_sepconv3_bn[0][0]         
                                                                 add_6[0][0]                      
__________________________________________________________________________________________________
block9_sepconv1_act (Activation (None, 30, 40, 728)  0           add_7[0][0]                      
__________________________________________________________________________________________________
block9_sepconv1 (SeparableConv2 (None, 30, 40, 728)  536536      block9_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv1_bn (BatchNormal (None, 30, 40, 728)  2912        block9_sepconv1[0][0]            
__________________________________________________________________________________________________
block9_sepconv2_act (Activation (None, 30, 40, 728)  0           block9_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block9_sepconv2 (SeparableConv2 (None, 30, 40, 728)  536536      block9_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv2_bn (BatchNormal (None, 30, 40, 728)  2912        block9_sepconv2[0][0]            
__________________________________________________________________________________________________
block9_sepconv3_act (Activation (None, 30, 40, 728)  0           block9_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block9_sepconv3 (SeparableConv2 (None, 30, 40, 728)  536536      block9_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv3_bn (BatchNormal (None, 30, 40, 728)  2912        block9_sepconv3[0][0]            
__________________________________________________________________________________________________
add_8 (Add)                     (None, 30, 40, 728)  0           block9_sepconv3_bn[0][0]         
                                                                 add_7[0][0]                      
__________________________________________________________________________________________________
block10_sepconv1_act (Activatio (None, 30, 40, 728)  0           add_8[0][0]                      
__________________________________________________________________________________________________
block10_sepconv1 (SeparableConv (None, 30, 40, 728)  536536      block10_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv1_bn (BatchNorma (None, 30, 40, 728)  2912        block10_sepconv1[0][0]           
__________________________________________________________________________________________________
block10_sepconv2_act (Activatio (None, 30, 40, 728)  0           block10_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block10_sepconv2 (SeparableConv (None, 30, 40, 728)  536536      block10_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv2_bn (BatchNorma (None, 30, 40, 728)  2912        block10_sepconv2[0][0]           
__________________________________________________________________________________________________
block10_sepconv3_act (Activatio (None, 30, 40, 728)  0           block10_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block10_sepconv3 (SeparableConv (None, 30, 40, 728)  536536      block10_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv3_bn (BatchNorma (None, 30, 40, 728)  2912        block10_sepconv3[0][0]           
__________________________________________________________________________________________________
add_9 (Add)                     (None, 30, 40, 728)  0           block10_sepconv3_bn[0][0]        
                                                                 add_8[0][0]                      
__________________________________________________________________________________________________
block11_sepconv1_act (Activatio (None, 30, 40, 728)  0           add_9[0][0]                      
__________________________________________________________________________________________________
block11_sepconv1 (SeparableConv (None, 30, 40, 728)  536536      block11_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv1_bn (BatchNorma (None, 30, 40, 728)  2912        block11_sepconv1[0][0]           
__________________________________________________________________________________________________
block11_sepconv2_act (Activatio (None, 30, 40, 728)  0           block11_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block11_sepconv2 (SeparableConv (None, 30, 40, 728)  536536      block11_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv2_bn (BatchNorma (None, 30, 40, 728)  2912        block11_sepconv2[0][0]           
__________________________________________________________________________________________________
block11_sepconv3_act (Activatio (None, 30, 40, 728)  0           block11_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block11_sepconv3 (SeparableConv (None, 30, 40, 728)  536536      block11_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv3_bn (BatchNorma (None, 30, 40, 728)  2912        block11_sepconv3[0][0]           
__________________________________________________________________________________________________
add_10 (Add)                    (None, 30, 40, 728)  0           block11_sepconv3_bn[0][0]        
                                                                 add_9[0][0]                      
__________________________________________________________________________________________________
block12_sepconv1_act (Activatio (None, 30, 40, 728)  0           add_10[0][0]                     
__________________________________________________________________________________________________
block12_sepconv1 (SeparableConv (None, 30, 40, 728)  536536      block12_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv1_bn (BatchNorma (None, 30, 40, 728)  2912        block12_sepconv1[0][0]           
__________________________________________________________________________________________________
block12_sepconv2_act (Activatio (None, 30, 40, 728)  0           block12_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block12_sepconv2 (SeparableConv (None, 30, 40, 728)  536536      block12_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv2_bn (BatchNorma (None, 30, 40, 728)  2912        block12_sepconv2[0][0]           
__________________________________________________________________________________________________
block12_sepconv3_act (Activatio (None, 30, 40, 728)  0           block12_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block12_sepconv3 (SeparableConv (None, 30, 40, 728)  536536      block12_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv3_bn (BatchNorma (None, 30, 40, 728)  2912        block12_sepconv3[0][0]           
__________________________________________________________________________________________________
add_11 (Add)                    (None, 30, 40, 728)  0           block12_sepconv3_bn[0][0]        
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
block13_sepconv1_act (Activatio (None, 30, 40, 728)  0           add_11[0][0]                     
__________________________________________________________________________________________________
block13_sepconv1 (SeparableConv (None, 30, 40, 728)  536536      block13_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block13_sepconv1_bn (BatchNorma (None, 30, 40, 728)  2912        block13_sepconv1[0][0]           
__________________________________________________________________________________________________
block13_sepconv2_act (Activatio (None, 30, 40, 728)  0           block13_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block13_sepconv2 (SeparableConv (None, 30, 40, 1024) 752024      block13_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block13_sepconv2_bn (BatchNorma (None, 30, 40, 1024) 4096        block13_sepconv2[0][0]           
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 30, 40, 1024) 745472      add_11[0][0]                     
__________________________________________________________________________________________________
block13_pool (MaxPooling2D)     (None, 30, 40, 1024) 0           block13_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 30, 40, 1024) 4096        conv2d_4[0][0]                   
__________________________________________________________________________________________________
add_12 (Add)                    (None, 30, 40, 1024) 0           block13_pool[0][0]               
                                                                 batch_normalization_4[0][0]      
__________________________________________________________________________________________________
block14_sepconv1 (SeparableConv (None, 30, 40, 1536) 1582080     add_12[0][0]                     
__________________________________________________________________________________________________
block14_sepconv1_bn (BatchNorma (None, 30, 40, 1536) 6144        block14_sepconv1[0][0]           
__________________________________________________________________________________________________
block14_sepconv1_act (Activatio (None, 30, 40, 1536) 0           block14_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block14_sepconv2 (SeparableConv (None, 30, 40, 2048) 3159552     block14_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block14_sepconv2_bn (BatchNorma (None, 30, 40, 2048) 8192        block14_sepconv2[0][0]           
__________________________________________________________________________________________________
block14_sepconv2_act (Activatio (None, 30, 40, 2048) 0           block14_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
global_conv (Conv2D)            (None, 10, 14, 256)  4718592     block14_sepconv2_act[0][0]       
__________________________________________________________________________________________________
global_BN (BatchNormalization)  (None, 10, 14, 256)  1024        global_conv[0][0]                
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 10, 14, 256)  0           global_BN[0][0]                  
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 10, 14, 256)  0           activation_1[0][0]               
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 256)          0           dropout_1[0][0]                  
__________________________________________________________________________________________________
global_dense (Dense)            (None, 256)          65792       global_average_pooling2d_1[0][0] 
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 256)          0           global_dense[0][0]               
__________________________________________________________________________________________________
out_mean_acc (Dense)            (None, 1)            257         dropout_2[0][0]                  
__________________________________________________________________________________________________
out_type0_acc (Dense)           (None, 1)            257         dropout_2[0][0]                  
==================================================================================================
Total params: 25,647,402
Trainable params: 25,592,362
Non-trainable params: 55,040
__________________________________________________________________________________________________
In [11]:
# set up data generation and checkpoints
if not os.path.exists(ckpt_savedir): 
    os.makedirs(ckpt_savedir)


# Generators
gen_train = RecallNet_Generator(train_filename,
                           train_label,
                           train_mean,
                           train_type,
                           batch_size = 4)

gen_val = RecallNet_Generator(val_filename,
                           val_label,
                           val_mean,
                           val_type,
                           1)

# Callbacks

# where to save checkpoints
#filepath = os.path.join(ckpt_savedir, dataset_sal + "_" + l_str + '_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5')
filepath = os.path.join(ckpt_savedir, model_name, model_name+'_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5')

print("Checkpoints will be saved with format %s" % filepath)

cb_chk = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_weights_only=False, period=1)
cb_plot = InteractivePlot()

def step_decay(epoch):
    lrate = init_lr * math.pow(lr_reduce_by, math.floor((1+epoch)/reduce_at_epoch))
    if epoch%reduce_at_epoch:
        print('Reducing lr. New lr is:', lrate)
    return lrate
cb_sched = LearningRateScheduler(step_decay)

cbs = [cb_chk, cb_sched, cb_plot]
Checkpoints will be saved with format ckpt/RecallNet_xception/RecallNet_xception_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5
In [12]:
#test the generator 
batch_img, out= gen_train.__getitem__(1)
print(out)
[array([0.74688797, 0.84647303, 0.93360996, 0.8340249 ]), array([0.7826087 , 0.95652174, 1.        , 0.86956522])]

Train

In [14]:
#model.compile(optimizer=opt, loss={'out_mean_acc': 'mean_squared_error', 'out_type0_acc':'mean_squared_error','out_classif':'binary_crossentropy'}, loss_weights={'out_mean_acc': 3, 'out_type0_acc': 3, 'out_classif': 1})
model.compile(optimizer=opt, loss={'out_mean_acc': 'mean_squared_error', 'out_type0_acc':'mean_squared_error'}, loss_weights={'out_mean_acc': 1, 'out_type0_acc': 1})

print('Ready to train')
model.fit_generator(gen_train, epochs=n_epochs, verbose=1, callbacks=cbs, validation_data=gen_val, max_queue_size=10, workers=5)
Out[14]:
<keras.callbacks.callbacks.History at 0x7ff9ed6f3950>

validation

In [77]:
data_path = '/your/path/VisRecall/'

if True:
    W = "./ckpt/RecallNet/U-question/Fold4/RecallNet_xception_ep10_valloss0.0313.hdf5"
    model.load_weights(W)
    print('load')
load
In [78]:
from sklearn.metrics import mean_squared_error

mean_list = []
type_list = []

num_t = 0.0

for i in range(len(val_filename)):
    img = preprocess_images([val_filename[i]], 240, 320)
    preds = model.predict(img)
    #if(np.argmax(preds[2][0]) == np.argmax(val_label[i])):
    #    num_t += 1
    mean_acc = preds[0][0]
    type_acc = preds[1][0]
    mean_list.append(mean_squared_error([val_mean[i]], mean_acc))
    type_list.append(mean_squared_error([val_type[i]], type_acc))
    #print(val_label[i])

#print('classificaion: ', num_t/len(val_filename))
print('avg, type: ', np.mean(mean_list),np.mean(type_list))
avg, type:  0.02553353417062307 0.02617591625785007
In [13]:
from sklearn.metrics import mean_squared_error
In [24]:
# 0: T, 1: FE, 2: F, 3: RV, 4: U
TYPE = 4
final = []
for split in range(5):
    TYPE_Q = ['T','FE','F','RV','U']
    bp_imp = data_path + 'merged/src/'
    training_set = np.load(data_path + 'training_data/%s-question/train_split%d.npy'%(TYPE_Q[TYPE], split),allow_pickle=True)
    val_set = np.load(data_path + 'training_data/%s-question/val_split%d.npy'%(TYPE_Q[TYPE], split),allow_pickle=True)
    train_filename = []
    train_label = []
    train_mean = []
    train_type = []

    for data in training_set:
        one_hot = [0,0,0,0,0,0]
        one_hot[data['vistype']] = 1
        train_filename.append(bp_imp+data['name'])
        train_label.append(one_hot)
        train_mean.append(data['norm_mean_acc_withD'])
        train_type.append(data['norm_%d_withD'%(TYPE)])
    
    val_filename = []
    val_label = []
    val_mean = []
    val_type = []
    for data in val_set:
        one_hot = [0,0,0,0,0,0]
        one_hot[data['vistype']] = 1
        val_filename.append(bp_imp+data['name'])
        val_label.append(one_hot)
        val_mean.append(data['norm_mean_acc_withD'])
        val_type.append(data['norm_%d_withD'%(TYPE)])

    mean_list = []
    type_list = []

    num_t = 0.0
    if True:
        W = "./ckpt/RecallNet_xception/U-question/Fold"+str(split)+"/RecallNet_xception_ep10.hdf5"
        model.load_weights(W)
        print('load')

    for i in range(len(val_filename)):
        img = preprocess_images([val_filename[i]], 240, 320)
        preds = model.predict(img)
        #if(np.argmax(preds[2][0]) == np.argmax(val_label[i])):
        #    num_t += 1
        mean_acc = preds[0][0]
        type_acc = preds[1][0]
        mean_list.append(mean_squared_error([val_mean[i]], mean_acc))
        type_list.append(mean_squared_error([val_type[i]], type_acc))
    final.append((np.mean(mean_list),np.mean(type_list)))
    print(final)
load
[(0.1154690275375791, 0.07620564377135046)]
load
[(0.1154690275375791, 0.07620564377135046), (0.10878741427577963, 0.2667425108749781)]
load
[(0.1154690275375791, 0.07620564377135046), (0.10878741427577963, 0.2667425108749781), (0.5932855723937102, 0.38129302771926993)]
load
[(0.1154690275375791, 0.07620564377135046), (0.10878741427577963, 0.2667425108749781), (0.5932855723937102, 0.38129302771926993), (0.015003968232081454, 0.06507362662601399)]
load
[(0.1154690275375791, 0.07620564377135046), (0.10878741427577963, 0.2667425108749781), (0.5932855723937102, 0.38129302771926993), (0.015003968232081454, 0.06507362662601399), (0.02553353417062307, 0.02617591625785007)]
In [25]:
np.mean(final ,axis=0),np.std(final ,axis=0)
Out[25]:
(array([0.1716159 , 0.16309815]), array([0.21483601, 0.13730111]))

Test

In [36]:
#W_T = "./ckpt/VMQA_UMSI/T-question/UMSI_weight1:1/Fold0/VMQA_UMSI_ep10_valloss0.6198.hdf5"
W_T = "./ckpt/RecallNet_xception/T-question/Fold4/RecallNet_xception_ep10.hdf5"

model.load_weights(W_T)

image_path = "/your/path/to/src"
list_img_targets = glob.glob(image_path + '/*.png')
print(list_img_targets)
images = preprocess_images(list_img_targets, model_inp_size[0], model_inp_size[1])
['/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/bar4.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/pie4.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/bar1.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/pie2.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/bar3.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/line1.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/bar2.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/pie1.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/line2.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/line4.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/line3.png', '/netpool/homes/wangyo/Projects/2020_luwei_VisQA/Dataset/application/src/pie3.png']
In [37]:
preds_recallNet = model.predict(images)
In [38]:
print(preds_recallNet[0].shape, preds_recallNet)
(12, 1) [array([[0.6729137 ],
       [0.62224036],
       [0.6644782 ],
       [0.63536835],
       [0.6595377 ],
       [0.6628296 ],
       [0.600067  ],
       [0.64756936],
       [0.600191  ],
       [0.6459451 ],
       [0.6242438 ],
       [0.60450727]], dtype=float32), array([[0.7803583 ],
       [0.67262685],
       [0.75816834],
       [0.718869  ],
       [0.7442306 ],
       [0.76404536],
       [0.8003189 ],
       [0.7172563 ],
       [0.7985541 ],
       [0.81216127],
       [0.6953638 ],
       [0.6564585 ]], dtype=float32)]