深度学习(统计模型参数量)
统计模型参数量,方便判断不同模型大小:
import torch import torch.nn as nn class AlexNet(nn.Module): def __init__(self): super(AlexNet, self).__init__() self.conv1 = nn.Conv2d(1, 96, kernel_size=11, stride=4) #1*96*(11*11)+96=11712 self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1,padding=2) #96*256*(5*5)+256=614656 self.conv3 = nn.Conv2d(256, 384, kernel_size=3,stride=1,padding=1) #256*384*3*3+384=885120 self.conv4 = nn.Conv2d(384, 384, kernel_size=3,stride=1,padding=1) #384*384*3*3+384=1327488 self.conv5 = nn.Conv2d(384, 256, kernel_size=3,stride=1,padding=1) #384*256*3*3+256=884992 self.fc1 = nn.Linear(256*6*6, 4096) #256*6*6*4096+4096=37752832 self.fc2 = nn.Linear(4096, 4096) #4096*4096+4096=16781312 self.fc3 = nn.Linear(4096, 2) #4096*2+2=8194 def forward(self, x): x = torch.relu(self.conv1(x)) # 227*227 -> 96*55*55 x = torch.max_pool2d(x, 3,stride=2) # 96*55*55 -> 96*27*27 x = torch.relu(self.conv2(x)) # 96*27*27 -> 256*27*27 x = torch.max_pool2d(x, 3,stride=2) # 256*27*27 -> 256*13*13 x = torch.relu(self.conv3(x)) # 256*13*13 -> 384*13*13 x = torch.relu(self.conv4(x)) # 384*13*13 -> 384*13*13 x = torch.relu(self.conv5(x)) # 384*13*13 -> 256*13*13 x = torch.max_pool2d(x,3,stride=2) # 256*13*13 -> 256*6*6 x = x.view(x.size(0), -1) # 256*6*6 -> 9216 x = torch.relu(self.fc1(x)) # 9216 -> 4096 x = torch.relu(self.fc2(x)) # 4096 -> 4096 x = self.fc3(x) # 4096 -> 2 return x net = AlexNet() total = sum([param.nelement() for param in net.parameters()]) print("Number of parameter:" ,total)