ndarray矩阵点积出现段错误
- 我正在对一个有50000行和100列的矩阵与它的转置进行点乘。这个矩阵里的值是浮点数。
A(50000, 100) B(100, 50000)
- 其实我是在对一个更大的稀疏矩阵进行奇异值分解(SVD)后得到这个矩阵的。
- 这个矩阵是numpy.ndarray类型。
我使用numpy的dot方法来乘这两个矩阵,但出现了段错误。
numpy.dot(A, B)
在处理30000行的矩阵时,点乘可以正常工作,但在50000行时就失败了。
- numpy的点乘有没有什么限制呢?
- 在使用点乘时,有没有什么问题呢?
- 有没有其他好的Python线性代数工具,能在处理大矩阵时更高效呢?
1 个回答
3
正如你所听说的,存在一个内存问题。你想要做的是:
numpy.dot(A, A.T)
这个操作需要大量的内存来存储结果(而不是操作数)。不过,这个操作可以分成小块来进行,比较简单。你可以使用循环的方式,一次生成一行输出:
def trans_multi(A):
rows = A.shape[0]
result = numpy.empty((rows, rows), dtype=A.dtype)
for r in range(rows):
result[r,:] = numpy.dot(A, A[r,:].T)
return result
这样做虽然速度较慢,但内存消耗是一样的(numpy.dot
已经经过很好的优化)。不过,你可能更想把结果写入一个文件,因为你没有足够的内存来保存结果:
def trans_multi(A, filename):
with open(filename, "wb") as f:
rows = A.shape[0]
for r in range(rows):
f.write(numpy.dot(A, A[r,:].T).tostring())
是的,这样做的速度并不是特别快。不过,这可能是你能期待的最快速度了。顺序写入通常是经过很好的优化的。我试过:
a=numpy.random.random((50000,100)).astype('float32')
trans_multi(a,"/tmp/large.dat")
这个过程大约花了60秒,但具体时间还得看你的硬盘性能。
那为什么不使用memmap呢?
我喜欢mmap
,而numpy.memmap
也是个很不错的工具。不过,numpy.memmap
是为了处理大表格并从中计算小结果而优化的。例如,有memmap.dot
,它是专门用来计算内存映射数组的点积的。这里的情况是,操作数是内存映射的,但结果却是在内存中。正好相反。
内存映射在你需要随机访问时非常有用。但在这里,访问并不是随机的,而是顺序写入。而且,如果你尝试使用numpy.memmap
来创建一个(50000,50000)的float32数组,会花费一些时间(我不太明白为什么,也许是因为它在初始化数据,尽管其实没必要)。
不过,文件创建好之后,使用numpy.memmap
来分析这个巨大的表格是个很好的主意,因为它提供了最佳的随机读取性能和非常方便的接口。