PyTorch使用LMDB加快数据集访问速度
下面给出的代码,允许用户使用LMDataset
对象,加快数据集的访问速度。它预先读取传入dataset中的数据,并存储于LMDB数据库中。在ImageNet的测试表明,它能够加快图像读取速度4.25倍。
使用代码如下:
from LMDataset import LMDataset
from torchvision.datasets import ImageNet
if __name__ == "__main__":
print("Begin Caching")
dataset_train: Dataset[torch.Tensor] = ImageNet(
"~/dataset/ImageNet-1k", split="train"
)
cached_dataset = LMDataset(dataset_train, "~/dataset/ImageNet-1k", "train")
for X, y in cached_dataset:
# ...
源码LMDataset.py
如下:
# -*- coding: utf-8 -*-
import lmdb
import pickle
from tqdm import tqdm
import multiprocessing as mp
from os import path, makedirs
from torch.utils.data import Dataset, DataLoader
from typing import TypeVar, Optional, Literal, Sized, Tuple, Callable
from torch.nn import Identity
__all__ = ["LMDataset"]
# Begin Configurations
MAX_SIZE = 1024**4 # 1TB
NUM_WORKERS = max(8, mp.cpu_count())
MAX_READERS = max(128, 2 * mp.cpu_count())
PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
INDEX_SIZE = 4
BATCH_SIZE = 256
# End Configurations
dtype = TypeVar("dtype", covariant=True)
transform_type = TypeVar("transform_type", covariant=True)
class _PickleWrapper(Dataset[bytes]):
def __init__(self, dataset: Dataset[dtype]) -> None:
super().__init__()
assert isinstance(dataset, Sized)
self.dataset = dataset
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, index: int) -> Tuple[bytes, bytes]:
value = pickle.dumps(self.dataset[index], protocol=PICKLE_PROTOCOL)
key = index.to_bytes(INDEX_SIZE, "little")
return key, value
class LMDataset(Dataset[transform_type]):
def __init__(
self,
dataset: Optional[Dataset[dtype]],
root: str,
split: Literal["train", "val", "test"],
transform: Callable[[dtype], transform_type] = Identity(),
desc: Optional[str] = "Caching Dataset",
) -> None:
assert isinstance(dataset, Sized)
self.root = path.join(path.expanduser(root), ".LMDB", f"{split}.mdb")
self.dataset = dataset
makedirs(path.dirname(self.root), exist_ok=True)
# Cache Dataset if not already cached
if not path.isfile(self.root):
if dataset is not None:
self._cache_dataset(dataset, desc=desc)
# Open Read-only LMDB Environment
self.env = lmdb.Environment(
self.root,
map_size=MAX_SIZE,
subdir=False,
readonly=True,
max_readers=MAX_READERS,
max_dbs=4,
lock=False,
)
# Check if dataset is complete
self.length = self._len_db()
if dataset is not None:
assert len(dataset) == self.length, "Dataset Length Mismatch!"
self.transform = transform
def _cache_dataset(
self, dataset: Dataset[dtype], desc: Optional[str] = None
) -> None:
assert isinstance(dataset, Sized)
makedirs(path.dirname(self.root), exist_ok=True)
pickle_dataset = _PickleWrapper(dataset)
data_loader = DataLoader(
pickle_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
prefetch_factor=8,
persistent_workers=True,
)
if desc is not None:
data_loader = tqdm(data_loader, desc=desc)
with lmdb.Environment(
self.root,
map_size=MAX_SIZE,
subdir=False,
readonly=False,
max_readers=1,
readahead=False,
sync=False,
map_async=True,
metasync=False,
meminit=False,
lock=True,
) as write_env:
with write_env.begin(write=True, buffers=True) as txn:
for batch in data_loader:
for key, value in zip(*batch):
txn.put(key, value)
def __len__(self) -> int:
return self.length
def __getitem__(self, index: int) -> transform_type:
key = index.to_bytes(INDEX_SIZE, "little")
with self.env.begin(write=False, buffers=True) as txn:
value: bytes = txn.get(key) # type: ignore
return self.transform(pickle.loads(value))
def _len_db(self) -> int:
with self.env.begin(write=False, buffers=True) as txn:
stat = txn.stat(db=None)
return stat["entries"]
def __getattr__(self, name: str):
return getattr(self.dataset, name, getattr(self.dataset, name))
本文版权,除注明引用的部分外,归作者所有。本文严禁商业用途的转载。非商业用途的转载需在网页明显处署上作者名称及原文链接。