大规模稀疏CSR矩阵乘法结果中的错误

2 投票
1 回答
45 浏览
提问于 2025-04-14 16:40

这让我很困惑,这是一个已知的bug吗?还是我漏掉了什么?如果真有bug,有没有办法绕过它呢?


假设我有一个相对较小的二进制矩阵(只有0和1),它的大小是n x q,使用的是scipy.sparse.csr_matrix,比如:

import numpy as np
from scipy import sparse

def get_dummies(vec, vec_max):
    vec_size = vec.size
    Z = sparse.csr_matrix((np.ones(vec_size), (np.arange(vec_size), vec)), shape=(vec_size, vec_max), dtype=np.uint8)
    return Z

q = 100
ns = np.round(np.random.random(q)*100).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<5171x100 sparse matrix of type '<class 'numpy.uint8'>'
    with 5171 stored elements in Compressed Sparse Row format>

这里的Z是一个标准的虚拟变量矩阵,包含5171个观察值和100个变量:

Z[:5, :5].toarray()
array([[1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0]], dtype=uint8)

例如,如果前5个变量的...

ns[:5]
array([21, 22, 37, 24, 99], dtype=int16)

频率,我们在Z的列总和中也能看到这些频率:

Z[:, :5].sum(axis=0)
matrix([[21, 22, 37, 24, 99]], dtype=uint64)

现在,按照预期,如果我计算Z.T @ Z,我应该得到一个q x q的对角矩阵,矩阵对角线上是这q个变量的频率:

print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray()
(100, 100)
[[21  0  0  0  0]
 [ 0 22  0  0  0]
 [ 0  0 37  0  0]
 [ 0  0  0 24  0]
 [ 0  0  0  0 99]]

现在说说这个bug:如果n真的很大(对我来说,大约在n = 100K时就会出现):

q = 1000
ns = np.round(np.random.random(q)*1000).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<495509x1000 sparse matrix of type '<class 'numpy.uint8'>'
    with 495509 stored elements in Compressed Sparse Row format>

频率很大,Z的列总和也是正常的:

print(ns[:5])
Z[:, :5].sum(axis=0)
array([485, 756, 380,  87, 454], dtype=int16)
matrix([[485, 756, 380,  87, 454]], dtype=uint64)

但是Z.T @ Z的结果就出问题了!也就是说,我在对角线上没有得到正确的频率:

print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray())
(1000, 1000)
[[229   0   0   0   0]
 [  0 244   0   0   0]
 [  0   0 124   0   0]
 [  0   0   0  87   0]
 [  0   0   0   0 198]]

令人惊讶的是,结果和真实频率之间还有某种关系:

import matplotlib.pyplot as plt

plt.scatter(ns, (Z.T @ Z).diagonal())
plt.xlabel('real frequencies')
plt.ylabel('values on ZZ diagonal')
plt.show()

在这里输入图片描述

这是怎么回事呢?

我使用的是标准的colab:

import scipy as sc

print(np.__version__)
print(sc.__version__)
1.25.2
1.11.4

附言:显然,如果我只是想要Z.T @ Z的输出矩阵,还有更简单的方法可以得到,这只是一个非常简化的问题,谢谢。

1 个回答

2

你在使用 uint8 类型的变量在 get_dummies 函数中。

print(Z.dtype, (Z.T @ Z)[:5, :5].dtype)
# uint8,  uint8

所以结果出现了溢出。这个现象是因为结果是实际结果对256取余。如果你把数据类型改成 uint16,这个问题就会消失。

让人有点意外的是,sum 函数会提高数据类型的精度,但这已经在文档中说明,而且和 NumPy 的 sum 函数的行为是一致的。

撰写回答