Numpy,长数组的问题

2024-03-28 23:51:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个数组(a和b),其中n个整数元素在范围(0,n)内。

输入错误:数组中有2^n个整数,其中最大的整数取n=3^n

我想计算a和b中每个元素组合的和(sum_ij_u=a_I_u+b_j_u,for allI,j)。然后取模N(sum_ij_u=sum_ij_uu%N),最后计算不同和的频率。

为了在没有任何循环的情况下快速使用numpy,我尝试使用meshgrid和bincount函数。

A,B = numpy.meshgrid(a,b)
A = A + B
A = A % N
A = numpy.reshape(A,A.size)
result = numpy.bincount(A)

现在,问题是我的输入数组太长了。当我使用包含2^13个元素的输入时,meshgrid会给我内存错误。我想为2^15-2^20个元素的数组计算这个值。

在15到20之间的n

有没有什么聪明的把戏来对付努比?

任何帮助都将不胜感激。

-- 乔恩


Tags: 函数numpy元素for错误情况整数数组
3条回答

检查你的数学,这是你要的空间的一个

2^20*2^20=2^40=1099 511 627 776

如果每个元素只有一个字节,那已经是一个太字节的内存了。

加一两个圈。这个问题不适合于最大化内存和最小化计算。

针对jonalm的评论进行编辑:

jonalm: N~3^n not n~3^N. N is max element in a and n is number of elements in a.

n是~2^20。如果N是~3^N,则N是~3^(2^20)>;10^(500207)。 科学家估计宇宙中只有大约10^87个粒子。因此,计算机无法(天真地)处理10^(500207)大小的整数。

jonalm: I am however a bit curios about the pv() function you define. (I do not manage to run it as text.find() is not defined (guess its in another module)). How does this function work and what is its advantage?

pv是我为调试变量值而编写的一个小助手函数。它工作起来像 print()除了当您说pv(x)时,它同时打印文本变量名(或表达式字符串)、冒号和变量值。

如果你把

#!/usr/bin/env python
import traceback
def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)

在剧本里你应该得到

x: 1

与打印相比,使用pv的适度优势在于它可以节省您的打字时间。而不是不得不 写

print('x: %s'%x)

你可以拍下来

pv(x)

当有多个变量要跟踪时,标记这些变量会很有帮助。 我只是厌倦了写出来。

pv函数通过使用回溯模块查看代码行来工作 用于调用pv函数本身。(参见http://docs.python.org/library/traceback.html#module-traceback)这行代码作为字符串存储在变量文本中。 find()是对常用字符串方法find()的调用。例如,如果

text='pv(x)'

那么

text.find('(') == 2               # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x'  # Everything in between the parentheses

我假设n~3^n,和n~2**20

我们的想法是让模块N工作,这样可以减少数组的大小。 第二个想法(n很大时很重要)是使用“object”类型的numpy ndarrays,因为如果使用整数类型,则可能会溢出所允许的最大整数的大小。

#!/usr/bin/env python
import traceback
import numpy as np

def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))

你可以将n改为2**20,但下面我将展示小n的情况 所以输出更容易阅读。

n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4

a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
#  1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
#  0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
#  3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
#  3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]

wa在a中包含0、1、2、3的数字 wb保存b中0、1、2、3的数目

wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')

把0当作一个代币或筹码。对于1,2,3也一样。

把wa=[24 28 28 20]看作是一个包,里面有24个0片,28个1片,28个2片,20个3片。

你有一个wa包和一个wb包。当你从每个袋子里抽出一个芯片时,你把它们“加”在一起,形成一个新的芯片。你“修改”答案(模N)。

想象一下从wb包中取出一个芯片,并将其与wa包中的每个芯片一起添加。

1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip  (we are mod'ing by N=4)

由于wb包中有34个1-芯片,当您将它们与wa=[24 28 28 20]包中的所有芯片相加时,您将得到

34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips

这只是34个1芯片的部分计数。你还得处理另一个 wb袋中的芯片类型,但这向您展示了以下使用的方法:

for i,count in enumerate(wb):
    partial_count=count*wa
    pv(partial_count)
    shifted_partial_count=np.roll(partial_count,i)
    pv(shifted_partial_count)
    result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]

pv(result)    
# result: [2444 2504 2520 2532]

这是最终结果:24440s,25041s,25202s,25323s

# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
    print('n is too large to run the check')
else:
    c=(a[:]+b[:,np.newaxis])
    c=c.ravel()
    c=c%N
    result2=np.bincount(c)
    pv(result2)
    assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]

试着把它分块。你的网格是一个NxN矩阵,块高达10x10n/10xN/10,只需计算100个箱子,在最后加起来。这只占用大约1%的内存。

相关问题 更多 >