Fork me on github

PyTorch使用LMDB加快数据集访问速度

下面给出的代码,允许用户使用LMDataset对象,加快数据集的访问速度。它预先读取传入dataset中的数据,并存储于LMDB数据库中。在ImageNet的测试表明,它能够加快图像读取速度4.25倍。

使用代码如下:

from LMDataset import LMDataset
from torchvision.datasets import ImageNet
if __name__ == "__main__":
print("Begin Caching")
dataset_train: Dataset[torch.Tensor] = ImageNet(
"~/dataset/ImageNet-1k", split="train"
)
cached_dataset = LMDataset(dataset_train, "~/dataset/ImageNet-1k", "train")
for X, y in cached_dataset:
# ...

源码LMDataset.py如下:

# -*- coding: utf-8 -*-
import lmdb
import pickle
from tqdm import tqdm
import multiprocessing as mp
from os import path, makedirs
from torch.utils.data import Dataset, DataLoader
from typing import TypeVar, Optional, Literal, Sized, Tuple, Callable
from torch.nn import Identity
__all__ = ["LMDataset"]
# Begin Configurations
MAX_SIZE = 1024**4 # 1TB
NUM_WORKERS = max(8, mp.cpu_count())
MAX_READERS = max(128, 2 * mp.cpu_count())
PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
INDEX_SIZE = 4
BATCH_SIZE = 256
# End Configurations
dtype = TypeVar("dtype", covariant=True)
transform_type = TypeVar("transform_type", covariant=True)
class _PickleWrapper(Dataset[bytes]):
def __init__(self, dataset: Dataset[dtype]) -> None:
super().__init__()
assert isinstance(dataset, Sized)
self.dataset = dataset
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, index: int) -> Tuple[bytes, bytes]:
value = pickle.dumps(self.dataset[index], protocol=PICKLE_PROTOCOL)
key = index.to_bytes(INDEX_SIZE, "little")
return key, value
class LMDataset(Dataset[transform_type]):
def __init__(
self,
dataset: Optional[Dataset[dtype]],
root: str,
split: Literal["train", "val", "test"],
transform: Callable[[dtype], transform_type] = Identity(),
desc: Optional[str] = "Caching Dataset",
) -> None:
assert isinstance(dataset, Sized)
self.root = path.join(path.expanduser(root), ".LMDB", f"{split}.mdb")
self.dataset = dataset
makedirs(path.dirname(self.root), exist_ok=True)
# Cache Dataset if not already cached
if not path.isfile(self.root):
if dataset is not None:
self._cache_dataset(dataset, desc=desc)
# Open Read-only LMDB Environment
self.env = lmdb.Environment(
self.root,
map_size=MAX_SIZE,
subdir=False,
readonly=True,
max_readers=MAX_READERS,
max_dbs=4,
lock=False,
)
# Check if dataset is complete
self.length = self._len_db()
if dataset is not None:
assert len(dataset) == self.length, "Dataset Length Mismatch!"
self.transform = transform
def _cache_dataset(
self, dataset: Dataset[dtype], desc: Optional[str] = None
) -> None:
assert isinstance(dataset, Sized)
makedirs(path.dirname(self.root), exist_ok=True)
pickle_dataset = _PickleWrapper(dataset)
data_loader = DataLoader(
pickle_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
prefetch_factor=8,
persistent_workers=True,
)
if desc is not None:
data_loader = tqdm(data_loader, desc=desc)
with lmdb.Environment(
self.root,
map_size=MAX_SIZE,
subdir=False,
readonly=False,
max_readers=1,
readahead=False,
sync=False,
map_async=True,
metasync=False,
meminit=False,
lock=True,
) as write_env:
with write_env.begin(write=True, buffers=True) as txn:
for batch in data_loader:
for key, value in zip(*batch):
txn.put(key, value)
def __len__(self) -> int:
return self.length
def __getitem__(self, index: int) -> transform_type:
key = index.to_bytes(INDEX_SIZE, "little")
with self.env.begin(write=False, buffers=True) as txn:
value: bytes = txn.get(key) # type: ignore
return self.transform(pickle.loads(value))
def _len_db(self) -> int:
with self.env.begin(write=False, buffers=True) as txn:
stat = txn.stat(db=None)
return stat["entries"]
def __getattr__(self, name: str):
return getattr(self.dataset, name, getattr(self.dataset, name))
posted @   fang-d  阅读(228)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 【杭电多校比赛记录】2025“钉耙编程”中国大学生算法设计春季联赛(1)
历史上的今天:
2020-04-01 一种Θ(1)的计算32位整数二进制中1的个数的方法
点击右上角即可分享
微信分享提示