pytorch训练AlexNet
一。AlexNet网络结构和参数
二。训练部分
model.py
1 import torch.nn as nn 2 import torch 3 4 5 class AlexNet(nn.Module): 6 def __init__(self, num_classes=1000, init_weights=False): 7 super(AlexNet, self).__init__() 8 self.features = nn.Sequential( 9 nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55] 10 nn.ReLU(inplace=True), 11 nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27] 12 nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27] 13 nn.ReLU(inplace=True), 14 nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13] 15 nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13] 16 nn.ReLU(inplace=True), 17 nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13] 18 nn.ReLU(inplace=True), 19 nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13] 20 nn.ReLU(inplace=True), 21 nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6] 22 ) 23 self.classifier = nn.Sequential( 24 nn.Dropout(p=0.5), 25 nn.Linear(128 * 6 * 6, 2048), 26 nn.ReLU(inplace=True), 27 nn.Dropout(p=0.5), 28 nn.Linear(2048, 2048), 29 nn.ReLU(inplace=True), 30 nn.Linear(2048, num_classes), 31 ) 32 if init_weights: 33 self._initialize_weights() 34 35 def forward(self, x): 36 x = self.features(x) 37 x = torch.flatten(x, start_dim=1) 38 x = self.classifier(x) 39 return x 40 41 def _initialize_weights(self): 42 for m in self.modules(): 43 if isinstance(m, nn.Conv2d): 44 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 45 if m.bias is not None: 46 nn.init.constant_(m.bias, 0) 47 elif isinstance(m, nn.Linear): 48 nn.init.normal_(m.weight, 0, 0.01) 49 nn.init.constant_(m.bias, 0)
train.py
1 import torch 2 import torch.nn as nn 3 from torchvision import transforms, datasets, utils 4 import matplotlib.pyplot as plt 5 import numpy as np 6 import torch.optim as optim 7 from model import AlexNet 8 import os 9 import json 10 import time 11 12 13 def main(): 14 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 print("using {} device.".format(device)) 16 17 data_transform = { 18 "train": transforms.Compose([transforms.RandomResizedCrop(224), 19 transforms.RandomHorizontalFlip(), 20 transforms.ToTensor(), 21 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 22 "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224) 23 transforms.ToTensor(), 24 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} 25 26 data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path 27 image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path 28 assert os.path.exists(image_path), "{} path does not exist.".format(image_path) 29 train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), 30 transform=data_transform["train"]) 31 train_num = len(train_dataset) 32 33 # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 34 flower_list = train_dataset.class_to_idx 35 cla_dict = dict((val, key) for key, val in flower_list.items()) 36 # write dict into json file 37 json_str = json.dumps(cla_dict, indent=4) 38 with open('class_indices.json', 'w') as json_file: 39 json_file.write(json_str) 40 41 batch_size = 32 42 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 43 print('Using {} dataloader workers every process'.format(nw)) 44 45 train_loader = torch.utils.data.DataLoader(train_dataset, 46 batch_size=batch_size, shuffle=True, 47 num_workers=0) 48 49 validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), 50 transform=data_transform["val"]) 51 val_num = len(validate_dataset) 52 validate_loader = torch.utils.data.DataLoader(validate_dataset, 53 batch_size=batch_size, shuffle=True, 54 num_workers=0) 55 56 print("using {} images for training, {} images fot validation.".format(train_num, 57 val_num)) 58 # test_data_iter = iter(validate_loader) 59 # test_image, test_label = test_data_iter.next() 60 # # 61 # def imshow(img): 62 # img = img / 2 + 0.5 # unnormalize 63 # npimg = img.numpy() 64 # plt.imshow(np.transpose(npimg, (1, 2, 0))) 65 # plt.show() 66 # 67 # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4))) 68 # imshow(utils.make_grid(test_image)) 69 70 net = AlexNet(num_classes=5, init_weights=True) 71 72 net.to(device) 73 loss_function = nn.CrossEntropyLoss() 74 # pata = list(net.parameters()) 75 optimizer = optim.Adam(net.parameters(), lr=0.0002) 76 77 save_path = './AlexNet.pth' 78 best_acc = 0.0 79 for epoch in range(10): 80 # train 81 net.train() 82 running_loss = 0.0 83 t1 = time.perf_counter() 84 for step, data in enumerate(train_loader, start=0): 85 images, labels = data 86 optimizer.zero_grad() 87 outputs = net(images.to(device)) 88 loss = loss_function(outputs, labels.to(device)) 89 loss.backward() 90 optimizer.step() 91 92 # print statistics 93 running_loss += loss.item() 94 # print train process 95 rate = (step + 1) / len(train_loader) 96 a = "*" * int(rate * 50) 97 b = "." * int((1 - rate) * 50) 98 print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="") 99 print() 100 print(time.perf_counter()-t1) 101 102 # validate 103 net.eval() 104 acc = 0.0 # accumulate accurate number / epoch 105 with torch.no_grad(): 106 for val_data in validate_loader: 107 val_images, val_labels = val_data 108 outputs = net(val_images.to(device)) 109 predict_y = torch.max(outputs, dim=1)[1] 110 acc += (predict_y == val_labels.to(device)).sum().item() 111 val_accurate = acc / val_num 112 if val_accurate > best_acc: 113 best_acc = val_accurate 114 torch.save(net.state_dict(), save_path) 115 print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' % 116 (epoch + 1, running_loss / step, val_accurate)) 117 118 print('Finished Training') 119 120 121 if __name__ == '__main__': 122 main()