有没有只计算结果对角线元素的numpy/scipy点积?

59 投票
3 回答
17396 浏览
提问于 2025-04-17 15:08

想象一下你有两个numpy数组:

> A, A.shape = (n,p)
> B, B.shape = (p,p)

通常情况下,p是一个比较小的数字(p <= 200),而n可以非常大。

我正在做以下操作:

result = np.diag(A.dot(B).dot(A.T))

如你所见,我只保留了n个对角线上的元素,但在这个过程中计算了一个(n x n)的数组,然后只从中保留了对角线上的元素。

我希望有一个像diag_dot()这样的函数,它只计算结果的对角线元素,而不需要分配完整的内存。

一个结果可能是:

> result = diag_dot(A.dot(B), A.T)

有没有现成的功能可以做到这一点,并且能有效地完成,而不需要分配这个中间的(n x n)数组?

3 个回答

2

一个简单的回答,它避免了构建大型中间数组的方法是:

result=np.empty([n,], dtype=A.dtype )
for i in xrange(n):
    result[i]=A[i,:].dot(B).dot(A[i,:])
42

你可以通过numpy.einsum实现几乎所有你曾经梦想过的功能。刚开始接触的时候,它看起来就像是黑魔法一样,让人摸不着头脑...

>>> a = np.arange(15).reshape(5, 3)
>>> b = np.arange(9).reshape(3, 3)

>>> np.diag(np.dot(np.dot(a, b), a.T))
array([  60,  672, 1932, 3840, 6396])
>>> np.einsum('ij,ji->i', np.dot(a, b), a.T)
array([  60,  672, 1932, 3840, 6396])
>>> np.einsum('ij,ij->i', np.dot(a, b), a)
array([  60,  672, 1932, 3840, 6396])

编辑 其实你可以一次性搞定所有的事情,真是太夸张了...

>>> np.einsum('ij,jk,ki->i', a, b, a.T)
array([  60,  672, 1932, 3840, 6396])
>>> np.einsum('ij,jk,ik->i', a, b, a)
array([  60,  672, 1932, 3840, 6396])

编辑 不过你不想让它自己想太多... 还把提问者自己对问题的回答加上来做对比。

n, p = 10000, 200
a = np.random.rand(n, p)
b = np.random.rand(p, p)

In [2]: %timeit np.einsum('ij,jk,ki->i', a, b, a.T)
1 loops, best of 3: 1.3 s per loop

In [3]: %timeit np.einsum('ij,ij->i', np.dot(a, b), a)
10 loops, best of 3: 105 ms per loop

In [4]: %timeit np.diag(np.dot(np.dot(a, b), a.T))
1 loops, best of 3: 5.73 s per loop

In [5]: %timeit (a.dot(b) * a).sum(-1)
10 loops, best of 3: 115 ms per loop
68

我觉得我自己明白了,不过还是想分享一下我的解决方案:

因为只获取矩阵乘法的对角线

> Z = N.diag(X.dot(Y))

相当于把X的每一行和Y的每一列的点积单独相加,所以之前的说法可以理解为:

> Z = (X * Y.T).sum(-1)

对于原始变量来说,这意味着:

> result = (A.dot(B) * A).sum(-1)

如果我错了请纠正我,但应该就是这样……

撰写回答