API Reference¶
Module trainable_initial_state_rnn
¶
TensorFlow Keras RNNs with trainable initial states
-
class
trainable_initial_state_rnn.
TrainableInitialStateRNN
(layer, *, initializer=None, regularizer=None, constraint=None, **kwargs)[source]¶ Bases:
tensorflow.python.keras.layers.wrappers.Wrapper
Wrapper class for RNNs with trainable initial state variables.
Parameters: - layer (tf.keras.Layer) – An RNN instance (or a Bidirectional instance wrapping an RNN instance).
- initializer (str, callable, or list/tuple of str/callable, optional.) – The initializer for the initial state tensor(s), or a list or tuple of initializers, one for each initial state. If not specified, initial states will be initialized to zeros.
- regularizer (str, callable, or list/tuple of str/callable, optional.) – The regularizer for the initial state tensor(s), or a list or tuple of regularizers, one for each initial state. If not specified, initial states will not be regularized.
- constraint (str, callable, or list/tuple of str/callable, optional.) – The constraint for the initial state tensor(s), or a list or tuple of constraints, one for each initial state. If not specified, initial states will not be constrained.
- **kwargs (keyword arguments) – Additional keyword arguments for
tf.keras.layers.Wrapper
(e.g., name).
-
build
(input_shape=None)[source]¶ Build this layer’s underlying RNN and initial state tensors.
Parameters: input_shape (TensorShape, or list of TensorShape) – Shape(s) of the input to the RNN.
-
call
(inputs, *args, **kwargs)[source]¶ Run the RNN on input data.
Parameters: - inputs (tensor-like, or nested structure of list or tensor-like) – Inputs to the underlying RNN. This should not include the initial state.
- *args (positional arguments) – Additional positional arguments to pass to the underlying RNN.
- **kwargs (keyword arguments) – Keyword arguments for the underlying RNN. If an initial_state keyword argument is present, its value is used instead of this layer’s initial_state variable(s).
Returns: The output of the underlying RNN.
Return type: tf.Tensor