import numpy as np import keras import sys import os from keras.layers import Layer, Input, Multiply, Dropout, TimeDistributed, LSTM, Activation, Lambda, Conv2D, Dense, GlobalAveragePooling2D, MaxPooling2D, ZeroPadding2D, UpSampling2D, BatchNormalization, Concatenate, Add, DepthwiseConv2D import keras.backend as K from keras.models import Model import tensorflow as tf from keras.utils import Sequence import cv2 import scipy.io import math from attentive_convlstm_new import AttentiveConvLSTM2D from dcn_resnet_new import dcn_resnet from gaussian_prior_new import LearningPrior from sal_imp_utilities import * from xception_custom import Xception_wrapper #from keras.applications import keras_modules_injection from keras.regularizers import l2 def decoder_block(x, dil_rate=(2,2), print_shapes=True, dec_filt=1024): # Dilated convolutions x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x) x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x) x = UpSampling2D((2,2), interpolation='bilinear')(x) x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x) x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x) x = UpSampling2D((2,2), interpolation='bilinear')(x) x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x) x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x) x = UpSampling2D((4,4), interpolation='bilinear')(x) if print_shapes: print('Shape after last ups:',x.shape) # Final conv to get to a heatmap x = Conv2D(1, kernel_size=1, padding='same', activation='relu')(x) if print_shapes: print('Shape after 1x1 conv:',x.shape) return x def decoder_block_simple(x, dil_rate=(2,2), print_shapes=True, dec_filt=1024): x = Conv2D(dec_filt, 3, padding='same', activation='relu')(x) x = UpSampling2D((2,2), interpolation='bilinear')(x) x = Conv2D(dec_filt//2, 3, padding='same', activation='relu')(x) x = UpSampling2D((2,2), interpolation='bilinear')(x) x = Conv2D(dec_filt//4, 3, padding='same', activation='relu')(x) x = UpSampling2D((4,4), interpolation='bilinear')(x) if print_shapes: print('Shape after last ups:',x.shape) # Final conv to get to a heatmap x = Conv2D(1, kernel_size=1, padding='same', activation='relu')(x) if print_shapes: print('Shape after 1x1 conv:',x.shape) return x def decoder_block_dp(x, dil_rate=(2,2), print_shapes=True, dec_filt=1024, dp=0.3): # Dilated convolutions x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x) x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x) x = Dropout(dp)(x) x = UpSampling2D((2,2), interpolation='bilinear')(x) x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x) x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x) x = Dropout(dp)(x) x = UpSampling2D((2,2), interpolation='bilinear')(x) x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x) x = Dropout(dp)(x) x = UpSampling2D((4,4), interpolation='bilinear')(x) x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=dil_rate)(x) x = Dropout(dp)(x) x = UpSampling2D((4,4), interpolation='bilinear')(x) if print_shapes: print('Shape after last ups:',x.shape) # Final conv to get to a heatmap x = Conv2D(1, kernel_size=1, padding='same', activation='relu')(x) if print_shapes: print('Shape after 1x1 conv:',x.shape) return x ######### ENCODER DECODER MODELS ############# def xception_decoder(input_shape = (shape_r, shape_c, 3), verbose=True, print_shapes=True, n_outs=1, ups=8, dil_rate = (2,2)): inp = Input(shape=input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception:',xception.output.shape) ## DECODER ## outs_dec = decoder_block(xception.output, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=512) outs_final = [outs_dec]*n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m def resnet_decoder(input_shape = (shape_r, shape_c, 3), verbose=True, print_shapes=True, n_outs=1, ups=8, dil_rate = (2,2)): inp = Input(shape=input_shape) ### ENCODER ### dcn = dcn_resnet(input_tensor=inp) if print_shapes: print('resnet output shape:',dcn.output.shape) ## DECODER ## outs_dec = decoder_block(dcn.output, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=512) outs_final = [outs_dec]*n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m def fcn_vgg16(input_shape=(shape_r, shape_c, 3), verbose=True, print_shapes=True, n_outs=1, ups=8, dil_rate=(2,2), freeze_enc=False, freeze_cl=True, internal_filts=256, num_classes=4, dp=0.3, weight_decay=0., batch_shape=None): if batch_shape: img_input = Input(batch_shape=batch_shape) image_size = batch_shape[1:3] else: img_input = Input(shape=input_shape) image_size = input_shape[0:2] # Block 1 x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer=l2(weight_decay))(img_input) x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer=l2(weight_decay))(x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) # Block 2 x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer=l2(weight_decay))(x) x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer=l2(weight_decay))(x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) # Block 3 x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer=l2(weight_decay))(x) x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer=l2(weight_decay))(x) x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer=l2(weight_decay))(x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) # Block 4 x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer=l2(weight_decay))(x) x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer=l2(weight_decay))(x) x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer=l2(weight_decay))(x) pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) # Block 5 x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer=l2(weight_decay))(pool4) x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer=l2(weight_decay))(x) x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer=l2(weight_decay))(x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) print("pool5 shape", x.shape) # Convolutional layers transfered from fully-connected layers x = Conv2D(4096, (7, 7), activation='relu', padding='same', name='fc1', kernel_regularizer=l2(weight_decay))(x) x = Dropout(0.5)(x) x = Conv2D(4096, (1, 1), activation='relu', padding='same', name='fc2', kernel_regularizer=l2(weight_decay))(x) x = Dropout(0.5)(x) # classification layer from fc7 classif_layer_fc7 = Conv2D(1, (1, 1), kernel_initializer='he_normal', activation='linear', padding='valid', strides=(1, 1))(x) print("classif_layer_fc7 shape", classif_layer_fc7.shape) # Upsampling fc7 classif layer to sum with pool4 classif layer classif_layer_fc7_ups = UpSampling2D(size=(2,2), interpolation="bilinear")(classif_layer_fc7) print("classif_layer_fc7_ups shape", classif_layer_fc7_ups.shape) # Lambda layer to match shape of pool4 def concat_one(fc7): shape_fc7 = K.shape(fc7) shape_zeros = (shape_fc7[0], 1, shape_fc7[2], shape_fc7[3] ) return K.concatenate([K.zeros(shape=shape_zeros), classif_layer_fc7_ups], axis=1) classif_layer_fc7_ups = Lambda(concat_one)(classif_layer_fc7_ups) print("classif_layer_fc7_ups shape after lambda:", classif_layer_fc7_ups.shape) # Classification layer from pool4 classif_layer_pool4 = Conv2D(1, (1, 1), kernel_initializer='he_normal', activation='linear', padding='valid', strides=(1, 1))(pool4) x = Add()([classif_layer_pool4, classif_layer_fc7_ups]) outs_up = UpSampling2D(size=(32, 32), interpolation="bilinear")(x) outs_final = [outs_up]*n_outs model = Model(img_input, outs_final) weights_path = '../../predimportance_shared/models/ckpt/fcn_vgg16/fcn_vgg16_weights_tf_dim_ordering_tf_kernels.h5' model.load_weights(weights_path, by_name=True) if verbose: model.summary() return model ############# SAM BASED MODELS ############### def sam_simple(input_shape = (224, 224, 3), in_conv_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8): '''Simple network that uses an attentive convlstm and a few convolutions.''' inp = Input(shape=input_shape) x = Conv2D(filters=in_conv_filters, kernel_size=(3,3), strides=(2, 2), padding='same', data_format=None, dilation_rate=(1,1))(inp) if print_shapes: print('after first conv') x = MaxPooling2D(pool_size=(4,4))(x) if print_shapes: print('after maxpool',x.shape) x = Lambda(repeat, repeat_shape)(x) if print_shapes: print('after repeat',x.shape) x = AttentiveConvLSTM2D(filters=512, attentive_filters=512, kernel_size=(3,3), attentive_kernel_size=(3,3), padding='same', return_sequences=False)(x) if print_shapes: print('after ACLSTM',x.shape) x = UpSampling2D(size=(ups,ups), interpolation='bilinear')(x) outs_up = Conv2D(filters=1, kernel_size=(3,3), strides=(1, 1), padding='same', data_format=None, dilation_rate=(1,1))(x) if print_shapes: print('output shape',outs_up.shape) outs_final = [outs_up]*n_outs att_convlstm = Model(inputs=inp, outputs=outs_final) if verbose: att_convlstm.summary() return att_convlstm def sam_resnet_nopriors(input_shape = (224, 224, 3), conv_filters=128, lstm_filters=512, att_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8): '''Sam ResNet with no priors.''' inp = Input(shape=input_shape) dcn = dcn_resnet(input_tensor=inp) conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output) if print_shapes: print('Shape after first conv after dcn_resnet:',conv_feat.shape) # Attentive ConvLSTM att_convlstm = Lambda(repeat, repeat_shape)(conv_feat) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3), attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm) # Dilated convolutions (priors would go here) dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(att_convlstm) dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(dil_conv1) # Final conv to get to a heatmap outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2) if print_shapes: print('Shape after 1x1 conv:',outs.shape) # Upsampling back to input shape outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs) if print_shapes: print('shape after upsampling',outs_up.shape) outs_final = [outs_up]*n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m def sam_resnet_new(input_shape = (shape_r, shape_c, 3), conv_filters=512, lstm_filters=512, att_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8, nb_gaussian=nb_gaussian): '''SAM-ResNet ported from the original code.''' inp = Input(shape=input_shape) dcn = dcn_resnet(input_tensor=inp) conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output) if print_shapes: print('Shape after first conv after dcn_resnet:',conv_feat.shape) # Attentive ConvLSTM att_convlstm = Lambda(repeat, repeat_shape)(conv_feat) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3), attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm) # Learned Prior (1) priors1 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm) concat1 = Concatenate(axis=-1)([att_convlstm, priors1]) dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat1) # Learned Prior (2) priors2 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm) concat2 = Concatenate(axis=-1)([dil_conv1, priors2]) dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat2) # Final conv to get to a heatmap outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2) if print_shapes: print('Shape after 1x1 conv:',outs.shape) # Upsampling back to input shape outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs) if print_shapes: print('shape after upsampling',outs_up.shape) outs_final = [outs_up]*n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m def sam_xception_new(input_shape = (shape_r, shape_c, 3), conv_filters=512, lstm_filters=512, att_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8, nb_gaussian=nb_gaussian): '''SAM with a custom Xception as encoder.''' inp = Input(shape=input_shape) from xception_custom import Xception from keras.applications import keras_modules_injection #@keras_modules_injection def Xception_wrapper(*args, **kwargs): return Xception(*args, **kwargs) inp = Input(shape = input_shape) dcn = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception:',dcn.output.shape) conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output) if print_shapes: print('Shape after first conv after dcn_resnet:',conv_feat.shape) # Attentive ConvLSTM att_convlstm = Lambda(repeat, repeat_shape)(conv_feat) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3), attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm) # Learned Prior (1) priors1 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm) concat1 = Concatenate(axis=-1)([att_convlstm, priors1]) dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat1) # Learned Prior (2) priors2 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm) concat2 = Concatenate(axis=-1)([dil_conv1, priors2]) dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat2) # Final conv to get to a heatmap outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2) if print_shapes: print('Shape after 1x1 conv:',outs.shape) # Upsampling back to input shape outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs) if print_shapes: print('shape after upsampling',outs_up.shape) outs_final = [outs_up]*n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m def xception_se_lstm_singledur(input_shape = (shape_r, shape_c, 3), conv_filters=256, lstm_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8, freeze_enc=False, return_sequences=False): inp = Input(shape = input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False ### LSTM over SE representation ### x = se_lstm_block(xception.output, nb_timestep, lstm_filters=lstm_filters, return_sequences=return_sequences) ### DECODER ### outs_dec = decoder_block(x, dil_rate=(2,2), print_shapes=print_shapes, dec_filt=conv_filters) outs_final = [outs_dec]*n_outs m = Model(inp, outs_final) if verbose: m.summary() return m def se_lstm_block(inp, nb_timestep, units=512, print_shapes=True, lstm_filters=512, return_sequences=False): inp_rep = Lambda(lambda y: K.repeat_elements(K.expand_dims(y, axis=1), nb_timestep, axis=1), lambda s: (s[0], nb_timestep) + s[1:])(inp) x = TimeDistributed(GlobalAveragePooling2D())(inp_rep) if print_shapes: print('shape after AvgPool',x.shape) x = TimeDistributed(Dense(units, activation='relu'))(x) if print_shapes: print('shape after first dense',x.shape) # Normally se block would feed into another fully connected. Instead, we feed it to an LSTM. x = LSTM(lstm_filters, return_sequences=return_sequences, unroll=True, activation='relu')(x) if print_shapes: print('shape after lstm',x.shape) x = Dense(inp.shape[-1].value, activation='sigmoid')(x) if print_shapes: print('shape after second dense:', x.shape) x = Lambda(lambda y: K.expand_dims(K.expand_dims(y, axis=1),axis=1), lambda s: (s[0], 1, 1, s[-1]))(x) if print_shapes: print('shape before mult',x.shape) out = Multiply()([x,inp]) print('shape out',out.shape) # out is (bs, r, c, 2048) return out def xception_aspp(input_shape = (shape_r, shape_c, 3), conv_filters=256, lstm_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8, freeze_enc=False, return_sequences=False): # Xception # Conv1,2,ASPP dil_conv1 = Conv2D(conv_filters, 3, padding='same', activation='relu', dilation_rate=(2, 2))(x) dil_conv2 = Conv2D(conv_filters, 3, padding='same', activation='relu', dilation_rate=(4, 4))(x) dil_conv3 = Conv2D(conv_filters, 3, padding='same', activation='relu', dilation_rate=(8, 8))(x) pass ############# UMSI MODELS ############### def UMSI(input_shape = (shape_r, shape_c, 3), conv_filters=256, verbose=True, print_shapes=True, n_outs=1, ups=8, freeze_enc=False, return_sequences=False): inp = Input(shape = input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False #ASPP c0 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csep0")(xception.output) c6 = DepthwiseConv2D((3,3),dilation_rate=(6,6),padding="same",use_bias=False,name="aspp_csepd6_depthwise")(xception.output) c12 = DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False,name="aspp_csepd12_depthwise")(xception.output) c18 = DepthwiseConv2D((3,3),dilation_rate=(18,18),padding="same",use_bias=False,name="aspp_csepd18_depthwise")(xception.output) c6 = BatchNormalization(name="aspp_csepd6_depthwise_BN")(c6) c12 = BatchNormalization(name="aspp_csepd12_depthwise_BN")(c12) c18 = BatchNormalization(name="aspp_csepd18_depthwise_BN")(c18) c6 = Activation("relu", name = "activation_2")(c6) c12 = Activation("relu", name = "activation_4")(c12) c18 = Activation("relu", name = "activation_6")(c18) c6 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd6_pointwise")(c6) c12 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd12_pointwise")(c12) c18 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd18_pointwise")(c18) c0 = BatchNormalization(name='aspp0_BN')(c0) c6 = BatchNormalization(name='aspp_csepd6_pointwise_BN')(c6) c12 = BatchNormalization(name='aspp_csepd12_pointwise_BN')(c12) c18 = BatchNormalization(name='aspp_csepd18_pointwise_BN')(c18) c0 = Activation("relu", name = "aspp0_activation")(c0) c6 = Activation("relu", name = "activation_3")(c6) c12 = Activation("relu", name = "activation_5")(c12) c18 = Activation("relu", name = "activation_7")(c18) concat1 = Concatenate(name="concatenate_1")([c0,c6,c12,c18]) ### classification module ### x = Conv2D(256, (3,3), strides = (3,3), padding="same",use_bias=False,name = "global_conv")(xception.output) x = BatchNormalization(name="global_BN")(x) x = Activation("relu", name = "activation_1")(x) x = Dropout(.3, name="dropout_1")(x) x = GlobalAveragePooling2D(name = "global_average_pooling2d_1")(x) x = Dense(256, name="global_dense")(x) classif = Dropout(.3, name="dropout_2")(x) out_classif = Dense(6, activation="softmax", name="out_classif")(classif) x = Dense(256, name="dense_fusion")(classif) def lambda_layer_function(x): x = tf.reshape(x,(tf.shape(x)[0],1,1,256)) con = [x for i in range(30)] con = tf.concat(con,axis=1) con = tf.concat([con for i in range(40)],axis=2) return con x = Lambda(lambda_layer_function, name = "lambda_1")(x) concat2 = Concatenate(name="concatenate_2")([concat1, x]) ### DECODER ### x = Conv2D(256,(1,1),padding="same",use_bias=False,name = "concat_projection")(concat2) x = BatchNormalization(name="concat_projection_BN")(x) x = Activation("relu", name="activation_8")(x) x = Dropout(.3, name="dropout_3")(x) x = Conv2D(256,(3,3),padding="same",use_bias=False,name = "dec_c1")(x) x = Conv2D(256,(3,3),padding="same",use_bias=False,name = "dec_c2")(x) x = Dropout(.3, name="dec_dp1")(x) x = UpSampling2D(size=(2,2), interpolation='bilinear', name="dec_ups1")(x) x = Conv2D(128,(3,3),padding="same",use_bias=False,name = "dec_c3")(x) x = Conv2D(128,(3,3),padding="same",use_bias=False,name = "dec_c4")(x) x = Dropout(.3, name="dec_dp2")(x) x = UpSampling2D(size=(2,2), interpolation='bilinear', name="dec_ups2")(x) x = Conv2D(64,(3,3),padding="same",use_bias=False,name = "dec_c5")(x) x = Dropout(.3, name="dec_dp3")(x) x = UpSampling2D(size=(4,4), interpolation='bilinear', name="dec_ups3")(x) out_heatmap = Conv2D(1,(1,1),padding="same",use_bias=False,name = "dec_c_cout")(x) # Building model outs_final = [out_heatmap, out_classif] print(out_heatmap.shape) m = Model(inp, outs_final) if verbose: m.summary() return m ############# MODELS FOR RecallNet ############### def RecallNet_UMSI(input_shape = (shape_r, shape_c, 3), conv_filters=256, verbose=True, print_shapes=True, n_outs=1, ups=8, freeze_enc=False, return_sequences=False): inp = Input(shape = input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False #ASPP c0 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csep0")(xception.output) c6 = DepthwiseConv2D((3,3),dilation_rate=(6,6),padding="same",use_bias=False,name="aspp_csepd6_depthwise")(xception.output) c12 = DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False,name="aspp_csepd12_depthwise")(xception.output) c18 = DepthwiseConv2D((3,3),dilation_rate=(18,18),padding="same",use_bias=False,name="aspp_csepd18_depthwise")(xception.output) c6 = BatchNormalization(name="aspp_csepd6_depthwise_BN")(c6) c12 = BatchNormalization(name="aspp_csepd12_depthwise_BN")(c12) c18 = BatchNormalization(name="aspp_csepd18_depthwise_BN")(c18) c6 = Activation("relu", name = "activation_2")(c6) c12 = Activation("relu", name = "activation_4")(c12) c18 = Activation("relu", name = "activation_6")(c18) c6 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd6_pointwise")(c6) c12 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd12_pointwise")(c12) c18 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd18_pointwise")(c18) c0 = BatchNormalization(name='aspp0_BN')(c0) c6 = BatchNormalization(name='aspp_csepd6_pointwise_BN')(c6) c12 = BatchNormalization(name='aspp_csepd12_pointwise_BN')(c12) c18 = BatchNormalization(name='aspp_csepd18_pointwise_BN')(c18) c0 = Activation("relu", name = "aspp0_activation")(c0) c6 = Activation("relu", name = "activation_3")(c6) c12 = Activation("relu", name = "activation_5")(c12) c18 = Activation("relu", name = "activation_7")(c18) concat1 = Concatenate(name="concatenate_1")([c0,c6,c12,c18]) ### classification module ### x = Conv2D(256, (3,3), strides = (3,3), padding="same",use_bias=False,name = "global_conv")(xception.output) x = BatchNormalization(name="global_BN")(x) x = Activation("relu", name = "activation_1")(x) x = Dropout(.3, name="dropout_1")(x) x = GlobalAveragePooling2D(name = "global_average_pooling2d_1")(x) x = Dense(256, name="global_dense")(x) classif = Dropout(.3, name="dropout_2")(x) out_classif = Dense(6, activation="softmax", name="out_classif")(classif) x = Dense(256, name="dense_fusion")(classif) def lambda_layer_function(x): x = tf.reshape(x,(tf.shape(x)[0],1,1,256)) con = [x for i in range(30)] con = tf.concat(con,axis=1) con = tf.concat([con for i in range(40)],axis=2) return con x = Lambda(lambda_layer_function, name = "lambda_1")(x) concat2 = Concatenate(name="concatenate_2")([concat1, x]) ### DECODER ### flatten = GlobalAveragePooling2D(name = "global_average_pooling2d_2")(concat2) mean_acc = Dense(256, name="mean_dense")(flatten) mean_acc = Dense(1, name='out_mean_acc')(mean_acc) type0_acc = Dense(256, name="type0_dense")(flatten) type0_acc = Dense(1, name='out_type0_acc')(type0_acc) # Building model outs_final = [type0_acc, mean_acc, out_classif] m = Model(inp, outs_final) if verbose: m.summary() return m #Model for RecallNet def RecallNet_xception_aspp(input_shape = (shape_r, shape_c, 3), conv_filters=256, verbose=True, print_shapes=True, n_outs=1, ups=8, freeze_enc=False, return_sequences=False): inp = Input(shape = input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False #ASPP c0 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csep0")(xception.output) c6 = DepthwiseConv2D((3,3),dilation_rate=(6,6),padding="same",use_bias=False,name="aspp_csepd6_depthwise")(xception.output) c12 = DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False,name="aspp_csepd12_depthwise")(xception.output) c18 = DepthwiseConv2D((3,3),dilation_rate=(18,18),padding="same",use_bias=False,name="aspp_csepd18_depthwise")(xception.output) c6 = BatchNormalization(name="aspp_csepd6_depthwise_BN")(c6) c12 = BatchNormalization(name="aspp_csepd12_depthwise_BN")(c12) c18 = BatchNormalization(name="aspp_csepd18_depthwise_BN")(c18) c6 = Activation("relu", name = "activation_2")(c6) c12 = Activation("relu", name = "activation_4")(c12) c18 = Activation("relu", name = "activation_6")(c18) c6 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd6_pointwise")(c6) c12 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd12_pointwise")(c12) c18 = Conv2D(256,(1,1),padding="same",use_bias=False,name = "aspp_csepd18_pointwise")(c18) c0 = BatchNormalization(name='aspp0_BN')(c0) c6 = BatchNormalization(name='aspp_csepd6_pointwise_BN')(c6) c12 = BatchNormalization(name='aspp_csepd12_pointwise_BN')(c12) c18 = BatchNormalization(name='aspp_csepd18_pointwise_BN')(c18) c0 = Activation("relu", name = "aspp0_activation")(c0) c6 = Activation("relu", name = "activation_3")(c6) c12 = Activation("relu", name = "activation_5")(c12) c18 = Activation("relu", name = "activation_7")(c18) concat1 = Concatenate(name="concatenate_1")([c0,c6,c12,c18]) ### classification module ### x = Conv2D(256, (3,3), strides = (3,3), padding="same",use_bias=False,name = "global_conv")(xception.output) x = BatchNormalization(name="global_BN")(x) x = Activation("relu", name = "activation_1")(x) x = Dropout(.3, name="dropout_1")(x) x = GlobalAveragePooling2D(name = "global_average_pooling2d_1")(x) x = Dense(256, name="global_dense")(x) classif = Dropout(.3, name="dropout_2")(x) #out_classif = Dense(6, activation="softmax", name="out_classif")(classif) x = Dense(256, name="dense_fusion")(classif) def lambda_layer_function(x): x = tf.reshape(x,(tf.shape(x)[0],1,1,256)) con = [x for i in range(30)] con = tf.concat(con,axis=1) con = tf.concat([con for i in range(40)],axis=2) return con x = Lambda(lambda_layer_function, name = "lambda_1")(x) concat2 = Concatenate(name="concatenate_2")([concat1, x]) ### DECODER ### flatten = GlobalAveragePooling2D(name = "global_average_pooling2d_2")(concat2) mean_acc = Dense(256, name="mean_dense")(flatten) mean_acc = Dense(1, name='out_mean_acc')(mean_acc) type0_acc = Dense(256, name="type0_dense")(flatten) type0_acc = Dense(1, name='out_type0_acc')(type0_acc) # Building model outs_final = [mean_acc, type0_acc] m = Model(inp, outs_final) if verbose: m.summary() return m #Model for VMQA def RecallNet_xception(input_shape = (shape_r, shape_c, 3), conv_filters=256, verbose=True, print_shapes=True, n_outs=1, ups=8, freeze_enc=False, return_sequences=False): inp = Input(shape = input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False x = Conv2D(256, (3,3), strides = (3,3), padding="same",use_bias=False,name = "global_conv")(xception.output) x = BatchNormalization(name="global_BN")(x) x = Activation("relu", name = "activation_1")(x) x = Dropout(.3, name="dropout_1")(x) x = GlobalAveragePooling2D(name = "global_average_pooling2d_1")(x) x = Dense(256, name="global_dense")(x) classif = Dropout(.3, name="dropout_2")(x) # out_classif = Dense(6, activation="softmax", name="out_classif")(classif) mean_acc = Dense(1, name='out_mean_acc')(classif) type_0 = Dense(1, name='out_type0_acc')(classif) # Building model outs_final = [ mean_acc, type_0] m = Model(inp, outs_final) if verbose: m.summary() return m