I’m building an RNN Model where I would like to do weight updates only when some condition, based on the state of the RNN, is met:
class CustomModel(keras.Model): def train_step(self, data): x, y = data # Perform forward pass and compute gradients if (some condition is met): #Apply gradients
I tried an if condition that only updates based on the value of a 1 x 1 tf.Variable
if t_step > 5. but got the error below:
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
t_step is essentially the time step of the input, but when eager execution is disable it seems like I cannot use it in the if condition.