AlexNet网络的实现

model

 1 import torch.nn as nn
 2 import torch
 3 
 4 class AlexNet(nn.Module):
 5     def __init__(self, num_classes=1000, init_weights=False):
 6         super(AlexNet, self).__init__()
 7         self.features = nn.Sequential(             #Sequential将一系列的层结构进行打包,组合成一个新的结构
 8             nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),   # padding=(1,2)   在上下方各补一行0,在左右两侧各补两列0
 9             nn.ReLU(inplace=True),
10             nn.MaxPool2d(kernel_size=3, stride=2),  # output[48, 27, 27]
11             nn.Conv2d(48, 128, kernel_size=5, padding=2),  # output[128, 27, 27]
12             nn.ReLU(inplace=True),
13             nn.MaxPool2d(kernel_size=3, stride=2),  # output[128, 13, 13]
14             nn.Conv2d(128, 192, kernel_size=3, padding=1),  # output[192, 13, 13]
15             nn.ReLU(inplace=True),
16             nn.Conv2d(192, 192, kernel_size=3, padding=1),  # output[192, 13, 13]
17             nn.ReLU(inplace=True),
18             nn.Conv2d(192, 128, kernel_size=3, padding=1),  # output[128, 13, 13]
19             nn.ReLU(inplace=True),
20             nn.MaxPool2d(kernel_size=3, stride=2),  # output[128, 6, 6]
21         )
22 
23         self.classifier = nn.Sequential(
24             nn.Dropout(p=0.5),
25             nn.Linear(128 * 6 * 6, 2048),
26             nn.ReLU(inplace=True),
27             nn.Dropout(p=0.5),
28             nn.Linear(2048, 2048),
29             nn.ReLU(inplace=True),
30             nn.Linear(2048, num_classes),
31         )
32         if init_weights:
33             self._initialize_weights()   #初始化权重函数
34 
35     def forward(self, x):
36         x = self.features(x)
37         x = torch.flatten(x, start_dim=1)  #从第1维开始展平,第0维为batch
38         x = self.classifier(x)
39         return x
40 
41     def _initialize_weights(self):
42         for m in self.modules():    #遍历网络中所有的模块
43             if isinstance(m, nn.Conv2d):      #判断属于那个类别,  判断m是否属于nn.Conv2d这个类别
44                 nn.init.kaiming_normal_(m.weight, mode='fan_out',nonlinearity='relu')  #kaiming_normal_ 初始化变量方法
45                 if m.bias is not None:    #bias 偏置不为空,用0对其初始化
46                     nn.init.constant_(m.bias, 0)
47             elif isinstance(m, nn.Linear):
48                 nn.init.normal_(m.weight, 0, 0.01)   #均值为0,方差为0.01
49                 nn.init.constant_(m.bias, 0)

train

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt     #绘制图像的包
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time
import sys

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_trsanform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),  #随即裁剪224*224像素的大小
                                     transforms.RandomHorizontalFlip(),  #在随机方向随机反转
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }

    data_root = os.path.abspath(os.path.join(os.getcwd()))   #get data root path
    print(data_root)
    image_path = data_root + "/flower_data/"

    train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                         transform=data_trsanform["train"])  #data_trsanform["train"]传入训练集的预处理函数
    train_num = len(train_dataset)

    flower_list = train_dataset.class_to_idx #{'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    cla_dict = dict((val, key) for key,val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)   #将cla_dict进行编码,编码为json格式
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)


    batch_size = 32
    train_loader = torch.utils.data.DataLoader(train_dataset,       #将训练集导入
                                              batch_size=batch_size,
                                              shuffle=True,   #是否将数据集打乱
                                              num_workers=0)  #载入数据的线程数,window下只能为0

    validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                            transform = data_trsanform["val"])#data_trsanform["val"]传入测试集的预处理函数
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,  # 是否将数据集打乱
                                                  num_workers=0)  # 载入数据的线程数,window下只能为0



    # test_data_iter = iter(validate_dataset)    #将validate_dataset转化为可迭代的迭代器
    # test_image, test_label = test_data_iter.next()    #next()方法可以获取一批数据
    # def imshow(img):
    #     img = img / 2 + 0.5     # unnormalize  对图像进行反标准化处理
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # # print labels
    # print(' '.join(f'{classes[test_label[j]]:5s}' for j in range(4)))
    # # show images
    # imshow(torchvision.utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()   #定义损失函数
    #pata = list(net.parameters())    #调试用,查看模型的参数
    optimizer = optim.Adam(net.parameters(), lr=0.0002)   #定义优化器,net.parameters()所需要训练的参数,lr学习率

    save_path = './AlexNet.pth'   #保存模型权重的路径
    best_acc = 0.0   #最佳准确率,为了后面保存最高准确率的模型


    for epoch in range(10):      #epoch表示将训练集迭代多少轮
        # train
        net.train()   #  net.train()和后面的net.eval()  管理Dropout, 只在训练中使用Dropout,预测中不使用
        running_loss = 0.0      #累加训练过程中的损失
        t1 = time.perf_counter() #统计训练一次的时间

        for step, data in enumerate(train_loader, start=0):   #遍历训练集样本
            #get the inputs; data is a list of [inputs, labels]
            images, labels = data

            #zero the parameter gradients
            optimizer.zero_grad()   #将历史损失梯度给清零

            #forward + backward + optimize
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))  #outputs网络预测值, labels真实值
            loss.backward()   #将loss反向传播
            optimizer.step()   #进行参数更新

            #print statistics
            running_loss += loss.item()

            #print train process
            rate = (step + 1) / len(train_loader)
            a = "*" * int(rate * 50)
            b = "." * int((1-rate) * 50)
            print("\r train lodd: {:^3.0f}%[{}->{}]{:^3.0f}".format(int(rate*100),a,b,loss),end="")
        print()
        print(time.perf_counter()-t1)


        #validate
        net.eval()
        acc = 0.0  # accumulate accurate number/epoch
        with torch.no_grad():  # 接下来的过程中不要计算每个节点的误差损失梯度
            for data_test in validate_loader:
                test_images, test_labels = data_test
                outputs = net(test_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]  # [1]表示返回其位置
                acc += (predict_y == test_labels.to(device)).sum().item()
            accurate_test = acc / val_num
            if accurate_test > best_acc:
                best_acc = accurate_test
                torch.save(net.state_dict(), save_path)   #保存模型
            print('[epoch %d]  train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch+1, running_loss/step, acc/val_num))

    print('Finished Training')

if __name__ == '__main__':
    main()

predict

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
from model import AlexNet

data_transform = transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#load image
img = Image.open("1.jpg")
plt.imshow(img)
#[N, C, H, W]
img = data_transform(img)
#expand batch dimension
img = torch.unsqueeze(img, dim=0)

#read class_indict
try:
    json_file = open('./class_indices.json', 'r')  #读取json文件索引对应的类别名称
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

#create model
model = AlexNet(num_classes=5)
#load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path)) #载入模型
model.eval()   #关闭掉Dropout方法
with torch.no_grad():
    #predic class
    output = torch.squeeze(model(img))  #squeeze() 将batch维度压缩掉
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())

plt.show()

 

posted @ 2022-06-20 20:48  Hello'world  阅读(68)  评论(0编辑  收藏  举报