Pytorch基础(5)——批数据训练
一、知识点:
-
相关包:torch.utils.data
import torch
import torch.utils.data as Data
-
包装数据类:TensorDataset
【包装数据和目标张量的数据集,通过沿着第一个维度索引两个张量来】
class torch.utils.data.TensorDataset(data_tensor, target_tensor)
#data_tensor (Tensor) - 包含样本数据
#target_tensor (Tensor) - 包含样本目标(标签)
-
加载数据类:DataLoader
【数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。】
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
#num_workers (int, optional) – 用多少个子进程加载数据
#drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
二、利用torch.utils.data进行批数据训练:
导入包:
import torch
import torch.utils.data as Data
设置参数并创建数据:
Batch_size = 5
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
将数据包装到TensorDataset中:
torch_dataset = Data.TensorDataset(x , y)
加载数据:
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = Batch_size,
shuffle=True,
num_workers = 2, #采用两个进程来提取
)
epoch 3次,每次epoch的训练步数steps = 2【batch_size = 5,总数据量为10】:
若最后不够一个batch_size,就只拿剩下的。
for epoch in range(3):
for step , (batch_x,batch_y) in enumerate(loader):
#training……
print('epoch:',epoch,
'| step:',step,
'| batch_x:',batch_x.numpy(),
'| batch_y:',batch_y.numpy()
)
结果: