Custom Sparse MAE Metric for Regression

I’m struggling to write a custom metric due to limitation surrounding symbolic tf.Tensor.


I’m training regression network to predict multiple values (think: price, height, weight, …). I have normalized all input values to be between [0,1] and all my target values (in the dataset) to be between [-1, 1]. Then I have a de-normalization function that scales the [-1,1] to a realistic values.

Unfortunately my dataset is very sparse.

  • For the input side I’m using one-hot encoded variables (each variable has a corresponding one-hot parameter that indicates if it present). This seems to be working on my experiments with synthetic data.
  • For the output side I made sure to use all complete (where all target values are present) in the training set, but the sparse ones (where some target values are missing) in the validation set.

This means that my validation curve on the mae-epoch graph is non-existent due to countless NaNs in the computation.

My idea was to write custom metric functions for the different target values. One for price, one for the height, one for weight, and etc. Each of these custom metrics values would only calculate the mae in situations that it exists in the dataset.

But I can’t seem to figure out how to drop the data-points where there are NaNs.


How do I write a sparse mean absolute error metric.

One that only uses data points that are present in the y_true tensor.

def met_mae_price(y_true, y_pred):
   y_true = y_true[:,0]
   y_pred = y_pred[:,0]

   # I'm trying to drop from both tensors any values that are NaNs in y_true
   mask = tf.vectorized_map(lambda x: not np.isnan(x), y_true)
   y_true = tf.boolean_mask(y_true, mask)
   y_pred = tf.boolean_mask(y_pred, mask)

   mae = tf.keras.metrics.MeanAbsoluteError()
   mae.update_state(y_true, y_pred)
   return mae.result()

I’ve tried lots of different approaches, but always get lots of different errors. Mostly that I can’t iterate, but also that I can’t pass Tensors to NumPy calls. I keep reading the API documentation trying to find a way to do this, but don’t seem to be able to figure it out.

Can someone help please?