如何显示测试样本的决策树路径?

2024-04-26 07:06:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用scikit learn的DecisionTreeClassifier对一些多类数据进行分类。我发现了很多描述如何显示决策树路径的文章,比如herehere,和{a4}。但是,它们都描述了如何为训练后的数据显示树。这是有道理的,因为^{}只需要一个合适的模型。在

我的问题是如何在测试样本上可视化树(最好是export_graphviz)。一、 在用clf.fit(X[train], y[train])拟合模型,然后用clf.predict(X[test])预测测试数据的结果之后,我想可视化用于预测样本X[test]的决策路径。有办法吗?在

编辑:

我看到可以使用decision_path打印路径。如果有一种方法可以从export_graphviz获得一个DOT输出来显示它,那就太好了。在


Tags: 数据模型test路径here可视化分类train
1条回答
网友
1楼 · 发布于 2024-04-26 07:06:02

为了获得决策树中特定样本的路径,可以使用^{}。它返回一个稀疏矩阵,其中包含所提供样本的决策路径。在

然后可以使用这些决策路径为通过pydot生成的树加上颜色/标签。这需要重写颜色和标签(这会导致一些难看的代码)。在

注意事项

  • decision_path可以从训练集或新值中获取样本
  • 您可以随意使用颜色,并根据采样数或其他可能需要的可视化效果来更改颜色

示例

在下面的示例中,一个访问过的节点是绿色的,所有其他节点都是白色的。在

enter image description here

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)

# empty all nodes, i.e.set color to white and number of samples to zero
for node in graph.get_node_list():
    if node.get_attributes().get('label') is None:
        continue
    if 'samples = ' in node.get_attributes()['label']:
        labels = node.get_attributes()['label'].split('<br/>')
        for i, label in enumerate(labels):
            if label.startswith('samples = '):
                labels[i] = 'samples = 0'
        node.set('label', '<br/>'.join(labels))
        node.set_fillcolor('white')

samples = iris.data[129:130]
decision_paths = clf.decision_path(samples)

for decision_path in decision_paths:
    for n, node_value in enumerate(decision_path.toarray()[0]):
        if node_value == 0:
            continue
        node = graph.get_node(str(n))[0]            
        node.set_fillcolor('green')
        labels = node.get_attributes()['label'].split('<br/>')
        for i, label in enumerate(labels):
            if label.startswith('samples = '):
                labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)

        node.set('label', '<br/>'.join(labels))

filename = 'tree.png'
graph.write_png(filename)

相关问题 更多 >