深度学习(统计模型参数量)

统计模型参数量,方便判断不同模型大小:

import torch
import torch.nn as nn

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 96, kernel_size=11, stride=4)                 #1*96*(11*11)+96=11712
        self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1,padding=2)      #96*256*(5*5)+256=614656
        self.conv3 = nn.Conv2d(256, 384, kernel_size=3,stride=1,padding=1)      #256*384*3*3+384=885120
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3,stride=1,padding=1)      #384*384*3*3+384=1327488
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3,stride=1,padding=1)      #384*256*3*3+256=884992

        self.fc1 = nn.Linear(256*6*6, 4096)                         #256*6*6*4096+4096=37752832
        self.fc2 = nn.Linear(4096, 4096)                    #4096*4096+4096=16781312
        self.fc3 = nn.Linear(4096, 2)                       #4096*2+2=8194

    def forward(self, x):
        x = torch.relu(self.conv1(x))           # 227*227   -> 96*55*55
        x = torch.max_pool2d(x, 3,stride=2)     # 96*55*55  -> 96*27*27
        x = torch.relu(self.conv2(x))           # 96*27*27  -> 256*27*27
        x = torch.max_pool2d(x, 3,stride=2)     # 256*27*27 -> 256*13*13
        x = torch.relu(self.conv3(x))           # 256*13*13 -> 384*13*13
        x = torch.relu(self.conv4(x))           # 384*13*13 -> 384*13*13
        x = torch.relu(self.conv5(x))           # 384*13*13 -> 256*13*13
        x = torch.max_pool2d(x,3,stride=2)      # 256*13*13 -> 256*6*6
        x = x.view(x.size(0), -1)               # 256*6*6   -> 9216
        x = torch.relu(self.fc1(x))             # 9216      -> 4096
        x = torch.relu(self.fc2(x))             # 4096      -> 4096
        x = self.fc3(x)                         # 4096      -> 2
        return x

net = AlexNet()
total = sum([param.nelement() for param in net.parameters()])
print("Number of parameter:" ,total)   
posted @ 2023-10-27 21:01  Dsp Tian  阅读(53)  评论(0编辑  收藏  举报