我的代码中有什么错误是错误的,错误会随着梯度下降的每次迭代而不断增加?

2024-04-23 21:07:51 发布

您现在位置:Python中文网/ 问答频道 /正文

下面的代码读取一个csv(Andrew NG ML course ex1 multivariable linear regression exercise data file),然后尝试使用学习率alpha=0.01将线性模型拟合到数据集。梯度下降法是将参数(θ向量)递减400倍(问题陈述中给出了迭代次数的α和数值)。 我尝试了一个矢量化的实现来获得参数的最佳值,但是下降并没有收敛-误差不断增加。你知道吗

# Imports


```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
```

# Model Preparation

## Gradient descent


```python
def gradient_descent(m, theta, alpha, num_of_iterations, X, Y):
#     print(m, theta, alpha, num_of_iterations)
    for i in range(num_of_iterations):
        htheta_vector = np.dot(X,theta)
#         print(X.shape, theta.shape, htheta_vector.shape)
        error_vector = htheta_vector - Y
        gradient_vector = (1/m) * (np.dot(X.T, error_vector)) # each element in gradient_vector corresponds to each theta
        theta = theta - alpha * gradient_vector

    return theta
```

# Main


```python
def main():
    df = pd.read_csv('data2.csv', header = None) #loading data
    data = df.values # converting dataframe to numpy array

    X = data[:, 0:2]
#     print(X.shape)
    Y = data[:, -1]

    m = (X.shape)[0] # number of training examples

    Y = Y.reshape(m, 1)

    ones = np.ones(shape = (m,1))
    X_with_bias = np.concatenate([ones, X], axis = 1)

    theta = np.zeros(shape = (3,1)) # two features, so three parameters

    alpha = 0.001
    num_of_iterations = 400

    theta = gradient_descent(m, theta, alpha, num_of_iterations, X_with_bias, Y) # calling gradient descent
#     print('Parameters learned: ' + str(theta))

if __name__ == '__main__':
    main()
```

错误:

    /home/krish-thorcode/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:8: RuntimeWarning: invalid value encountered in subtract

Error values for different iterations:

Iteration 1 [[-399900.] [-329900.] [-369000.] [-232000.] [-539900.] [-299900.] [-314900.] [-198999.] [-212000.] [-242500.] [-239999.] [-347000.] [-329999.] [-699900.] [-259900.] [-449900.] [-299900.] [-199900.] [-499998.] [-599000.] [-252900.] [-255000.] [-242900.] [-259900.] [-573900.] [-249900.] [-464500.] [-469000.] [-475000.] [-299900.] [-349900.] [-169900.] [-314900.] [-579900.] [-285900.] [-249900.] [-229900.] [-345000.] [-549000.] [-287000.] [-368500.] [-329900.] [-314000.] [-299000.] [-179900.] [-299900.] [-239500.]]

Iteration 2 [[1.60749981e+09] [1.22240841e+09] [1.83373661e+09] [1.08189071e+09] [2.29209231e+09] [1.51666004e+09] [1.17198560e+09] [1.09033113e+09] [1.05440030e+09] [1.14148964e+09] [1.48233053e+09] [1.52807496e+09] [1.44402895e+09] [3.42143452e+09] [9.68760976e+08] [1.75723592e+09] [1.00845873e+09] [9.44366284e+08] [1.99332644e+09] [2.31572369e+09] [1.35010833e+09] [1.44257442e+09] [1.22555224e+09] [1.49912323e+09] [2.97220331e+09] [8.40383843e+08] [1.11375611e+09] [1.92992696e+09] [1.68078878e+09] [2.01492327e+09] [1.40503327e+09] [7.64040689e+08] [1.55867654e+09] [2.39674784e+09] [1.38370165e+09] [1.09792232e+09] [9.46628911e+08] [1.62895368e+09] [3.22059730e+09] [1.65193796e+09] [1.27127807e+09] [1.70997383e+09] [1.96141565e+09] [9.16755655e+08] [6.50928858e+08] [1.41502023e+09] [9.19107783e+08]]

Iteration 3 [[-7.42664624e+12] [-5.64764378e+12] [-8.47145714e+12] [-4.99816153e+12] [-1.05893224e+13] [-7.00660901e+12] [-5.41467917e+12] [-5.03699402e+12] [-4.87109500e+12] [-5.27348843e+12] [-6.84776945e+12] [-7.05955046e+12] [-6.67127611e+12] [-1.58063228e+13] [-4.47576119e+12] [-8.11848565e+12] [-4.65930400e+12] [-4.36280860e+12] [-9.20918360e+12] [-1.06987452e+13] [-6.23711474e+12] [-6.66421140e+12] [-5.66176276e+12] [-6.92542434e+12] [-1.37308096e+13] [-3.88276038e+12] [-5.14641706e+12] [-8.91620784e+12] [-7.76550392e+12] [-9.30801176e+12] [-6.49125293e+12] [-3.52977344e+12] [-7.20074619e+12] [-1.10728954e+13] [-6.39242960e+12] [-5.07229174e+12] [-4.37339793e+12] [-7.52548475e+12] [-1.48779889e+13] [-7.63137769e+12] [-5.87354379e+12] [-7.89963490e+12] [-9.06093321e+12] [-4.23573710e+12] [-3.00737309e+12] [-6.53715005e+12] [-4.24632634e+12]]

