第三章 3.3 使用pytorch的数据集
比较复杂,可以将代码放在AI(温馨一眼)中做解读
代码:
# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch # https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch ################### Chapter Three ####################################### # 第三章 读取数据集并显示 from torchvision import datasets import torch ######################################################################## data_folder = '~/data/FMNIST' # This can be any directory you want # to download FMNIST to fmnist = datasets.FashionMNIST(data_folder, download=False, train=True) ######################################################################## tr_images = fmnist.data #图像 tr_targets = fmnist.targets #标签 ######################################################################## #检查张量 unique_values = tr_targets.unique() print(f'train_images_shape: {tr_images.shape}') print(f'train_targets_shape: {tr_targets.shape}') print(f'train_targets_unique_values:{unique_values}') print(f'train_targets_classes: {fmnist.classes}') ######################################################################## import matplotlib.pyplot as plt #matplotlib inline import numpy as np R, C = len(tr_targets.unique()), 10 fig, ax = plt.subplots(R, C, figsize=(10,10)) for label_class, plot_row in enumerate(ax): label_x_rows = np.where(tr_targets == label_class)[0] for plot_cell in plot_row: plot_cell.grid(False); plot_cell.axis('off') ix = np.random.choice(label_x_rows) x, y = tr_images[ix], tr_targets[ix] plot_cell.imshow(x, cmap='gray') plt.tight_layout() plt.show()