批数据训练
批数据训练
拥有大量数据时,可把数据分批送入神经网络训练。
分批代码如下:
import torch
import torch.utils.data as Data
if __name__ == '__main__':
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True, # shuffle the data
num_workers=2, # by 2 threads
)
for epoch in range(3): # train the whole data 3 times
for step, (batch_x, batch_y) in enumerate(loader): # enumerate:add index to elements in loader
# fake training
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
输出结果:
Epoch: 0 | Step: 0 | batch x: [3. 1. 9. 6. 2.] | batch y: [ 8. 10. 2. 5. 9.]
Epoch: 0 | Step: 1 | batch x: [10. 8. 5. 4. 7.] | batch y: [1. 3. 6. 7. 4.]
Epoch: 1 | Step: 0 | batch x: [ 4. 2. 10. 6. 9.] | batch y: [7. 9. 1. 5. 2.]
Epoch: 1 | Step: 1 | batch x: [3. 8. 5. 1. 7.] | batch y: [ 8. 3. 6. 10. 4.]
Epoch: 2 | Step: 0 | batch x: [4. 8. 2. 6. 7.] | batch y: [7. 3. 9. 5. 4.]
Epoch: 2 | Step: 1 | batch x: [ 3. 10. 9. 1. 5.] | batch y: [ 8. 1. 2. 10. 6.]
一共有10对数据,批大小BATCH_SIZE=5,因此分为10/5=2批数据,每次完整的训练需要2步,总共进行了3次完整训练。
若BATCH_SIZE=8,则输出结果为:
Epoch: 0 | Step: 0 | batch x: [ 4. 10. 2. 6. 5. 7. 9. 1.] | batch y: [ 7. 1. 9. 5. 6. 4. 2. 10.]
Epoch: 0 | Step: 1 | batch x: [8. 3.] | batch y: [3. 8.]
Epoch: 1 | Step: 0 | batch x: [10. 1. 4. 7. 8. 2. 5. 3.] | batch y: [ 1. 10. 7. 4. 3. 9. 6. 8.]
Epoch: 1 | Step: 1 | batch x: [9. 6.] | batch y: [2. 5.]
Epoch: 2 | Step: 0 | batch x: [10. 4. 8. 7. 5. 6. 3. 9.] | batch y: [1. 7. 3. 4. 6. 5. 8. 2.]
Epoch: 2 | Step: 1 | batch x: [1. 2.] | batch y: [10. 9.]
依然是10对数据,批大小BATCH_SIZE=8,因此分为10/8=1余2批数据,1余2取上界=2批,第一批大小=8,第二批凑不满8则取剩下的2个数据,因此每次完整的训练需要2步,总共进行了3次完整训练。