火把[琰];硬的
torchkeras的Python项目详细描述
1、引言
torchkeras库是一个简单的工具,用于在pytorch-jusk中以keras风格训练神经网络。在
使用torchkeras,您不需要用很多行代码编写训练循环,您只需要做的就是
如下三个步骤:
(i)创建您的人际网络并用火炬。型号像这样:model = torchkeras.Model(net)
(ii)编译模型以绑定损失函数、优化器和度量函数。在
(iii)用训练数据拟合模型并验证数据。在
这个项目似乎有点强大,但源代码非常简单。
实际上,不到300行Python代码。
如果您想了解或修改此项目的某些细节,请随时阅读并更改源代码!!!
2、举例说明
您可以使用pip安装火炬:
pip install torchkeras
这里有一个完整的例子使用火炬!在
importnumpyasnpimportpandasaspdfrommatplotlibimportpyplotaspltimporttorchfromtorchimportnnimporttorch.nn.functionalasFfromtorch.utils.dataimportDataset,DataLoader,TensorDatasetimporttorchkeras#Attention this line
(1) 准备数据
^{pr2}$# split samples into train and valid data.ds=TensorDataset(X,Y)ds_train,ds_valid=torch.utils.data.random_split(ds,[int(len(ds)*0.7),len(ds)-int(len(ds)*0.7)])dl_train=DataLoader(ds_train,batch_size=100,shuffle=True,num_workers=2)dl_valid=DataLoader(ds_valid,batch_size=100,num_workers=2)
(2) 创建模型
classNet(nn.Module):def__init__(self):super().__init__()self.fc1=nn.Linear(2,4)self.fc2=nn.Linear(4,8)self.fc3=nn.Linear(8,1)defforward(self,x):x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))y=nn.Sigmoid()(self.fc3(x))returnynet=Net()### Attention heremodel=torchkeras.Model(net)model.summary(input_shape=(2,))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 4] 12
Linear-2 [-1, 8] 40
Linear-3 [-1, 1] 9
================================================================
Total params: 61
Trainable params: 61
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.000008
Forward/backward pass size (MB): 0.000099
Params size (MB): 0.000233
Estimated Total Size (MB): 0.000340
----------------------------------------------------------------
(3) 训练模型
# define metricdefaccuracy(y_pred,y_true):y_pred=torch.where(y_pred>0.5,torch.ones_like(y_pred,dtype=torch.float32),torch.zeros_like(y_pred,dtype=torch.float32))acc=torch.mean(1-torch.abs(y_true-y_pred))returnacc# if gpu is available, use gpudevice=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")model.compile(loss_func=nn.BCELoss(),optimizer=torch.optim.Adam(model.parameters(),lr=0.01),metrics_dict={"accuracy":accuracy},device=device)dfhistory=model.fit(30,dl_train=dl_train,dl_val=dl_valid,log_step_freq=20)
Start Training ...
================================================================================2020-06-21 20:40:23
{'step': 10, 'loss': 0.217, 'accuracy': 0.905}
{'step': 20, 'loss': 0.215, 'accuracy': 0.914}
+-------+-------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+-------+----------+----------+--------------+
| 1 | 0.212 | 0.914 | 0.186 | 0.927 |
+-------+-------+----------+----------+--------------+
================================================================================2020-06-21 20:40:23
{'step': 10, 'loss': 0.211, 'accuracy': 0.912}
{'step': 20, 'loss': 0.193, 'accuracy': 0.919}
+-------+-------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+-------+----------+----------+--------------+
| 2 | 0.194 | 0.919 | 0.188 | 0.935 |
+-------+-------+----------+----------+--------------+
================================================================================2020-06-21 20:40:23
{'step': 10, 'loss': 0.217, 'accuracy': 0.913}
{'step': 20, 'loss': 0.205, 'accuracy': 0.92}
+-------+-------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+-------+----------+----------+--------------+
| 3 | 0.195 | 0.921 | 0.176 | 0.931 |
+-------+-------+----------+----------+--------------+
================================================================================2020-06-21 20:40:23
{'step': 10, 'loss': 0.164, 'accuracy': 0.932}
{'step': 20, 'loss': 0.197, 'accuracy': 0.917}
+-------+-------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+-------+----------+----------+--------------+
| 4 | 0.197 | 0.917 | 0.178 | 0.935 |
+-------+-------+----------+----------+--------------+
================================================================================2020-06-21 20:40:24
{'step': 10, 'loss': 0.192, 'accuracy': 0.926}
{'step': 20, 'loss': 0.182, 'accuracy': 0.931}
+-------+-------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+-------+----------+----------+--------------+
| 5 | 0.193 | 0.924 | 0.188 | 0.928 |
+-------+-------+----------+----------+--------------+
================================================================================2020-06-21 20:40:44
{'step': 10, 'loss': 0.175, 'accuracy': 0.932}
{'step': 20, 'loss': 0.188, 'accuracy': 0.924}
+-------+-------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+-------+----------+----------+--------------+
| 97 | 0.184 | 0.923 | 0.176 | 0.935 |
+-------+-------+----------+----------+--------------+
================================================================================2020-06-21 20:40:44
{'step': 10, 'loss': 0.21, 'accuracy': 0.913}
{'step': 20, 'loss': 0.192, 'accuracy': 0.918}
+-------+------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+------+----------+----------+--------------+
| 98 | 0.19 | 0.922 | 0.179 | 0.934 |
+-------+------+----------+----------+--------------+
================================================================================2020-06-21 20:40:45
{'step': 10, 'loss': 0.186, 'accuracy': 0.923}
{'step': 20, 'loss': 0.181, 'accuracy': 0.928}
+-------+-------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+-------+----------+----------+--------------+
| 99 | 0.182 | 0.926 | 0.178 | 0.938 |
+-------+-------+----------+----------+--------------+
================================================================================2020-06-21 20:40:45
{'step': 10, 'loss': 0.16, 'accuracy': 0.93}
{'step': 20, 'loss': 0.173, 'accuracy': 0.93}
+-------+-------+----------+----------+--------------+
| epoch | loss | accuracy | val_loss | val_accuracy |
+-------+-------+----------+----------+--------------+
| 100 | 0.185 | 0.925 | 0.174 | 0.936 |
+-------+-------+----------+----------+--------------+
================================================================================2020-06-21 20:40:45
Finished Training...
# visual the resultsfig,(ax1,ax2)=plt.subplots(nrows=1,ncols=2,figsize=(12,5))ax1.scatter(Xp[:,0],Xp[:,1],c="r")ax1.scatter(Xn[:,0],Xn[:,1],c="g")ax1.legend(["positive","negative"]);ax1.set_title("y_true")Xp_pred=X[torch.squeeze(model.forward(X)>=0.5)]Xn_pred=X[torch.squeeze(model.forward(X)<0.5)]ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c="r")ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c="g")ax2.legend(["positive","negative"]);ax2.set_title("y_pred")
(4) 评估模型
%matplotlibinline%configInlineBackend.figure_format='svg'importmatplotlib.pyplotaspltdefplot_metric(dfhistory,metric):train_metrics=dfhistory[metric]val_metrics=dfhistory['val_'+metric]epochs=range(1,len(train_metrics)+1)plt.plot(epochs,train_metrics,'bo--')plt.plot(epochs,val_metrics,'ro-')plt.title('Training and validation '+metric)plt.xlabel("Epochs")plt.ylabel(metric)plt.legend(["train_"+metric,'val_'+metric])plt.show()
plot_metric(dfhistory,"loss")
plot_metric(dfhistory,"accuracy")
model.evaluate(dl_valid)
{'val_loss': 0.13576620258390903, 'val_accuracy': 0.9441666702429453}
(5) 使用模型
model.predict(dl_valid)[0:10]
tensor([[0.8767],
[0.0154],
[0.9976],
[0.9990],
[0.9984],
[0.0071],
[0.3529],
[0.4061],
[0.9938],
[0.9997]])
forfeatures,labelsindl_valid:withtorch.no_grad():predictions=model.forward(features)print(predictions[0:10])break
tensor([[0.9979],
[0.0011],
[0.9782],
[0.9675],
[0.9653],
[0.9906],
[0.1774],
[0.9994],
[0.9178],
[0.9579]])
(6) save the model
# save the model parameterstorch.save(model.state_dict(),"model_parameter.pkl")model_clone=torchkeras.Model(Net())model_clone.load_state_dict(torch.load("model_parameter.pkl"))model_clone.compile(loss_func=nn.BCELoss(),optimizer=torch.optim.Adam(model.parameters(),lr=0.01),metrics_dict={"accuracy":accuracy})model_clone.evaluate(dl_valid)
{'val_loss': 0.17422042911251387, 'val_accuracy': 0.9358333299557368}
- 项目
标签: