numpy einsum的替代品

2024-05-16 19:00:37 发布

您现在位置:Python中文网/ 问答频道 /正文

当我计算含有N行和n列的矩阵X的三阶矩时,我通常使用einsum

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N

这通常工作正常,但现在我使用更大的值,即n = 120和{},并且einsum返回以下错误:

ValueError: iterator is too large

做3个嵌套循环的替代方案是不可行的,所以我想知道是否有任何替代方案。在


Tags: is错误方案矩阵jklspilik
1条回答
网友
1楼 · 发布于 2024-05-16 19:00:37

请注意,计算这个值至少需要进行~n3×n=1730亿个操作(不考虑对称性),因此除非numpy可以访问GPU或其他东西,否则计算速度会很慢。在一台拥有~3ghz CPU的现代计算机上,假设没有SIMD/并行加速,整个计算预计需要60秒才能完成。在


对于测试,让我们从N=1000开始。我们将用这个来检查正确性和性能:

#!/usr/bin/env python3

import numpy
import time

numpy.random.seed(0)

n = 120
N = 1000
X = numpy.random.random((N, n))

start_time = time.time()

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)

end_time = time.time()

print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)

这大约需要8秒。这是基线。在

让我们从3个嵌套循环开始作为替代:

^{pr2}$

大概要半分钟,不行!一个原因是因为这实际上是四个嵌套循环:numpy.sum也可以被视为一个循环。在

我们注意到,总和可以转化为点积,以消除第四个循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        for l in range(n):
            M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds

现在好多了,但还是很慢。但我们注意到,点积可以转换为矩阵乘法,以消除一个循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds

嗯?现在这甚至比einsum更有效!我们还可以检查答案是否确实正确。在

我们能走得更远吗?对!我们可以通过以下方法消除k循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    Y = numpy.repeat(X[:,j], n).reshape((N, n))
    M3[j] = (Y * X).T @ X
# ~0.3 seconds

我们还可以使用广播(即X的每一行a * [b,c] == [a*b, a*c])来避免执行numpy.repeat(谢谢@Divakar):

M3 = numpy.zeros((n, n, n))
for j in range(n):
    Y = X[:,j].reshape((N, 1))
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j]
    M3[j] = (Y * X).T @ X
# ~0.16 seconds

如果我们将其扩展到N=100000,程序预计需要16秒,这在理论限制内,因此消除j可能没有太大帮助(但这可能会使代码真正难以理解)。我们可以接受这是最后的解决办法。在


注意:如果您使用的是python2,a @ b相当于a.dot(b)。在

相关问题 更多 >