How to extract the decision rules from scikit-learn decision-tree? How to extract the decision rules from scikit-learn decision-tree? python python

How to extract the decision rules from scikit-learn decision-tree?


I believe that this answer is more correct than the other answers here:

from sklearn.tree import _treedef tree_to_code(tree, feature_names):    tree_ = tree.tree_    feature_name = [        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"        for i in tree_.feature    ]    print "def tree({}):".format(", ".join(feature_names))    def recurse(node, depth):        indent = "  " * depth        if tree_.feature[node] != _tree.TREE_UNDEFINED:            name = feature_name[node]            threshold = tree_.threshold[node]            print "{}if {} <= {}:".format(indent, name, threshold)            recurse(tree_.children_left[node], depth + 1)            print "{}else:  # if {} > {}".format(indent, name, threshold)            recurse(tree_.children_right[node], depth + 1)        else:            print "{}return {}".format(indent, tree_.value[node])    recurse(0, 1)

This prints out a valid Python function. Here's an example output for a tree that is trying to return its input, a number between 0 and 10.

def tree(f0):  if f0 <= 6.0:    if f0 <= 1.5:      return [[ 0.]]    else:  # if f0 > 1.5      if f0 <= 4.5:        if f0 <= 3.5:          return [[ 3.]]        else:  # if f0 > 3.5          return [[ 4.]]      else:  # if f0 > 4.5        return [[ 5.]]  else:  # if f0 > 6.0    if f0 <= 8.5:      if f0 <= 7.5:        return [[ 7.]]      else:  # if f0 > 7.5        return [[ 8.]]    else:  # if f0 > 8.5      return [[ 9.]]

Here are some stumbling blocks that I see in other answers:

  1. Using tree_.threshold == -2 to decide whether a node is a leaf isn't a good idea. What if it's a real decision node with a threshold of -2? Instead, you should look at tree.feature or tree.children_*.
  2. The line features = [feature_names[i] for i in tree_.feature] crashes with my version of sklearn, because some values of tree.tree_.feature are -2 (specifically for leaf nodes).
  3. There is no need to have multiple if statements in the recursive function, just one is fine.


I created my own function to extract the rules from the decision trees created by sklearn:

import pandas as pdimport numpy as npfrom sklearn.tree import DecisionTreeClassifier# dummy data:df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})# create decision treedt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)dt.fit(df.ix[:,:2], df.dv)

This function first starts with the nodes (identified by -1 in the child arrays) and then recursively finds the parents. I call this a node's 'lineage'. Along the way, I grab the values I need to create if/then/else SAS logic:

def get_lineage(tree, feature_names):     left      = tree.tree_.children_left     right     = tree.tree_.children_right     threshold = tree.tree_.threshold     features  = [feature_names[i] for i in tree.tree_.feature]     # get ids of child nodes     idx = np.argwhere(left == -1)[:,0]          def recurse(left, right, child, lineage=None):                    if lineage is None:               lineage = [child]          if child in left:               parent = np.where(left == child)[0].item()               split = 'l'          else:               parent = np.where(right == child)[0].item()               split = 'r'          lineage.append((parent, split, threshold[parent], features[parent]))          if parent == 0:               lineage.reverse()               return lineage          else:               return recurse(left, right, parent, lineage)     for child in idx:          for node in recurse(left, right, child):               print node

The sets of tuples below contain everything I need to create SAS if/then/else statements. I do not like using do blocks in SAS which is why I create logic describing a node's entire path. The single integer after the tuples is the ID of the terminal node in a path. All of the preceding tuples combine to create that node.

In [1]: get_lineage(dt, df.columns)(0, 'l', 0.5, 'col1')1(0, 'r', 0.5, 'col1')(2, 'l', 4.5, 'col2')3(0, 'r', 0.5, 'col1')(2, 'r', 4.5, 'col2')(4, 'l', 2.5, 'col1')5(0, 'r', 0.5, 'col1')(2, 'r', 4.5, 'col2')(4, 'r', 2.5, 'col1')6

GraphViz output of example tree


I modified the code submitted by Zelazny7 to print some pseudocode:

def get_code(tree, feature_names):        left      = tree.tree_.children_left        right     = tree.tree_.children_right        threshold = tree.tree_.threshold        features  = [feature_names[i] for i in tree.tree_.feature]        value = tree.tree_.value        def recurse(left, right, threshold, features, node):                if (threshold[node] != -2):                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"                        if left[node] != -1:                                recurse (left, right, threshold, features,left[node])                        print "} else {"                        if right[node] != -1:                                recurse (left, right, threshold, features,right[node])                        print "}"                else:                        print "return " + str(value[node])        recurse(left, right, threshold, features, 0)

if you call get_code(dt, df.columns) on the same example you will obtain:

if ( col1 <= 0.5 ) {return [[ 1.  0.]]} else {if ( col2 <= 4.5 ) {return [[ 0.  1.]]} else {if ( col1 <= 2.5 ) {return [[ 1.  0.]]} else {return [[ 0.  1.]]}}}