深度学习-图像分类
import torch import torchvision import torch.nn as nn from model import LeNet import torch.optim as optim import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np import time transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # # # 50000张训练图片 # # 第一次使用时要将download设置为True才会自动去下载数据集 train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_set, batch_size=36, shuffle=True, num_workers=0) # # 10000张验证图片 # # 第一次使用时要将download设置为True才会自动去下载数据集 val_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform) val_loader = torch.utils.data.DataLoader(val_set, batch_size=4, shuffle=False, num_workers=0) val_data_iter = iter(val_loader) val_image, val_label = next(val_data_iter) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # print labels print(' '.join(f'{classes[val_label[j]]:5s}' for j in range(4))) # show images imshow(torchvision.utils.make_grid(val_image)) # net = LeNet() # loss_function = nn.CrossEntropyLoss() # optimizer = optim.Adam(net.parameters(), lr=0.001) # # for epoch in range(5): # loop over the dataset multiple times # # running_loss = 0.0 # for step, data in enumerate(train_loader, start=0): # # get the inputs; data is a list of [inputs, labels] # inputs, labels = data # # # zero the parameter gradients # optimizer.zero_grad() # # forward + backward + optimize # outputs = net(inputs) # loss = loss_function(outputs, labels) # loss.backward() # optimizer.step() # # # print statistics # running_loss += loss.item() # if step % 500 == 499: # print every 500 mini-batches # with torch.no_grad(): # outputs = net(val_image) # [batch, 10] # predict_y = torch.max(outputs, dim=1)[1] # accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0) # # print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' % # (epoch + 1, step + 1, running_loss / 500, accuracy)) # running_loss = 0.0 # # print('Finished Training') # # save_path = './Lenet.pth' # torch.save(net.state_dict(), save_path) # # # if __name__ == '__main__': # main()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】