API Reference

Module trainable_initial_state_rnn

TensorFlow Keras RNNs with trainable initial states

class trainable_initial_state_rnn.TrainableInitialStateRNN(layer, *, initializer=None, regularizer=None, constraint=None, **kwargs)[source]

Bases: tensorflow.python.keras.layers.wrappers.Wrapper

Wrapper class for RNNs with trainable initial state variables.

Parameters:
  • layer (tf.keras.Layer) – An RNN instance (or a Bidirectional instance wrapping an RNN instance).
  • initializer (str, callable, or list/tuple of str/callable, optional.) – The initializer for the initial state tensor(s), or a list or tuple of initializers, one for each initial state. If not specified, initial states will be initialized to zeros.
  • regularizer (str, callable, or list/tuple of str/callable, optional.) – The regularizer for the initial state tensor(s), or a list or tuple of regularizers, one for each initial state. If not specified, initial states will not be regularized.
  • constraint (str, callable, or list/tuple of str/callable, optional.) – The constraint for the initial state tensor(s), or a list or tuple of constraints, one for each initial state. If not specified, initial states will not be constrained.
  • **kwargs (keyword arguments) – Additional keyword arguments for tf.keras.layers.Wrapper (e.g., name).
build(input_shape=None)[source]

Build this layer’s underlying RNN and initial state tensors.

Parameters:input_shape (TensorShape, or list of TensorShape) – Shape(s) of the input to the RNN.
call(inputs, *args, **kwargs)[source]

Run the RNN on input data.

Parameters:
  • inputs (tensor-like, or nested structure of list or tensor-like) – Inputs to the underlying RNN. This should not include the initial state.
  • *args (positional arguments) – Additional positional arguments to pass to the underlying RNN.
  • **kwargs (keyword arguments) – Keyword arguments for the underlying RNN. If an initial_state keyword argument is present, its value is used instead of this layer’s initial_state variable(s).
Returns:

The output of the underlying RNN.

Return type:

tf.Tensor

get_config() → dict[source]

Get the configuration of the layer as a JSON-serializable dict.