matplotlib.pyplot展示MNIST图片

import torch
import torch.utils.data as Data
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
 
#torch.manual_seed(1)    # reproducible
 
# Hyper Parameters
EPOCH = 1               # train the training data n times, to save time, we just train 1 epoch
batch_size = 4
LR = 0.001              # learning rate
DOWNLOAD_MNIST = True   # set to False if you have downloaded
 
classes = ['0','1','2','3','4','5','6','7','8','9']


# Mnist digits dataset
train_dataset = torchvision.datasets.MNIST(
    root='../../data/',
    train=True,                                     # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=DOWNLOAD_MNIST,                        # download it if you don't have it
)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

print(type(train_dataset))
def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()

dataiter = iter(train_loader)
images,labels = dataiter.next()
print(images.shape)#[50,1,28,28]

imshow(torchvision.utils.make_grid(images))
print(''.join('%5s' % classes[labels[j]] for j in range(4)))

  

posted @ 2019-01-11 19:42  youqia  阅读(1127)  评论(0编辑  收藏  举报