高效计算numpy数组中三点间的角度

2 投票
1 回答
86 浏览
提问于 2025-04-14 17:49

假设我们有一个大小为 M x N 的 numpy 数组 A,我们可以把它看作是 M 个维度为 N 的向量。对于三个向量 a,b,c,我们想要计算它们之间形成的角度的余弦值:

cos(angle(a,b,c)) = np.dot((a-b)/norm(a-b), (c-b)/norm(c-b))

我们想要计算这个值,针对 A 中的三元组,应该有 (M choose 2)*(M-2)独特 的三元组(因为 ac 是对称的;如果我说错了请纠正我)。当然,我们可以用三重嵌套的循环来完成这个任务,但我希望能用一种向量化的方式来实现。我很确定可以使用一些广播技巧来计算一个包含所需输出和更多内容的数组,但我希望有人能提供一个方法,能够准确计算出 独特 的值,最好是没有额外的计算。谢谢。

编辑。 为了完整性,使用循环的简单实现:

angles = []
for i in range(len(A)):
    for j in range(len(A)):
        for k in range(i+1, len(A)):
            if j not in (i,k):
                d1 = A[i] - A[j]
                d2 = A[k] - A[j]
                ang = np.dot(d1/np.linalg.norm(d1), d2/np.linalg.norm(d2))
                angles.append(ang)

1 个回答

4

应该有 (M 选 2)*(M-2) 个独特的三元组(因为 a 和 c 的对称性;如果我说错了,请纠正我)

我觉得没错。我计算了 M * ((M-1) 选 2),这也是等价的。

我希望有人能提供一个方法,准确计算出独特的数量,最好是没有额外的计算。

好吧,我们先从简单的开始 - 将你的循环向量化,假设我们已经生成了索引数组 ijk

def cosang1(A, i, j, k):
    d1 = A[i] - A[j]
    d2 = A[k] - A[j]
    d1_hat = np.linalg.norm(d1, axis=1, keepdims=True)
    d2_hat = np.linalg.norm(d2, axis=1, keepdims=True)
    # someone will almost certainly suggest a better way to do this
    ang = np.einsum("ij, ij -> i", d1/d1_hat, d2/d2_hat)
    return ang

这将问题简化为计算索引数组,假设索引数组的计算时间占总计算时间的比例很小。我看不出有什么办法可以避免冗余计算,而不做这种处理。

然后,如果我们愿意允许冗余计算,生成索引的最简单方法就是使用 np.meshgrid

def cosang2(A):
    i = np.arange(len(A))
    i, j, k = np.meshgrid(i, i, i)
    i, j, k = i.ravel(), j.ravel(), k.ravel()
    return cosang1(A, i, j, k)

在 Colab 上,对于形状为 (30, 3) 的 A,使用 Python 循环的方法花了 160 毫秒,而这个解决方案只花了 7 毫秒。

如果我们可以使用 Numba,快速生成独特的索引集合非常简单。这基本上就是将你的代码拆分成一个函数:

from numba import jit
# generate unique tuples of indices of vectors
@jit(nopython=True)
def get_ijks(M):
    ijks = []
    for i in range(M):
        for j in range(M):
            for k in range(i+1, M):
                if j not in (i, k):
                    ijks.append((i, j, k))
    return ijks

(当然,我们也可以在你的整个循环上使用 Numba。)

这比冗余的向量化解决方案花费的时间少了一半。

使用纯 NumPy 高效生成索引可能是可行的。最开始,我以为这会简单得多:

i = np.arange(M)
j, k = np.triu_indices(M, 1)
i, j, k = np.broadcast_arrays(i, j[:, None], k[:, None])
i, j, k = i.ravel(), j.ravel(), k.ravel()

这并不完全正确,但可以从这里开始,并通过对 range(M) 的循环来修正这些索引(总比三重嵌套要好!)。像这样:

# generates the same set of indices as `get_ijks`,
# but currently a bit slower.
def get_ijks2(M):
    i = np.arange(M)
    j, k = np.triu_indices(M-1, 1)
    i, j, k = np.broadcast_arrays(i[:, None], j, k)
    i, j, k = i.ravel(), j.ravel(), k.ravel()
    for ii in range(M):
        # this can be improved by using slices
        # instead of masks where possible
        mask0 = i == ii
        mask1 = (j >= ii) & mask0
        mask2 = (k >= ii) & mask0
        j[mask1] += 1
        k[mask1 | mask2] += 1
    return j, i, k  # intentionally swapped due to the way I think about this

