管理tensorflow模型的训练结果、权重和数据流
mlpipe-trainer的Python项目详细描述
MLPIPE培训师
使用mlpipe管理数据管道和tensorflow&keras模型。它不是TensorFlow的另一个“包装器”,而是在MongoDB的帮助下添加一些实用程序来设置一个控制数据流和管理训练模型(权重和结果)的环境。
>> pip install mlpipe-trainer
安装-安装MongoDB
mongodb数据库用于存储训练后的模型,包括模型的权重和结果。此外,还实现了MongoDB的数据读取器(基本上只是一个生成器,如您所知,并且喜欢使用Keras)。currenlty是唯一一个“开箱即用”的数据读取器。 按照mongodb网站上的说明进行安装,例如,对于linux:https://docs.mongodb.com/manual/administration/install-on-linux/
代码示例
配置
# The config is used to specify the localhost connections# for saving trained models to the mongoDB as well as fetching training datafrommlpipe.utilsimportConfigConfig.add_config('./path_to/config.ini')
每个连接配置都由.ini文件中的这些字段组成
[example_mongo_db_connection]db_type=MongoDBurl=localhostport=27017user=read_writepwd=rw
数据管道
frommlpipe.processors.i_processorimportIPreProcessorfrommlpipe.data_reader.mongodbimportMongoDBGeneratorclassPreProcessData(IPreProcessor):defprocess(self,raw_data,input_data,ground_truth,piped_params=None):# Process raw_data to output input_data and ground_truth# which will be the input for the model...returnraw_data,input_data,ground_truth,piped_paramstrain_data=[...]# consists of MongoDB ObjectIds that are used for trainingprocessors=[PreProcessData()]# Chain of Processors (in our case its just one)# Generator that can be used e.g. with keras' fit_generator()train_gen=MongoDBGenerator(("connection_name","cifar10","train"),# specify data source from a MongoDBtrain_data,batch_size=128,processors=processors)
数据生成器从tf.keras.utils.Sequence
继承。查看这个tensorflow docu以了解如何编写自定义生成器(例如,对于除mongodb之外的其他数据源)。
型号
只要最后有一个keras(tensorflow.keras)模型,此步骤就没有限制
model=Sequential()model.add(Conv2D(32,(3,3),padding='same',input_shape=(32,32,3)))...model.add(Dense(10,activation='softmax'))opt=optimizers.RMSprop(lr=0.0001,decay=1e-6)model.compile(optimizer=opt,loss='categorical_crossentropy',metrics=["accuracy"])
培训和回访
frommlpipe.callbacksimportSaveToMongoDBsave_to_mongodb_cb=SaveToMongoDB(("localhost_mongo_db","models"),"test",model)model.fit_generator(generator=train_gen,validation_data=val_gen,epochs=10,verbose=1,callbacks=[save_to_mongodb_cb],initial_epoch=0,)
SaveToMongoDB
是一个自定义的keras回调类,如tensorflow docu中所述。同样,可以根据任何特定需要创建自定义回调。
如果不是fit_generator()
,而是对每个批进行逐个训练(例如使用本机tensorflow模型),则仍然可以在生成器上循环。只需记住在特定步骤调用回调方法,例如on_batch_end()
。
完整的cifar10示例可以在示例文件夹here
路线图
- 创建并生成mkdocs文档和宿主文档
- 添加测试
- 设置CI