# Trainable Initial State RNN¶

Treat the initial state(s) of TensorFlow Keras recurrent neural network (RNN) layers as a parameter or parameters to be learned during training, as recommended in, e.g., [1].

Ordinary RNNs use an all-zero initial state by default. Why not let the neural network learn a smarter initial state?

The `trainable_initial_state_rnn`

package provides a class
`TrainableInitialStateRNN`

that can wrap any
`tf.keras`

RNN (or bidirectional RNN) and manage new initial state
variables in addition to the RNN’s weights.

Typical usage looks as follows.

```
import tensorflow as tf
from trainable_initial_state_rnn import TrainableInitialStateRNN
base_rnn = tf.keras.layers.LSTM(256)
rnn = TrainableInitialStateRNN(base_rnn) # Treats initial state as a variable!
model = tf.keras.Model(...) # Use rnn like any other tf.keras layer in your model
model.compile(...)
history = model.fit(...)
```