如何提取决策规则来定义决策树分类器中的final/terminal节点,并打印使用numpy数组的代码

2024-05-15 21:31:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试提取决策规则来预测终端节点,并打印使用pandas numpy数组来预测终端节点数的代码。我找到了一个可以在(How to extract the decision rules from scikit-learn decision-tree?)处拉取规则的解决方案,但是我不知道如何扩展它来产生我需要的东西。解决方案的链接有很多答案。这就是我所指的问题和对这个问题的描述。你知道吗

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})
df
# create decision tree
dt = DecisionTreeClassifier(random_state=0, max_depth=5, min_samples_leaf=1)
dt.fit(df.loc[:,('col1','col2')], df.dv)

#This function first starts with the nodes (identified by -1 in the child arrays) and then recursively finds the parents. 
#I call this a node's 'lineage'. Along the way, I grab the values I need to create if/then/else SAS logic:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print (node)

get_lineage(dt, df.columns)

运行代码时,它将提供以下内容:

(0, 'l', 3.5, 'col2')
1
(0, 'r', 3.5, 'col2')
(2, 'l', 1.5, 'col1')
3
(0, 'r', 3.5, 'col2')
(2, 'r', 1.5, 'col1')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 3.5, 'col2')
(2, 'r', 1.5, 'col1')
(4, 'r', 2.5, 'col1')
6

如何将其展开以打印如下内容:

df['Terminal_Node_Num']=np.where(df.loc[:,'col2']<=3.5,1,0)
df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5) & (df.loc[:,'col1'] 
<=1.5)), 3, df['Terminal_Node_Num'])
df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5) & 
(df.loc[:,'col1']>1.5) & (df.loc[:,'col1']<=2.5)), 5, 
df['Terminal_Node_Num'])
df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5)`enter code here`(df.loc[:,'col1']>1.5) & (df.loc[:,'col1']>2.5)), 6, df['Terminal_Node_Num'])  

Tags: therightnodechildtreedfnpleft