PyTorch批训练
PyTorch提供了一种将数据包装起来进行批训练的工具——DataLoader。使用的时候,只需要将我们的数据首先转换为torch的tensor形式,再转换成torch可以识别的Dataset格式,然后将Dataset放入DataLoader中就可以啦。
TensorDataset
classtorch.utils.data.TensorDataset(data_tensor,target_tensor)
TensorDataset类用来将样本及其标签打包成torch的Dataset,data_tensor,和target_tensor都是tensor。
DataLoaderwww.weixiu3721.com
classtorch.utils.data.DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=<function default_collate>,pin_memory=False,drop_last=False)
dataset就是Torch的Dataset格式的对象;batch_size即每批训练的样本数量,默认为;shuffle表示是否需要随机取样本;num_workers表示读取样本的线程数。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步