我如何用sklearn按组进行多元单变量回归?

2024-04-19 19:22:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图复制这个解决方案Python pandas: how to run multiple univariate regression by group,但是使用sklearn线性回归而不是statsmodels。你知道吗

import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression

df = pd.DataFrame({
  'y': np.random.randn(20),
  'x1': np.random.randn(20), 
  'x2': np.random.randn(20),
  'grp': ['a', 'b'] * 10})


def ols_res(x, y):
    return pd.Series(LinearRegression.fit(x,y).predict(x))


results = df.groupby('grp').apply(lambda x : x[['x1', 'x2']].apply(ols_res, y=x['y']))

print(results)

我得到:

TypeError: ("fit() missing 1 required positional argument: 'y'", 'occurred at index x1')

结果应与第一篇链接的文章相同,即:

             x1        x2
grp                      
a   0 -0.102766 -0.205196
    1 -0.073282 -0.102290
    2  0.023832  0.033228
    3  0.059369 -0.017519
    4  0.003281 -0.077150
        ...       ...
b   5  0.072874 -0.002919
    6  0.180362  0.000502
    7  0.005274  0.050313
    8 -0.065506 -0.005163
    9  0.003419 -0.013829


Tags: importpandasdfasnpresrandomsklearn
1条回答
网友
1楼 · 发布于 2024-04-19 19:22:52

您的代码有两个小问题:

  1. 您没有实例化LinearRegression对象,因此您的代码实际上尝试调用LinearRegression的未绑定fit方法。

  2. 即使修复了此问题,LinearRegression实例也将无法执行fittransform,因为它需要一个2D数组,而得到一个1D数组。因此,您还需要重塑每个Series中包含的数组。

import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression

df = pd.DataFrame({
  'y': np.random.randn(20),
  'x1': np.random.randn(20), 
  'x2': np.random.randn(20),
  'grp': ['a', 'b'] * 10})

def ols_res(x, y):
    x_2d = x.values.reshape(len(x), -1)
    return pd.Series(LinearRegression().fit(x_2d, y).predict(x_2d))

results = df.groupby('grp').apply(lambda df: df[['x1', 'x2']].apply(ols_res, y=df['y']))

print(results)

输出:

             x1        x2
grp                      
a   0 -0.126680  0.137907
    1 -0.441300 -0.595972
    2 -0.285903 -0.385033
    3 -0.252434  0.560938
    4 -0.046632 -0.718514
    5 -0.267396 -0.693155
    6 -0.364425 -0.476643
    7 -0.221493 -0.779082
    8 -0.203781  0.722860
    9 -0.106912 -0.090262
b   0 -0.015384  0.092137
    1  0.478447  0.032881
    2  0.366102  0.059832
    3 -0.055907  0.055388
    4 -0.221876  0.013941
    5 -0.054299  0.048263
    6  0.043979  0.024594
    7 -0.307831  0.059972
    8 -0.226570 -0.024809
    9  0.394460  0.038921

相关问题 更多 >