【pytorch】土堆pytorch教程学习(二)加载数据
Pytorch加载数据初认识
pytorch 中加载数据主要涉及两个类:Dataset
和 Dataloader
。
-
Dataset
提供一种方式去获取数据及其label -
Dataloader
构建可迭代的数据装载器,为网络提供不同的数据形式
Dataset
Dataset
实现的功能:
- 获取每个数据及其label
- 获取数据长度
每个数据集都需要继承 torch.utils.data.Dataset
类,并且重写 __getitem__
和 __len__
。
数据存放在 dataset/train
里,分为两个目录 ants
和 bees
,也分别是数据的标签,如下图所示:
from PIL import Image
from torch.utils.data import Dataset
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 根目录
self.label_dir = label_dir # 标签目录
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path) # 图片路径列表
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.path, img_name)
img = Image.open(img_item_path) # 获取数据
label = self.label_dir # 获取label
return img, label
def __len__(self):
return len(self.img_path) # 获取数据集长度
# test
root_dir = 'dataset/train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
# 生成两个数据集
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset # 拼接两个数据集
img1, label1 = ants_dataset[0]
img1.show()
print('label1:', label1)
img2, label2 = train_dataset[130]
img2.show()
print('label2:', label2)
本文来自博客园,作者:hzyuan,转载请注明原文链接:https://www.cnblogs.com/hzyuan/p/17344219.html
分类:
ai / pytorch
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)