Custom loss function with a conditional

I’m trying to write a custom loss function where there is an extra potential penalty when the true label is 17.0. From what I understand, the code below doesn’t work because of the “if” statement. I think I need to replace it with a tf.where() statement, but everything I’ve tried creates an error. Any suggestions?

from tensorflow.keras import backend as bk

def custom_error(y_true, y_pred):
  total_error = 0.0

  # do mse calc
  error = y_true - y_pred
  sqr_error = bk.square(error)
  mse = bk.mean(sqr_error)

  # This is an extra penalty I was trying to code
#  if y_true==17.0 and bk.round(y_pred)<17.0:
#    total_error += 25.0

  # add mse to to total error as well  
  total_error += mse
  mse2 = tf.math.add(mse,25.0)
  total_error = tf.where(tf.equal(y_true,17.0) and tf.less(bk.round(y_pred),17.0), mse, mse2)
  return total_error

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1024, activation='relu', input_shape=(40,)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),


model.summary(), ytrain, epochs=10

Here’s the error message I’m getting:

Epoch 1/10
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-29-ec85575477da> in <module>()
     44 model.summary()
---> 46, ytrain, epochs=10
     47           )

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/ in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53     ctx.ensure_initialized()
     54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:
     57     if name is not None:

InvalidArgumentError: Graph execution error:

The second input must be a scalar, but it has shape [32]
	 [[{{node custom_error/cond/custom_error/Equal/_6}}]] [Op:__inference_train_function_433584]

I just want to say that The issue with your code is that the if operator is not differentiable per keras backend. However, if you think about the differentiation for this operation, it is rather straightforward; the gradient is dependent on one term only if the conditional evaluates True, and likewise the other term if the conditional evaluates False. Thus, the fix should be simple. Keras backend provides the switch() operation I believe which is essentially a differentiable form of a conditional statement. Try using that instead. switch() takes three arguments: the first is a conditional expression, the second a tensor from which values are taken if the conditional evaluates to true, and the third a tensor from which values are taken if the conditional evaluates to false.

Thank you! I was able to use tf.keras.backend.switch() to do the comparison.

Afterward, I also had to debug my logic a bit. mse is dimensionless, but select() needed something the same shape as y_true, like sqr_error.