sklearn决策树的遍历

2024-04-24 23:18:47 发布

您现在位置:Python中文网/ 问答频道 /正文

如何进行sklearn决策树的广度优先搜索遍历

在我的代码中,我尝试了sklearn.tree_uu库,并使用了各种函数,如tree_uu.feature和tree_uuu.threshold来理解树的结构。但是这些函数执行dfs树遍历如果我想执行bfs,我应该怎么做

假设

clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)

这是我的分类器,生成的决策树是

enter image description here

然后,我使用以下函数遍历了树

def encoding(clf, features):
l1 = list()
l2 = list()

for i in range(len(clf.tree_.feature)):
    if(clf.tree_.feature[i]>=0):
        l1.append( features[clf.tree_.feature[i]])
        l2.append(clf.tree_.threshold[i])
    else:
        l1.append(None)
        print(np.max(clf.tree_.value))
        l2.append(np.argmax(clf.tree_.value[i]))

l = [l1 , l2]

return np.array(l)

生产的产品是

array([['address', 'age', None, None, 'age', None, None],
       [0.5, 17.5, 2, 1, 15.5, 1, 1]], dtype=object)

如果第一个数组是节点的特征,或者如果它没有叶,那么它被标记为无,第二个数组是特征节点的阈值,对于类节点,它是类,但这是树的dfs遍历我想做bfs遍历我应该做什么

由于我不熟悉stack overflow,请建议如何改进问题描述,以及我应该添加哪些其他信息来进一步解释我的问题

X_列车(样本) X_train

y_列车(样本) y_train


Tags: 函数none决策树treel1threshold节点np
1条回答
网友
1楼 · 发布于 2024-04-24 23:18:47

这应该做到:

from collections import deque

tree = clf.tree_

stack = deque()
stack.append(0)  # push tree root to stack

while stack:
    current_node = stack.popleft()

    # do whatever you want with current node
    # ...

    left_child = tree.children_left[current_node]
    if left_child >= 0:
        stack.append(left_child)

    right_child = tree.children_right[current_node]
    if right_child >= 0:
        stack.append(right_child)

这使用了一个deque来保存下一个要处理的节点堆栈。由于我们从左侧移除元素并将其添加到右侧,因此这应该表示宽度优先遍历


为了实际使用,我建议您将其转换为发电机:

from collections import deque

def breadth_first_traversal(tree):
    stack = deque()
    stack.append(0)

    while stack:
        current_node = stack.popleft()

        yield current_node

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append(left_child)

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append(right_child)

然后,您只需要对原始函数进行最小的更改:

def encoding(clf, features):
    l1 = list()
    l2 = list()

    for i in breadth_first_traversal(clf.tree_):
        if(clf.tree_.feature[i]>=0):
            l1.append( features[clf.tree_.feature[i]])
            l2.append(clf.tree_.threshold[i])
        else:
            l1.append(None)
            print(np.max(clf.tree_.value))
            l2.append(np.argmax(clf.tree_.value[i]))

    l = [l1 , l2]

    return np.array(l)

相关问题 更多 >