~~我觉得可以通过只使用切片而不使用掩码来加速这个过程,但今晚我做不到。~~

更新:正如评论中所指出的,最后一个循环并不是必要的!

def get_ijks3(M):
    i = np.arange(M)
    j, k = np.triu_indices(M-1, 1)
    i, j, k = np.broadcast_arrays(i[:, None], j, k)
    i, j, k = i.ravel(), j.ravel(), k.ravel()
    mask1 = (j >= i)
    mask2 = (k >= i)
    j[mask1] += 1
    k[mask1 | mask2] += 1
    return j, i, k  # intentionally swapped

这比 Numba 循环快了很多。实际上,我很惊讶这能成功!


所有代码在一起,以防你想运行它:

from numba import jit
import numpy as np
rng = np.random.default_rng(23942342)

M = 30
N = 3
A = rng.random((M, N))

# generate unique tuples of indices of vectors
@jit(nopython=True)
def get_ijks(M):
    ijks = []
    for i in range(M):
        for j in range(M):
            for k in range(i+1, M):
                if j not in (i, k):
                    ijks.append((i, j, k))
    return ijks

# attempt to generate the same integers efficiently
# without Numba
def get_ijks2(M):
    i = np.arange(M)
    j, k = np.triu_indices(M-1, 1)
    i, j, k = np.broadcast_arrays(i[:, None], j, k)
    i, j, k = i.ravel(), j.ravel(), k.ravel()
    for ii in range(M):
        # this probably doesn't need masks
        mask0 = i == ii
        mask1 = (j >= ii) & mask0
        mask2 = (k >= ii) & mask0
        j[mask1] += 1
        k[mask1 | mask2] += 1
    return j, i, k  # intentionally swapped due to the way I think about this

# proposed method 
def cosang1(A, i, j, k):
    d1 = A[i] - A[j]
    d2 = A[k] - A[j]
    d1_hat = np.linalg.norm(d1, axis=1, keepdims=True)
    d2_hat = np.linalg.norm(d2, axis=1, keepdims=True)
    ang = np.einsum("ij, ij -> i", d1/d1_hat, d2/d2_hat)
    return ang

# another naive implementation
def cosang2(A):
    i = np.arange(len(A))
    i, j, k = np.meshgrid(i, i, i)
    i, j, k = i.ravel(), j.ravel(), k.ravel()
    return cosang1(A, i, j, k)

# naive implementation provided by OP
def cosang0(A):
    angles = []
    for i in range(len(A)):
        for j in range(len(A)):
            for k in range(i+1, len(A)):
                if j not in (i,k):
                    d1 = A[i] - A[j]
                    d2 = A[k] - A[j]
                    ang = np.dot(d1/np.linalg.norm(d1), d2/np.linalg.norm(d2))
                    angles.append(ang)
    return angles

%timeit cosang0(A)

%timeit get_ijks(len(A))
ijks = np.asarray(get_ijks(M)).T
%timeit cosang1(A, *ijks)

%timeit cosang2(A)

# 180 ms ± 34.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 840 µs ± 68.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 2.19 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# <ipython-input-1-d2a3835710f2>:26: RuntimeWarning: invalid value encountered in divide
#   ang = np.einsum("ij, ij -> i", d1/d1_hat, d2/d2_hat)
# 8.13 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

cosangs0 = cosang0(A)
cosangs1 = cosang1(A, *ijks)
cosangs2 = cosang2(A)
np.testing.assert_allclose(cosangs1, cosangs0)  # passes

%timeit get_ijks2(M)
# 1.73 ms ± 242 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
i, j, k = get_ijks2(M)
cosangs3 = cosang1(A, i, j, k)
np.testing.assert_allclose(np.sort(cosangs3), np.sort(cosangs0))  # passes

%timeit get_ijks3(M)
# 184 µs ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
i, j, k = get_ijks3(M)
cosangs4 = cosang1(A, i, j, k)
np.testing.assert_allclose(np.sort(cosangs4), np.sort(cosangs0))  # passes

撰写回答