statmodels 回归图无法与 pandas 数据类型配合使用
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.graphics as smg
data = pd.DataFrame({'Y': np.random.rand(1000), 'X':np.random.rand(1000)})
这个可以正常工作
smg.regressionplots.plot_fit(sm.OLS(data.Y.values, data.X.values).fit(), 0, y_true=None)
这个不行
smg.regressionplots.plot_fit(sm.OLS(data.Y, data.X).fit(), 0, y_true=None)
smg.regressionplots.plot_fit(sm.OLS(data['Y'], data['X']).fit(), 0, y_true=None)
2 个回答
4
我查了一下,确实是在plot_fit
这段代码里有个bug。在稳定版本中,你会看到这一行:
prstd, iv_l, iv_u = wls_prediction_std(results)
这行代码返回了iv_l
和iv_u
,大概是用来表示拟合值的标准差的上下限,返回的是pandas的Series格式。这导致后面调用ax.fill_between
的时候出错了。
这个问题在开发版本中似乎已经修复了,你可以在这个链接找到相关代码:https://github.com/statsmodels/statsmodels/blob/master/statsmodels/graphics/regressionplots.py。在那儿你会看到一个不同的调用:
prstd, iv_l, iv_u = wls_prediction_std(results._results)
现在iv_l
和iv_u
变成了numpy数组,如果你这样做的话就不会再出错了:
smg.regressionplots.plot_fit(sm.OLS(data['Y'], data['X']).fit(), 0, y_true=None)
目前你只能接受这个结果:
smg.regressionplots.plot_fit(sm.OLS(data.Y.values, data.X.values).fit(), 0, y_true=None)
尽管这和通常的线性回归调用不太一致。
3
这个错误信息告诉我们发生了什么。简单来说:
/usr/lib/pymodules/python2.7/matplotlib/axes.pyc in fill_between(self, x, y1, y2, where, interpolate, **kwargs)
6542 start = xslice[0], y2slice[0]
-> 6543 end = xslice[-1], y2slice[-1]
[...]
/usr/local/lib/python2.7/dist-packages/pandas-0.11.0.dev_fc8de6d-py2.7-linux-i686.egg/pandas/core/index.pyc in get_value(self, series, key)
725 try:
--> 726 return self._engine.get_value(series, key)
727 except KeyError, e1:
728 if len(self) > 0 and self.inferred_type == 'integer':
[...]
KeyError: -1L
data.X
和 data.Y
是一种叫做 Series
的对象,而你不能用 [-1]
来获取最后一个元素。如果可以这样做,当你的索引中有一个元素是 -1
时,就会出现问题:你是想要最后一个元素,还是想要和 -1
相关联的那个元素呢?
pandas
遵循了“在模糊的情况下,拒绝猜测”的原则,因此不允许这样做,它更看重标签而不是位置。所以你会得到一个 KeyError
,而不是 IndexError
,这也提示了我们这个问题。你可以查看文档中关于 使用整数标签的高级索引 的讨论,了解更多信息。