实现自定义scikit-learn估计器的完整规范是什么?
我正在自己做一个预测器,想像使用scikit库里的其他工具(比如RandomForestRegressor)那样来使用它。我有一个类,里面有fit
和predict
这两个方法,看起来都能正常工作。但是,当我尝试使用一些scikit的方法,比如交叉验证时,却出现了错误,内容像这样:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in cross_val_
score
for train, test in cv)
File "C:\Python27\lib\site-packages\sklearn\externals\joblib\parallel.py", line 516, in __
call__
for function, args, kwargs in iterable:
File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in <genexpr>
for train, test in cv)
File "C:\Python27\lib\site-packages\sklearn\base.py", line 43, in clone
% (repr(estimator), type(estimator)))
TypeError: Cannot clone object '<__main__.Custom instance at 0x033A6990>' (type <type 'inst
ance'>): it does not seem to be a scikit-learn estimator a it does not implement a 'get_para
ms' methods.
我发现它希望我实现一些方法(可能是get_params
,还有可能是set_params
和score
),但我不太确定这些方法应该怎么写。关于这个话题,有没有什么资料可以参考?谢谢。
1 个回答
18
完整的说明可以在scikit-learn文档中找到,API背后的原理在这篇论文中有详细介绍。简单来说,除了fit
方法,你的估计器还需要get_params
和set_params
这两个方法。get_params
用来返回估计器的超参数(以dict
的形式),而set_params
则用来设置这些超参数(通过关键字参数)。这些超参数是学习算法本身的参数,而不是它学习的数据参数。这些参数应该和__init__
方法中的参数一致。
这两个方法可以通过继承sklearn.base
中的类来获得,但如果你不想让你的代码依赖于scikit-learn,也可以自己实现这两个方法。
需要注意的是,输入验证应该在fit
方法中进行,而不是在构造函数中,因为如果在set_params
中设置了无效的参数,fit
可能会以意想不到的方式失败。