基于HDF5的高维数据有效读写
代码:
https://github.com/JiJingYu/concat_dataset
Demo特点
该代码基于自行编写的H5Imageset类与pytorch中的ConcatDataset接口,主要有以下特点:
-
有效利用了hdf5读取数据时直接与硬盘交互,无需载入整个数据集到内存中的优势,降低内存开销。
-
重载了python内置的__getitem__()方法,使得数据动态生成,无需独立保存数据,降低磁盘开销。
-
利用pytorch内置的ConcatDataset类,高效合并多组H5Imageset数据集,统一调用,统一索引。
H5Imageset
重载是个好东西,可以多用。该类重载了__getitem__()方法,并维护一个索引列表idx_list, 调用__getitem__方法时会自动从idx_list中读取预设的下标,从图像中读取下标对应的区域。
1 class H5Imageset(Dataset): 2 """Dataset wrapping data and target tensors. 3 Each sample will be retrieved by indexing both tensors along the first 4 dimension. 5 Arguments: 6 data_tensor (Tensor): H,W,C. 7 target_tensor (Tensor): H,W,C. 8 """ 9 10 def __init__(self, data_tensor, target_tensor, patch_size, stride): 11 assert data_tensor.shape[0] == target_tensor.shape[0] 12 self.data_tensor = data_tensor 13 self.target_tensor = target_tensor 14 ## idx_list的生成方法为该类的核心,可自行根据需要替换为其他函数。 15 self.idx_list = self.get_idx_list(patch_size, stride) 16 17 def get_idx_list(self, patch_size, stride): 18 """ 19 idx_list的生成方法为该类的核心,可自行根据需要替换为其他函数。 20 """ 21 H, W, _ = self.data_tensor.shape 22 idx_list = [] 23 for h in np.arange(start=0, stop=H-patch_size, step=stride): 24 for w in np.arange(start=0, stop=W - patch_size, step=stride): 25 idx_list.append((h, w, patch_size, patch_size)) 26 return idx_list 27 28 def __getitem__(self, index): 29 # print(index) 30 h, w, patch_size, patch_size = self.idx_list[index] 31 return self.data_tensor[h:h+patch_size, w:w+patch_size], \ 32 self.target_tensor[h:h+patch_size, w:w+patch_size] 33 34 def __len__(self): 35 return len(self.idx_list)
ConcatDataset
pytorch提供的ConcatDataset类,做了很好的示范
1 class ConcatDataset(Dataset): 2 """ 3 Dataset to concatenate multiple datasets. 4 Purpose: useful to assemble different existing datasets, possibly 5 large-scale datasets as the concatenation operation is done in an 6 on-the-fly manner. 7 Arguments: 8 datasets (iterable): List of datasets to be concatenated 9 """ 10 11 @staticmethod 12 def cumsum(sequence): 13 r, s = [], 0 14 for e in sequence: 15 l = len(e) 16 r.append(l + s) 17 s += l 18 return r 19 20 def __init__(self, datasets): 21 super(ConcatDataset, self).__init__() 22 assert len(datasets) > 0, 'datasets should not be an empty iterable' 23 self.datasets = list(datasets) 24 self.cummulative_sizes = self.cumsum(self.datasets) 25 26 def __len__(self): 27 return self.cummulative_sizes[-1] 28 29 def __getitem__(self, idx): 30 dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx) 31 if dataset_idx == 0: 32 sample_idx = idx 33 else: 34 sample_idx = idx - self.cummulative_sizes[dataset_idx - 1] 35 return self.datasets[dataset_idx][sample_idx]
使用场景
该代码通常用于高维数据,如光场图像(4维),高光谱图像(3维),该类数据有数据量大,处理速度有限等特点。 传统的直接处理数据集、直接生成数据集、保存数据集的方法会使得数据量暴涨。例如ICVL数据集原始数据约30GB, patch=64, stride=16分割之后,数据集会暴涨至500GB,给磁盘、IO和内存带来巨大压力。 用该代码可在不增加磁盘占用,不损失数据集IO时间的前提下,对数据集做有效的预处理,如按patch分割等, 同时可以大幅度降低内存占用。
其他
由于空间有限,此处以少量RGB图像为例,简单展示demo用途。
made by 法师漂流