From keras.layers.RNN, the RNN layer has a reset_state() method that can reset the state of the RNN. However, this method takes in a list of numpy arrays instead of Tensors, so if I were to invoke this method inside the call of a model subclass when eager execution is disabled like:
def call(self, x, training=False, **kwargs):
...
original_state = [i.numpy() for i in self.rnn_layer.states]
...
self.rnn_layer.reset_states(states = original_state)
...
I would get an error saying numpy() is only available when eager execution is enabled
. Is there some way to do what I want above regardless of whether eager execution is enabled or disabled? I thought maybe add_update()
method might be the solution (per this StackOverflow post) but I was unable to get it to reset to the original state.