Trainable Initial State RNN¶
Treat the initial state(s) of TensorFlow Keras recurrent neural network (RNN) layers as a parameter or parameters to be learned during training, as recommended in, e.g., .
Ordinary RNNs use an all-zero initial state by default. Why not let the neural network learn a smarter initial state?
trainable_initial_state_rnn package provides a class
TrainableInitialStateRNN that can wrap any
tf.keras RNN (or bidirectional RNN) and manage new initial state
variables in addition to the RNN’s weights.
Typical usage looks as follows.
import tensorflow as tf from trainable_initial_state_rnn import TrainableInitialStateRNN base_rnn = tf.keras.layers.LSTM(256) rnn = TrainableInitialStateRNN(base_rnn) # Treats initial state as a variable! model = tf.keras.Model(...) # Use rnn like any other tf.keras layer in your model model.compile(...) history = model.fit(...)