Loss explodes when using TF.variable vs a constant input

Hi, I am working on a VAE model where one of the losses (KL loss) has a multiplier that I want to change based on the epoch. If I use a simple .5 multiplier in the loss function, there is no issue. If I use a tensorflow variable set to .5, the loss increases until NaN. I was wondering what the cause was, since I know my model works properly with a given multiplier, but not when the multiplier is stored as a tensorflow.variable.

The exact difference would be:
Equation:
total_loss = reconstruction_loss + self.klConst * kl_loss

Correct Loss:
self.klConst = .5

NaN loss:
self.klConst = tensorflow.Variable(.5)

I believe tf.variable is my best option as keras.constant cannot be updated, and changing the constant with a simple “=” does not change the value in training.

IE self.klConst = .2 in a callback function does not actually change the constant in the training loop

Any help would be great!

Hi @jludeman

Welcome to the TensorFlow Forum!

Could you please share the minimal reproducible code to replicate and understand the issue better? Thank you.