【PyTorch】PyTorch使用LMDB数据库加速文件读取

PyTorch使用LMDB数据库加速文件读取

原始文档:https://www.yuque.com/lart/ugkv9f/hbnym1

对于数据库的了解较少,文章中大部分的介绍主要来自于各种博客和LMDB的文档,但是文档中的介绍,默认是已经了解了数据库的许多知识,这导致目前只能囫囵吞枣,待之后仔细了解后再重新补充内容。

背景介绍

文章https://blog.csdn.net/jyl1999xxxx/article/details/53942824中介绍了使用LMDB的原因:

Caffe使用LMDB来存放训练/测试用的数据集,以及使用网络提取出的feature(为了方便,以下还是统称数据集)。数据集的结构很简单,就是大量的矩阵/向量数据平铺开来。数据之间没有什么关联,数据内没有复杂的对象结构,就是向量和矩阵。既然数据并不复杂,Caffe就选择了LMDB这个简单的数据库来存放数据。

LMDB的全称是Lightning Memory-Mapped Database,闪电般的内存映射数据库。它文件结构简单,一个文件夹,里面一个数据文件,一个锁文件。数据随意复制,随意传输。它的访问简单,不需要运行单独的数据库管理进程,只要在访问数据的代码里引用LMDB库,访问时给文件路径即可。

图像数据集归根究底从图像文件而来。引入数据库存放数据集,是为了减少IO开销。读取大量小文件的开销是非常大的,尤其是在机械硬盘上。LMDB的整个数据库放在一个文件里,避免了文件系统寻址的开销。LMDB使用内存映射的方式访问文件,使得文件内寻址的开销非常小,使用指针运算就能实现。数据库单文件还能减少数据集复制/传输过程的开销。一个几万,几十万文件的数据集,不管是直接复制,还是打包再解包,过程都无比漫长而痛苦。LMDB数据库只有一个文件,你的介质有多块,就能复制多快,不会因为文件多而慢如蜗牛。

在文章http://shuokay.com/2018/05/14/python-lmdb/中类似提到:

为什么要把图像数据转换成大的二进制文件?
简单来说,是因为读写小文件的速度太慢。那么,不禁要问,图像数据也是二进制文件,单个大的二进制文件例如 LMDB 文件也是二进制文件,为什么单个图像读写速度就慢了呢?这里分两种情况解释。

  1. 机械硬盘的情况:机械硬盘的每次读写启动时间比较长,例如磁头的寻道时间占比很高,因此,如果单个小文件读写,尤其是随机读写单个小文件的时候,这个寻道时间占比就会很高,最后导致大量读写小文件的时候时间会很浪费;
  2. NFS 的情况:在 NFS 的场景下,系统的一次读写首先要进行上百次的网络通讯,并且这个通讯次数和文件的大小无关。因此,如果是读写小文件,这个网络通讯时间占据了整个读写时间的大部分。
    固态硬盘的情况下应该也会有一些类似的开销,目前没有研究过。

总而言之,使用LMDB可以为我们的数据读取进行加速。

具体操作

LMDB主要类

pip install imdb

lmdb.Environment

lmdb.open() 这个方法实际上是 class lmdb.Environment(path, map_size=10485760, subdir=True, readonly=False, metasync=True, sync=True, map_async=False, mode=493, create=True, readahead=True, writemap=False, meminit=True, max_readers=126, max_dbs=0, max_spare_txns=1, lock=True) 的一个别名(shortcut),二者是等价的。关于这个类:https://lmdb.readthedocs.io/en/release/#environment-class

这是数据库环境的结构。 一个环境可能包含多个数据库,所有数据库都驻留在同一共享内存映射和基础磁盘文件中。要写入环境,必须创建事务(Transaction)。 允许同时进行一次写入事务,但是即使存在写入事务,读取事务的数量也没有限制。

几个重要的实例方法:

  • begin(db=None, parent=None, write=False, buffers=False): 可以调用事务类 lmdb.Transaction
  • open_db(key=None, txn=None, reverse_key=False, dupsort=False, create=True, integerkey=False, integerdup=False, dupfixed=False): 打开一个数据库,返回一个不透明的句柄。重复Environment.open_db() 调用相同的名称将返回相同的句柄。作为一个特殊情况,主数据库总是开放的。命名数据库是通过在主数据库中存储一个特殊的描述符来实现的。环境中的所有数据库共享相同的文件。因为描述符存在于主数据库中,所以如果已经存在与数据库名称匹配的 key ,创建命名数据库的尝试将失败。此外,查找和枚举可以看到key 。如果主数据库keyspace与命名数据库使用的名称冲突,则将主数据库的内容移动到另一个命名数据库。
