[PyTorch] 自定义数据集

自己学深度学习的时候想要记下来的,看到代码这样用就这样记了,所以我也不知道这个定义过程规不规范,反正就是能用。其他类型的数据按代码改一下应该没问题,下边是一个图片数据集!
步骤

  • 自定义Dataset实例
    • 定义 __init__ 方法:返回feature和label两个部分的数据;
    • 定义 __getitem__ 方法
    • 定义 __len__ 方法
  • 使用 torch.utils.data.DataLoader 加载数据

示例

  • 目的:载入自定义的目标检测数据集——“banana-detection”;

  • 数据集格式:包含两个文件夹“bananas_train”和“bananans_val”,

    • 文件夹:两个文件夹包含内容都一样,分别为一个“label.csv”文件和一个“images”文件夹(存放的是图片);

    • label.csv:存放的是每张图片的文件名和目标边框信息,字段如下:

      字段名 img_name label xmin ymin xmax ymax
      含义 图片文件名 具体含义我不太清楚,全为0,好像没用到 目标左上角的x轴坐标 目标左上角的y轴坐标 目标右下角的x轴坐标 目标右下角的y轴坐标
# %% import os import pandas as pd import torch import torchvision from d2l import torch as d2l import numpy as np # %% def read_data_bananas(is_train = True): """读取香蕉数据集中的图像和标签""" data_dir = '..\\data\\banana-detection' csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv') csv_data = pd.read_csv(csv_fname) # 指定索引,下边的for循环取到数的index(第一个参数img_name)就为该字段,其他字段的数据就会放到target中 csv_data = csv_data.set_index('img_name') images, targets = [], [] # 将图片读取到内存中(数据集大的时候不能用该方法) for img_name, target in csv_data.iterrows(): # pandas.DataFrame的iterrows()函数是在数据框中的行进行迭代的一个生成器,它返回每行的索引及一个包含行本身的对象。 images.append( # 读取图片,返回一个三维tensor(channels, height, width) torchvision.io.read_image( os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'images', f'{img_name}') # python的print字符串前面加f表示格式化字符串,加f后可以在字符串里面使用用花括号括起来的变量和表达式 ) ) tmp = torchvision.io.read_image( os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'images', f'{img_name}') # python的print字符串前面加f表示格式化字符串,加f后可以在字符串里面使用用花括号括起来的变量和表达式 ) targets.append(list(target)) print(tmp) print(tmp.shape) # images包含所有图片的张量 # images是一个列表,每个元素为一个多维tensor(一张图片的张量),转换为tensor会报错。如果是其他数据的话,如果能转换为tensor的话,不知道这个代码会不会报错,但是target这里也转换成tensor了,也没问题。有人知道的话请告诉我 # targets转换成tensor,每个元素除以256。实际上target是列表也不会报错((torch.tensor(targets).unsqueeze(1) / 256).numpy().tolist()) return images, torch.tensor(targets).unsqueeze(1) / 256 # target扩一个维度,然后每个数字除以256 # %% # 创建一个自定义Dataset实例 class BananasDataset(torch.utils.data.Dataset): """一个用于加载香蕉数据集的自定义数据集""" def __init__(self, is_train): self.features, self.labels = read_data_bananas(is_train) print('read' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples')) def __getitem__(self, idx): return (self.features[idx].float(), self.labels[idx]) def __len__(self): return len(self.features) # %% # 训练集和测试集返回两个数据加载器实例 def load_data_bananas(batch_size): """加载香蕉检测数据集""" train_iter = torch.utils.data.DataLoader(BananasDataset(is_train = True), batch_size, shuffle = True) val_iter = torch.utils.data.DataLoader(BananasDataset(is_train = False), batch_size) return train_iter, val_iter # %% # 读取一个小批量,并打印其中的图像和标签的形状 batch_size, edge_size = 32, 256 train_iter, _ = load_data_bananas(batch_size) batch = next(iter(train_iter)) # 迭代train_iter,取第一轮的结果,也就是一个结果 # 输出第一个batch的维度(batch_size, channels, height, width),第二个是 batch[0].shape, batch[1].shape # %% # 画出图片和目标框(用的是深度学习模块 d2l) # permute()是维度换位,(0,2,3,1)表示把原来的 # 0维 -> 0维 # 2维 -> 1维 # 3维 -> 2维 # 1维 -> 3维 imgs = (batch[0][0 : 10].permute(0, 2, 3, 1)) / 255 axes = d2l.show_images(imgs, 2, 5, scale=2) for ax, label in zip(axes, batch[1][0:10]): d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors = ['w'])

感谢李沐老哥哥赞助的代码和数据集

posted @   小贼的自由  阅读(138)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Obsidian + DeepSeek:免费 AI 助力你的知识管理,让你的笔记飞起来!
· 分享4款.NET开源、免费、实用的商城系统
· 解决跨域问题的这6种方案,真香!
· 一套基于 Material Design 规范实现的 Blazor 和 Razor 通用组件库
· 5. Nginx 负载均衡配置案例(附有详细截图说明++)
点击右上角即可分享
微信分享提示