Custom F1 metric Keras

I have to define a custom F1 metric in keras for a multiclass classification problem. Since it is a streaming metric the idea is to keep track of the true positives, false negative and false positives so as to gradually update the f1 score batch after batch. Here’s the code:

data = load_iris()
X =
y =
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
def compute_confusion_matrix(true, pred, K):
    result = tf.zeros((K, K), dtype=tf.int32)
    for i in range(len(true)):
        result = tf.tensor_scatter_nd_add(tensor = result, indices=tf.constant([[true[i], pred[i]]]), 
    return result

def f1_function(y_true, y_pred):
    k = 3
    y_pred_lab = np.argmax(y_pred, axis=1)  
    y_true = np.ravel(y_true)
    conf_mat= compute_confusion_matrix(y_true, y_pred_lab, K = k)
    tp = tf.linalg.tensor_diag_part(conf_mat)   
    fp = tf.reduce_sum(conf_mat, axis = 0) - tp
    fn = tf.reduce_sum(conf_mat, axis = 1) - tp
    support = tf.reduce_sum(conf_mat, axis = 1)
    return tp, fp, fn, support

class F1Metric(keras.metrics.Metric):
    def __init__(self, **kwargs):
        self.f1_fn = f1_function
        self.tp_count = self.add_weight("tp_count", initializer="zeros", shape = (3,), dtype=tf.float32)
        self.fp_count = self.add_weight("fp_count", initializer="zeros", shape = (3,), dtype=tf.float32)
        self.fn_count = self.add_weight("fn_count", initializer="zeros", shape = (3,), dtype=tf.float32)
        self.support_total = self.add_weight("support_total", initializer = "zeros", shape = (3,), 
    def update_state(self, y_true, y_pred, sample_weight=None):
        tp, fp, fn, support = self.f1_fn(y_true, y_pred)
        self.tp_count.assign_add(tf.cast(tp, dtype=tf.float32))
        self.fp_count.assign_add(tf.cast(fp, dtype=tf.float32))
        self.fn_count.assign_add(tf.cast(fn, dtype=tf.float32))
        self.support_total.assign_add(tf.cast(support, dtype=tf.float32))
    def result(self):
        precisions = self.tp_count / (self.tp_count + self.fp_count)
        recalls = self.tp_count / (self.tp_count + self.fn_count)
        f1 = tf.constant(2, dtype=tf.float32) * (precisions*recalls) / (precisions + recalls)
        weighted_f1 = (f1 * self.support_total) / tf.reduce_sum(tf.cast(self.support_total, dtype=tf.float32))
        return  recalls

model = keras.models.Sequential([
    keras.layers.Dense(200, activation = "relu", input_shape = X_train.shape[1:]),
    keras.layers.Dense(4, activation = "softmax")

early_stopping_cb = keras.callbacks.EarlyStopping(patience=10,

#compile the model
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=[F1Metric()],

#fit the model
history =, y_train, epochs = 100, 
                    callbacks = [early_stopping_cb],

It gives the following error:
“Cannot assign to variable tp_count:0 due to variable shape (3,) and value shape () are incompatible”

Alternatively, I tried to use the tfa F1 metric but I can’t use it in a grid search (indeed I want to find the optimal model architecture and I want to use the f1 metric as the scorer) since it gives the following error:
“ValueError: The list/tuple elements must be unique strings of predefined scorers. One or more of the elements were callables. Use a dict of score name mapped to the scorer callable. Got [<tensorflow_addons.metrics.f_scores.F1Score object at 0x7f8ac9516be0>]”
Any idea? Thank you

1 Like

Did you try using the version from TensorFlow Addons?

1 Like

yes and it gives the error highlighted above

1 Like

It was removed from Keras some years ago:

Also in Addons It used in a callback (see point 2):

1 Like

I’ve read the issue #825 in the second link and it says that there are no problems related to the tfa implementation of the F1 metric when used together with tf.keras instead of multi-backend keras. However, I still haven’t figured out how to make it work in a grid search and that’s the reason why I tried a custom implementation. Is there a way to solve the problem? Thank you again.

1 Like

This is an example search with Kerastuner:

I think you can try the same with TFA or use the custom impl.