pytorch之 batch_train
1 import torch 2 import torch.utils.data as Data 3 4 torch.manual_seed(1) # reproducible 5 6 BATCH_SIZE = 5 7 # BATCH_SIZE = 8 8 9 x = torch.linspace(1, 10, 10) # this is x data (torch tensor) 10 y = torch.linspace(10, 1, 10) # this is y data (torch tensor) 11 12 torch_dataset = Data.TensorDataset(x, y) 13 loader = Data.DataLoader( 14 dataset=torch_dataset, # torch TensorDataset format 15 batch_size=BATCH_SIZE, # mini batch size 16 shuffle=True, # random shuffle for training 17 num_workers=2, # subprocesses for loading data 18 ) 19 20 21 def show_batch(): 22 for epoch in range(3): # train entire dataset 3 times 23 for step, (batch_x, batch_y) in enumerate(loader): # for each training step 24 # train your data... 25 print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', 26 batch_x.numpy(), '| batch y: ', batch_y.numpy()) 27 28 29 if __name__ == '__main__': 30 show_batch()