Update learning rate in concrate api


My model is exported and loaded in C. (No Python)
Tensorflow v2.9

The train concrete method is called by C. Its last parameter is a scaler tensor containing the rate to use.

The concrete method is built into a graph without eager execution.
How should I update the learning rate here?

@tf.function(input_signature=[tf.TensorSpec(shape=[None, 4, board_height, board_width],  dtype=tf.float32), 
                                  tf.TensorSpec(shape=[None, board_height * board_width],  dtype=tf.float32),
                                  tf.TensorSpec(shape=[],  dtype=tf.float32),
                                  tf.TensorSpec(shape=[1],  dtype=tf.float32) ])
        def train(self, state_batch, mcts_probs, winner_batch, lr):
            tf.keras.backend.set_value(self.model.optimizer.learning_rate, lr.eval())
            with tf.GradientTape() as tape:
                predictions = self.model(state_batch, training=True)  # Forward pass
                # the loss function is configured in `compile()`
                loss = self.model.compiled_loss([mcts_probs, winner_batch], predictions, regularization_losses=self.model.losses)
            gradients = tape.gradient(loss, self.model.trainable_variables)
                zip(gradients, self.model.trainable_variables))

            entropy = tf.negative(tf.reduce_mean(
                tf.reduce_sum(tf.exp(predictions[0][0]) * predictions[0][0], 1)))

            return (loss, entropy)

I tried

tf.keras.backend.set_value(self.model.optimizer.learning_rate, tf.gather(lr, 0))

But it failed with Cannot convert a symbolic tf.Tensor (GatherV2:0) to a numpy array

Also tried , but it says session is missing. Session is dropped in V2.

tf.keras.backend.set_value(self.model.optimizer.learning_rate,  lr.eval())

I also doubt this is wrong in a graph?

Can tf.keras.backend.set_value be used in a graph? How can I pass the lr for each execution?

Have you also tried with:

Yep, I see scheduler allows to change lr in a callback. But how can I pass the variable from a function call?

You can test it but I think you will find the same error as internally it is similar:

I solved the issue with following code

            self.lr = tf.Variable(0.002, trainable=False, dtype=tf.dtypes.float32)

            self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = self.lr),
                    loss=[self.action_loss, tf.keras.losses.MeanSquaredError()],

And in my train method : self.lr.assign(tf.gather(lr, 0))