>>> env = lmdb.open('/tmp/test', max_dbs=2)
>>> with env.begin(write=True) as txn
...     txn.put('somename', 'somedata')
>>> # Error: database cannot share name of existing key!
>>> subdb = env.open_db('somename')

lmdb.Transaction

这和事务对象有关。

class lmdb.Transaction(env, db=None, parent=None, write=False, buffers=False)

关于这个类的参数:https://lmdb.readthedocs.io/en/release/#transaction-class

所有操作都需要事务句柄,事务可以是只读或读写的。写事务可能不会跨越线程。事务对象实现了上下文管理器协议,因此即使面对未处理的异常,也可以可靠地释放事务:

# Transaction aborts correctly:
with env.begin(write=True) as txn:
   crash()
# Transaction commits automatically:
with env.begin(write=True) as txn:
   txn.put('a', 'b')

这个类的实例包含着很多有用的操作方法。

  • abort(): 中止挂起的事务。重复调用 abort() 在之前成功的 commit()abort() 后或者在相关环境关闭后是没有效果的。
  • commit(): 提交挂起的事务。
  • cursor(db=None): Shortcut for lmdb.Cursor(db, self)
  • delete(key, value='', db=None): Delete a key from the database.
    • key: The key to delete.
    • value:如果数据库是以 dupsort = True 打开的,并且 value 不是空的 bytestring ,则删除仅与此 (key, value) 对匹配的元素,否则该 key 的所有值都将被删除。
    • Returns True if at least one key was deleted.
  • drop(db, delete=True): 删除命名数据库中的所有键,并可选地删除命名数据库本身。删除命名数据库会导致其不可用,并使现有cursors无效。
  • get(key, default=None, db=None): 获取匹配键的第一个值,如果键不存在,则返回默认值。cursor必须用于获取 dupsort = True 数据库中的 key 的所有值。
  • id(): 返回事务的ID。这将返回与此事务相关联的标识符。对于只读事务,这对应于正在读取的快照; 并发读取器通常具有相同的事务ID。
  • pop(key, db=None): 使用临时cursor调用 Cursor.pop()
    • db: 要操作的命名数据库。如果未指定,默认为事务构造函数被给定的数据库。
  • put(key, value, dupdata=True, overwrite=True, append=False, db=None): 存储一条记录(record),如果记录被写入,则返回 True ,否则返回 False ,以指示key已经存在并且 overwrite = False 。成功后,cursor位于新记录上。
    • key: Bytestring key to store.
    • value: Bytestring value to store.
    • dupdata: 如果 True ,并且数据库是用 dupsort = True 打开的,如果给定 key 已经存在,则添加键值对作为副本。否则覆盖任何现有匹配的 key
    • overwrite: If False , do not overwrite any existing matching key.
    • append: 如果为 True ,则将对附加到数据库末尾,而不首先比较其顺序。附加不大于现有最高 keykey 将导致损坏。
    • db: 要操作的命名数据库。如果未指定,默认为事务构造函数被给定的数据库。
  • replace(key, value, db=None): 使用临时cursor调用 Cursor.replace() .
  • db: Named database to operate on. If unspecified, defaults to the database given to the Transaction constructor.
  • stat(db): Return statistics like Environment.stat() , except for a single DBI. db must be a database handle returned by open_db() .

Imdb.Cursor

class lmdb.Cursor(db, txn) 是用于在数据库中导航(navigate)的结构。

  • db: Database to navigate.
  • txn: Transaction to navigate.

As a convenience, Transaction.cursor() can be used to quickly return a cursor:

>>> env = lmdb.open('/tmp/foo')
>>> child_db = env.open_db('child_db')
>>> with env.begin() as txn:
...     cursor = txn.cursor()           # Cursor on main database.
...     cursor2 = txn.cursor(child_db)  # Cursor on child database.

游标以未定位的状态开始。如果在这种状态下使用 iternext()iterprev() ,那么迭代将分别从开始处和结束处开始。迭代器直接使用游标定位,这意味着在同一游标上存在多个迭代器时会产生奇怪的行为

