AlexNet网络结构与pytorch代码实现
一 网络结构
二 pytoch代码实现
方法一
import time import torch from torch import nn, optim import torchvision device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class AlexNet(nn.Module): def __init__(self): super(AlexNet, self).__init__() self.conv = nn.Sequential( nn.Conv2d(1, 96, 11, 4), # in_channels, out_channels, kernel_size, stride, padding nn.ReLU(), nn.MaxPool2d(3, 2), # kernel_size, stride # 减小卷积窗口,使用填充为2来使得输入与输出的高和宽一致,且增大输出通道数 nn.Conv2d(96, 256, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(3, 2), # 连续3个卷积层,且使用更小的卷积窗口。除了最后的卷积层外,进一步增大了输出通道数。 # 前两个卷积层后不使用池化层来减小输入的高和宽 nn.Conv2d(256, 384, 3, 1, 1), nn.ReLU(), nn.Conv2d(384, 384, 3, 1, 1), nn.ReLU(), nn.Conv2d(384, 256, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(3, 2) ) # 这里全连接层的输出个数比LeNet中的大数倍。使用丢弃层来缓解过拟合 self.fc = nn.Sequential( nn.Linear(256*5*5, 4096), nn.ReLU(), nn.Dropout(0.5), nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5), # 输出层。由于这里使用Fashion-MNIST,所以用类别数为10,而非论文中的1000 nn.Linear(4096, 10), ) def forward(self, img): feature = self.conv(img) output = self.fc(feature.view(img.shape[0], -1)) return output
方法二
import torch import torch.nn as nn class Conv2dReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1), padding=(0, 0), bias=True): super(Conv2dReLU, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.bias = bias self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) self.relu = nn.ReLU(inplace=True) def forward(self, x): """ Args: x: [N,C,H,W] """ o1 = self.conv(x) o2 = self.relu(o1) return o2 class Features(nn.Module): def __init__(self): super(Features, self).__init__() self.conv2drelu1 = Conv2dReLU(in_channels=3, out_channels=64, kernel_size=11, stride=4, padding=2) # [1,64,55,55] self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) # [1,64,27,27] self.conv2drelu2 = Conv2dReLU(in_channels=64, out_channels=192, kernel_size=5, padding=2) # [1,192,27,27] self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) # [1,192,13,13] self.conv2drelu3 = Conv2dReLU(in_channels=192, out_channels=384, kernel_size=3, padding=1) # [1,384,13,13] self.conv2drelu4 = Conv2dReLU(in_channels=384, out_channels=256, kernel_size=3, padding=1) # [1,256,13,13] self.conv2drelu5 = Conv2dReLU(in_channels=256, out_channels=256, kernel_size=3, padding=1) # [1,256,13,13] self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2) # [1,256,6,6] def forward(self, x): """ Args: x: [N,C,H,W] """ o1 = self.conv2drelu1(x) o2 = self.maxpool1(o1) o3 = self.conv2drelu2(o2) o4 = self.maxpool2(o3) o5 = self.conv2drelu3(o4) o6 = self.conv2drelu4(o5) o7 = self.conv2drelu5(o6) o8 = self.maxpool3(o7) return o8 class Classifier(nn.Module): def __init__(self, num_classes): super(Classifier, self).__init__() self.num_classes = num_classes self.dropout1 = nn.Dropout() self.fc1 = nn.Linear(in_features=9216, out_features=4096) self.relu1 = nn.ReLU(inplace=True) self.dropout2 = nn.Dropout() self.fc2 = nn.Linear(in_features=4096, out_features=4096) self.relu2 = nn.ReLU(inplace=True) self.fc3 = nn.Linear(in_features=4096, out_features=num_classes) def forward(self, x): o1 = self.dropout1(x) o2 = self.fc1(o1) o3 = self.relu1(o2) o4 = self.dropout2(o3) o5 = self.fc2(o4) o6 = self.relu2(o5) o7 = self.fc3(o6) return o7 class AlexNet(nn.Module): def __init__(self, num_classes): super(AlexNet, self).__init__() self.num_classes = num_classes self.features = Features() self.aavgpool = nn.AdaptiveAvgPool2d(output_size=6) self.flatten = nn.Flatten() self.classifier = Classifier(num_classes=num_classes) def forward(self, x): o1 = self.features(x) o2 = self.aavgpool(o1) o3 = self.flatten(o2) o4 = self.classifier(o3) return o4 if __name__ == '__main__': model = AlexNet(4) for name, parameters in model.named_parameters(): print(name, ':', parameters.size()) data = torch.rand(1,3,224,224) a = model(data) print(a)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!