在numpy/pandas中进行矩阵乘法的多线程?

9 投票
2 回答
5657 浏览
提问于 2025-04-18 01:16

我真的很想知道如何在numpy/pandas中利用多核处理来进行矩阵乘法。

我现在尝试的代码如下:

M = pd.DataFrame(...) # super high dimensional square matrix.
A = M.T.dot(M) 

这个过程耗时很长,因为需要进行很多的乘法和加法。我觉得对于大矩阵的乘法,使用多线程应该很简单。所以我仔细在网上查找,但找不到如何在numpy/pandas中做到这一点。我是不是需要手动编写多线程代码,使用一些Python自带的线程库呢?

2 个回答

2

首先,我建议你把数据转换成“波浪数组”,然后使用numpy的点乘函数。如果你有一个MKL版本的库,这个库目前是最快的实现之一,你可以试着设置一个环境变量叫做 OMP_NUM_THREADS。这样可以激活你系统的其他处理器核心。在我的MAC上,这样做似乎效果很好。此外,我还建议你使用 np.einsum,因为它的速度似乎比 np.dot 更快。

不过要注意!如果你编译了一个使用OpenMP进行并行处理的多线程库(比如MKL),你需要知道,所有苹果系统上的“默认gcc”其实不是gcc,而是Clang/LLVM,而Clang目前不支持OpenMP,除非你使用的是仍在实验阶段的OpenMP版本。所以你需要安装英特尔编译器或者其他支持OpenMP的编译器。

3

在NumPy中,想要实现多线程的矩阵乘法,可以使用一种叫做BLAS的基础线性代数子程序的多线程版本。你需要做到以下几点:

  1. 首先,你得有这样的BLAS实现,比如OpenBLAS、ATLAS和MKL都支持多线程的矩阵乘法。
  2. 然后,你需要确保你的NumPy是用这种实现编译的。
  3. 最后,确保你要相乘的矩阵的类型是float32float64(并且满足某些对齐要求;我建议使用NumPy 1.7.1或更高版本,因为这些要求在新版本中有所放宽)。

不过,有几点需要注意:

  • 旧版本的OpenBLAS在用GCC编译时,如果程序使用了multiprocessing(这包括大多数使用joblib的应用),会出现问题,程序可能会卡住。原因是GCC中的一个bug(或者说缺少某个功能)。虽然已经提交了一个补丁,但还没有被纳入主版本中。
  • 在典型的Linux发行版中找到的ATLAS包,可能没有编译成支持多线程的版本。

至于Pandas:我不太确定它是如何进行点乘的。为了确保,可以先转换为NumPy数组,然后再转换回来。

撰写回答