pytorch数据集MNIST训练与测试实例

 

 

复制代码
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F


BATCH_SIZE = 128
TEST_BATCH_SIZE = 516
#1、准备数据集
def get_dataloader(train=True,batch_size=BATCH_SIZE):
    transform_fn = Compose([ToTensor(),Normalize(mean=(0.1307,),std=(0.3081,))]) #mean和std的形状和通道数相同
    dataset = MNIST(root='./data',train=train,download=False,transform=transform_fn)
    data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
    return data_loader
    # for i in enumerate(data_loader):
    #     print(i)

#2.构建模型
class MnisModel(nn.Module):
    def __init__(self):
        super(MnisModel,self).__init__()
        self.fc1 = nn.Linear(1*28*28,28)
        self.fc2 = nn.Linear(28,10)

    def forward(self,input):
        """
        input:[batch_size,1,28,28]
        """
        #1.修改形状
        x = input.view([input.size(0),1*28*28]) # input.view(-1,1*28*28)
        #2.进行全连接的操作
        x = self.fc1(x)
        #3.进行激活函数处理,形状不会发生变化
        x = F.relu(x)
        #4.输出层
        out = self.fc2(x)
        return F.log_softmax(out,dim=-1)

# 1.实例化模型
model = MnisModel()
#2.实例优化器类
optimizer = Adam(model.parameters(),lr=0.001)
if os.path.exists("./model/model.pt"):
    model.load_state_dict(torch.load("./model/model.pt"))  #加载模型
    optimizer.load_state_dict(torch.load("./model/optimizer.pt"))  #加载优化器

def train(epoch):
    """
    实现训练过程
    """
    #3.加载数据集,遍历
    data_loader = get_dataloader()
    for idex,(input,target) in enumerate(data_loader):
        optimizer.zero_grad()  #4.梯度置为0
        output = model(input)  #5.调用模型,得到预测值
        loss = F.nll_loss(output,target)  #6.计算损失
        loss.backward()  #7.反向传播
        optimizer.step()  #8.梯度的更新
        if idex % 100 == 0:
            print(loss.item())

        if idex % 100 == 0:
            torch.save(model.state_dict(),"./model/model.pt")  #保存模型参数
            torch.save(optimizer.state_dict(),"./model/optimizer.pt") #保存优化器参数

def test():  #测试数据
    loss_list = []
    acc_list = []
    test_dataloader = get_dataloader(False,batch_size=TEST_BATCH_SIZE)  #获取测试数据集
    for idx,(input,target) in enumerate(test_dataloader):
        # print(idx,target,input)
        # break
        with torch.no_grad():
            output = model(input)
            cur_loss = F.nll_loss(output,target)
            loss_list.append(cur_loss)
            #计算准备率
            #output [batch_size,10] target:[batch_size]
            pred = output.max(dim = -1)[-1]
            cur_acc = pred.eq(target).float().mean()
            acc_list.append(cur_acc)

    print("平均准确率,平均损失",np.mean(acc_list),np.mean(loss_list))


if __name__ == '__main__':
    # for i  in range(3):  #训练三轮
    #     train(i)
    test()
复制代码

 

posted @   ziff123  阅读(96)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
历史上的今天:
2023-02-05 npm通过--registry来指定安装源
2023-02-05 npm 的配置文件 .npmrc
点击右上角即可分享
微信分享提示