龙良曲pytorch学习笔记_ResNet18
main----dataloader----train----test
相对LeNet5的主函数来讲,仅仅是更换了模型名称,其他部分没有变化。
import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvision import transforms from torch import nn,optim from resnet import ResNet18 def main(): batch_size = 32 cifar_train = datasets.CIFAR10('cifar',train = True,transform = transforms.Compose([ transforms.Resize((32,32)), transforms.ToTensor() ]),download = True) # 可以同时加载多张图片 cifar_train = DataLoader(cifar_train,batch_size = batch_size,shuffle = True) cifar_test = datasets.CIFAR10('cifar',train = False,transform = transforms.Compose([ transforms.Resize((32,32)), transforms.ToTensor(), # transforms.RandomRotation(5), transforms.Normalize(mean = [0.485,0,456,0,406], std = [0.229.0.224.0.225]) ]),download = True) # 可以同时加载多张图片 cifar_test = DataLoader(cifar_test,batch_size = batch_size,shuffle = True) # 数据加载成功后可以检验shape x,label = iter(cifar_train).next() print('x:',x.shape,'label:',label.shape) device = torch.device('cuda') model = ResNet18().to(device) criteon = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters(),lr=1e-3) print(model) for epoch in range(1000): model.train() for batchidx,(x,label) in enumerate(cifar_train): # x: [b,3,32,32], label: [b] x,label = x.to(device),label.to(device) logits = model(x) # logits:[b,10] # label:[b] loss = criteon(logits,label) # backprop optimizer.zero_grad() loss.backwark() optimizer.step() # print(epoch,loss.item()) model.eval() # 不需要做梯度相关计算 with torch.nn_grad(): # test total_correct = 0 total_num = 0 for x,label in cifar_test: x,label = x.to(device),label.to(device) # logits:[b,10] logits = model(x) pred = logits.argmax(dim=1) # 获取一个batch的在累加 total_correct = += torch.eq(pred,label).float().sum().item() # x.size(0)就是batch_size total_num += x.size(0) acc = total_correct / total_num print(epoch,acc) if __name__ == '__main__' main()
ResNet18
1 import torch 2 from torch import nn 3 from torch.nn import functional as F 4 5 class ResBlk(nn.Module): 6 7 def __init__(self,ch_in,ch_out,stride = 1): 8 super(ResBlk,self).__init__() 9 10 # 改变stride是为了使得图片的size变小,以避免占用过多内存 11 self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size = 3,stride = stride,padding = 1) 12 self.bn1 = nn.BatchNorm2d(ch_out) 13 self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size = 3,stride = 1,padding = 1) 14 self.bn2 = nn.BatchNorm2d(ch_out) 15 16 self.extra = nn.Squential() 17 if ch_out != ch_in: 18 # [b,ch_in,h,w] --> [b,ch_in,h,w] 19 self.extra = nn.Squential( 20 # x要和f(x)的size也一样,所以也要设置stride 21 # 而channel通过一个卷积层来使得他们一致 22 nn.Conv2d(ch_in,ch_out,kernel_size = 1,stride = stride) 23 nn.BatchNorm2d(ch_out) 24 ) 25 26 def forward(self,x): 27 out = F.relu(self.bn1(self.conv1(x))) 28 # 这里的relu取决于自己 29 out = F.relu(self.bn2(self.conv2(out))) 30 # short cut 31 # extra module: [b,ch_in,h,w] --> [b,ch_in,h,w] 32 # element-wise add 需要ch_in和ch_out相等 33 # 由于是残差网络,所以要把f(x)和短路的x相加 34 out = self.extra(x) + out 35 36 return out 37 38 class ResNet18(nn.Module): 39 40 def __init__(self): 41 super(ResNet18,self).__init__() 42 43 self.conv1 = nn.Sequential( 44 nn.Conv2d(3,64,kernel_size = 3,stride = 1,padding = 1), 45 nn.BatchNorm2d(64) 46 ) 47 # followws 4 blocks 48 # [b,64,h,w] --> [b,128,h,w] 49 self.blk1 = ResBlk(64,128,stride = 2) 50 # [b,128,h,w] --> [b,256,h,w] 51 self.blk2 = ResBlk(128,256,stride = 2) 52 # [b,256,h,w] --> [b,512,h,w] 53 self.blk3 = ResBlk(256,512,stride = 2) 54 # [b,512,h,w] --> [b,512,h,w] 55 self.blk4 = ResBlk(512,512,stride = 2) 56 57 # 线性层的输入需要测试之后才能知道 58 self.outlayer = nn.Linear(512*1*1,10) 59 60 def forward(self,x): 61 x = F.relu(self.conv1(x)) 62 # [b,64,h,w] --> [b,1024,h,w] 63 x = self.blk1(x) 64 x = self.blk2(x) 65 x = self.blk3(x) 66 x = self.blk4(x) 67 68 # print('after conv:',x.shape) # [b,512,2,2] 69 # [b,512,1,1] --> [b,512,1,1] 70 x = F.adaptive_avg_pool2d(x,[1,1]) 71 # print('after conv:',x.shape) 72 x = x.view(x.size(0),-1) 73 x = self.outlayer(x) 74 75 return x 76 77 def main(): 78 79 blk = ResBlk(64,128,stride = 4) 80 tmp = torch.randn(2,64,32,32) 81 out = blk(tmp) 82 print('block:',out.shape) 83 84 x = torch.randn(2,3,32,32) 85 model = ResNet18() 86 out = model(x) 87 print('resnet:',out.shape)