Custom Loss Function for Tensorflow Decision Trees

Hello,

How does one create and use a custom loss function with a decision forest model, such as Random forest, in the tensorflow decision forest (tfdf) library?
The example below given in the documentation seems to be for neural nets, and it is not clear how to modify it to work with tfdf models:

============
class MyModel(tf.keras.Model):
def init(self, *args, **kwargs):
super(MyModel, self).init(*args, **kwargs)
self.loss_tracker = tf.keras.metrics.Mean(name=‘loss’)

def compute_loss(self, x, y, y_pred, sample_weight):
loss = tf.reduce_mean(tf.math.squared_difference(y_pred, y))
loss += tf.add_n(self.losses)
self.loss_tracker.update_state(loss)
return loss

def reset_metrics(self):
self.loss_tracker.reset_states()

@property
def metrics(self):
return [self.loss_tracker]

tensors = tf.random.uniform((10, 10)), tf.random.uniform((10,))
dataset = tf.data.Dataset.from_tensor_slices(tensors).repeat().batch(1)

inputs = tf.keras.layers.Input(shape=(10,), name=‘my_input’)
outputs = tf.keras.layers.Dense(10)(inputs)
model = MyModel(inputs, outputs)
model.add_loss(tf.reduce_sum(outputs))

optimizer = tf.keras.optimizers.SGD()
model.compile(optimizer, loss=‘mse’, steps_per_execution=10)
model.fit(dataset, epochs=2, steps_per_epoch=10)
print('My custom loss: ', model.loss_tracker.result().numpy())

==========

Hi,
TF-DF developer here.

Unfortunately, it is not (yet) possible to use custom losses within TF-DF. TF-DF does provide a library of the most common losses for the tasks it supports (RMSE for regression, NDCG for ranking, …). Since those are deeply engrained in the forest’s computation, the library currently does not expose a way to add other losses.

We are interested in adding new losses that can be useful to the community, so feel free to tell us which losses you’re missing and how those are useful to you and your projects.

1 Like

Hi,
Thanks for your quick reply.
We are interested in adding our own custom regularization terms to existing loss functions, such Mean Square Error. Is there a place in the source code where we can easily incorporate them on a local copy of TF-DF?

Tl;Dr: Defining your own loss function needs some C++ work. You may be able to use the pre-existing regularization mechanisms we have

Any performance-sensitive code of TF-DF is implemented in C++ in a separate project called Yggdrasil Decision Forests (called YDF, see Github). If you want to define or modify your own loss function, you need to modify the functions in this folder. This is definitely some engineering work, but doable. You even have an example, since we’re currently working on adding Poisson Log Loss to the library. If you want to implement a loss, look at the CLs titled “Add Poisson Log Loss” for guidance.

After modifying YDF, you can work directly with the C++ interface, which is documented here. Building YDF is usually not an issue and YDF models are fully compatible with TF-DF and vice versa.

You can also rebuild TF-DF with your changes to YDF, but this is a bit more complicated. The compilation process on Linux is easiest in the TF docker and is documented here, feel free to email with any questions - again, this is certainly doable, but expect a bit of engineering work.

If your goal is to explore regularization techniques, you may be able to use methods that TF-DF already supports. A full list of hyper-parameters is available here.

Hi @rstz , I’m curious to know if the ability of registering custom loss functions in TF-DF is still under your radar. I have a very specific use case where the loss function could hardly be shared for other use cases, therefore would not make sense to end up in YDF.

Hi, yes, this is still planned as an upcoming feature, but we will likely implement first as part of the Yggdrasil Decision Forests API that can export to the model to TF-DF.

Update:

We’ve added custom losses to Yggdrasil Decision Forests :partying_face: The models are fully compatible with TF-DF and you can even define your loss functions with Jax for auto-differentiation.

See the tutorial here: Custom loss - Yggdrasil Decision Forests' documentation

2 Likes