Resnet网络--BasicBlock与BottleNeck

ResNetV2的网络深度有18,34,50,101,152。50层以下的网络基础块是BasicBlock,50层及以上的网络基础块是BottleNeck。

BasicBlock

图示如下

 

代码实现

 1 class BasicBlock(nn.Module):
 2     expansion = 1
 3     def __init__(self, in_channel, out_channel, stride=1, downsample=None):
 4         super(BasicBlock, self).__init__()
 5         self.conv1 = conv3x3(in_channel, out_channel, stride)
 6         self.bn1 = nn.BatchNorm2d(out_channel)
 7         self.relu = nn.ReLU(inplace=True)
 8         self.conv2 = conv3x3(out_channel, out_channel)
 9         self.bn2 = nn.BatchNorm2d(out_channel)
10         self.downsample = downsample
11         self.stride =stride
12 
13     def forward(self, x):
14         residual = x
15         out = self.conv1(x)
16         out = self.bn1(out)
17         out = self.relu(out)
18         out = self.conv2(out)
19         out = self.bn2(out)
20         if self.downsample is not None:
21             residual = self.downsample(x)
22 
23         out = out + residual
24         out = self.relu(out)
25 
26         return out

 

BottleNeck

图示如下

 

 代码实现:

 1 class Bottleneck(nn.Module):
 2 
 3     expansion = 4
 4 
 5     def __init__(self, in_channel, out_channel, stride=1, downsample=None):
 6         super(Bottleneck, self).__init__()
 7 
 8         self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False)
 9         self.bn1 = nn.BatchNorm2d(out_channel)
10 
11         self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)  # stride = 3
12         self.bn2 = nn.BatchNorm2d(out_channel)
13 
14         self.conv3 = nn.Conv2d(out_channel, out_channel * 4, kernel_size=1, bias=False)
15         self.bn3 = nn.BatchNorm2d(out_channel * 4)
16 
17         self.relu = nn.ReLU(inplace=True)
18         self.stride = stride
19         self.downsample =downsample
20 
21 
22     def forward(self, x):
23         residual = x
24 
25         out = self.conv1(x)
26         out = self.bn1(out)
27         out = self.relu(out)
28 
29         out = self.conv2(out)
30         out = self.bn2(out)
31         out = self.relu(out)
32 
33         out = self.conv3(out)
34         out = self.bn3(out)
35 
36         if self.downsample is not None:
37             residual = self.downsample(x)
38 
39         out = out + residual
40         out = self.relu(out)
41 
42         return out

 

posted @ 2022-03-09 11:00  奋斗的小仔  阅读(3013)  评论(0编辑  收藏  举报