这个梯度下降算法有什么问题?

2024-04-26 06:13:22 发布

您现在位置:Python中文网/ 问答频道 /正文

X_train已经使用StandardScaler()进行了规范化,并且分类列已经转换为一个热编码。你知道吗

X_train.shape=(32000, 37)

我使用下面的代码使用梯度下降法计算w的值

w = np.zeros(len(X_train.columns))
learning_rate = 0.001    
for t in range(1000):
    Yhat = X.dot(w)
    delta = Yhat - Y_train
    w = w - learning_rate*X_train.T.dot(delta)

我的w向量爆炸(即增长非常快),并且w的每个条目变成NaN。我试着把历元数减少到10、15、20等等,我发现w的每个元素都是发散的而不是汇聚的。你知道吗

我试着使用正态方程,在这种情况下w确实很好(为了可读性增加了换行符):

w_found_using_normal_eqns = [ 3.53175449e-14  1.27924991e-14 -5.42441539e-14
9.91098366e-16 -2.31752259e-14 -6.21205773e-13  1.66139358e-13
2.72739782e-13 -1.65076881e-13 -1.25280166e-14 -1.98905983e-14  3.78837632e-13
-1.39424696e-12 -6.48511452e-15  1.58136412e-14  1.39778439e-12
-1.06142667e-14  3.00624557e-14 -1.70159700e-15 -6.91500349e-15 -4.04842208e-15
2.37516654e-16  3.25211677e+01 -2.86074823e+01 -2.86074823e+01
-2.86074823e+01 -2.86074823e+01 -2.86074823e+01 -2.86074823e+01 -2.86074823e+01 
3.55024823e+01  3.55024823e+01 3.55024823e+01  3.55024823e+01  
3.55024823e+01  3.55024823e+01 3.55024823e+01]

如果我用正规方程来解w,那么r^2错误就是1。你知道吗


Tags: 代码编码ratenp分类train规范化dot
1条回答
网友
1楼 · 发布于 2024-04-26 06:13:22

梯度下降权值更新公式通过训练集大小标准化。你知道吗

在最后一行中,您需要将学习速率除以训练集大小。你知道吗

修复代码:

w = w - (learning_rate/X_train.shape) * X_train.T.dot(delta)

相关问题 更多 >