Python矩阵乘法与BLA

2024-04-19 18:17:16 发布

您现在位置:Python中文网/ 问答频道 /正文

numpy矩阵乘法的基础是什么?我的理解是它使用BLAS,但是我在Jupyter笔记本中执行矩阵乘法时得到的运行时与在c++中直接调用BLAS::gemm时得到的运行时非常不同。例如,此代码:

import numpy
import time

N = 1000
mat = numpy.random.rand(N, N)
start = time.perf_counter()
m = mat@mat
print(time.perf_counter()-start)

在我的笔记本电脑上运行大约0.25秒。此代码:

    float *A_ = new float[n*n], *C_ = new float[n*n];

    for (long i = 0; i < n*n; i++)
        A_[i] = 1.1;  

    const char tran = 'N';  float alpha = 1, beta = 0;
    st = clock();  
    blas::gemm(&tran, &tran, &n, &n, &n, &alpha, A_, &n, A_, &n, &beta, C_, &n);
    fn = clock();
    cout << float(fn-st)/float(CLOCKS_PER_SEC) << endl;

运行大约需要0.85秒。所以,我猜numpy在使用blas::gemm之外的其他东西。有人知道它在引擎盖下用什么吗?提前谢谢


Tags: 代码importnumpynewtimecounter矩阵float