Model cannot be loaded if the loss function contains tf.where

I am trying to save a model built by the functional API with a loss function that has tf.where, but it cannot be loaded again. Here is my code:

#!/usr/bin/env python3
import tensorflow as tf


def loss(output_0, output_1, weights, label_0, label_1):
    q_0 = tf.where(condition=(label_0 == 0), x=tf.fill(tf.shape(output_0), 0.0), y=output_0)
    q_0 = tf.where(condition=(label_0 == 1), x=tf.fill(tf.shape(output_0), 1.0), y=q_0)
    q_1 = tf.where(condition=(label_1 == 0), x=tf.fill(tf.shape(output_1), 0.0), y=output_1)
    q_1 = tf.where(condition=(label_1 == 1), x=tf.fill(tf.shape(output_1), 1.0), y=q_1)
    L = weights * tf.square(q_0 - q_1)
    return tf.reduce_mean(L)


if __name__ == '__main__':
    x_0 = tf.keras.layers.Input(shape=(2,))
    x_1 = tf.keras.layers.Input(shape=(2,))
    w = tf.keras.layers.Input(shape=(1,))
    l_0 = tf.keras.layers.Input(shape=(1,))
    l_1 = tf.keras.layers.Input(shape=(1,))
    layer = tf.keras.layers.Dense(name='dense1', units=10)
    y_0 = layer(x_0)
    y_1 = layer(x_1)
    model = tf.keras.Model(
        inputs=[x_0, x_1, w, l_0, l_1],
        outputs=[y_0, y_1, w, l_0, l_1])
    model.add_loss(
        loss(output_0=y_0, output_1=y_1, weights=w,
             label_0=l_0, label_1=l_1))
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5))
    model.save('test_save')
    test_reload = tf.keras.models.load_model('test_save')

If I run the code above, then I will get the following error messages:

Traceback (most recent call last):
  File "/home/hanatok/HDD/Documents/playground/python_tf_bug/./test.py", line 31, in <module>
    test_reload = tf.keras.models.load_model('test_save')
  File "/home/hanatok/mambaforge/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/hanatok/mambaforge/lib/python3.10/site-packages/tensorflow/python/util/dispatch.py", line 1076, in op_dispatch_handler
    result = api_dispatcher.Dispatch(args, kwargs)
TypeError: Missing required positional argument

Any ideas?
Update:
I solved the problem myself by replacing

label_0 == 0

with

tf.equal(label_0, 0)
1 Like

Just posting to say thank you for your update, I had the same incredibly cryptic error, and found a similar change that needed to be made so the model would load.