PyTorch torch.utils.data 模块 结构化数据

torch.utils.data 模块中的一些函数,PyTorch 官方文档

1. Dataset

Dataset 类创建 Map-style 数据集,通过 __getitem__()__len__() 方法来从数据集中采样,样本可以表示为数据集的索引或键值(indices / keys)的映射(map)。

引入

from torch.utils.data import Dataset

主要作用: 规范化模型的数据,结合 DataLoader 类,根据索引,在每次训练的过程中取出数据,基本结构

class MyDataset(Dataset):
    def __init__(self, params):  
        # 传入必要的参数,原始数据集,等
        super(MyDataset, self).__init__()  # 父类初始化模型
        ...
        return None
        
    def __len__(self):
        # 返回数据集样本总数
        return data_len  
    
    def __getitem__(self, idx):
        # 根据索引 idx,确定每个(或 batch size)需要输入模型的样本
        # 返回值可根据具体情况调整
        return input, output       

1.1 TensorDataset() 函数

对于不需要任何加工的向量,TensorDataset() 函数可以直接将数据(torch.Tensor 数据类型)直接打包成 Dataset 类。

mydataset = TensorDataset(X, Y)  # X, Y 应为 torch.Tensor 类型,且数量相等(即X.shape[0]=Y.shape[0])

实例: 获取数据集的大小,从数据集中选取样本

# 加载 iris 数据集
from sklearn.datasets import load_iris
data = load_iris()
X, Y = data.data, data.target
print(X.shape, Y.shape)

import torch
import torch.utils.data as tud
X, Y = torch.tensor(X), torch.tensor(Y, dtype=torch.long)  # 数据转换为torch.Tensor 类型

mydataset = tud.TensorDataset(X, Y)        # 构建 Dataset 实例
print(mydataset.__len__())                 # 获取 Dataset 样本总数,len() 函数也可
print(mydataset.__getitem__(0))            # 获取 Dataset 中第一个样本(索引为 0)
print(mydataset.__getitem__([1, 2, 3]))    # 获取 Dataset 中第二、三个样本(索引为 1, 2, 3)

print(len(mydataset))        # 获取 Dataset 样本总数
print(mydataset[0])          # 获取 Dataset 中第一个样本(索引为 0)
print(mydataset[[1, 2, 3]])  # 获取 Dataset 中第二、三、四个样本(索引为 1, 2, 3)
print(mydataset[1:4])        # 获取 Dataset 中第二个到四个(第二、三、四)样本

2. DataLoader

引入

from torch.utils.data import DataLoader

主要参数

  • dataset:上文中 Dataset 类型

  • batch_sizeint 类型,默认为 1

  • shufflebool 类型,默认为 False;在每一 epoch 训练模型前,是否 shuffle 数据。

  • drop_lastbool 类型,默认为 False;是否丢弃(不参与训练)每一 epoch 最后 1 个 batch_size 的样本。这是由于样本总量(sample size)不能整除 batch size,因此,最后一批的样本数量多数情况下会小于 batch size,设置 drop_last=True,这最后这一批次的样本不参与模型训练。

  • collate_fn

实例:

n_sample = len(mydataset)
batch_size = 16

dataloader = tud.DataLoader(mydataset, batch_size=batch_size)   # DataLoader 实例
for i, (X_, Y_) in enumerate(dataloader):
    # print(X_, Y_)
    # X_, Y_ 为 Dataset 中 __getitem__() 方法的返回值,样本数量(X.shape[0])由 batch_size 决定
    print(X_.size(), Y_.size())

# 等价于,但无法 shuffle
for i in range((n_sample - 1) // batch_size + 1):
    X_, Y_ = mydataset[i * batch_size : (i + 1) * batch_size]
    print(i, X_.size(), Y_.size())

2.1 关于 GPU 加速

若使用 GPU 加速,需要将数据加载到 GPU 设备上。可以在重构 Dataset 类时,在 __getitem__() 方法的最后,将返回的数据加载到 GPU 设备上。

也可以在 DataLoader 类中,过设置 collate_fn 参数实现,代码如下:

import torch
from torch.utils.data.dataloader import default_collate

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

dataloader = tud.DataLoader(
    mydataset, batch_size=16, shuffle=True,
    collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))  # 将加载的数据置于 GPU 上

3. 基本工具

3.1 子集提取

以下函数返回的结果均为 Dataset 类(或 Subset 类型),可直接传入 DataLoader 中实现数据加载。

(1) random_split() 函数

random_split(dataset, lengths, generator=<torch._C.Generator object>) :随机划分数据集

主要参数:

  • dataset
  • lengths:list 类型,每个子集的样本数量
  • generator

实例:

random_split(MyDataset, [3, 7], generator=torch.Generator().manual_seed(42))

实例:

n_sample = X.size()[0]
n_train = int(n_sample*0.7)
n_valid = int(n_sample*0.2)
n_test = n_sample - n_train - n_valid
lens = [n_train, n_valid, n_test]
print(n_sample, lens)

d1, d2, d3 = tud.random_split(mydataset, lens, generator=None)
print(d1.__len__(), d2.__len__(), d3.__len__())
# Output: 105  30  15

由于 PyTorch 未提供非随机(即按顺序)划分样本的方法,可借助 Subset() 函数实现:

d1 = tud.Subset(mydataset, range(n_train))
d2 = tud.Subset(mydataset, range(n_train, n_train + n_valid))
d3 = tud.Subset(mydataset, range(n_train + n_valid, n_sample))
print(len(d1), len(d2), len(d3))
# Output: 105  30  15

(2) ConcatDataset() 函数

ConcatDataset(datasets):合并 dataset

  • 参数:datasets:List of dataset
  • 也可直接用 + 合并两个数据集

实例:

# 方式一
dc1 = d1 + d2 + d3 
# 方式二:与方式一等价
dc2 = tud.ConcatDataset([d1, d2, d3]) 
print(dc1.__len__(), dc1.__len__())
# Output: 105  105

(3) Subset() 函数

Subset(dataset, indices):从 dataset 中提取子集

注意: Dataset.__getitem__() 方法返回的是具体数据(tuple 类型),而 Subset() 函数返回的是 Dataset

实例:

d4 = tud.Subset(mydataset, [1,2,3,4])
print(d4.__len__())
# Output: 4

3.2 随机采样

SubsetRandomSampler(indices, generator=None)

WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)

代码实例

参考资料

文中代码:Colab, Github

PyTorch, TORCH.UTILS.DATA, site

5-1, Dataset和DataLoader, 20天吃掉那只Pytorch, site

posted @ 2022-05-22 11:56  veager  阅读(365)  评论(0编辑  收藏  举报