scikit类学习接口和chainer的堆栈式自动编码器
zChainer的Python项目详细描述
Scikit Learn Like界面和Chainer的堆叠自动编码器
要求
- 努比
- SCIKIT学习
- chainer=1.5
安装
pip install zChainer
用法
自动编码器
importnumpyasnpimportchainer.functionsasFimportchainer.linksasLfromchainerimportChainList,optimizersfromzChainerimportNNAutoEncoder,utilitydata=(..).astype(np.float32)encoder=ChainList(L.Linear(784,200),L.Linear(200,100))decoder=ChainList(L.Linear(200,784),L.Linear(100,200))# You can set your own forward function. Default is as below.#def forward(self, x):# h = F.dropout(F.relu(self.model[0](x)))# return F.dropout(F.relu(self.model[1](h)))##NNAutoEncoder.forward = forwardae=NNAutoEncoder(encoder,decoder,optimizers.Adam(),epoch=100,batch_size=100,log_path="./ae_log_"+utility.now()+".csv",export_path="./ae_"+utility.now()+".model")ae.fit(data)
培训和测试
importnumpyasnpimportchainer.functionsasFimportchainer.linksasLfromchainerimportChainList,optimizersfromzChainerimportNNManager,utilityimportpickleX_train=(..).astype(np.float32)y_train=(..).astype(np.int32)X_test=(..).astype(np.float32)y_test=(..).astype(np.int32)# Create a new networkmodel=ChainList(L.Linear(784,200),L.Linear(200,100),L.Linear(100,10))# or load a serialized model#f = open("./ae_2015-12-01_11-26-45.model")#model = pickle.load(f)#f.close()#model.add_link(L.Linear(100,10))defforward(self,x):h=F.relu(self.model[0](x))h=F.relu(self.model[1](h))returnF.relu(self.model[2](h))defoutput(self,y):y_trimed=y.data.argmax(axis=1)returnnp.array(y_trimed,dtype=np.int32)NNManager.forward=forwardNNManager.output=outputnn=NNManager(model,optimizers.Adam(),F.softmax_cross_entropy,epoch=100,batch_size=100,log_path="./training_log_"+utility.now()+".csv")nn.fit(X_train,y_train,is_classification=True)nn.predict(X_test,y_test)