如何高效计数3D numpy数组中相邻元素
我有一个三维的numpy数组,里面填满了从1到7的整数。
我想统计每个单元格周围邻居单元格中独特元素的数量。比如,在一个二维数组中:
a=[[1,1,1,7,4],
[1,1,1,3,2],
[1,1,1,2,2],
[1,3,1,4,2],
[1,1,1,4,2]]
会得到这样的结果:
[[1,1,2,3,2],
[1,1,2,3,3],
[1,2,2,4,1],
[2,1,3,3,2],
[1,2,2,3,2]]
我现在是逐个检查数组中的每个单元格,然后一个一个查看它的邻居。
temp = np.zeros(6)
if (x>0):
temp[0] = model[x-1,y,z]
if (x<x_len-1):
temp[1] = model[x+1,y,z]
if (y>0):
temp[2] = model[x,y-1,z]
if (y<y_len-1):
temp[3] = model[x,y+1,z]
if (z>0):
temp[4] = model[x,y,z-1]
if (z<z_len-1):
temp[5] = model[x,y,z+1]
result[x,y,z] = np.count_nonzero(np.unique(temp))
我发现这样做速度很慢,也不太高效。有没有更快、更有效的方法来完成这个任务呢?
谢谢。
2 个回答
[[1 2 3]
[2 2 4]
[1 3 3]]
你可以尝试以下方法,这个方法不一定是最优的,如果你的数据量太大可能会出现问题,但可以试试。
import numpy as np
from sklearn.feature_extraction.image import extract_patches
a = np.array([[1,1,1,7,4],
[1,1,1,3,2],
[1,1,1,2,2],
[1,3,1,4,2],
[1,1,1,4,2]])
patches = extract_patches(a, patch_shape=(3, 3), extraction_step=(1, 1))
neighbor_template = np.array([[0, 1, 0],
[1, 0, 1],
[0, 1, 0]]).astype(np.bool)
centers = patches[:, :, 1, 1]
neighbors = patches[:, :, neighbor_template]
possible_values = np.arange(1, 8)
counts = (neighbors[..., np.newaxis] ==
possible_values[np.newaxis, np.newaxis, np.newaxis]).sum(2)
nonzero_counts = counts > 0
unique_counter = nonzero_counts.sum(-1)
print unique_counter
这个结果会给你一个数组的中间部分,这是你期望得到的结果。如果想要得到完整的数组,包括边界部分,就需要单独处理这些边界。使用numpy 1.8,你可以用 np.pad
这个函数,选择 reflect
模式来给边界加上一像素的填充。这样也能正确处理边界。
现在我们来看看三维数据,确保我们不会用太多内存。
# first we generate a neighbors template
from scipy.ndimage import generate_binary_structure
neighbors = generate_binary_structure(3, 1)
neighbors[1, 1, 1] = False
neighbor_coords = np.array(np.where(neighbors)).T
data = np.random.randint(1, 8, (384, 384, 100))
data_neighbors = np.zeros((neighbors.sum(),) + tuple(np.array(data.shape) - 2), dtype=np.uint8)
# extract_patches only generates a strided view
data_view = extract_patches(data, patch_shape=(3, 3, 3), extraction_step=(1, 1, 1))
for neigh_coord, data_neigh in zip(neighbor_coords, data_neighbors):
sl = [slice(None)] * 3 + list(neigh_coord)
data_neigh[:] = data_view[sl]
indicator = (data_neigh[np.newaxis] == possible_values[:, np.newaxis, np.newaxis, np.newaxis]).sum(1) > 0
uniques = indicator.sum(0)
和之前一样,你需要找出 uniques
中唯一的条目数量。使用像 generate_binary_structure
这样的工具来自scipy,以及 extract_patches
中的滑动窗口方法,可以让这个方法变得通用:如果你想要一个26邻域而不是6邻域,只需要把 generate_binary_structure(3, 1)
改成 generate_binary_structure(3, 2)
。这个方法也可以很简单地扩展到更多维度,只要生成的数据量能够适应你电脑的内存。
好吧,可能有一种方法:
- 创建6个偏移数组(左、右、上、下、前、后)
- 把这些数组组合成一个四维数组,大小是(R-2, C-2, D-2, 6)
- 根据最后一个维度(大小为6的维度)对这个四维数组进行排序
现在你有了一个四维数组,可以为每个单元格选择一个排序好的邻居向量。接下来,你可以通过以下方式计算不同的邻居数量:
- 对第四个轴(排序后的数组)使用
diff
- 计算第四个轴上非零差值的总和
这样你就能得到不同邻居的数量减去1。
第一部分可能比较清楚。如果一个单元格的邻居是(1, 2, 4, 2, 2, 3),那么邻居向量会排序成(1, 2, 2, 2, 3, 4)。差值向量则是(1, 0, 0, 1, 1),非零元素的总和((diff(v) != 0).sum(axis=4)
)是3。所以,有4个独特的邻居。
当然,这种方法没有考虑到边缘。你可以通过使用 numpy.pad
和 reflect
模式,将初始数组在每个方向上填充1个单元来解决这个问题。(这种模式实际上是唯一一种保证不会在邻域中引入新值的方法,试试用二维数组来理解为什么。)
例如:
import numpy as np
# create some fictional data
dat = np.random.randint(1, 8, (6, 7, 8))
# pad the data by 1
datp = np.pad(dat, 1, mode='reflect')
# create the neighbouring 4D array
neigh = np.concatenate((
datp[2:,1:-1,1:-1,None], datp[:-2,1:-1,1:-1,None],
datp[1:-1,2:,1:-1,None], datp[1:-1,:-2,1:-1,None],
datp[1:-1,1:-1,2:,None], datp[1:-1,1:-1,:-2,None]), axis=3)
# sort the 4D array
neigh.sort(axis=3)
# calculate the number of unique samples
usamples = (diff(neigh, axis=3) != 0).sum(axis=3) + 1
上面的解决方案相当通用,适用于任何可以排序的东西。不过,它消耗了很多内存(6个数组的副本),而且性能不是很高。如果我们只满足于一个只适用于这种特殊情况(值是非常小的整数)的解决方案,我们可以做一些位运算的魔法。
- 创建一个数组,每个项目用位掩码表示(1 = 00000001,2 = 00000010,3 = 00000100,等等)
- 把邻居数组进行按位或操作
- 使用查找表计算按位或结果中的位数
.
import numpy as np
# create a "number of ones" lookup table
no_ones = np.array([bin(i).count("1") for i in range(256)], dtype='uint8')
# create some fictional data
dat = np.random.randint(1, 8, (6, 7, 8))
# create a bit mask of the cells
datb = 1 << dat.astype('uint8')
# pad the data by 1
datb = np.pad(datb, 1, mode='reflect')
# or the padded data together
ored = (datb[ 2:, 1:-1, 1:-1] |
datb[:-2, 1:-1, 1:-1] |
datb[1:-1, 2:, 1:-1] |
datb[1:-1, :-2, 1:-1] |
datb[1:-1, 1:-1, 2:] |
datb[1:-1, 1:-1, :-2])
# get the number of neighbours from the LUT
usamples = no_ones[ored]
性能影响相当明显。第一种版本在我这台机器上处理一个384 x 384 x 100的表格需要2.57秒,而第二种版本只需283毫秒(不包括创建随机数据的时间)。这分别转化为每个单元19纳秒和174纳秒。
不过,这个解决方案仅限于不同值数量合理且已知的情况。如果不同可能值的数量超过64,位运算的魔法就失去了效果。(另外,当不同值达到20个左右时,查找部分必须分成多个操作,因为查找表的内存消耗。查找表应该适合CPU缓存,否则会变慢。)
另一方面,扩展解决方案以使用完整的26个邻居是简单且相当快速的。