How to extract the decision rules from scikit-learn decision-tree?
I believe that this answer is more correct than the other answers here:
from sklearn.tree import _treedef tree_to_code(tree, feature_names): tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] print "def tree({}):".format(", ".join(feature_names)) def recurse(node, depth): indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] print "{}if {} <= {}:".format(indent, name, threshold) recurse(tree_.children_left[node], depth + 1) print "{}else: # if {} > {}".format(indent, name, threshold) recurse(tree_.children_right[node], depth + 1) else: print "{}return {}".format(indent, tree_.value[node]) recurse(0, 1)
This prints out a valid Python function. Here's an example output for a tree that is trying to return its input, a number between 0 and 10.
def tree(f0): if f0 <= 6.0: if f0 <= 1.5: return [[ 0.]] else: # if f0 > 1.5 if f0 <= 4.5: if f0 <= 3.5: return [[ 3.]] else: # if f0 > 3.5 return [[ 4.]] else: # if f0 > 4.5 return [[ 5.]] else: # if f0 > 6.0 if f0 <= 8.5: if f0 <= 7.5: return [[ 7.]] else: # if f0 > 7.5 return [[ 8.]] else: # if f0 > 8.5 return [[ 9.]]
Here are some stumbling blocks that I see in other answers:
- Using
tree_.threshold == -2
to decide whether a node is a leaf isn't a good idea. What if it's a real decision node with a threshold of -2? Instead, you should look attree.feature
ortree.children_*
. - The line
features = [feature_names[i] for i in tree_.feature]
crashes with my version of sklearn, because some values oftree.tree_.feature
are -2 (specifically for leaf nodes). - There is no need to have multiple if statements in the recursive function, just one is fine.
I created my own function to extract the rules from the decision trees created by sklearn:
import pandas as pdimport numpy as npfrom sklearn.tree import DecisionTreeClassifier# dummy data:df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})# create decision treedt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)dt.fit(df.ix[:,:2], df.dv)
This function first starts with the nodes (identified by -1 in the child arrays) and then recursively finds the parents. I call this a node's 'lineage'. Along the way, I grab the values I need to create if/then/else SAS logic:
def get_lineage(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] # get ids of child nodes idx = np.argwhere(left == -1)[:,0] def recurse(left, right, child, lineage=None): if lineage is None: lineage = [child] if child in left: parent = np.where(left == child)[0].item() split = 'l' else: parent = np.where(right == child)[0].item() split = 'r' lineage.append((parent, split, threshold[parent], features[parent])) if parent == 0: lineage.reverse() return lineage else: return recurse(left, right, parent, lineage) for child in idx: for node in recurse(left, right, child): print node
The sets of tuples below contain everything I need to create SAS if/then/else statements. I do not like using do
blocks in SAS which is why I create logic describing a node's entire path. The single integer after the tuples is the ID of the terminal node in a path. All of the preceding tuples combine to create that node.
In [1]: get_lineage(dt, df.columns)(0, 'l', 0.5, 'col1')1(0, 'r', 0.5, 'col1')(2, 'l', 4.5, 'col2')3(0, 'r', 0.5, 'col1')(2, 'r', 4.5, 'col2')(4, 'l', 2.5, 'col1')5(0, 'r', 0.5, 'col1')(2, 'r', 4.5, 'col2')(4, 'r', 2.5, 'col1')6
I modified the code submitted by Zelazny7 to print some pseudocode:
def get_code(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node): if (threshold[node] != -2): print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {" if left[node] != -1: recurse (left, right, threshold, features,left[node]) print "} else {" if right[node] != -1: recurse (left, right, threshold, features,right[node]) print "}" else: print "return " + str(value[node]) recurse(left, right, threshold, features, 0)
if you call get_code(dt, df.columns)
on the same example you will obtain:
if ( col1 <= 0.5 ) {return [[ 1. 0.]]} else {if ( col2 <= 4.5 ) {return [[ 0. 1.]]} else {if ( col1 <= 2.5 ) {return [[ 1. 0.]]} else {return [[ 0. 1.]]}}}