基于HDF5的高维数据有效读写

代码:

https://github.com/JiJingYu/concat_dataset

 

Demo特点

该代码基于自行编写的H5Imageset类与pytorch中的ConcatDataset接口,主要有以下特点:

  1. 有效利用了hdf5读取数据时直接与硬盘交互,无需载入整个数据集到内存中的优势,降低内存开销。

  2. 重载了python内置的__getitem__()方法,使得数据动态生成,无需独立保存数据,降低磁盘开销。

  3. 利用pytorch内置的ConcatDataset类,高效合并多组H5Imageset数据集,统一调用,统一索引。

H5Imageset

重载是个好东西,可以多用。该类重载了__getitem__()方法,并维护一个索引列表idx_list, 调用__getitem__方法时会自动从idx_list中读取预设的下标,从图像中读取下标对应的区域。

 1 class H5Imageset(Dataset):
 2     """Dataset wrapping data and target tensors.
 3     Each sample will be retrieved by indexing both tensors along the first
 4     dimension.
 5     Arguments:
 6         data_tensor (Tensor): H,W,C.
 7         target_tensor (Tensor): H,W,C.
 8     """
 9 
10     def __init__(self, data_tensor, target_tensor, patch_size, stride):
11         assert data_tensor.shape[0] == target_tensor.shape[0]
12         self.data_tensor = data_tensor
13         self.target_tensor = target_tensor
14         ## idx_list的生成方法为该类的核心,可自行根据需要替换为其他函数。
15         self.idx_list = self.get_idx_list(patch_size, stride)
16 
17     def get_idx_list(self, patch_size, stride):
18         """
19         idx_list的生成方法为该类的核心,可自行根据需要替换为其他函数。
20         """
21         H, W, _ = self.data_tensor.shape
22         idx_list = []
23         for h in np.arange(start=0, stop=H-patch_size, step=stride):
24             for w in np.arange(start=0, stop=W - patch_size, step=stride):
25                 idx_list.append((h, w, patch_size, patch_size))
26         return idx_list
27 
28     def __getitem__(self, index):
29         # print(index)
30         h, w, patch_size, patch_size = self.idx_list[index]
31         return self.data_tensor[h:h+patch_size, w:w+patch_size], \
32                self.target_tensor[h:h+patch_size, w:w+patch_size]
33 
34     def __len__(self):
35      return len(self.idx_list)

ConcatDataset

pytorch提供的ConcatDataset类,做了很好的示范

 1 class ConcatDataset(Dataset):
 2     """
 3     Dataset to concatenate multiple datasets.
 4     Purpose: useful to assemble different existing datasets, possibly
 5     large-scale datasets as the concatenation operation is done in an
 6     on-the-fly manner.
 7     Arguments:
 8         datasets (iterable): List of datasets to be concatenated
 9     """
10 
11     @staticmethod
12     def cumsum(sequence):
13         r, s = [], 0
14         for e in sequence:
15             l = len(e)
16             r.append(l + s)
17             s += l
18         return r
19 
20     def __init__(self, datasets):
21         super(ConcatDataset, self).__init__()
22         assert len(datasets) > 0, 'datasets should not be an empty iterable'
23         self.datasets = list(datasets)
24         self.cummulative_sizes = self.cumsum(self.datasets)
25 
26     def __len__(self):
27         return self.cummulative_sizes[-1]
28 
29     def __getitem__(self, idx):
30         dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx)
31         if dataset_idx == 0:
32             sample_idx = idx
33         else:
34             sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
35         return self.datasets[dataset_idx][sample_idx]

 

使用场景

该代码通常用于高维数据,如光场图像(4维),高光谱图像(3维),该类数据有数据量大,处理速度有限等特点。 传统的直接处理数据集、直接生成数据集、保存数据集的方法会使得数据量暴涨。例如ICVL数据集原始数据约30GB, patch=64, stride=16分割之后,数据集会暴涨至500GB,给磁盘、IO和内存带来巨大压力。 用该代码可在不增加磁盘占用,不损失数据集IO时间的前提下,对数据集做有效的预处理,如按patch分割等, 同时可以大幅度降低内存占用。

其他

由于空间有限,此处以少量RGB图像为例,简单展示demo用途。

made by 法师漂流

posted @ 2017-11-18 10:00  法师漂流  阅读(3306)  评论(0编辑  收藏  举报