Pytorch大批量流式数据IterableDataset实现(包括shuffle操作)

Pytorch大批量流式数据IterableDataset实现(包括shuffle操作)

对于小批量数,可以完全载入内存的数据集来说,我们一般的实践是通过定义 torch.utils.data.Dataset 这个类类实现,但是对于好几TB甚至更大的数量来说,我们显然无法直接加载到内存,因此我们需要使用 torch.utils.data.IterableDataset 来实现。这个类适用于处理大数据或者流式数据。

IterableDataset 实现核心想法

  1. 该类需要返回一个生成器,然后被 DadaLoader 调用, 需要实现 __iter__ 函数
  2. 实现打乱数据的操作,由于我们不知道数据长度,因此需要采用 蓄水池抽样 的方法来进行shuffle,所以我们还需要开辟一个 Buffer 空间作为 蓄水池,从而实现打乱操作。

代码

数据集文件 `iter_ds.txt

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line

实现代码

import itertools

import torch
from torch.utils.data import IterableDataset, DataLoader

import numpy as np
from copy import deepcopy


class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename, buffersize, seed, shuffle=False):

        # Store the filename in object's memory
        self.filename = filename
        self.buffer_size = buffersize
        self.generator = np.random.default_rng(seed=seed)
        self.shuffle = shuffle

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip('\n')
        ###

        return text_pp

    def line_mapper(self, line):

        # Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)
        label = self.preprocess(label)

        return text, label

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info:
            worker_total_num = worker_info.num_workers
            worker_id = worker_info.id
        else:
            worker_id = 0
            worker_total_num = 1

        # Create an iterator
        file_itr = open(self.filename)

        # Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)

        # Add multiworker functionality
        mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)

        if self.shuffle:
            return self._shuffle(mapped_itr)
        else:
            return mapped_itr

    @staticmethod
    def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batch_size=1000):
        while True:
            yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))

    def _shuffle(self, ex_iterable):
        buffer_size = self.buffer_size
        rng = deepcopy(self.generator)
        indices_iterator = self._iter_random_indices(rng, buffer_size)
        # this is the shuffle buffer that we keep in memory
        mem_buffer = []
        for x in ex_iterable:
            if len(mem_buffer) == buffer_size:  # if the buffer is full, pick and example from it
                i = next(indices_iterator)
                yield mem_buffer[i]
                mem_buffer[i] = x  # replace the picked example by a new one
            else:  # otherwise, keep filling the buffer
                mem_buffer.append(x)
        # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
        rng.shuffle(mem_buffer)
        yield from mem_buffer


# base_dataset = CustomIterableDatasetv1("iter_ds.txt",4,1,False)
base_dataset = CustomIterableDatasetv1("iter_ds.txt", 4, 1, True)
dataloader = DataLoader(base_dataset, batch_size=3, num_workers=0)
for X, y in dataloader:
    print(X, y)

参考实现:

  1. IterableDataset 简单使用: https://stackoverflow.com/questions/69778356/iterable-pytorch-dataset-with-multiple-workers
  2. IterableDataset 多进程使用:https://blog.csdn.net/autoliuweijie/article/details/121693112
  3. hugging face 的流式处理:
posted @ 2022-07-19 15:40  佰大于  阅读(2065)  评论(0编辑  收藏  举报