如何在Python中绘制CART树,像在R中一样?
在R语言中,我可以直接通过一个API画出决策树的图形表示,这个决策树是对应于一个CART模型的。比如说,使用prp
这个函数,就能生成类似下面这样的图:
但是我在Python中找不到类似的API来实现同样的功能。比如说,sklearn
库里的RandomForestClassifier
和DecisionTreeClassifier
似乎都没有绘制树的相关方法。
那么,我该如何在Python中获取CART或随机森林树的图形表示呢?
2 个回答
1
这个函数可以让图形在Jupyter笔记本中显示出来:
# Imports
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image, display
import pydotplus
def jupyter_graphviz(m, **kwargs):
dot_data = StringIO()
export_graphviz(m, dot_data, **kwargs)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
display(Image(graph.create_png()))
比如说:
import sklearn.datasets as datasets
import pandas as pd
iris = datasets.load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
dtree = DecisionTreeClassifier(random_state=42)
dtree.fit(df, y)
jupyter_graphviz(dtree, filled=True, rounded=True, special_characters=True)
7
使用 export_graphviz
这个函数。
from sklearn.tree import DecisionTreeClassifier, export_graphviz
np.random.seed(0)
X = np.random.randn(10, 4)
y = array(["foo", "bar", "baz"])[np.random.randint(0, 3, 10)]
clf = DecisionTreeClassifier(random_state=42).fit(X, y)
export_graphviz(clf)
现在运行 dotty tree.dot
应该会显示类似下面的内容:
这里有一个 示例笔记本。