决策树发现当遍历树时,常量预测是如何变化的

2024-04-23 14:54:08 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我有以下DecisionTreeClassifier模型:

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer

bunch = load_breast_cancer()

X, y = bunch.data, bunch.target

model = DecisionTreeClassifier(random_state=100)
model.fit(X, y)

我想遍历这个树中的每个节点(叶子和决策),并确定预测值在遍历树时是如何变化的。基本上,对于给定的样本,我想知道最终预测(由.predict返回的内容)是如何确定的。因此,样本可能最终会被预测1,但会遍历四个节点,在每个节点上,其“常量”(scikit文档中使用的语言)预测又会从100再到1。你知道吗

我如何从model.tree_.value获得这些信息还不是很清楚,这被描述为:

 |  value : array of double, shape [node_count, n_outputs, max_n_classes]
 |      Contains the constant prediction value of each node.

在这个模型中:

>>> model.tree_.value.shape
(43, 1, 2)
>>> model.tree_.value
array([[[212., 357.]],

       [[ 33., 346.]],

       [[  5., 328.]],

       [[  4., 328.]],

       [[  2., 317.]],

       [[  1.,   6.]],

       [[  1.,   0.]],

       [[  0.,   6.]],

       [[  1., 311.]],

       [[  0., 292.]],

       [[  1.,  19.]],

       [[  1.,   0.]],

       [[  0.,  19.]],

有人知道我是怎么做到的吗?上面43个节点的类预测是否只是每个列表的argmax?所以1,1,1,1,1,1,0,0,…,从上到下?你知道吗


Tags: from模型importtreemodel节点valueload
1条回答
网友
1楼 · 发布于 2024-04-23 14:54:08

一种解决方案是直接走到树中的决策路径。 您可以调整this solution,它像打印子句一样打印整个决策树。 下面是一个简单的例子:

def tree_path(instance, values, left, right, threshold, features, node, depth):
    spacer = '    ' * depth
    if (threshold[node] != _tree.TREE_UNDEFINED):
        if instance[features[node]] <= threshold[node]:
            path = f'{spacer}{features[node]} ({round(instance[features[node]], 2)}) <= {round(threshold[node], 2)}'
            next_node = left[node]
        else:
            path = f'{spacer}{features[node]} ({round(instance[features[node]], 2)}) > {round(threshold[node], 2)}'
            next_node = right[node]
        return path + '\n' + tree_path(instance, values, left, right, threshold, features, next_node, depth+1)
    else:
        target = values[node]
        for i, v in zip(np.nonzero(target)[1],
                        target[np.nonzero(target)]):
            target_count = int(v)
            return spacer + "==> " + str(round(target[0][0], 2)) + \
                   " ( " + str(target_count) + " examples )"

def get_path_code(tree, feature_names, instance):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    values = tree.tree_.value
    return tree_path(instance, values, left, right, threshold, features, 0, 0)

# print the decision path of the first intance of a panda dataframe df
print(get_path_code(tree, df.columns, df.iloc[0]))

相关问题 更多 >