龙良曲pytorch学习笔记_加载宝可梦数据集
1 import torch 2 import os,glob 3 import random,csv 4 5 from torch.utils.data import Dataset,DataLoader 6 7 from torchvision import transforms 8 from PIL import Image 9 10 class Pokemon(Dataset): 11 ''' 12 @param 13 root:存储的根路径 14 resize:将图片大小根据网络结构适配 15 mode:train或者test模式 16 ''' 17 def __init__(self,root,resize,mode): 18 super(Pokemon,self).__init__() 19 20 self.root = root 21 self.resize = resize 22 23 # 字典类型key:name value:label 24 self.name2label = {} 25 # listdir返回顺序不固定,用sorted将它固定,因为排序一次之后就固定了 26 for name in sorted(os.listdir(os.path.join(root))): 27 if not os.path.isdir(os.path.join(root,name)): 28 continue 29 30 self.name2label[name] = len(self.name2label.keys()) 31 32 # print(self.name2label) 33 34 # image_path + image_label 35 self.images,self.labels = self.load_csv('images.csv') 36 37 if mode == 'train': # 60% 38 self.images = self.images[:int(0.6*len(self.images))] 39 self.labels = self.labels[:int(0.6*len(self.labels))] 40 elif mode == 'val': # 20% 41 self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))] 42 self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))] 43 elif mode == 'test': # 20% = 80% ->100% 44 self.images = self.images[int(0.8*len(self.images)):] 45 self.labels = self.labels[int(0.8*len(self.labels)):] 46 47 def load_csv(self,filename): 48 49 # 如果不存在再写入,存在的话直接读取就可以了 50 if not os.path.exists(os.path.join(self.root,filename)) 51 images = [] 52 for name in self.name2label.keys(): 53 # 'pokemon'\\mewtwo\\00001.png 54 images += glob.glob(os.path.join(self.root,name,'*.png')) 55 images += glob.glob(os.path.join(self.root,name,'*.jpg')) 56 images += glob.glob(os.path.join(self.root,name,'*.jpeg')) 57 58 # 1167,'pokemon\\bulbasaur\\00000000.png' 59 print(len(images),images) 60 61 random.shuffle(images) 62 with open(os.path.join(self.root,filename),mode = 'w',newline='') as f: 63 writer = csv.writer(f) 64 for img in images: # 'pokemon\\bulbasaur\\00000000.png' 65 name = img.split(os.sep)[-2] 66 label = self.name2label[name] 67 # 'pokemon\\bulbasaur\\00000000.png',0 68 writer.writerow([img,label]) 69 print('writen into csv file:',filename) 70 71 # read from csv file 72 images,labels = [],[] 73 with open(os.path.join(self.root,filename)) 74 reader = csv.reader(f) 75 for row in reader: 76 # 'pokemon\\bulbasaur\\00000000.png',0 77 img,label = row 78 label = int(label) 79 80 images.append(img) 81 labels.append(label) 82 83 # 保证images和labels一一对应,长度相等 84 assert len(images) == len(labels) 85 return images,labels 86 87 def __len__(self): 88 89 return len(self.images) 90 91 def denormalize(self,x_hat): 92 93 mean=[0.485,0.456,0.406] 94 std=[0.229,0.224,0.225] 95 96 # x_hat = (x-mean)/std 97 # x = x_hat*std+mean 98 # x: [c,h,w] 99 # mean: [3] --> [3,1,1] 100 mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) 101 std = torch.tensor(std).unsqueeze(1).unsqueeze(1) 102 103 x = x_hat*std + mean 104 105 return x 106 107 108 def __getitem__(self,idx): 109 # idx~[0~len(images)] 110 # self.images,self.labels 111 # img: pokemon\\bulbasaur\\00000000.png' 112 # label: 0 113 img,label = self.images[idx],self.labels[idx] 114 115 tf = transforms.Compose([ 116 lambda x:Image.open(x).convert('RGB'), # string path --> image data 117 transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))), 118 transforms.RandomRotation(15), 119 transforms.CenterCrop(self.resize), 120 transforms.ToTensor(), 121 transforms.Normalize(mean=[0.485,0.456,0.406], 122 std=[0.229,0.224,0.225]) 123 ]) 124 125 img = tf(img) 126 label = torch.tensor(label) 127 128 return img,label