ResNet详解与实现

1、前言

ResNet是何恺明等人于2015年提出的神经网络结构,该网络凭借其优秀的性能夺得了多项机器视觉领域竞赛的冠军,而后在2016年发表的论文《Deep Residual Learning for Image Recognition》也获得了CVPR2016最佳论文奖。本文整理了笔者对ResNet的理解,详细解释了ResNet34、ResNet50等具体结构,并使用PyTorch实现了一个使用ResNet训练CIFAR-10数据集的具体实例。

2、深度残差学习(Deep Residual Learning)

卷积神经网络在图像分类领域具有非常广泛的应用,从理论上讲,越深的网络结构,其拟合能力应该越强,如16层的VGG16的拟合能力要强于5层的LeNet。然而,何恺明等人通过实验发现,当网络的深度达到一定程度时,网络的性能不升反降,并且这种性能的下降不是由过拟合引起的,因为深度网络的训练误差和测试误差都比浅层网络高,如图1所示。

为改善上述问题,何恺明等人提出了深度残差学习框架。

2.1 残差学习

常规的神经网络,是将卷积层、全连接层等结构按照一定的顺序简单地连接到一起,每层结构仅接受来自上一层的信息,并在本层处理后传递给下一层。当网络层次加深后,这种单一的连接方式会导致神经网络性能退化。所谓的残差学习,就是在上述的单一连接方式的基础上,加入了“短连接”(shortcut connections),如图所示。

短连接能跨越几个层,将输入x直接映射到输出端(类似于电路中的短路),与输出相加。这样做导致的直接结果就是,在神经网络中加入上述的结构不再会导致神经网络的退化,考虑最坏的情况也不过是F(x)为0,这相当于网络没有加深,与之前一致。但是如果加入的层有效,就能使网络的性能得到提升。从数学的角度来看,上述结构的输入为x,输出H(x)为:

\[H(x) = F(x)+x \]

其中F(x)称为残差函数,变换上式可得:

\[F(x) = H(x)-x \]

显然,该结构中加入的短连接是没有参数的,因此学习H(x)的参数与学习F(x)的参数本质上是一样的,但是将函数优化为0或者在0附近显然更容易。因此,残差学习,学习的就是残差函数F(x)的参数。

2.2 ResNet基本单元

论文原文中给出了两种ResNet的基本单元结构,其中左边的单元用于较浅的网络,如ResNet18、ResNet34,右边的单元则用于较深的网络,如ResNet50、ResNet101等。

首先,无论是左边还是右边的单元,都是由卷积层和激活函数构成的。左边的单元由2个大小为3x3的卷积层与2个ReLU激活函数构成,右边的单元由2个1x1、1个3x3的卷积层和3个ReLU激活函数构成。至于卷积核的的通道(channel)数,则由单元的具体位置决定。通常卷积神经网络在提取图像特征的过程种,卷积核都会从“大卷积核、低通道数量“层层递进到到”小卷积核、多通道数量“。因此,在实际应用的过程种,短连接跨越的卷积层的通道数量有所改变是非常常见的。原论文中给出了直接连接、填零连接和映射连接三种具体的短连接形式。当实际单元的输入通道数与输出相同时,如上图左,则短连接直接连接即可。若输入通道数与输出不相同,如上图右,输入的维度为64,输出却是256,此时输入x无法与F(x)直接相加。填零连接就是将多出的维度全部填充0后再进行相加,这个方法不会引入额外的参数。映射连接则是通过矩阵变换,利用大小为1x1的卷积层扩展输入的维度后再进行相加。短连接跨越通道数不同的卷积层时,卷积执行的步长为2,这刚好缩小特征矩阵的大小,并增加了特征的通道数。具体的细节将在下文中的代码实现,此处不再赘述。

2.3 ResNet

原论文中给出了如下5个具体的ResNet网络,这些网络都由上述的基本单元构成。

为了便于阐述,我们直接来讨论ResNet34,该网络主要由16个基本单元、1个7x7卷积层与1个全连接层构成,共计34层。

其中,实线的短连接表示直接连接,前后通道数未发生改变。虚线的短连接则表示前后维度发生了改变,需要通过填0或者映射后进行连接。

3、代码实现

上面讲了这么多白开水般的原理,相信各位看官早已疲倦,接下来上紧张刺激的代码,毕竟看了那么多,最终目的还是为了应用。对于ResNet34,其基本单元实现代码如下。

class  ResidualBlock(nn.Module):
    expansion = 1 # 扩展系数
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        # 短连接方式
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1,
                          stride=stride,bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x = x + self.shortcut(x)
        out = F.relu(x)

        return out

只要理解了上文所叙述的原理,代码还是很容易理解的。其中的扩展系数expansion表示的是单元输出与输入张量的通道数之比,对于ResNet34,这个比是1。而对于ResNet50,这个比是4。如果无法理解,可以回顾上文2.2中的图。ResNet50的基本单元代码如下。

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, self.expansion *
                               out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion*out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion*out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

有了建房子的”砖块“,就可以着手建房了,ResNet的实现代码如下。

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self.__make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self.__make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self.__make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self.__make_layer(block, 512, num_blocks[3], stride=2)
        self.fc = nn.Linear(512*block.expansion, num_blocks)

    def __make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels*block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 4)
        out = self.fc(x)

        return out

该代码可以实现论文中的5个网络ResNet18、ResNet34、ResNet50、ResNet101、ResNet152。ResNet34的实例化代码如下:

resnet34 = ResNet(ResidualBlock, [3, 4, 6, 3])

具体的参数num_blocks可查阅2.3中的表确定。

搭建好了神经网络,就可以用于训练模型了。ResNet34训练CIFAR-10的例程请点击此处访问。本文代码参考了Gihub上的项目pytorch-cifar与torchvision.models中的代码,特此声明,以示感谢。

4、总结

本文介绍了一些我在学习过程中对于ResNet的理解,并给出了代码与具体的实例。从实质上讲,ResNet与传统卷积神经网络的主要区别就是引入了短连接,而这个短连接极大地提升了神经网络的性能。至于数学原理,本文并未深究,毕竟包括我在内的许多学习者都是关心如何应用,而不是深挖那些数学公式的来龙去脉。

由于本人才疏学浅,如有疏漏谬误之处,还请指出,万分感谢。

5、参考

文献:He K, Zhang X, Ren S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.

代码:https://github.com/kuangliu/pytorch-cifar

posted @ 2022-03-21 10:52  菜鸡刘  阅读(4885)  评论(1编辑  收藏  举报