高效选择numpy数组的子部分
我想把一个叫做 x
的 numpy 数组分成三个不同的数组,这个分割是基于一些逻辑比较。这个数组的形状如下,但里面的内容会有所不同:(为了回应 Saullo Castro 的评论,我这里用了一个稍微不同的数组 x。)
array([[ 0.46006547, 0.5580928 , 0.70164242, 0.84519205, 1.4 ],
[ 0.00912908, 0.00912908, 0.05 , 0.05 , 0.05 ]])
这个数组的值在每一列中是单调递增的。我还有两个其他的数组,分别叫 lowest_gridpoints
和 highest_gridpoints
。这两个数组的内容也会变化,但它们的形状总是和下面这个一样:
array([ 0.633, 0.01 ]), array([ 1.325, 0.99 ])
我想要的选择过程如下:
- 所有包含小于
lowest_gridpoints
中任何值的列都应该从x
中移除,并组成数组temp1
。 - 所有包含大于
highest_gridpoints
中任何值的列也应该从x
中移除,并组成数组temp2
。 - 所有既不在
temp1
也不在temp2
中的x
的列组成新的数组x_new
。
我写的以下代码可以完成这个任务。
if np.any( x[:,-1] > highest_gridpoints ) or np.any( x[:,0] < lowest_gridpoints ):
for idx, sample, in enumerate(x.T):
if np.any( sample > highest_gridpoints):
max_idx = idx
break
elif np.any( sample < lowest_gridpoints ):
min_idx = idx
temp1, temp2 = np.array([[],[]]), np.array([[],[]])
if 'min_idx' in locals():
temp1 = x[:,0:min_idx+1]
if 'max_idx' in locals():
temp2 = x[:,max_idx:]
if 'min_idx' in locals() or 'max_idx' in locals():
if 'min_idx' not in locals():
min_idx = -1
if 'max_idx' not in locals():
max_idx = x.shape[1]
x_new = x[:,min_idx+1:max_idx]
不过,我怀疑这段代码效率很低,因为用了很多循环。此外,我觉得代码的写法有点冗长。
有没有人能提供一个更高效或者更简洁的代码来完成上述任务呢?
1 个回答
1
只有你问题的第一部分
from numpy import *
x = array([[ 0.46006547, 0.5580928 , 0.70164242, 0.84519205, 1.4 ],
[ 0.00912908, 0.00912908, 0.05 , 0.05 , 0.05 ]])
low, high = array([ 0.633, 0.01 ]), array([ 1.325, 0.99 ])
# construct an array of two rows of bools expressing your conditions
indices1 = array((x[0,:]<low[0], x[1,:]<low[1]))
print indices1
# do an or of the values along the first axis
indices1 = any(indices1, axis=0)
# now it's a single row array
print indices1
# use the indices1 to extract what you want,
# the double transposition because the elements
# of a 2d array are the rows
tmp1 = x.T[indices1].T
print tmp1
# [[ True True False False False]
# [ True True False False False]]
# [ True True False False False]
# [[ 0.46006547 0.5580928 ]
# [ 0.00912908 0.00912908]]
接下来类似地构造 indices2
和 tmp2
,剩余部分的索引是前两个索引的“或”操作的否定。(也就是说,numpy.logical_not(numpy.logical_or(i1,i2))
)。
补充说明
另一种方法,如果你有成千上万的条目,可能会更快,涉及到 numpy.searchsorted
from numpy import *
x = array([[ 0.46006547, 0.5580928 , 0.70164242, 0.84519205, 1.4 ],
[ 0.00912908, 0.00912908, 0.05 , 0.05 , 0.05 ]])
low, high = array([ 0.633, 0.01 ]), array([ 1.325, 0.99 ])
l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')
h0l = searchsorted(x[0,:], high[0], side='left')
h1l = searchsorted(x[1,:], high[1], side='left')
lr = max(l0r, l1r)
hl = min(h0l, h1l)
print lr, hl
print x[:,:lr]
print x[:,lr:hl]
print x[:,hl]
# 2 4
# [[ 0.46006547 0.5580928 ]
# [ 0.00912908 0.00912908]]
# [[ 0.70164242 0.84519205]
# [ 0.05 0.05 ]]
# [ 1.4 0.05]
排除重叠可以通过 hl = max(lr, hl)
来实现。注意,在之前的方法中,数组切片被复制到新的对象中,而在这里你得到的是 x
的视图,如果你想要新的对象,你必须明确说明。
编辑 一个不必要的优化
如果我们在第二对 sortedsearch
中只使用 x
的上半部分(如果你看看代码就会明白我是什么意思……),我们会得到两个好处,1)搜索速度稍微加快(sortedsearch
总是足够快),2)重叠的情况会自动处理。
作为额外内容,提供了将 x
的片段复制到新数组的代码。 注意 x
被更改以强制重叠
from numpy import *
# I changed x to force overlap
x = array([[ 0.46006547, 1.4 , 1.4, 1.4, 1.4 ],
[ 0.00912908, 0.00912908, 0.05, 0.05, 0.05 ]])
low, high = array([ 0.633, 0.01 ]), array([ 1.325, 0.99 ])
l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')
lr = max(l0r, l1r)
h0l = searchsorted(x[0,lr:], high[0], side='left')
h1l = searchsorted(x[1,lr:], high[1], side='left')
hl = min(h0l, h1l) + lr
t1 = x[:,range(lr)]
xn = x[:,range(lr,hl)]
ncol = shape(x)[1]
t2 = x[:,range(hl,ncol)]
print x
del(x)
print
print t1
print
# note that xn is a void array
print xn
print
print t2
# [[ 0.46006547 1.4 1.4 1.4 1.4 ]
# [ 0.00912908 0.00912908 0.05 0.05 0.05 ]]
#
# [[ 0.46006547 1.4 ]
# [ 0.00912908 0.00912908]]
#
# []
#
# [[ 1.4 1.4 1.4 ]
# [ 0.05 0.05 0.05]]