pytorch ResNet
一、ResNet网络结构
1.1ResNet特点
- 深层网络结构
- 残差模块
- Batch Normalization加速训练
使一批feature map满足均值为0,方差为1的分布。
ResNet解决了网络层数增加带来的梯度消失,梯度爆炸和梯度退化问题。
1.2网络结构
residua block的虚线代表主分支和shortcut的shape不同,所以要在shortcut中加入kernel,使得输出的维度和主分支相同,才能进行相加。
1.3参数列表
二。模型和训练代码
2.1 model.py
1 import torch.nn as nn 2 import torch 3 4 5 class BasicBlock(nn.Module): 6 expansion = 1 7 8 def __init__(self, in_channel, out_channel, stride=1, downsample=None): 9 super(BasicBlock, self).__init__() 10 self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 11 kernel_size=3, stride=stride, padding=1, bias=False) 12 self.bn1 = nn.BatchNorm2d(out_channel) 13 self.relu = nn.ReLU() 14 self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, 15 kernel_size=3, stride=1, padding=1, bias=False) 16 self.bn2 = nn.BatchNorm2d(out_channel) 17 self.downsample = downsample 18 19 def forward(self, x): 20 identity = x 21 if self.downsample is not None: 22 identity = self.downsample(x) 23 24 out = self.conv1(x) 25 out = self.bn1(out) 26 out = self.relu(out) 27 28 out = self.conv2(out) 29 out = self.bn2(out) 30 31 out += identity 32 out = self.relu(out) 33 34 return out 35 36 37 class Bottleneck(nn.Module): 38 expansion = 4 39 40 def __init__(self, in_channel, out_channel, stride=1, downsample=None): 41 super(Bottleneck, self).__init__() 42 self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 43 kernel_size=1, stride=1, bias=False) # squeeze channels 44 self.bn1 = nn.BatchNorm2d(out_channel) 45 # ----------------------------------------- 46 self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, 47 kernel_size=3, stride=stride, bias=False, padding=1) 48 self.bn2 = nn.BatchNorm2d(out_channel) 49 # ----------------------------------------- 50 self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion, 51 kernel_size=1, stride=1, bias=False) # unsqueeze channels 52 self.bn3 = nn.BatchNorm2d(out_channel*self.expansion) 53 self.relu = nn.ReLU(inplace=True) 54 self.downsample = downsample 55 56 def forward(self, x): 57 identity = x 58 if self.downsample is not None: 59 identity = self.downsample(x) 60 61 out = self.conv1(x) 62 out = self.bn1(out) 63 out = self.relu(out) 64 65 out = self.conv2(out) 66 out = self.bn2(out) 67 out = self.relu(out) 68 69 out = self.conv3(out) 70 out = self.bn3(out) 71 72 out += identity 73 out = self.relu(out) 74 75 return out 76 77 78 class ResNet(nn.Module): 79 80 def __init__(self, block, blocks_num, num_classes=1000, include_top=True): 81 super(ResNet, self).__init__() 82 self.include_top = include_top 83 self.in_channel = 64 84 85 self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, 86 padding=3, bias=False) 87 self.bn1 = nn.BatchNorm2d(self.in_channel) 88 self.relu = nn.ReLU(inplace=True) 89 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 90 self.layer1 = self._make_layer(block, 64, blocks_num[0]) 91 self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) 92 self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) 93 self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) 94 if self.include_top: 95 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1) 96 self.fc = nn.Linear(512 * block.expansion, num_classes) 97 98 for m in self.modules(): 99 if isinstance(m, nn.Conv2d): 100 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 101 102 def _make_layer(self, block, channel, block_num, stride=1): 103 downsample = None 104 if stride != 1 or self.in_channel != channel * block.expansion: 105 downsample = nn.Sequential( 106 nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), 107 nn.BatchNorm2d(channel * block.expansion)) 108 109 layers = [] 110 layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride)) 111 self.in_channel = channel * block.expansion 112 113 for _ in range(1, block_num): 114 layers.append(block(self.in_channel, channel)) 115 116 return nn.Sequential(*layers) 117 118 def forward(self, x): 119 x = self.conv1(x) 120 x = self.bn1(x) 121 x = self.relu(x) 122 x = self.maxpool(x) 123 124 x = self.layer1(x) 125 x = self.layer2(x) 126 x = self.layer3(x) 127 x = self.layer4(x) 128 129 if self.include_top: 130 x = self.avgpool(x) 131 x = torch.flatten(x, 1) 132 x = self.fc(x) 133 134 return x 135 136 137 def resnet34(num_classes=1000, include_top=True): 138 return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top) 139 140 141 def resnet101(num_classes=1000, include_top=True): 142 return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
2.2 train.py 带迁移学习
1 import torch 2 import torch.nn as nn 3 from torchvision import transforms, datasets 4 import json 5 import matplotlib.pyplot as plt 6 import os 7 import torch.optim as optim 8 from model import resnet34, resnet101 9 10 import torchvision.models.resnet 11 def main(): 12 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 print("using {} device.".format(device)) 14 15 data_transform = { 16 "train": transforms.Compose([transforms.RandomResizedCrop(224), 17 transforms.RandomHorizontalFlip(), 18 transforms.ToTensor(), 19 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 20 "val": transforms.Compose([transforms.Resize(256), 21 transforms.CenterCrop(224), 22 transforms.ToTensor(), 23 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} 24 25 data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path 26 image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path 27 assert os.path.exists(image_path), "{} path does not exist.".format(image_path) 28 train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), 29 transform=data_transform["train"]) 30 train_num = len(train_dataset) 31 32 # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 33 flower_list = train_dataset.class_to_idx 34 cla_dict = dict((val, key) for key, val in flower_list.items()) 35 # write dict into json file 36 json_str = json.dumps(cla_dict, indent=4) 37 with open('class_indices.json', 'w') as json_file: 38 json_file.write(json_str) 39 40 batch_size = 16 41 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 42 print('Using {} dataloader workers every process'.format(0)) 43 44 train_loader = torch.utils.data.DataLoader(train_dataset, 45 batch_size=batch_size, shuffle=True, 46 num_workers=0) 47 48 validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), 49 transform=data_transform["val"]) 50 val_num = len(validate_dataset) 51 validate_loader = torch.utils.data.DataLoader(validate_dataset, 52 batch_size=batch_size, shuffle=False, 53 num_workers=0) 54 55 print("using {} images for training, {} images fot validation.".format(train_num, 56 val_num)) 57 58 net = resnet34() 59 # load pretrain weights transfer learning 60 # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth 61 model_weight_path = "./resnet34-pre.pth" 62 assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path) 63 missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False) 64 # for param in net.parameters(): 65 # param.requires_grad = False 66 # change fc layer structure 67 in_channel = net.fc.in_features 68 net.fc = nn.Linear(in_channel, 5) 69 net.to(device) 70 71 loss_function = nn.CrossEntropyLoss() 72 optimizer = optim.Adam(net.parameters(), lr=0.0001) 73 74 best_acc = 0.0 75 save_path = './resNet34.pth' 76 for epoch in range(3): 77 # train 78 net.train() 79 running_loss = 0.0 80 for step, data in enumerate(train_loader, start=0): 81 images, labels = data 82 optimizer.zero_grad() 83 logits = net(images.to(device)) 84 loss = loss_function(logits, labels.to(device)) 85 loss.backward() 86 optimizer.step() 87 88 # print statistics 89 running_loss += loss.item() 90 # print train process 91 rate = (step+1)/len(train_loader) 92 a = "*" * int(rate * 50) 93 b = "." * int((1 - rate) * 50) 94 print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="") 95 print() 96 97 # validate 98 net.eval() 99 acc = 0.0 # accumulate accurate number / epoch 100 with torch.no_grad(): 101 for val_data in validate_loader: 102 val_images, val_labels = val_data 103 outputs = net(val_images.to(device)) # eval model only have last output layer 104 # loss = loss_function(outputs, test_labels) 105 predict_y = torch.max(outputs, dim=1)[1] 106 acc += (predict_y == val_labels.to(device)).sum().item() 107 val_accurate = acc / val_num 108 if val_accurate > best_acc: 109 best_acc = val_accurate 110 torch.save(net.state_dict(), save_path) 111 print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' % 112 (epoch + 1, running_loss / step, val_accurate)) 113 114 print('Finished Training') 115 116 117 if __name__ == '__main__': 118 main()
2.3predict.py
1 import torch 2 from model import resnet34 3 from PIL import Image 4 from torchvision import transforms 5 import matplotlib.pyplot as plt 6 import json 7 8 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 10 data_transform = transforms.Compose( 11 [transforms.Resize(256), 12 transforms.CenterCrop(224), 13 transforms.ToTensor(), 14 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 15 16 # load image 17 img = Image.open("../rose.jpg") 18 plt.imshow(img) 19 # [N, C, H, W] 20 img = data_transform(img) 21 # expand batch dimension 22 img = torch.unsqueeze(img, dim=0) 23 24 # read class_indict 25 try: 26 json_file = open('./class_indices.json', 'r') 27 class_indict = json.load(json_file) 28 except Exception as e: 29 print(e) 30 exit(-1) 31 32 # create model 33 model = resnet34(num_classes=5) 34 # load model weights 35 model_weight_path = "./resNet34.pth" 36 model.load_state_dict(torch.load(model_weight_path, map_location=device)) 37 model.eval() 38 with torch.no_grad(): 39 # predict class 40 output = torch.squeeze(model(img)) 41 predict = torch.softmax(output, dim=0) 42 predict_cla = torch.argmax(predict).numpy() 43 print(class_indict[str(predict_cla)], predict[predict_cla].numpy()) 44 plt.show()