如何在Python中绘制CART树,像在R中一样?

6 投票
2 回答
8994 浏览
提问于 2025-04-18 01:05

在R语言中,我可以直接通过一个API画出决策树的图形表示,这个决策树是对应于一个CART模型的。比如说,使用prp这个函数,就能生成类似下面这样的图:

但是我在Python中找不到类似的API来实现同样的功能。比如说,sklearn库里的RandomForestClassifierDecisionTreeClassifier似乎都没有绘制树的相关方法。

那么,我该如何在Python中获取CART或随机森林树的图形表示呢?

2 个回答

1

这个函数可以让图形在Jupyter笔记本中显示出来:

# Imports
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image, display
import pydotplus

def jupyter_graphviz(m, **kwargs):
    dot_data = StringIO()
    export_graphviz(m, dot_data, **kwargs)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
    display(Image(graph.create_png()))

比如说:

import sklearn.datasets as datasets
import pandas as pd

iris = datasets.load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
dtree = DecisionTreeClassifier(random_state=42)
dtree.fit(df, y)

jupyter_graphviz(dtree, filled=True, rounded=True, special_characters=True)

树形图示

这里有一个笔记本的示例,改编自这篇文章

7

使用 export_graphviz 这个函数。

from sklearn.tree import DecisionTreeClassifier, export_graphviz
np.random.seed(0)
X = np.random.randn(10, 4)
y = array(["foo", "bar", "baz"])[np.random.randint(0, 3, 10)]
clf = DecisionTreeClassifier(random_state=42).fit(X, y)
export_graphviz(clf)

现在运行 dotty tree.dot 应该会显示类似下面的内容:

树的可视化

这里有一个 示例笔记本

撰写回答