pytorch创建自己的图片Dataset

class mydataset(Dataset):
    def __init__(self, root, transform=None):
        super(mydataset,self).__init__()
        # 所有图片的绝对路径
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root,k) for k in imgs]
        self.transform = transform
    def __getitem__(self,index):
        img_path = self.imgs[index]
        pil_img =Image.open(img_path)
        if self.transform:
            data = self.transform(pil_img)
        else:
            pil_img = np.asarray(pil_img)
            data = torch.from_numpy(pil_img)
        return data
    def __len__(self):
        return len(self.imgs)  
posted @ 2020-08-03 23:17  林胜联府  阅读(576)  评论(0编辑  收藏  举报