PyTorch使用LMDB加快数据集访问速度
下面给出的代码,允许用户使用LMDataset
对象,加快数据集的访问速度。它预先读取传入dataset中的数据,并存储于LMDB数据库中。在ImageNet的测试表明,它能够加快图像读取速度4.25倍。
使用代码如下:
from LMDataset import LMDataset from torchvision.datasets import ImageNet if __name__ == "__main__": print("Begin Caching") dataset_train: Dataset[torch.Tensor] = ImageNet( "~/dataset/ImageNet-1k", split="train" ) cached_dataset = LMDataset(dataset_train, "~/dataset/ImageNet-1k", "train") for X, y in cached_dataset: # ...
源码LMDataset.py
如下:
# -*- coding: utf-8 -*- import lmdb import pickle from tqdm import tqdm import multiprocessing as mp from os import path, makedirs from torch.utils.data import Dataset, DataLoader from typing import TypeVar, Optional, Literal, Sized, Tuple, Callable from torch.nn import Identity __all__ = ["LMDataset"] # Begin Configurations MAX_SIZE = 1024**4 # 1TB NUM_WORKERS = max(8, mp.cpu_count()) MAX_READERS = max(128, 2 * mp.cpu_count()) PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL INDEX_SIZE = 4 BATCH_SIZE = 256 # End Configurations dtype = TypeVar("dtype", covariant=True) transform_type = TypeVar("transform_type", covariant=True) class _PickleWrapper(Dataset[bytes]): def __init__(self, dataset: Dataset[dtype]) -> None: super().__init__() assert isinstance(dataset, Sized) self.dataset = dataset def __len__(self) -> int: return len(self.dataset) def __getitem__(self, index: int) -> Tuple[bytes, bytes]: value = pickle.dumps(self.dataset[index], protocol=PICKLE_PROTOCOL) key = index.to_bytes(INDEX_SIZE, "little") return key, value class LMDataset(Dataset[transform_type]): def __init__( self, dataset: Optional[Dataset[dtype]], root: str, split: Literal["train", "val", "test"], transform: Callable[[dtype], transform_type] = Identity(), desc: Optional[str] = "Caching Dataset", ) -> None: assert isinstance(dataset, Sized) self.root = path.join(path.expanduser(root), ".LMDB", f"{split}.mdb") self.dataset = dataset makedirs(path.dirname(self.root), exist_ok=True) # Cache Dataset if not already cached if not path.isfile(self.root): if dataset is not None: self._cache_dataset(dataset, desc=desc) # Open Read-only LMDB Environment self.env = lmdb.Environment( self.root, map_size=MAX_SIZE, subdir=False, readonly=True, max_readers=MAX_READERS, max_dbs=4, lock=False, ) # Check if dataset is complete self.length = self._len_db() if dataset is not None: assert len(dataset) == self.length, "Dataset Length Mismatch!" self.transform = transform def _cache_dataset( self, dataset: Dataset[dtype], desc: Optional[str] = None ) -> None: assert isinstance(dataset, Sized) makedirs(path.dirname(self.root), exist_ok=True) pickle_dataset = _PickleWrapper(dataset) data_loader = DataLoader( pickle_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, prefetch_factor=8, persistent_workers=True, ) if desc is not None: data_loader = tqdm(data_loader, desc=desc) with lmdb.Environment( self.root, map_size=MAX_SIZE, subdir=False, readonly=False, max_readers=1, readahead=False, sync=False, map_async=True, metasync=False, meminit=False, lock=True, ) as write_env: with write_env.begin(write=True, buffers=True) as txn: for batch in data_loader: for key, value in zip(*batch): txn.put(key, value) def __len__(self) -> int: return self.length def __getitem__(self, index: int) -> transform_type: key = index.to_bytes(INDEX_SIZE, "little") with self.env.begin(write=False, buffers=True) as txn: value: bytes = txn.get(key) # type: ignore return self.transform(pickle.loads(value)) def _len_db(self) -> int: with self.env.begin(write=False, buffers=True) as txn: stat = txn.stat(db=None) return stat["entries"] def __getattr__(self, name: str): return getattr(self.dataset, name, getattr(self.dataset, name))
本文版权,除注明引用的部分外,归作者所有。本文严禁商业用途的转载。非商业用途的转载需在网页明显处署上作者名称及原文链接。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 【杭电多校比赛记录】2025“钉耙编程”中国大学生算法设计春季联赛(1)
2020-04-01 一种Θ(1)的计算32位整数二进制中1的个数的方法