如何加快这个矩阵乘法的速度

2024-04-23 20:49:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我试着用numba重现矩阵分解。代码如下:

import numpy as np
import timeit
from numba import jit, float64, prange


@jit('float64[:,:](float64[:,:],float64[:,:])', parallel=True, nopython=True)
def matmul(A, B):
    C = np.zeros((A.shape[0], B.shape[1]))
    for i in prange(A.shape[0]):
        for j in prange(B.shape[1]):
            for k in range(A.shape[0]):
                C[i,j] = C[i,j] + A[i,k]*B[k,j]
    return C



if __name__ == '__main__':
    m_size = 1000
    num_loops = 10
    A = np.random.rand(m_size, m_size)
    B = np.random.rand(m_size, m_size)

    # Numpy
    start = timeit.default_timer()
    for i in range(num_loops):
        A.dot(B)
    stop = timeit.default_timer()
    execution_time = stop - start
    print("Numpy Executed in ", execution_time)


    # Numba
    start = timeit.default_timer()
    for i in range(num_loops):
        matmul(A, B)
    stop = timeit.default_timer()
    execution_time = stop - start
    print("Numba Executed in ", execution_time) 

输出如下:

^{pr2}$

在a related post中,numba和numpy的表演非常接近。 我做错了什么?如何提高matmul函数的性能?在


Tags: inimportdefaultforsizetimenpstart