从Python绑定的角度来看,一旦任何扫描或查找方法(例如 next()prev_nodup()set_range() )返回 False 或引发异常,游标将返回未定位状态。这主要是为了确保在面对任何错误条件时语义的安全性和一致性。
当游标返回到未定位的状态时,它的 key()value() 返回空字符串,表示没有活动的位置,尽管在内部,LMDB游标可能仍然有一个有效的位置。
这可能会导致在迭代 dupsort=True 数据库的 key 时出现一些令人吃惊的行为,因为 iternext_dup() 等方法将导致游标显示为未定位,尽管它返回 False 只是为了表明当前键没有更多的值。在这种情况下,简单地调用 next() 将导致在下一个可用键处继续迭代。
This behaviour may change in future.

Iterator methods such as iternext() and iterprev() accept keys and values arguments. If both are True , then the value of item() is yielded on each iteration. If only keys is True , key() is yielded, otherwise only value() is yielded.

在迭代之前,游标可能定位在数据库中的任何位置

>>> with env.begin() as txn:
...     cursor = txn.cursor()
...     if not cursor.set_range('5'): # Position at first key >= '5'.
...         print('Not found!')
...     else:
...         for key, value in cursor: # Iterate from first key >= '5'.
...             print((key, value))

不需要迭代来导航,有时会导致丑陋或低效的代码。在迭代顺序不明显的情况下,或者与正在读取的数据相关的情况下,使用 set_key()set_range()key()value()item() 可能是更好的选择。

>>> # Record the path from a child to the root of a tree.
>>> path = ['child14123']
>>> while path[-1] != 'root':
...     assert cursor.set_key(path[-1]), \
...         'Tree is broken! Path: %s' % (path,)
...     path.append(cursor.value())

几个实例方法:

  • set_key(key): Seek exactly to key, returning True on success or False if the exact key was not found. 对于 set_key() ,空字节串是错误的。对于使用 dupsort=True 打开的数据库,移动到键的第一个值(复制)。
  • set_range(key): Seek to the first key greater than or equal to key , returning True on success, or False to indicate key was past end of database. Behaves like first() if key is the empty bytestring. 对于使用 dupsort=True 打开的数据库,移动到键的第一个值(复制)。
  • get(key, default=None): Equivalent to set_key() , except value() is returned when key is found, otherwise default.
  • item(): Return the current (key, value) pair.
  • key(): Return the current key.
  • value(): Return the current value.

操作流程

概况地讲,操作LMDB的流程是:

  • 通过 env = lmdb.open() 打开环境
  • 通过 txn = env.begin() 建立事务
  • 通过 txn.put(key, value) 进行插入和修改
  • 通过 txn.delete(key) 进行删除
  • 通过 txn.get(key) 进行查询
  • 通过 txn.cursor() 进行遍历
  • 通过 txn.commit() 提交更改

这里要注意:

  1. putdelete 后一定注意要 commit ,不然根本没有存进去
  2. 每一次 commit 后,需要再定义一次 txn=env.begin(write=True)

来自https://github.com/kophy/py4db的代码:

#!/usr/bin/env python

import lmdb
import os, sys

def initialize():
	env = lmdb.open("students");
	return env;

def insert(env, sid, name):
	txn = env.begin(write = True);
	txn.put(str(sid), name);
	txn.commit();

def delete(env, sid):
	txn = env.begin(write = True);
	txn.delete(str(sid));
	txn.commit();

def update(env, sid, name):
	txn = env.begin(write = True);
	txn.put(str(sid), name);
	txn.commit();

def search(env, sid):
	txn = env.begin();
	name = txn.get(str(sid));
	return name;

def display(env):
	txn = env.begin();
	cur = txn.cursor();
	for key, value in cur:
		print (key, value);

env = initialize();

print "Insert 3 records."
insert(env, 1, "Alice");
insert(env, 2, "Bob");
insert(env, 3, "Peter");
display(env);

print "Delete the record where sid = 1."
delete(env, 1);
display(env);

print "Update the record where sid = 3."
update(env, 3, "Mark");
display(env);

print "Get the name of student whose sid = 3."
name = search(env, 3);
print name;

env.close();

os.system("rm -r students");

创建图像数据集

这里主要借鉴自https://github.com/open-mmlab/mmsr/blob/master/codes/data_scripts/create_lmdb.py的代码。

改写为:

