有效地乘法Numpy/Scipy稀疏稠密矩阵

2024-03-28 14:50:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在努力实现以下等式:

X =(Y.T * Y + Y.T * C * Y) ^ -1

Y是(n x f)矩阵,C是(n x n)对角矩阵;n约为300k,f在100到200之间变化。作为优化过程的一部分,这个方程将被使用近1亿次,因此必须处理得非常快。

Y是随机初始化的,C是一个非常稀疏的矩阵,在300k的对角线上只有几个数不同于0。由于Numpy的对角线函数创建了稠密矩阵,所以我创建了C作为稀疏csr矩阵。但是当你试图解方程的第一部分时:

r = dot(C, Y)

由于内存限制,计算机崩溃。我决定将Y转换成csr_matrix并执行相同的操作:

r = dot(C, Ysparse)

该方法耗时1.38ms。但是这个解决方案有点“棘手”,因为我使用稀疏矩阵来存储稠密矩阵,我想知道这到底有多有效。

所以我的问题是,是否有某种方法可以将稀疏C和稠密Y相乘,而不必将Y变为稀疏并提高性能?如果C可以被表示为对角线密集型而不消耗大量内存,那么这可能会导致非常高效的性能,但我不知道这是否可能。

谢谢你的帮助!


Tags: 方法函数内存numpy过程计算机矩阵性能
3条回答

当计算r=dot(C,Y)时,dot乘积会遇到内存问题,原因是numpy的dot函数不支持处理稀疏矩阵。现在发生的是numpy将稀疏矩阵C看作一个python对象,而不是numpy数组。如果你在小范围内检查,你可以直接看到问题:

>>> from numpy import dot, array
>>> from scipy import sparse
>>> Y = array([[1,2],[3,4]])
>>> C = sparse.csr_matrix(array([[1,0], [0,2]]))
>>> dot(C,Y)
array([[  (0, 0)    1
  (1, 1)    2,   (0, 0) 2
  (1, 1)    4],
  [  (0, 0) 3
  (1, 1)    6,   (0, 0) 4
  (1, 1)    8]], dtype=object)

很明显,以上不是你感兴趣的结果。相反,您要做的是使用scipy的sparse.csr_matrix.dot函数进行计算:

r = sparse.csr_matrix.dot(C, Y)

或者更紧凑

r = C.dot(Y)

首先,你真的确定你的问题需要进行一个完整的矩阵反演吗?大多数情况下,只需要计算x=A^-1y,这是一个更容易解决的问题。

如果真是这样,我会考虑计算逆矩阵的近似值,而不是全矩阵的逆。因为矩阵反演的成本很高。有关逆矩阵的有效近似,请参见Lanczos algorithm示例。近似值可以作为奖励稀疏存储。另外,它只需要矩阵向量运算,所以您甚至不必存储整个矩阵来进行逆运算。

另一种方法是,使用pyoperators,还可以使用to.todense方法使用有效的矩阵向量运算来计算要求逆的矩阵。对角矩阵有一个特殊的稀疏容器。

对于Lanczos算法的实现,您可以查看pyoperators(免责声明:我是这一软件的合著者之一)。

尝试:

import numpy as np
from scipy import sparse

f = 100
n = 300000

Y = np.random.rand(n, f)
Cdiag = np.random.rand(n) # diagonal of C
Cdiag[np.random.rand(n) < 0.99] = 0

# Compute Y.T * C * Y, skipping zero elements
mask = np.flatnonzero(Cdiag)
Cskip = Cdiag[mask]

def ytcy_fast(Y):
    Yskip = Y[mask,:]
    CY = Cskip[:,None] * Yskip  # broadcasting
    return Yskip.T.dot(CY)

%timeit ytcy_fast(Y)

# For comparison: all-sparse matrices
C_sparse = sparse.spdiags([Cdiag], [0], n, n)
Y_sparse = sparse.csr_matrix(Y)
%timeit Y_sparse.T.dot(C_sparse * Y_sparse)

我的时间安排:

In [59]: %timeit ytcy_fast(Y)
100 loops, best of 3: 16.1 ms per loop

In [18]: %timeit Y_sparse.T.dot(C_sparse * Y_sparse)
1 loops, best of 3: 282 ms per loop

相关问题 更多 >