pytorch 加载对应数据
Background
有两个相同的图像数据集,仅图像的分辨率不同(如一批为128一批为64),数据集中的图像一一对应,
为两个数据集分别设置两个Dataloader,现在想要在加载batch时得到一一对应且乱序的图像。
如果仅在Dataloader中设置shuffle=True,得到的两个batch并不一一对应
如:A:[[2,4,0], [1,3]] B:[[1,2,0], [4,3]]
如果仅在Dataloader中设置shuffle=Flase,可实现一一对应但无法实现每个epoch的乱序
如:A:[[0,1,2], [3,4]] B:[[0,1,2], [3,4]]
Method
通过重载RandomSampler来为两个Dataloader提供相同的indice
由于两个Dataloader共用一个sampler,请求indice时用的是同一个__iter__()
因此考虑用计数器来控制新一组indice的生成
1 def r(n): 2 3 torch.manual_seed(0) 4 5 while True: 6 7 yield torch.randperm(n).tolist() 8 9 10 11 class mySampler(torch.utils.data.sampler.RandomSampler): 12 13 def __init__(self, data_source, replacement=False, num_samples=None): 14 15 self.data_source = data_source 16 17 self.replacement = replacement 18 19 self._num_samples = num_samples 20 21 22 23 self.r = r(len(self.data_source)) # 提供标号 24 25 self.count = 0 26 27 self.rand_list = [] # 保存标号 28 29 30 31 if not isinstance(self.replacement, bool): 32 33 raise ValueError("replacement should be a boolean value, but got " 34 35 "replacement={}".format(self.replacement)) 36 37 38 39 if self._num_samples is not None and not replacement: 40 41 raise ValueError("With replacement=False, num_samples should not be specified, " 42 43 "since a random permute will be performed.") 44 45 46 47 if not isinstance(self.num_samples, int) or self.num_samples <= 0: 48 49 raise ValueError("num_samples should be a positive integer " 50 51 "value, but got num_samples={}".format(self.num_samples)) 52 53 54 55 def __iter__(self): 56 57 # 共两个平行数据集,因此 58 59 # 每两次iter更新一次indice 60 61 if self.count % 2 == 0: 62 63 self.rand_list = self.r.__next__() 64 65 self.count += 1 66 67 68 69 return iter(self.rand_list) 70 71 72 73 dataset16 = datasets.ImageFolder('./test/16/', transform=transforms.ToTensor()) 74 75 dataset32 = datasets.ImageFolder('./test/32/', transform=transforms.ToTensor()) 76 77 78 79 # 平行数据集大小相同,因此可共用sampler 80 81 ms = mySampler(dataset16) 82 83 84 85 d16 = DataLoader(dataset16, batch_size=2, sampler=ms, drop_last=False) 86 87 d32 = DataLoader(dataset32, batch_size=2, sampler=ms, drop_last=False)