大规模稀疏CSR矩阵乘法结果中的错误
这让我很困惑,这是一个已知的bug吗?还是我漏掉了什么?如果真有bug,有没有办法绕过它呢?
假设我有一个相对较小的二进制矩阵(只有0和1),它的大小是n x q,使用的是scipy.sparse.csr_matrix
,比如:
import numpy as np
from scipy import sparse
def get_dummies(vec, vec_max):
vec_size = vec.size
Z = sparse.csr_matrix((np.ones(vec_size), (np.arange(vec_size), vec)), shape=(vec_size, vec_max), dtype=np.uint8)
return Z
q = 100
ns = np.round(np.random.random(q)*100).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<5171x100 sparse matrix of type '<class 'numpy.uint8'>'
with 5171 stored elements in Compressed Sparse Row format>
这里的Z
是一个标准的虚拟变量矩阵,包含5171个观察值和100个变量:
Z[:5, :5].toarray()
array([[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0]], dtype=uint8)
例如,如果前5个变量的...
ns[:5]
array([21, 22, 37, 24, 99], dtype=int16)
频率,我们在Z
的列总和中也能看到这些频率:
Z[:, :5].sum(axis=0)
matrix([[21, 22, 37, 24, 99]], dtype=uint64)
现在,按照预期,如果我计算Z.T @ Z
,我应该得到一个q x q的对角矩阵,矩阵对角线上是这q个变量的频率:
print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray()
(100, 100)
[[21 0 0 0 0]
[ 0 22 0 0 0]
[ 0 0 37 0 0]
[ 0 0 0 24 0]
[ 0 0 0 0 99]]
现在说说这个bug:如果n真的很大(对我来说,大约在n = 100K时就会出现):
q = 1000
ns = np.round(np.random.random(q)*1000).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<495509x1000 sparse matrix of type '<class 'numpy.uint8'>'
with 495509 stored elements in Compressed Sparse Row format>
频率很大,Z
的列总和也是正常的:
print(ns[:5])
Z[:, :5].sum(axis=0)
array([485, 756, 380, 87, 454], dtype=int16)
matrix([[485, 756, 380, 87, 454]], dtype=uint64)
但是Z.T @ Z
的结果就出问题了!也就是说,我在对角线上没有得到正确的频率:
print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray())
(1000, 1000)
[[229 0 0 0 0]
[ 0 244 0 0 0]
[ 0 0 124 0 0]
[ 0 0 0 87 0]
[ 0 0 0 0 198]]
令人惊讶的是,结果和真实频率之间还有某种关系:
import matplotlib.pyplot as plt
plt.scatter(ns, (Z.T @ Z).diagonal())
plt.xlabel('real frequencies')
plt.ylabel('values on ZZ diagonal')
plt.show()
这是怎么回事呢?
我使用的是标准的colab:
import scipy as sc
print(np.__version__)
print(sc.__version__)
1.25.2
1.11.4
附言:显然,如果我只是想要Z.T @ Z
的输出矩阵,还有更简单的方法可以得到,这只是一个非常简化的问题,谢谢。
1 个回答
2
你在使用 uint8
类型的变量在 get_dummies
函数中。
print(Z.dtype, (Z.T @ Z)[:5, :5].dtype)
# uint8, uint8
所以结果出现了溢出。这个现象是因为结果是实际结果对256取余。如果你把数据类型改成 uint16
,这个问题就会消失。
让人有点意外的是,sum
函数会提高数据类型的精度,但这已经在文档中说明,而且和 NumPy 的 sum
函数的行为是一致的。