718 lines
29 KiB
Python
718 lines
29 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Convolutional-recurrent layers.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import tensorflow as tf
|
|
from keras import backend as K
|
|
from keras import activations
|
|
from keras import initializers
|
|
from keras import regularizers
|
|
from keras import constraints
|
|
#from keras.layers.recurrent import _generate_dropout_mask
|
|
#from keras.layers.recurrent import _standardize_args
|
|
|
|
import numpy as np
|
|
import warnings
|
|
from keras.engine.base_layer import InputSpec, Layer
|
|
from keras.utils import conv_utils
|
|
#from keras.legacy import interfaces
|
|
#from keras.legacy.layers import Recurrent, ConvRecurrent2D
|
|
from keras.layers.recurrent import RNN
|
|
from keras.utils.generic_utils import has_arg
|
|
from keras.utils.generic_utils import to_list
|
|
from keras.utils.generic_utils import transpose_shape
|
|
from tensorflow.python.keras.layers.convolutional_recurrent import ConvRNN2D
|
|
|
|
|
|
def _generate_dropout_mask(ones, rate, training=None, count=1):
|
|
def dropped_inputs():
|
|
return K.dropout(ones, rate)
|
|
|
|
class AttentiveConvLSTM2DCell(Layer):
|
|
|
|
def __init__(self,
|
|
filters,
|
|
attentive_filters,
|
|
kernel_size,
|
|
attentive_kernel_size,
|
|
strides=(1, 1),
|
|
padding='valid',
|
|
data_format=None,
|
|
dilation_rate=(1, 1),
|
|
activation='tanh',
|
|
recurrent_activation='hard_sigmoid',
|
|
attentive_activation='tanh',
|
|
use_bias=True,
|
|
kernel_initializer='glorot_uniform',
|
|
recurrent_initializer='orthogonal',
|
|
attentive_initializer='zeros',
|
|
bias_initializer='zeros',
|
|
unit_forget_bias=True,
|
|
kernel_regularizer=None,
|
|
recurrent_regularizer=None,
|
|
attentive_regularizer=None,
|
|
bias_regularizer=None,
|
|
kernel_constraint=None,
|
|
recurrent_constraint=None,
|
|
attentive_constraint=None,
|
|
bias_constraint=None,
|
|
dropout=0.,
|
|
recurrent_dropout=0.,
|
|
attentive_dropout=0.,
|
|
**kwargs):
|
|
super(AttentiveConvLSTM2DCell, self).__init__(**kwargs)
|
|
self.filters = filters
|
|
self.attentive_filters = attentive_filters
|
|
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
|
|
self.attentive_kernel_size = conv_utils.normalize_tuple(attentive_kernel_size, 2, 'attentive_kernel_size')
|
|
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
|
|
self.padding = conv_utils.normalize_padding(padding)
|
|
self.data_format = K.normalize_data_format(data_format)
|
|
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
|
|
'dilation_rate')
|
|
self.activation = activations.get(activation)
|
|
self.recurrent_activation = activations.get(recurrent_activation)
|
|
self.attentive_activation = activations.get(attentive_activation)
|
|
self.use_bias = use_bias
|
|
|
|
self.kernel_initializer = initializers.get(kernel_initializer)
|
|
self.recurrent_initializer = initializers.get(recurrent_initializer)
|
|
self.attentive_initializer = initializers.get(attentive_initializer)
|
|
self.bias_initializer = initializers.get(bias_initializer)
|
|
self.unit_forget_bias = unit_forget_bias
|
|
|
|
self.kernel_regularizer = regularizers.get(kernel_regularizer)
|
|
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
|
|
self.attentive_regularizer = regularizers.get(attentive_regularizer)
|
|
self.bias_regularizer = regularizers.get(bias_regularizer)
|
|
|
|
self.kernel_constraint = constraints.get(kernel_constraint)
|
|
self.recurrent_constraint = constraints.get(recurrent_constraint)
|
|
self.attentive_constraint = constraints.get(attentive_constraint)
|
|
self.bias_constraint = constraints.get(bias_constraint)
|
|
|
|
|
|
if K.backend() == 'theano' and (dropout or recurrent_dropout):
|
|
warnings.warn(
|
|
'RNN dropout is no longer supported with the Theano backend '
|
|
'due to technical limitations. '
|
|
'You can either set `dropout` and `recurrent_dropout` to 0, '
|
|
'or use the TensorFlow backend.')
|
|
dropout = 0.
|
|
recurrent_dropout = 0.
|
|
self.dropout = min(1., max(0., dropout))
|
|
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
|
|
self.attentive_dropout = min(1., max(0., attentive_dropout))
|
|
self.state_size = (self.filters, self.filters)
|
|
self._dropout_mask = None
|
|
self._recurrent_dropout_mask = None
|
|
self._attentive_dropout_mask = None
|
|
|
|
def build(self, input_shape):
|
|
|
|
if self.data_format == 'channels_first':
|
|
channel_axis = 1
|
|
else:
|
|
channel_axis = -1
|
|
if input_shape[channel_axis] is None:
|
|
raise ValueError('The channel dimension of the inputs '
|
|
'should be defined. Found `None`.')
|
|
input_dim = input_shape[channel_axis]
|
|
kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
|
|
|
|
self.kernel_shape = kernel_shape
|
|
print('kernel_shape', kernel_shape)
|
|
recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
|
|
input_attentive_kernel_shape = self.attentive_kernel_size + (input_dim, self.attentive_filters)
|
|
hidden_attentive_kernel_shape = self.attentive_kernel_size + (self.filters, self.attentive_filters)
|
|
squeeze_attentive_kernel_shape = self.attentive_kernel_size + (self.attentive_filters, 1)
|
|
|
|
self.kernel = self.add_weight(shape=kernel_shape,
|
|
initializer=self.kernel_initializer,
|
|
name='kernel',
|
|
regularizer=self.kernel_regularizer,
|
|
constraint=self.kernel_constraint)
|
|
|
|
self.recurrent_kernel = self.add_weight(
|
|
shape=recurrent_kernel_shape,
|
|
initializer=self.recurrent_initializer,
|
|
name='recurrent_kernel',
|
|
regularizer=self.recurrent_regularizer,
|
|
constraint=self.recurrent_constraint)
|
|
|
|
self.input_attentive_kernel = self.add_weight(
|
|
shape=input_attentive_kernel_shape,
|
|
initializer=self.attentive_initializer,
|
|
name='input_attentive_kernel',
|
|
regularizer=self.attentive_regularizer,
|
|
constraint=self.attentive_constraint)
|
|
|
|
self.hidden_attentive_kernel = self.add_weight(
|
|
shape=hidden_attentive_kernel_shape,
|
|
initializer=self.attentive_initializer,
|
|
name='hidden_attentive_kernel',
|
|
regularizer=self.attentive_regularizer,
|
|
constraint=self.attentive_constraint)
|
|
|
|
self.squeeze_attentive_kernel = self.add_weight(
|
|
shape=squeeze_attentive_kernel_shape,
|
|
initializer=self.attentive_initializer,
|
|
name='squeeze_attentive_kernel',
|
|
regularizer=self.attentive_regularizer,
|
|
constraint=self.attentive_constraint)
|
|
|
|
|
|
if self.use_bias:
|
|
if self.unit_forget_bias:
|
|
|
|
def bias_initializer(_, *args, **kwargs):
|
|
return K.concatenate([
|
|
self.bias_initializer((self.filters,), *args, **kwargs),
|
|
initializers.Ones()((self.filters,), *args, **kwargs),
|
|
self.bias_initializer((self.filters * 2,), *args, **kwargs),
|
|
])
|
|
else:
|
|
bias_initializer = self.bias_initializer
|
|
self.bias = self.add_weight(
|
|
shape=(self.filters * 4,),
|
|
name='bias',
|
|
initializer=bias_initializer,
|
|
regularizer=self.bias_regularizer,
|
|
constraint=self.bias_constraint)
|
|
|
|
self.attentive_bias = self.add_weight(
|
|
shape=(self.attentive_filters * 2,),
|
|
name='attentive_bias',
|
|
initializer=bias_initializer,
|
|
regularizer=self.bias_regularizer,
|
|
constraint=self.bias_constraint)
|
|
|
|
else:
|
|
self.bias = None
|
|
|
|
self.kernel_i = self.kernel[:, :, :, :self.filters]
|
|
self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters]
|
|
self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2]
|
|
self.recurrent_kernel_f = (self.recurrent_kernel[:, :, :, self.filters:
|
|
self.filters * 2])
|
|
self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3]
|
|
self.recurrent_kernel_c = (self.recurrent_kernel[:, :, :, self.filters * 2:
|
|
self.filters * 3])
|
|
self.kernel_o = self.kernel[:, :, :, self.filters * 3:]
|
|
self.recurrent_kernel_o = self.recurrent_kernel[:, :, :, self.filters * 3:]
|
|
|
|
if self.use_bias:
|
|
self.bias_i = self.bias[:self.filters]
|
|
self.bias_f = self.bias[self.filters: self.filters * 2]
|
|
self.bias_c = self.bias[self.filters * 2: self.filters * 3]
|
|
self.bias_o = self.bias[self.filters * 3:]
|
|
self.bias_wa = self.attentive_bias[:self.attentive_filters ]
|
|
self.bias_ua = self.attentive_bias[self.attentive_filters : self.attentive_filters * 2]
|
|
else:
|
|
self.bias_i = None
|
|
self.bias_f = None
|
|
self.bias_c = None
|
|
self.bias_o = None
|
|
|
|
self.built = True
|
|
|
|
def call(self, inputs, states, training=None):
|
|
if 0 < self.dropout < 1 and self._dropout_mask is None:
|
|
self._dropout_mask = _generate_dropout_mask(
|
|
K.ones_like(inputs),
|
|
self.dropout,
|
|
training=training,
|
|
count=4)
|
|
if (0 < self.recurrent_dropout < 1 and
|
|
self._recurrent_dropout_mask is None):
|
|
self._recurrent_dropout_mask = _generate_dropout_mask(
|
|
K.ones_like(states[1]),
|
|
self.recurrent_dropout,
|
|
training=training,
|
|
count=4)
|
|
# if (0 < self.attentive_dropout < 1 and self._attentive_dropout_mask is None):
|
|
# self._attentive_dropout_mask = _generate_dropout_mask(
|
|
# K.ones_like(inputs),
|
|
# self.attentive_dropout,
|
|
# training=training,
|
|
# count=4)
|
|
|
|
# dropout matrices for input units
|
|
dp_mask = self._dropout_mask
|
|
# dropout matrices for recurrent units
|
|
rec_dp_mask = self._recurrent_dropout_mask
|
|
# dropout matrices for attentive units
|
|
# att_dp_mask = self._attentive_dropout_mask
|
|
|
|
h_tm1 = states[0] # previous memory state
|
|
c_tm1 = states[1] # previous carry state
|
|
|
|
|
|
##### ATTENTION MECHANISM
|
|
|
|
h_and_x = self.input_conv(h_tm1, self.hidden_attentive_kernel, self.bias_ua, padding='same') + self.input_conv(inputs, self.input_attentive_kernel, self.bias_wa, padding='same')
|
|
|
|
e = self.recurrent_conv(self.attentive_activation(h_and_x), self.squeeze_attentive_kernel)
|
|
|
|
a = K.reshape(K.softmax(K.batch_flatten(e)), K.shape(e))
|
|
inputs = inputs * K.repeat_elements(a, inputs.shape[-1], -1)
|
|
|
|
|
|
##### END OF ATTENTION MECHANISM
|
|
|
|
|
|
if 0 < self.dropout < 1.:
|
|
inputs_i = inputs * dp_mask[0]
|
|
inputs_f = inputs * dp_mask[1]
|
|
inputs_c = inputs * dp_mask[2]
|
|
inputs_o = inputs * dp_mask[3]
|
|
else:
|
|
inputs_i = inputs
|
|
inputs_f = inputs
|
|
inputs_c = inputs
|
|
inputs_o = inputs
|
|
|
|
if 0 < self.recurrent_dropout < 1.:
|
|
h_tm1_i = h_tm1 * rec_dp_mask[0]
|
|
h_tm1_f = h_tm1 * rec_dp_mask[1]
|
|
h_tm1_c = h_tm1 * rec_dp_mask[2]
|
|
h_tm1_o = h_tm1 * rec_dp_mask[3]
|
|
else:
|
|
h_tm1_i = h_tm1
|
|
h_tm1_f = h_tm1
|
|
h_tm1_c = h_tm1
|
|
h_tm1_o = h_tm1
|
|
|
|
|
|
x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i,
|
|
padding=self.padding)
|
|
x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f,
|
|
padding=self.padding)
|
|
x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c,
|
|
padding=self.padding)
|
|
x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o,
|
|
padding=self.padding)
|
|
h_i = self.recurrent_conv(h_tm1_i,
|
|
self.recurrent_kernel_i)
|
|
h_f = self.recurrent_conv(h_tm1_f,
|
|
self.recurrent_kernel_f)
|
|
h_c = self.recurrent_conv(h_tm1_c,
|
|
self.recurrent_kernel_c)
|
|
h_o = self.recurrent_conv(h_tm1_o,
|
|
self.recurrent_kernel_o)
|
|
|
|
i = self.recurrent_activation(x_i + h_i)
|
|
f = self.recurrent_activation(x_f + h_f)
|
|
c = f * c_tm1 + i * self.activation(x_c + h_c)
|
|
o = self.recurrent_activation(x_o + h_o)
|
|
h = o * self.activation(c)
|
|
|
|
if 0 < self.dropout + self.recurrent_dropout:
|
|
if training is None:
|
|
h._uses_learning_phase = True
|
|
|
|
return h, [h, c]
|
|
|
|
def input_conv(self, x, w, b=None, padding='valid'):
|
|
conv_out = K.conv2d(x, w, strides=self.strides,
|
|
padding=padding,
|
|
data_format=self.data_format,
|
|
dilation_rate=self.dilation_rate)
|
|
if b is not None:
|
|
conv_out = K.bias_add(conv_out, b,
|
|
data_format=self.data_format)
|
|
return conv_out
|
|
|
|
def recurrent_conv(self, x, w):
|
|
conv_out = K.conv2d(x, w, strides=(1, 1),
|
|
padding='same',
|
|
data_format=self.data_format)
|
|
return conv_out
|
|
|
|
def get_config(self):
|
|
config = {'filters': self.filters,
|
|
'attentive_filters': self.attentive_filters,
|
|
'kernel_size': self.kernel_size,
|
|
'attentive_kernel_size': self.attentive_kernel_size,
|
|
'strides': self.strides,
|
|
'padding': self.padding,
|
|
'data_format': self.data_format,
|
|
'dilation_rate': self.dilation_rate,
|
|
'activation': activations.serialize(self.activation),
|
|
'recurrent_activation': activations.serialize(
|
|
self.recurrent_activation),
|
|
'attentive_activation': activations.serialize(
|
|
self.attentive_activation),
|
|
'use_bias': self.use_bias,
|
|
'kernel_initializer': initializers.serialize(
|
|
self.kernel_initializer),
|
|
'recurrent_initializer': initializers.serialize(
|
|
self.recurrent_initializer),
|
|
'attentive_initializer': initializers.serialize(
|
|
self.attentive_initializer),
|
|
'bias_initializer': initializers.serialize(self.bias_initializer),
|
|
'unit_forget_bias': self.unit_forget_bias,
|
|
'kernel_regularizer': regularizers.serialize(
|
|
self.kernel_regularizer),
|
|
'recurrent_regularizer': regularizers.serialize(
|
|
self.recurrent_regularizer),
|
|
'attentive_regularizer': regularizers.serialize(
|
|
self.attentive_regularizer),
|
|
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
|
'kernel_constraint': constraints.serialize(
|
|
self.kernel_constraint),
|
|
'recurrent_constraint': constraints.serialize(
|
|
self.recurrent_constraint),
|
|
'attentive_constraint': constraints.serialize(
|
|
self.attentive_constraint),
|
|
'bias_constraint': constraints.serialize(self.bias_constraint),
|
|
'dropout': self.dropout,
|
|
'recurrent_dropout': self.recurrent_dropout,
|
|
'attentive_dropout': self.attentive_dropout}
|
|
base_config = super(AttentiveConvLSTM2DCell, self).get_config()
|
|
return dict(list(base_config.items()) + list(config.items()))
|
|
|
|
|
|
class AttentiveConvLSTM2D(ConvRNN2D):
|
|
"""Convolutional LSTM.
|
|
It is similar to an LSTM layer, but the input transformations
|
|
and recurrent transformations are both convolutional.
|
|
Arguments:
|
|
filters: Integer, the dimensionality of the output space
|
|
(i.e. the number of output filters in the convolution).
|
|
kernel_size: An integer or tuple/list of n integers, specifying the
|
|
dimensions of the convolution window.
|
|
strides: An integer or tuple/list of n integers,
|
|
specifying the strides of the convolution.
|
|
Specifying any stride value != 1 is incompatible with specifying
|
|
any `dilation_rate` value != 1.
|
|
padding: One of `"valid"` or `"same"` (case-insensitive).
|
|
data_format: A string,
|
|
one of `channels_last` (default) or `channels_first`.
|
|
The ordering of the dimensions in the inputs.
|
|
`channels_last` corresponds to inputs with shape
|
|
`(batch, time, ..., channels)`
|
|
while `channels_first` corresponds to
|
|
inputs with shape `(batch, time, channels, ...)`.
|
|
It defaults to the `image_data_format` value found in your
|
|
Keras config file at `~/.keras/keras.json`.
|
|
If you never set it, then it will be "channels_last".
|
|
dilation_rate: An integer or tuple/list of n integers, specifying
|
|
the dilation rate to use for dilated convolution.
|
|
Currently, specifying any `dilation_rate` value != 1 is
|
|
incompatible with specifying any `strides` value != 1.
|
|
activation: Activation function to use.
|
|
If you don't specify anything, no activation is applied
|
|
(ie. "linear" activation: `a(x) = x`).
|
|
recurrent_activation: Activation function to use
|
|
for the recurrent step.
|
|
use_bias: Boolean, whether the layer uses a bias vector.
|
|
kernel_initializer: Initializer for the `kernel` weights matrix,
|
|
used for the linear transformation of the inputs.
|
|
recurrent_initializer: Initializer for the `recurrent_kernel`
|
|
weights matrix,
|
|
used for the linear transformation of the recurrent state.
|
|
bias_initializer: Initializer for the bias vector.
|
|
unit_forget_bias: Boolean.
|
|
If True, add 1 to the bias of the forget gate at initialization.
|
|
Use in combination with `bias_initializer="zeros"`.
|
|
This is recommended in [Jozefowicz et al.]
|
|
(http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
|
|
kernel_regularizer: Regularizer function applied to
|
|
the `kernel` weights matrix.
|
|
recurrent_regularizer: Regularizer function applied to
|
|
the `recurrent_kernel` weights matrix.
|
|
bias_regularizer: Regularizer function applied to the bias vector.
|
|
activity_regularizer: Regularizer function applied to.
|
|
kernel_constraint: Constraint function applied to
|
|
the `kernel` weights matrix.
|
|
recurrent_constraint: Constraint function applied to
|
|
the `recurrent_kernel` weights matrix.
|
|
bias_constraint: Constraint function applied to the bias vector.
|
|
return_sequences: Boolean. Whether to return the last output
|
|
in the output sequence, or the full sequence.
|
|
go_backwards: Boolean (default False).
|
|
If True, process the input sequence backwards.
|
|
stateful: Boolean (default False). If True, the last state
|
|
for each sample at index i in a batch will be used as initial
|
|
state for the sample of index i in the following batch.
|
|
dropout: Float between 0 and 1.
|
|
Fraction of the units to drop for
|
|
the linear transformation of the inputs.
|
|
recurrent_dropout: Float between 0 and 1.
|
|
Fraction of the units to drop for
|
|
the linear transformation of the recurrent state.
|
|
Input shape:
|
|
- if data_format='channels_first'
|
|
5D tensor with shape:
|
|
`(samples, time, channels, rows, cols)`
|
|
- if data_format='channels_last'
|
|
5D tensor with shape:
|
|
`(samples, time, rows, cols, channels)`
|
|
Output shape:
|
|
- if `return_sequences`
|
|
- if data_format='channels_first'
|
|
5D tensor with shape:
|
|
`(samples, time, filters, output_row, output_col)`
|
|
- if data_format='channels_last'
|
|
5D tensor with shape:
|
|
`(samples, time, output_row, output_col, filters)`
|
|
- else
|
|
- if data_format ='channels_first'
|
|
4D tensor with shape:
|
|
`(samples, filters, output_row, output_col)`
|
|
- if data_format='channels_last'
|
|
4D tensor with shape:
|
|
`(samples, output_row, output_col, filters)`
|
|
where o_row and o_col depend on the shape of the filter and
|
|
the padding
|
|
Raises:
|
|
ValueError: in case of invalid constructor arguments.
|
|
References:
|
|
- [Convolutional LSTM Network: A Machine Learning Approach for
|
|
Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1)
|
|
The current implementation does not include the feedback loop on the
|
|
cells output.
|
|
"""
|
|
|
|
def __init__(self,
|
|
filters,
|
|
attentive_filters,
|
|
kernel_size,
|
|
attentive_kernel_size,
|
|
strides=(1, 1),
|
|
padding='valid',
|
|
data_format=None,
|
|
dilation_rate=(1, 1),
|
|
activation='tanh',
|
|
recurrent_activation='hard_sigmoid',
|
|
attentive_activation='tanh',
|
|
use_bias=True,
|
|
kernel_initializer='glorot_uniform',
|
|
recurrent_initializer='orthogonal',
|
|
attentive_initializer='zeros',
|
|
bias_initializer='zeros',
|
|
unit_forget_bias=True,
|
|
kernel_regularizer=None,
|
|
recurrent_regularizer=None,
|
|
attentive_regularizer=None,
|
|
bias_regularizer=None,
|
|
activity_regularizer=None,
|
|
kernel_constraint=None,
|
|
recurrent_constraint=None,
|
|
attentive_constraint=None,
|
|
bias_constraint=None,
|
|
return_sequences=False,
|
|
go_backwards=False,
|
|
stateful=False,
|
|
dropout=0.,
|
|
recurrent_dropout=0.,
|
|
attentive_dropout=0.,
|
|
**kwargs):
|
|
cell = AttentiveConvLSTM2DCell(filters=filters,
|
|
attentive_filters=attentive_filters,
|
|
kernel_size=kernel_size,
|
|
attentive_kernel_size=attentive_kernel_size,
|
|
strides=strides,
|
|
padding=padding,
|
|
data_format=data_format,
|
|
dilation_rate=dilation_rate,
|
|
activation=activation,
|
|
recurrent_activation=recurrent_activation,
|
|
attentive_activation=attentive_activation,
|
|
use_bias=use_bias,
|
|
kernel_initializer=kernel_initializer,
|
|
recurrent_initializer=recurrent_initializer,
|
|
attentive_initializer=attentive_initializer,
|
|
bias_initializer=bias_initializer,
|
|
unit_forget_bias=unit_forget_bias,
|
|
kernel_regularizer=kernel_regularizer,
|
|
recurrent_regularizer=recurrent_regularizer,
|
|
attentive_regularizer=attentive_regularizer,
|
|
bias_regularizer=bias_regularizer,
|
|
kernel_constraint=kernel_constraint,
|
|
recurrent_constraint=recurrent_constraint,
|
|
attentive_constraint=attentive_constraint,
|
|
bias_constraint=bias_constraint,
|
|
dropout=dropout,
|
|
recurrent_dropout=recurrent_dropout,
|
|
attentive_dropout=attentive_dropout)
|
|
|
|
super(AttentiveConvLSTM2D, self).__init__(cell,
|
|
return_sequences=return_sequences,
|
|
go_backwards=go_backwards,
|
|
stateful=stateful,
|
|
**kwargs)
|
|
self.activity_regularizer = regularizers.get(activity_regularizer)
|
|
|
|
def call(self, inputs, mask=None, training=None, initial_state=None):
|
|
return super(AttentiveConvLSTM2D, self).call(inputs,
|
|
mask=mask,
|
|
training=training,
|
|
initial_state=initial_state)
|
|
|
|
@property
|
|
def filters(self):
|
|
return self.cell.filters
|
|
|
|
@property
|
|
def attentive_filters(self):
|
|
return self.cell.attentive_filters
|
|
|
|
@property
|
|
def kernel_size(self):
|
|
return self.cell.kernel_size
|
|
|
|
@property
|
|
def attentive_kernel_size(self):
|
|
return self.cell.attentive_kernel_size
|
|
|
|
@property
|
|
def strides(self):
|
|
return self.cell.strides
|
|
|
|
@property
|
|
def padding(self):
|
|
return self.cell.padding
|
|
|
|
@property
|
|
def data_format(self):
|
|
return self.cell.data_format
|
|
|
|
@property
|
|
def dilation_rate(self):
|
|
return self.cell.dilation_rate
|
|
|
|
@property
|
|
def activation(self):
|
|
return self.cell.activation
|
|
|
|
@property
|
|
def recurrent_activation(self):
|
|
return self.cell.recurrent_activation
|
|
|
|
@property
|
|
def attentive_activation(self):
|
|
return self.cell.attentive_activation
|
|
|
|
@property
|
|
def use_bias(self):
|
|
return self.cell.use_bias
|
|
|
|
@property
|
|
def kernel_initializer(self):
|
|
return self.cell.kernel_initializer
|
|
|
|
@property
|
|
def recurrent_initializer(self):
|
|
return self.cell.recurrent_initializer
|
|
|
|
@property
|
|
def attentive_initializer(self):
|
|
return self.cell.attentive_initializer
|
|
|
|
@property
|
|
def bias_initializer(self):
|
|
return self.cell.bias_initializer
|
|
|
|
@property
|
|
def unit_forget_bias(self):
|
|
return self.cell.unit_forget_bias
|
|
|
|
@property
|
|
def kernel_regularizer(self):
|
|
return self.cell.kernel_regularizer
|
|
|
|
@property
|
|
def recurrent_regularizer(self):
|
|
return self.cell.recurrent_regularizer
|
|
|
|
@property
|
|
def attentive_regularizer(self):
|
|
return self.cell.attentive_regularizer
|
|
|
|
@property
|
|
def bias_regularizer(self):
|
|
return self.cell.bias_regularizer
|
|
|
|
@property
|
|
def kernel_constraint(self):
|
|
return self.cell.kernel_constraint
|
|
|
|
@property
|
|
def recurrent_constraint(self):
|
|
return self.cell.recurrent_constraint
|
|
|
|
@property
|
|
def attentive_constraint(self):
|
|
return self.cell.attentive_constraint
|
|
|
|
@property
|
|
def bias_constraint(self):
|
|
return self.cell.bias_constraint
|
|
|
|
@property
|
|
def dropout(self):
|
|
return self.cell.dropout
|
|
|
|
@property
|
|
def recurrent_dropout(self):
|
|
return self.cell.recurrent_dropout
|
|
|
|
@property
|
|
def attentive_dropout(self):
|
|
return self.cell.attentive_dropout
|
|
|
|
def get_config(self):
|
|
config = {'filters': self.filters,
|
|
'attentive_filters': self.attentive_filters,
|
|
'kernel_size': self.kernel_size,
|
|
'attentive_kernel_size': self.attentive_kernel_size,
|
|
'strides': self.strides,
|
|
'padding': self.padding,
|
|
'data_format': self.data_format,
|
|
'dilation_rate': self.dilation_rate,
|
|
'activation': activations.serialize(self.activation),
|
|
'recurrent_activation': activations.serialize(
|
|
self.recurrent_activation),
|
|
'attentive_activation': activations.serialize(
|
|
self.attentive_activation),
|
|
'use_bias': self.use_bias,
|
|
'kernel_initializer': initializers.serialize(
|
|
self.kernel_initializer),
|
|
'recurrent_initializer': initializers.serialize(
|
|
self.recurrent_initializer),
|
|
'attentive_initializer': initializers.serialize(
|
|
self.attentive_initializer),
|
|
'bias_initializer': initializers.serialize(self.bias_initializer),
|
|
'unit_forget_bias': self.unit_forget_bias,
|
|
'kernel_regularizer': regularizers.serialize(
|
|
self.kernel_regularizer),
|
|
'recurrent_regularizer': regularizers.serialize(
|
|
self.recurrent_regularizer),
|
|
'attentive_regularizer': regularizers.serialize(
|
|
self.attentive_regularizer),
|
|
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
|
'activity_regularizer': regularizers.serialize(
|
|
self.activity_regularizer),
|
|
'kernel_constraint': constraints.serialize(
|
|
self.kernel_constraint),
|
|
'recurrent_constraint': constraints.serialize(
|
|
self.recurrent_constraint),
|
|
'attentive_constraint': constraints.serialize(
|
|
self.attentive_constraint),
|
|
'bias_constraint': constraints.serialize(self.bias_constraint),
|
|
'dropout': self.dropout,
|
|
'recurrent_dropout': self.recurrent_dropout,
|
|
'attentive_dropout': self.attentive_dropout}
|
|
base_config = super(AttentiveConvLSTM2D, self).get_config()
|
|
del base_config['cell']
|
|
return dict(list(base_config.items()) + list(config.items()))
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
return cls(**config)
|