利用torch.utils.data.Dataset自定义数据加载类

import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

import torchvision.transforms as T

 

transforms = T.Compose([

  T.Resize(224),

  T.CenterCrop(224),

  T.ToTensor(),

  T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

])

 

# 继承Dataset类要重写__getitem__()和__len__()
class CatDog(data.Dataset):
  def __init__(self, root, transforms=None):

    # 临时变量不用加self
    imgs = os.listdir(root)
    self.imgs = [os.path.join(root, img) for img in imgs]

    self.transforms = transforms

  def __getitem__(self, index):
    label = 1 if dog else 0

    data = Image.open(self.imgs[index])
    if self.transform:

      data = self.transform(data)
    return data, label

  def __len__(self):
    return len(self.imgs)

posted @ 2020-02-17 10:06  6+0  阅读(2654)  评论(1编辑  收藏  举报