pytorch.utils.data
概览
torch.utils.data
主要是负责容纳数据集、数据打散、分批等操作。
这里面有三个概念:数据集dataset,抽样器sampler,数据加载器dataloader。其中第三个就是最终对外的接口,也是最重要的。
它们之间的关系是:首先需要根据源数据创建数据集dataset,然后根据dataset创建抽样器sampler,最后同时通过dataset和sampler来创建dataloader,这就是我们最终需要的。这个在训练、测试的时候,会得到batch数据。
dataset
第一个是dataset,就是常规理解的数据集。
数据集主要分为两种:map-style和iterable-style
map-style数据集,一般都是继承Dataset类 ,必须要实现__getitem__()
和__len__()
方法,表示从索引或者key到数据样本的映射
iterable-style数据集,一般都是继承IterableDataset类,必须实现__iter__()
方法,表示在数据样本上迭代。一般从一些流中实时获取数据(比如数据库、远程服务器或者日志),是无法进行随机读取的,这时就主要使用迭代式数据集。
一般如果数据量小,使用map-style就可以了,如果数据量很大,需要从数据流中获取,那就使用iterable-style
对应到具体的类,有以下六个:
- torch.utils.data.Dataset
- torch.utils.data.IterableDataset
- torch.utils.data.TensorDataset
- torch.utils.data.ConcatDataset
- torch.utils.data.ChainDataset
- torch.utils.data.Subset
除此之外,torch.utils.data还包含了两个函数
- torch.utils.data.get_worker_info()
- torch.utils.data.random_split()
sampler
sampler是抽样器,作用在dataset上面
抽样的方式也有几个方式:
按顺序抽样,随机抽样,在子集合中随机抽样,带权重的抽样等等
包括以下类:
- class Sampler
- class SequentialSampler
- class RandomSampler
- class SubsetRandomSampler
- class WeightedRandomSampler
- class BatchSampler
- class DistributedSampler
生成sampler的最终目的就是为了创建dataloader。
dataLoader
DataLoader是核心。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
构建DataLoader有几个重要的参数:
- dataset是数据集,
- batch_size
- shuffle 是否每一轮都将数据进行打散,最好通过sampler来打散,否则使用SequentialSampler的时候也会被打散。
- sampler 生成indices
- collate_fn
- pin_memory 含义参考pytorch pinned memory
实例1:通过TensorDataset快速生成dataloader
数据中有字符串类型的时候慎用。
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset, RandomSampler
import numpy as np
# 创建TensorDataset
feature = torch.tensor(np.arange(100))
dataset = TensorDataset([feature, feature])
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=5, sampler=sampler)
for epoch in range(2):
print('epoch=', epoch)
for index, batch in enumerate(dataloader):
print(batch)
if index > 10:
break
epoch= 0
[tensor([79, 6, 81, 35, 21], dtype=torch.int32), tensor([79, 6, 81, 35, 21], dtype=torch.int32)]
[tensor([43, 98, 86, 23, 68], dtype=torch.int32), tensor([43, 98, 86, 23, 68], dtype=torch.int32)]
[tensor([ 0, 36, 60, 1, 91], dtype=torch.int32), tensor([ 0, 36, 60, 1, 91], dtype=torch.int32)]
[tensor([71, 59, 72, 75, 52], dtype=torch.int32), tensor([71, 59, 72, 75, 52], dtype=torch.int32)]
[tensor([45, 2, 73, 46, 95], dtype=torch.int32), tensor([45, 2, 73, 46, 95], dtype=torch.int32)]
[tensor([82, 37, 24, 12, 16], dtype=torch.int32), tensor([82, 37, 24, 12, 16], dtype=torch.int32)]
[tensor([90, 11, 70, 31, 53], dtype=torch.int32), tensor([90, 11, 70, 31, 53], dtype=torch.int32)]
[tensor([15, 7, 64, 22, 65], dtype=torch.int32), tensor([15, 7, 64, 22, 65], dtype=torch.int32)]
[tensor([ 3, 87, 4, 17, 99], dtype=torch.int32), tensor([ 3, 87, 4, 17, 99], dtype=torch.int32)]
[tensor([83, 20, 19, 89, 42], dtype=torch.int32), tensor([83, 20, 19, 89, 42], dtype=torch.int32)]
[tensor([97, 58, 8, 38, 30], dtype=torch.int32), tensor([97, 58, 8, 38, 30], dtype=torch.int32)]
[tensor([54, 56, 48, 27, 57], dtype=torch.int32), tensor([54, 56, 48, 27, 57], dtype=torch.int32)]
epoch= 1
[tensor([66, 15, 37, 82, 47], dtype=torch.int32), tensor([66, 15, 37, 82, 47], dtype=torch.int32)]
[tensor([75, 70, 5, 99, 33], dtype=torch.int32), tensor([75, 70, 5, 99, 33], dtype=torch.int32)]
[tensor([80, 76, 55, 29, 41], dtype=torch.int32), tensor([80, 76, 55, 29, 41], dtype=torch.int32)]
[tensor([79, 17, 63, 92, 74], dtype=torch.int32), tensor([79, 17, 63, 92, 74], dtype=torch.int32)]
[tensor([52, 53, 58, 38, 87], dtype=torch.int32), tensor([52, 53, 58, 38, 87], dtype=torch.int32)]
[tensor([84, 59, 77, 48, 71], dtype=torch.int32), tensor([84, 59, 77, 48, 71], dtype=torch.int32)]
[tensor([56, 16, 27, 81, 60], dtype=torch.int32), tensor([56, 16, 27, 81, 60], dtype=torch.int32)]
[tensor([50, 73, 46, 28, 32], dtype=torch.int32), tensor([50, 73, 46, 28, 32], dtype=torch.int32)]
[tensor([45, 40, 10, 25, 9], dtype=torch.int32), tensor([45, 40, 10, 25, 9], dtype=torch.int32)]
[tensor([12, 49, 22, 51, 20], dtype=torch.int32), tensor([12, 49, 22, 51, 20], dtype=torch.int32)]
[tensor([ 6, 68, 72, 24, 67], dtype=torch.int32), tensor([ 6, 68, 72, 24, 67], dtype=torch.int32)]
[tensor([57, 96, 23, 97, 98], dtype=torch.int32), tensor([57, 96, 23, 97, 98], dtype=torch.int32)]
自定义Dataset
import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
class ToyDataset(Dataset):
def __init__(self):
self.Data = np.arange(32).reshape(16, 2).tolist()
self.Target = np.random.randint(0, 2, (16,1)).tolist()
def __getitem__(self, index):
txt = torch.LongTensor(self.Data[index])
label = torch.LongTensor(self.Target[index])
return txt, label
def __len__(self):
return len(self.Data)