学习笔记16:残差网络

产生背景

随着网络深度的增加,会出现网络退化的现象。
网络退化现象形象化解释是在训练集上的loss不增反降。
这说明,浅层网络的训练效果要好于深层网络
一个想法就是,如果将浅层网络的特征传到深层网络,那么深层网络的训练效果不会比浅层网络差
举个例子,就是假设总共有50层,20层的训练结果就比50层的好了,因此可以将18层与98层之间连接一个直接映射
这样随着网络的加深,训练效果就不会降低了

残差块

残差块的数学表示:
\(x_{l + 1} = x_l + F(x_l, W_l)\)
\(x_l\)相当于是一个直接映射,\(F(x_l, W_l)\)是残差部分

在这个网络结构中,右侧指的就是残差部分,左侧是直接映射

代码实现

class ResnetbasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channnels, out_channels, kernel_size = 3, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = F.relu(self.bn1(out), inplace = True)
        out = self.conv2(out)
        out = F.relu(self.bn2(out), inplace = True)
        out = out + residual
        return F.relu(out)
posted @ 2021-02-03 17:21  pbc的成长之路  阅读(112)  评论(0编辑  收藏  举报