pytorch实现批训练
代码:
#进行批训练 import torch import torch.utils.data as Data BATCH_SIZE = 5 #每批5个数据 if __name__ == '__main__': x = torch.linspace(1, 10, 10) #x是从1到10共10个数据 y = torch.linspace(10, 1, 10) #y是从10到1共10个数据 #torch_dataset = Data.TensorDataset(data_tensor = x, target_tensor=y)会报错 torch_dataset = Data.TensorDataset(x,y) loader = Data.DataLoader( #使我们的训练变成一小批一小批的 dataset = torch_dataset, #将所有数据放入dataset中 batch_size= BATCH_SIZE, shuffle=True, #true训练的时候随机打乱数据,false不打乱 num_workers=2, #每次训练用两个线程或进程进行提取 ) for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): #利用enumerate可以同时获得索引(step)和值 print('Epoch:', epoch, '| Step:', step, '| batch_x:', batch_x.numpy(), '| batch_y:', batch_y.numpy())
过程中遇到了问题,问题及解决办法都在https://blog.csdn.net/thunderf/article/details/94733747
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步