PyTorch学习记录(二):Dataset数据读取
深度学习关于数据层面的处理
- 数据下载 [官网/谷歌云盘/百度网盘]
- 数据处理 [将官方下载的原始数据转为易于模型训练的数据]
- 数据读取 [dataset.py包括Dataset类的具体实现]
Dataset
|
|----TrainDataset
| |----train
| |----validate
|
|----TestDataset
|----test
TrainDataset要兼容train和validate两个子集,TestDataset用于网络模型测试
TrainDataset和TestDataset两个数据集分开写
(不确定两者是不是只有标签不同,很有可能数据读取的路径也存在差异)
类名首字母大写/实例化的变量小写,并用下划线分开
from torch.utils.data import Dataset
import os
import cv2
import numpy as np
from PIL import Image
class MyDataset(Dataset):
def __init__(self, data_path, mode)
super(MyDataset, self).__init__()
self.data_path = data_path
self.mode = mode
assert self.mode in ["train", "validate"]
self.metas = xxx #只保存路径,不读图片
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
# 训练集:返回图片和标签
pass
if __name__ == "__main__":
# test code HERE
data_path = ""
train_dataset = TrainDataset(data_path, mode="train")
# opencv visualization
# TODO
TODO: torch.tensor --> np.ndarray