Numpy,长数组问题

14 投票
3 回答
28372 浏览
提问于 2025-04-15 15:47

我有两个数组(a 和 b),里面各有 n 个整数,范围是从 0 到 N。

这里有个小错误:数组里有 2^n 个整数,最大的整数是 N = 3^n。

我想计算 a 和 b 中每个元素组合的和(也就是 sum_ij_ = a_i_ + b_j_,对于所有的 i,j)。然后对这个和取模 N(也就是 sum_ij_ = sum_ij_ % N),最后统计不同和的出现频率。

为了快速完成这个计算,我想用 numpy,不想用循环,所以我尝试使用 meshgrid 和 bincount 函数。

A,B = numpy.meshgrid(a,b)
A = A + B
A = A % N
A = numpy.reshape(A,A.size)
result = numpy.bincount(A)

现在的问题是,我的输入数组很长。当我使用 2^13 个元素的输入时,meshgrid 会让我遇到内存错误。我希望能计算包含 2^15 到 2^20 个元素的数组。

也就是 n 的范围是 15 到 20。

有没有什么聪明的办法可以用 numpy 来做到这一点?

任何帮助都非常感谢。

-- jon

3 个回答

1

检查一下你的计算,这个空间要求可真是太大了:

2的20次方乘以2的20次方等于2的40次方,也就是1 099 511 627 776。

如果你每个元素只占一个字节,那这已经是一太字节的内存了。

再加上几个循环,这个问题并不适合让你的内存用到极限,同时还要减少计算量。

2
wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')

根据jonalm的评论进行编辑:

jonalm: N~3^n,而不是n~3^N。N是数组a中的最大元素,n是数组a中的元素数量。

n大约是2的20次方。如果N大约是3的n次方,那么N大约是3的(2的20次方),这比10的500207次方还要大。科学家估计(http://www.stormloader.com/ajy/reallife.html)宇宙中大约只有10的87次方的粒子。所以,计算机根本无法处理大小为10的500207次方的整数。

jonalm: 不过我对你定义的pv()函数有点好奇。我没法运行它,因为text.find()没有定义(我猜它在另一个模块里)。这个函数是怎么工作的,有什么好处呢?

pv是我写的一个小助手函数,用来调试变量的值。它的工作方式类似于print(),不过当你调用pv(x)时,它会打印出变量的名字(或表达式字符串)、一个冒号,然后是变量的值。

如果你在脚本中放入

#!/usr/bin/env python
import traceback
def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)

你应该会得到

x: 1

使用pv而不是print的一个小好处是省去了打字的麻烦。你不需要写

print('x: %s'%x)

只需写

pv(x)

当需要跟踪多个变量时,给变量加标签是很有帮助的。我只是厌倦了写那么多。

pv函数通过使用traceback模块来查看调用pv函数的那行代码。 (查看http://docs.python.org/library/traceback.html#module-traceback) 那行代码作为字符串存储在变量text中。 text.find()是调用常规字符串方法find()。例如,如果

text='pv(x)'

那么

text.find('(') == 2               # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x'  # Everything in between the parentheses

我假设n大约是3的N次方,而n大约是2的20次方。

这个想法是对N取模。这可以减少数组的大小。第二个想法(当n非常大时很重要)是使用numpy的ndarray类型为'object',因为如果你使用整数类型,可能会超出允许的最大整数大小。

#!/usr/bin/env python
import traceback
import numpy as np

def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))

你可以把n改成2的20次方,但下面我会展示小n的情况,这样输出会更容易阅读。

n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4

a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
#  1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
#  0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
#  3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
#  3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]

wa表示数组a中0、1、2、3的数量,

wb表示数组b中0、1、2、3的数量。

把0想象成一个代币或筹码,1、2、3也是如此。

把wa=[24 28 28 20]理解为有一个袋子,里面有24个0筹码、28个1筹码、28个2筹码和20个3筹码。

你有一个wa袋子和一个wb袋子。当你从每个袋子中抽出一个筹码时,你“加”在一起形成一个新筹码。你对结果进行“取模”(模N)。

想象一下从wb袋子中拿出一个1筹码,并将其与wa袋子中的每个筹码相加。

1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip  (we are mod'ing by N=4)

因为wb袋子里有34个1筹码,当你把它们与wa=[24 28 28 20]袋子里的所有筹码相加时,你会得到

34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips

这只是由于34个1筹码而产生的部分计数。你还需要处理wb袋子中的其他类型筹码,但这展示了下面使用的方法:

for i,count in enumerate(wb):
    partial_count=count*wa
    pv(partial_count)
    shifted_partial_count=np.roll(partial_count,i)
    pv(shifted_partial_count)
    result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]

pv(result)    
# result: [2444 2504 2520 2532]

这是最终结果:2444个0,2504个1,2520个2,2532个3。

# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
    print('n is too large to run the check')
else:
    c=(a[:]+b[:,np.newaxis])
    c=c.ravel()
    c=c%N
    result2=np.bincount(c)
    pv(result2)
    assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]
7

试着把它分块处理。你的网格是一个NxN的矩阵,把它分成10x10的小块,也就是每块的大小是N/10xN/10,然后只计算100个小块,最后再把它们加起来。这样做只会用到大约1%的内存,跟一次性处理整个矩阵相比,节省了很多内存。

撰写回答