如何获取scikitlearn决策树中所有节点的pos/neg实例计数?

2024-03-29 08:25:03 发布

您现在位置:Python中文网/ 问答频道 /正文

我训练了一个sklearn决策树。在

from sklearn.tree import DecisionTreeClassifier
c=DecisionTreeClassifier(class_weight="auto")
c.fit([[0,0],
       [0,1],
       [1,1],
      ],[0,1,0])

现在我想检查每个节点有多少个正/负样本。因此一个图

^{pr2}$

如何从一个经过训练的决策树中得到这个(左数)?在

我可以看到一个c.tree_变量,但是内容似乎没有什么帮助。有零,权重。。。而且很难猜出怎样才能把计数恢复过来。在


Tags: fromimport决策树tree内容auto节点sklearn
0条回答
网友
1楼 · 发布于 2024-03-29 08:25:03

每个类的样本数存储在tree_.value中,但是它只存储叶的节点值,所以我使用后序遍历来获取所有节点的值。在

import numpy as np

def get_value(dt):
    left = dt.tree_.children_left
    right = dt.tree_.children_right
    value = dt.tree_.value
    leaves = np.argwhere(left == -1)[:, 0]

    def visit(node):
        if node in leaves:
            return
        visit(left[node])
        visit(right[node])
        value[node, :] = value[left[node], :] + value[right[node], :]

    visit(0)
    return value

例如

^{pr2}$

输出:

[[[ 2.  1.]]

 [[ 1.  1.]]

 [[ 1.  0.]]

 [[ 0.  1.]]

 [[ 1.  0.]]]

更新1

我想知道为什么tree_.value只存储叶节点的值,然后我找到了https://stackoverflow.com/questions/27417809/show-values-at-each-node-level-of-scikit-learn-decision-tree和{a2}。在

原来在scikit learn 0.17.dev0中,tree_.value已经返回了所有节点的值。在

In [1]: from sklearn.tree import DecisionTreeClassifier

In [2]: dt = DecisionTreeClassifier()

In [3]: dt.fit([[0,0],
   ...:         [0,1],
   ...:         [1,1]], [0,1,0])
Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            random_state=None, splitter='best')

In [4]: dt.tree_.value
Out[4]:
array([[[ 2.,  1.]],

       [[ 1.,  1.]],

       [[ 1.,  0.]],

       [[ 0.,  1.]],

       [[ 1.,  0.]]])

更新2

虽然我认为当给定class_weight时“撤销权重”是有意义的,但实现这一点是有可能的。在

class_weight的计算公式为

In [1]: from sklearn.utils import compute_class_weight

In [2]: compute_class_weight('auto', [0, 1], [0, 1, 0])
Out[2]: array([ 0.66666667,  1.33333333])

因此,您可以在if node in leaves:之后添加value[node, :] /= class_weight,以重新计算叶节点的值。在

相关问题 更多 >