从torchvision加载FashionMNIST数据并用matplotlib画出来

环境:

Python版本: 3.8
IDE:Spyder 5.2.2

代码:

from torchvision import datasets
from matplotlib import pyplot as plt

# 下载数据
from torchvision import datasets
from matplotlib import pyplot as plt
import math

# 下载数据
training_data = datasets.FashionMNIST(
    root='data',  #这个root指的是数据存放在本地电脑的路径
    train=True,   #拿的是训练的数据
    download=True  #是否要下载
)

test_data = datasets.FashionMNIST(
    root='data',
    train=False,  
    download=True
    )

tdata = test_data.data.numpy() # tensor转换成numpy
tclasses = test_data.classes


# 抽取前面的10张画出来
plt.figure(figsize=(3,3))
for i in range(9):
    print(i)
    plt.subplot(3,3,i+1)
    plt.axis('off')   #去掉坐标轴
    plt.imshow(tdata[i])
    plt.title(tclasses[i])

# 让子标题不会和轴重合
plt.tight_layout()

效果

posted @ 2022-02-21 14:44  裏表異体  阅读(107)  评论(0编辑  收藏  举报