如何高效计数3D numpy数组中相邻元素

2 投票
2 回答
2498 浏览
提问于 2025-04-18 16:18

我有一个三维的numpy数组,里面填满了从1到7的整数。
我想统计每个单元格周围邻居单元格中独特元素的数量。比如,在一个二维数组中:

a=[[1,1,1,7,4],
   [1,1,1,3,2],
   [1,1,1,2,2],
   [1,3,1,4,2],
   [1,1,1,4,2]]  

会得到这样的结果:

[[1,1,2,3,2],
 [1,1,2,3,3],
 [1,2,2,4,1],
 [2,1,3,3,2],
 [1,2,2,3,2]]  

我现在是逐个检查数组中的每个单元格,然后一个一个查看它的邻居。

temp = np.zeros(6)
if (x>0):
    temp[0] = model[x-1,y,z]
if (x<x_len-1):
    temp[1] = model[x+1,y,z]
if (y>0):
    temp[2] = model[x,y-1,z]
if (y<y_len-1):
    temp[3] = model[x,y+1,z]
if (z>0):
    temp[4] = model[x,y,z-1]
if (z<z_len-1):
    temp[5] = model[x,y,z+1]
result[x,y,z] = np.count_nonzero(np.unique(temp))  

我发现这样做速度很慢,也不太高效。有没有更快、更有效的方法来完成这个任务呢?

谢谢。

2 个回答

1
[[1 2 3]
 [2 2 4]
 [1 3 3]]

你可以尝试以下方法,这个方法不一定是最优的,如果你的数据量太大可能会出现问题,但可以试试。

import numpy as np
from sklearn.feature_extraction.image import extract_patches

a = np.array([[1,1,1,7,4],
              [1,1,1,3,2],
              [1,1,1,2,2],
              [1,3,1,4,2],
              [1,1,1,4,2]])

patches = extract_patches(a, patch_shape=(3, 3), extraction_step=(1, 1))

neighbor_template = np.array([[0, 1, 0],
                              [1, 0, 1],
                              [0, 1, 0]]).astype(np.bool)
centers = patches[:, :, 1, 1]
neighbors = patches[:, :, neighbor_template]

possible_values = np.arange(1, 8)
counts = (neighbors[..., np.newaxis] ==
          possible_values[np.newaxis, np.newaxis, np.newaxis]).sum(2)

nonzero_counts = counts > 0
unique_counter = nonzero_counts.sum(-1)

print unique_counter

这个结果会给你一个数组的中间部分,这是你期望得到的结果。如果想要得到完整的数组,包括边界部分,就需要单独处理这些边界。使用numpy 1.8,你可以用 np.pad 这个函数,选择 reflect 模式来给边界加上一像素的填充。这样也能正确处理边界。

现在我们来看看三维数据,确保我们不会用太多内存。

# first we generate a neighbors template
from scipy.ndimage import generate_binary_structure

neighbors = generate_binary_structure(3, 1)
neighbors[1, 1, 1] = False
neighbor_coords = np.array(np.where(neighbors)).T

data = np.random.randint(1, 8, (384, 384, 100))
data_neighbors = np.zeros((neighbors.sum(),) + tuple(np.array(data.shape) - 2), dtype=np.uint8)

# extract_patches only generates a strided view
data_view = extract_patches(data, patch_shape=(3, 3, 3), extraction_step=(1, 1, 1))

for neigh_coord, data_neigh in zip(neighbor_coords, data_neighbors):
    sl = [slice(None)] * 3 + list(neigh_coord)
    data_neigh[:] = data_view[sl]

indicator = (data_neigh[np.newaxis] == possible_values[:, np.newaxis, np.newaxis, np.newaxis]).sum(1) > 0

uniques = indicator.sum(0)

和之前一样,你需要找出 uniques 中唯一的条目数量。使用像 generate_binary_structure 这样的工具来自scipy,以及 extract_patches 中的滑动窗口方法,可以让这个方法变得通用:如果你想要一个26邻域而不是6邻域,只需要把 generate_binary_structure(3, 1) 改成 generate_binary_structure(3, 2)。这个方法也可以很简单地扩展到更多维度,只要生成的数据量能够适应你电脑的内存。