import glob
import os
import pickle
import sys

import cv2
import lmdb
import numpy as np
from tqdm import tqdm


def main(mode):
    proj_root = '/home/lart/coding/TIFNet'
    datasets_root = '/home/lart/Datasets/'
    lmdb_path = os.path.join(proj_root, 'datasets/ECSSD.lmdb')
    data_path = os.path.join(datasets_root, 'RGBSaliency', 'ECSSD/Image')
    
    if mode == 'creating':
        opt = {
            'name': 'TrainSet',
            'img_folder': data_path,
            'lmdb_save_path': lmdb_path,
            'commit_interval': 100,  # After commit_interval images, lmdb commits
            'num_workers': 8,
        }
        general_image_folder(opt)
    elif mode == 'testing':
        test_lmdb(lmdb_path, index=1)


def general_image_folder(opt):
    """
    Create lmdb for general image folders
    If all the images have the same resolution, it will only store one copy of resolution info.
        Otherwise, it will store every resolution info.
    """
    img_folder = opt['img_folder']
    lmdb_save_path = opt['lmdb_save_path']
    meta_info = {'name': opt['name']}
    
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with 'lmdb'.")
    if os.path.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)
    
    # read all the image paths to a list
    
    print('Reading image path list ...')
    all_img_list = sorted(glob.glob(os.path.join(img_folder, '*')))
    # cache the filename, 这里的文件名必须是ascii字符
    keys = []
    for img_path in all_img_list:
        keys.append(os.path.basename(img_path))
    
    # create lmdb environment
    
    # 估算大概的映射空间大小
    data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
    # map_size:
    # Maximum size database may grow to; used to size the memory mapping. If database grows larger
    # than map_size, an exception will be raised and the user must close and reopen Environment.
    
    # write data to lmdb
    
    txn = env.begin(write=True)
    resolutions = []
    tqdm_iter = tqdm(enumerate(zip(all_img_list, keys)), total=len(all_img_list), leave=False)
    for idx, (path, key) in tqdm_iter:
        tqdm_iter.set_description('Write {}'.format(key))
        
        key_byte = key.encode('ascii')
        data = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        
        if data.ndim == 2:
            H, W = data.shape
            C = 1
        else:
            H, W, C = data.shape
        resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
        
        txn.put(key_byte, data)
        if (idx + 1) % opt['commit_interval'] == 0:
            txn.commit()
            # commit 之后需要再次 begin
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')
    
    # create meta information
    
    # check whether all the images are the same size
    assert len(keys) == len(resolutions)
    if len(set(resolutions)) <= 1:
        meta_info['resolution'] = [resolutions[0]]
        meta_info['keys'] = keys
        print('All images have the same resolution. Simplify the meta info.')
    else:
        meta_info['resolution'] = resolutions
        meta_info['keys'] = keys
        print('Not all images have the same resolution. Save meta info for each image.')
    
    pickle.dump(meta_info, open(os.path.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
    print('Finish creating lmdb meta info.')


def test_lmdb(dataroot, index=1):
    env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
    meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), "rb"))
    print('Name: ', meta_info['name'])
    print('Resolution: ', meta_info['resolution'])
    print('# keys: ', len(meta_info['keys']))
    
    # read one image
    key = meta_info['keys'][index]
    print('Reading {} for test.'.format(key))
    with env.begin(write=False) as txn:
        buf = txn.get(key.encode('ascii'))
    img_flat = np.frombuffer(buf, dtype=np.uint8)
    
    C, H, W = [int(s) for s in meta_info['resolution'][index].split('_')]
    img = img_flat.reshape(H, W, C)
    
    cv2.namedWindow('Test')
    cv2.imshow('Test', img)
    cv2.waitKeyEx()


if __name__ == "__main__":
    # mode = creating or testing
    main(mode='creating')

配合DataLoader

这里仅对训练集进行LMDB处理,测试机依旧使用的原始的读取图片的方式。

import os
import pickle

import lmdb
import numpy as np
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from utils import joint_transforms


def _get_paths_from_lmdb(dataroot):
    """get image path list from lmdb meta info"""
    meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'),
                                 'rb'))
    paths = meta_info['keys']
    sizes = meta_info['resolution']
    if len(sizes) == 1:
        sizes = sizes * len(paths)
    return paths, sizes


