How to predict using TensorflowDecision Tree?

Hi,
How can I predict Labels using Tensorflow’s Decision Tree?
I have attached all the steps that I’ve done so far.

Thank you

If you call model.predict(), it should output class probabilities (number of columns is equal to the number of classes). You can apply tf.math.argmax() to it - it will produce the indexes of largest value in each row (highest probability). Presumably you have a list of labels for all classes. Use the indexes to get the class label for each sample in our dataset.
If your classification task is binary, model.predict() should output one value between 0 and 1 for each sample. In that case you can use 0.5 or some other threshold to decide if it’s class 0 or class 1.

2 Likes