使用numpy最快获取查找表索引的方法
这个问题是基于另一个问题,主要是为了加速以下代码的运行。我写了一段代码(有帮助的朋友一起完成的),这段代码从一个m x n x 3的numpy.ndarray
(也就是一张RGB图像)中提取像素值,和一个(位置)查找表中的像素值进行比较,然后输出查找表中对应的像素值索引,具体如下:
img = np.random.randint(0, 9 , (3, 3, 3))
lut2 = img[0,0:2,:]
output = []
for x in xrange(lut2.shape[0]):
if lut2[x] in img:
output.append(np.concatenate(np.where( (img == lut2[x]).sum(axis=2) == 3 )))
print np.array(output)
输出结果:
[[0 0]
[0 1]]
问题:如果在img
中有多个地方出现lut[x]
,那么输出结果就不正确了。也许更好的方法是把输出格式化为一个嵌套列表,因为任何可以索引的东西都可以使用。例如:
[[x_px1, y_px1], [[x_px2a, y_px2a], [x_px2b, y_px2b]]]
这里的px2a
和px2b
是img
中具有相同颜色值的不同像素。
我一直在尝试使用numpy
的索引功能,但没有找到更好的方法。虽然上面的代码部分有效,但由于对数组的迭代,它的速度非常慢。如果我给它一个常见大小的图像,完成的时间会非常不可接受。
有人能给我指个更快的解决方案吗?
编辑:我理想中的未来算法总结:
an image all pixels replaced all indices are
(here, single -----> by index in lut -----> reused in digital_scale
black pixel) (here, first pixel) to output an array of scalars
[[[0, 0, 0]]] [[[0, 0]]] [[0]]
array shape: array shape: array shape:
m X n X 3 m X n X 2 m X n
最后的m X n
数组将用于两件事:
- 在不同的时间间隔上绘制数据,通过把数组存储在数据库中。
- 对由于某些外部事件导致的数组变化进行统计分析。
注意:基础图像已经是一种图表,但我无法访问数值数据,所以我需要反向处理这些图表。
一些(可选)关于我代码目的的详细说明:
查找表通常是基于像素值,而不是位置。不过在这种情况下,位置可能更好(我觉得)。上面代码中的lut
数组是一个虚拟的线性颜色尺度,代表我需要数字化的速度,如下所示:
由于每个级别的颜色都是均匀的,这个尺度实际上可以简化为一个1 X height X 3
的数组(适用于RGB、HSV等)。我实际上可以去掉第一维,直接在一个height X 3
的数组上进行迭代。我用来数字化这个尺度的东西有:
- 它是围绕零点对称的
- 它是对称的
- 它是线性的
- 最大值和最小值是已知的
因此,这个尺度的数字表示将是:
digital_scale = np.linspace(-max_speed, max_speed, lut.shape[0])
我只需要在lut
中找到我想要的像素值的y
索引,然后取digital_scale
的第y
个元素,就能输出我需要的值(一个标量,见上面的算法总结)。
1 个回答
1
你可以使用Kd树,这里有一个示例:
import numpy as np
from scipy import spatial
H, W = 200, 100
np.random.seed(1)
a = np.random.randint(0, 20, (H, W, 3))
b = np.random.randint(0, 20, (20, 3))
tree = spatial.cKDTree(a.reshape(-1, 3))
res = tree.query_ball_point(b, 0.5, p=1)
print res
输出结果是:
[[] [577, 17471] [14636, 4515, 13693, 10988, 15013] [16935, 8576, 13286]
[2443] [7743, 5914] [] [7469, 19736, 13395, 14992, 9083, 15514]
[1167, 11416] [3903, 4968] [16504, 2996, 10805, 2264] [] [6725]
[14437, 5888] [17667] [4681, 2545, 6442] [15067, 4533]
[7876, 2235, 10152, 3288] [15404, 5691, 17216]
[15586, 9916, 16938, 15931, 4828, 4069]]
要查看索引2的结果:
rows, cols = np.where(np.all(a == b[None, None, 2], axis=-1))
assert np.all(rows * W + cols == sorted(res[2]))