Loading

加速Pytorch数据读取--LMDB

背景

在深度学习的时候,如果你的batch size调的很大,或者你每次获取一个batch需要许多的预操作,那么pytorch的Dataloader获取一个batch就会花费较多的时间,那么训练的时候就会出现GPU等CPU的情况,训练的效率就会下降。

为了应对这种情况,Tensorflow有TFrecord,但是Pytorch没有对应的数据格式,在查询各类资料之后,我决定使用LMDB这个数据库

LMDB是一种数据库,可以实现多进程访问,访问简单,而且不需要把全部文件读入内存,总而言之就是速度很快

方法

我们想用LMDB来读取的话,首先就需要将我们原始的数据集转换为LMDB的数据格式,然后在训练的时候读取这个文件就行了。首先我们来实现将原始数据集转换为LMDB的过程:

转换为LMDB格式

相必大家在寻找一个更加快速的Dataloader的时候,已经写好了Pytorch常规的Dataloader,我们这里就可以利用上这个已有的Dataloader

首先打开一个lmdb数据库文件,如果之前有用过其他文件数据库的话,会发现这有点相似~

db = lmdb.open(lmdb_path, subdir=isdir,
               map_size=1099511627776 * 2, readonly=False,
               meminit=False, map_async=True)

然后准备向其中写入数据:

txn = db.begin(write=True)

然后就是将图片放到lmdb中

txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow(
    (pic.numpy())
))

最后

txn.commit()

总体的操作与数据库一样,总体就是打开数据库,放入数据,最后commit

最后放出完整的方法:

def data2lmdb(dataloader, db_path='data_lmdb', name="train", write_frequency=50):
    """
    Args:
        dataloader: the general dataloader of the dataset, e.g torch.utils.data.DataLoader
        db_path: the path you want to save the lmdb file
        name: train or test
        write_frequency: Write once every ? rounds

    Returns:
        None
    """
    if not os.path.exists(db_path):
        os.makedirs(db_path)
    lmdb_path = os.path.join(db_path, "%s.lmdb" % name)
    isdir = os.path.isdir(lmdb_path)

    print("Generate LMDB to %s" % lmdb_path)
    db = lmdb.open(lmdb_path, subdir=isdir,
                   map_size=1099511627776 * 2, readonly=False,
                   meminit=False, map_async=True)

    txn = db.begin(write=True)
    for idx, data in enumerate(dataloader):
        # get data from dataloader
        if name == 'train':
            pic = data
        elif name == 'test':
            pic = data
        else:
            raise 'unexpect name :{}'.format(name)

        # put data to lmdb dataset
        # {idx, (in_LDRs, in_HDRs, ref_HDRs)}
        txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow(
            (pic.numpy())
        ))
        if idx % write_frequency == 0:
            print("[%d/%d]" % (idx, len(dataloader)))
            txn.commit()
            txn = db.begin(write=True)

    # finish iterating through dataset
    txn.commit()
    keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
    with db.begin(write=True) as txn:
        txn.put(b'__keys__', dumps_pyarrow(keys))
        txn.put(b'__len__', dumps_pyarrow(len(keys)))

    print("Flushing database ...")
    db.sync()
    db.close()

上面的函数适合我的使用场景,如果你需要搬去用的话,需要修改dataloader输出的地方,根据自己的dataloader来写

读取LMDB文件

这里放上IMDB的Dataloader

class ImageFolderLMDBTest(data.Dataset):
    def __init__(self, db_path):
        self.db_path = db_path
        self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            # self.length = txn.stat()['entries'] - 1
            self.length = pa.deserialize(txn.get(b'__len__'))
            self.keys = pa.deserialize(txn.get(b'__keys__'))

    def __getitem__(self, index):
        env = self.env
        with env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index])
        unpacked = pa.deserialize(byteflow)

        # load image
        # 这里写你写入lmdb时的数据,上面的我写入了pic,这里展开就还是pic
        pic = unpacked

        return pic

    def to_tensor(self, img):
        img_t = torch.from_numpy(img.copy())
        if isinstance(img_t, torch.ByteTensor):
            return img_t.float().div(255)
        else:
            return img_t

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'

这个Dataloader的结构和一般的Dataloader一样,需要注意的我也写在了注释里面,其实就是根据你写入的东西不同,从LMDB里取出的东西也不一样

后记

在用上这个LMDB文件格式之后,数据的读取速度也快了很多,GPU也终于不会歇着了🤣

posted @ 2021-11-12 22:41  _CHENBIN  阅读(2126)  评论(0编辑  收藏  举报