For tf.keras.layers.RNN, there’s a reset_states method that resets the recorded state to some pre-defined input. Is there some way to extract what the current recorded state is?
The documentation mentions a
states attribute which looks like it’s readable. There’s also a
return_state argument in the declaration if you want to include the state in your graph along with your output.
You can also take a look at this example that shows how to retrieve the current state of an RNN layer like LSTM and pass that to initialize another: