Resnet网络--BasicBlock与BottleNeck
ResNetV2的网络深度有18,34,50,101,152。50层以下的网络基础块是BasicBlock,50层及以上的网络基础块是BottleNeck。
BasicBlock
图示如下
代码实现
1 class BasicBlock(nn.Module): 2 expansion = 1 3 def __init__(self, in_channel, out_channel, stride=1, downsample=None): 4 super(BasicBlock, self).__init__() 5 self.conv1 = conv3x3(in_channel, out_channel, stride) 6 self.bn1 = nn.BatchNorm2d(out_channel) 7 self.relu = nn.ReLU(inplace=True) 8 self.conv2 = conv3x3(out_channel, out_channel) 9 self.bn2 = nn.BatchNorm2d(out_channel) 10 self.downsample = downsample 11 self.stride =stride 12 13 def forward(self, x): 14 residual = x 15 out = self.conv1(x) 16 out = self.bn1(out) 17 out = self.relu(out) 18 out = self.conv2(out) 19 out = self.bn2(out) 20 if self.downsample is not None: 21 residual = self.downsample(x) 22 23 out = out + residual 24 out = self.relu(out) 25 26 return out
BottleNeck
图示如下
代码实现:
1 class Bottleneck(nn.Module): 2 3 expansion = 4 4 5 def __init__(self, in_channel, out_channel, stride=1, downsample=None): 6 super(Bottleneck, self).__init__() 7 8 self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False) 9 self.bn1 = nn.BatchNorm2d(out_channel) 10 11 self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False) # stride = 3 12 self.bn2 = nn.BatchNorm2d(out_channel) 13 14 self.conv3 = nn.Conv2d(out_channel, out_channel * 4, kernel_size=1, bias=False) 15 self.bn3 = nn.BatchNorm2d(out_channel * 4) 16 17 self.relu = nn.ReLU(inplace=True) 18 self.stride = stride 19 self.downsample =downsample 20 21 22 def forward(self, x): 23 residual = x 24 25 out = self.conv1(x) 26 out = self.bn1(out) 27 out = self.relu(out) 28 29 out = self.conv2(out) 30 out = self.bn2(out) 31 out = self.relu(out) 32 33 out = self.conv3(out) 34 out = self.bn3(out) 35 36 if self.downsample is not None: 37 residual = self.downsample(x) 38 39 out = out + residual 40 out = self.relu(out) 41 42 return out