pytorch 加载对应数据

Background

有两个相同的图像数据集,仅图像的分辨率不同(如一批为128一批为64),数据集中的图像一一对应,

为两个数据集分别设置两个Dataloader,现在想要在加载batch时得到一一对应且乱序的图像。

如果仅在Dataloader中设置shuffle=True,得到的两个batch并不一一对应

如:A:[[2,4,0], [1,3]]    B:[[1,2,0], [4,3]]

如果仅在Dataloader中设置shuffle=Flase,可实现一一对应但无法实现每个epoch的乱序

如:A:[[0,1,2], [3,4]]    B:[[0,1,2], [3,4]]

 

Method

通过重载RandomSampler来为两个Dataloader提供相同的indice

由于两个Dataloader共用一个sampler,请求indice时用的是同一个__iter__()

因此考虑用计数器来控制新一组indice的生成

 1 def r(n):
 2 
 3     torch.manual_seed(0)
 4 
 5     while True:
 6 
 7         yield torch.randperm(n).tolist()
 8 
 9  
10 
11 class mySampler(torch.utils.data.sampler.RandomSampler):
12 
13     def __init__(self, data_source, replacement=False, num_samples=None):
14 
15         self.data_source = data_source
16 
17         self.replacement = replacement
18 
19         self._num_samples = num_samples
20 
21  
22 
23         self.r = r(len(self.data_source))  # 提供标号
24 
25         self.count = 0
26 
27         self.rand_list = []  # 保存标号
28 
29  
30 
31         if not isinstance(self.replacement, bool):
32 
33             raise ValueError("replacement should be a boolean value, but got "
34 
35                              "replacement={}".format(self.replacement))
36 
37  
38 
39         if self._num_samples is not None and not replacement:
40 
41             raise ValueError("With replacement=False, num_samples should not be specified, "
42 
43                              "since a random permute will be performed.")
44 
45  
46 
47         if not isinstance(self.num_samples, int) or self.num_samples <= 0:
48 
49             raise ValueError("num_samples should be a positive integer "
50 
51                              "value, but got num_samples={}".format(self.num_samples))
52 
53  
54 
55     def __iter__(self):
56 
57         # 共两个平行数据集,因此
58 
59         # 每两次iter更新一次indice
60 
61         if self.count % 2 == 0:
62 
63             self.rand_list = self.r.__next__()
64 
65         self.count += 1
66 
67  
68 
69         return iter(self.rand_list)
70 
71  
72 
73 dataset16 = datasets.ImageFolder('./test/16/', transform=transforms.ToTensor())
74 
75 dataset32 = datasets.ImageFolder('./test/32/', transform=transforms.ToTensor())
76 
77  
78 
79 # 平行数据集大小相同,因此可共用sampler
80 
81 ms = mySampler(dataset16)
82 
83  
84 
85 d16 = DataLoader(dataset16, batch_size=2, sampler=ms, drop_last=False)
86 
87 d32 = DataLoader(dataset32, batch_size=2, sampler=ms, drop_last=False)

 

 

posted @ 2020-09-25 18:06  Junzhao  阅读(280)  评论(0编辑  收藏  举报