一个Kerascompatible生成器,用于创建平衡批次
keras-balanced-batch-generator的Python项目详细描述
keras平衡批处理生成器:用于创建平衡批的keras兼容生成器
安装
pip install keras-balanced-batch-generator
概述
该模块实现了一个过采样算法来解决类不平衡的问题。 它生成balanced batches,即每个类的样本数平均相同的批。 生成的批处理也会被洗牌。在
发电机可轻松与Keras型号配合使用
^{
目前,只支持单输入单输出模型的NumPy数组。在
美国石油学会
^{pr2}$x
(努比·恩达雷)输入数据。长度必须与y
相同。在y
(努比·恩达雷)目标数据。必须是二进制类矩阵(即形状(num_samples, num_classes)
)。 可以使用^{} 将类向量转换为二进制类矩阵。在batch_size
(int)批大小。在categorical
(bool)如果为true,则为批处理目标生成二进制类矩阵(即shape(num_samples, num_classes)
)。 否则,生成类向量(即shape(num_samples,)
)。在- ^{str1}$
seed
随机种子(参见docs)。在 - 返回一个Keras兼容的生成器,生成的批处理为
(x, y)
元组。在
使用
fromkeras.modelsimportSequentialfromkeras_balanced_batch_generatorimportmake_generatorx=...y=...batch_size=...steps_per_epoch=...model=Sequential(...)generator=make_generator(x,y,batch_size)model.fit(generator,steps_per_epoch=steps_per_epoch)
示例:多类分类
importnumpyasnpfromkeras.utilsimportto_categoricalfromkeras.modelsimportSequentialfromkeras.layersimportDensefromkeras_balanced_batch_generatorimportmake_generatornum_samples=100num_classes=3input_shape=(2,)batch_size=16x=np.random.rand(num_samples,*input_shape)y=np.random.randint(low=0,high=num_classes,size=num_samples)y=to_categorical(y)generator=make_generator(x,y,batch_size)model=Sequential()model.add(Dense(32,input_shape=input_shape,activation='relu'))model.add(Dense(num_classes,activation='softmax'))model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])model.fit(generator,steps_per_epoch=10,epochs=5)
示例:二进制分类
importnumpyasnpfromkeras.utilsimportto_categoricalfromkeras.modelsimportSequentialfromkeras.layersimportDensefromkeras_balanced_batch_generatorimportmake_generatornum_samples=100num_classes=2input_shape=(2,)batch_size=16x=np.random.rand(num_samples,*input_shape)y=np.random.randint(low=0,high=num_classes,size=num_samples)y=to_categorical(y)generator=make_generator(x,y,batch_size,categorical=False)model=Sequential()model.add(Dense(32,input_shape=input_shape,activation='relu'))model.add(Dense(1,activation='sigmoid'))model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])model.fit(generator,steps_per_epoch=10,epochs=5)
- 项目
标签: