树的动态剪枝

3 投票
2 回答
1117 浏览
提问于 2025-04-18 11:06

我的问题是这样的:我想找到所有长度为 n 的组合,这些组合由 m 个可能的数字组成,并且这些数字的平均值要大于一个阈值 X

举个例子,假设长度 n=3,数字是 {1, 2},阈值是 1.5。所有可能的组合总数是 2*2*2 == 2**3 = 8,具体组合如下:

222 - avg 2.000000 > 1.500000 -> include in acceptable set
221 - avg 1.666667 > 1.500000 -> include in acceptable set
212 - avg 1.666667 > 1.500000 -> include in acceptable set
211 - avg 1.333333 < 1.500000 -> adding 1 and below to the exclude list
122 - avg 1.666667 > 1.500000 -> include in acceptable set
121 - avg 1.333333 < 1.500000 -> skipping this vote combo
112 - avg 1.333333 < 1.500000 -> skipping this vote combo
111 - avg 1.000000 < 1.500000 -> skipping this vote combo

final list of valid  votecombos
[[2, 2, 2], [2, 2, 1], [2, 1, 2], [1,2,2]]

我认为解决这个问题的方法是想象一个包含所有可能组合的树,然后在遍历的过程中动态地剪枝,去掉那些不可能的解。比如想象一个 n=3 的树,结构大概是这样的:

                            root
                         /        \
                        1          2
                    /      \    /     \
                   1       2    1     2
                 /  \    /  \  /  \  /  \
                1   2   1   2  1  2  1   2

每条到达叶子节点的路径就是一个可能的组合。可以想象,对于 n=3m=5,节点的数量是 N == m**n == 5**3 == 125 个节点。很明显,即使是 m=5n=20 的情况下,树的规模也会非常庞大,大约有 96 万亿个节点。所以这棵树不能全部存储在内存中。但其实也不需要,因为它的结构是非常有规律的。

获取所有可能有效组合的方法是通过深度优先搜索(DFS)遍历这棵树,采用一种前序遍历的方式,同时在遍历的过程中进行剪枝。比如在上面的例子中,前面三个组合 {222, 221, 212} 是有效的,但 211 就不行。这也意味着,从这个点开始,任何包含两个 1 的组合都不会有效。因此,我们几乎可以剪掉树的左半部分,除了 122!这样可以帮助我们避免检查这三个组合。

为此,我写了一个简单的 Python 脚本。

import string
import itertools
import numpy as np
import re

chars = '21'
grad_thr = 1.5
seats = 3

excllist = []
validlist = []

for word in itertools.product(chars, repeat = seats):
    # form the string of digits
    votestr = ''.join(word)
    print (votestr)

    # convert string into list of chars
    liststr = list(votestr)
    #print liststr

    # map list of chars to list of ints
    listint = map(int, liststr)

    if len(list(set(listint) & set(excllist))) == 0:
        # if there are no excluded votes in this votecombo then proceed

        # compute a function over the digits; func can be average/bayesian score/something else.
        y_mean = np.mean(listint)
        print 'avg %f' %y_mean
        #y_bayes = bayesian score

        if y_mean >= grad_thr:
            # if function result is greater than grad threshold then save result to a list of valid votes
            validlist.append(listint)
            print 'geq than %f -> include in acceptable set' %grad_thr

        elif y_mean < grad_thr:
            # if function result is not greater than grad threshold then add logic to stop searching the tree further
            # prune unnecessary part of the tree

            if listint[-1] not in excllist:
                excllist = [int(d) for d in range(listint[-1] + 1)]
                print 'adding %d and below to the exclude list' %listint[-1]
            else:
                print '%d already present in exclude list' %listint[-1]

    else:
        print 'skipping this vote combo'

print '\nfinal valid list of votecombos'
print validvotelist
print 'exclude list'
print excllist
print '\n'

通过这种方式,我可以遍历每一个可能的组合,并且 跳过 一些组合,以避免计算平均值。不过,我还是得在进入循环后检查每一个可能的组合。

有没有可能完全不检查某个组合?也就是说,我们知道组合 121 不行,但我们还是得进入循环,然后跳过这个组合。有没有可能不这样做呢?

2 个回答

0

看起来你把事情搞得太复杂了。我会这样做:先拿你的数字集合 A、字符串长度 n 和阈值 T,然后解决下面这个优化问题:

Minimize the sum of n elements of A (with repeats) such that the sum still exceeds
the threshold value T.

你可以用结果的 argmin 来生成一个多重集合,这样你就可以从中抽取出有效的字符串,而且是不重复抽取的。例如,任何包含两个 2 的字符串,其平均数字值都会超过你的阈值,所以多重集合 M = [1, 2, 2, 2] 中任何三个元素的排列都是有效的。

补充说明:这里有一种生成最小有效多重集合的方法。partitionfunc 的定义是借用自 这个帖子,然后我只是过滤掉那些所有元素都在 digit_set 中的列表。min_sum 需要进行向上取整的操作,因为我假设数字必须是整数,所以它们的和也会是整数。因此,为了超过阈值,数字和的值必须不小于 ceil(num_digits * threshold)。希望这对你有帮助!

from math import ceil

def partitionfunc(n,k,l=1):
'''n is the integer to partition, 
   k is the length of partitions, 
   l is the min partition element size'''
if k < 1:
    raise StopIteration
if k == 1:
    if n >= l:
        yield (n,)
    raise StopIteration
