import numpy as np import keras import sys import os from keras.layers import Layer, Input, Multiply, Dropout,DepthwiseConv2D, TimeDistributed, LSTM, Activation, Lambda, Conv2D, Dense, GlobalAveragePooling2D, MaxPooling2D, ZeroPadding2D, UpSampling2D, BatchNormalization, Concatenate import keras.backend as K from keras.models import Model import tensorflow as tf from keras.utils import Sequence import cv2 import scipy.io import math from attentive_convlstm_new import AttentiveConvLSTM2D from dcn_resnet_new import dcn_resnet from gaussian_prior_new import LearningPrior from sal_imp_utilities import * from multiduration_models import decoder_block_timedist from xception_custom import Xception_wrapper from keras.applications import keras_modules_injection def xception_cl(input_shape = (None, None, 3), verbose=True, print_shapes=True, n_outs=1, ups=8, freeze_enc=False, dil_rate = (2,2), freeze_cl=True, append_classif=True, num_classes=5): """Xception with classification capabilities""" inp = Input(shape=input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False ### CLASSIFIER ### cl = GlobalAveragePooling2D(name='gap_cl')(xception.output) cl = Dense(512,name='dense_cl')(cl) cl = Dropout(0.3, name='dropout_cl')(cl) cl = Dense(num_classes, activation='softmax', name='dense_cl_out')(cl) ## DECODER ## outs_dec = decoder_block(xception.output, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=512, prefix='decoder') outs_final = [outs_dec]*n_outs if append_classif: outs_final.append(cl) # Building model m = Model(inp, outs_final) # Last element of outs_final is classification vector if verbose: m.summary() if freeze_cl: print('Freezing classification dense layers') m.get_layer('dense_cl').trainable = False m.get_layer('dense_cl_out').trainable = False return m def xception_cl_fus(input_shape=(None, None, 3), verbose=True, print_shapes=True, n_outs=1, ups=8, dil_rate=(2,2), freeze_enc=False, freeze_cl=True, internal_filts=256, num_classes=5, dp=0.3): """Xception with classification capabilities that fuses representations from both tasks""" inp = Input(shape=input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False ### GLOBAL FEATURES ### g_n = global_net(xception.output, nfilts=internal_filts, dp=dp) if print_shapes: print('g_n shapes:', g_n.shape) ### CLASSIFIER ### # We potentially need another layer here out_classif = Dense(num_classes, activation='softmax', name='out_classif')(g_n) ### ASPP (MID LEVEL FEATURES) ### aspp_out = app(xception.output, internal_filts) if print_shapes: print('aspp out shapes:', aspp_out.shape) ### FUSION ### dense_f = Dense(internal_filts, name = 'dense_fusion')(g_n) if print_shapes: print('dense_f shapes:', dense_f.shape) reshap = Lambda(lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(x, axis=1), K.int_shape(aspp_out)[2], axis=1), axis=1), K.int_shape(aspp_out)[1], axis=1), lambda s: (s[0], K.int_shape(aspp_out)[1], K.int_shape(aspp_out)[2], s[1]))(dense_f) if print_shapes: print('after lambda shapes:', reshap.shape) conc = Concatenate()([aspp_out,reshap]) ### Projection ### x = Conv2D(internal_filts, (1, 1), padding='same', use_bias=False, name='concat_projection')(conc) x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x) x = Activation('relu')(x) x = Dropout(dp)(x) ### DECODER ### outs_dec = decoder_block(x, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=internal_filts, dp=dp) outs_final = [outs_dec]*n_outs outs_final.append(out_classif) # Building model m = Model(inp, outs_final) # Last element of outs_final is classification vector if freeze_cl: m.get_layer('out_classif').trainable = False # for l in g_n.layers: # l.trainable=False if verbose: m.summary() return m def xception_cl_fus_aspp(input_shape=(None, None, 3), verbose=True, print_shapes=True, n_outs=1, ups=8, dil_rate=(2,2), freeze_enc=False, freeze_cl=True, internal_filts=256, num_classes=6, dp=0.3, lambda_layer_for_save=False): inp = Input(shape=input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False ### GLOBAL FEATURES ### g_n = global_net(xception.output, nfilts=internal_filts, dp=dp) if print_shapes: print('g_n shapes:', g_n.shape) ### CLASSIFIER ### # We potentially need another layer here out_classif = Dense(num_classes, activation='softmax', name='out_classif')(g_n) ### ASPP (MID LEVEL FEATURES) ### aspp_out = aspp(xception.output, internal_filts) if print_shapes: print('aspp out shapes:', aspp_out.shape) ### FUSION ### dense_f = Dense(internal_filts, name = 'dense_fusion')(g_n) if print_shapes: print('dense_f shapes:', dense_f.shape) if not lambda_layer_for_save: reshap = Lambda(lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(x, axis=1), K.int_shape(aspp_out)[2], axis=1), axis=1), K.int_shape(aspp_out)[1], axis=1), lambda s: (s[0], K.int_shape(aspp_out)[1], K.int_shape(aspp_out)[2], s[1]))(dense_f) else: # Use this lambda layer if you want to be able to use model.save() (set lambda_layer_for_save to True) print("Using lambda layer adapted to model.save()") reshap = Lambda(lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(x, axis=1), 40, axis=1), axis=1), 30, axis=1), lambda s: (s[0], 30, 40, s[1]))(dense_f) # reshap = FusionReshape()(dense_f) if print_shapes: print('after lambda shapes:', reshap.shape) conc = Concatenate()([aspp_out,reshap]) ### Projection ### x = Conv2D(internal_filts, (1, 1), padding='same', use_bias=False, name='concat_projection')(conc) x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x) x = Activation('relu')(x) x = Dropout(dp)(x) ### DECODER ### outs_dec = decoder_block(x, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=internal_filts, dp=dp) outs_final = [outs_dec]*n_outs outs_final.append(out_classif) # Building model m = Model(inp, outs_final,name = 'xception_cl_fus_aspp') # Last element of outs_final is classification vector if freeze_cl: m.get_layer('out_classif').trainable = False # for l in g_n.layers: # l.trainable=False if verbose: m.summary() return m def umsi(input_shape=(None, None, 3), verbose=True, print_shapes=True, n_outs=1, ups=8, dil_rate=(2,2), freeze_enc=False, freeze_cl=True, internal_filts=256, num_classes=6, dp=0.3, lambda_layer_for_save=False): inp = Input(shape=input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) if freeze_enc: for layer in xception.layers: layer.trainable = False # xception.summary() skip_layers = ['block3_sepconv2_bn','block1_conv1_act'] # sizes: 119x159x32, 59x79x256 skip_feature_maps = [xception.get_layer(n).output for n in skip_layers] ### GLOBAL FEATURES ### g_n = global_net(xception.output, nfilts=internal_filts, dp=dp) if print_shapes: print('g_n shapes:', g_n.shape) ### CLASSIFIER ### # We potentially need another layer here out_classif = Dense(num_classes, activation='softmax', name='out_classif')(g_n) ### ASPP (MID LEVEL FEATURES) ### aspp_out = aspp(xception.output, internal_filts) if print_shapes: print('aspp out shapes:', aspp_out.shape) ### FUSION ### dense_f = Dense(internal_filts, name = 'dense_fusion')(g_n) if print_shapes: print('dense_f shapes:', dense_f.shape) if not lambda_layer_for_save: reshap = Lambda(lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(x, axis=1), K.int_shape(aspp_out)[2], axis=1), axis=1), K.int_shape(aspp_out)[1], axis=1), lambda s: (s[0], K.int_shape(aspp_out)[1], K.int_shape(aspp_out)[2], s[1]))(dense_f) else: # Use this lambda layer if you want to be able to use model.save() (set lambda_layer_for_save to True) print("Using lambda layer adapted to model.save()") reshap = Lambda(lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(x, axis=1), 40, axis=1), axis=1), 30, axis=1), lambda s: (s[0], 30, 40, s[1]))(dense_f) # reshap = FusionReshape()(dense_f) if print_shapes: print('after lambda shapes:', reshap.shape) conc = Concatenate()([aspp_out,reshap]) ### Projection ### x = Conv2D(internal_filts, (1, 1), padding='same', use_bias=False, name='concat_projection')(conc) x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x) x = Activation('relu')(x) x = Dropout(dp)(x) ### DECODER ### # outs_dec = decoder_block(x, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=internal_filts, dp=dp) outs_dec = decoder_with_skip(x, skip_feature_maps, print_shapes=print_shapes, dec_filt=internal_filts, dp=dp) outs_final = [outs_dec]*n_outs outs_final.append(out_classif) # Building model m = Model(inp, outs_final, name = 'umsi') # Last element of outs_final is classification vector if freeze_cl: m.get_layer('out_classif').trainable = False # for l in g_n.layers: # l.trainable=False if verbose: m.summary() return m def xception_cl_fus_skipdec(input_shape=(None, None, 3), verbose=True, print_shapes=True, n_outs=1, ups=8, dil_rate=(2,2), freeze_enc=False, freeze_cl=True, internal_filts=256, num_classes=5, dp=0.3): inp = Input(shape=input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception output shapes:',xception.output.shape) xception.summary() if freeze_enc: for layer in xception.layers: layer.trainable = False ### GLOBAL FEATURES ### g_n = global_net(xception.output, nfilts=internal_filts, dp=dp) if print_shapes: print('g_n shapes:', g_n.shape) ### CLASSIFIER ### # We potentially need another layer here out_classif = Dense(num_classes, activation='softmax', name='out_classif')(g_n) ### ASPP (MID LEVEL FEATURES) ### aspp_out = aspp(xception.output, internal_filts) if print_shapes: print('aspp out shapes:', aspp_out.shape) ### FUSION ### dense_f = Dense(internal_filts, name = 'dense_fusion')(g_n) if print_shapes: print('dense_f shapes:', dense_f.shape) reshap = Lambda(lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(x, axis=1), K.int_shape(aspp_out)[2], axis=1), axis=1), K.int_shape(aspp_out)[1], axis=1), lambda s: (s[0], K.int_shape(aspp_out)[1], K.int_shape(aspp_out)[2], s[1]))(dense_f) if print_shapes: print('after lambda shapes:', reshap.shape) conc = Concatenate()([aspp_out,reshap]) ### Projection ### x = Conv2D(internal_filts, (1, 1), padding='same', use_bias=False, name='concat_projection')(conc) x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x) x = Activation('relu')(x) x = Dropout(dp)(x) ### DECODER ### outs_dec = decoder_block(x, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=internal_filts, dp=dp) # outs_dec = decoder_with_skip(x, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=internal_filts, dp=dp) outs_final = [outs_dec]*n_outs outs_final.append(out_classif) # Building model m = Model(inp, outs_final) # Last element of outs_final is classification vector if freeze_cl: m.get_layer('out_classif').trainable = False # for l in g_n.layers: # l.trainable=False if verbose: m.summary() return m def global_net(x, nfilts=512, dp=0.1, print_shapes = True): x = Conv2D(nfilts, (3, 3), strides=3, padding='same', use_bias=False, name='global_conv')(x) if print_shapes: print('Shape after global net conv:', x.shape) x = BatchNormalization(name='global_BN',epsilon=1e-5)(x) x = Activation('relu')(x) x = Dropout(dp)(x) x = GlobalAveragePooling2D()(x) x = Dense(nfilts, name='global_dense')(x) x = Dropout(dp)(x) return x def app(x, nfilts=256, prefix='app', dils=[6,12,18]): x1 = Conv2D(nfilts, 1, padding='same', activation='relu', dilation_rate=(1,1), name=prefix+'_c1x1')(x) x2 = Conv2D(nfilts, 3, padding='same', activation='relu', dilation_rate=(dils[0],dils[0]), name=prefix+'_c3x3d'+str(dils[0]))(x) x3 = Conv2D(nfilts, 3, padding='same', activation='relu', dilation_rate=(dils[1],dils[1]), name=prefix+'_c3x3d'+str(dils[1]))(x) x4 = Conv2D(nfilts, 3, padding='same', activation='relu', dilation_rate=(dils[2],dils[2]), name=prefix+'_c3x3d'+str(dils[2]))(x) x = Concatenate()([x1,x2,x3,x4]) return x def aspp(x, nfilts=256, prefix='aspp', dils=[6,12,18]): x1 = Conv2D(nfilts, (1, 1), padding='same', use_bias=False, name=prefix+'_csep0')(x) x1 = BatchNormalization(name='aspp0_BN', epsilon=1e-5)(x1) x1 = Activation('relu', name='aspp0_activation')(x1) # rate = 6 x2 = SepConv_BN(x, nfilts, prefix+'_csepd'+str(dils[0]), rate=dils[0], depth_activation=True, epsilon=1e-5) # rate = 12 (24) x3 = SepConv_BN(x, nfilts, prefix+'_csepd'+str(dils[1]),rate=dils[1], depth_activation=True, epsilon=1e-5) # rate = 18 (36) x4 = SepConv_BN(x, nfilts, prefix+'_csepd'+str(dils[2]),rate=dils[2], depth_activation=True, epsilon=1e-5) x = Concatenate()([x1,x2,x3,x4]) return x def decoder_with_skip(x, skip_tensors, dil_rate=1, print_shapes=True, dec_filt=1024, dp=0.2, ups=16, prefix='decskip'): # sizes of input skip connections from Xception: 119x159x32, 117x157x128, 59x79x256 for i, sk in enumerate(skip_tensors, start=1): # Upsample x = UpSampling2D((2,2), interpolation='bilinear', name=prefix+'_ups%d'%i)(x) if x.shape[1] != sk.shape[1] or x.shape[2] != sk.shape[2]: x = Lambda(lambda t: tf.image.resize(t, (K.int_shape(sk)[1], K.int_shape(sk)[2])))(x) # Concatenate x = Concatenate()([x, sk]) # Convolve to reduce feature dimensionality x = Conv2D(dec_filt//2**i, (1, 1), padding='same', use_bias=False, name=prefix+'_proj_%d'%i)(x) x = BatchNormalization(name=prefix+'_bn_%d'%i, epsilon=1e-5)(x) x = Activation('relu', name=prefix+'_act_%d'%i)(x) # Convolve with depth sep convs x = SepConv_BN(x, dec_filt//2**i, kernel_size=3, depth_activation=True, epsilon=1e-5, rate=dil_rate, prefix=prefix+'_sepconvA_%d'%i) x = SepConv_BN(x, dec_filt//2**i, kernel_size=3, depth_activation=True, epsilon=1e-5, rate=dil_rate, prefix=prefix+'_sepconvB_%d'%i) x = Dropout(dp, name=prefix+'_dp%d'%i)(x) print("shape after block %d of dec:"%i, x.shape) # Upsampling and normal conv # i+=1 # x = UpSampling2D((2,2), interpolation='bilinear', name=prefix+'_ups_prefinal')(x) # x = Conv2D(dec_filt//2**i, (3, 3), padding='same', use_bias=True, name=prefix+'_conv_%d'%i)(x) # x = BatchNormalization(name=prefix+'_bn_%d'%i, epsilon=1e-5)(x) # x = Activation('relu', name=prefix+'_act_%d'%i)(x) # Final upsample to get to desired output size (480x640) x = UpSampling2D((4,4), interpolation='bilinear', name=prefix+'_ups_final')(x) if x.shape[1] != shape_r_out or x.shape[2] != shape_c_out: x = Lambda(lambda t: tf.image.resize(t, (shape_r_out, shape_c_out)))(x) if print_shapes: print('Shape after last ups and resize:',x.shape) # Final conv to get to a heatmap x = Conv2D(1, kernel_size=1, padding='same', activation='relu', name=prefix+'_c_out')(x) if print_shapes: print('Shape after 1x1 conv:',x.shape) return x def decoder_block(x, dil_rate=(2,2), print_shapes=True, dec_filt=1024, dp=0.2, ups=16, prefix='dec'): # Dilated convolutions x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=dil_rate, name=prefix+'_c1')(x) x = Conv2D(dec_filt, 3, padding='same', activation='relu', dilation_rate=dil_rate, name=prefix+'_c2')(x) x = Dropout(dp, name=prefix+'_dp1')(x) x = UpSampling2D((2,2), interpolation='bilinear', name=prefix+'_ups1')(x) x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=dil_rate, name=prefix+'_c3')(x) x = Conv2D(dec_filt//2, 3, padding='same', activation='relu', dilation_rate=dil_rate, name=prefix+'_c4')(x) x = Dropout(dp, name=prefix+'_dp2')(x) x = UpSampling2D((2,2), interpolation='bilinear', name=prefix+'_ups2')(x) x = Conv2D(dec_filt//4, 3, padding='same', activation='relu', dilation_rate=dil_rate, name=prefix+'_c5')(x) x = Dropout(dp, name=prefix+'_dp3')(x) x = UpSampling2D((4,4), interpolation='bilinear', name=prefix+'_ups3')(x) if print_shapes: print('Shape after last ups:',x.shape) # Final conv to get to a heatmap x = Conv2D(1, kernel_size=1, padding='same', activation='relu', name=prefix+'_c_out')(x) if print_shapes: print('Shape after 1x1 conv:',x.shape) return x class FusionReshape(Layer): def __init__(self, **kwargs): super(FusionReshape, self).__init__(**kwargs) def build(self, input_shape): super(FusionReshape, self).build(input_shape) def call(self, x): return K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(x, axis=1), 40, axis=1), axis=1), 30, axis=1) def compute_output_shape(self, input_shape): return (input_shape[0], 30, 40, input_shape[1]) ##### DEEPLAB V3 CODE ##### def SepConv_BN(x, filters, prefix='scb', stride=1, kernel_size=3, rate=1, depth_activation=False, epsilon=1e-3): """ SepConv with BN between depthwise & pointwise. Optionally add activation after BN Implements right "same" padding for even kernel sizes Args: x: input tensor filters: num of filters in pointwise convolution prefix: prefix before name stride: stride at depthwise conv kernel_size: kernel size for depthwise convolution rate: atrous rate for depthwise convolution depth_activation: flag to use activation between depthwise & poinwise convs epsilon: epsilon to use in BN layer """ if stride == 1: depth_padding = 'same' else: kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) pad_total = kernel_size_effective - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg x = ZeroPadding2D((pad_beg, pad_end))(x) depth_padding = 'valid' if not depth_activation: x = Activation('relu')(x) x = DepthwiseConv2D((kernel_size, kernel_size), strides=(stride, stride), dilation_rate=(rate, rate), padding=depth_padding, use_bias=False, name=prefix + '_depthwise')(x) x = BatchNormalization(name=prefix + '_depthwise_BN', epsilon=epsilon)(x) if depth_activation: x = Activation('relu')(x) x = Conv2D(filters, (1, 1), padding='same', use_bias=False, name=prefix + '_pointwise')(x) x = BatchNormalization(name=prefix + '_pointwise_BN', epsilon=epsilon)(x) if depth_activation: x = Activation('relu')(x) return x