Keras中神经网络验证集的规范化

2024-04-20 09:58:17 发布

您现在位置:Python中文网/ 问答频道 /正文

所以,我知道规范化对训练神经网络很重要。在

我也明白我必须用训练集中的参数规范化验证和测试集(参见下面的讨论:https://stats.stackexchange.com/questions/77350/perform-feature-normalization-before-or-within-model-validation

我的问题是:如何在Keras中做到这一点?在

我现在做的是:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping

def Normalize(data):
    mean_data = np.mean(data)
    std_data = np.std(data)
    norm_data = (data-mean_data)/std_data
    return norm_data

input_data, targets = np.loadtxt(fname='data', delimiter=';')
norm_input = Normalize(input_data)

model = Sequential()
model.add(Dense(25, input_dim=20, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

early_stopping = EarlyStopping(monitor='val_acc', patience=50) 
model.fit(norm_input, targets, validation_split=0.2, batch_size=15, callbacks=[early_stopping], verbose=1)

但在这里,我先对整个数据集进行归一化,然后对验证集进行拆分,根据上述讨论,这是错误的。在

保存训练集的平均值和标准差(training_mean和training_std)并不是什么大不了的事,但是如何分别在验证集上应用训练平均值和训练标准差?在


Tags: fromimportnorminputdatamodelnpmean
2条回答

下面的代码正是您想要的:

import numpy as np
def normalize(x_train, x_test):
    mu = np.mean(x_train, axis=0)
    std = np.std(x_train, axis=0)
    x_train_normalized = (x_train - mu) / std
    x_test_normalized = (x_test - mu) / std
    return x_train_normalized, x_test_normalized

然后可以将其与keras一起使用,如下所示:

^{pr2}$

威尔玛的回答是不正确的。在

在使用sklearn.model_selection.train_test_split拟合模型之前,可以手动将数据拆分为训练和测试数据集。然后,分别规范化训练和测试数据,并使用validation_data参数调用model.fit。在

代码示例

import numpy as np
from sklearn.model_selection import train_test_split

data = np.random.randint(0,100,200).reshape(20,10)
target = np.random.randint(0,1,20)

X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

X_train = Normalize(X_train)
X_test = Normalize(X_test)

model.fit(X_train, y_train, validation_data=(X_test,y_test), batch_size=15, callbacks=[early_stopping], verbose=1)

相关问题 更多 >