【pytorch】土堆pytorch教程学习(二)加载数据

Pytorch加载数据初认识

pytorch 中加载数据主要涉及两个类:DatasetDataloader

  • Dataset 提供一种方式去获取数据及其label

  • Dataloader 构建可迭代的数据装载器,为网络提供不同的数据形式

Dataset

Dataset 实现的功能:

  • 获取每个数据及其label
  • 获取数据长度

每个数据集都需要继承 torch.utils.data.Dataset 类,并且重写 __getitem____len__
数据存放在 dataset/train里,分为两个目录 antsbees,也分别是数据的标签,如下图所示:

from PIL import Image
from torch.utils.data import Dataset
import os

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # 根目录
        self.label_dir = label_dir  # 标签目录
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)  # 图片路径列表

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.path, img_name)
        img = Image.open(img_item_path) # 获取数据 
        label = self.label_dir # 获取label
        return img, label

    def __len__(self):
        return len(self.img_path) # 获取数据集长度

# test
root_dir = 'dataset/train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
# 生成两个数据集
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset  # 拼接两个数据集

img1, label1 = ants_dataset[0]
img1.show()
print('label1:', label1)
img2, label2 = train_dataset[130]
img2.show()
print('label2:', label2)
posted @   hzyuan  阅读(70)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)

喜欢请打赏

扫描二维码打赏

支付宝打赏

点击右上角即可分享
微信分享提示