Pytorch卷积神经网络识别手写数字集

卷积神经网络目前被广泛地用在图片识别上, 已经有层出不穷的应用, 如果你对卷积神经网络充满好奇心,这里为你带来pytorch实现cnn一些入门的教程代码

#首先导入包

import torch
from torch.autograd import Variable
import torch.nn as nn
import torchvision
import torch.utils.data as Data

 

 

#一、数据准备

#训练数据:用了torchvision.datasets.MNIST,root是文件路径,train为True(这是训练数据),transform是把图像数据转换为张量,download(如果本地已有该文件选择false,没有就选择true)

train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=False)

#训练数据:同上,train为False(这是测试数据)

test_data = torchvision.datasets.MNIST(root='./mnist/',train=False)

# "训练数据加载器":dataset为训练数据,shuflle为打乱数据的顺序,batch_size是让数据50个为一组

train_loader = Data.DataLoader(dataset=train_data,shuffle=True,batch_size=50)

test_data.test_data.size()

torch.Size([10000, 28, 28])

#测试数据 test_data下的test_data为测试数据,因为下面conv2d输入的为4维数据,所以此处用torch.unsqueeze升维

test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)

#测试数据目标值

test_y = test_data.test_labels

 

 

#二、实现模型

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()

    #conv2d参数:输入1维,输出16维,5个卷积核(kernel),步长(stride)为1,padding是2(如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1)
    self.conv1 = nn.Sequential(nn.Conv2d(1,16,5,1,2),nn.ReLU(),nn.MaxPool2d(2))
    self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))

    #Linear参数:输入维数,输出分的种类数
    self.out = nn.Linear(32*7*7,10)
  def forward(self,x):
    x1 = self.conv1(x)
    x2 = self.conv2(x1)

    #这里给x3降为2维可以让linear函数使用
    x3 = x2.view(x2.size(0),-1)
    out = self.out(x3)
    return out

 

#自动调整参数,最优化模型

cnn = CNN()

optimizer = torch.optim.Adam(cnn.parameters(),lr = 0.02)
loss_func = nn.CrossEntropyLoss()

 

#三、训练模型

for step,(x,y) in enumerate(train_loader):
  x = Variable(x)
  y = Variable(y)
  out = cnn(x)
  loss = loss_func(out,y)

  #以下为固定操作,为了训练每一条数据,不断调整参数
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

 

 

 

#四、测试

predict = cnn(test_x[:10])
res = torch.max(predict,1)[1]

res #测试数据

tensor([7, 2, 1, 0, 4, 1, 4, 9, 9, 9])

test_y[:10] #真实数据

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])

 

#在这里我们发现前十个数据分类准确率达到90

 

posted @ 2019-05-20 10:03  温祖斌  阅读(1137)  评论(0编辑  收藏  举报