Pytorch-实现ResNet-18并在Cifar-10数据集上进行验证

1.Pytorch上搭建ResNet-18

 1 import  torch
 2 from    torch import  nn
 3 from    torch.nn import functional as F
 4 
 5 
 6 class ResBlk(nn.Module):
 7     """
 8     resnet block子模块
 9     """
10     def __init__(self, ch_in, ch_out, stride=1):
11    
12         super(ResBlk, self).__init__()
13 
14         self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
15         self.bn1 = nn.BatchNorm2d(ch_out)
16         self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
17         self.bn2 = nn.BatchNorm2d(ch_out)
18 
19         self.extra = nn.Sequential()
20         # 如果输入和输出的通道不一致,或其步长不为 1,需要将二者转成一致
21         if ch_out != ch_in:
22             # [b, ch_in, h, w] => [b, ch_out, h, w]
23             self.extra = nn.Sequential(
24                 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
25                 nn.BatchNorm2d(ch_out)
26             )
27 
28     def forward(self, x):
29         
30         out = F.relu(self.bn1(self.conv1(x)))
31         out = self.bn2(self.conv2(out))
32         
33         out = self.extra(x) + out
34         out = F.relu(out)        
35         return out
36 
37 
38 class ResNet18(nn.Module):
39     '''
40     主模块
41     '''
42     def __init__(self):
43         super(ResNet18, self).__init__()
44 
45         self.conv1 = nn.Sequential(
46             nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
47             nn.BatchNorm2d(64)
48         )
49         # followed 4 blocks       
50         self.blk1 = ResBlk(64, 128, stride=2)          #[b, 64, h, w] => [b, 128, h ,w]
51         self.blk2 = ResBlk(128, 256, stride=2)         #[b, 128, h, w] => [b, 256, h, w]
52         self.blk3 = ResBlk(256, 512, stride=2)         #[b, 256, h, w] => [b, 512, h, w]
53         self.blk4 = ResBlk(512, 512, stride=2)         #[b, 512, h, w] => [b, 512, h, w]
54 
55         self.outlayer = nn.Linear(512*1*1, 10)         #全连接层,总共10个分类
56 
57     def forward(self, x):
58         x = F.relu(self.conv1(x))
59 
60         # [b, 64, h, w] => [b, 1024, h, w]
61         x = self.blk1(x)
62         x = self.blk2(x)
63         x = self.blk3(x)
64         x = self.blk4(x)
65 
66         x = F.adaptive_avg_pool2d(x, [1, 1])          #[b, 512, h, w] => [b, 512, 1, 1]
67         x = x.view(x.size(0), -1)
68         x = self.outlayer(x)
69 
70         return x

举个栗子测试一下:

 1 if __name__ == '__main__':
 2 
 3     blk = ResBlk(64, 128, stride=4)
 4     tmp = torch.randn(2, 64, 32, 32)
 5     out = blk(tmp)
 6     print('block:', out.shape)                      #block: torch.Size([2, 128, 8, 8])
 7 
 8     x = torch.randn(2, 3, 32, 32)
 9     model = ResNet18()
10     out = model(x)
11     print('resnet:', out.shape)                     #resnet: torch.Size([2, 10])

2.训练Cifar-10数据集

所选数据集为Cifar-10,该数据集共有60000张带标签的彩色图像,这些图像尺寸32*32,分为10个类,每类6000张图。这里面有50000张用于训练,每个类5000张,另外10000用于测试,每个类1000张。

 1 import  torch
 2 from    torch.utils.data import DataLoader
 3 from    torchvision import datasets,transforms 
 4 from    torch import nn, optim
 5 
 6 from    resnet import ResNet18
 7 
 8 
 9 def main():
10     batchsz = 128
11     
12     #训练集
13     cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
14         transforms.Resize((32, 32)),
15         transforms.ToTensor(),
16         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17     ]))
18     cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
19     
20     
21     #测试集
22     cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
23         transforms.Resize((32, 32)),
24         transforms.ToTensor(),
25         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26     ]))
27     cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
28 
29 
30     x, label = iter(cifar_train).next()
31     print('x:', x.shape, 'label:', label.shape)         #x: torch.Size([128, 3, 32, 32])  label: torch.Size([128])
32     
33     #定义模型-ResNet
34     model = ResNet18()
35     
36     #定义损失函数和优化方式
37     criteon = nn.CrossEntropyLoss()
38     optimizer = optim.Adam(model.parameters(), lr=1e-3)
39     print(model)
40     
41     #训练网络
42     for epoch in range(1000):
43 
44         model.train()                                   #训练模式
45         for batchidx, (x, label) in enumerate(cifar_train):
46             #x: [b, 3, 32, 32]
47             #label: [b]
48             
49             logits = model(x)                           #logits: [b, 10]
50             loss = criteon(logits, label)               #标量
51 
52             optimizer.zero_grad()
53             loss.backward()
54             optimizer.step()
55 
56         print(epoch, 'loss:', loss.item())
57 
58 
59         model.eval()                                    #测试模式
60         with torch.no_grad():
61             
62             total_correct = 0                           #预测正确的个数
63             total_num = 0
64             for x, label in cifar_test:
65                 #x: [b, 3, 32, 32]
66                 #label: [b]
67                               
68                 logits = model(x)                       #[b, 10]            
69                 pred = logits.argmax(dim=1)             #[b]
70                 
71                 # [b] vs [b] => scalar tensor
72                 correct = torch.eq(pred, label).float().sum().item()
73                 total_correct += correct
74                 total_num += x.size(0)
75                 
76             acc = total_correct / total_num
77             print(epoch, 'test acc:', acc)
78 
79 
80 if __name__ == '__main__':
81     main()

迭代1000次,训练太久了,暂且输出前5次。

0 loss: 1.0912220478057861
0 test acc: 0.5583
1 loss: 0.8604468107223511
1 test acc: 0.6592
2 loss: 0.6625195145606995
2 test acc: 0.6827
3 loss: 0.7064175009727478
3 test acc: 0.6904
4 loss: 0.5687283277511597
4 test acc: 0.7059

posted @ 2020-07-19 15:44  最咸的鱼  阅读(2191)  评论(0编辑  收藏  举报