使用Numpy高效计算欧几里得距离矩阵

32 投票
5 回答
61021 浏览
提问于 2025-04-18 00:21

我有一组在二维空间中的点,需要计算每个点到其他点的距离。

我的点的数量相对较少,最多大概100个。但是因为我需要频繁而快速地计算这些移动点之间的关系,所以我知道如果一个个遍历这些点,复杂度可能会达到O(n^2),这让我想找一些方法来利用numpy的矩阵运算(或者scipy)。

在我的代码中,每个对象的坐标存储在它的类里面。不过,我也可以在更新类的坐标时,把它们更新到一个numpy数组里。

class Cell(object):
    """Represents one object in the field."""
    def __init__(self,id,x=0,y=0):
        self.m_id = id
        self.m_x = x
        self.m_y = y

我想到可以创建一个欧几里得距离矩阵来避免重复计算,但也许你有更聪明的数据结构。

我也欢迎推荐一些巧妙的算法。

另外,我注意到有一些类似的问题讨论欧几里得距离和numpy,但没有找到直接解决如何高效填充完整距离矩阵的问题。

5 个回答

4

如果你想要计算效率最高的方法,可以使用SciPy的cdist()(如果你只需要成对距离的向量,而不是完整的距离矩阵,可以用pdist())。正如Tweakimp的评论所说,这种方法比基于向量化和广播的方法快得多,这种方法是RichPauloo和shx2提出的。原因在于,SciPy的cdist()pdist()在底层使用了for循环和一些用C语言实现的计算,这些实现比向量化还要快。

顺便提一下,如果你可以使用SciPy,但还是想用广播的方法,你不需要自己去实现,因为distance_matrix()函数是纯Python实现的,它利用了广播和向量化的特性(源代码, 文档)。

值得一提的是,cdist()/pdist()在内存使用上也更高效,因为它是一个一个地计算距离,避免了创建n*n*d个元素的数组,其中n是点的数量,d是点的维度。

实验

我做了一些简单的实验来比较SciPy的cdist()distance_matrix()和NumPy中的广播实现的性能。我使用了Python的时间模块中的perf_counter_ns()来测量时间,所有结果都是在10000个二维空间点上运行10次的平均值,使用np.float64数据类型(在Python 3.8.10、Windows 10、Ryzen 2700和16 GB RAM上测试):

  • cdist() - 0.6724秒
  • distance_matrix() - 3.0128秒
  • 我的NumPy实现 - 3.6931秒

如果有人想重现实验,这里是代码:

from scipy.spatial import *
import numpy as np
from time import perf_counter_ns


def dist_mat_custom(a, b):
    return np.sqrt(np.sum(np.square(a[:, np.newaxis, :] - b[np.newaxis, :, :]), axis=-1))


results = []
size = 10000
it_num = 10
for i in range(it_num):
    a = np.random.normal(size=(size, 2))
    b = np.random.normal(size=(size, 2))
    start = perf_counter_ns()
    c = distance_matrix(a, b)
    #c = dist_mat_custom(a, b)
    #c = distance.cdist(a, b)
    results.append(perf_counter_ns() - start)
print(np.mean(results) / 1e9)
8

下面是你可以使用numpy来实现的方法:

import numpy as np

x = np.array([0,1,2])
y = np.array([2,4,6])

# take advantage of broadcasting, to make a 2dim array of diffs
dx = x[..., np.newaxis] - x[np.newaxis, ...]
dy = y[..., np.newaxis] - y[np.newaxis, ...]
dx
=> array([[ 0, -1, -2],
          [ 1,  0, -1],
          [ 2,  1,  0]])

# stack in one array, to speed up calculations
d = np.array([dx,dy])
d.shape
=> (2, 3, 3)

现在只需要沿着0轴计算L2范数(就像在这里讨论的那样):

(d**2).sum(axis=0)**0.5
=> array([[ 0.        ,  2.23606798,  4.47213595],
          [ 2.23606798,  0.        ,  2.23606798],
          [ 4.47213595,  2.23606798,  0.        ]])
13

Jake Vanderplas在他的《Python数据科学手册》中给出了一个使用广播的例子,这个例子和@shx2提到的非常相似。

import numpy as np
rand = random.RandomState(42)
X = rand.rand(3, 2)  
dist_sq = np.sum((X[:, np.newaxis, :] - X[np.newaxis, :, :]) ** 2, axis = -1)

dist_sq
array([[0.        , 0.18543317, 0.81602495],
       [0.18543317, 0.        , 0.22819282],
       [0.81602495, 0.22819282, 0.        ]])
14

如果你不需要完整的距离矩阵,使用kd树会更好。可以考虑使用 scipy.spatial.cKDTreesklearn.neighbors.KDTree。这是因为kd树可以在O(n log n)的时间内找到k个最近邻居,这样你就可以避免计算所有n个点之间距离时的O(n**2)复杂度。

43

你可以利用 complex 类型:

# build a complex array of your cells
z = np.array([complex(c.m_x, c.m_y) for c in cells])

第一种解决方案

# mesh this array so that you will have all combinations
m, n = np.meshgrid(z, z)
# get the distance via the norm
out = abs(m-n)

第二种解决方案

这里的关键是“网格化”。不过 numpy 很聪明,所以你不需要自己生成 mn。只需要用转置后的 z 来计算差值。网格会自动生成:

out = abs(z[..., np.newaxis] - z)

第三种解决方案

如果 z 直接设置为一个二维数组,你可以用 z.T 来代替那种奇怪的 z[..., np.newaxis]。所以最后,你的代码看起来会是这样的:

z = np.array([[complex(c.m_x, c.m_y) for c in cells]]) # notice the [[ ... ]]
out = abs(z.T-z)

示例

>>> z = np.array([[0.+0.j, 2.+1.j, -1.+4.j]])
>>> abs(z.T-z)
array([[ 0.        ,  2.23606798,  4.12310563],
       [ 2.23606798,  0.        ,  4.24264069],
       [ 4.12310563,  4.24264069,  0.        ]])

另外,你可能想在之后去掉重复项,取上三角部分:

>>> np.triu(out)
array([[ 0.        ,  2.23606798,  4.12310563],
       [ 0.        ,  0.        ,  4.24264069],
       [ 0.        ,  0.        ,  0.        ]])

一些基准测试

>>> timeit.timeit('abs(z.T-z)', setup='import numpy as np;z = np.array([[0.+0.j, 2.+1.j, -1.+4.j]])')
4.645645342274779
>>> timeit.timeit('abs(z[..., np.newaxis] - z)', setup='import numpy as np;z = np.array([0.+0.j, 2.+1.j, -1.+4.j])')
5.049334864854522
>>> timeit.timeit('m, n = np.meshgrid(z, z); abs(m-n)', setup='import numpy as np;z = np.array([0.+0.j, 2.+1.j, -1.+4.j])')
22.489568296184686

撰写回答