如何在numpy 2D数组中选择唯一元素的所有位置并画出边界框?

3 投票
1 回答
1660 浏览
提问于 2025-04-17 04:02

我有一个二维的numpy数组,我想找到所有独特元素的每一个位置。我们可以用 numpy.unique(numpyarray.) 来找到这些独特的元素。接下来就有点复杂了。我需要知道每个独特元素的所有位置。我们来看一个例子。

array([[1, 1, 2, 2],\
       [1, 1, 2, 2],\
       [3, 3, 4, 4],\
       [3, 3, 4, 4]])

结果应该是

1, (0,0),(1,1)
2, (0,2),(1,2)
3, (2,0),(3,1)
4, (2,2),(3,3)

那么该怎么做呢?有什么合适的方法来存储和遍历这些值呢?

需要注意的是,所有独特的值会彼此相邻。它们之间唯一的空隙可能只有零。我们再考虑一个变体。

 array([[1, 0, 1, 2, 2],\
        [1, 0, 1, 2, 2],\
        [3, 0, 3, 4, 4],\
        [3, 0, 3, 4, 4]])

结果应该是

1, (0,0),(1,2)
2, (0,3),(1,4)
3, (2,0),(3,2)
4, (2,3),(3,4)

边界上的零可以忽略不计。

非常感谢!

1 个回答

3

最简单、最直接的方法就是使用 numpy.where

比如说,如果你只想要一个 边界框

import numpy as np

x = np.array([[1,1,2,2],
              [1,1,2,2],
              [3,3,4,4],
              [3,3,4,4]])

for val in np.unique(x):
    rows, cols = np.where(x == val)
    rowstart, rowstop = np.min(rows), np.max(rows)
    colstart, colstop = np.min(cols), np.max(cols)
    print val, (rowstart, colstart), (rowstop, colstop) 

这个方法在处理全是零的例子时也能用。

如果你的数组很大,而且你已经有 scipy 这个库了,可以考虑使用 scipy.ndimage.find_objects,就像 @unutbu 提到的那样。

在你这个例子的特殊情况下,假如你的唯一值是连续的整数,你可以直接使用 find_objects。它需要一个数组,其中每个非零的连续整数代表一个对象,它会返回这个对象的边界框。(0会被忽略,正好符合你的需求。)不过一般来说,你可能需要先处理一下,把任意的唯一值转换成连续的整数。

find_objects 会返回一个包含 slice 对象的元组列表。说实话,这些可能正是你想要的,如果你想要边界框的话。不过,打印出开始和结束的索引时,可能会显得有点乱。

import numpy as np
import scipy.ndimage as ndimage

x = np.array([[1, 0, 1, 2, 2],
              [1, 0, 1, 2, 2],
              [3, 0, 3, 4, 4],
              [3, 0, 3, 4, 4]])

for i, item in enumerate(ndimage.find_objects(x), start=1):
    print i, item

这看起来可能和你预期的稍有不同。这些是 slice 对象,所以“最大值”总是比之前例子中的“最大值”大一。这是为了让你可以直接用这个元组来切片,获取你需要的数据。

比如:

for i, item in enumerate(ndimage.find_objects(x), start=1):
    print i, ':'
    print x[item], '\n'

如果你真的想要开始和结束的索引,可以这样做:

    for i, (rowslice, colslice) in enumerate(ndimage.find_objects(x), start=1):
        print i, 
        print (rowslice.start, rowslice.stop - 1),
        print (colslice.start, colslice.stop - 1)

如果你的唯一值不是连续的整数,你就需要做一些预处理,正如我之前提到的。你可以这样做:

import numpy as np
import scipy.ndimage as ndimage

x = np.array([[1.1, 0.0, 1.1, 0.9, 0.9],
              [1.1, 0.0, 1.1, 0.9, 0.9],
              [3.3, 0.0, 3.3, 4.4, 4.4],
              [3.3, 0.0, 3.3, 4.4, 4.4]])
ignored_val = 0.0
labels = np.zeros(data.shape, dtype=np.int)

i = 1
for val in np.unique(x):
    if val != ignored_val:
        labels[x == val] = i
        i += 1

# Now we can use the "labels" array as input to find_objects
for i, item in enumerate(ndimage.find_objects(labels), start=1):
    print i, ':'
    print x[item], '\n'

撰写回答