火把:火把模型分析仪。
torchstat的Python项目详细描述
火炬塔
这是一个基于pytorch的轻量级神经网络分析仪。 它的目的是使您的网络建设迅速和容易,并有能力调试他们。 注意:此存储库当前正在开发中。因此,一些api可能会被更改。
此工具可以显示
- 网络参数总数
- 浮点运算理论量(flops)
- 理论乘加量(madd)
- 内存使用量
安装
有两种方法可以将torchstat安装到您的环境中。
- 通过PIP安装。
$ pip install torchstat
- 克隆此存储库后,使用setup.py安装并更新。
$ python3 setup.py install
一个简单的例子
<>如果你想运行TrCHSTAT ASAP,如果你的网络存在于脚本中,你可以把它称为CLI工具。 否则需要将torchstat作为模块导入。cli工具
$ torchstat masato$ torchstat -f example.py -m Net [MAdd]: Dropout2d is not supported! [Flops]: Dropout2d is not supported! [Memory]: Dropout2d is not supported! module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)0 conv1 322422410220220760.0 1.85 72,600,000.0 36,784,000.0 605152.0 1936000.0 57.49% 2541152.0 1 conv2 10110110201061065020.0 0.86 112,360,000.0 56,404,720.0 504080.0 898880.0 26.62% 1402960.0 2 conv2_drop 20106106201061060.0 0.86 0.0 0.0 0.0 0.0 4.09% 0.0 3 fc1 56180502809050.0 0.00 5,617,950.0 2,809,000.0 11460920.0 200.0 11.58% 11461120.0 4 fc2 5010510.0 0.00 990.0 500.0 2240.0 40.0 0.22% 2280.0 total 2815340.0 3.56 190,578,940.0 95,998,220.0 2240.0 40.0 100.00% 15407512.0 =============================================================================================================================================== Total params: 2,815,340 ----------------------------------------------------------------------------------------------------------------------------------------------- Total memory: 3.56MB Total MAdd: 190.58MMAdd Total Flops: 96.0MFlops Total MemR+W: 14.69MB
如果不确定如何使用特定命令,请使用-h或-help开关运行该命令。 您将看到使用信息和可用于该命令的选项列表。
模块
fromtorchstatimportstatimporttorchvision.modelsasmodelsmodel=models.resnet18()stat(model,(3,224,224))
功能和待办事项
注意:这些功能仅适用于nn.module。尚不支持torch.nn.函数中的模块。
- [X]触发器
- [X]参数数
- [X]总内存
- [X]MADD(FMA)
- [X]内存读取
- [X]memwrite
- []模型摘要(细节,分层)
- []导出分数表
- []任意输入形状
对于支持的层,请签出the details。
要求
- Python3.6+
- 喷灯0.4.0+
- 熊猫0.23.4+
- 纽比1.14.3+
参考文献
感谢@sovrasov提供了flops计算的初始版本,@ceykmc提供了脚本的主干。