Keras: how to update a metric according to a value stored on a file?

Hello all,

In my scenario, I use a model to generate some files (pddl files, in planning…), I would like to introduce as a metric of my model the quality of these generated pddl files…

So, to put things clearly, I need to read a file, at each batch iteration, and to update a custom metric of the model with the value stored in this file.

I have made a simple “POC” that you can find here and experiment with.

Here is the code so far (mostly based on Keras documentation):

import tensorflow as tf
import keras
from keras.layers import *
from keras.models import Model
import random

inputs = Input(shape=(784,), name="digits")
x1 = Dense(64, activation="relu")(inputs)
x2 = Dense(64, activation="relu")(x1)
outputs = Dense(10, name="predictions")(x2)
model = Model(inputs=inputs, outputs=outputs)

from keras.optimizers import SGD
from keras.losses import SparseCategoricalCrossentropy
from keras import datasets
import numpy as np
from keras.callbacks import Callback, CallbackList

optimizer = SGD(learning_rate=1e-3)
loss_fn = SparseCategoricalCrossentropy(from_logits=True)

# My custom metric
def my_metric_fn(y_true, y_pred):
    with open("test.txt", "r") as file1: # I open the file where the value is stored
        read_content =
        result_file = float(read_content)
        return tf.reduce_mean(result_file) # I return the value stored in the file

model.compile(optimizer, loss_fn, metrics=[my_metric_fn])

# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset =, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset =, y_val))
val_dataset = val_dataset.batch(batch_size)

epochs = 15

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
      model.train_on_batch(x_batch_train, y_batch_train)
      if step == 3:
        logs_eval = model.evaluate(x_batch_train, y_batch_train, return_dict=True)

But with this code I get this:

As you can see (sorry for the bad quality of the GIF), the my_metric_fn value does not change (stays to 12 instead of 55).

Any idea?