How to explore a decision tree built using scikit learn How to explore a decision tree built using scikit learn python python

How to explore a decision tree built using scikit learn


You need to use the predict method.

After training the tree, you feed the X values to predict their output.

from sklearn.datasets import load_irisfrom sklearn.tree import DecisionTreeClassifierclf = DecisionTreeClassifier(random_state=0)iris = load_iris()tree = clf.fit(iris.data, iris.target)tree.predict(iris.data) 

output:

>>> tree.predict(iris.data)array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

To get details on the tree structure, we can use tree_.__getstate__()

Tree structure translated into an "ASCII art" picture

              0          _____________        1           2               ______________               3            12            _______      _______            4     7      13   16           ___   ______        _____           5 6   8    9        14 15                      _____                      10 11

tree structure as an array.

In [38]: tree.tree_.__getstate__()['nodes']Out[38]: array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),       (-1, -1, -2, -2.0, 0.0, 50, 50.0),       (3, 12, 3, 1.75, 0.5, 100, 100.0),       (4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),       (5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),       (-1, -1, -2, -2.0, 0.0, 47, 47.0),        (-1, -1, -2, -2.0, 0.0, 1, 1.0),       (8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),       (-1, -1, -2, -2.0, 0.0, 3, 3.0),       (10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),       (-1, -1, -2, -2.0, 0.0, 2, 2.0),        (-1, -1, -2, -2.0, 0.0, 1, 1.0),       (13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),       (14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),       (-1, -1, -2, -2.0, 0.0, 2, 2.0),        (-1, -1, -2, -2.0, 0.0, 1, 1.0),       (-1, -1, -2, -2.0, 0.0, 43, 43.0)],       dtype=[('left_child', '<i8'), ('right_child', '<i8'),              ('feature', '<i8'), ('threshold', '<f8'),              ('impurity', '<f8'), ('n_node_samples', '<i8'),              ('weighted_n_node_samples', '<f8')])

Where:

  • The first node [0] is the root node.
  • internal nodes have left_child and right_child refering to nodes with positive values, and greater than the current node.
  • leaves have -1 value for the left and right child nodes.
  • nodes 1,5,6, 8,10,11,14,15,16 are leaves.
  • the node structure is built using the Depth First Search Algorithm.
  • the feature field tells us which of the iris.data features was used in the node to determine the path for this sample.
  • the threshold tells us the value used to evaluate the direction based on the feature.
  • impurity reaches 0 at the leaves... since all the samples are in the same class once you reach the leaf.
  • n_node_samples tells us how many samples reach each leaf.

Using this information we could trivially track each sample X to the leaf where it eventually lands by following the classification rules and thresholds on a script. Additionally, the n_node_samples would allow us to perform unit tests ensuring that each node gets the correct number of samples.Then using the output of tree.predict, we could map each leaf to the associated class.


NOTE: This is not an answer, only a hint on possible solutions.

I encountered a similar problem recently in my project. My goal is to extract the corresponding chain of decisions for some particular samples. I think your problem is a subset of mine, since you just need to record the last step in the decision chain.

Up to now, it seems the only viable solution is to write a custom predict method in Python to keep track of the decisions along the way. The reason is that the predict method provided by scikit-learn cannot do this out-of-box (as far as I know). And to make it worse, it is a wrapper for C implementation which is pretty hard to customize.

Customization is fine for my problem, since I'm dealing with a unbalanced dataset, and the samples I care about (positive ones) are rare. So I can filter them out first using sklearn predict and then get the decision chain using my customization.

However, this may not work for you if you have a large dataset. Because if you parse the tree and do predict in Python, it will run slow in Python speed and will not (easily) scale. You may have to fallback to customizing the C implementation.


The below code should produce a plot of your top ten features:

import numpy as npimport matplotlib.pyplot as pltimportances = clf.feature_importances_std = np.std(clf.feature_importances_,axis=0)indices = np.argsort(importances)[::-1]# Print the feature rankingprint("Feature ranking:")for f in range(10):    print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))# Plot the feature importances of the forestplt.figure()plt.title("Feature importances")plt.bar(range(10), importances[indices],       color="r", yerr=std[indices], align="center")plt.xticks(range(10), indices)plt.xlim([-1, 10])plt.show()

Taken from here and modified slightly to fit the DecisionTreeClassifier.

This doesn't exactly help you explore the tree, but it does tell you about the tree.