用torchvision.datasets.ImageFolder加载图片数据集
一、项目结构
二、代码
1 data_loader = torch.utils.data.DataLoader( 2 torchvision.datasets.ImageFolder('traing_dataset', 3 transform=torchvision.transforms.Compose([ 4 torchvision.transforms.Resize([28, 28]), # 裁剪图片 5 torchvision.transforms.Grayscale(1), # 单通道 6 torchvision.transforms.ToTensor(), # 将图片数据转成tensor格式 7 torchvision.transforms.Normalize( # 归一化 8 (0.1307,), (0.3081,)) 9 ])), 10 batch_size=10, shuffle=False) # 10张图片
三、显示效果
1 def plot_image(img, label, name): 2 fig = plt.figure() 3 for i in range(6): # 只显示6张 4 plt.subplot(2, 3, i+1) # 2行3列第i+1张 5 plt.tight_layout() 6 plt.imshow(img[i][0]*0.3081+0.1307, cmap='Greys', interpolation='none') 7 plt.title("{}:{}".format(name, label[i].item())) # 标题名称 8 plt.xticks([]) 9 plt.yticks([]) 10 plt.show() 11 12 x, y = next(iter(data_loader)) # 文件夹的名称即为图片的label 13 print(x.shape, y.shape) 14 plot_image(x, y, 'image')