pytorch创建自己的图片Dataset
class mydataset(Dataset):
def __init__(self, root, transform=None):
super(mydataset,self).__init__()
# 所有图片的绝对路径
imgs = os.listdir(root)
self.imgs = [os.path.join(root,k) for k in imgs]
self.transform = transform
def __getitem__(self,index):
img_path = self.imgs[index]
pil_img =Image.open(img_path)
if self.transform:
data = self.transform(pil_img)
else:
pil_img = np.asarray(pil_img)
data = torch.from_numpy(pil_img)
return data
def __len__(self):
return len(self.imgs)