How to reset state of an RNN layer?

From keras.layers.RNN, the RNN layer has a reset_state() method that can reset the state of the RNN. However, this method takes in a list of numpy arrays instead of Tensors, so if I were to invoke this method inside the call of a model subclass when eager execution is disabled like:

def call(self, x, training=False,   **kwargs):
    ...
    original_state = [i.numpy() for i in self.rnn_layer.states]
    ...
    self.rnn_layer.reset_states(states = original_state)
    ...

I would get an error saying numpy() is only available when eager execution is enabled. Is there some way to do what I want above regardless of whether eager execution is enabled or disabled? I thought maybe add_update() method might be the solution (per this StackOverflow post) but I was unable to get it to reset to the original state.

IIUC self.states will return you the variables containing the states.

You should be able to use var.assign(x) to update them.

But the code may be clearer if you use stateful=False and return_states=True. Then you have direct control of the states, you don’t have to reach into the RNN to edit the.

1 Like