笨方法实现resnet18

import torch


class myResNet(torch.nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(myResNet, self).__init__()
        # 第1层
        self.conv0_1 = torch.nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn0_1 = torch.nn.BatchNorm2d(64)
        self.relu0_1 = torch.nn.ReLU()
        self.dmp = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 第2 3 层
        self.conv1_1 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn1_1 = torch.nn.BatchNorm2d(64)
        self.relu1_1 = torch.nn.ReLU()
        self.conv1_2 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn1_2 = torch.nn.BatchNorm2d(64)
        self.relu1_2 = torch.nn.ReLU()

        # 第4 5层
        self.conv2_1 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2_1 = torch.nn.BatchNorm2d(64)
        self.relu2_1 = torch.nn.ReLU()
        self.conv2_2 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2_2 = torch.nn.BatchNorm2d(64)
        self.relu2_2 = torch.nn.ReLU()

        # 第6 7层
        self.conv3_0 = torch.nn.Conv2d(64, 128, kernel_size=1, stride=2)
        self.conv3_1 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3_1 = torch.nn.BatchNorm2d(128)
        self.relu3_1 = torch.nn.ReLU()
        self.conv3_2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn3_2 = torch.nn.BatchNorm2d(128)
        self.relu3_2 = torch.nn.ReLU()

        # 第8 9层
        self.conv4_1 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4_1 = torch.nn.BatchNorm2d(128)
        self.relu4_1 = torch.nn.ReLU()
        self.conv4_2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4_2 = torch.nn.BatchNorm2d(128)
        self.relu4_2 = torch.nn.ReLU()

        # 第10 11层
        self.conv5_0 = torch.nn.Conv2d(128, 256, kernel_size=1, stride=2)
        self.conv5_1 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.bn5_1 = torch.nn.BatchNorm2d(256)
        self.relu5_1 = torch.nn.ReLU()
        self.conv5_2 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn5_2 = torch.nn.BatchNorm2d(256)
        self.relu5_2 = torch.nn.ReLU()

        # 第12 13层
        self.conv6_1 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6_1 = torch.nn.BatchNorm2d(256)
        self.relu6_1 = torch.nn.ReLU()
        self.conv6_2 = torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6_2 = torch.nn.BatchNorm2d(256)
        self.relu6_2 = torch.nn.ReLU()

        # 第14 15层
        self.conv7_0 = torch.nn.Conv2d(256, 512, kernel_size=1, stride=2)
        self.conv7_1 = torch.nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.bn7_1 = torch.nn.BatchNorm2d(512)
        self.relu7_1 = torch.nn.ReLU()
        self.conv7_2 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn7_2 = torch.nn.BatchNorm2d(512)
        self.relu7_2 = torch.nn.ReLU()

        # 第16 17层
        self.conv8_1 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn8_1 = torch.nn.BatchNorm2d(512)
        self.relu8_1 = torch.nn.ReLU()
        self.conv8_2 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn8_2 = torch.nn.BatchNorm2d(512)
        self.relu8_2 = torch.nn.ReLU()

        # 第18层
        self.fc = torch.nn.Linear(512, num_classes)

    def forward(self, x):  # batch_size, 3, 224, 224
        x = self.conv0_1(x)   # bs, 64, 112, 112
        x = self.bn0_1(x)
        x = self.relu0_1(x)
        x1 = self.dmp(x)  # bs, 64, 56, 56

        x = self.conv1_1(x1)  # bs, 64, 56, 56
        x = self.bn1_1(x)
        x = self.relu1_1(x)
        x = self.conv1_2(x)
        x = self.bn1_2(x)
        x = x + x1
        x2 = self.relu1_2(x)

        x = self.conv2_1(x2)
        x = self.bn2_1(x)
        x = self.relu2_1(x)
        x = self.conv2_2(x)
        x = self.bn2_2(x)
        x = x + x2
        x = self.relu2_2(x)  # bs, 64, 56, 56

        x3 = self.conv3_0(x)  # bs, 128, 28, 28
        x = self.conv3_1(x)
        x = self.bn3_1(x)
        x = self.relu3_1(x)
        x = self.conv3_2(x)
        x = self.bn3_2(x)
        x = x + x3
        x4 = self.relu3_2(x)

        x = self.conv4_1(x4)
        x = self.bn4_1(x)
        x = self.relu4_1(x)
        x = self.conv4_2(x)
        x = self.bn4_2(x)
        x = x + x4
        x = self.relu4_2(x)  # bs, 128, 28, 28

        x5 = self.conv5_0(x)  # bs, 256, 14, 14
        x = self.conv5_1(x)
        x = self.bn5_1(x)
        x = self.relu5_1(x)
        x = self.conv5_2(x)
        x = self.bn5_2(x)
        x = x + x5
        x6 = self.relu5_2(x)

        x = self.conv6_1(x6)
        x = self.bn6_1(x)
        x = self.relu6_1(x)
        x = self.conv6_2(x)
        x = self.bn6_2(x)
        x = x + x6
        x = self.relu6_2(x)  # bs, 256, 14, 14

        x7 = self.conv7_0(x)  # bs, 512, 7, 7
        x = self.conv7_1(x)
        x = self.bn7_1(x)
        x = self.relu7_1(x)
        x = self.conv7_2(x)
        x = self.bn7_2(x)
        x = x + x7
        x8 = self.relu7_2(x)

        x = self.conv8_1(x8)
        x = self.bn8_1(x)
        x = self.relu8_1(x)
        x = self.conv8_2(x)
        x = self.bn8_2(x)
        x = x + x8
        x = self.relu8_2(x)  # bs, 512, 7, 7

        x = torch.nn.functional.avg_pool2d(x, (x.shape[-2], x.shape[-1]))
        x = torch.flatten(x, 1, -1)
        x = self.fc(x)
        return x


if __name__ == "__main__":
    tx = torch.randn((4, 3, 224, 224))
    algo = myResNet()
    pred = algo(tx)
    print(pred.shape)

参考地址:https://mp.weixin.qq.com/s/eWeVWcEMLC9FIiFqKy5wqA

posted @ 2024-10-14 15:54  ddzhen  阅读(24)  评论(0编辑  收藏  举报