def _read_img_lmdb(env, key, size):
    """read image from lmdb with key (w/ and w/o fixed size)
    size: (C, H, W) tuple"""
    with env.begin(write=False) as txn:
        buf = txn.get(key.encode('ascii'))
    img_flat = np.frombuffer(buf, dtype=np.uint8)
    C, H, W = size
    img = img_flat.reshape(H, W, C)
    return img


def _make_dataset(root, prefix=('.jpg', '.png')):
    img_path = os.path.join(root, 'Image')
    gt_path = os.path.join(root, 'Mask')
    img_list = [
        os.path.splitext(f)[0] for f in os.listdir(gt_path)
        if f.endswith(prefix[1])
    ]
    return [(os.path.join(img_path, img_name + prefix[0]),
             os.path.join(gt_path, img_name + prefix[1]))
            for img_name in img_list]


class TestImageFolder(Dataset):
    def __init__(self, root, in_size, prefix):
        self.imgs = _make_dataset(root, prefix=prefix)
        self.test_img_trainsform = transforms.Compose([
            # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值
            transforms.Resize((in_size, in_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    def __getitem__(self, index):
        img_path, gt_path = self.imgs[index]
        
        img = Image.open(img_path).convert('RGB')
        img_name = (img_path.split(os.sep)[-1]).split('.')[0]
        
        img = self.test_img_trainsform(img)
        return img, img_name
    
    def __len__(self):
        return len(self.imgs)


class TrainImageFolder(Dataset):
    def __init__(self, root, in_size, scale=1.5, use_bigt=False):
        self.use_bigt = use_bigt
        self.in_size = in_size
        self.root = root
        
        self.train_joint_transform = joint_transforms.Compose([
            joint_transforms.JointResize(in_size),
            joint_transforms.RandomHorizontallyFlip(),
            joint_transforms.RandomRotate(10)
        ])
        self.train_img_transform = transforms.Compose([
            transforms.ColorJitter(0.1, 0.1, 0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])  # 处理的是Tensor
        ])
        # ToTensor 操作会将 PIL.Image 或形状为 H×W×D,数值范围为 [0, 255] 的 np.ndarray 转换为形状为 D×H×W,
        # 数值范围为 [0.0, 1.0] 的 torch.Tensor。
        self.train_target_transform = transforms.ToTensor()
        
        self.gt_root = '/home/lart/coding/TIFNet/datasets/DUTSTR/DUTSTR_GT.lmdb'
        self.img_root = '/home/lart/coding/TIFNet/datasets/DUTSTR/DUTSTR_IMG.lmdb'
        self.paths_gt, self.sizes_gt = _get_paths_from_lmdb(self.gt_root)
        self.paths_img, self.sizes_img = _get_paths_from_lmdb(self.img_root)
        self.gt_env = lmdb.open(self.gt_root, readonly=True, lock=False, readahead=False,
                                meminit=False)
        self.img_env = lmdb.open(self.img_root, readonly=True, lock=False, readahead=False,
                                 meminit=False)
    
    
    def __getitem__(self, index):
        gt_path = self.paths_gt[index]
        img_path = self.paths_img[index]
        
        gt_resolution = [int(s) for s in self.sizes_gt[index].split('_')]
        img_resolution = [int(s) for s in self.sizes_img[index].split('_')]
        img_gt = _read_img_lmdb(self.gt_env, gt_path, gt_resolution)
        img_img = _read_img_lmdb(self.img_env, img_path, img_resolution)
        if img_img.shape[-1] != 3:
            img_img = np.repeat(img_img, repeats=3, axis=-1)
        img_img = img_img[:, :, [2, 1, 0]]  # bgr => rgb
        img_gt = np.squeeze(img_gt, axis=2)
        gt = Image.fromarray(img_gt, mode='L')
        img = Image.fromarray(img_img, mode='RGB')
        
        img, gt = self.train_joint_transform(img, gt)
        gt = self.train_target_transform(gt)
        img = self.train_img_transform(img)
        
        if self.use_bigt:
            gt = gt.ge(0.5).float()  # 二值化
        
        img_name = self.paths_img[index]
        return img, gt, img_name
    
    def __len__(self):
        return len(self.paths_img)


class DataLoaderX(DataLoader):
    def __iter__(self):
        return BackgroundGenerator(super(DataLoaderX, self).__iter__())

参考链接

posted @ 2019-11-25 14:03  lart  阅读(6241)  评论(1编辑  收藏  举报