神经网络性能评估

1 torchstat

该工具包可通过pip直接安装:

pip install torchstat    

使用方法

import torchvision.models as models
#pretrained=True就可以使用预训练的模型
#resnet18 = models.resnet18(pretrained=True)
resnet18 = models.resnet18()

from torchstat import stat
# 第一个参数为待分析的模型,另一个参数表示输入图片的大小
stat(resnet18, (3, 224, 224)) 

分析的效果如下:

从分析结果可以看出,torchstat的功能非常强大,不仅可以实现FLOPs、参数量、MAdd、显卡内存占用量等模型参数的分析,还可以看到模型每一层的分析结果,工具包不支持的layer也会列在分析结果前提醒使用者。

虽然torchstat的功能十分强大,但是也有一些缺陷:

  1. 限制模型输入仅能为图片

  2. 限制模型每一个layer的输入须为单个变量

  3. 对Pytorch-0.4.1及以下版本的支持不足

以上这些缺陷是在实践中发现的,具体表现为程序报错。如果修改模型也无法适配torchstat,这时就要考虑另选分析工具。

2 thop

对于torchstat无法适用的模型某一个layer的输入为多个变量和Pytorch-0.4.1版本等情况,可以尝试使用thop工具包进行模型分析。

安装

pip install thop                          

thop工具包相对torchstat而言,功能较为简单,仅支持FLOPs和参数量的计算(或者是我没有发现,不过我看源码是只返回这俩参量)。thop工具包的使用方法如下

from thop import profile                            
from thop import clever_format
import torchvision.models as models

resnet18 = models.resnet18()

input = torch.randn(1, 3, 224, 224)                        
flops, params = profile(resnet18, inputs=(input, ))                   
print(flops,params) 
flops,params = clever_format([flops, params],"%.3f")
print(flops,params)   

结果如下:

推荐首选torchstat进行模型分析,如果出现无法解决的程序报错,再尝试使用thop

3 ptflops

使用方法如下:

from ptflops import get_model_complexity_info
import torchvision.models as models

resnet18 = models.resnet18()
flops, params = get_model_complexity_info(resnet18, (3, 224, 224), as_strings=True, print_per_layer_stat=True)                             
print('flops: ', flops, 'params: ', params)

结果如下:

4 pytorch_model_summary

安装

 pip install pytorch_model_summary

使用

from pytorch_model_summary import summary
import torchvision.models as models

resnet18 = models.resnet18()
nc, nh, nw = 3, 513, 513

batch_size = 1  # 批处理大小
input_shape = (nc, nh, nw)  # 输入数据
# set the model to inference mode
resnet18.eval()
inputdata = torch.randn(1, *input_shape)  # 生成张量

print(summary(resnet18, inputdata, show_input=False, show_hierarchical=False))
posted @ 2022-08-11 17:16  Truman001  阅读(510)  评论(0编辑  收藏  举报