PyTorch学习记录(二):Dataset数据读取

深度学习关于数据层面的处理

  • 数据下载 [官网/谷歌云盘/百度网盘]
  • 数据处理 [将官方下载的原始数据转为易于模型训练的数据]
  • 数据读取 [dataset.py包括Dataset类的具体实现]
Dataset
  |
  |----TrainDataset
  |        |----train  
  |        |----validate
  |
  |----TestDataset
            |----test

TrainDataset要兼容train和validate两个子集,TestDataset用于网络模型测试
TrainDataset和TestDataset两个数据集分开写
(不确定两者是不是只有标签不同,很有可能数据读取的路径也存在差异)
类名首字母大写/实例化的变量小写,并用下划线分开

from torch.utils.data import Dataset
import os
import cv2
import numpy as np
from PIL import Image

class MyDataset(Dataset):
  def __init__(self, data_path, mode)
    super(MyDataset, self).__init__()
    self.data_path = data_path
    self.mode = mode

    assert self.mode in ["train", "validate"]

    self.metas = xxx  #只保存路径,不读图片
    
  def __len__(self):
    return len(self.metas)

  def __getitem__(self, idx):
    # 训练集:返回图片和标签
    pass

if __name__ == "__main__":
  # test code HERE
  data_path = ""
  train_dataset = TrainDataset(data_path, mode="train")

  # opencv visualization
  # TODO

TODO: torch.tensor --> np.ndarray

posted @ 2022-07-15 20:29  达可奈特  阅读(210)  评论(0编辑  收藏  举报