第三章 3.3 使用pytorch的数据集

 比较复杂,可以将代码放在AI(温馨一眼)中做解读

代码:

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch
# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch

###################  Chapter Three #######################################

# 第三章  读取数据集并显示

from torchvision import datasets
import torch
########################################################################

data_folder = '~/data/FMNIST' # This can be any directory you want
# to download FMNIST to
fmnist = datasets.FashionMNIST(data_folder, download=False, train=True)

########################################################################
tr_images = fmnist.data     #图像
tr_targets = fmnist.targets #标签

########################################################################
#检查张量
unique_values = tr_targets.unique()
print(f'train_images_shape: {tr_images.shape}')
print(f'train_targets_shape: {tr_targets.shape}')
print(f'train_targets_unique_values:{unique_values}')
print(f'train_targets_classes: {fmnist.classes}')

########################################################################
import matplotlib.pyplot as plt
#matplotlib inline
import numpy as np


R, C = len(tr_targets.unique()), 10
fig, ax = plt.subplots(R, C, figsize=(10,10))
for label_class, plot_row in enumerate(ax):
    label_x_rows = np.where(tr_targets == label_class)[0]
    for plot_cell in plot_row:
        plot_cell.grid(False); plot_cell.axis('off')
        ix = np.random.choice(label_x_rows)
        x, y = tr_images[ix], tr_targets[ix]
        plot_cell.imshow(x, cmap='gray')
plt.tight_layout()

plt.show()

 

posted @ 2024-12-13 12:01  辛河  阅读(9)  评论(0编辑  收藏  举报