numpy矩阵乘法的基础是什么?我的理解是它使用BLAS,但是我在Jupyter笔记本中执行矩阵乘法时得到的运行时与在c++中直接调用BLAS::gemm时得到的运行时非常不同。例如,此代码:
import numpy
import time
N = 1000
mat = numpy.random.rand(N, N)
start = time.perf_counter()
m = mat@mat
print(time.perf_counter()-start)
在我的笔记本电脑上运行大约0.25秒。此代码:
float *A_ = new float[n*n], *C_ = new float[n*n];
for (long i = 0; i < n*n; i++)
A_[i] = 1.1;
const char tran = 'N'; float alpha = 1, beta = 0;
st = clock();
blas::gemm(&tran, &tran, &n, &n, &n, &alpha, A_, &n, A_, &n, &beta, C_, &n);
fn = clock();
cout << float(fn-st)/float(CLOCKS_PER_SEC) << endl;
运行大约需要0.85秒。所以,我猜numpy在使用blas::gemm之外的其他东西。有人知道它在引擎盖下用什么吗?提前谢谢
目前没有回答
相关问题 更多 >
编程相关推荐