CNN可学习参数个数的计算
博客搬家至 Mun: https://kiddie92.github.io
简书同步更新
前几天听室友给我讲算法岗的面经,其中面试官就问了一个小问题,“给出CNN网络的参数(可学习的)个数如何计算”,今天就来计算一下好了。
问题描述
可学习参数顾名思义就是指CNN中需要学习/更新的变量,因为CNN的网络架构设计中会引入很多需要被学习出来的变量,比如:hidden layer
中的神经元个数便直接和仿射变换的参数个数相关,而现在的问题是把这些可学习的变量的个数统计出来。
卷积神经网络架构
卷积神经网络在计算机视觉领域有比较多的应用,下图便是一个图片识别的网络架构示例图(工业界使用的模型更复杂)。
上图描述了卷积神经网络在进行正向计算/正向传播时的流程/架构,如图所示,当输入一个"小轿车"的图片时,我们希望经过一个函数各种计算后,能够输出“CAR”这个词。那么该如何计算呢?
-
输入层(input):
就是读取图片,将图片用数字化的矩阵来表示。 -
卷积(convolution):
选用卷积核(filter,可以是多个)对图片的多个通道进行卷积操作(element-wise的相乘)。卷积计算会使图片的长宽变小,但是"高度"变大(如图中的图片逐渐变"厚"),这是因为使用的卷积核(filter)较多,使得计算得到的图片通道数(channels)也会增加。
卷积操作其实可以理解为简化版的"连接层",部分神经元才可以和下一层的部分神经元进行连接。
-
激活(activation):
该操作主要是对之前的卷积计算结果做非线性处理,万能逼近原理告诉我们这种线性和非线性计算的组合可以拟合所有复杂的函数。常用的非线性处理函数/激活函数有Sigmoid、Relu、Leaky ReLU、tanh等,更多内容可以参考这里。 -
池化(pooling):
对非线性化后的高维矩阵进行"减采样",同样以一定步长逐步将矩阵中的"元素块(例如:)"仅使用一个数来代表,比如:取"元素块"中的最大值、平均值等计算方式。
降采样可以减少后续的计算量还可以一定程度防止过拟合。
-
拉平(Flatten):
将高维矩阵"拉平",转换为一维矩阵,元素依次排序。 -
全连接(Fully Connected):
设置下一层神经元的个数,并使用仿射变换得到下一层神经元的值,因为两层之间的神经元会全部连接起来,所及叫做全连接。 -
输出层计算分类概率(Softmax):
对最后一层的神经元进行概率输出计算,即:给出各分类标签的概率,比如这里预期"Car"的概率一定要大于其他分类标签的概率值,所以最后一层的神经元个数和分类的标签个数需要一致。下面是softmax函数的表达式: -
Batch-Normalization:
这一层其实在每一次卷积、全连接后都可以进行计算,但是图片中没有反映处这个处理过程。感兴趣可以查看我之前的博客卷积神经网络之Batch Normalization(一):How?
参数个数计算
按照上面的步骤描述,可知:
- input层是没有引进变量的;
- 卷积层则引入了"卷积核/filter",假设卷积核大小为,图片有个通道(channels)/维度,而选用的"卷积核/filter"有个,再加上bias,可以得到引进的参数有:个;
- 激活、池化层仅仅对原来的矩阵做了一个变换,不会引进新的参数;
- 拉平操作仅仅对矩阵进行了reshape,也不会引进新的变量;
- 全连接层就是对前后神经元做了仿射变换,引进的参数有权重和偏置,假设n个神经元连接m个神经元,则引入的参数有;
- 输出层其实和全连接层没啥区别,只是输出的神经元个数要求是分类的标签个数,所以引入的变量也是,这里m是分类的标签个数;
- BN层引入的参数则和输入层神经元个数相关,假设输入神经元个数为n,则该层引进的参数为
综合以上,计算一个CNN架构的所有可学习参数/变量的个数可以分解成每一个步骤的参数变化量和引入参数个数两个相关的小问题,就像这样:
# name size parameters
--- -------- ------------------------- ------------------------
0 input 1x28x28 0
1 conv2d1 (28-(5-1))=24 -> 32x24x24 (5*5*1+1)*32 = 832
2 maxpool1 32x12x12 0
3 conv2d2 (12-(3-1))=10 -> 32x10x10 (3*3*32+1)*32 = 9'248
4 maxpool2 32x5x5 0
5 dense 256 (32*5*5+1)*256 = 205'056
6 output 10 (256+1)*10 = 2'570
最后将每一个步骤的参数相加便得到所有参数的个数。
参考资料
- https://stackoverflow.com/questions/42786717/how-to-calculate-the-number-of-parameters-for-convolutional-neural-network
- https://medium.freecodecamp.org/an-intuitive-guide-to-convolutional-neural-networks-260c2de0a050