机器学习-完整的模型训练套路(pytorch环境)

一个例子

import torch
import torchvision

# 准备数据集
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import TuDui

train_data = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor(),
                                         download=True)

# length长度
train_data_size = len(train_data)
test_data_size = len(test_data)

# format():将字符串中的{}进行格式化
# print("测试集的长度为{}".format(train_data_size))

# 利用DataLoader加载数据集
train_dataLoader = DataLoader(train_data,batch_size=64)
test_dataLoader = DataLoader(test_data,batch_size=64)

# 创建网络模型
tudui = TuDui()

# 损失函数
loss_fun = nn.CrossEntropyLoss()

# 优化器
train_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=train_rate)

# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10
write = SummaryWriter("log")

for i in range(epoch):
    tudui.train()
    for data in train_dataLoader:
        imgs,targets = data
        outputs = tudui(imgs)
        loss = loss_fun(outputs,targets)

        # 优化器调优
        # 梯度清零
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if total_train_step % 100 == 0:
            write.add_scalar("train_loss",loss.item(),total_train_step)
        total_train_step += 1

    tudui.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataLoader:
            imgs,targets = data
            outputs = tudui(imgs)
            loss = loss_fun(outputs,targets)
            accuracy = (outputs.argmax(1) == targets).sum()
            total_test_loss += loss.item()
            total_accuracy += accuracy

    write.add_scalar("test_loss",total_test_step,total_test_loss)
    total_test_step += 1

    torch.save(tudui,"tudui{}.pth".format(i))

write.close()

待补。。。

posted @ 2021-09-05 16:55  EA2218764AB  阅读(220)  评论(0编辑  收藏  举报