PyTorch 从入门到放弃 —— 加载数据

PyTorch 有两种基础数据类型: torch.utils.data.DataLoader 和 torch.utils.data.DatasetDataset,它们存储着样本和对应的标记。 Dataset是样本数据集,DataLoader对Dataset进行封装,方便加载、遍历和分批等。

1
2
3
4
5
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

  

PyTorch 提供了不同用途的数据集,比如: TorchTextTorchVision, and TorchAudio. 在本教程中,我们使用TorchVision。

torchvision.datasets 模块包含了各种视觉数据集, 比如 CIFAR, COCO (完整列表)。 本教程我们使用FashionMNIST数据集。 每个视觉数据集包含2个参数:transform 和 target_transform,可以分别用来修改样本和标记。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 从开放机构下载训练数据集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
 
# 下载测试数据集
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
复制代码
输出:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz 0%| | 0/26421880 [00:00<?, ?it/s] 0%| | 65536/26421880 [00:00<01:12, 365718.31it/s] 1%| | 229376/26421880 [00:00<00:38, 685682.68it/s] 3%|3 | 884736/26421880 [00:00<00:10, 2498938.52it/s] 7%|7 | 1933312/26421880 [00:00<00:05, 4141475.37it/s] 19%|#8 | 4915200/26421880 [00:00<00:01, 10854978.12it/s] 26%|##5 | 6782976/26421880 [00:00<00:01, 11037400.65it/s] 37%|###7 | 9797632/26421880 [00:01<00:01, 15568756.79it/s] 44%|####4 | 11730944/26421880 [00:01<00:01, 14184748.16it/s] 55%|#####5 | 14647296/26421880 [00:01<00:00, 17510568.70it/s] 63%|######3 | 16777216/26421880 [00:01<00:00, 15834704.91it/s] 75%|#######4 | 19693568/26421880 [00:01<00:00, 18759775.35it/s] 83%|########2 | 21889024/26421880 [00:01<00:00, 16780435.96it/s] 94%|#########3| 24772608/26421880 [00:01<00:00, 19391805.01it/s] 100%|##########| 26421880/26421880 [00:01<00:00, 13914460.04it/s] Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz 0%| | 0/29515 [00:00<?, ?it/s] 100%|##########| 29515/29515 [00:00<00:00, 326673.50it/s] Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz 0%| | 0/4422102 [00:00<?, ?it/s] 1%|1 | 65536/4422102 [00:00<00:12, 362354.20it/s] 5%|5 | 229376/4422102 [00:00<00:06, 684627.79it/s] 21%|## | 917504/4422102 [00:00<00:01, 2626211.85it/s] 44%|####3 | 1933312/4422102 [00:00<00:00, 4103892.12it/s] 100%|##########| 4422102/4422102 [00:00<00:00, 6109664.51it/s] Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz 0%| | 0/5148 [00:00<?, ?it/s] 100%|##########| 5148/5148 [00:00<00:00, 61868988.52it/s] Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
复制代码

 

Dataset作为参数传递给DataLoader。这样就可以把数据集封装起来,实现自动分批,取样,打乱和多处理器协同加载。在这里,我们定义每批大小为65,这样一来,分批遍历dataloader的时候,就能在循环中每次取到64组特征和标记。

1
2
3
4
5
6
7
8
9
10
batch_size = 64
 
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
 
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
输出:
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) Shape of y: torch.Size([64]) torch.int64

 

想了解更多请移步 从TyTorch加载数据

posted @   陈景安  阅读(11)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Blazor Hybrid适配到HarmonyOS系统
· Obsidian + DeepSeek:免费 AI 助力你的知识管理,让你的笔记飞起来!
· 分享4款.NET开源、免费、实用的商城系统
· 解决跨域问题的这6种方案,真香!
· 一套基于 Material Design 规范实现的 Blazor 和 Razor 通用组件库
点击右上角即可分享
微信分享提示