如何在numpy中加速乘法和求和运算

2024-05-28 21:50:06 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要解决一个有限元方法问题,并且必须从AB用一个大的MM>1M)计算以下C。比如说,

import numpy as np
M=4000000
A=np.random.rand(4, M, 3)
B=np.random.rand(M,3)
C = (A * B).sum(axis = -1) # need to be optimized

谁能想出一个比(A * B).sum(axis = -1)更快的代码?您可以自由地重塑或重新排列ABC的轴


Tags: to方法代码importnumpyasnprandom
3条回答

一般来说,为了加速numpy乘法,一种可能的方法是使用ctypes。然而,据我所知,这种方法可能会给您带来有限的性能改进(如果有的话)

您可以像这样使用NumExpr来实现3倍的加速:

import numpy as np
import numexpr as ne

M=40000
A=np.random.rand(4, M, 3)
B=np.random.rand(M,3)

%timeit out = (A * B).sum(axis = -1)
2.12 ms ± 57.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit me = ne.evaluate('sum(A*B,2)')
662 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


out = (A * B).sum(axis = -1)
me = numexpr.evaluate('sum(A*B,2)')
np.allclose(out,me)
Out[29]: True

在性能和内存使用方面,您可以使用^{}来实现稍微更高效的方法:

M=40000
A=np.random.rand(4, M, 3)
B=np.random.rand(M,3)
out = (A * B).sum(axis = -1) # need to be optimized

%timeit (A * B).sum(axis = -1) # need to be optimized
# 5.23 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit np.einsum('ijk,jk->ij', A, B)
# 1.31 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

np.allclose(out, np.einsum('ijk,jk->ij', A, B))
# True

相关问题 更多 >

    热门问题