Keras神经网络的部分依赖图

2024-05-14 19:18:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个紧密连接的神经网络,它是用Keras Sequential API构建的。我正在尝试创建一些部分依赖图(PDP),以用于敏感性分析。我正试图使用scikit-learn plot_partial_dependence函数来实现这一点。我得到了以下错误:ValueError: 'estimator' must be a fitted regressor or classifier.。当它第一次出现时,我添加了KerasClassifier的用法。我过去在scikit-learn GridSearchCV中成功地使用了它来使用我的Keras模型。我还是会犯同样的错误。我也试过了

有谁能告诉我出了什么问题,我该如何解决?我是否绝对需要使用scikit learn基于决策树的函数才能使用PDP函数?如果是,Keras神经网络和决策树之间最大的实现差异是什么?(我从未使用过决策树。我的机器学习经验仅限于Keras。)

下面是我的相关代码,我正在GoogleColab的GPU上运行python。我确信最后一行中有几个问题,但我无法通过这一行来解决它们

from sklearn.inspection import plot_partial_dependence
from keras.wrappers.scikit_learn import KerasClassifier
from keras.wrappers.scikit_learn import KerasRegressor

def create_model():
  def swish(x):
    return (x*sigmoid(x))

  from keras.utils.generic_utils import get_custom_objects
  from keras.layers import Activation
  get_custom_objects().update({'swish':(swish)})

  model=Sequential()

  model.add(Dense(1024,activation='swish',input_shape=(6,)))
  model.add(Dropout(.1))

  model.add(Dense(512,activation='swish'))

  model.add(Dense(256,activation='swish'))
  model.add(Dropout(.1))

  model.add(Dense(128,activation='swish'))

  model.add(Dense(64,activation='swish'))
  model.add(Dropout(.1))

  model.add(Dense(32,activation='swish'))

  model.add(Dense(16,activation='swish'))
  model.add(Dropout(.1))

  model.add(Dense(12, activation='softmax'))

  opt=optimizers.Adam(lr=0.05)

  model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])

  return model

from keras.callbacks import LearningRateScheduler
from keras.callbacks import EarlyStopping
import math
def scheduler(epoch, lr):
  if epoch < 20:
    return lr
  else:
    return lr * math.exp(-0.1)

callback=LearningRateScheduler(scheduler, verbose=1)

weightsCallback=EarlyStopping(patience=30,monitor='accuracy',restore_best_weights=True, min_delta=1*10**-5, verbose=1)

modelClassified=KerasClassifier(build_fn=create_model)

modelClassified.fit(X_train, Y_train, batch_size=50, epochs=50, callbacks=[callback,weightsCallback], verbose=1)

disp=plot_partial_dependence(modelClassified, X_holdout,target=1, verbose =1, features=[0,1,2,3,4,5], feature_names=['aspect ratio','diel inner radius','diel outer radius','diel seperation','diel height','diel constant'])

Tags: fromimportaddmodelreturnscikitactivationlearn
1条回答
网友
1楼 · 发布于 2024-05-14 19:18:30

我发现这个错误实际上是一个bug。无论如何,我的程序应该运行得很好。plot_partial_dependence函数源代码中存在错误

有关更多详细信息以及我用来使其工作的解决方案,请参阅指向另一个StackOverflow问题的链接:https://stackoverflow.com/a/61485502/13822019

相关问题 更多 >

    热门问题