神经网络-AlexNet 21

 

训练的数据集:

 含有数据集的:链接:https://pan.baidu.com/s/1u8N_yRnxrNoIMc4aP55rcQ 提取码:6wfe
 不含数据集的:链接:https://pan.baidu.com/s/1BNVj2XSajJx8u1ZlKadnmw 提取码:xrng

model.py

 1 import numpy as np
 2 import cv2
 3 import torch
 4 import torch.nn as nn
 5 import torch.optim as optim
 6 import torch.nn.functional as F
 7  
 8 class AlexNet(nn.Module):
 9     def __init__(self,num_classes=1000,init_weights=False):
10         super(AlexNet, self).__init__()
11         self.features = nn.Sequential(              #Sequential能将层结构打包
12             nn.Conv2d(3,48,kernel_size=11,stride=4,padding=2),          #input_channel=3,output_channel=48
13             nn.ReLU(inplace=True),
14             nn.MaxPool2d(kernel_size=3,stride=2),
15  
16             nn.Conv2d(48, 128, kernel_size=5,  padding=2),  # input_channel=3,output_channel=48
17             nn.ReLU(inplace=True),
18             nn.MaxPool2d(kernel_size=3, stride=2),
19  
20             nn.Conv2d(128, 192, kernel_size=3,  padding=1),  # input_channel=3,output_channel=48
21             nn.ReLU(inplace=True),
22  
23             nn.Conv2d(192, 192, kernel_size=3,  padding=1),  # input_channel=3,output_channel=48
24             nn.ReLU(inplace=True),
25  
26             nn.Conv2d(192, 128, kernel_size=3,  padding=1),  # input_channel=3,output_channel=48
27             nn.ReLU(inplace=True),
28             nn.MaxPool2d(kernel_size=3, stride=2),
29         )
30         self.classifier = nn.Sequential(
31             nn.Dropout(p=0.5),                           #默认随机失活
32             nn.Linear(128*6*6,2048),
33             nn.ReLU(inplace=True),
34             nn.Dropout(p=0.5),  # 默认随机失活
35             nn.Linear(2048, 2048),
36             nn.ReLU(inplace=True),
37             nn.Linear(2048,num_classes),
38         )
39         if init_weights:
40             self._initialize_weights()
41  
42     def forward(self,x):
43         x = self.features(x)
44         x = torch.flatten(x,start_dim=1)
45         x = self.classifier(x)
46         return x
47  
48     def _initialize_weights(self):
49         for m in self.modules():
50             if isinstance(m,nn.Conv2d):
51                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')    #凯明初始化,国人大佬
52                 if m.bias is not None:
53                     nn.init.constant_(m.bias,0)
54             elif isinstance(m,nn.Linear):
55                 nn.init.normal_(m.weight,0,0.01)
56                 nn.init.constant_(m.bias,0)
View Code

