使ML模型与scikitlearn兼容

2024-06-16 10:09:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使此ML型号scikit学习兼容: https://github.com/manifoldai/merf

为此,我按照这里的说明:https://danielhnyk.cz/creating-your-own-estimator-scikit-learn/并导入 from sklearn.base import BaseEstimator, RegressorMixin并像这样从他们那里继承: class MERF(BaseEstimator, RegressorMixin):

但是,当我检查scikit学习兼容性时:

from sklearn.utils.estimator_checks import check_estimator

import merf
check_estimator(merf)

我得到这个错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 500, in check_estimator
    for estimator, check in checks_generator:
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in _generate_instance_checks
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in <genexpr>
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 232, in _yield_all_checks
    tags = estimator._get_tags()
AttributeError: module 'merf' has no attribute '_get_tags'

如何使此模型与scikit学习兼容


Tags: infromchecklineutilssklearnscikitusers
1条回答
网友
1楼 · 发布于 2024-06-16 10:09:56

docs中,check_estimator用于“检查估计器是否遵守scikit学习约定。”

This estimator will run an extensive test-suite for input validation, shapes, etc, making sure that the estimator complies with scikit-learn conventions as detailed in Rolling your own estimator. Additional tests for classifiers, regressors, clustering or transformers will be run if the Estimator class inherits from the corresponding mixin from sklearn.base.

所以check_estimator不仅仅是一个兼容性检查,它还检查您是否遵循所有约定等

你可以仔细阅读rolling your own estimator以确保遵守惯例

然后,您需要传递estimator类的一个实例来检查类似于check_estimator(MERF())的esimator。要使它真正遵循所有约定,您必须解决它抛出的每个错误,并逐一修复它们

例如,一个这样的检查是__init__方法只设置它接受作为参数的那些属性

MERF类违反以下规定:

    def __init__(
        self,
        fixed_effects_model=RandomForestRegressor(n_estimators=300, n_jobs=-1),
        gll_early_stop_threshold=None,
        max_iterations=20,
    ):
        self.gll_early_stop_threshold = gll_early_stop_threshold
        self.max_iterations = max_iterations

        self.cluster_counts = None
        # Note fixed_effects_model must already be instantiated when passed in.
        self.fe_model = fixed_effects_model
        self.trained_fe_model = None
        self.trained_b = None

        self.b_hat_history = []
        self.sigma2_hat_history = []
        self.D_hat_history = []
        self.gll_history = []
        self.val_loss_history = []

它正在设置self.b_hat_history等属性,即使它们不是参数

还有很多像这样的支票

我个人的建议是,除非必要,否则不要检查所有这些条件,只需继承mixin和基类,实现所需的方法并使用模型

相关问题 更多 >