深部残差神经网络包装库
baggingrnet的Python项目详细描述
baggingrnet:深部残余神经网络的bagging库
简介
这个包提供了用于深度残差神经网络(baggingrnet)打包的python库。当前版本只支持Keras的深度学习包,并将在未来扩展到其他版本。本软件包提供了以下功能:*模型multbagging:基于深度剩余网络的自动编码器的主要类到并行打包。你可以设置它的最佳效果。有关详细信息,请参阅类及其成员函数的帮助。 resautoencoder:基于深度残差网络的自动编码器基本模型的主要类。详情请看具体情况。ensprediction:主要的类到系综预测和独立测试的可选评估。 *util p metrics:主要指标包括rsquare和rmse等。
- 数据数据:访问两个样本数据的功能,通过bagging测试和演示多个模型的并行训练和预测。simdata:模拟数据集进行测试的函数。
软件包的安装
对于最新版本,您可以使用以下命令直接安装此软件包:
pip install baggingrnet
您还可以克隆存储库,然后安装:
git clone --recursive https://github.com/lspatial/baggingrnet.git pip install ./setup.py install
建模框架
该模型建立在基于编解码的深度残差多层感知器(mlp)的编解码自动编码的基础上。利用编码层到解码层的残差连接来提高学习效率,利用bagging来实现具有不确定性度量(标准差)的稳定和改进的集成预测。
相关论文将在此处发布并在发布后更新。
示例1:模拟数据的回归
数据集的模拟公式如下:frombaggingrnet.dataimportdatasim_train=data('sim_train')sim_train['gindex']=np.array([iforiinrange(sim_train.shape[0])])
knitr::kable(py$sim_train[c(1:5),],format="html")<表><广告>
# Load the major class for parallel bagging trainingfrombaggingrnet.model.baggingimportmultBaggingfeasList=['x'+str(i)foriinrange(1,9)]#List of the covariates used in training target='y'# Name of the target variable bagpath='/tmp/sim_bagging/res'# Path used to chkpath(bagpath)mbag=multBagging(bagpath)mbag.getInputSample(sim_train,feasList,None,'gindex',target)
3)定义模型的参数并将其附加到建模职责列表中:
name=str(0)# model name as unique identifier nodes=[32,16,8,4]# List of number of nodes for the encoding and coding layers, adjustable optionally; minibatch=512# Size for mini batch isresidual=True# Whether to use residual connections in the model nepoch=200#Number of epoches sampling_fea=False# Whether to bootstrap the predictors/features noutput=1# Number of the output node islog=False# Whether to make the log transformation # The following is to add the model's arguments to the list of duties. mbag.addTask(name,noutput,sampling_fea,nepoch,nodes,minibatch,isresidual,islog)
4)开始培训:
mbag.startMProcess(1)
在这里,一个模型只使用一个核心。
5)使用训练模型的预测和对训练模型的可选评估:
frombaggingrnet.model.baggingpreimportensPrediction# Load the test dataset sim_test=data('sim_test')sim_test['gindex']=np.array([iforiinrange(sim_test.shape[0])])# Generate the unique id for merging the predicitons of multiple models # Setup the path and target variable prepath="/tmp/sim_bagging/res_pre"chkpath(prepath)#Load the prdiction classmbagpre=ensPrediction(bagpath,prepath)#Load the test data mbagpre.getInputSample(sim_test,feasList,'gindex')#Start to make predictions for multiple trained models. mbagpre.startMProcess(1)#Obtain the ensemble predictions from those of multiple models and optional evaluation of the models. mbagpre.aggPredict(isval=True,tfld='y')
以上五个步骤说明了加载数据、训练、测试和预测的过程。为了与残差模型的结果进行比较,下面的代码将获得非残差模型的结果。
mbag.removeTask(name)bagpath='/tmp/sim_bagging/nores'chkpath(bagpath)mbag_nores=multBagging(bagpath)mbag_nores.getInputSample(sim_train,feasList,None,'gindex','y')isresidual=False# This is to set no use of residual connections in the models. mbag_nores.addTask(name,noutput,sampling_fea,nepoch,nodes,minibatch,isresidual,islog)mbag_nores.startMProcess(1)prepath="/tmp/sim_bagging/nores_pre"chkpath(prepath)mbagpre=ensPrediction(bagpath,prepath)mbagpre.getInputSample(sim_test,feasList,'gindex')mbagpre.startMProcess(1)mbagpre.aggPredict(isval=True,tfld='y')
残差模型和非残差模型的训练/学习曲线比较:
残差模型和非残差模型独立检验的比较:性能(r2和rmse)
pip install baggingrnet
0
独立测试中残差模型与非残差模型的分散性比较:
例2:pm的时空估计2.5
该数据集是京津唐地区2015年pm2.5及相关协变量的真实数据集。由于数据安全原因,它添加了小高斯噪声。
1)加载输入数据:
这里使用pm2.5数据集来测试建议的方法。
pip install baggingrnet
1
<表><广告> pip install baggingrnet
2
pip install baggingrnet
3
pip install baggingrnet
4
pip install baggingrnet
5
3)定义多个模型(这里是100个模型)的参数,并将它们附加到建模职责列表中:
pip install baggingrnet
6
4)开始培训:
使用10核启动并行程序
pip install baggingrnet
7
5)使用训练模型的预测和对训练模型的可选评估:
pip install baggingrnet
8
最后,得到以下结果。
结果如下所示:
1)非残差与残差模型的典型学习曲线如下所示:
2)独立数据集的多个非残差与残差模型预测的平均性能(r2和rmse):
3)基于独立数据集多个模型的集合预测的性能(r2和rmse):
pip install baggingrnet
9
4)非残差与残差模型集合预测的散点图:
5)集合预测与单个模型预测的比较:
对多个模型的预测和集合预测的性能进行了统计。下面显示r2和rmse、条形图和散点图。
性能数字:
git clone --recursive https://github.com/lspatial/baggingrnet.git
pip install ./setup.py install
0
与单一模型相比,箱形图通过套袋(6%的r2和-5.72μg/m3)显示出显著的改善。
以下是观测到的pm2.5与集合预测/残差的散点图:
联系人
对于本图书馆及其相关的完整应用,欢迎与李连发博士联系。电子邮件:lspatial@gmail.com或lilf@lrees.ac.cn