如何在groupby中应用statsmodels的OLS

1 投票
2 回答
4999 浏览
提问于 2025-04-18 08:54

我正在按月份对产品进行普通最小二乘法(OLS)分析。对于单个产品来说,这个方法很好用,但我的数据表里有很多产品。如果我创建一个分组对象,OLS就会报错。

linear_regression_df:
  product_desc  period_num    TOTALS  
0    product_a     1          53  
3    product_a     2          52 
6    product_a     3          50 
1    product_b     1          44 
4    product_b     2          43 
7    product_b     3          41 
2    product_c     1          36   
5    product_c     2          35 
8    product_c     3          34 


from pandas import DataFrame, Series
import statsmodels.api as sm    

linear_regression_grouped = linear_regression_df.groupby(['product_desc'])
X = linear_regression_grouped['period_num'] 
y = linear_regression_grouped['TOTALS']

model = sm.OLS(y, X)
results = model.fit()

在sm.OLS()这一行,我收到了这个错误:

ValueError: unrecognized data structures: <class 'pandas.core.groupby.SeriesGroupBy'>

那么,我该如何遍历我的数据表,并对每个产品描述应用sm.OLS()呢?

2 个回答

2

使用 get_group 来获取每个单独的组,然后对每个组进行普通最小二乘法(OLS)模型的分析:

for group in linear_regression_grouped.groups.keys():
    df= linear_regression_grouped.get_group(group)
    X = df['period_num'] 
    y = df['TOTALS']
    model = sm.OLS(y, X)
    results = model.fit()
    print results.summary()

但是在实际情况中,你可能还想要一个截距项,所以模型的定义需要稍微改一下:

for group in linear_regression_grouped.groups.keys():
    df= linear_regression_grouped.get_group(group)
    df['constant']=1
    X = df[['period_num','constant']]
    y = df['TOTALS']
    model = sm.OLS(y,X)
    results = model.fit()
    print results.summary()

有截距项和没有截距项的结果,肯定是非常不同的。

2

你可以这样做...

import pandas as pd
import statsmodels.api as sm

for products in linear_regression_df.product_desc.unique():
    tempdf = linear_regression_df[linear_regression_df.product_desc == products]
    X = tempdf['period_num']
    y = tempdf['TOTALS']

    model = sm.OLS(y, X)
    results = model.fit()

    print results.params #  Or whatever summary info you want

撰写回答