From keras.layers.RNN, the RNN layer has a reset_state() method that can reset the state of the RNN. However, this method takes in a list of numpy arrays instead of Tensors, so if I were to invoke this method inside the call of a model subclass when eager execution is disabled like:
def call(self, x, training=False, **kwargs): ... original_state = [i.numpy() for i in self.rnn_layer.states] ... self.rnn_layer.reset_states(states = original_state) ...
I would get an error saying
numpy() is only available when eager execution is enabled. Is there some way to do what I want above regardless of whether eager execution is enabled or disabled? I thought maybe
add_update() method might be the solution (per this StackOverflow post) but I was unable to get it to reset to the original state.