Numpy:检查多维数组中元素是否在元组内

3 投票
4 回答
3405 浏览
提问于 2025-04-17 05:43

看起来我还是对 numpy中的“in”运算符 有点困惑。情况是这样的:

>>> a = np.random.randint(1, 10, (2, 2, 3))
>>> a
array([[[9, 8, 8],
        [4, 9, 1]],

       [[6, 6, 3],
        [9, 3, 5]]])

我想找出那些三元组中,第二个元素在 (6, 8) 这个范围内的索引。我直观上尝试的方法是:

>>> a[:, :, 1] in (6, 8)
ValueError: The truth value of an array with more than one element...

我最终的目标是把那些位置插入前面那个数字乘以二的结果。 根据上面的例子,a 应该变成:

array([[[9, 18, 8],   #8 @ pos #2 --> replaced by 9 @ pos #1 by 2
        [4, 9, 1]],

       [[6, 12, 3],   #6 @ pos #2 --> replaced by 6 @ pos #1 by 2
        [9, 3, 5]]])

提前谢谢你的建议和时间!

4 个回答

1

还有一种方法是使用查找表,这个方法是我从Cellprofiler的一个开发者那里学到的。首先,你需要创建一个查找表(LUT),它的大小要和你数组中最大的数字一样。对于数组中每一个可能的值,查找表里要么是True,要么是False。

# create a large volume image with random numbers
a = np.random.randint(1, 1000, (50, 1000 , 1000))
labels_to_find=np.unique(np.random.randint(1,1000,500))

# create filter mask LUT 
def find_mask_LUT(inputarr, obs):
    keep = np.zeros(np.max(inputarr)+1, bool)
    keep[np.array(obs)] = True
    return keep[inputarr]

# This will return a mask that is the 
# same shape as a, with True is a is one of the 
# labels we look for, False otherwise
find_mask_LUT(a, labels_to_find)

这个方法运行得非常快(比np.in1d快得多,而且速度和对象的数量无关)。

2

这里有一个方法,可以处理任意长度的元组。它使用了 numpy.in1d 这个函数。

import numpy as np
np.random.seed(1)

a = np.random.randint(1, 10, (2, 2, 3))
print(a)

check_tuple = (6, 9, 1)

bool_array = np.in1d(a[:,:,1], check_tuple)
ind = np.where(bool_array)[0]
a0 = a[:,:,0].reshape((len(bool_array), ))
a1 = a[:,:,1].reshape((len(bool_array), ))
a1[ind] = a0[ind] * 2

print(a)

这是输出结果:

[[[6 9 6]
  [1 1 2]]

 [[8 7 3]
  [5 6 3]]]

[[[ 6 12  6]
  [ 1  2  2]]

 [[ 8  7  3]
  [ 5 10  3]]]
1
import numpy as np
a = np.array([[[9, 8, 8],
               [4, 9, 1]],

              [[6, 6, 3],
               [9, 3, 5]]])

ind=(a[:,:,1]<=8) & (a[:,:,1]>=6)
a[ind,1]=a[ind,0]*2
print(a)

产生

[[[ 9 18  8]
  [ 4  9  1]]

 [[ 6 12  3]
  [ 9  3  5]]]

如果你想检查一个集合中是否包含某个元素,而这个集合不是简单的范围,那么我觉得可以参考mac的想法,使用Python循环,或者bellamyj的想法,使用np.in1d。哪种方法更快,取决于check_tuple的大小:

test.py:

import numpy as np
np.random.seed(1)

N = 10
a = np.random.randint(1, 1000, (2, 2, 3))
check_tuple = np.random.randint(1, 1000, N)

def using_in1d(a):
    idx = np.in1d(a[:,:,1], check_tuple)
    idx=idx.reshape(a[:,:,1].shape)
    a[idx,1] = a[idx,0] * 2
    return a

def using_in(a):
    idx = np.zeros(a[:,:,0].shape,dtype=bool)
    for n in check_tuple:
        idx |= a[:,:,1]==n
    a[idx,1] = a[idx,0]*2
    return a

assert np.allclose(using_in1d(a),using_in(a))    

当N = 10时,using_in稍微快一点:

% python -m timeit -s'import test' 'test.using_in1d(test.a)'
10000 loops, best of 3: 156 usec per loop
% python -m timeit -s'import test' 'test.using_in(test.a)'
10000 loops, best of 3: 143 usec per loop

当N = 100时,using_in1d快得多:

% python -m timeit -s'import test' 'test.using_in1d(test.a)'
10000 loops, best of 3: 171 usec per loop
% python -m timeit -s'import test' 'test.using_in(test.a)'
1000 loops, best of 3: 1.15 msec per loop

撰写回答