笔记7:训练过程封装(代码模板)

相关包

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torchvision
from torchvision import datasets, transforms
%matplotlib inline

训练过程封装

def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    for x, y in trainloader:
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim = 1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()

    epoch_acc = correct / total
    epoch_loss = running_loss / len(trainloader.dataset)
    
    test_correct = 0
    test_total = 0
    test_running_loss = 0
    
    with torch.no_grad():
        for x, y in testloader:
            y_pred = model(x)
            loss = loss_func(y_pred, y)
            y_pred = torch.argmax(y_pred, dim = 1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    epoch_test_acc = test_correct / test_total
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy: ', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy: ', round(epoch_test_acc, 3))
    
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

这里的correct指的是每一轮中分类正确的样本数
total指的是每一轮总的样本数
running_loss指的是在一轮中,损失值的总和

初始化

model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
epochs = 100

模型训练

train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

训练可视化

plt.plot(range(1, epochs + 1), train_loss, label = 'train_loss')
plt.plot(range(1, epochs + 1), test_loss, label = 'test_loss')
plt.legend()

plt.plot(range(1, epochs + 1), train_acc, label = 'train_acc')
plt.plot(range(1, epochs + 1), test_acc, label = 'test_acc')
plt.legend()
posted @ 2021-01-27 16:04  pbc的成长之路  阅读(156)  评论(0编辑  收藏  举报