pytorch 可训练数据集创建(torch.utils.data)

pytorch 是应用非常广泛的深度学习框架,模型训练的第一步就是数据集的创建。

pytorch 可训练数据集创建的操作步骤如下:

1.创建一个Dataset对象

2.创建一个DataLoader对象

3.循环这个DataLoder对象,将data,label加载到模型中训练

其中Dataset和Dataloader的创建就要用到pytorch的torch.utils.data 中的Dataset类和DataLoader类。

首先看一下torch.utils.data.Dataset 的源码

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

torch.utils.data.Dataset 是代表自定义数据集的抽象类,我们可以定义自己的数据类抽象这个类,只需要重写__len__和__getitem__这两个方法就可以。

作用:

__len__(self)获取数据集的长度

__geiitem__(self, index)函数来根据索引号获取图片和标签

通常我们按如下方式定义自己的数据类

# 导入必要的包
import
torch import os import torch.utils.data as data from PIL import Image import numpy as np import random # 自定义数据集,继承Dataset父类 class VeriDataset(data.Dataset):
  #初始化, (图片文件路径, txt文件, transfrom ...)
def __init__(self, data_dir, train_list, train_data_transform=None, is_train=True): ''' data_dir: 图像文件根目录 train_list: 图像名称txt文件 train_data_transform: 图像预处理 is_train: 训练集集验证集标志 ''' super(VeriDataset, self).__init__() self.is_train = is_train self.data_dir = data_dir self.train_data_transform = train_data_transform #读取.txt文件 f = open(train_list, 'r') lines = f.readlines() f.close() self.names = [] self.labels = [] self.cams = [] if is_train == True: i = 0 for line in lines: if i % 10000 == 0: print(line) line = line.strip().split(' ') self.names.append(line[0]) self.labels.append(line[1]) self.cams.append(line[0].split('_')[1]) i += 1 # 训练集、验证集的文件储存不一样 else: for line in lines: line = line.strip() self.names.append(line) self.labels.append(line.split('_')[0]) self.cams.append(line.split('_')[1]) # self.labels = np.array(self.labels, dtype=np.float32)
#根据索引获取图片和标签(重写父类函数)
def __getitem__(self, index): ''' index 自动+1 ''' img = Image.open(os.path.join(self.data_dir, self.names[index])).convert('RGB') # print("图像数据已输入") target = int(self.labels[index]) camid = self.cams[index] if self.train_data_transform != None: img = self.train_data_transform(img) return img, target, camid
# 返回数据集大小(重写父类函数)
def __len__(self): return len(self.names)

 

 创建一个DataLoader对象

DataLoader也是pytorch的重要接口,该接口可以将自定义的Dataset 根据batch_size大小、是否shuffle等封装成一个BatchSize大小的Tensor,用于后面训练。

看一下DataLoader的源码

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: 1).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: False).
        sampler (Sampler, optional): defines the strategy to draw samples from
            the dataset. If specified, ``shuffle`` must be False.
        batch_sampler (Sampler, optional): like sampler, but returns a batch of
            indices at a time. Mutually exclusive with batch_size, shuffle,
            sampler, and drop_last.
        num_workers (int, optional): how many subprocesses to use for data
            loading. 0 means that the data will be loaded in the main process.
            (default: 0)
        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
            into CUDA pinned memory before returning them.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: False)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: 0)
        worker_init_fn (callable, optional): If not None, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: None)

    .. note:: By default, each worker will have its PyTorch seed set to
              ``base_seed + worker_id``, where ``base_seed`` is a long generated
              by main process using its RNG. However, seeds for other libraies
              may be duplicated upon initializing workers (w.g., NumPy), causing
              each worker to return identical random numbers. (See
              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
              use ``torch.initial_seed()`` to access the PyTorch seed for each
              worker in :attr:`worker_init_fn`, and use it to set other seeds
              before data loading.

    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                 unpicklable object, e.g., a lambda function.
    """

    __initialized = False

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            self.batch_size = None
            self.drop_last = None

        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers option cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)  //将list打乱
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True

    def __setattr__(self, attr, val):
        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
            raise ValueError('{} attribute should not be set after {} is '
                             'initialized'.format(attr, self.__class__.__name__))

        super(DataLoader, self).__setattr__(attr, val)

    def __iter__(self):
        return _DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

 

一般我们关心的有以下几个参数

dataset:传入的数据集

batch_size:每个batch有多少个样本

shuffle:在每个epoch开始的时候,对数据进行重排

num_workers:有几个进程来处理data_loading

生成dataloader和我们通常会用for循环遍历数据进行训练

因为DataLoader只有__iter__()而没有实现__next__(),所以DataLoader是一个iterable而不是iteraror。所以__iter__()需要返回一个迭代器,_DataLoaderIter。

 在_DataLoaderIter中实现了__next__()方法。

class _DataLoaderIter(object):
    "Iterates once over the DataLoader's dataset, as specified by the sampler"

    def __init__(self, loader):
        self.dataset = loader.dataset
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        self.num_workers = loader.num_workers
        self.pin_memory = loader.pin_memory and torch.cuda.is_available()
        self.timeout = loader.timeout
        self.done_event = threading.Event()

        self.sample_iter = iter(self.batch_sampler)

        if self.num_workers > 0:
            self.worker_init_fn = loader.worker_init_fn
            self.index_queue = multiprocessing.SimpleQueue()
            self.worker_result_queue = multiprocessing.SimpleQueue()
            self.batches_outstanding = 0
            self.worker_pids_set = False
            self.shutdown = False
            self.send_idx = 0
            self.rcvd_idx = 0
            self.reorder_dict = {}

            base_seed = torch.LongTensor(1).random_()[0]
            self.workers = [
                multiprocessing.Process(
                    target=_worker_loop,
                    args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
                          base_seed + i, self.worker_init_fn, i))
                for i in range(self.num_workers)]

            if self.pin_memory or self.timeout > 0:
                self.data_queue = queue.Queue()
                self.worker_manager_thread = threading.Thread(
                    target=_worker_manager_loop,
                    args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                          torch.cuda.current_device()))
                self.worker_manager_thread.daemon = True
                self.worker_manager_thread.start()
            else:
                self.data_queue = self.worker_result_queue

            for w in self.workers:
                w.daemon = True  # ensure that the worker exits on process exit
                w.start()

            _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
            _set_SIGCHLD_handler()
            self.worker_pids_set = True

            # prime the prefetch loop
            for _ in range(2 * self.num_workers):
                self._put_indices()


    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self._get_batch()
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            return self._process_next_batch(batch)

    def __iter__(self):
        return self

 

 实例如下

from torch.utils.data import DataLoader

train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)

 for i, (image, target, camid) in enumerate(train_loader):
   batch_size = image.size(0)

   target = target.cuda() #target为tuple
    #转化为GPU格式
    # volatile 失效
   image = torch.autograd.Variable(image, volatile=True).cuda()
   mage = image.cuda()
   with torch.no_grad():
     image = torch.autograd.Variable(image).cuda()

   output, feat = model(image)

 

参考博客:

https://www.cnblogs.com/ytxwzqin/p/13086436.html

https://blog.csdn.net/qq_36653505/article/details/83351808

https://blog.csdn.net/tsq292978891/article/details/79414512?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

posted @ 2020-10-20 23:49  learningcaiji  阅读(2197)  评论(0编辑  收藏  举报