随笔 - 34  文章 - 0  评论 - 0  阅读 - 3316

深度学习-图像分类

 

 

 

 

复制代码
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time



transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#
#     # 50000张训练图片
#     # 第一次使用时要将download设置为True才会自动去下载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
                                                shuffle=True, num_workers=0)

#     # 10000张验证图片
#     # 第一次使用时要将download设置为True才会自动去下载数据集
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=4,
                                         shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
val_image, val_label = next(val_data_iter)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# print labels
print(' '.join(f'{classes[val_label[j]]:5s}' for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(val_image))


#     net = LeNet()
#     loss_function = nn.CrossEntropyLoss()
#     optimizer = optim.Adam(net.parameters(), lr=0.001)
#
#     for epoch in range(5):  # loop over the dataset multiple times
#
#         running_loss = 0.0
#         for step, data in enumerate(train_loader, start=0):
#             # get the inputs; data is a list of [inputs, labels]
#             inputs, labels = data
#
#             # zero the parameter gradients
#             optimizer.zero_grad()
#             # forward + backward + optimize
#             outputs = net(inputs)
#             loss = loss_function(outputs, labels)
#             loss.backward()
#             optimizer.step()
#
#             # print statistics
#             running_loss += loss.item()
#             if step % 500 == 499:    # print every 500 mini-batches
#                 with torch.no_grad():
#                     outputs = net(val_image)  # [batch, 10]
#                     predict_y = torch.max(outputs, dim=1)[1]
#                     accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)
#
#                     print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
#                           (epoch + 1, step + 1, running_loss / 500, accuracy))
#                     running_loss = 0.0
#
#     print('Finished Training')
#
#     save_path = './Lenet.pth'
#     torch.save(net.state_dict(), save_path)
#
#
# if __name__ == '__main__':
#     main()
复制代码

 

 

 

posted on   ccxwyyjy  阅读(32)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

点击右上角即可分享
微信分享提示