如何解释scikit-learn中的决策树
我在理解scikit-learn中的决策树结果时遇到了两个问题。比如,这是我其中一棵决策树的图:
我的问题是:我该如何使用这棵树呢?
第一个问题是:如果一个样本满足某个条件,它就会走到左边的分支(如果有的话),否则就走右边。在我的例子中,如果样本的X[7]大于63521.3984,那么这个样本就会进入绿色框里。这样理解对吗?
第二个问题是:当一个样本到达叶子节点时,我怎么知道它属于哪个类别?在这个例子中,我有三个类别要分类。在红色框里,有91、212和113个样本分别满足条件。但我该怎么决定它的类别呢?我知道有一个函数clf.predict(sample)可以告诉我类别。我能从这个图中做到这一点吗??非常感谢。
4 个回答
在tree.export_graphviz中添加feature_names=X.columns,其中X是训练数据。
我的代码如下
with open("lectureGini.txt", "w") as f:
f = tree.export_graphviz(lectureGini, out_file=f,feature_names=X.columns)
# copy contents of file LectureGini.txt into WebGraphviz - http://webgraphviz.com/
lectureGini是我用DecisionTreeClassifier得到的结果。
这是我发现的一个简单方法,可以加到我研究的所有关于基尼指数的网页示例中。所有的网页示例都很好地解释了这个方法,但没有一个展示如何找到类别。我还没有安装Graphviz,所以我从jupyter导出一个文本文件,然后把文本复制到Webgraphwiz中。
根据《学习scikit-learn:Python中的机器学习》这本书,决策树就像是根据训练数据做出的一系列决策。
!(https://i.stack.imgur.com/2omYY.png)
要对一个实例进行分类,我们需要在每个节点回答一个问题。例如,问“性别是否小于等于0.5?”(我们是在谈论女性吗?)。 如果答案是“是”,那就往树的左边走;如果答案是“否”,那就往右边走。你会继续回答问题(她是否在第三班?她是否在第一班?她是否小于13岁?),直到你到达一个叶子节点。 当你到达那里时,预测结果就是出现次数最多的目标类别。
第一个问题: 是的,你的逻辑是对的。左边的节点是对的(True),右边的节点是错的(False)。这可能有点反直觉;有时候,真的(True)可以对应一个更小的样本。
第二个问题: 解决这个问题的最好方法是用pydotplus把树形结构可视化成图形。tree.export_graphviz()中的'class_names'属性会在每个节点的主要类别上添加一个类别声明。代码是在iPython笔记本中执行的。
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf2 = tree.DecisionTreeClassifier()
clf2 = clf2.fit(iris.data, iris.target)
with open("iris.dot", 'w') as f:
f = tree.export_graphviz(clf, out_file=f)
import os
os.unlink('iris.dot')
import pydotplus
dot_data = tree.export_graphviz(clf2, out_file=None)
graph2 = pydotplus.graph_from_dot_data(dot_data)
graph2.write_pdf("iris.pdf")
from IPython.display import Image
dot_data = tree.export_graphviz(clf2, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True, # leaves_parallel=True,
special_characters=True)
graph2 = pydotplus.graph_from_dot_data(dot_data)
## Color of nodes
nodes = graph2.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
values = color[values.index(max(values))]; # print(values)
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
node.set_fillcolor(color )
#
Image(graph2.create_png() )
至于如何确定叶子节点的类别,你的例子中没有只有单一类别的叶子节点,就像鸢尾花数据集那样。这是很常见的,可能需要对模型进行过拟合才能得到这样的结果。对于许多交叉验证的模型来说,类别的离散分布是最佳结果。
每个框里的 value
行告诉你在这个节点上有多少个样本属于每个类别,顺序是有讲究的。也就是说,在每个框里,value
中的数字加起来正好等于 sample
显示的数字。例如,在你的红色框里,91+212+113=416。这就意味着如果你到达这个节点,类别1有91个数据点,类别2有212个,类别3有113个。
如果你要预测一个新的数据点在决策树的这个叶子节点的结果,你会预测类别2,因为在这个节点上,类别2的样本数量是最多的。