for i in range(l,n+1):
    for result in partitionfunc(n-i,k-1,i):
        yield (i,)+result

def find_min_sets(num_digits, digit_set, threshold):
  min_sum = ceil(num_digits * threshold)
  min_sets = [l for l in partitionfunc(min_sum, num_digits) if
              all(map(lambda x: x in digit_set, l))]
  return min_sets
1

一些建议:

  1. 构建多重集合而不是有序列表。一个有序列表的平均值不受数字顺序的影响,而每个多重集合对应着许多有序列表。因此,你可以只保存多重集合,等需要的时候再从中生成所有对应的有序列表,这样可以节省很多内存。
  2. 与其从一个空的多重集合开始,然后在每次深度优先搜索(DFS)中添加一个数字,不如从一个包含n个最大数字的满多重集合开始,在每次DFS中将其中一个数字减1。(这假设可用数字之间没有“空缺”。)这样做的好处是,我们知道向下遍历DFS边缘只会降低平均值,因此如果这样做的结果平均值低于阈值,我们就可以完全剪枝,因为所有更深的子节点的平均值一定更低。
  3. 其实你根本不需要进行任何除法:你只需要将阈值x乘以n,得到一个最小的数字,然后可以用这个和去比较多重集合的和。此外,按照之前的建议生成子节点时,子节点的和总是比父节点的和少1,所以我们甚至不需要循环来计算和——这是一种常数时间的操作。

避免重复

不过,上面提到的生成子节点的规则确实带来了一个困难:我们怎么确保不重复生成同一个子节点呢?例如,如果树中有一个节点包含多重集合{5, 8}(在这个例子中就是一个普通集合),那么它会生成子节点{4, 8}和{5, 7};但如果树中还有另一个节点是集合{4, 9},那么它会生成子节点{3, 9}和{4, 8}——这样子节点{4, 8}就会被生成两次。

解决这个问题的方法是找出一个规则,让每个子节点可以“选择”一个独特的父节点,然后安排父节点只生成它们会成为“被选中”的父节点的子节点。例如,我们可以规定,子节点应该选择在所有可以生成它的父节点中,按字典顺序最大的那个作为它的唯一父节点。(你也可以选择字典顺序最小的,但选择最大的计算效率更高。)对于多重集合{4, 8},可以生成它的两个父节点是{5, 8}和{4, 9};在这两个中,{5, 8}按字典顺序更大,因此我们将它选为父节点。

但是在DFS中,我们是从父节点生成子节点,而不是反过来,所以我们仍然需要将这个“选择父节点”的规则转化为一种方法,来判断当我们在一个可能是某个子节点的父节点的节点时,它是否确实是那个子节点的“被选中”父节点。为此,考虑某个子节点v的所有潜在父节点。首先,有多少个呢?如果v有r个不同的数字小于最大数字值,那么就有r个可能的父节点,每个父节点都等于v,但有1个不同的数字比它大1。

假设v中最小的数字是d,并且有k >= 1个这样的数字。在v的r个潜在父节点中,除了一个父节点u会有k-1个d,其他所有父节点都会有k个d,因为在这个父节点中,必须将数字d+1减1变成d(从而将d的数量从k-1增加到k)来生成v。现在如果我们将v的r个潜在父节点按数字升序列出,注意到除了u之外,所有父节点都会以k个d开头,而u则以k-1(可能为0)个d开头,后面跟着至少1个d+1。因此u在字典顺序上大于其他r-1个潜在父节点。

这告诉我们从潜在父节点的角度来看,成为被选中父节点的标准。假设u中最小的数字是d。那么某个节点v只有在以下两种情况下才会将u视为它的被选中父节点:要么是将u中的一个d数字减小到d-1,要么是将u中的一个(d+1)数字减小到d。这转化为生成子节点的两个简单高效的规则:

假设我们在某个节点u,并想根据上述规则生成它的所有子节点,以确保每个满足条件的多重集合在树中只生成一次。和之前一样,设d为u中的最小数字。那么:

  • 如果d > 最小数字,生成一个与u相同的子节点,只是将u中的一个d数字减小到d-1。例如:{3, 3, 3, 4, 6, 6}应该生成子节点{2, 3, 3, 4, 6, 6}。
  • 如果u中包含一个d+1的数字,生成一个与u相同的子节点,只是将u中的一个(d+1)数字减小到d。例如:{3, 3, 3, 4, 6, 6}应该生成子节点{3, 3, 3, 3, 6, 6}。

所以在上面的例子中,节点u = {3, 3, 3, 4, 6, 6}将生成2个子节点。(没有节点会生成超过2个子节点。)

如果多重集合以排序列表或按排序顺序的数字频率计数表示,那么只需扫描初始部分就可以高效检查这两个条件。

示例

在你的例子中(记住我们这里只生成多重集合;在单独的步骤中生成它们的每个排列以找到所有有序列表):

sum_threshold = 1.5*3 = 4.5

                                       SUM
                        222             6
                      /
                    122                 5
                  /
                112                     4
               PRUNE

在一个稍大的例子中,数字={1, 2, 3},n=3,x=0(以显示所有多重集合将被生成,且每个只生成一次):

                                       SUM
                       333              9
                     /
                   233                  8
                 /     \
               133     223              7
                      /  \
                    123  222            6
                    /      \
                  113      122          5
                             \
                             112        4
                               \
                               111      3

撰写回答