无法从管道绘制树

2024-05-28 19:17:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我有下面的决策树分类代码,我可以看到这个模型的预测结果,但不能绘制树

from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.compose import make_column_selector as selector
from sklearn.tree import plot_tree
from sklearn.tree import DecisionTreeClassifier

# Scale numeric values
num_piepline = Pipeline([("imputer", SimpleImputer(missing_values=np.nan,
                                          strategy="median",
                                          )),
                           ('scalar1',StandardScaler()),
                           
                      ])

# One-hot encode categorical values
cat_pipeline = Pipeline([('onehot', OneHotEncoder(handle_unknown='ignore'))])

full_pipeline = ColumnTransformer(
    transformers=[
        ('num', num_piepline, ['a', 'b', 'c', 'd']),
        ('cat', cat_pipeline, ['e'])
        
    ])

decisiontree_entropy_model = Pipeline(steps=[
    ('dt_preprocessor', full_pipeline),
    ('dt_classifier', DecisionTreeClassifier(random_state=2021, max_depth=3, criterion='entropy'))])

decisiontree_entropy_model.fit(X_train, y_train)

dte_y_pred = decisiontree_entropy_model.predict(X_train)

fig = plt.figure(figsize=(25,20))
plot_tree(decisiontree_entropy_model_clf)

我得到下面的错误堆栈跟踪

---------------------------------------------------------------------------
NotFittedError                            Traceback (most recent call last)
<ipython-input-151-da85340c2477> in <module>
      1 from sklearn.tree import plot_tree
      2 fig = plt.figure(figsize=(25,20))
----> 3 plot_tree(decisiontree_entropy_model_clf)
      4 
      5 # from IPython.display import Image

~\Anaconda3\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\Anaconda3\lib\site-packages\sklearn\tree\_export.py in plot_tree(decision_tree, max_depth, feature_names, class_names, label, filled, impurity, node_ids, proportion, rotate, rounded, precision, ax, fontsize)
    178     """
    179 
--> 180     check_is_fitted(decision_tree)
    181 
    182     if rotate != 'deprecated':

~\Anaconda3\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\Anaconda3\lib\site-packages\sklearn\utils\validation.py in check_is_fitted(estimator, attributes, msg, all_or_any)
   1017 
   1018     if not attrs:
-> 1019         raise NotFittedError(msg % {'name': type(estimator).__name__})
   1020 
   1021 

NotFittedError: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

在这里,我在模型上运行了fit,我可以看到模型上的分类报告,但打印出来的结果显示NotFitted错误。在我们调用fit一次后,管道距离是否不存在?不知道为什么它只是在从分类报告中导出性能指标时绘制树失败


Tags: infromimporttreemodelpipelineplotlib
1条回答
网友
1楼 · 发布于 2024-05-28 19:17:41

代码中没有名为decisiontree_entropy_model_clf的内容;要从管道中绘制决策树,应使用

plot_tree(decisiontree_entropy_model['dt_classifier'])

管道安装后(安装前树甚至不存在)

有关访问管道的各种属性的一般信息,请参见Getting model attributes from pipeline

相关问题 更多 >

    热门问题