How to use tfdf.builder.CARTBuilder to build/train a decision tree by hand

Expectation

Use the tfdf.builder.CARTBuilder to build a decision tree structure and train it with the literal dataset, and optimize the tree structure per the performance.

The process is like manually replicate the training process using tfdf.keras.CartModel but the benefit is that I can adjust the tree structure per needs, not only focusing on the model performance, which will be helpful if intuitive rules are needed.

Sample code

I tried to use the tfdf.builder.CARTBuilder to build the structure and fit/predict but the results are not as expected as the fitting process does not lead to change of the prediction of leaves.

Below are some sample code with a sample dataset running in Colab

import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import math
import collections


# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")

model_trial_idx = 10

# Create the model builder

model_trial_idx += 1
model_path = f"/tmp/manual_model/{model_trial_idx}"

builder = tfdf.builder.CARTBuilder(
    path=model_path,
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="species", classes=["Adelie", "Non-Adelie"]))


# Create some alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue

builder.add_tree(
    Tree(
        NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="bill_length_mm", type=ColumnType.NUMERICAL),
                threshold=40.0,
                missing_evaluation=False),
            
            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="island",type=ColumnType.CATEGORICAL),
                    mask=["Dream", "Torgersen"],
                    missing_evaluation=False)
                ,pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.2], num_examples=10))
                ,neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.9], num_examples=20))
                ),
            
            neg_child=LeafNode(value=ProbabilityValue(probability=[0.2, 0.8], num_examples=30))
            )
        )
    
    )

builder.close()

manual_model = tf.keras.models.load_model(model_path)

# Convert the pandas dataframe into a tf dataset.

dataset_df['species_binary'] = dataset_df['species'] == 'Adelie'

dataset_tf_2 = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df[['bill_length_mm','island','species_binary']], label="species_binary")


# model compile and fit
manual_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.keras.metrics.BinaryAccuracy(),
                       tf.keras.metrics.FalseNegatives()])


manual_model.fit(dataset_tf_2)

Questions

  • The above code runs without error, but the tree does not reflect the fitting results, the prediction prob and number of samples stay the same, which is very weird, looks like the manual_model is a completely static model. How can we define the prob and number of samples before running the model on some data?

  • I assume that the tfdf.builder.CARTBuilder is used to build a shell, and the performance of each node can be reflected after fitting/prediction. I am very confused why it requires me to define the value in the leaf in the first place and the value remains the same after fitting/prediction, did I miss anything?

  • What is the best practice to use tfdf.builder.CARTBuilder to build a decision tree by hand on earth?

Reference:

Hi Vincent,

Different machine learning models (neural network, trees, svm, etc.) use different algorithms to train. In your particular situation, note that decision forest algorithms like CART do not use the same algorithm as neural networks. TensorFlow is primarily a Neural Network library, so not all TensorFlow operations are meaningful when using decision forests (e.g. TensorFlow Decision Forests).

The guide Migrating from Neural Networks can give you more details.

In your particular example:

  • The model builder is a way to create a model by hand. This is notably used to import models from other formats, or to apply post-training transformations (e.g. pruning).
  • Calling “fit” on a model created with the builder will not change the decision forest model. You are likely seeing warning messages telling you that there are not trainable weights. Calling fit on a manual model is only meaningful if you add extra “tensorflow variables” e.g. finetuning of a neural network taking the tree as input features, but this is more advanced.

Can you give more details on what you are trying to achieve? Maybe this can be done with TF-DF, just not with the same algorithm as for a neural network.

Cheers,
M.

Hi @Mathieu ,

Thank you for the reply!

I am aligned with you that the builder is perfect for expert-designed trees and this is what I am looking for, because I expect to build my tree consisting of the features and thresholds under my control, which helps me build intuitive and compliant rules (without using features violating regulations and laws).

To this end, the CARTBuilder does address part of the needs, but what I do not follow is the builder requires adding the probability and number of samples before running any data on the model. How could we know the probability and number of samples before running the model on some dataset?

Also, I do not expect the tree structure to be updated after fitting/prediction, but the probability and samples (which are the results) of each node do not get updated after running the model on some data, this is what I am confused about. When using the CARTBuilder, is in fact a way to simply design the expert-designed trees structure without specifying the probability and number of samples?

Given dataset, mostly we do not know the prob and number of samples before running the model on any data, I can set them with arbitrary values, but could those values be updated after running the model on some data by any chance? If yes, that would make a lot of sense and very helpful!

