pytorch 实现一个自定义的dataloader,每个batch都可以实现类别数量均衡
#!/usr/bin/python3 # _*_coding:utf-8 _*_ ''' 自定义重写 dataset,实现类别均衡,体现为 每个batch都可以按照自己设定得比例来采样,且支持多进程和分布式 ''' from check_pkgs import * import torch.distributed as dist IMG_EXT = ['.png', '.jpg'] class MyClassBalanceDataset(Dataset): def __init__(self, root, transform=None): super(MyClassBalanceDataset, self).__init__() assert osp.exists(root) classes = sorted(d.name for d in os.scandir(root) if d.is_dir()) classes_to_idx = {name: idx for idx, name in enumerate(classes)} idxs = list(classes_to_idx.values()) class_idx_num = {i: 0 for i in idxs} class_idx_samples = {i: [] for i in idxs} samples = [] start, end = 0, 0 for cls in classes: _idx = classes_to_idx[cls] for f in [i for i in glob.glob(f'{root}/**/*.*', recursive=True) if osp.splitext(i)[-1] in IMG_EXT]: class_idx_num[_idx] += 1 samples.append((f, _idx)) end = len(samples) class_idx_samples[_idx] = [start, end] start = end print(f'number of each category: {class_idx_num}') print(f'class_idx_samples: {class_idx_samples}') self.samples = samples self.class_idx_samples = class_idx_samples self.transform = transform def __len__(self): total = len(self.samples) if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() num_replicas = dist.get_world_size() total = math.ceil(total / num_replicas) return total def __getitem__(self, index): _path, _target = self.samples[index] ########### DEBUG ########### return index, _target ########### DEBUG ########### _sample = Image.open(_path).convert('RGB') if self.transform is not None: _sample = self.transform(_sample) else: _sample = np.asarray(_sample) _target = torch.tensor(_target) return _sample, _target # 自己实现一个batchsampler 采样器,精准控制每个batch里得类别数量 class MyBatchSampler(Sampler): def __init__(self, data_source, batch_size, class_weight): super(MyBatchSampler, self).__init__(data_source) self.data_source = data_source assert isinstance(class_weight, list) assert 1 - sum(class_weight) < 1e-5 self.batch_size = batch_size _num = len(class_weight) number_in_batch = {i: 0 for i in range(_num)} for c in range(_num): number_in_batch[c] = math.floor(batch_size * class_weight[c]) _remain_num = batch_size - sum(number_in_batch.values()) number_in_batch[random.choice(range(_num))] += _remain_num self.number_in_batch = number_in_batch self.offset_per_class = {i: 0 for i in range(_num)} print(f'setting number_in_batch: {number_in_batch}') print('my sampler is inited.') # 如果是分布式,需要重新分配采样比例,避免重复采样 if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() num_replicas = dist.get_world_size() t = self.data_source.class_idx_samples.items() for c, (start, end) in t: total = end - start num_samples = math.ceil(total / num_replicas) start_rank = rank * num_samples + start end_rank = start_rank + num_samples if end_rank > end: end_rank = end # update idx range self.data_source.class_idx_samples[c] = [start_rank, end_rank] print('using torch distributed mode.') print(f'current rank data sample setting: {self.data_source.class_idx_samples}') def __iter__(self): print('======= start __iter__ =======') batch = [] i = 0 while i < len(self): for c, num in self.number_in_batch.items(): start, end = self.data_source.class_idx_samples[c] for _ in range(num): idx = start + self.offset_per_class[c] if idx >= end: self.offset_per_class[c] = 0 idx = start + self.offset_per_class[c] batch.append(idx) self.offset_per_class[c] += 1 assert len(batch) == self.batch_size # random.shuffle(batch) yield batch batch = [] i += 1 def __len__(self): return len(self.data_source) // self.batch_size # 单卡版本 def test_1(): root = 'G:/Project/DataSet/flower_photos/flower_photos' assert osp.exists(root) batch_size = 32 num_workers = 8 transform = TF.Compose([ TF.Resize((5, 5)), TF.Grayscale(1), TF.ToTensor(), ]) n_class = 5 clas_weight = [0.5, 0.2, 0.1, 0.1, 0.1] ds = MyClassBalanceDataset(root, transform) _batchSampler = MyBatchSampler(ds, batch_size, clas_weight) data_loader = DataLoader(ds, batch_size=1, num_workers=num_workers, pin_memory=True, batch_sampler=_batchSampler) print(f'dataloader total: {len(data_loader)}') for epoch in range(3): for step, (x, y) in enumerate(data_loader): # print(step) print(step, x) # print('batch hist:', torch.histc(y.float(), n_class, 0, n_class - 1)) # 多卡分布式版本 def test_2(): ''' CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --mast_port=29734 \ dataset_customized.py --distributed=1 :return: ''' parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed parallel') parser.add_argument('--distributed', type=int, default=0, help='distributed mode') args = parser.parse_args() assert torch.distributed.is_nccl_available() torch.cuda.set_device(args.local_rank) device_num = torch.cuda.device_count() distributed_mode = device_num >= 2 and args.distributed if distributed_mode: dist.init_process_group('nccl', world_size=device_num, rank=args.local_rank) rank = dist.get_rank() num_rep = dist.get_world_size() print(rank, num_rep) print('torch distributed work is inited.') root = 'G:/Project/DataSet/flower_photos/flower_photos' assert osp.exists(root) batch_size = 32 num_workers = 8 transform = TF.Compose([ TF.Resize((5, 5)), TF.Grayscale(1), TF.ToTensor(), ]) n_class = 5 clas_weight = [0.5, 0.2, 0.1, 0.1, 0.1] ds = MyClassBalanceDataset(root, transform) _batchSampler = MyBatchSampler(ds, batch_size, clas_weight) data_loader = DataLoader(ds, batch_size=1, num_workers=num_workers, pin_memory=True, batch_sampler=_batchSampler) print(f'dataloader total: {len(data_loader)}') for epoch in range(3): for step, (x, y) in enumerate(data_loader): print(step, x) print('batch hist:', torch.histc(y.float(), n_class, 0, n_class - 1)) if __name__ == '__main__': test_1()