第三章 3.3 使用pytorch的数据集



# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch
# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch

###################  Chapter Three #######################################

# 第三章  读取数据集并显示

from torchvision import datasets
import torch

data_folder = '~/data/FMNIST' # This can be any directory you want
# to download FMNIST to
fmnist = datasets.FashionMNIST(data_folder, download=False, train=True)

tr_images = fmnist.data     #图像
tr_targets = fmnist.targets #标签

unique_values = tr_targets.unique()
print(f'train_images_shape: {tr_images.shape}')
print(f'train_targets_shape: {tr_targets.shape}')
print(f'train_targets_classes: {fmnist.classes}')

import matplotlib.pyplot as plt
#matplotlib inline
import numpy as np

R, C = len(tr_targets.unique()), 10
fig, ax = plt.subplots(R, C, figsize=(10,10))
for label_class, plot_row in enumerate(ax):
    label_x_rows = np.where(tr_targets == label_class)[0]
    for plot_cell in plot_row:
        plot_cell.grid(False); plot_cell.axis('off')
        ix = np.random.choice(label_x_rows)
        x, y = tr_images[ix], tr_targets[ix]
        plot_cell.imshow(x, cmap='gray')



posted @ 2024-12-13 12:01  辛河  阅读(9)  评论(0编辑  收藏  举报