深度学习--实战 ResNet18
ResNet18的基本含义是,网络的基本架构是ResNet,网络的深度是18层。但是这里的网络深度指的是网络的权重层,也就是包括池化,激活,线性层。而不包括批量化归一层,池化层。
模型实现
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| |
| class ResBlk(nn.Module): |
| ''' |
| resnet block |
| ''' |
| |
| def __init__(self,ch_in,ch_out): |
| ''' |
| |
| :param ch_in: |
| :param ch_out: |
| ''' |
| |
| super(ResBlk, self).__init__() |
| |
| self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1) |
| self.bn1 = nn.BatchNorm2d(ch_out) |
| self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1) |
| self.bn2 = nn.BatchNorm2d(ch_out) |
| |
| self.extra = nn.Sequential() |
| if ch_out != ch_in: |
| |
| self.extra = nn.Sequential( |
| nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1), |
| nn.BatchNorm2d(ch_out) |
| ) |
| |
| def forward(self,x): |
| ''' |
| |
| :param x: |
| :return: |
| ''' |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| |
| |
| |
| out = self.extra(x) + out |
| return out |
| |
| |
| class ResNet18(nn.Module): |
| ''' |
| |
| ''' |
| def __init__(self): |
| super(ResNet18, self).__init__() |
| |
| self.conv1=nn.Sequential( |
| nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1), |
| nn.BatchNorm2d(64) |
| ) |
| |
| |
| |
| self.blk1 = ResBlk(64,64) |
| |
| self.blk2 = ResBlk(64,128) |
| |
| self.blk3 = ResBlk(128,256) |
| |
| self.blk4 = ResBlk(256,512) |
| |
| self.outlayer = nn.Linear(512*32*32,10) |
| |
| def forward(self,x): |
| ''' |
| |
| :param x: |
| :return: |
| ''' |
| x = F.relu((self.conv1(x))) |
| |
| |
| x = self.blk1(x) |
| x = self.blk2(x) |
| x = self.blk3(x) |
| x = self.blk4(x) |
| |
| x = x.view(x.size(0),-1) |
| |
| x=self.outlayer(x) |
| |
| return x |
| |
| |
| def main(): |
| blk = ResBlk(64,128) |
| tmp = torch.randn(2,64,32,32) |
| out = blk(tmp) |
| print("blkk",out.shape) |
| |
| model = ResNet18() |
| tmp = torch.randn(2, 3, 32, 32) |
| out = model(tmp) |
| print("resnet:",out.shape) |
| |
| |
| if __name__ =='__main__': |
| main() |
训练与测试
| import torch |
| from torchvision import datasets |
| from torchvision import transforms |
| from torch.utils.data import DataLoader |
| from lenet5 import Lenet5 |
| import torch.nn.functional as F |
| from torch import nn,optim |
| from resnet import ResNet18 |
| |
| def main(): |
| |
| batch_size = 32 |
| epochs = 1000 |
| learn_rate = 1e-3 |
| |
| |
| cifer_train = datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([ |
| transforms.Resize((32,32)), |
| transforms.ToTensor() |
| ]),download=True) |
| |
| |
| cifer_train = DataLoader(cifer_train,batch_size=batch_size,shuffle=True) |
| |
| |
| cifer_test = datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([ |
| transforms.Resize((32,32)), |
| transforms.ToTensor() |
| ]),download=True) |
| |
| |
| cifer_test = DataLoader(cifer_test,batch_size=batch_size,shuffle=True) |
| |
| |
| x, label = iter(cifer_train).__next__() |
| print("x:",x.shape,"label:",label.shape) |
| |
| |
| |
| device = torch.device('cuda') |
| |
| model = ResNet18().to(device) |
| print(model) |
| criteon = nn.CrossEntropyLoss().to(device) |
| optimizer = optim.Adam(model.parameters(),lr=learn_rate) |
| |
| |
| for epoch in range(epochs): |
| model.train() |
| for batchidx,(x,label) in enumerate(cifer_train): |
| x,label = x.to(device),label.to(device) |
| |
| logits = model(x) |
| |
| |
| loss = criteon(logits,label) |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| print(epoch,loss.item()) |
| |
| model.eval() |
| with torch.no_grad(): |
| |
| total_correct = 0 |
| total_num = 0 |
| for x,label in cifer_test: |
| x,label = x.to(device),label.to(device) |
| |
| logits = model(x) |
| |
| pred =logits.argmax(dim=1) |
| |
| |
| total_correct += torch.eq(pred,label).float().sum().item() |
| total_num += x.size(0) |
| |
| acc = total_correct/total_num |
| print("epoch:",epoch,"acc:",acc) |
| |
| |
| if __name__ == '__main__': |
| main() |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人