如何解释scikit-learn中的决策树

39 投票
4 回答
40155 浏览
提问于 2025-04-18 05:56

我在理解scikit-learn中的决策树结果时遇到了两个问题。比如,这是我其中一棵决策树的图:

enter image description here

我的问题是:我该如何使用这棵树呢?

第一个问题是:如果一个样本满足某个条件,它就会走到左边的分支(如果有的话),否则就走右边。在我的例子中,如果样本的X[7]大于63521.3984,那么这个样本就会进入绿色框里。这样理解对吗?

第二个问题是:当一个样本到达叶子节点时,我怎么知道它属于哪个类别?在这个例子中,我有三个类别要分类。在红色框里,有91、212和113个样本分别满足条件。但我该怎么决定它的类别呢?我知道有一个函数clf.predict(sample)可以告诉我类别。我能从这个图中做到这一点吗??非常感谢。

4 个回答

0

在tree.export_graphviz中添加feature_names=X.columns,其中X是训练数据。

我的代码如下

with open("lectureGini.txt", "w") as f:
    f = tree.export_graphviz(lectureGini, out_file=f,feature_names=X.columns)
# copy contents of file LectureGini.txt into WebGraphviz - http://webgraphviz.com/

lectureGini是我用DecisionTreeClassifier得到的结果。

这是我发现的一个简单方法,可以加到我研究的所有关于基尼指数的网页示例中。所有的网页示例都很好地解释了这个方法,但没有一个展示如何找到类别。我还没有安装Graphviz,所以我从jupyter导出一个文本文件,然后把文本复制到Webgraphwiz中。

3

根据《学习scikit-learn:Python中的机器学习》这本书,决策树就像是根据训练数据做出的一系列决策。

!(https://i.stack.imgur.com/2omYY.png)

要对一个实例进行分类,我们需要在每个节点回答一个问题。例如,问“性别是否小于等于0.5?”(我们是在谈论女性吗?)。 如果答案是“是”,那就往树的左边走;如果答案是“否”,那就往右边走。你会继续回答问题(她是否在第三班?她是否在第一班?她是否小于13岁?),直到你到达一个叶子节点。 当你到达那里时,预测结果就是出现次数最多的目标类别

14

第一个问题: 是的,你的逻辑是对的。左边的节点是对的(True),右边的节点是错的(False)。这可能有点反直觉;有时候,真的(True)可以对应一个更小的样本。

第二个问题: 解决这个问题的最好方法是用pydotplus把树形结构可视化成图形。tree.export_graphviz()中的'class_names'属性会在每个节点的主要类别上添加一个类别声明。代码是在iPython笔记本中执行的。

from sklearn.datasets import load_iris  
from sklearn import tree  
iris = load_iris()  
clf2 = tree.DecisionTreeClassifier()  
clf2 = clf2.fit(iris.data, iris.target)  

with open("iris.dot", 'w') as f:  
    f = tree.export_graphviz(clf, out_file=f)  
    
import os  
os.unlink('iris.dot')  

import pydotplus  
dot_data = tree.export_graphviz(clf2, out_file=None)  
graph2 = pydotplus.graph_from_dot_data(dot_data)  
graph2.write_pdf("iris.pdf")  

from IPython.display import Image  
dot_data = tree.export_graphviz(clf2, out_file=None,  
                     feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                     filled=True, rounded=True,  # leaves_parallel=True, 
                     special_characters=True)  
graph2 = pydotplus.graph_from_dot_data(dot_data)

## Color of nodes
nodes = graph2.get_node_list()

for node in nodes:
    if node.get_label():
        values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
        color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
        values = color[values.index(max(values))]; # print(values)
        color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
        node.set_fillcolor(color )
#

Image(graph2.create_png() ) 

在这里输入图片描述

至于如何确定叶子节点的类别,你的例子中没有只有单一类别的叶子节点,就像鸢尾花数据集那样。这是很常见的,可能需要对模型进行过拟合才能得到这样的结果。对于许多交叉验证的模型来说,类别的离散分布是最佳结果。

34

每个框里的 value 行告诉你在这个节点上有多少个样本属于每个类别,顺序是有讲究的。也就是说,在每个框里,value 中的数字加起来正好等于 sample 显示的数字。例如,在你的红色框里,91+212+113=416。这就意味着如果你到达这个节点,类别1有91个数据点,类别2有212个,类别3有113个。

如果你要预测一个新的数据点在决策树的这个叶子节点的结果,你会预测类别2,因为在这个节点上,类别2的样本数量是最多的。

撰写回答