Python如何提高numpy数组的性能?

2024-03-29 12:37:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个全球数字阵列数据这是一个200*200*3的三维数组,在三维空间中包含40000个点。在

我的目标是计算每个点到单位立方体四个角点的距离((0,0,0),(1,0,0),(0,1,0),(0,0,1)),这样我就可以确定哪个角点离它最近。在

def dist(*point):
    return np.linalg.norm(data - np.array(rgb), axis=2)

buffer = np.stack([dist(0, 0, 0), dist(1, 0, 0), dist(0, 1, 0), dist(0, 0, 1)]).argmin(axis=0)

我写了这段代码并对其进行了测试,每次运行大约花费10毫秒。 我的问题是如何提高这段代码的性能,在不到1ms的时间内更好地运行。在


Tags: 数据代码距离目标distdefnp单位
1条回答
网友
1楼 · 发布于 2024-03-29 12:37:08

您可以使用^{}-

# unit cube coordinates as array
uc = np.array([[0, 0, 0],[1, 0, 0], [0, 1, 0], [0, 0, 1]])

# buffer output
buf = cdist(data.reshape(-1,3), uc).argmin(1).reshape(data.shape[0],-1)

运行时测试

^{pr2}$

时间安排-

In [170]: data = np.random.rand(200,200,3)

In [171]: %timeit org_app()
100 loops, best of 3: 4.24 ms per loop

In [172]: %timeit cdist(data.reshape(-1,3), uc).argmin(1).reshape(data.shape[0],-1)
1000 loops, best of 3: 1.25 ms per loop

相关问题 更多 >