我修改了scikit的BernoulliRBM类,学习使用softmax可见单元组。在这个过程中,我添加了一个额外的Numpy数组visible_config
作为类属性,它在构造函数中初始化如下:
self.visible_config = np.cumsum(np.concatenate((np.asarray([0]),
visible_config), axis=0))
其中visible_config
是作为输入传递给构造函数的Numpy数组。当我直接使用fit()
函数来训练模型时,代码运行没有错误。但是,当我使用GridSearchCV
结构时,会得到以下错误
Cannot clone object SoftmaxRBM(batch_size=100, learning_rate=0.01, n_components=100, n_iter=100,
random_state=0, verbose=True, visible_config=[ 0 21 42 63]), as the constructor does not seem to set parameter visible_config
这似乎是类实例与其由sklearn.base.clone创建的副本之间的相等性检查中的问题,因为visible_config
没有正确复制。我不知道怎么解决这个问题。文档中说sklearn.base.clone
使用deepcopy()
,所以visible_config
不应该也被复制吗?有人能解释一下我能在这里试什么吗?谢谢!
如果没有看到您的代码,很难准确地判断出什么地方出了问题,但是您违反了这里的scikit-learn-API约定。估计器中的构造函数只应为用户作为参数传递的值设置属性。所有计算都应在
fit
中进行,如果fit
需要存储计算结果,则应在带有尾随下划线(_
)的属性中进行。这个约定使得clone
和GridSearchCV
等元估计器工作。(*)如果您在主代码库中看到违反此规则的估计器:那将是一个bug,欢迎使用修补程序。
相关问题 更多 >
编程相关推荐