# -*- coding: utf-8 -*- """Convolutional-recurrent layers. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from keras import backend as K from keras import activations from keras import initializers from keras import regularizers from keras import constraints #from keras.layers.recurrent import _generate_dropout_mask #from keras.layers.recurrent import _standardize_args import numpy as np import warnings from keras.engine.base_layer import InputSpec, Layer from keras.utils import conv_utils #from keras.legacy import interfaces #from keras.legacy.layers import Recurrent, ConvRecurrent2D from keras.layers.recurrent import RNN from keras.utils.generic_utils import has_arg from keras.utils.generic_utils import to_list from keras.utils.generic_utils import transpose_shape from tensorflow.python.keras.layers.convolutional_recurrent import ConvRNN2D def _generate_dropout_mask(ones, rate, training=None, count=1): def dropped_inputs(): return K.dropout(ones, rate) class AttentiveConvLSTM2DCell(Layer): def __init__(self, filters, attentive_filters, kernel_size, attentive_kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation='tanh', recurrent_activation='hard_sigmoid', attentive_activation='tanh', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', attentive_initializer='zeros', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, attentive_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, attentive_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., attentive_dropout=0., **kwargs): super(AttentiveConvLSTM2DCell, self).__init__(**kwargs) self.filters = filters self.attentive_filters = attentive_filters self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') self.attentive_kernel_size = conv_utils.normalize_tuple(attentive_kernel_size, 2, 'attentive_kernel_size') self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') self.padding = conv_utils.normalize_padding(padding) self.data_format = K.normalize_data_format(data_format) self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate') self.activation = activations.get(activation) self.recurrent_activation = activations.get(recurrent_activation) self.attentive_activation = activations.get(attentive_activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.recurrent_initializer = initializers.get(recurrent_initializer) self.attentive_initializer = initializers.get(attentive_initializer) self.bias_initializer = initializers.get(bias_initializer) self.unit_forget_bias = unit_forget_bias self.kernel_regularizer = regularizers.get(kernel_regularizer) self.recurrent_regularizer = regularizers.get(recurrent_regularizer) self.attentive_regularizer = regularizers.get(attentive_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.recurrent_constraint = constraints.get(recurrent_constraint) self.attentive_constraint = constraints.get(attentive_constraint) self.bias_constraint = constraints.get(bias_constraint) if K.backend() == 'theano' and (dropout or recurrent_dropout): warnings.warn( 'RNN dropout is no longer supported with the Theano backend ' 'due to technical limitations. ' 'You can either set `dropout` and `recurrent_dropout` to 0, ' 'or use the TensorFlow backend.') dropout = 0. recurrent_dropout = 0. self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.attentive_dropout = min(1., max(0., attentive_dropout)) self.state_size = (self.filters, self.filters) self._dropout_mask = None self._recurrent_dropout_mask = None self._attentive_dropout_mask = None def build(self, input_shape): if self.data_format == 'channels_first': channel_axis = 1 else: channel_axis = -1 if input_shape[channel_axis] is None: raise ValueError('The channel dimension of the inputs ' 'should be defined. Found `None`.') input_dim = input_shape[channel_axis] kernel_shape = self.kernel_size + (input_dim, self.filters * 4) self.kernel_shape = kernel_shape print('kernel_shape', kernel_shape) recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4) input_attentive_kernel_shape = self.attentive_kernel_size + (input_dim, self.attentive_filters) hidden_attentive_kernel_shape = self.attentive_kernel_size + (self.filters, self.attentive_filters) squeeze_attentive_kernel_shape = self.attentive_kernel_size + (self.attentive_filters, 1) self.kernel = self.add_weight(shape=kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( shape=recurrent_kernel_shape, initializer=self.recurrent_initializer, name='recurrent_kernel', regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.input_attentive_kernel = self.add_weight( shape=input_attentive_kernel_shape, initializer=self.attentive_initializer, name='input_attentive_kernel', regularizer=self.attentive_regularizer, constraint=self.attentive_constraint) self.hidden_attentive_kernel = self.add_weight( shape=hidden_attentive_kernel_shape, initializer=self.attentive_initializer, name='hidden_attentive_kernel', regularizer=self.attentive_regularizer, constraint=self.attentive_constraint) self.squeeze_attentive_kernel = self.add_weight( shape=squeeze_attentive_kernel_shape, initializer=self.attentive_initializer, name='squeeze_attentive_kernel', regularizer=self.attentive_regularizer, constraint=self.attentive_constraint) if self.use_bias: if self.unit_forget_bias: def bias_initializer(_, *args, **kwargs): return K.concatenate([ self.bias_initializer((self.filters,), *args, **kwargs), initializers.Ones()((self.filters,), *args, **kwargs), self.bias_initializer((self.filters * 2,), *args, **kwargs), ]) else: bias_initializer = self.bias_initializer self.bias = self.add_weight( shape=(self.filters * 4,), name='bias', initializer=bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.attentive_bias = self.add_weight( shape=(self.attentive_filters * 2,), name='attentive_bias', initializer=bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None self.kernel_i = self.kernel[:, :, :, :self.filters] self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters] self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2] self.recurrent_kernel_f = (self.recurrent_kernel[:, :, :, self.filters: self.filters * 2]) self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3] self.recurrent_kernel_c = (self.recurrent_kernel[:, :, :, self.filters * 2: self.filters * 3]) self.kernel_o = self.kernel[:, :, :, self.filters * 3:] self.recurrent_kernel_o = self.recurrent_kernel[:, :, :, self.filters * 3:] if self.use_bias: self.bias_i = self.bias[:self.filters] self.bias_f = self.bias[self.filters: self.filters * 2] self.bias_c = self.bias[self.filters * 2: self.filters * 3] self.bias_o = self.bias[self.filters * 3:] self.bias_wa = self.attentive_bias[:self.attentive_filters ] self.bias_ua = self.attentive_bias[self.attentive_filters : self.attentive_filters * 2] else: self.bias_i = None self.bias_f = None self.bias_c = None self.bias_o = None self.built = True def call(self, inputs, states, training=None): if 0 < self.dropout < 1 and self._dropout_mask is None: self._dropout_mask = _generate_dropout_mask( K.ones_like(inputs), self.dropout, training=training, count=4) if (0 < self.recurrent_dropout < 1 and self._recurrent_dropout_mask is None): self._recurrent_dropout_mask = _generate_dropout_mask( K.ones_like(states[1]), self.recurrent_dropout, training=training, count=4) # if (0 < self.attentive_dropout < 1 and self._attentive_dropout_mask is None): # self._attentive_dropout_mask = _generate_dropout_mask( # K.ones_like(inputs), # self.attentive_dropout, # training=training, # count=4) # dropout matrices for input units dp_mask = self._dropout_mask # dropout matrices for recurrent units rec_dp_mask = self._recurrent_dropout_mask # dropout matrices for attentive units # att_dp_mask = self._attentive_dropout_mask h_tm1 = states[0] # previous memory state c_tm1 = states[1] # previous carry state ##### ATTENTION MECHANISM h_and_x = self.input_conv(h_tm1, self.hidden_attentive_kernel, self.bias_ua, padding='same') + self.input_conv(inputs, self.input_attentive_kernel, self.bias_wa, padding='same') e = self.recurrent_conv(self.attentive_activation(h_and_x), self.squeeze_attentive_kernel) a = K.reshape(K.softmax(K.batch_flatten(e)), K.shape(e)) inputs = inputs * K.repeat_elements(a, inputs.shape[-1], -1) ##### END OF ATTENTION MECHANISM if 0 < self.dropout < 1.: inputs_i = inputs * dp_mask[0] inputs_f = inputs * dp_mask[1] inputs_c = inputs * dp_mask[2] inputs_o = inputs * dp_mask[3] else: inputs_i = inputs inputs_f = inputs inputs_c = inputs inputs_o = inputs if 0 < self.recurrent_dropout < 1.: h_tm1_i = h_tm1 * rec_dp_mask[0] h_tm1_f = h_tm1 * rec_dp_mask[1] h_tm1_c = h_tm1 * rec_dp_mask[2] h_tm1_o = h_tm1 * rec_dp_mask[3] else: h_tm1_i = h_tm1 h_tm1_f = h_tm1 h_tm1_c = h_tm1 h_tm1_o = h_tm1 x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i, padding=self.padding) x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f, padding=self.padding) x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c, padding=self.padding) x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o, padding=self.padding) h_i = self.recurrent_conv(h_tm1_i, self.recurrent_kernel_i) h_f = self.recurrent_conv(h_tm1_f, self.recurrent_kernel_f) h_c = self.recurrent_conv(h_tm1_c, self.recurrent_kernel_c) h_o = self.recurrent_conv(h_tm1_o, self.recurrent_kernel_o) i = self.recurrent_activation(x_i + h_i) f = self.recurrent_activation(x_f + h_f) c = f * c_tm1 + i * self.activation(x_c + h_c) o = self.recurrent_activation(x_o + h_o) h = o * self.activation(c) if 0 < self.dropout + self.recurrent_dropout: if training is None: h._uses_learning_phase = True return h, [h, c] def input_conv(self, x, w, b=None, padding='valid'): conv_out = K.conv2d(x, w, strides=self.strides, padding=padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if b is not None: conv_out = K.bias_add(conv_out, b, data_format=self.data_format) return conv_out def recurrent_conv(self, x, w): conv_out = K.conv2d(x, w, strides=(1, 1), padding='same', data_format=self.data_format) return conv_out def get_config(self): config = {'filters': self.filters, 'attentive_filters': self.attentive_filters, 'kernel_size': self.kernel_size, 'attentive_kernel_size': self.attentive_kernel_size, 'strides': self.strides, 'padding': self.padding, 'data_format': self.data_format, 'dilation_rate': self.dilation_rate, 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize( self.recurrent_activation), 'attentive_activation': activations.serialize( self.attentive_activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize( self.kernel_initializer), 'recurrent_initializer': initializers.serialize( self.recurrent_initializer), 'attentive_initializer': initializers.serialize( self.attentive_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'unit_forget_bias': self.unit_forget_bias, 'kernel_regularizer': regularizers.serialize( self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize( self.recurrent_regularizer), 'attentive_regularizer': regularizers.serialize( self.attentive_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'kernel_constraint': constraints.serialize( self.kernel_constraint), 'recurrent_constraint': constraints.serialize( self.recurrent_constraint), 'attentive_constraint': constraints.serialize( self.attentive_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout, 'attentive_dropout': self.attentive_dropout} base_config = super(AttentiveConvLSTM2DCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) class AttentiveConvLSTM2D(ConvRNN2D): """Convolutional LSTM. It is similar to an LSTM layer, but the input transformations and recurrent transformations are both convolutional. Arguments: filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of n integers, specifying the dimensions of the convolution window. strides: An integer or tuple/list of n integers, specifying the strides of the convolution. Specifying any stride value != 1 is incompatible with specifying any `dilation_rate` value != 1. padding: One of `"valid"` or `"same"` (case-insensitive). data_format: A string, one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape `(batch, time, ..., channels)` while `channels_first` corresponds to inputs with shape `(batch, time, channels, ...)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". dilation_rate: An integer or tuple/list of n integers, specifying the dilation rate to use for dilated convolution. Currently, specifying any `dilation_rate` value != 1 is incompatible with specifying any `strides` value != 1. activation: Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. bias_initializer: Initializer for the bias vector. unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate at initialization. Use in combination with `bias_initializer="zeros"`. This is recommended in [Jozefowicz et al.] (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to. kernel_constraint: Constraint function applied to the `kernel` weights matrix. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. bias_constraint: Constraint function applied to the bias vector. return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. go_backwards: Boolean (default False). If True, process the input sequence backwards. stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Input shape: - if data_format='channels_first' 5D tensor with shape: `(samples, time, channels, rows, cols)` - if data_format='channels_last' 5D tensor with shape: `(samples, time, rows, cols, channels)` Output shape: - if `return_sequences` - if data_format='channels_first' 5D tensor with shape: `(samples, time, filters, output_row, output_col)` - if data_format='channels_last' 5D tensor with shape: `(samples, time, output_row, output_col, filters)` - else - if data_format ='channels_first' 4D tensor with shape: `(samples, filters, output_row, output_col)` - if data_format='channels_last' 4D tensor with shape: `(samples, output_row, output_col, filters)` where o_row and o_col depend on the shape of the filter and the padding Raises: ValueError: in case of invalid constructor arguments. References: - [Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1) The current implementation does not include the feedback loop on the cells output. """ def __init__(self, filters, attentive_filters, kernel_size, attentive_kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation='tanh', recurrent_activation='hard_sigmoid', attentive_activation='tanh', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', attentive_initializer='zeros', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, attentive_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, attentive_constraint=None, bias_constraint=None, return_sequences=False, go_backwards=False, stateful=False, dropout=0., recurrent_dropout=0., attentive_dropout=0., **kwargs): cell = AttentiveConvLSTM2DCell(filters=filters, attentive_filters=attentive_filters, kernel_size=kernel_size, attentive_kernel_size=attentive_kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, recurrent_activation=recurrent_activation, attentive_activation=attentive_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, attentive_initializer=attentive_initializer, bias_initializer=bias_initializer, unit_forget_bias=unit_forget_bias, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, attentive_regularizer=attentive_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, attentive_constraint=attentive_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, attentive_dropout=attentive_dropout) super(AttentiveConvLSTM2D, self).__init__(cell, return_sequences=return_sequences, go_backwards=go_backwards, stateful=stateful, **kwargs) self.activity_regularizer = regularizers.get(activity_regularizer) def call(self, inputs, mask=None, training=None, initial_state=None): return super(AttentiveConvLSTM2D, self).call(inputs, mask=mask, training=training, initial_state=initial_state) @property def filters(self): return self.cell.filters @property def attentive_filters(self): return self.cell.attentive_filters @property def kernel_size(self): return self.cell.kernel_size @property def attentive_kernel_size(self): return self.cell.attentive_kernel_size @property def strides(self): return self.cell.strides @property def padding(self): return self.cell.padding @property def data_format(self): return self.cell.data_format @property def dilation_rate(self): return self.cell.dilation_rate @property def activation(self): return self.cell.activation @property def recurrent_activation(self): return self.cell.recurrent_activation @property def attentive_activation(self): return self.cell.attentive_activation @property def use_bias(self): return self.cell.use_bias @property def kernel_initializer(self): return self.cell.kernel_initializer @property def recurrent_initializer(self): return self.cell.recurrent_initializer @property def attentive_initializer(self): return self.cell.attentive_initializer @property def bias_initializer(self): return self.cell.bias_initializer @property def unit_forget_bias(self): return self.cell.unit_forget_bias @property def kernel_regularizer(self): return self.cell.kernel_regularizer @property def recurrent_regularizer(self): return self.cell.recurrent_regularizer @property def attentive_regularizer(self): return self.cell.attentive_regularizer @property def bias_regularizer(self): return self.cell.bias_regularizer @property def kernel_constraint(self): return self.cell.kernel_constraint @property def recurrent_constraint(self): return self.cell.recurrent_constraint @property def attentive_constraint(self): return self.cell.attentive_constraint @property def bias_constraint(self): return self.cell.bias_constraint @property def dropout(self): return self.cell.dropout @property def recurrent_dropout(self): return self.cell.recurrent_dropout @property def attentive_dropout(self): return self.cell.attentive_dropout def get_config(self): config = {'filters': self.filters, 'attentive_filters': self.attentive_filters, 'kernel_size': self.kernel_size, 'attentive_kernel_size': self.attentive_kernel_size, 'strides': self.strides, 'padding': self.padding, 'data_format': self.data_format, 'dilation_rate': self.dilation_rate, 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize( self.recurrent_activation), 'attentive_activation': activations.serialize( self.attentive_activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize( self.kernel_initializer), 'recurrent_initializer': initializers.serialize( self.recurrent_initializer), 'attentive_initializer': initializers.serialize( self.attentive_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'unit_forget_bias': self.unit_forget_bias, 'kernel_regularizer': regularizers.serialize( self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize( self.recurrent_regularizer), 'attentive_regularizer': regularizers.serialize( self.attentive_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize( self.activity_regularizer), 'kernel_constraint': constraints.serialize( self.kernel_constraint), 'recurrent_constraint': constraints.serialize( self.recurrent_constraint), 'attentive_constraint': constraints.serialize( self.attentive_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout, 'attentive_dropout': self.attentive_dropout} base_config = super(AttentiveConvLSTM2D, self).get_config() del base_config['cell'] return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config): return cls(**config)