用xml定义pytorch神经网络

xml2pytorch的Python项目详细描述


xml2季度

使用xml定义pytorch神经网络

它能做什么

使用xml2pytorch,可以很容易地用xml定义神经网络,然后用pytorch声明它们。

当前不支持RNN和LSTM。

安装

环境

操作系统独立。Python3。(未在python2上测试,但应该有效。)

安装要求

火炬>;=0.4.1 numpy=1.15.1

通过pip3安装

pip3安装xml2pytorch

快速启动

如何声明由xml文件定义的cnn

import torch
import xml2pytorch as xm

# declare the net defined in .xml
net = xm.convertXML(xml_filename)    

# input a random tensor
x = torch.randn(1, 3, 32, 32)
y = net(x)
print(y)

如何用xml定义一个简单的cnn

<graph>
	<net>
		<layer>
			<net_style>Conv2d</net_style>
			<in_channels>3</in_channels>
			<out_channels>6</out_channels>
			<kernel_size>5</kernel_size>
		</layer>	
		<layer>
			<net_style>ELU</net_style>
		</layer>	
		<layer>
			<net_style>MaxPool2d</net_style>
			<kernel_size>2</kernel_size>
			<stride>2</stride>
			<activation>logsigmoid</activation>
		</layer>
		<layer>
			<net_style>Conv2d</net_style>
			<in_channels>6</in_channels>
			<out_channels>16</out_channels>
			<kernel_size>5</kernel_size>
			<activation>relu</activation>
		</layer>	
		<layer>
			<net_style>MaxPool2d</net_style>
			<kernel_size>2</kernel_size>
			<stride>2</stride>
			<activation>relu</activation>
		</layer>
		<layer>
			<net_style>reshape</net_style>
			<dimensions>[-1, 16*5*5]</dimensions>
		</layer>
		<layer>
			<net_style>Linear</net_style>
			<in_features>400</in_features> 
			<out_features>120</out_features>
			<activation>tanh</activation>
		</layer>
		<layer>
			<net_style>Linear</net_style>
			<in_features>120</in_features> 
			<out_features>84</out_features>
			<activation>sigmoid</activation>
		</layer>
		<layer>
			<net_style>Linear</net_style>
			<in_features>84</in_features>
			<out_features>10</out_features>
			<activation>softmax</activation>
		</layer>
	</net>
</graph>

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
JAVAutil。整数java的扫描器键盘输入   java通知运行后立即崩溃   java如何在一个只能由类修改而不能由其实例修改的类中生成静态变量?   数据库Java字段猜测   返回值周围的java括号为什么?   java Android更新通讯录中的联系人   一个消费者正在读取数据   java是否可以通过编程方式为蓝牙配对设置pin?   java Spring引导和buildResponseEntity()   java为什么序列化可以在没有实现可序列化的情况下工作   Java同步无助于相互排斥   twitter Java Twitter4J未在推文下显示源标签   为什么Javasocket不支持中断处理?