高效选择numpy数组的子部分

1 投票
1 回答
917 浏览
提问于 2025-04-28 12:41

我想把一个叫做 x 的 numpy 数组分成三个不同的数组,这个分割是基于一些逻辑比较。这个数组的形状如下,但里面的内容会有所不同:(为了回应 Saullo Castro 的评论,我这里用了一个稍微不同的数组 x。)

array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
      [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

这个数组的值在每一列中是单调递增的。我还有两个其他的数组,分别叫 lowest_gridpointshighest_gridpoints。这两个数组的内容也会变化,但它们的形状总是和下面这个一样:

 array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

我想要的选择过程如下:

  • 所有包含小于 lowest_gridpoints 中任何值的列都应该从 x 中移除,并组成数组 temp1
  • 所有包含大于 highest_gridpoints 中任何值的列也应该从 x 中移除,并组成数组 temp2
  • 所有既不在 temp1 也不在 temp2 中的 x 的列组成新的数组 x_new

我写的以下代码可以完成这个任务。

if np.any( x[:,-1] > highest_gridpoints ) or np.any( x[:,0] < lowest_gridpoints ):
    for idx, sample, in enumerate(x.T):
        if np.any( sample > highest_gridpoints):
            max_idx = idx
            break
        elif np.any( sample < lowest_gridpoints ):
            min_idx = idx 
    temp1, temp2 = np.array([[],[]]), np.array([[],[]])
    if 'min_idx' in locals():
        temp1 = x[:,0:min_idx+1]
    if 'max_idx' in locals():
        temp2 = x[:,max_idx:]
    if 'min_idx' in locals() or 'max_idx' in locals():
        if 'min_idx' not in locals():
            min_idx = -1
        if 'max_idx' not in locals():
            max_idx = x.shape[1]
        x_new = x[:,min_idx+1:max_idx]

不过,我怀疑这段代码效率很低,因为用了很多循环。此外,我觉得代码的写法有点冗长。

有没有人能提供一个更高效或者更简洁的代码来完成上述任务呢?

暂无标签

1 个回答

1

只有你问题的第一部分

from numpy import *

x = array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

# construct an array of two rows of bools expressing your conditions
indices1 = array((x[0,:]<low[0], x[1,:]<low[1]))
print indices1

# do an or of the values along the first axis
indices1 = any(indices1, axis=0)
# now it's a single row array
print indices1

# use the indices1 to extract what you want,
# the double transposition because the elements
# of a 2d array are  the rows
tmp1 = x.T[indices1].T
print tmp1

# [[ True  True False False False]
#  [ True  True False False False]]
# [ True  True False False False]
# [[ 0.46006547  0.5580928 ]
#  [ 0.00912908  0.00912908]]

接下来类似地构造 indices2tmp2,剩余部分的索引是前两个索引的“或”操作的否定。(也就是说,numpy.logical_not(numpy.logical_or(i1,i2)))。

补充说明

另一种方法,如果你有成千上万的条目,可能会更快,涉及到 numpy.searchsorted

from numpy import *

x = array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')

h0l = searchsorted(x[0,:], high[0], side='left')
h1l = searchsorted(x[1,:], high[1], side='left')

lr = max(l0r, l1r)
hl = min(h0l, h1l)

print lr, hl
print x[:,:lr]
print x[:,lr:hl]
print x[:,hl]

# 2 4
# [[ 0.46006547  0.5580928 ]
#  [ 0.00912908  0.00912908]]
# [[ 0.70164242  0.84519205]
#  [ 0.05        0.05      ]]
# [ 1.4   0.05]

排除重叠可以通过 hl = max(lr, hl) 来实现。注意,在之前的方法中,数组切片被复制到新的对象中,而在这里你得到的是 x 的视图,如果你想要新的对象,你必须明确说明。

编辑 一个不必要的优化

如果我们在第二对 sortedsearch 中只使用 x 的上半部分(如果你看看代码就会明白我是什么意思……),我们会得到两个好处,1)搜索速度稍微加快(sortedsearch 总是足够快),2)重叠的情况会自动处理。

作为额外内容,提供了将 x 的片段复制到新数组的代码。 注意 x 被更改以强制重叠

from numpy import *

# I changed x to force overlap
x = array([[ 0.46006547,  1.4 ,        1.4,   1.4,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05,  0.05, 0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')
lr = max(l0r, l1r)

h0l = searchsorted(x[0,lr:], high[0], side='left')
h1l = searchsorted(x[1,lr:], high[1], side='left')

hl = min(h0l, h1l) + lr

t1 = x[:,range(lr)]
xn = x[:,range(lr,hl)]
ncol = shape(x)[1]
t2 = x[:,range(hl,ncol)]

print x
del(x)
print
print t1
print
# note that xn is a void array 
print xn
print
print t2

# [[ 0.46006547  1.4         1.4         1.4         1.4       ]
#  [ 0.00912908  0.00912908  0.05        0.05        0.05      ]]
# 
# [[ 0.46006547  1.4       ]
#  [ 0.00912908  0.00912908]]
# 
# []
# 
# [[ 1.4   1.4   1.4 ]
#  [ 0.05  0.05  0.05]]

撰写回答