If you manually build the tree structure, you also have to specify the probabilities. TF-DF does not allow one without the other.

You are the expert on your problem, but if you manually create the structure of the tree, you probably also know or be able to compute the values of the leaves of a CART (without TF-DF). You can check the Decision Forest ML class for some insights.

However, if your goal is to create a good predictive model that follows some compliance rules, I would recommend to start by using the classical TF-DF training, and inject those constraints through the hyper parameters. For example, if you don’t want your model to use a particular feature, you can remove this feature from the dataset (or use the “features=” model argument). If you don’t want leaves to depend on one (or a small number of) data points, increase the “minimum number of observations per leaf” parameters. No all constraints can be expressed this way, but this should already get you moving.

I hope this helps.
M.

However, if your goal is to create a good predictive model that follows some compliance rules, I would recommend to start by using the classical TF-DF training, and inject those constraints through the hyper parameters. For example, if you don’t want your model to use a particular feature, you can remove this feature from the dataset (or use the “features=” model argument). If you don’t want leaves to depend on one (or a small number of) data points, increase the “minimum number of observations per leaf” parameters. No all constraints can be expressed this way, but this should already get you moving.

Thanks for your reply!

This is what I do with regular CART model with different packages like Sci-kit learn, TensorFlow etc. TensorFlow is the first package that allows users to build a tree by hand and in fact it has way more flexibility than the regular CART model.

I just wonder, even though the prob and number of samples are specified in the first place when using the CART builder, by any chance can the prob and number of samples be updated after running the model on some dataset?

On top of that, these a few days I have been thinking of the best practice of this builder, below workflow would make sense:

  1. ML: Build a tree model using the regular tfdf.keras.model, say tfdf.keras.CartModel, it returns the tree structure and the real number of samples and probability
  2. Insert the ML model into the builder: get the tree of the above step and build a builder, something like below:
sample_tree = inspector.extract_tree(tree_idx=0)

# Create the model builder

model_trial_idx = 1
model_trial_idx += 1
model_path = f"/tmp/manual_model/{model_trial_idx}"

builder = tfdf.builder.CARTBuilder(
    path=model_path,
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="species", classes=["Adelie", "Gentoo", "Chinstrap"])
    )

builder.add_tree(sample_tree)

  1. Human expertise: the fine-tune with expertise comes into play - tweak root, nodes and leaves if necessary, then call builder.close()
  2. Activate ML-expertise-fusion model: run this model on the dataset and refresh the prob and number of samples to reflect the literal performance of each node.

This makes more sense than building the builder from scratch, as it works better in general on top of a result based on ML model.

Nonetheless, I went through the APIs, looks like some classes like Tree does not have setter or similar functions, below code snippets threw errors:

sample_tree.root = NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="bill_length_mm", type=ColumnType.NUMERICAL),
                threshold=40.0,
                missing_evaluation=False),

            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="island",type=ColumnType.CATEGORICAL),
                    mask=["Dream", "Torgersen"],
                    missing_evaluation=False)
                ,pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.2], num_examples=10))
                ,neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.9], num_examples=20))
                ),

            neg_child=LeafNode(value=ProbabilityValue(probability=[0.2, 0.8], num_examples=30))
            )

error is:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-35-242d44038491>](https://localhost:8080/#) in <cell line: 1>()
----> 1 sample_tree.root = NonLeafNode(
      2             condition=NumericalHigherThanCondition(
      3                 feature=SimpleColumnSpec(name="bill_length_mm", type=ColumnType.NUMERICAL),
      4                 threshold=40.0,
      5                 missing_evaluation=False),

AttributeError: can't set attribute 'root'

Also builder class does not have APIs to tweak each node, seemingly, something like below in the phase of using builder will also be helpful.

builder.get_node[node_idx] = NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="island",type=ColumnType.CATEGORICAL),
                    mask=["Dream", "Torgersen"],
                    missing_evaluation=False)
                ,pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.2], num_examples=10))
                ,neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.9], num_examples=20))
                )

Perhaps I have missed some APIs, as seems the Python API doc is not exactly aligned with the code. But please advise on my proposed workflow and the questions regarding those setting methods. Appreciate it!

Note: This question was replied to on the TensorFlow Decision Forest GitHub repo (How to use tfdf.builder.CARTBuilder to build/train a decision tree by hand · Issue #184 · tensorflow/decision-forests · GitHub).

1 Like