2

好吧,可能有一种方法:

  • 创建6个偏移数组(左、右、上、下、前、后)
  • 把这些数组组合成一个四维数组,大小是(R-2, C-2, D-2, 6)
  • 根据最后一个维度(大小为6的维度)对这个四维数组进行排序

现在你有了一个四维数组,可以为每个单元格选择一个排序好的邻居向量。接下来,你可以通过以下方式计算不同的邻居数量:

  • 对第四个轴(排序后的数组)使用 diff
  • 计算第四个轴上非零差值的总和

这样你就能得到不同邻居的数量减去1。

第一部分可能比较清楚。如果一个单元格的邻居是(1, 2, 4, 2, 2, 3),那么邻居向量会排序成(1, 2, 2, 2, 3, 4)。差值向量则是(1, 0, 0, 1, 1),非零元素的总和((diff(v) != 0).sum(axis=4))是3。所以,有4个独特的邻居。

当然,这种方法没有考虑到边缘。你可以通过使用 numpy.padreflect 模式,将初始数组在每个方向上填充1个单元来解决这个问题。(这种模式实际上是唯一一种保证不会在邻域中引入新值的方法,试试用二维数组来理解为什么。)

例如:

import numpy as np

# create some fictional data
dat = np.random.randint(1, 8, (6, 7, 8))

# pad the data by 1
datp = np.pad(dat, 1, mode='reflect')

# create the neighbouring 4D array
neigh = np.concatenate((
    datp[2:,1:-1,1:-1,None], datp[:-2,1:-1,1:-1,None], 
    datp[1:-1,2:,1:-1,None], datp[1:-1,:-2,1:-1,None],
    datp[1:-1,1:-1,2:,None], datp[1:-1,1:-1,:-2,None]), axis=3)

# sort the 4D array
neigh.sort(axis=3)

# calculate the number of unique samples
usamples = (diff(neigh, axis=3) != 0).sum(axis=3) + 1

上面的解决方案相当通用,适用于任何可以排序的东西。不过,它消耗了很多内存(6个数组的副本),而且性能不是很高。如果我们只满足于一个只适用于这种特殊情况(值是非常小的整数)的解决方案,我们可以做一些位运算的魔法。

  • 创建一个数组,每个项目用位掩码表示(1 = 00000001,2 = 00000010,3 = 00000100,等等)
  • 把邻居数组进行按位或操作
  • 使用查找表计算按位或结果中的位数

.

import numpy as np

# create a "number of ones" lookup table
no_ones = np.array([bin(i).count("1") for i in range(256)], dtype='uint8')

# create some fictional data
dat = np.random.randint(1, 8, (6, 7, 8))

# create a bit mask of the cells
datb = 1 << dat.astype('uint8')

# pad the data by 1
datb = np.pad(datb, 1, mode='reflect')

# or the padded data together
ored = (datb[ 2:, 1:-1, 1:-1] |
        datb[:-2, 1:-1, 1:-1] |
        datb[1:-1,  2:, 1:-1] |
        datb[1:-1, :-2, 1:-1] |
        datb[1:-1, 1:-1,  2:] |
        datb[1:-1, 1:-1, :-2])

# get the number of neighbours from the LUT
usamples = no_ones[ored]

性能影响相当明显。第一种版本在我这台机器上处理一个384 x 384 x 100的表格需要2.57秒,而第二种版本只需283毫秒(不包括创建随机数据的时间)。这分别转化为每个单元19纳秒和174纳秒。

不过,这个解决方案仅限于不同值数量合理且已知的情况。如果不同可能值的数量超过64,位运算的魔法就失去了效果。(另外,当不同值达到20个左右时,查找部分必须分成多个操作,因为查找表的内存消耗。查找表应该适合CPU缓存,否则会变慢。)

另一方面,扩展解决方案以使用完整的26个邻居是简单且相当快速的。

撰写回答