Pytorch基础(5)——批数据训练

一、知识点:

  • 相关包:torch.utils.data

import torch
import torch.utils.data as Data
  • 包装数据类:TensorDataset

【包装数据和目标张量的数据集,通过沿着第一个维度索引两个张量来】

class torch.utils.data.TensorDataset(data_tensor, target_tensor)
#data_tensor (Tensor) - 包含样本数据
#target_tensor (Tensor) - 包含样本目标(标签)

 

  • 加载数据类:DataLoader

【数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。】

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
#num_workers (int, optional) – 用多少个子进程加载数据
#drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

 

二、利用torch.utils.data进行批数据训练:

导入包:

import torch
import torch.utils.data as Data

设置参数并创建数据:

Batch_size = 5

x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

将数据包装到TensorDataset中:

torch_dataset = Data.TensorDataset(x , y)

加载数据:

loader = Data.DataLoader(
    dataset = torch_dataset,
    batch_size = Batch_size,
    shuffle=True,
    num_workers = 2,  #采用两个进程来提取
)

epoch 3次,每次epoch的训练步数steps = 2【batch_size = 5,总数据量为10】:

若最后不够一个batch_size,就只拿剩下的。

for epoch in range(3):
    for step , (batch_x,batch_y) in enumerate(loader):
        #training……
        print('epoch:',epoch,
              '| step:',step,
              '| batch_x:',batch_x.numpy(),
              '| batch_y:',batch_y.numpy()    
             )

 

结果:

 

posted on 2018-12-18 20:21  吱吱了了  阅读(3915)  评论(0编辑  收藏  举报

导航