找出能通过加减法得到的所有数字

3 投票
6 回答
2332 浏览
提问于 2025-04-16 23:03

我对我的应用程序进行了性能分析,发现它90%的时间都花在了plus_minus_variations这个函数上。

这个函数的作用是通过加法和减法,从一组数字中找出可以得到的各种数字。

举个例子:
输入

1, 2

输出

1+2=3
1-2=-1
-1+2=1
-1-2=-3

这是我现在的代码。我觉得在速度上可以有很大的提升。

def plus_minus_variations(nums):
    result = dict()
    for i, ops in zip(xrange(2 ** len(nums)), \
            itertools.product([-1, 1], repeat=len(nums))):
        total = sum(map(operator.mul, ops, nums))
        result[total] = ops
    return result

我主要想找一种不同的算法来解决这个问题。现在的算法效率似乎不太高。不过,如果你有关于代码本身的优化建议,我也很乐意听听。

补充说明:

  • 如果结果缺少一些答案(或者多出一些不必要的答案),只要能更快完成也是可以的。
  • 如果有多种方法可以得到一个数字,哪种都可以。
  • 对于我使用的列表大小,99.9%的方法都会产生重复的数字。
  • 如果结果不包含数字是怎么得来的,只要能更快完成也是可以的。

6 个回答

5

编辑:

哦!
这段代码是用Python 3写的,灵感来自tyz:

from functools import reduce # only in Python 3

def process(old, num):
    new = set(map(num.__add__, old)) # use itertools.imap for Python 2
    new.update(map((-num).__add__, old))
    return new

def pmv(nums):
    n = iter(nums)
    x = next(n)
    result = {x, -x} # set([x, -x]) for Python 2
    return reduce(process, n, result)

我没有采用分半和递归的方法,而是用reduce这个函数一个一个地计算,这样大大减少了函数调用的次数。

计算256个数字的时间不到1秒。


为什么先乘积再相乘?

def pmv(nums):
    return {sum(i):i for i in itertools.product(*((num, -num) for num in nums))}

这样做可能会更快,而不管这些数字是怎么产生的:

def pmv(nums):
    return set(map(sum, itertools.product(*((num, -num) for num in nums))))
6

如果不需要追踪生成的数字,那就没必要每次都重新计算数字组合的总和。你可以把中间结果存起来:

def combine(l,r):
    res = set()
    for x in l:
        for y in r:
            res.add( x+y )
            res.add( x-y )
            res.add( -x+y )
            res.add( -x-y )
    return list(res)

def pmv(nums):
    if len(nums) > 1:
        l = pmv( nums[:len(nums)/2] )
        r = pmv( nums[len(nums)/2:] )
        return combine( l, r )
    return nums

编辑: 如果数字生成的方式很重要,你可以使用这个方法:

def combine(l,r):
    res = dict()
    for x,q in l.iteritems():
        for y,w in r.iteritems():
            if not res.has_key(x+y):
                res[x+y] = w+q
                res[-x-y] = [-i for i in res[x+y]]
            if not res.has_key(x-y):
                res[x-y] = w+[-i for i in q]
                res[-x+y] = [-i for i in res[x-y]]
    return res

def pmv(nums):
    if len(nums) > 1:
        l = pmv( nums[:len(nums)/2] )
        r = pmv( nums[len(nums)/2:] )
        return combine( l, r )
    return {nums[0]:[1]}

我的测试显示,这种方法仍然比其他解决方案更快。

4

对于大随机列表,这种方法似乎快了很多。我想你可以进一步进行微调,但我更喜欢代码的可读性。

我把列表分成小块,然后为每一块创建不同的组合。因为你得到的组合数量远少于 2 ** len(chunk),所以速度会更快。这里的块长度是6,你可以尝试不同的长度,看看哪个长度效果最好。

def pmv(nums):
    chunklen=6
    res = dict()
    res[0] = ()
    for i in xrange(0, len(nums), chunklen):
        part = plus_minus_variations(nums[i:i+chunklen])
        resnew = dict()
        for (i,j) in itertools.product(res, part):
            resnew[i + j] = tuple(list(res[i]) + list(part[j]))
        res = resnew
    return res

撰写回答