向量化:数组的索引太多

2024-04-19 18:37:01 发布

您现在位置:Python中文网/ 问答频道 /正文

a=b=np.arange(9).reshape(3,3)
i=np.arange(3)
mask=a<i[:,None,None]+3

以及

^{pr2}$

现在我想把它矢量化,然后把它们打印在一起,我试着

b[np.where(mask[i])]和{}

它们都显示IndexError: too many indices for array


Tags: nonefornpmaskwherearray矢量化many
2条回答

当试图打印一个矢量时,它只能存在于x、y和z维中。你有4个。在

In [165]: a
Out[165]: 
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
In [166]: mask
Out[166]: 
array([[[ True,  True,  True],
        [False, False, False],
        [False, False, False]],

       [[ True,  True,  True],
        [ True, False, False],
        [False, False, False]],

       [[ True,  True,  True],
        [ True,  True, False],
        [False, False, False]]], dtype=bool)

所以a(和b)是(3,3),而{}是(3,3,3)。在

应用于数组的布尔掩码生成1d(通过where应用时也是如此):

^{pr2}$

2d掩码上的where生成一个2元素元组,它可以索引2d数组:

In [173]: np.where(mask[1,:,:])
Out[173]: (array([0, 0, 0, 1], dtype=int32), array([0, 1, 2, 0], dtype=int32))

3d掩码上的where是一个3元素元组-因此出现too many indices错误:

In [174]: np.where(mask)
Out[174]: 
(array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype=int32),
 array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1], dtype=int32),
 array([0, 1, 2, 0, 1, 2, 0, 0, 1, 2, 0, 1], dtype=int32))

让我们尝试将a扩展到3d并应用掩码

In [176]: np.tile(a[None,:],(3,1,1)).shape
Out[176]: (3, 3, 3)
In [177]: np.tile(a[None,:],(3,1,1))[mask]
Out[177]: array([0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4])

价值观是存在的,但它们是连接在一起的。在

我们可以计算出mask的每个平面中True的数量,并将其用于split蒙版瓷砖:

In [185]: mask.sum(axis=(1,2))
Out[185]: array([3, 4, 5])
In [186]: cnt=np.cumsum(mask.sum(axis=(1,2)))
In [187]: cnt
Out[187]: array([ 3,  7, 12], dtype=int32)

In [189]: np.split(np.tile(a[None,:],(3,1,1))[mask], cnt[:-1])
Out[189]: [array([0, 1, 2]), array([0, 1, 2, 3]), array([0, 1, 2, 3, 4])]

在内部,np.split使用Python级别的迭代。所以在mask平面上的迭代可能也一样好(在这个小例子中,迭代速度要快6倍)。在

In [190]: [a[m] for m in mask]
Out[190]: [array([0, 1, 2]), array([0, 1, 2, 3]), array([0, 1, 2, 3, 4])]

这指出了一个基本的问题,即理想的“矢量化”,单个数组是(3,),(4,)和(5,)形状。不同大小的数组是一个强有力的指示,真正的“矢量化”是困难的,如果不是不可能的话。在

相关问题 更多 >