在scikit-learn中可以打印决策树吗?

16 投票
3 回答
41543 浏览
提问于 2025-04-18 17:04

有没有办法在scikit-learn中打印出一个训练好的决策树?我想为我的论文训练一个决策树,并且想把这个树的图片放在论文里。这可能吗?

3 个回答

7

我知道有四种方法可以绘制scikit-learn的决策树:

  • 使用sklearn.tree.export_text方法打印树的文本表示
  • 使用sklearn.tree.plot_tree方法绘图(需要matplotlib库)
  • 使用sklearn.tree.export_graphviz方法绘图(需要graphviz库)
  • 使用dtreeviz包绘图(需要dtreevizgraphviz库)

最简单的方法是导出为文本表示。示例决策树看起来像这样:

|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2

如果你安装了matplotlib,可以使用sklearn.tree.plot_tree来绘图:

tree.plot_tree(clf) # the clf is your decision tree model

示例输出与使用export_graphviz得到的结果类似: sklearn决策树可视化

你还可以尝试dtreeviz包。它会提供更多的信息。示例:

dtreeviz示例决策树

你可以在这篇博客文章中找到不同的scikit-learn决策树可视化比较和代码片段:链接

9

虽然我来得有点晚,但下面这些详细的步骤可能对想要展示决策树输出的人有帮助:

安装必要的模块:

  1. 安装 graphviz。我使用的是conda的安装包,可以在这里找到(推荐这个方法,因为用 pip install graphviz 安装的话,里面没有实际的GraphViz 可执行文件
  2. 通过pip安装 pydot(使用命令 pip install pydot
  3. 将包含.exe文件(比如dot.exe)的graphviz文件夹路径添加到你的环境变量PATH中
  4. 运行EdChum上面的代码(注意:graph是一个包含pydot.Dot对象的list):

from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.externals.six import StringIO  
import pydot 

clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)

dot_data = StringIO() 
tree.export_graphviz(clf, out_file=dot_data) 
graph = pydot.graph_from_dot_data(dot_data.getvalue()) 

graph[0].write_pdf("iris.pdf")  # must access graph's first element

现在你会在你环境的默认目录下找到"iris.pdf"这个文件

17

有一种方法可以导出为graph_viz格式,具体可以查看这个链接:http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

根据在线文档:

>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>>
>>> clf = tree.DecisionTreeClassifier()
>>> iris = load_iris()
>>>
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf,
...     out_file='tree.dot')    

然后你可以使用graph viz来加载这个文件,或者如果你安装了pydot,那么可以更直接地做到这一点:http://scikit-learn.org/stable/modules/tree.html

>>> from sklearn.externals.six import StringIO  
>>> import pydot 
>>> dot_data = StringIO() 
>>> tree.export_graphviz(clf, out_file=dot_data) 
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue()) 
>>> graph.write_pdf("iris.pdf") 

这样会生成一个svg文件,但我这里无法显示,你需要点击这个链接查看:http://scikit-learn.org/stable/_images/iris.svg

更新

看起来自从我第一次回答这个问题以来,行为发生了变化,现在返回的是一个list,所以你会遇到这个错误:

AttributeError: 'list' object has no attribute 'write_pdf'

首先,当你看到这个错误时,值得打印一下这个对象并检查它,通常你想要的是第一个对象:

graph[0].write_pdf("iris.pdf")

感谢@NickBraunagel的评论

撰写回答