多层全连接网络实现手写数字识别(PyTorch)

具体细节见深度学习之PyTorch(廖星宇)

#基于深度神经网络的手写数字识别的PyTorch实现
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms #提供预处理

#三层神经网络
class simpleNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(simpleNet, self).__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer3 = nn.Linear(n_hidden_2, out_dim)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

#激活函数
class Activation_Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Activation_Net, self).__init__()
        #Sequential序列模型,像堆积木那样将各层网络堆起来
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.ReLU(True)) 
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x
        
#batch-normalization(将标准化的过程应用到每层神经网络)
class Batch_Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Batch_Net, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x) 
        x = self.layer3(x)
        return x

#对超参数和数据进行处理
batch_size = 64
learning_rate = 1e-2
num_epoches = 20

data_tf = transforms.Compose(
    [transforms.ToTensor(), #标准化
    transforms.Normalize([0.5], [0.5])] #归一化
)

#下载训练集MNIST手写训练集
train_dataset = datasets.MNIST(root = './data', train = True, transform = data_tf, download = True)
test_dataset = datasets.MNIST(root = './data', train = False, transform = data_tf)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

#定义损失函数和优化器
model = Batch_Net(28 * 28, 300, 100, 10)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = learning_rate)

#训练网络
for epoch in range(num_epoches):
    model.train()
    for data in train_loader: #每次取一个batch_size张图片
        img, label = data #img.size = 128 * 1 * 28 * 28
        img = img.view(img.size(0), -1) #展开成 128 * 784
        
        if torch.cuda.is_available():
            img = img.cuda()
            label = label.cuda()
        else:
            img = Variable(img)
            label = Variable(label)
        
        #前向传播
        out = model(img)
        loss = criterion(out, label)
        
        #反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print('Epoch:{}, Loss:{:.4f}'.format(epoch + 1, loss.data.item())) #随机梯度下降法,epoch一次,梯度下降好多次

#测试网络
model.eval()
eval_loss = 0
eval_acc = 0
for data in test_loader:
    img, label = data
    img = img.view(img.size(0), -1)
    
    if torch.cuda.is_available():
        img = Variable(img, volatile = True).cuda()
        label = Variable(label, volatile = True).cuda()
    else:
        img = Variable(img, volatile = True)
        label = Variable(label, volatile = True)
    
    out = model(img)
    loss = criterion(out, label)
    
    eval_loss += loss.item() * label.size(0) #label.size(0) = 128
    _, predict = torch.max(out, 1)
    num_correct = (predict == label).sum()
    eval_acc += num_correct.item()

print('Loss:{:.6f}, Acc:{:.6f}'.format(eval_loss / len(test_dataset), eval_acc / len(test_dataset)))

 

posted @ 2020-07-16 16:24  Peterxiazhen  阅读(825)  评论(0编辑  收藏  举报