用xml定义pytorch神经网络
xml2pytorch的Python项目详细描述
xml2季度
使用xml定义pytorch神经网络
它能做什么
使用xml2pytorch,可以很容易地用xml定义神经网络,然后用pytorch声明它们。
当前不支持RNN和LSTM。
安装
环境
操作系统独立。Python3。(未在python2上测试,但应该有效。)
安装要求
火炬>;=0.4.1 numpy=1.15.1
通过pip3安装
pip3安装xml2pytorch
快速启动
如何声明由xml文件定义的cnn
import torch
import xml2pytorch as xm
# declare the net defined in .xml
net = xm.convertXML(xml_filename)
# input a random tensor
x = torch.randn(1, 3, 32, 32)
y = net(x)
print(y)
如何用xml定义一个简单的cnn
<graph>
<net>
<layer>
<net_style>Conv2d</net_style>
<in_channels>3</in_channels>
<out_channels>6</out_channels>
<kernel_size>5</kernel_size>
</layer>
<layer>
<net_style>ELU</net_style>
</layer>
<layer>
<net_style>MaxPool2d</net_style>
<kernel_size>2</kernel_size>
<stride>2</stride>
<activation>logsigmoid</activation>
</layer>
<layer>
<net_style>Conv2d</net_style>
<in_channels>6</in_channels>
<out_channels>16</out_channels>
<kernel_size>5</kernel_size>
<activation>relu</activation>
</layer>
<layer>
<net_style>MaxPool2d</net_style>
<kernel_size>2</kernel_size>
<stride>2</stride>
<activation>relu</activation>
</layer>
<layer>
<net_style>reshape</net_style>
<dimensions>[-1, 16*5*5]</dimensions>
</layer>
<layer>
<net_style>Linear</net_style>
<in_features>400</in_features>
<out_features>120</out_features>
<activation>tanh</activation>
</layer>
<layer>
<net_style>Linear</net_style>
<in_features>120</in_features>
<out_features>84</out_features>
<activation>sigmoid</activation>
</layer>
<layer>
<net_style>Linear</net_style>
<in_features>84</in_features>
<out_features>10</out_features>
<activation>softmax</activation>
</layer>
</net>
</graph>