果果又哭了

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

这是我的代码和注释,你可以通过直接复制代码到你的pycharm中跑起来。

你不需要另外去准备数据集,当本地没有数据集运行代码就会自动下载

这是一个很小的项目,你不需要准备GPU

mnist_train.py

  1 import torch
  2 
  3 # nn包用来完成神经网络的搭建
  4 from torch import nn
  5 
  6 # functional包含常用的函数
  7 from torch.nn import functional as F
  8 
  9 # optim优化数据包,用来更新权重
 10 from torch import optim
 11 
 12 # 视觉相关的工具包
 13 import torchvision
 14 
 15 # 导入画图工具包
 16 from matplotlib import pyplot as plt
 17 
 18 # 从utils 包里导入所需工具
 19 from utils import plot_image, plot_curve, one_hot
 20 
 21 # step1. load dataset 加载数据集
 22 
 23 # 这里设定一次处理多少张图片
 24 batch_size = 512
 25 
 26 # 加载训练集
 27 train_loader = torch.utils.data.DataLoader(
 28     # 加载MNIST数据集(1.图片路径,2.指定下载的图片为text还是train,3.download若1本地没有则去网上下载,
 29     # 4.transform格式转换,网上图片一般为numpy格式,转为totensor格式)
 30     torchvision.datasets.MNIST('mnist_data', train=True, download=True,
 31                                transform=torchvision.transforms.Compose([
 32                                    torchvision.transforms.ToTensor(),
 33                                    torchvision.transforms.Normalize(
 34                                        (0.1307,), (0.3081,))
 35                                    # 这个参数是正则化,防止过拟合,防止参数过多或过大,避免模型过复杂。有L1正则化和L2正则化,这里是让参数维持在0的附近均匀的分配
 36                                ])),  # 0.3081是均差
 37     batch_size=batch_size, shuffle=True)  # 加载数据并随机打散数据
 38 
 39 # 加载测试集
 40 test_loader = torch.utils.data.DataLoader(
 41     torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
 42                                transform=torchvision.transforms.Compose([
 43                                    torchvision.transforms.ToTensor(),
 44                                    torchvision.transforms.Normalize(
 45                                        (0.1307,), (0.3081,))
 46                                ])),
 47     batch_size=batch_size, shuffle=False)
 48 
 49 
 50 # # 查看图片
 51 x, y = next(iter(train_loader))
 52 print(x.shape, y.shape, x.min(), x.max())
 53 plot_image(x, y, 'image sample')
 54 
 55 
 56 # 设置网络层
 57 
 58 class Net(nn.Module):
 59 
 60     def __init__(self):
 61         super(Net, self).__init__()
 62 
 63         # xw + b
 64         # 第一层,第一个参数为图像大小,第二个参数根据经验值设置输出层大小
 65         self.fc1 = nn.Linear(28 * 28, 256)
 66         # 第二层,第一个个参数为上一层的输出大小,第二个大小根据经验设置输出层大小
 67         self.fc2 = nn.Linear(256, 64)
 68         # 最后一层,第一个值为上一层输出大小,第二个参数为输出的种类数
 69         self.fc3 = nn.Linear(64, 10)
 70 
 71     # 计算函数
 72     def forward(self, x):
 73         # x:[b,1,28,28]  #relu将线性函数调整变种为非线性函数
 74         # h1=relu(xw1 +b1)
 75         x = F.relu(self.fc1(x))
 76         # h2=relu(h1w2+b20
 77         x = F.relu(self.fc2(x))
 78         # 第三层为输出层,一般输出概率值
 79         x = self.fc3(x)
 80 
 81         return x
 82 
 83 
 84 # 对创建的神经网络进行初始化
 85 net = Net()
 86 
 87 # 设置对计算后的梯度进行梯度更新方法,这里采用SGD随机梯度下降,lr是学习率
 88 optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
 89 
 90 train_loss = []
 91 
 92 for epoch in range(3):
 93     # 对整个数据集迭代三次
 94     for batch_idx, (x, y) in enumerate(train_loader):
 95         # 对整个数据集迭代一次
 96 
 97         # x :
 98         # print(x.shape,y.shape)
 99 
100         # 输入
101         x = x.view(x.size(0), 28 * 28)
102 
103         # 输出
104         out = net(x)  # 我们的目的是将输出更加接近于y
105 
106         # 将真实的y转为独热编码
107         y_onehot = one_hot(y)
108 
109         # 通过mse_loss计算误差值,也就是均方差
110         loss = F.mse_loss(out, y_onehot)
111 
112         # 清零梯度
113         optimizer.zero_grad()
114         # 计算梯度
115         loss.backward()
116         # 更新梯度
117         optimizer.step()
118 
119         # 最后我们会得到较为合适的[w1,b1,w2,b2,w3,b3]
120 
121         # 将loss数据收集,以便用matplotlib将其变化图示化
122         train_loss.append(loss.item())
123 
124 
125         # 查看loss下降的变化
126         if batch_idx % 10 == 0:
127 
128             print(epoch, batch_idx, loss.item())
129 
130 
131  plot_curve(train_loss)
132 
133 # 我们最终想要看到的并不是loss而是准确率
134 # 准确度的测试
135 # 在test测试集取数据然后进行测试
136 total_correct = 0
137 for x,y in test_loader:
138     x  = x.view(x.size(0), 28*28)
139     out = net(x)
140     # out: [b, 10] => pred: [b]
141     pred = out.argmax(dim=1)
142     correct = pred.eq(y).sum().float().item()
143     total_correct += correct
144 145 total_num = len(test_loader.dataset)
146 acc = total_correct / total_num
147 print('test acc:', acc)
148 149 x, y = next(iter(test_loader))
150 out = net(x.view(x.size(0), 28*28))
151 pred = out.argmax(dim=1)
152 plot_image(x, pred, 'test')

 

 

 

utils.py文件中包含的是画图函数,和独热编码的函数,可以直接调用,比如上面的代码就调用了它。将它一并放入你的pycharm中

import torch
from matplotlib import pyplot as plt


# 画一条曲线
def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()



# 可视化查看识别结果
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)  # 生成独热编码
    return out

 

posted on 2020-04-14 20:39  果果又哭了  阅读(227)  评论(0编辑  收藏  举报