找出能通过加减法得到的所有数字
我对我的应用程序进行了性能分析,发现它90%的时间都花在了plus_minus_variations
这个函数上。
这个函数的作用是通过加法和减法,从一组数字中找出可以得到的各种数字。
举个例子:
输入
1, 2
输出
1+2=3
1-2=-1
-1+2=1
-1-2=-3
这是我现在的代码。我觉得在速度上可以有很大的提升。
def plus_minus_variations(nums):
result = dict()
for i, ops in zip(xrange(2 ** len(nums)), \
itertools.product([-1, 1], repeat=len(nums))):
total = sum(map(operator.mul, ops, nums))
result[total] = ops
return result
我主要想找一种不同的算法来解决这个问题。现在的算法效率似乎不太高。不过,如果你有关于代码本身的优化建议,我也很乐意听听。
补充说明:
- 如果结果缺少一些答案(或者多出一些不必要的答案),只要能更快完成也是可以的。
- 如果有多种方法可以得到一个数字,哪种都可以。
- 对于我使用的列表大小,99.9%的方法都会产生重复的数字。
- 如果结果不包含数字是怎么得来的,只要能更快完成也是可以的。
6 个回答
5
编辑:
哦!
这段代码是用Python 3写的,灵感来自tyz:
from functools import reduce # only in Python 3
def process(old, num):
new = set(map(num.__add__, old)) # use itertools.imap for Python 2
new.update(map((-num).__add__, old))
return new
def pmv(nums):
n = iter(nums)
x = next(n)
result = {x, -x} # set([x, -x]) for Python 2
return reduce(process, n, result)
我没有采用分半和递归的方法,而是用reduce
这个函数一个一个地计算,这样大大减少了函数调用的次数。
计算256个数字的时间不到1秒。
为什么先乘积再相乘?
def pmv(nums):
return {sum(i):i for i in itertools.product(*((num, -num) for num in nums))}
这样做可能会更快,而不管这些数字是怎么产生的:
def pmv(nums):
return set(map(sum, itertools.product(*((num, -num) for num in nums))))
6
如果不需要追踪生成的数字,那就没必要每次都重新计算数字组合的总和。你可以把中间结果存起来:
def combine(l,r):
res = set()
for x in l:
for y in r:
res.add( x+y )
res.add( x-y )
res.add( -x+y )
res.add( -x-y )
return list(res)
def pmv(nums):
if len(nums) > 1:
l = pmv( nums[:len(nums)/2] )
r = pmv( nums[len(nums)/2:] )
return combine( l, r )
return nums
编辑: 如果数字生成的方式很重要,你可以使用这个方法:
def combine(l,r):
res = dict()
for x,q in l.iteritems():
for y,w in r.iteritems():
if not res.has_key(x+y):
res[x+y] = w+q
res[-x-y] = [-i for i in res[x+y]]
if not res.has_key(x-y):
res[x-y] = w+[-i for i in q]
res[-x+y] = [-i for i in res[x-y]]
return res
def pmv(nums):
if len(nums) > 1:
l = pmv( nums[:len(nums)/2] )
r = pmv( nums[len(nums)/2:] )
return combine( l, r )
return {nums[0]:[1]}
我的测试显示,这种方法仍然比其他解决方案更快。
4
对于大随机列表,这种方法似乎快了很多。我想你可以进一步进行微调,但我更喜欢代码的可读性。
我把列表分成小块,然后为每一块创建不同的组合。因为你得到的组合数量远少于 2 ** len(chunk)
,所以速度会更快。这里的块长度是6,你可以尝试不同的长度,看看哪个长度效果最好。
def pmv(nums):
chunklen=6
res = dict()
res[0] = ()
for i in xrange(0, len(nums), chunklen):
part = plus_minus_variations(nums[i:i+chunklen])
resnew = dict()
for (i,j) in itertools.product(res, part):
resnew[i + j] = tuple(list(res[i]) + list(part[j]))
res = resnew
return res