train.py

  1 import os
  2 import sys
  3 import json
  4  
  5 import torch
  6 import torch.nn as nn
  7 from torchvision import transforms, datasets, utils
  8 import matplotlib.pyplot as plt
  9 import numpy as np
 10 import torch.optim as optim
 11 from tqdm import tqdm
 12  
 13 from model import AlexNet
 14  
 15  
 16 def main():
 17     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 18     print("using {} device.".format(device))
 19  
 20     data_transform = {
 21         "train": transforms.Compose([transforms.RandomResizedCrop(224),
 22                                      transforms.RandomHorizontalFlip(),
 23                                      transforms.ToTensor(),
 24                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
 25         "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
 26                                    transforms.ToTensor(),
 27                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
 28  
 29     data_root = os.path.abspath(os.path.join(os.getcwd(), "./"))  # get data root path
 30     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
 31     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
 32     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
 33                                          transform=data_transform["train"])
 34     train_num = len(train_dataset)
 35  
 36     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
 37     flower_list = train_dataset.class_to_idx
 38     cla_dict = dict((val, key) for key, val in flower_list.items())
 39     # write dict into json file
 40     json_str = json.dumps(cla_dict, indent=4)
 41     with open('class_indices.json', 'w') as json_file:
 42         json_file.write(json_str)
 43  
 44     batch_size = 32
 45     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
 46     print('Using {} dataloader workers every process'.format(nw))
 47  
 48     train_loader = torch.utils.data.DataLoader(train_dataset,
 49                                                batch_size=batch_size, shuffle=True,
 50                                                num_workers=nw)
 51  
 52     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
 53                                             transform=data_transform["val"])
 54     val_num = len(validate_dataset)
 55     validate_loader = torch.utils.data.DataLoader(validate_dataset,
 56                                                   batch_size=4, shuffle=False,
 57                                                   num_workers=nw)
 58  
 59     print("using {} images for training, {} images for validation.".format(train_num,
 60                                                                            val_num))
 61     # test_data_iter = iter(validate_loader)
 62     # test_image, test_label = test_data_iter.next()
 63     #
 64     # def imshow(img):
 65     #     img = img / 2 + 0.5  # unnormalize
 66     #     npimg = img.numpy()
 67     #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
 68     #     plt.show()
 69     #
 70     # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
 71     # imshow(utils.make_grid(test_image))
 72  
 73     net = AlexNet(num_classes=5, init_weights=True)
 74  
 75     net.to(device)
 76     loss_function = nn.CrossEntropyLoss()
 77     # pata = list(net.parameters())
 78     optimizer = optim.Adam(net.parameters(), lr=0.0002)
 79  
 80     epochs = 10
 81     save_path = './AlexNet.pth'
 82     best_acc = 0.0
 83     train_steps = len(train_loader)
 84     for epoch in range(epochs):
 85         # train
 86         net.train()
 87         running_loss = 0.0
 88         train_bar = tqdm(train_loader, file=sys.stdout)
 89         for step, data in enumerate(train_bar):
 90             images, labels = data
 91             optimizer.zero_grad()
 92             outputs = net(images.to(device))
 93             loss = loss_function(outputs, labels.to(device))
 94             loss.backward()
 95             optimizer.step()
 96  
 97             # print statistics
 98             running_loss += loss.item()
 99  
100             train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
101                                                                      epochs,
102                                                                      loss)
103  
104         # validate
105         net.eval()
106         acc = 0.0  # accumulate accurate number / epoch
107         with torch.no_grad():
108             val_bar = tqdm(validate_loader, file=sys.stdout)
109             for val_data in val_bar:
110                 val_images, val_labels = val_data
111                 outputs = net(val_images.to(device))
112                 predict_y = torch.max(outputs, dim=1)[1]
113                 acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
114  
115         val_accurate = acc / val_num
116         print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
117               (epoch + 1, running_loss / train_steps, val_accurate))
118  
119         if val_accurate > best_acc:
120             best_acc = val_accurate
121             torch.save(net.state_dict(), save_path)
122  
123     print('Finished Training')
124  
125  
126 if __name__ == '__main__':
127     main()
View Code

predict.py

 1 import os
 2 import json
 3  
 4 import torch
 5 from PIL import Image
 6 from torchvision import transforms
 7 import matplotlib.pyplot as plt
 8  
 9 from model import AlexNet
10  
11  
12 def main():
13     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14  
15     data_transform = transforms.Compose(
16         [transforms.Resize((224, 224)),
17          transforms.ToTensor(),
18          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
19  
20     # load image
21     img_path = "./1.png"
22     assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
23     img = Image.open(img_path)
24  
25     plt.imshow(img)
26     # [N, C, H, W]
27     img = data_transform(img)
28     # expand batch dimension
29     img = torch.unsqueeze(img, dim=0)
30  
31     # read class_indict
32     json_path = './class_indices.json'
33     assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
34  
35     with open(json_path, "r") as f:
36         class_indict = json.load(f)
37  
38     # create model
39     model = AlexNet(num_classes=5).to(device)
40  
41     # load model weights
42     weights_path = "./AlexNet.pth"
43     assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
44     model.load_state_dict(torch.load(weights_path))
45  
46     model.eval()
47     with torch.no_grad():
48         # predict class
49         output = torch.squeeze(model(img.to(device))).cpu()
50         predict = torch.softmax(output, dim=0)
51         predict_cla = torch.argmax(predict).numpy()
52  
53     print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
54                                                  predict[predict_cla].numpy())
55     plt.title(print_res)
56     for i in range(len(predict)):
57         print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
58                                                   predict[i].numpy()))
59     plt.show()
60  
61  
62 if __name__ == '__main__':
63     main()
View Code

 

posted @ 2022-10-31 20:04  赵家小伙儿  阅读(70)  评论(0编辑  收藏  举报