How to get standard deviation (or other measure of confidence) from tfdf.RandomForestModel regression?

I’m working on an application where I’d like to retrieve the standard deviation of the predictions made by the trees within an ensemble (currently a tfdf.keras.RandomForestModel) to use as an estimate of the confidence of a given prediction. These are regression predictions rather than categorical so I’m assuming the best way would be to look at the distribution of predictions within the ensemble, but other ideas very welcome!

It looks like I could do this by running a prediction on each individual tree with inspector.iterate_on_nodes() but is there a better way to do this via the main predict method that I’ve missed in the documentation, or some other recommended way?

Hi Jamie,

I am copying the answer from the github :slight_smile:.

As you corrected noted, the API does not allow to obtain the individual tree predictions directly. Please feel free to create a feature request :). If we see traction, we will prioritize it.

In the mean time, there is two alternative solutions:

  1. Training multiple Random Forest models, each with one tree (while making sure to change the random seed).
  2. Training a single Random Forest model and dividing it per trees using the model inspector and model builder.

Using the model builder to generate the individual trees might be easier than running the inference manually in python.

While faster than solution 1., the solution 2. can still be slow on large models and datasets as the model deserialization+re-serialization in python is relatively slow. It would look like this:

# Train a Random Forest with 10 trees
model = tfdf.keras.RandomForestModel(num_trees=10)
model.fit(train_ds)

# Extract each of the 10 trees into a separate model.
inspector = model.make_inspector()

# TODO: Run in parallel.
models = []
for tree_idx, tree in enumerate(inspector.extract_all_trees()):
  print(f"Extract and export tree #{tree_idx}")

  # Create a RF model with a single tree.
  path = os.path.join(f"/tmp/model/{tree_idx}")
  builder = tfdf.builder.RandomForestBuilder(
      path=path,
      objective=inspector.objective(),
      import_dataspec=inspector.dataspec)
  builder.add_tree(tree)
  builder.close()

  models.append(tf.keras.models.load_model(path))

# Compute the predictions of all the trees together.
class CombinedModel (tf.keras.Model):
  def call(self, inputs):
    # We assume that we have a binary classication model that returns a single
    # probability. In case of multi-class classification, use tf.stack instead.
    return tf.concat([ submodel(inputs) for submodel in models], axis=1)

print("Prediction of all the trees")
combined_model = CombinedModel()
all_trees_predictions = combined_model.predict(test_with_cast_ds)

See this colab for a full example.

Ps: Make sure to correctly use the all_trees_predictions to compute the prediction confidence interval. For example using Wager et al. method.

Cheers,
M.

Edit: Add the import_dataspec constructor argument in the model builder. This will help with some of the situation with categorical features. See this page for some explanations.

1 Like

Thanks @Mathieu - incredibly helpful and appreciate you taking the time to put together the example! I’ll give your #2 route a go with my application since I’m already training a Random Forest with multiple trees.

Also thanks for the link to Wagner et al paper - I was aware of the technique as I’ve seen reference in the R ranger package and this Scikit Learn contrib package, but hadn’t seen the source paper so look forward to taking a look to understand more deeply.

Sorry for asking on both GitHub and the forum - after posting there I wasn’t sure if GitHub was only for code focused issues and the forum for help requests so thought here might be more appropriate.

Will add a feature request to GitHub. Do you prefer that they sit at the TFDF level? In this case I think it would probably involve changes to YDF (i.e. a custom reducer similar to here) but it looks like most activity is on the TFDF repository.

Happy to help :slight_smile:

Yes. Posting the feature request in TF-DF is better for the reasons you mentioned.

Ps: Happy to see a R user.