如何在groupby中应用statsmodels的OLS
我正在按月份对产品进行普通最小二乘法(OLS)分析。对于单个产品来说,这个方法很好用,但我的数据表里有很多产品。如果我创建一个分组对象,OLS就会报错。
linear_regression_df:
product_desc period_num TOTALS
0 product_a 1 53
3 product_a 2 52
6 product_a 3 50
1 product_b 1 44
4 product_b 2 43
7 product_b 3 41
2 product_c 1 36
5 product_c 2 35
8 product_c 3 34
from pandas import DataFrame, Series
import statsmodels.api as sm
linear_regression_grouped = linear_regression_df.groupby(['product_desc'])
X = linear_regression_grouped['period_num']
y = linear_regression_grouped['TOTALS']
model = sm.OLS(y, X)
results = model.fit()
在sm.OLS()这一行,我收到了这个错误:
ValueError: unrecognized data structures: <class 'pandas.core.groupby.SeriesGroupBy'>
那么,我该如何遍历我的数据表,并对每个产品描述应用sm.OLS()呢?
2 个回答
2
使用 get_group
来获取每个单独的组,然后对每个组进行普通最小二乘法(OLS)模型的分析:
for group in linear_regression_grouped.groups.keys():
df= linear_regression_grouped.get_group(group)
X = df['period_num']
y = df['TOTALS']
model = sm.OLS(y, X)
results = model.fit()
print results.summary()
但是在实际情况中,你可能还想要一个截距项,所以模型的定义需要稍微改一下:
for group in linear_regression_grouped.groups.keys():
df= linear_regression_grouped.get_group(group)
df['constant']=1
X = df[['period_num','constant']]
y = df['TOTALS']
model = sm.OLS(y,X)
results = model.fit()
print results.summary()
有截距项和没有截距项的结果,肯定是非常不同的。
2
你可以这样做...
import pandas as pd
import statsmodels.api as sm
for products in linear_regression_df.product_desc.unique():
tempdf = linear_regression_df[linear_regression_df.product_desc == products]
X = tempdf['period_num']
y = tempdf['TOTALS']
model = sm.OLS(y, X)
results = model.fit()
print results.params # Or whatever summary info you want