Pythorch中的模型摘要,基于原始火炬摘要。

torch-summar的Python项目详细描述


火炬总结

Python 3.5+PyPI versionBuild StatusGitHub licensecodecovDownloads

Torch summary提供了对PyTorch中print(your_model)所提供内容的补充信息,类似于Tensorflow的model.summary()API来查看模型的可视化,这在调试网络时很有帮助。在这个项目中,我们在PyTorch中实现了一个类似的功能,并创建了一个干净、简单的界面,可以在您的项目中使用。在

这是由@sksq96和@nmhkahn编写的原始torchsummary和torchsummaryX项目的完全重写版本。这个项目通过引入一个全新的API来解决遗留在原始项目上的所有问题和请求。在

使用

pip install torch-summary

或者

git clone https://github.com/tyleryep/torch-summary.git

如何使用

fromtorchsummaryimportsummarymodel=ConvNet()summary(model,(1,28,28))
^{pr2}$

此版本现在支持:

  • RNN、LSTM和其他递归层
  • 顺序和模块列表
  • 用于使用指定深度探索模型图层的分支输出
  • 返回包含所有摘要数据字段的ModelStatistics对象
  • 可配置列

其他新功能:

  • 详细模式显示权重和偏移层
  • 接受输入数据或简单的输入形状!在
  • 可自定义宽度和批次尺寸
  • 全面的单元/输出测试、linting和代码覆盖测试

文件

"""Summarize the given PyTorch model. Summarized information includes:    1) Layer names,    2) input/output shapes,    3) kernel shape,    4) # of parameters,    5) # of operations (Mult-Adds)Args:    model (nn.Module):            PyTorch model to summarize    input_data (Sequence of Sizes or Tensors):            Example input tensor of the model (dtypes inferred from model input).            - OR -            Shape of input data as a List/Tuple/torch.Size            (dtypes must match model input, default is FloatTensors).            You should NOT include batch size in the tuple.            - OR -            If input_data is not provided, no forward pass through the network is            performed, and the provided model information is limited to layer names.            Default: None    batch_dim (int):            Batch_dimension of input data. If batch_dim is None, the input data            is assumed to contain the batch dimension.            WARNING: in a future version, the default will change to None.            Default: 0    branching (bool):            Whether to use the branching layout for the printed output.            Default: True    col_names (Sequence[str]):            Specify which columns to show in the output. Currently supported:            ("input_size", "output_size", "num_params", "kernel_size", "mult_adds")            If input_data is not provided, only "num_params" is used.            Default: ("output_size", "num_params")    col_width (int):            Width of each column.            Default: 25    depth (int):            Number of nested layers to traverse (e.g. Sequentials).            Default: 3    device (torch.Device):            Uses this torch device for model and input_data.            If not specified, uses result of torch.cuda.is_available().            Default: None    dtypes (List[torch.dtype]):            For multiple inputs, specify the size of both inputs, and            also specify the types of each parameter here.            Default: None    verbose (int):            0 (quiet): No output            1 (default): Print model summary            2 (verbose): Show weight and bias layers in full detail            Default: 1    *args, **kwargs:            Other arguments used in `model.forward` function.Return:    ModelStatistics object            See torchsummary/model_statistics.py for more information."""

示例

以字符串形式获取模型摘要

fromtorchsummaryimportsummarymodel_stats=summary(your_model,(3,28,28),verbose=0)summary_str=str(model_stats)# summary_str contains the string representation of the summary. See below for examples.

雷斯内特

