numpy矩阵乘法性能结果异常

2 投票
2 回答
843 浏览
提问于 2025-04-18 12:41

最近我发现了一个用numpy进行矩阵乘法的情况,表现得非常奇怪(至少对我来说是这样)。为了说明这个问题,我创建了一个这样的矩阵示例和一个简单的脚本来展示时间性能。你可以从这个仓库下载这两个文件,我这里不包含脚本,因为没有数据的话它没什么用。

这个脚本用不同的方法将两对矩阵相乘(每对矩阵在shapedtype上是相同的,只有数据不同),使用了dot函数和einsum。实际上,我注意到了一些异常现象:

  • 第一对矩阵(A * B)的乘法速度比第二对(C * D)快得多。
  • 当我把所有矩阵转换为float64类型时,两对矩阵的计算时间变得相同:比乘A * B的时间长,但比C * D的时间短。
  • 这些现象在einsum(我理解是numpy的实现)和dot(在我的机器上使用BLAS)中都存在。为了完整性,这个脚本在我笔记本上的输出是:
With np.dot:
A * B: 0.142910003662 s
C * D: 4.9057161808 s
A * D: 0.20524597168 s
C * B: 4.20220398903 s
A * B (to float32): 0.156805992126 s
C * D (to float32): 5.11792707443 s
A * B (to float64): 0.52608704567 s
C * D (to float64): 0.484733819962 s
A * B (to float64 to float32): 0.255760908127 s
C * D (to float64 to float32): 4.7677090168 s
With einsum:
A * B: 0.489732980728 s
C * D: 7.34477996826 s
A * D: 0.449800014496 s
C * B: 4.05954909325 s
A * B (to float32): 0.411967992783 s
C * D (to float32): 7.32073783875 s
A * B (to float64): 0.80580997467 s
C * D (to float64): 0.808521032333 s
A * B (to float64 to float32): 0.414498090744 s
C * D (to float64 to float32): 7.32472801208 s

这些结果怎么解释呢?怎么才能让C * D的乘法速度像A * B一样快呢?

2 个回答

1

Mark Dickinson已经回答了你的问题,不过为了好玩,试试这个:

Cp = np.array(list(C[:,0]))
Ap = np.array(list(A[:,0]))

这个方法去掉了拼接的延迟,并确保数组在内存中是相似的。

%timeit Cp * Cp   % 34.9 us per loop
%timeit Ap * Ap   % 3.59 us per loop

哎呀。

4

你看到的速度变慢是因为计算中涉及到了非标准数。很多处理器在进行包含非标准输入或输出的算术运算时会变得很慢。这里有几个相关的StackOverflow问题可以参考:可以看看这个关于C#的问题(特别是Eric Postpischil的回答),还有这个关于C++的问题的回答,获取更多信息。

在你的具体案例中,矩阵C(数据类型为float32)包含了几个非标准数。对于单精度浮点数,非标准数和标准数的分界线是2^-126,大约是1.18e-38。我看到的C是这样的:

>>> ((0 < abs(C)) & (abs(C) < 2.0**-126)).sum()  # number of subnormal entries
44694
>>> C.size
682450

所以大约6.5%的C的元素是非标准数,这已经足够让C*BC*D的乘法变慢了。相比之下,AB的值没有接近非标准数的边界:

>>> abs(A[A != 0]).min()
4.6801152e-12
>>> abs(B[B != 0]).min()
4.0640174e-07

因此,参与A*B矩阵乘法的中间值都不是非标准数,所以没有速度损失。

至于你问题的第二部分,我不太确定该建议什么。如果你努力尝试,并且使用的是x64/SSE2(而不是x87 FPU),你可以从Python中设置“清零”和“非标准数视为零”的标志。可以参考这个回答,这是一个粗糙且不便携的基于ctypes的解决方案;如果你真的想走这条路,写一个自定义的C扩展可能会更好。

我更倾向于尝试将C进行缩放,使其完全进入标准范围(同时也将C*D的各个乘积带入标准范围),但如果C的值在浮点数范围的上限附近,这可能就不太可能了。或者,简单地将C中的小值替换为零也许可以,但最终的准确性损失是否显著和/或可接受则取决于你的应用。

撰写回答