Iteration 4 [[3.43099835e+16] [2.60912608e+16] [3.91368523e+16] [2.30907512e+16] [4.89210695e+16] [3.23694753e+16] [2.50149995e+16] [2.32701516e+16] [2.25037231e+16] [2.43627199e+16] [3.16356608e+16] [3.26140566e+16] [3.08202877e+16] [7.30228235e+16] [2.06773403e+16] [3.75061770e+16] [2.15252802e+16] [2.01555166e+16] [4.25450367e+16] [4.94265862e+16] [2.88145280e+16] [3.07876502e+16] [2.61564888e+16] [3.19944145e+16] [6.34342666e+16] [1.79377661e+16] [2.37756683e+16] [4.11915330e+16] [3.58754545e+16] [4.30016088e+16] [2.99886077e+16] [1.63070200e+16] [3.32663597e+16] [5.11551035e+16] [2.95320591e+16] [2.34332215e+16] [2.02044376e+16] [3.47666027e+16] [6.87340617e+16] [3.52558124e+16] [2.71348846e+16] [3.64951201e+16] [4.18601431e+16] [1.95684650e+16] [1.38936092e+16] [3.02006457e+16] [1.96173860e+16]]

Iteration 5 [[-1.58506940e+20] [-1.20537683e+20] [-1.80806345e+20] [-1.06675782e+20] [-2.26007951e+20] [-1.49542086e+20] [-1.15565519e+20] [-1.07504585e+20] [-1.03963801e+20] [-1.12552086e+20] [-1.46151974e+20] [-1.50672014e+20] [-1.42385073e+20] [-3.37354413e+20] [-9.55261885e+19] [-1.73272871e+20] [-9.94435428e+19] [-9.31154420e+19] [-1.96551642e+20] [-2.28343362e+20] [-1.33118767e+20] [-1.42234293e+20] [-1.20839027e+20] [-1.47809362e+20] [-2.93056729e+20] [-8.28697695e+19] [-1.09839996e+20] [-1.90298660e+20] [-1.65739180e+20] [-1.98660937e+20] [-1.38542837e+20] [-7.53359691e+19] [-1.53685556e+20] [-2.36328850e+20] [-1.36433652e+20] [-1.08257943e+20] [-9.33414495e+19] [-1.60616452e+20] [-3.17540981e+20] [-1.62876527e+20] [-1.25359067e+20] [-1.68601941e+20] [-1.93387537e+20] [-9.04033523e+19] [-6.41863754e+19] [-1.39522421e+20] [-9.06293597e+19]]

Iteration 83 [[-1.09904300e+306] [-8.35774743e+305] [-1.25366087e+306] [-7.39660179e+305] [-1.56707622e+306] [-1.03688320e+306] [-8.01299137e+305] [-7.45406868e+305] [-7.20856058e+305] [-7.80404831e+305] [-1.01337710e+306] [-1.04471781e+306] [-9.87258464e+305] [-2.33912159e+306] [-6.62352000e+305] [-1.20142586e+306] [-6.89513844e+305] [-6.45636555e+305] [-1.36283437e+306] [-1.58326931e+306] [-9.23008472e+305] [-9.86212994e+305] [-8.37864174e+305] [-1.02486897e+306] [-2.03197378e+306] [-5.74595914e+305] [-7.61599955e+305] [-1.31947793e+306] [-1.14918934e+306] [-1.37745963e+306] [-9.60617469e+305] [-5.22358639e+305] [-1.06561287e+306] [-1.63863846e+306] [-9.45992963e+305] [-7.50630445e+305] [-6.47203628e+305] [-1.11366977e+306] [-2.20174077e+306] [-1.12934050e+306] [-8.69204879e+305] [-1.16903893e+306] [-1.34089535e+306] [-6.26831680e+305] [-4.45050460e+305] [-9.67409627e+305] [-6.28398753e+305]]

Iteration84 [[inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf]
[inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf]]


Tags: ofcsvalphadatanpnuminfprint
2条回答

请尝试使用特征规范化来解决此问题。只是特征值是一个很大的数字,当值很大时,代价函数(平方误差)会以很快的速度增加。一般来说,当您试图最小化非线性代价函数时,执行平均标准化和特征缩放。你知道吗

进行特征规范化。Asummingthis是您的数据集,X的第一个维度以千为单位,第二个维度以万为单位,Y以十万为单位。 使用sklearn.preprocessing.scale将所有数据列和目标设置为[0,1],也可以使用此选项 脏标准化:

 X[:,0] = X[:,0] / np.max( X[:,0])

 X[:,1] = X[:,1] / np.max( X[:,1])

 Y = Y / np.max(Y)

我用这些规范化程序重新运行你的代码。θ收敛到 [ 0.81705857], [ 0.98398577], [ 0.98398577]

为将来的问题提供数据文件或数据框摘要的链接。你知道吗

相关问题 更多 >