Tensorflow - self.metrics does not contain compiled metrics in train_step()

In the code below, I overridded train_step(), but self.metrics does not contain compiled metrics.

As a result the verbose of fit() does not show the compiled metrics and the history does not contains the metrics.

How can I modified this train_step() to reproduce exactly the same behaviour as tensorflow?

import tensorflow as tf
from tensorflow.keras.layers import Dense

print(f"{tf.__version__ = }")
tf.config.run_functions_eagerly(True)

# Dataset
(x_tra, y_tra), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()

x_tra = x_tra.reshape(-1, 784).astype("float32") / 255
x_tst = x_tst.reshape(-1, 784).astype("float32") / 255

y_tra = tf.one_hot(y_tra.astype("float32"), depth=10)
y_tst = tf.one_hot(y_tst.astype("float32"), depth=10)


# Model
class MModel(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.dense1 = Dense(64, activation="relu")
        self.dense2 = Dense(64, activation="relu")
        self.softmax = Dense(10, activation="softmax")

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.softmax(x)

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            ŷ = self(x, training=True)
            loss_value = self.compiled_loss(y, ŷ, regularization_losses=self.losses)

        # Compute gradients and Update weights
        gradients = tape.gradient(loss_value, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))


        """
        
         With a debugger, you can see that self.metrics list contains only the loss (Mean(name=loss,dtype=float32)) but
         not the compiled metrics
         Also, self.compiled_metrics.metrics list is empty.
        """

        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss_value)
            else:
                metric.update_state(y, ŷ)

        return {m.name: m.result() for m in self.metrics}


model = MModel()

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[tf.keras.metrics.CategoricalAccuracy()]
)

history = model.fit(
    x_tra, y_tra,
    epochs=1,
    validation_data=(x_tst, y_tst)
)

print(f"{model.evaluate( x_tst, y_tst) = }")


"""
1875/1875 [==============================] - 12s 5ms/step - loss: 0.2838 - val_loss: 0.1499 - val_categorical_accuracy: 0.9525

"""

Please note that the validation metrics is present because I did not overide train_test()

To fix the issue, modify your train_step() method to explicitly update all compiled metrics using self.compiled_metrics.update_state(y, ŷ) and ensure the returned dictionary includes results for each metric by combining self.metrics and self.compiled_metrics.metrics. This will ensure proper tracking and display of metrics during training and evaluation.

Thanks, it works. But I also removed the

for metric in self.metrics :

loop because self.compiled_loss add (in the first train_step) / update the loss in self.metrics.

Regarding your second advice, I don’t see in which case self.compiled_metrics.metrics would contain metrics that are not in self.metrics but I keep that in the back of my mind. Thanks.

This also seems to mean that the official documentation at https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit#providing_your_own_evaluation_step is (now) wrong.

Please find below a working code:

def train_step(self, data):
    x, y = data

    with tf.GradientTape() as tape:
        ŷ = self(x, training=True)
        loss_value = self.compiled_loss(y, ŷ, regularization_losses=self.losses)

    gradients = tape.gradient(loss_value, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

    self.compiled_metrics.update_state(y, ŷ)

    return {m.name: m.result() for m in self.metrics}

def test_step(self, data):
    x, y = data

    ŷ = self(x, training=False)
    _ = self.compiled_loss(y, ŷ, regularization_losses=self.losses)

    self.compiled_metrics.update_state(y, ŷ)

    return {m.name: m.result() for m in self.metrics}