在numpy/pandas中进行矩阵乘法的多线程?
我真的很想知道如何在numpy/pandas中利用多核处理来进行矩阵乘法。
我现在尝试的代码如下:
M = pd.DataFrame(...) # super high dimensional square matrix.
A = M.T.dot(M)
这个过程耗时很长,因为需要进行很多的乘法和加法。我觉得对于大矩阵的乘法,使用多线程应该很简单。所以我仔细在网上查找,但找不到如何在numpy/pandas中做到这一点。我是不是需要手动编写多线程代码,使用一些Python自带的线程库呢?
2 个回答
2
首先,我建议你把数据转换成“波浪数组”,然后使用numpy的点乘函数。如果你有一个MKL版本的库,这个库目前是最快的实现之一,你可以试着设置一个环境变量叫做 OMP_NUM_THREADS
。这样可以激活你系统的其他处理器核心。在我的MAC上,这样做似乎效果很好。此外,我还建议你使用 np.einsum
,因为它的速度似乎比 np.dot
更快。
不过要注意!如果你编译了一个使用OpenMP进行并行处理的多线程库(比如MKL),你需要知道,所有苹果系统上的“默认gcc”其实不是gcc,而是Clang/LLVM,而Clang目前不支持OpenMP,除非你使用的是仍在实验阶段的OpenMP版本。所以你需要安装英特尔编译器或者其他支持OpenMP的编译器。
3
在NumPy中,想要实现多线程的矩阵乘法,可以使用一种叫做BLAS的基础线性代数子程序的多线程版本。你需要做到以下几点:
- 首先,你得有这样的BLAS实现,比如OpenBLAS、ATLAS和MKL都支持多线程的矩阵乘法。
- 然后,你需要确保你的NumPy是用这种实现编译的。
- 最后,确保你要相乘的矩阵的类型是
float32
或float64
(并且满足某些对齐要求;我建议使用NumPy 1.7.1或更高版本,因为这些要求在新版本中有所放宽)。
不过,有几点需要注意:
- 旧版本的OpenBLAS在用GCC编译时,如果程序使用了
multiprocessing
(这包括大多数使用joblib
的应用),会出现问题,程序可能会卡住。原因是GCC中的一个bug(或者说缺少某个功能)。虽然已经提交了一个补丁,但还没有被纳入主版本中。 - 在典型的Linux发行版中找到的ATLAS包,可能没有编译成支持多线程的版本。
至于Pandas:我不太确定它是如何进行点乘的。为了确保,可以先转换为NumPy数组,然后再转换回来。