PyTorch梯度解算器和SciPy稀疏矩阵求解器的结果不同。

2024-04-19 09:37:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在pytorch中实现baseline als减法,这样我就可以在我的GPU上运行它,但是我遇到了一些问题,因为皮托奇.gesv给出的结果与西皮.linalg.spsolve. 这是我为scipy编写的代码:

def baseline_als(y, lam, p, niter=10):
  L = len(y)
  D = sparse.diags([1,-2,1],[0,-1,-2], shape=(L,L-2))
  w = np.ones(L)
  for i in range(niter):
    W = sparse.spdiags(w, 0, L, L)
    Z = W + lam * D.dot(D.transpose())
    z = spsolve(Z, w*y)
    w = p * (y > z) + (1-p) * (y < z)
  return z

这是我的密码

^{pr2}$

抱歉,Pythorch代码看起来很糟糕,我才刚开始。在

我已经确认了,对于scipy和pytorch来说Z,w,y都是一样的,在我试图解方程组之后,它们之间的Z是不同的。在

感谢您的评论,下面是一个例子:

我用10万计算lam,0.001计算p

使用虚拟输入:y=(5,5,5,5,5,10,10,5,5,5,5,10,10,10,10,10,5,5,5,5,5,5)

我从scipy得到(3.68010263,4.90344214,6.12679489,7.35022406,8.57384278,9.79774074,11.021971199,12.2465927,13.47164891,14.69711435,15.92287813,17.14873257,18.37456982,19.60038184,20.82626043,22.0521557,23.27805103,24.50400438,25.73010693,26.95625922)

(6.4938312、6.46912395、6.44440175、6.41963499、6.39477958、6.36977727、6.34455582、6.31907933、6.29334844、6.26735058、6.24106029、6.21443939、6.18748732、6.16024137、6.13277694、6.10515785、6.07743658、6.04965455、6.02184242、5.99402035)。在

这只是循环的一次迭代。西皮是对的,皮托克不是。在

有趣的是,如果我使用一个较短的虚拟输入(5,5,5,5,10,10,5,5,5,5),两者的答案都是一样的。我的实际输入是1011维。在


Tags: 代码lengpudefpytorchscipysparselinalg
1条回答
网友
1楼 · 发布于 2024-04-19 09:37:19

你的pytorch函数是错误的(你从来没有在for循环的第一行更新W),而且我得到了你说的从Scipy得到的pytorch的结果。在

Scipy版本

def baseline_als(y, lam=100000, p=1e-3, niter=1):
    L = len(y)
    D = sparse.diags([1,-2,1],[0,-1,-2], shape=(L,L-2))
    w = np.ones(L)
    for i in range(niter):
        W = sparse.spdiags(w, 0, L, L)
        Z = W + lam * D.dot(D.transpose())
        z = spsolve(Z, w*y)
        w = p * (y > z) + (1-p) * (y < z)
    return z

相当于Pythorch

^{pr2}$

当我用y = np.array([5,5,5,5,5,10,10,5,5,5,10,10,10,5,5,5,5,5,5,5], dtype='float64')喂它们时:

西皮:

array([6.4938312 , 6.46912395, 6.44440175, 6.41963499, 6.39477958,
       6.36977727, 6.34455582, 6.31907933, 6.29334844, 6.26735058,
       6.24106029, 6.21443939, 6.18748732, 6.16024137, 6.13277694,
       6.10515785, 6.07743658, 6.04965455, 6.02184242, 5.99402035])

Pythorch公司:

tensor([6.4938, 6.4691, 6.4444, 6.4196, 6.3948, 6.3698, 6.3446, 6.3191, 6.2933,
        6.2674, 6.2411, 6.2144, 6.1875, 6.1602, 6.1328, 6.1052, 6.0774, 6.0497,
        6.0218, 5.9940], dtype=torch.float64)

如果我将n_iter增加到10:

西皮:

array([5.00202571, 5.00199038, 5.00195504, 5.00191963, 5.0018841 ,
       5.00184837, 5.00181235, 5.00177598, 5.00173927, 5.00170221,
       5.00166475, 5.00162685, 5.00158851, 5.00154979, 5.00151077,
       5.00147155, 5.0014322 , 5.00139276, 5.00135329, 5.0013138 ])

Pythorch公司:

tensor([5.0020, 5.0020, 5.0020, 5.0019, 5.0019, 5.0018, 5.0018, 5.0018, 5.0017,
        5.0017, 5.0017, 5.0016, 5.0016, 5.0015, 5.0015, 5.0015, 5.0014, 5.0014,
        5.0014, 5.0013], dtype=torch.float64)

它与你在问题中链接的基线als的代码相匹配。在

相关问题 更多 >