visrecall/RecallNet/src/singleduration_models.py

823 lines
33 KiB
Python

import numpy as np
import keras
import sys
import os
from keras.layers import Layer, Input, Multiply, Dropout, TimeDistributed, LSTM, Activation, Lambda, Conv2D, Dense, GlobalAveragePooling2D, MaxPooling2D, ZeroPadding2D, UpSampling2D, BatchNormalization, Concatenate, Add, DepthwiseConv2D
import keras.backend as K
from keras.models import Model
import tensorflow as tf
from keras.utils import Sequence
import cv2
import scipy.io
import math
from attentive_convlstm_new import AttentiveConvLSTM2D
from dcn_resnet_new import dcn_resnet
from gaussian_prior_new import LearningPrior
from sal_imp_utilities import *
from xception_custom import Xception_wrapper
#from keras.applications import keras_modules_injection
from keras.regularizers import l2
def decoder_block(x, dil_rate=(2,2), print_shapes=True, dec_filt=1024):
# Dilated convolutions
x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x)
x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x)
x = UpSampling2D((2,2), interpolation='bilinear')(x)
x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x)
x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x)
x = UpSampling2D((2,2), interpolation='bilinear')(x)
x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x)
x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x)
x = UpSampling2D((4,4), interpolation='bilinear')(x)
if print_shapes: print('Shape after last ups:',x.shape)
# Final conv to get to a heatmap
x = Conv2D(1, kernel_size=1, padding='same', activation='relu')(x)
if print_shapes: print('Shape after 1x1 conv:',x.shape)
return x
def decoder_block_simple(x, dil_rate=(2,2), print_shapes=True, dec_filt=1024):
x = Conv2D(dec_filt, 3, padding='same', activation='relu')(x)
x = UpSampling2D((2,2), interpolation='bilinear')(x)
x = Conv2D(dec_filt//2, 3, padding='same', activation='relu')(x)
x = UpSampling2D((2,2), interpolation='bilinear')(x)
x = Conv2D(dec_filt//4, 3, padding='same', activation='relu')(x)
x = UpSampling2D((4,4), interpolation='bilinear')(x)
if print_shapes: print('Shape after last ups:',x.shape)
# Final conv to get to a heatmap
x = Conv2D(1, kernel_size=1, padding='same', activation='relu')(x)
if print_shapes: print('Shape after 1x1 conv:',x.shape)
return x
def decoder_block_dp(x, dil_rate=(2,2), print_shapes=True, dec_filt=1024, dp=0.3):
# Dilated convolutions
x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x)
x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x)
x = Dropout(dp)(x)
x = UpSampling2D((2,2), interpolation='bilinear')(x)
x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x)
x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x)
x = Dropout(dp)(x)
x = UpSampling2D((2,2), interpolation='bilinear')(x)
x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x)
x = Dropout(dp)(x)
x = UpSampling2D((4,4), interpolation='bilinear')(x)
x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x)
x = Dropout(dp)(x)
x = UpSampling2D((4,4), interpolation='bilinear')(x)
if print_shapes: print('Shape after last ups:',x.shape)
# Final conv to get to a heatmap
x = Conv2D(1, kernel_size=1, padding='same', activation='relu')(x)
if print_shapes: print('Shape after 1x1 conv:',x.shape)
return x
######### ENCODER DECODER MODELS #############
def xception_decoder(input_shape = (shape_r, shape_c, 3),
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
dil_rate = (2,2)):
inp = Input(shape=input_shape)
### ENCODER ###
xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
if print_shapes: print('xception:',xception.output.shape)
## DECODER ##
outs_dec = decoder_block(xception.output, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=512)
outs_final = [outs_dec]*n_outs
# Building model
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
def resnet_decoder(input_shape = (shape_r, shape_c, 3),
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
dil_rate = (2,2)):
inp = Input(shape=input_shape)
### ENCODER ###
dcn = dcn_resnet(input_tensor=inp)
if print_shapes: print('resnet output shape:',dcn.output.shape)
## DECODER ##
outs_dec = decoder_block(dcn.output, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=512)
outs_final = [outs_dec]*n_outs
# Building model
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
def fcn_vgg16(input_shape=(shape_r, shape_c, 3),
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
dil_rate=(2,2),
freeze_enc=False,
freeze_cl=True,
internal_filts=256,
num_classes=4,
dp=0.3,
weight_decay=0.,
batch_shape=None):
if batch_shape:
img_input = Input(batch_shape=batch_shape)
image_size = batch_shape[1:3]
else:
img_input = Input(shape=input_shape)
image_size = input_shape[0:2]
# Block 1
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer=l2(weight_decay))(img_input)
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer=l2(weight_decay))(x)
pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer=l2(weight_decay))(pool4)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
print("pool5 shape", x.shape)
# Convolutional layers transfered from fully-connected layers
x = Conv2D(4096, (7, 7), activation='relu', padding='same', name='fc1', kernel_regularizer=l2(weight_decay))(x)
x = Dropout(0.5)(x)
x = Conv2D(4096, (1, 1), activation='relu', padding='same', name='fc2', kernel_regularizer=l2(weight_decay))(x)
x = Dropout(0.5)(x)
# classification layer from fc7
classif_layer_fc7 = Conv2D(1, (1, 1), kernel_initializer='he_normal', activation='linear',
padding='valid', strides=(1, 1))(x)
print("classif_layer_fc7 shape", classif_layer_fc7.shape)
# Upsampling fc7 classif layer to sum with pool4 classif layer
classif_layer_fc7_ups = UpSampling2D(size=(2,2), interpolation="bilinear")(classif_layer_fc7)
print("classif_layer_fc7_ups shape", classif_layer_fc7_ups.shape)
# Lambda layer to match shape of pool4
def concat_one(fc7):
shape_fc7 = K.shape(fc7)
shape_zeros = (shape_fc7[0], 1, shape_fc7[2], shape_fc7[3] )
return K.concatenate([K.zeros(shape=shape_zeros), classif_layer_fc7_ups], axis=1)
classif_layer_fc7_ups = Lambda(concat_one)(classif_layer_fc7_ups)
print("classif_layer_fc7_ups shape after lambda:", classif_layer_fc7_ups.shape)
# Classification layer from pool4
classif_layer_pool4 = Conv2D(1, (1, 1), kernel_initializer='he_normal', activation='linear',
padding='valid', strides=(1, 1))(pool4)
x = Add()([classif_layer_pool4, classif_layer_fc7_ups])
outs_up = UpSampling2D(size=(32, 32), interpolation="bilinear")(x)
outs_final = [outs_up]*n_outs
model = Model(img_input, outs_final)
weights_path = '../../predimportance_shared/models/ckpt/fcn_vgg16/fcn_vgg16_weights_tf_dim_ordering_tf_kernels.h5'
model.load_weights(weights_path, by_name=True)
if verbose:
model.summary()
return model
############# SAM BASED MODELS ###############
def sam_simple(input_shape = (224, 224, 3), in_conv_filters=512,
verbose=True, print_shapes=True, n_outs=1, ups=8):
'''Simple network that uses an attentive convlstm and a few convolutions.'''
inp = Input(shape=input_shape)
x = Conv2D(filters=in_conv_filters, kernel_size=(3,3), strides=(2, 2), padding='same', data_format=None, dilation_rate=(1,1))(inp)
if print_shapes:
print('after first conv')
x = MaxPooling2D(pool_size=(4,4))(x)
if print_shapes:
print('after maxpool',x.shape)
x = Lambda(repeat, repeat_shape)(x)
if print_shapes:
print('after repeat',x.shape)
x = AttentiveConvLSTM2D(filters=512, attentive_filters=512, kernel_size=(3,3),
attentive_kernel_size=(3,3), padding='same', return_sequences=False)(x)
if print_shapes:
print('after ACLSTM',x.shape)
x = UpSampling2D(size=(ups,ups), interpolation='bilinear')(x)
outs_up = Conv2D(filters=1, kernel_size=(3,3), strides=(1, 1), padding='same', data_format=None, dilation_rate=(1,1))(x)
if print_shapes:
print('output shape',outs_up.shape)
outs_final = [outs_up]*n_outs
att_convlstm = Model(inputs=inp, outputs=outs_final)
if verbose:
att_convlstm.summary()
return att_convlstm
def sam_resnet_nopriors(input_shape = (224, 224, 3), conv_filters=128, lstm_filters=512,
att_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8):
'''Sam ResNet with no priors.'''
inp = Input(shape=input_shape)
dcn = dcn_resnet(input_tensor=inp)
conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output)
if print_shapes:
print('Shape after first conv after dcn_resnet:',conv_feat.shape)
# Attentive ConvLSTM
att_convlstm = Lambda(repeat, repeat_shape)(conv_feat)
att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3),
attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm)
# Dilated convolutions (priors would go here)
dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(att_convlstm)
dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(dil_conv1)
# Final conv to get to a heatmap
outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2)
if print_shapes:
print('Shape after 1x1 conv:',outs.shape)
# Upsampling back to input shape
outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs)
if print_shapes:
print('shape after upsampling',outs_up.shape)
outs_final = [outs_up]*n_outs
# Building model
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
def sam_resnet_new(input_shape = (shape_r, shape_c, 3),
conv_filters=512,
lstm_filters=512,
att_filters=512,
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
nb_gaussian=nb_gaussian):
'''SAM-ResNet ported from the original code.'''
inp = Input(shape=input_shape)
dcn = dcn_resnet(input_tensor=inp)
conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output)
if print_shapes:
print('Shape after first conv after dcn_resnet:',conv_feat.shape)
# Attentive ConvLSTM
att_convlstm = Lambda(repeat, repeat_shape)(conv_feat)
att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters,
attentive_filters=att_filters,
kernel_size=(3,3),
attentive_kernel_size=(3,3),
padding='same',
return_sequences=False)(att_convlstm)
# Learned Prior (1)
priors1 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm)
concat1 = Concatenate(axis=-1)([att_convlstm, priors1])
dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat1)
# Learned Prior (2)
priors2 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm)
concat2 = Concatenate(axis=-1)([dil_conv1, priors2])
dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat2)
# Final conv to get to a heatmap
outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2)
if print_shapes:
print('Shape after 1x1 conv:',outs.shape)
# Upsampling back to input shape
outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs)
if print_shapes:
print('shape after upsampling',outs_up.shape)
outs_final = [outs_up]*n_outs
# Building model
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
def sam_xception_new(input_shape = (shape_r, shape_c, 3), conv_filters=512, lstm_filters=512, att_filters=512,
verbose=True, print_shapes=True, n_outs=1, ups=8, nb_gaussian=nb_gaussian):
'''SAM with a custom Xception as encoder.'''
inp = Input(shape=input_shape)
from xception_custom import Xception
from keras.applications import keras_modules_injection
#@keras_modules_injection
def Xception_wrapper(*args, **kwargs):
return Xception(*args, **kwargs)
inp = Input(shape = input_shape)
dcn = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
if print_shapes: print('xception:',dcn.output.shape)
conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output)
if print_shapes:
print('Shape after first conv after dcn_resnet:',conv_feat.shape)
# Attentive ConvLSTM
att_convlstm = Lambda(repeat, repeat_shape)(conv_feat)
att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3),
attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm)
# Learned Prior (1)
priors1 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm)
concat1 = Concatenate(axis=-1)([att_convlstm, priors1])
dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat1)
# Learned Prior (2)
priors2 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm)
concat2 = Concatenate(axis=-1)([dil_conv1, priors2])
dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat2)
# Final conv to get to a heatmap
outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2)
if print_shapes:
print('Shape after 1x1 conv:',outs.shape)
# Upsampling back to input shape
outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs)
if print_shapes:
print('shape after upsampling',outs_up.shape)
outs_final = [outs_up]*n_outs
# Building model
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
def xception_se_lstm_singledur(input_shape = (shape_r, shape_c, 3),
conv_filters=256,
lstm_filters=512,
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
freeze_enc=False,
return_sequences=False):
inp = Input(shape = input_shape)
### ENCODER ###
xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
if print_shapes: print('xception output shapes:',xception.output.shape)
if freeze_enc:
for layer in xception.layers:
layer.trainable = False
### LSTM over SE representation ###
x = se_lstm_block(xception.output, nb_timestep, lstm_filters=lstm_filters, return_sequences=return_sequences)
### DECODER ###
outs_dec = decoder_block(x, dil_rate=(2,2), print_shapes=print_shapes, dec_filt=conv_filters)
outs_final = [outs_dec]*n_outs
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
def se_lstm_block(inp, nb_timestep, units=512, print_shapes=True, lstm_filters=512, return_sequences=False):
inp_rep = Lambda(lambda y: K.repeat_elements(K.expand_dims(y, axis=1), nb_timestep, axis=1),
lambda s: (s[0], nb_timestep) + s[1:])(inp)
x = TimeDistributed(GlobalAveragePooling2D())(inp_rep)
if print_shapes: print('shape after AvgPool',x.shape)
x = TimeDistributed(Dense(units, activation='relu'))(x)
if print_shapes: print('shape after first dense',x.shape)
# Normally se block would feed into another fully connected. Instead, we feed it to an LSTM.
x = LSTM(lstm_filters, return_sequences=return_sequences, unroll=True, activation='relu')(x)
if print_shapes: print('shape after lstm',x.shape)
x = Dense(inp.shape[-1].value, activation='sigmoid')(x)
if print_shapes: print('shape after second dense:', x.shape)
x = Lambda(lambda y: K.expand_dims(K.expand_dims(y, axis=1),axis=1),
lambda s: (s[0], 1, 1, s[-1]))(x)
if print_shapes: print('shape before mult',x.shape)
out = Multiply()([x,inp])
print('shape out',out.shape)
# out is (bs, r, c, 2048)
return out
def xception_aspp(input_shape = (shape_r, shape_c, 3),
conv_filters=256,
lstm_filters=512,
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
freeze_enc=False,
return_sequences=False):
# Xception
# Conv1,2,ASPP
dil_conv1 = Conv2D(conv_filters, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x)
dil_conv2 = Conv2D(conv_filters, 3, padding='same', activation='relu', dilation_rate=(4, 4))(x)
dil_conv3 = Conv2D(conv_filters, 3, padding='same', activation='relu', dilation_rate=(8, 8))(x)
pass
############# UMSI MODELS ###############
def UMSI(input_shape = (shape_r, shape_c, 3),
conv_filters=256,
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
freeze_enc=False,
return_sequences=False):
inp = Input(shape = input_shape)
### ENCODER ###
xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
if print_shapes: print('xception output shapes:',xception.output.shape)
if freeze_enc:
for layer in xception.layers:
layer.trainable = False
#ASPP
c0 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csep0")(xception.output)
c6 = DepthwiseConv2D((3,3),dilation_rate=(6,6),padding="same",use_bias=False,name="aspp_csepd6_depthwise")(xception.output)
c12 = DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False,name="aspp_csepd12_depthwise")(xception.output)
c18 = DepthwiseConv2D((3,3),dilation_rate=(18,18),padding="same",use_bias=False,name="aspp_csepd18_depthwise")(xception.output)
c6 = BatchNormalization(name="aspp_csepd6_depthwise_BN")(c6)
c12 = BatchNormalization(name="aspp_csepd12_depthwise_BN")(c12)
c18 = BatchNormalization(name="aspp_csepd18_depthwise_BN")(c18)
c6 = Activation("relu", name = "activation_2")(c6)
c12 = Activation("relu", name = "activation_4")(c12)
c18 = Activation("relu", name = "activation_6")(c18)
c6 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd6_pointwise")(c6)
c12 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd12_pointwise")(c12)
c18 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd18_pointwise")(c18)
c0 = BatchNormalization(name='aspp0_BN')(c0)
c6 = BatchNormalization(name='aspp_csepd6_pointwise_BN')(c6)
c12 = BatchNormalization(name='aspp_csepd12_pointwise_BN')(c12)
c18 = BatchNormalization(name='aspp_csepd18_pointwise_BN')(c18)
c0 = Activation("relu", name = "aspp0_activation")(c0)
c6 = Activation("relu", name = "activation_3")(c6)
c12 = Activation("relu", name = "activation_5")(c12)
c18 = Activation("relu", name = "activation_7")(c18)
concat1 = Concatenate(name="concatenate_1")([c0,c6,c12,c18])
### classification module ###
x = Conv2D(256, (3,3), strides = (3,3), padding="same",use_bias=False,name = "global_conv")(xception.output)
x = BatchNormalization(name="global_BN")(x)
x = Activation("relu", name = "activation_1")(x)
x = Dropout(.3, name="dropout_1")(x)
x = GlobalAveragePooling2D(name = "global_average_pooling2d_1")(x)
x = Dense(256, name="global_dense")(x)
classif = Dropout(.3, name="dropout_2")(x)
out_classif = Dense(6, activation="softmax", name="out_classif")(classif)
x = Dense(256, name="dense_fusion")(classif)
def lambda_layer_function(x):
x = tf.reshape(x,(tf.shape(x)[0],1,1,256))
con = [x for i in range(30)]
con = tf.concat(con,axis=1)
con = tf.concat([con for i in range(40)],axis=2)
return con
x = Lambda(lambda_layer_function, name = "lambda_1")(x)
concat2 = Concatenate(name="concatenate_2")([concat1, x])
### DECODER ###
x = Conv2D(256,(1,1),padding="same",use_bias=False,name = "concat_projection")(concat2)
x = BatchNormalization(name="concat_projection_BN")(x)
x = Activation("relu", name="activation_8")(x)
x = Dropout(.3, name="dropout_3")(x)
x = Conv2D(256,(3,3),padding="same",use_bias=False,name = "dec_c1")(x)
x = Conv2D(256,(3,3),padding="same",use_bias=False,name = "dec_c2")(x)
x = Dropout(.3, name="dec_dp1")(x)
x = UpSampling2D(size=(2,2), interpolation='bilinear', name="dec_ups1")(x)
x = Conv2D(128,(3,3),padding="same",use_bias=False,name = "dec_c3")(x)
x = Conv2D(128,(3,3),padding="same",use_bias=False,name = "dec_c4")(x)
x = Dropout(.3, name="dec_dp2")(x)
x = UpSampling2D(size=(2,2), interpolation='bilinear', name="dec_ups2")(x)
x = Conv2D(64,(3,3),padding="same",use_bias=False,name = "dec_c5")(x)
x = Dropout(.3, name="dec_dp3")(x)
x = UpSampling2D(size=(4,4), interpolation='bilinear', name="dec_ups3")(x)
out_heatmap = Conv2D(1,(1,1),padding="same",use_bias=False,name = "dec_c_cout")(x)
# Building model
outs_final = [out_heatmap, out_classif]
print(out_heatmap.shape)
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
############# MODELS FOR RecallNet ###############
def RecallNet_UMSI(input_shape = (shape_r, shape_c, 3),
conv_filters=256,
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
freeze_enc=False,
return_sequences=False):
inp = Input(shape = input_shape)
### ENCODER ###
xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
if print_shapes: print('xception output shapes:',xception.output.shape)
if freeze_enc:
for layer in xception.layers:
layer.trainable = False
#ASPP
c0 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csep0")(xception.output)
c6 = DepthwiseConv2D((3,3),dilation_rate=(6,6),padding="same",use_bias=False,name="aspp_csepd6_depthwise")(xception.output)
c12 = DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False,name="aspp_csepd12_depthwise")(xception.output)
c18 = DepthwiseConv2D((3,3),dilation_rate=(18,18),padding="same",use_bias=False,name="aspp_csepd18_depthwise")(xception.output)
c6 = BatchNormalization(name="aspp_csepd6_depthwise_BN")(c6)
c12 = BatchNormalization(name="aspp_csepd12_depthwise_BN")(c12)
c18 = BatchNormalization(name="aspp_csepd18_depthwise_BN")(c18)
c6 = Activation("relu", name = "activation_2")(c6)
c12 = Activation("relu", name = "activation_4")(c12)
c18 = Activation("relu", name = "activation_6")(c18)
c6 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd6_pointwise")(c6)
c12 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd12_pointwise")(c12)
c18 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd18_pointwise")(c18)
c0 = BatchNormalization(name='aspp0_BN')(c0)
c6 = BatchNormalization(name='aspp_csepd6_pointwise_BN')(c6)
c12 = BatchNormalization(name='aspp_csepd12_pointwise_BN')(c12)
c18 = BatchNormalization(name='aspp_csepd18_pointwise_BN')(c18)
c0 = Activation("relu", name = "aspp0_activation")(c0)
c6 = Activation("relu", name = "activation_3")(c6)
c12 = Activation("relu", name = "activation_5")(c12)
c18 = Activation("relu", name = "activation_7")(c18)
concat1 = Concatenate(name="concatenate_1")([c0,c6,c12,c18])
### classification module ###
x = Conv2D(256, (3,3), strides = (3,3), padding="same",use_bias=False,name = "global_conv")(xception.output)
x = BatchNormalization(name="global_BN")(x)
x = Activation("relu", name = "activation_1")(x)
x = Dropout(.3, name="dropout_1")(x)
x = GlobalAveragePooling2D(name = "global_average_pooling2d_1")(x)
x = Dense(256, name="global_dense")(x)
classif = Dropout(.3, name="dropout_2")(x)
out_classif = Dense(6, activation="softmax", name="out_classif")(classif)
x = Dense(256, name="dense_fusion")(classif)
def lambda_layer_function(x):
x = tf.reshape(x,(tf.shape(x)[0],1,1,256))
con = [x for i in range(30)]
con = tf.concat(con,axis=1)
con = tf.concat([con for i in range(40)],axis=2)
return con
x = Lambda(lambda_layer_function, name = "lambda_1")(x)
concat2 = Concatenate(name="concatenate_2")([concat1, x])
### DECODER ###
flatten = GlobalAveragePooling2D(name = "global_average_pooling2d_2")(concat2)
mean_acc = Dense(256, name="mean_dense")(flatten)
mean_acc = Dense(1, name='out_mean_acc')(mean_acc)
type0_acc = Dense(256, name="type0_dense")(flatten)
type0_acc = Dense(1, name='out_type0_acc')(type0_acc)
# Building model
outs_final = [type0_acc, mean_acc, out_classif]
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
#Model for RecallNet
def RecallNet_xception_aspp(input_shape = (shape_r, shape_c, 3),
conv_filters=256,
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
freeze_enc=False,
return_sequences=False):
inp = Input(shape = input_shape)
### ENCODER ###
xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
if print_shapes: print('xception output shapes:',xception.output.shape)
if freeze_enc:
for layer in xception.layers:
layer.trainable = False
#ASPP
c0 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csep0")(xception.output)
c6 = DepthwiseConv2D((3,3),dilation_rate=(6,6),padding="same",use_bias=False,name="aspp_csepd6_depthwise")(xception.output)
c12 = DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False,name="aspp_csepd12_depthwise")(xception.output)
c18 = DepthwiseConv2D((3,3),dilation_rate=(18,18),padding="same",use_bias=False,name="aspp_csepd18_depthwise")(xception.output)
c6 = BatchNormalization(name="aspp_csepd6_depthwise_BN")(c6)
c12 = BatchNormalization(name="aspp_csepd12_depthwise_BN")(c12)
c18 = BatchNormalization(name="aspp_csepd18_depthwise_BN")(c18)
c6 = Activation("relu", name = "activation_2")(c6)
c12 = Activation("relu", name = "activation_4")(c12)
c18 = Activation("relu", name = "activation_6")(c18)
c6 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd6_pointwise")(c6)
c12 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd12_pointwise")(c12)
c18 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd18_pointwise")(c18)
c0 = BatchNormalization(name='aspp0_BN')(c0)
c6 = BatchNormalization(name='aspp_csepd6_pointwise_BN')(c6)
c12 = BatchNormalization(name='aspp_csepd12_pointwise_BN')(c12)
c18 = BatchNormalization(name='aspp_csepd18_pointwise_BN')(c18)
c0 = Activation("relu", name = "aspp0_activation")(c0)
c6 = Activation("relu", name = "activation_3")(c6)
c12 = Activation("relu", name = "activation_5")(c12)
c18 = Activation("relu", name = "activation_7")(c18)
concat1 = Concatenate(name="concatenate_1")([c0,c6,c12,c18])
### classification module ###
x = Conv2D(256, (3,3), strides = (3,3), padding="same",use_bias=False,name = "global_conv")(xception.output)
x = BatchNormalization(name="global_BN")(x)
x = Activation("relu", name = "activation_1")(x)
x = Dropout(.3, name="dropout_1")(x)
x = GlobalAveragePooling2D(name = "global_average_pooling2d_1")(x)
x = Dense(256, name="global_dense")(x)
classif = Dropout(.3, name="dropout_2")(x)
#out_classif = Dense(6, activation="softmax", name="out_classif")(classif)
x = Dense(256, name="dense_fusion")(classif)
def lambda_layer_function(x):
x = tf.reshape(x,(tf.shape(x)[0],1,1,256))
con = [x for i in range(30)]
con = tf.concat(con,axis=1)
con = tf.concat([con for i in range(40)],axis=2)
return con
x = Lambda(lambda_layer_function, name = "lambda_1")(x)
concat2 = Concatenate(name="concatenate_2")([concat1, x])
### DECODER ###
flatten = GlobalAveragePooling2D(name = "global_average_pooling2d_2")(concat2)
mean_acc = Dense(256, name="mean_dense")(flatten)
mean_acc = Dense(1, name='out_mean_acc')(mean_acc)
type0_acc = Dense(256, name="type0_dense")(flatten)
type0_acc = Dense(1, name='out_type0_acc')(type0_acc)
# Building model
outs_final = [mean_acc, type0_acc]
m = Model(inp, outs_final)
if verbose:
m.summary()
return m
#Model for VMQA
def RecallNet_xception(input_shape = (shape_r, shape_c, 3),
conv_filters=256,
verbose=True,
print_shapes=True,
n_outs=1,
ups=8,
freeze_enc=False,
return_sequences=False):
inp = Input(shape = input_shape)
### ENCODER ###
xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
if print_shapes: print('xception output shapes:',xception.output.shape)
if freeze_enc:
for layer in xception.layers:
layer.trainable = False
x = Conv2D(256, (3,3), strides = (3,3), padding="same",use_bias=False,name = "global_conv")(xception.output)
x = BatchNormalization(name="global_BN")(x)
x = Activation("relu", name = "activation_1")(x)
x = Dropout(.3, name="dropout_1")(x)
x = GlobalAveragePooling2D(name = "global_average_pooling2d_1")(x)
x = Dense(256, name="global_dense")(x)
classif = Dropout(.3, name="dropout_2")(x)
# out_classif = Dense(6, activation="softmax", name="out_classif")(classif)
mean_acc = Dense(1, name='out_mean_acc')(classif)
type_0 = Dense(1, name='out_type0_acc')(classif)
# Building model
outs_final = [ mean_acc, type_0]
m = Model(inp, outs_final)
if verbose:
m.summary()
return m