使用numba计算内积的正确方法

3 投票
1 回答
3358 浏览
提问于 2025-04-18 18:49

我正在尝试计算两个大矩阵的内积。看起来在用 numpy 计算点积时,它会创建矩阵的副本,这让我遇到了一些内存问题。经过一番搜索,我发现 numba 这个包很有前景。不过我没能让它正常工作。以下是我的代码:

import numpy as np
from numba import jit
import time, contextlib



@contextlib.contextmanager
def timeit():
    t=time.time()
    yield
    print(time.time()-t,"sec")


def dot1(a,b):
    return np.dot(a,b)

@jit(nopython=True)
def dot2(a,b):
    n = a.shape[0]
    m = b.shape[1]
    K = b.shape[0]
    c = np.zeros((n,m))
    for i in xrange(n):
        for j in xrange(m):
            for k in range(K):
                c[i,j] += a[i,k]*b[k,j]

    return c



def main():
    a = np.random.random((200,1000))
    b = np.random.random((1000,400))

    with timeit():
        c1 = dot1(a,b)
    with timeit():
        c2 = dot2(a,b)

运行时间如下:

dot1:
(0.034691810607910156, 'sec')

dot2:
(0.9215810298919678, 'sec')

有没有人能告诉我我在这里漏掉了什么?

1 个回答

1

你的算法是个简单粗暴的算法。BLAS实现了一种更快的算法。

引用维基百科的矩阵乘法页面:

不过,这种算法在一些库中出现,比如BLAS,对于维度大于100的矩阵,它的效率要高得多。

撰写回答