statmodels 回归图无法与 pandas 数据类型配合使用

0 投票
2 回答
1445 浏览
提问于 2025-04-17 19:47
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.graphics as smg

data = pd.DataFrame({'Y': np.random.rand(1000), 'X':np.random.rand(1000)})

这个可以正常工作

smg.regressionplots.plot_fit(sm.OLS(data.Y.values, data.X.values).fit(), 0, y_true=None)

这个不行

smg.regressionplots.plot_fit(sm.OLS(data.Y, data.X).fit(), 0, y_true=None)
smg.regressionplots.plot_fit(sm.OLS(data['Y'], data['X']).fit(), 0, y_true=None)

2 个回答

4

我查了一下,确实是在plot_fit这段代码里有个bug。在稳定版本中,你会看到这一行:

prstd, iv_l, iv_u = wls_prediction_std(results)

这行代码返回了iv_liv_u,大概是用来表示拟合值的标准差的上下限,返回的是pandas的Series格式。这导致后面调用ax.fill_between的时候出错了。

这个问题在开发版本中似乎已经修复了,你可以在这个链接找到相关代码:https://github.com/statsmodels/statsmodels/blob/master/statsmodels/graphics/regressionplots.py。在那儿你会看到一个不同的调用:

prstd, iv_l, iv_u = wls_prediction_std(results._results)

现在iv_liv_u变成了numpy数组,如果你这样做的话就不会再出错了:

smg.regressionplots.plot_fit(sm.OLS(data['Y'], data['X']).fit(), 0, y_true=None)

目前你只能接受这个结果:

smg.regressionplots.plot_fit(sm.OLS(data.Y.values, data.X.values).fit(), 0, y_true=None)

尽管这和通常的线性回归调用不太一致。

3

这个错误信息告诉我们发生了什么。简单来说:

/usr/lib/pymodules/python2.7/matplotlib/axes.pyc in fill_between(self, x, y1, y2, where, interpolate, **kwargs)

   6542                 start = xslice[0], y2slice[0]
-> 6543                 end = xslice[-1], y2slice[-1]

[...]
/usr/local/lib/python2.7/dist-packages/pandas-0.11.0.dev_fc8de6d-py2.7-linux-i686.egg/pandas/core/index.pyc in get_value(self, series, key)

    725         try:
--> 726             return self._engine.get_value(series, key)
    727         except KeyError, e1:
    728             if len(self) > 0 and self.inferred_type == 'integer':

[...]

KeyError: -1L

data.Xdata.Y 是一种叫做 Series 的对象,而你不能用 [-1] 来获取最后一个元素。如果可以这样做,当你的索引中有一个元素是 -1 时,就会出现问题:你是想要最后一个元素,还是想要和 -1 相关联的那个元素呢?

pandas 遵循了“在模糊的情况下,拒绝猜测”的原则,因此不允许这样做,它更看重标签而不是位置。所以你会得到一个 KeyError,而不是 IndexError,这也提示了我们这个问题。你可以查看文档中关于 使用整数标签的高级索引 的讨论,了解更多信息。

撰写回答