importtorchvisionmodel=torchvision.models.resnet50()summary(model,(3,224,224),depth=3)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 64, 112, 112]        9,408
├─BatchNorm2d: 1-2                       [-1, 64, 112, 112]        128
├─ReLU: 1-3                              [-1, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [-1, 64, 56, 56]          --
├─Sequential: 1-5                        [-1, 256, 56, 56]         --
|    └─Bottleneck: 2-1                   [-1, 256, 56, 56]         --
|    |    └─Conv2d: 3-1                  [-1, 64, 56, 56]          4,096
|    |    └─BatchNorm2d: 3-2             [-1, 64, 56, 56]          128
|    |    └─ReLU: 3-3                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-4                  [-1, 64, 56, 56]          36,864
|    |    └─BatchNorm2d: 3-5             [-1, 64, 56, 56]          128
|    |    └─ReLU: 3-6                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-7                  [-1, 256, 56, 56]         16,384
|    |    └─BatchNorm2d: 3-8             [-1, 256, 56, 56]         512
|    |    └─Sequential: 3-9              [-1, 256, 56, 56]         --
|    |    └─ReLU: 3-10                   [-1, 256, 56, 56]         --

  ...
  ...
  ...

├─AdaptiveAvgPool2d: 1-9                 [-1, 2048, 1, 1]          --
├─Linear: 1-10                           [-1, 1000]                2,049,000
==========================================================================================
Total params: 60,192,808
Trainable params: 60,192,808
Non-trainable params: 0
Total mult-adds (G): 11.63
==========================================================================================
Input size (MB): 0.57
Forward/backward pass size (MB): 344.16
Params size (MB): 229.62
Estimated Total Size (MB): 574.35
==========================================================================================

具有不同数据类型的多个输入

classMultipleInputNetDifferentDtypes(nn.Module):def__init__(self):super().__init__()self.fc1a=nn.Linear(300,50)self.fc1b=nn.Linear(50,10)self.fc2a=nn.Linear(300,50)self.fc2b=nn.Linear(50,10)defforward(self,x1,x2):x1=F.relu(self.fc1a(x1))x1=self.fc1b(x1)x2=x2.type(torch.float)x2=F.relu(self.fc2a(x2))x2=self.fc2b(x2)x=torch.cat((x1,x2),0)returnF.log_softmax(x,dim=1)summary(model,[(1,300),(1,300)],dtypes=[torch.float,torch.long])

或者,您也可以传入输入数据本身,并且 torchsummary将自动推断数据类型。在

input_data=torch.randn(1,300)other_input_data=torch.randn(1,300).long()model=MultipleInputNetDifferentDtypes()summary(model,input_data,other_input_data,...)

探索不同的配置

classLSTMNet(nn.Module):""" Batch-first LSTM model. """def__init__(self,vocab_size=20,embed_dim=300,hidden_dim=512,num_layers=2):super().__init__()self.hidden_dim=hidden_dimself.embedding=nn.Embedding(vocab_size,embed_dim)self.encoder=nn.LSTM(embed_dim,hidden_dim,num_layers=num_layers,batch_first=True)self.decoder=nn.Linear(hidden_dim,vocab_size)defforward(self,x):embed=self.embedding(x)out,hidden=self.encoder(embed)out=self.decoder(out)out=out.view(-1,out.size(2))returnout,hiddensummary(LSTMNet(),(100,),dtypes=[torch.long],branching=False,verbose=2,col_width=16,col_names=["kernel_size","output_size","num_params","mult_adds"],)
========================================================================================================================
Layer (type:depth-idx)                   Kernel Shape         Output Shape         Param #              Mult-Adds
========================================================================================================================
Embedding: 1-1                           [300, 20]            [-1, 100, 300]       6,000                6,000
LSTM: 1-2                                --                   [-1, 100, 512]        3,768,320            3,760,128
  weight_ih_l0                           [2048, 300]
  weight_hh_l0                           [2048, 512]
  weight_ih_l1                           [2048, 512]
  weight_hh_l1                           [2048, 512]
Linear: 1-3                              [512, 20]            [-1, 100, 20]        10,260               10,240
========================================================================================================================
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 3.78
========================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 1.03
Params size (MB): 14.44
Estimated Total Size (MB): 15.46
========================================================================================================================

序列和模块列表

classContainerModule(nn.Module):""" Model using ModuleList. """def__init__(self):super().__init__()self._layers=nn.ModuleList()self._layers.append(nn.Linear(5,5))self._layers.append(ContainerChildModule())self._layers.append(nn.Linear(5,5))defforward(self,x):forlayerinself._layers:x=layer(x)returnxclassContainerChildModule(nn.Module):""" Model using Sequential in different ways. """def__init__(self):super().__init__()self._sequential=nn.Sequential(nn.Linear(5,5),nn.Linear(5,5))self._between=nn.Linear(5,5)defforward(self,x):out=self._sequential(x)out=self._between(out)forlinself._sequential:out=l(out)out=self._sequential(x)forlinself._sequential:out=l(out)returnoutsummary(ContainerModule(),(5,))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─ModuleList: 1                          []                        --
|    └─Linear: 2-1                       [-1, 5]                   30
|    └─ContainerChildModule: 2-2         [-1, 5]                   --
|    |    └─Sequential: 3-1              [-1, 5]                   --
|    |    |    └─Linear: 4-1             [-1, 5]                   30
|    |    |    └─Linear: 4-2             [-1, 5]                   30
|    |    └─Linear: 3-2                  [-1, 5]                   30
|    |    └─Sequential: 3                []                        --
|    |    |    └─Linear: 4-3             [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-4             [-1, 5]                   (recursive)
|    |    └─Sequential: 3-3              [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-5             [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-6             [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-7             [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-8             [-1, 5]                   (recursive)
|    └─Linear: 2-3                       [-1, 5]                   30
==========================================================================================
Total params: 150
Trainable params: 150
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

其他示例

================================================================
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [-1, 1, 16, 16]              10
              ReLU-2            [-1, 1, 16, 16]               0
            Conv2d-3            [-1, 1, 28, 28]              10
              ReLU-4            [-1, 1, 28, 28]               0
================================================================
Total params: 20
Trainable params: 20
Non-trainable params: 0
================================================================
Input size (MB): 0.77
Forward/backward pass size (MB): 0.02
Params size (MB): 0.00
Estimated Total Size (MB): 0.78
================================================================

贡献

感谢所有的问题和请求!如果您想知道如何构建项目:

  • torch summary是使用Python的最新版本积极开发的。
    • 更改应该向后兼容Python3.5,但这可能会在将来发生更改。在
    • 运行pip install -r requirements-dev.txt。我们使用所有dev包的最新版本。在
    • 首先,确保运行./scripts/install-hooks
    • 要运行所有测试并使用自动格式化工具,请查看scripts/run-tests。在
    • 要只运行单元测试,请运行pytest。在

参考文献

  • 感谢@sksq96、@nmhkahn和@sangyx提供了这个项目所基于的原始代码。在
  • 对于模型大小估计@jacokimmel(details here

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java如何使用MVC设计模式观察嵌套对象   java将多个客户端连接到服务器   合并Java Web应用程序   Spring Security中未捕获java AuthenticationSuccessEvent   java Firebase JSON到Arraylist内部的Arraylist,存在对象问题   在Java15的sealedclasses特性中,final类和非密封类之间有什么区别?   java我可以使用数组。copyOf制作二维数组的防御副本?   java球不会在屏幕上移动   Java类如何在同一个文件中包含两个类?   java使用“Character.isWhiteSpace”删除所有空白   java阻止在RealmList中保存时创建领域对象   如何仅在ConnectionFactory上使用Java JMS身份验证   spring可以强制java对象在运行时实现接口吗?   socket无法在JAVA中使用TCP启用双工模式通信