打印模型摘要的实用函数。
torch-inspect的Python项目详细描述
特点
- 提供打印Keras样式模型摘要的帮助函数summary。在
- 提供帮助函数inspect,该函数返回具有网络摘要信息的对象以进行编程访问。在
- RNN/LSTM支持。在
- 库有测试和合理的代码覆盖率。在
简单示例
importtorch.nnasnnimporttorch.nn.functionalasFimporttorch_inspectasticlassSimpleNet(nn.Module):def__init__(self):super(SimpleNet,self).__init__()self.conv1=nn.Conv2d(1,6,3)self.conv2=nn.Conv2d(6,16,3)self.fc1=nn.Linear(16*6*6,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)defforward(self,x):x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))x=F.max_pool2d(F.relu(self.conv2(x)),2)x=x.view(-1,self.num_flat_features(x))x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)returnxdefnum_flat_features(self,x):size=x.size()[1:]num_features=1forsinsize:num_features*=sreturnnum_featuresnet=SimpleNet()ti.summary(net,(1,32,32))
将产生以下输出:
^{pr2}$对于网络信息的编程访问,有inspect函数:
info=ti.inspect(net,(1,32,32))print(info)
[LayerInfo(name='Conv2d-1', input_shape=[100, 1, 32, 32], output_shape=[100, 6, 30, 30], trainable_params=60, non_trainable_params=0), LayerInfo(name='Conv2d-2', input_shape=[100, 6, 15, 15], output_shape=[100, 16, 13, 13], trainable_params=880, non_trainable_params=0), LayerInfo(name='Linear-3', input_shape=[100, 576], output_shape=[100, 120], trainable_params=69240, non_trainable_params=0), LayerInfo(name='Linear-4', input_shape=[100, 120], output_shape=[100, 84], trainable_params=10164, non_trainable_params=0), LayerInfo(name='Linear-5', input_shape=[100, 84], output_shape=[100, 10], trainable_params=850, non_trainable_params=0)]
安装
安装过程简单,只需:
$ pip install torch-inspect
推荐和感谢
此包基于pytorch-summary和PyTorchissue。与 pytorch-summary,pytorch inspect支持RNN/LSTMs,还提供程序化的 访问网络摘要信息。更模块化的结构和测试的存在 它更容易扩展和支持更多的功能。在
变化
0.0.3(2019-09-22)
- 增加了LSTM支持
- 固定多输入/输出支持
- 添加了更多的网络测试用例
- 默认情况下,批大小不再为-1
0.0.2(2019-09-22)
- 增加批量定额支持
- 已删除设备参数
0.0.1(2019-09-1)
- 初始版本。在
- 项目
标签: