Pytorch-实现ResNet-18并在Cifar-10数据集上进行验证
1.Pytorch上搭建ResNet-18
1 import torch 2 from torch import nn 3 from torch.nn import functional as F 4 5 6 class ResBlk(nn.Module): 7 """ 8 resnet block子模块 9 """ 10 def __init__(self, ch_in, ch_out, stride=1): 11 12 super(ResBlk, self).__init__() 13 14 self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1) 15 self.bn1 = nn.BatchNorm2d(ch_out) 16 self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1) 17 self.bn2 = nn.BatchNorm2d(ch_out) 18 19 self.extra = nn.Sequential() 20 # 如果输入和输出的通道不一致,或其步长不为 1,需要将二者转成一致 21 if ch_out != ch_in: 22 # [b, ch_in, h, w] => [b, ch_out, h, w] 23 self.extra = nn.Sequential( 24 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride), 25 nn.BatchNorm2d(ch_out) 26 ) 27 28 def forward(self, x): 29 30 out = F.relu(self.bn1(self.conv1(x))) 31 out = self.bn2(self.conv2(out)) 32 33 out = self.extra(x) + out 34 out = F.relu(out) 35 return out 36 37 38 class ResNet18(nn.Module): 39 ''' 40 主模块 41 ''' 42 def __init__(self): 43 super(ResNet18, self).__init__() 44 45 self.conv1 = nn.Sequential( 46 nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0), 47 nn.BatchNorm2d(64) 48 ) 49 # followed 4 blocks 50 self.blk1 = ResBlk(64, 128, stride=2) #[b, 64, h, w] => [b, 128, h ,w] 51 self.blk2 = ResBlk(128, 256, stride=2) #[b, 128, h, w] => [b, 256, h, w] 52 self.blk3 = ResBlk(256, 512, stride=2) #[b, 256, h, w] => [b, 512, h, w] 53 self.blk4 = ResBlk(512, 512, stride=2) #[b, 512, h, w] => [b, 512, h, w] 54 55 self.outlayer = nn.Linear(512*1*1, 10) #全连接层,总共10个分类 56 57 def forward(self, x): 58 x = F.relu(self.conv1(x)) 59 60 # [b, 64, h, w] => [b, 1024, h, w] 61 x = self.blk1(x) 62 x = self.blk2(x) 63 x = self.blk3(x) 64 x = self.blk4(x) 65 66 x = F.adaptive_avg_pool2d(x, [1, 1]) #[b, 512, h, w] => [b, 512, 1, 1] 67 x = x.view(x.size(0), -1) 68 x = self.outlayer(x) 69 70 return x
举个栗子测试一下:
1 if __name__ == '__main__': 2 3 blk = ResBlk(64, 128, stride=4) 4 tmp = torch.randn(2, 64, 32, 32) 5 out = blk(tmp) 6 print('block:', out.shape) #block: torch.Size([2, 128, 8, 8]) 7 8 x = torch.randn(2, 3, 32, 32) 9 model = ResNet18() 10 out = model(x) 11 print('resnet:', out.shape) #resnet: torch.Size([2, 10])
2.训练Cifar-10数据集
所选数据集为Cifar-10,该数据集共有60000张带标签的彩色图像,这些图像尺寸32*32,分为10个类,每类6000张图。这里面有50000张用于训练,每个类5000张,另外10000用于测试,每个类1000张。
1 import torch 2 from torch.utils.data import DataLoader 3 from torchvision import datasets,transforms 4 from torch import nn, optim 5 6 from resnet import ResNet18 7 8 9 def main(): 10 batchsz = 128 11 12 #训练集 13 cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([ 14 transforms.Resize((32, 32)), 15 transforms.ToTensor(), 16 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 17 ])) 18 cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True) 19 20 21 #测试集 22 cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([ 23 transforms.Resize((32, 32)), 24 transforms.ToTensor(), 25 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 26 ])) 27 cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True) 28 29 30 x, label = iter(cifar_train).next() 31 print('x:', x.shape, 'label:', label.shape) #x: torch.Size([128, 3, 32, 32]) label: torch.Size([128]) 32 33 #定义模型-ResNet 34 model = ResNet18() 35 36 #定义损失函数和优化方式 37 criteon = nn.CrossEntropyLoss() 38 optimizer = optim.Adam(model.parameters(), lr=1e-3) 39 print(model) 40 41 #训练网络 42 for epoch in range(1000): 43 44 model.train() #训练模式 45 for batchidx, (x, label) in enumerate(cifar_train): 46 #x: [b, 3, 32, 32] 47 #label: [b] 48 49 logits = model(x) #logits: [b, 10] 50 loss = criteon(logits, label) #标量 51 52 optimizer.zero_grad() 53 loss.backward() 54 optimizer.step() 55 56 print(epoch, 'loss:', loss.item()) 57 58 59 model.eval() #测试模式 60 with torch.no_grad(): 61 62 total_correct = 0 #预测正确的个数 63 total_num = 0 64 for x, label in cifar_test: 65 #x: [b, 3, 32, 32] 66 #label: [b] 67 68 logits = model(x) #[b, 10] 69 pred = logits.argmax(dim=1) #[b] 70 71 # [b] vs [b] => scalar tensor 72 correct = torch.eq(pred, label).float().sum().item() 73 total_correct += correct 74 total_num += x.size(0) 75 76 acc = total_correct / total_num 77 print(epoch, 'test acc:', acc) 78 79 80 if __name__ == '__main__': 81 main()
迭代1000次,训练太久了,暂且输出前5次。
0 loss: 1.0912220478057861
0 test acc: 0.5583
1 loss: 0.8604468107223511
1 test acc: 0.6592
2 loss: 0.6625195145606995
2 test acc: 0.6827
3 loss: 0.7064175009727478
3 test acc: 0.6904
4 loss: 0.5687283277511597
4 test acc: 0.7059