加速Pytorch数据读取--LMDB
背景
在深度学习的时候,如果你的batch size调的很大,或者你每次获取一个batch需要许多的预操作,那么pytorch的Dataloader获取一个batch就会花费较多的时间,那么训练的时候就会出现GPU等CPU的情况,训练的效率就会下降。
为了应对这种情况,Tensorflow有TFrecord,但是Pytorch没有对应的数据格式,在查询各类资料之后,我决定使用LMDB这个数据库
LMDB是一种数据库,可以实现多进程访问,访问简单,而且不需要把全部文件读入内存,总而言之就是速度很快
方法
我们想用LMDB来读取的话,首先就需要将我们原始的数据集转换为LMDB的数据格式,然后在训练的时候读取这个文件就行了。首先我们来实现将原始数据集转换为LMDB的过程:
转换为LMDB格式
相必大家在寻找一个更加快速的Dataloader的时候,已经写好了Pytorch常规的Dataloader,我们这里就可以利用上这个已有的Dataloader
首先打开一个lmdb数据库文件,如果之前有用过其他文件数据库的话,会发现这有点相似~
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)
然后准备向其中写入数据:
txn = db.begin(write=True)
然后就是将图片放到lmdb中
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow(
(pic.numpy())
))
最后
txn.commit()
总体的操作与数据库一样,总体就是打开数据库,放入数据,最后commit
最后放出完整的方法:
def data2lmdb(dataloader, db_path='data_lmdb', name="train", write_frequency=50):
"""
Args:
dataloader: the general dataloader of the dataset, e.g torch.utils.data.DataLoader
db_path: the path you want to save the lmdb file
name: train or test
write_frequency: Write once every ? rounds
Returns:
None
"""
if not os.path.exists(db_path):
os.makedirs(db_path)
lmdb_path = os.path.join(db_path, "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)
print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)
txn = db.begin(write=True)
for idx, data in enumerate(dataloader):
# get data from dataloader
if name == 'train':
pic = data
elif name == 'test':
pic = data
else:
raise 'unexpect name :{}'.format(name)
# put data to lmdb dataset
# {idx, (in_LDRs, in_HDRs, ref_HDRs)}
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow(
(pic.numpy())
))
if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(dataloader)))
txn.commit()
txn = db.begin(write=True)
# finish iterating through dataset
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))
print("Flushing database ...")
db.sync()
db.close()
上面的函数适合我的使用场景,如果你需要搬去用的话,需要修改dataloader输出的地方,根据自己的dataloader来写
读取LMDB文件
这里放上IMDB的Dataloader
class ImageFolderLMDBTest(data.Dataset):
def __init__(self, db_path):
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
# self.length = txn.stat()['entries'] - 1
self.length = pa.deserialize(txn.get(b'__len__'))
self.keys = pa.deserialize(txn.get(b'__keys__'))
def __getitem__(self, index):
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
unpacked = pa.deserialize(byteflow)
# load image
# 这里写你写入lmdb时的数据,上面的我写入了pic,这里展开就还是pic
pic = unpacked
return pic
def to_tensor(self, img):
img_t = torch.from_numpy(img.copy())
if isinstance(img_t, torch.ByteTensor):
return img_t.float().div(255)
else:
return img_t
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
这个Dataloader的结构和一般的Dataloader一样,需要注意的我也写在了注释里面,其实就是根据你写入的东西不同,从LMDB里取出的东西也不一样
后记
在用上这个LMDB文件格式之后,数据的读取速度也快了很多,GPU也终于不会歇着了🤣