二类分类中的回归误差

2024-04-26 04:51:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我试着用g1作为0级,g2作为1级进行逻辑回归:

ft=np.vstack((g1,g2)) #data stacked on each other
class = np.hstack((np.zeros(f), np.ones(f))) #class values in column matrix
clc=np.reshape(cl,(2*f,1)) #class values in an array
w=np.zeros((2,1))   #weights matrix

for n in range(0,5000):
    s = np.dot(ft, w)

    prediction = (1 / (1 + np.exp(-s))) #sigmoid function

    gr = (np.dot(ft.T, class - prediction)) #gradient of loss function
    w += 0.01 * gr
print (w)

我使用sklearn评估我的结果:

from sklearn.linear_model import 

我得到:

w=[[6.77812323] [2.91052504]]

系数=[[1.22724506 1.10456893]

你知道为什么重量不匹配吗?我的数学有什么问题吗?你知道吗


Tags: innpzerosfunction逻辑sklearnmatrixdot
1条回答
网友
1楼 · 发布于 2024-04-26 04:51:09

你只是错过了计算平均梯度的步骤。还要注意,我使用的不是class,而是clc

N = len(ft)    
for _ in range(5000):
    s = np.dot(ft, w)
    prediction = 1 / (1 + np.exp(-s))  # sigmoid function

    gr = np.dot(ft.T, clc - prediction)  # gradient of loss function
    gr /= N  # calculate gradient average

    w += 0.01 * gr  # update weights

添加整个代码

import numpy as np
from sklearn.linear_model import LogisticRegression

np.random.seed(1)  # for reproducibility

f = 100
mean1 = [-5, -3]
cov1 = [[5, 0], [0, 3]]
mean2 = [4, 3]
cov2 = [[3, 0], [0, 2]]
g1 = np.random.multivariate_normal(mean1, cov1, f)
g2 = np.random.multivariate_normal(mean2, cov2, f)

ft = np.vstack((g1, g2))  # data stacked on each other
cls = np.hstack((np.zeros(f), np.ones(f)))  # class values in column matrix
clc = np.reshape(cls,(2 * f, 1))  # class values in an array
w = np.zeros((2, 1))  # weights matrix

N = len(ft)
for _ in range(5000):
    s = np.dot(ft, w)
    prediction = 1 / (1 + np.exp(-s))  # sigmoid function

    gr = np.dot(ft.T, clc - prediction)  # gradient of loss function
    gr /= N  # calculate gradient average

    w += 0.01 * gr  # update weights

print("w = {}".format(w))

lr = LogisticRegression(fit_intercept=False)
lr.fit(ft, cls)
print("lr.coef_ = {}".format(lr.coef_))

输出

w = [[1.28459432]
 [1.07186532]]
lr.coef_ = [[1.23311932 1.0363586 ]]

相关问题 更多 >