PyTorch 从入门到放弃 —— 加载数据
PyTorch 有两种基础数据类型: torch.utils.data.DataLoader
和 torch.utils.data.Dataset
. Dataset,它们存储着样本和对应的标记。
Dataset是样本数据集,DataLoader对Dataset进行封装,方便加载、遍历和分批等。
1 2 3 4 5 | import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor |
PyTorch 提供了不同用途的数据集,比如: TorchText, TorchVision, and TorchAudio. 在本教程中,我们使用TorchVision。
torchvision.datasets
模块包含了各种视觉数据集, 比如 CIFAR, COCO (完整列表)。 本教程我们使用FashionMNIST数据集。 每个视觉数据集包含2个参数:transform
和 target_transform,可以分别用来修改样本和标记。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | # 从开放机构下载训练数据集 training_data = datasets.FashionMNIST( root = "data" , train = True , download = True , transform = ToTensor(), ) # 下载测试数据集 test_data = datasets.FashionMNIST( root = "data" , train = False , download = True , transform = ToTensor(), ) |
输出:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz 0%| | 0/26421880 [00:00<?, ?it/s] 0%| | 65536/26421880 [00:00<01:12, 365718.31it/s] 1%| | 229376/26421880 [00:00<00:38, 685682.68it/s] 3%|3 | 884736/26421880 [00:00<00:10, 2498938.52it/s] 7%|7 | 1933312/26421880 [00:00<00:05, 4141475.37it/s] 19%|#8 | 4915200/26421880 [00:00<00:01, 10854978.12it/s] 26%|##5 | 6782976/26421880 [00:00<00:01, 11037400.65it/s] 37%|###7 | 9797632/26421880 [00:01<00:01, 15568756.79it/s] 44%|####4 | 11730944/26421880 [00:01<00:01, 14184748.16it/s] 55%|#####5 | 14647296/26421880 [00:01<00:00, 17510568.70it/s] 63%|######3 | 16777216/26421880 [00:01<00:00, 15834704.91it/s] 75%|#######4 | 19693568/26421880 [00:01<00:00, 18759775.35it/s] 83%|########2 | 21889024/26421880 [00:01<00:00, 16780435.96it/s] 94%|#########3| 24772608/26421880 [00:01<00:00, 19391805.01it/s] 100%|##########| 26421880/26421880 [00:01<00:00, 13914460.04it/s] Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz 0%| | 0/29515 [00:00<?, ?it/s] 100%|##########| 29515/29515 [00:00<00:00, 326673.50it/s] Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz 0%| | 0/4422102 [00:00<?, ?it/s] 1%|1 | 65536/4422102 [00:00<00:12, 362354.20it/s] 5%|5 | 229376/4422102 [00:00<00:06, 684627.79it/s] 21%|## | 917504/4422102 [00:00<00:01, 2626211.85it/s] 44%|####3 | 1933312/4422102 [00:00<00:00, 4103892.12it/s] 100%|##########| 4422102/4422102 [00:00<00:00, 6109664.51it/s] Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz 0%| | 0/5148 [00:00<?, ?it/s] 100%|##########| 5148/5148 [00:00<00:00, 61868988.52it/s] Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
把Dataset
作为参数传递给DataLoader
。这样就可以把数据集封装起来,实现自动分批,取样,打乱和多处理器协同加载。在这里,我们定义每批大小为65,这样一来,分批遍历dataloader的时候,就能在循环中每次取到64组特征和标记。
1 2 3 4 5 6 7 8 9 10 | batch_size = 64 # Create data loaders. train_dataloader = DataLoader(training_data, batch_size = batch_size) test_dataloader = DataLoader(test_data, batch_size = batch_size) for X, y in test_dataloader: print (f "Shape of X [N, C, H, W]: {X.shape}" ) print (f "Shape of y: {y.shape} {y.dtype}" ) break |
输出:
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) Shape of y: torch.Size([64]) torch.int64
想了解更多请移步 从TyTorch加载数据
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Blazor Hybrid适配到HarmonyOS系统
· Obsidian + DeepSeek:免费 AI 助力你的知识管理,让你的笔记飞起来!
· 分享4款.NET开源、免费、实用的商城系统
· 解决跨域问题的这6种方案,真香!
· 一套基于 Material Design 规范实现的 Blazor 和 Razor 通用组件库