Pytorch大批量流式数据IterableDataset实现(包括shuffle操作)
Pytorch大批量流式数据IterableDataset实现(包括shuffle操作)
对于小批量数,可以完全载入内存的数据集来说,我们一般的实践是通过定义 torch.utils.data.Dataset
这个类类实现,但是对于好几TB甚至更大的数量来说,我们显然无法直接加载到内存,因此我们需要使用 torch.utils.data.IterableDataset
来实现。这个类适用于处理大数据或者流式数据。
IterableDataset 实现核心想法
- 该类需要返回一个生成器,然后被 DadaLoader 调用, 需要实现
__iter__
函数 - 实现打乱数据的操作,由于我们不知道数据长度,因此需要采用
蓄水池抽样
的方法来进行shuffle,所以我们还需要开辟一个Buffer
空间作为蓄水池
,从而实现打乱操作。
代码
数据集文件 `iter_ds.txt
0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line
实现代码
import itertools
import torch
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
from copy import deepcopy
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename, buffersize, seed, shuffle=False):
# Store the filename in object's memory
self.filename = filename
self.buffer_size = buffersize
self.generator = np.random.default_rng(seed=seed)
self.shuffle = shuffle
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip('\n')
###
return text_pp
def line_mapper(self, line):
# Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
label = self.preprocess(label)
return text, label
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info:
worker_total_num = worker_info.num_workers
worker_id = worker_info.id
else:
worker_id = 0
worker_total_num = 1
# Create an iterator
file_itr = open(self.filename)
# Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
# Add multiworker functionality
mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)
if self.shuffle:
return self._shuffle(mapped_itr)
else:
return mapped_itr
@staticmethod
def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batch_size=1000):
while True:
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
def _shuffle(self, ex_iterable):
buffer_size = self.buffer_size
rng = deepcopy(self.generator)
indices_iterator = self._iter_random_indices(rng, buffer_size)
# this is the shuffle buffer that we keep in memory
mem_buffer = []
for x in ex_iterable:
if len(mem_buffer) == buffer_size: # if the buffer is full, pick and example from it
i = next(indices_iterator)
yield mem_buffer[i]
mem_buffer[i] = x # replace the picked example by a new one
else: # otherwise, keep filling the buffer
mem_buffer.append(x)
# when we run out of examples, we shuffle the remaining examples in the buffer and yield them
rng.shuffle(mem_buffer)
yield from mem_buffer
# base_dataset = CustomIterableDatasetv1("iter_ds.txt",4,1,False)
base_dataset = CustomIterableDatasetv1("iter_ds.txt", 4, 1, True)
dataloader = DataLoader(base_dataset, batch_size=3, num_workers=0)
for X, y in dataloader:
print(X, y)
参考实现:
IterableDataset
简单使用: https://stackoverflow.com/questions/69778356/iterable-pytorch-dataset-with-multiple-workersIterableDataset
多进程使用:https://blog.csdn.net/autoliuweijie/article/details/121693112hugging face
的流式处理: