网格搜索Keras时出错

2024-04-19 22:25:11 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用网格搜索优化技术来提高Python和Keras中深度学习模型的准确性。 下面的脚本我正在使用

# encode class values as integers
encoder = LabelEncoder()
encoder.fit(train["Group"])
encoder_name_mapping = dict(zip(encoder.classes_, encoder.transform(encoder.classes_)))
print(encoder_name_mapping)
encoded_Y = encoder.transform(train["Group"])

# convert integers to dummy variables (i.e. one hot encoded)
train_y = np_utils.to_categorical(encoded_Y)

    def create_model():
        model = Sequential()
        model.add(Dense(10, input_dim=train_data_features.shape[1], activation='relu'))
        model.add(Dense(len(list(set(train["Group"]))), activation='softmax'))
        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
        return model

    # fix random seed for reproducibility
    seed = 7
    numpy.random.seed(seed)

    # create model
    model = KerasClassifier(build_fn=create_model)

    # define the grid search parameters
    batch_size = [10, 20, 40, 60, 80, 100]
    epochs = [10, 50, 100]
    param_grid = dict(batch_size=batch_size, epochs=epochs)
    grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)

    grid_result = grid.fit(train_data_features, train_y)

但是,我是在错误下面。有人能帮我吗。在

^{pr2}$

Tags: integersnameencodersizemodelparamcreatebatch
2条回答

下面是一个如何将KerasClassifier与GridSearchCV一起使用的示例。 我认为很清楚,你如何适应它。在

def create_model(optimizer='adam'):
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model

optimizer = ['SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adam', 'Adamax', 'Nadam']
param_grid = dict(optimizer=optimizer)
model = KerasClassifier(build_fn=create_model, epochs=100, batch_size=10, verbose=0)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
grid_result = grid.fit(X, Y)

如果使用GPU进行神经网络训练,请在GridSearchCV中设置n_jobs=1。你可能只有一个GPU,这个参数是针对CPU线程的。在

相关问题 更多 >