pytorch批训练数据构造
这是对莫凡python的学习笔记。
1.创建数据
import torch import torch.utils.data as Data BATCH_SIZE = 8 x = torch.linspace(1,10,10) y = torch.linspace(10,1,10)
可以看到创建了两个一维数据,x:1~10,y:10~1
2.构造数据集对象,及数据加载器对象
torch_dataset = Data.TensorDataset(x,y) loader = Data.DataLoader( dataset = torch_dataset, batch_size = BATCH_SIZE, shuffle = False, num_workers = 2)
num_workers应该指的是多线程
3.输出数据集,这一步主要是看一下batch长什么样子
for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): print('Epoch:',epoch,'| Step:', step, '| batch x:', batch_x.numpy(), '| batch y:', batch_y.numpy())
输出如下
('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32)) ('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32)) ('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32)) ('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32)) ('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32)) ('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
可以看到,batch_size等于8,则第二个bacth的数据只有两个。
将batch_size改为5,输出如下
('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32)) ('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32)) ('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32)) ('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32)) ('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32)) ('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))