扩展numpy.digitize以处理多维数据

11 投票
1 回答
5582 浏览
提问于 2025-04-18 12:34

我有一组很大的数组(每个大约有600万个元素),我想对它们进行类似于np.digitize的操作,但需要在多个维度上进行。我希望能得到一些建议,既包括如何有效地完成这个操作,也包括如何存储结果。

我需要数组A中所有的索引(或者所有的值,或者一个掩码),条件是数组B的值在一个范围内,数组C的值在另一个范围内,数组D的值在第三个范围内。我想要这些值、索引或掩码,以便我可以对数组A中每个区间的值进行一些尚未决定的统计分析。我还需要每个区间中的元素数量,不过用len()就可以做到这一点。

这是我想出的一个看起来还不错的例子:

import itertools
import numpy as np

A = np.random.random_sample(1e4)
B = (np.random.random_sample(1e4) + 10)*20
C = (np.random.random_sample(1e4) + 20)*40
D = (np.random.random_sample(1e4) + 80)*80

# make the edges of the bins
Bbins = np.linspace(B.min(), B.max(), 10)
Cbins = np.linspace(C.min(), C.max(), 12) # note different number
Dbins = np.linspace(D.min(), D.max(), 24) # note different number

B_Bidx = np.digitize(B, Bbins)
C_Cidx = np.digitize(C, Cbins)
D_Didx = np.digitize(D, Dbins)

a_bins = []
for bb, cc, dd in itertools.product(np.unique(B_Bidx), 
                                    np.unique(C_Cidx), 
                                    np.unique(D_Didx)):
    a_bins.append([(bb, cc, dd), [A[np.bitwise_and((B_Bidx==bb),
                                                   (C_Cidx==cc),
                                                   (D_Didx==dd))]]])

不过,这让我有点担心,因为在处理大数组时可能会耗尽内存。

我也可以这样做:

b_inds = np.empty((len(A), 10), dtype=np.bool)
c_inds = np.empty((len(A), 12), dtype=np.bool)
d_inds = np.empty((len(A), 24), dtype=np.bool)
for i in range(10):
    b_inds[:,i] = B_Bidx = i     
for i in range(12):
    c_inds[:,i] = C_Cidx = i     
for i in range(24):
    d_inds[:,i] = D_Didx = i     
# get the A data for the 1,2,3 B,C,D bin
print A[b_inds[:,1] & c_inds[:,2] & d_inds[:,3]]

至少在这里,输出的大小是已知且固定的。

有没有人有更好的想法,能让我更聪明地完成这个任务?或者需要进一步的解释吗?


根据HYRY的回答,这是我决定采取的路径。

import numpy as np
import pandas as pd

np.random.seed(42)
A =  np.random.random_sample(1e7)
B = (np.random.random_sample(1e7) + 10)*20
C = (np.random.random_sample(1e7) + 20)*40
D = (np.random.random_sample(1e7) + 80)*80
# make the edges of the bins we want
Bbins = np.linspace(B.min(), B.max(), 9)
Cbins = np.linspace(C.min(), C.max(), 10) # note different number
Dbins = np.linspace(D.min(), D.max(), 11) # note different number
sA = pd.Series(A)
cB = pd.cut(B, Bbins, include_lowest=True)
cC = pd.cut(C, Cbins, include_lowest=True)
cD = pd.cut(D, Dbins, include_lowest=True)

dat = pd.DataFrame({'A':A, 'cB':cB.labels, 'cC':cC.labels, 'cD':cD.labels})
g = sA.groupby([cB.labels, cC.labels, cD.labels]).indices
# this then gives all the indices that match the group 
print g[0,1,2]
# this is all the array A data for that B,C,D bin
print sA[g[0,1,2]]

这种方法即使在处理超大数组时也显得非常快速。

1 个回答

6

你可以试试在Pandas中使用 groupby。首先,修正你代码中的一些问题:

import itertools
import numpy as np

np.random.seed(42)

A = np.random.random_sample(1e4)
B = (np.random.random_sample(1e4) + 10)*20
C = (np.random.random_sample(1e4) + 20)*40
D = (np.random.random_sample(1e4) + 80)*80

# make the edges of the bins
Bbins = np.linspace(B.min(), B.max(), 10)
Cbins = np.linspace(C.min(), C.max(), 12) # note different number
Dbins = np.linspace(D.min(), D.max(), 24) # note different number

B_Bidx = np.digitize(B, Bbins)
C_Cidx = np.digitize(C, Cbins)
D_Didx = np.digitize(D, Dbins)

a_bins = []
for bb, cc, dd in itertools.product(np.unique(B_Bidx), 
                                    np.unique(C_Cidx), 
                                    np.unique(D_Didx)):
    a_bins.append([(bb, cc, dd), A[(B_Bidx==bb) & (C_Cidx==cc) & (D_Didx==dd)]])

a_bins[1000]

输出结果:

[(4, 6, 17), array([ 0.70723863,  0.907611  ,  0.46214047])]

下面是用Pandas返回相同结果的代码:

import pandas as pd

cB = pd.cut(B, 9)
cC = pd.cut(C, 11)
cD = pd.cut(D, 23)

sA = pd.Series(A)
g = sA.groupby([cB.labels, cC.labels, cD.labels])
g.get_group((3, 5, 16))

输出结果:

800     0.707239
2320    0.907611
9388    0.462140
dtype: float64

如果你想计算每个组的一些统计数据,可以调用 g 的方法,比如:

g.mean()

返回结果:

0  0  0     0.343566
      1     0.410979
      2     0.700007
      3     0.189936
      4     0.452566
      5     0.565330
      6     0.539565
      7     0.530867
      8     0.568120
      9     0.587762
      11    0.352453
      12    0.484903
      13    0.477969
      14    0.484328
      15    0.467357
...
8  10  8     0.559859
       9     0.570652
       10    0.656718
       11    0.353938
       12    0.628980
       13    0.372350
       14    0.404543
       15    0.387920
       16    0.742292
       17    0.530866
       18    0.389236
       19    0.628461
       20    0.387384
       21    0.541831
       22    0.573023
Length: 2250, dtype: float64

撰写回答