网络中的参数量(param)和浮点计算量(FLOPs)的计算

0️⃣前言

本文介绍了卷积层和全连接层的参数量和计算量,同时还提到了浮点计算量(FLOPs)的计算。

1️⃣概念

  • 网络中的参数量(param)对应与空间Space概念,及空间复杂度。
  • 浮点计算量(FLOPs)对应与时间Time概念,对应于时间复杂度。

即,网络参数量(param)和显存密切相关;浮点计算量(FLOPs)和GPU的计算速度相关。

读前需知(为了简单方便而已)

  1. 这篇文章都是在不考虑偏置项的情况下分析参数量;
  2. 在stride=1,不考虑卷积中的加法操作分析计算量;
  3. 特别在计算量的时候会有提到

2️⃣如何计算网络中的参数量(param)

网络中的参数量(param)的计算
网络中参数计算需要分为

🐱‍👤2.1卷积层:

需要关注的参数为(kernel_size,in_channel,out_channel)

计算公式:
完全版:conv_param=(kernel_size∗in_channel+bias)∗out_channel,默认bias=1,out_channel是filter(代表卷积核个数),且每个卷积核都有对应的bias。

简略版:conv_param=kernel_size∗in_channel∗out_channel,因为bias不会影响数量级的变化,一般可省略。

🐱‍👤2.2池化层:

池化层不需要参数。例如 max_pooling:直接最大化池化就可以,无需参数。

🐱‍👤2.3全连接层:

全连接层有两种情况,一种是卷积层到全连接层,一种是全连接层到全连接层,因此需要分情况来讨论:

  • CONV->FC 及计算公式

Conv_FC_param=feturemap_size∗in_channel∗out_neural
feturemap_size : 前一层特征图尺寸
in_channel : 前一层卷积核个数
out_neural : 全连接层神经元个数

  • FC->FC 及计算公式

FC_FC_param=in_neura∗∗out_neural−bias
bias = out_neural,每个神经元都有一个bias。一般可忽略bias。

3️⃣如何计算网络中的计算量

🐱‍👓3.1一次卷积的计算量,如何计算呢?

以VGG-16为例,Conv1-1,输入 224×224×3 ,64个 3×3 filter,输出feature map 224×224×64

feature map中的每一个像素点,都是64个 3×3 filter 共同作用于原图计算一次得到的,所以它的计算量为 3×3×3

这里解释一下为什么是3×3×3,因为其他地方可能会写成3×3×64。因为3×3 filter作用于原图,计算一次,应该再乘上输入层通道数3,而不应该再承上卷积核的个数64(也就是输出通道数),64已经在输出feature map的通道数64被考虑到了(自己薄见,欢迎指正)

已经知道单个像素的计算量,那乘以feature map所有像素,就是一次卷积的计算量:3×3×3×224×224×64

计算公式:计算量 = 输出的feature map * 当前层filter大小 * 输入层的通道数

这仅仅是单个样本前向传播计算量,实际计算量还应乘以batch size


其实我这里还是不够严谨,这里提两个延伸:

1.如果考虑加法运算,那么64个 3×3 filter共同作用于原图的计算量就是 [3×3×3 + 3×3×3-1]=53,后面还有一个3×3×3-1 是因为加法操作导致的。那么最后的计算量就是[3×3×3 + 3×3×3-1]×224×224×64 = 53×224×224×64

2.如果考虑步长不是等于1呢?这个问题自己暂时也没有搞清楚,欢迎评论


🐱‍👓3.2全连接层的计算量

VGG-16最后一次卷积得到的feature map为 7×7×512,全连接层是将feature map展开成一维向量 1×4096 。则FC层的计算量为7×7×512×4096

通过以上讨论可以发现:我们需要减少网络参数时主要针对全连接层;进行计算优化时,重点放在卷积层。

4️⃣计算参数量和计算量(pytorch库)

例如torchstat、thop、ptflops、torchinfo等等都可以计算

🐱‍🚀4.1Torchstat

from torchstat import stat

# 导入模型,输入一张输入图片的尺寸
stat(model, (1, 28, 28))

🐱‍🚀4.2Torchinfo

import torch
import torchvision
from torchinfo import summary

# Model
print('==> Building model..')
model = torchvision.models.alexnet(pretrained=False)

dummy_input = torch.randn(1, 3, 224, 224)
print(summary(model, dummy_input, show_input=False, show_hierarchical=False))

🐱‍🚀4.3Thop

第一步:安装模块

pip install thop

第二步:计算

from thop import profile

 # 导入模型,输入一张输入图片的尺寸
input = torch.randn(1, 3, 300, 300).cuda()  # 输入input的第一维度是批量(batch size),批量的大小不回影响参数量, 计算量是batch_size=1的倍数
flop, para = profile(model, inputs=(input, ))  # 必须加上逗号,否者会报错
print('Flops:',"%.2fM" % (flop/1e6), 'Params:',"%.2fM" % (para/1e6))
total = sum([param.nelement() for param in model.parameters()])
print('Number of parameter: %.2fM' % (total/1e6))

🐱‍🚀4.4Ptflops

# -- coding: utf-8 --
import torchvision
from ptflops import get_model_complexity_info

model = torchvision.models.alexnet(pretrained=False)
flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
print('flops: ', flops, 'params: ', params)

🐱‍🚀参数总量和可训练参数总量

import torch
import torchvision
from pytorch_model_summary import summary

# Model
print('==> Building model..')
model = torchvision.models.alexnet(pretrained=False)

pytorch_total_params = sum(p.numel() for p in model.parameters())
trainable_pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Total - ', pytorch_total_params)
print('Trainable - ', trainable_pytorch_total_params)

5️⃣输入数据对模型的参数量和计算量的影响

# -- coding: utf-8 --
import torch
import torchvision
from thop import profile

# Model
print('==> Building model..')
model = torchvision.models.alexnet(pretrained=False)

dummy_input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, (dummy_input,))
print('flops: ', flops, 'params: ', params)
print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))
  • 输入数据:(1, 3, 224, 224),一张224*224的RGB图像
flops:  714691904.0 params:  61100840.0
flops: 714.69 M, params: 61.10 M
  • 输入数据:(1, 3, 512, 512),一张512*512的RGB图像
flops:  3710034752.0 params:  61100840.0
flops: 3710.03 M params: 61.10 M
  • 输入数据:(8, 3, 224, 224),八张224*224的RGB图像
flops:  5717535232.0 params:  61100840.0
flops: 5717.54 M params: 61.10 M
输入数据计算量(flops)参数量(params)
(1, 3, 224, 224)714.69 M61.10 M
(1, 3, 512, 512)3710.03 M61.10 M
(8, 3, 224, 224)5717.54 M61.10 M

可见输入数据影响到计算量,对于参数量并没有影响

6️⃣引用

[0] https://blog.csdn.net/qq_44554428/article/details/123121220
[1] https://zhuanlan.zhihu.com/p/77471991
[2] https://blog.csdn.net/qq_40507857/article/details/118764782
[3] https://www.cnblogs.com/lllcccddd/p/10671879.html
[4] https://blog.csdn.net/Caesar6666/article/details/109842379

posted @ 2022-04-24 08:30  小Aer  阅读(321)  评论(0编辑  收藏  举报  来源