pytorch不定长数据的dataloader读取
参考资料:
https://pytorch.org/docs/stable/data.html#dataloader-collate-fn
https://blog.csdn.net/anshiquanshu/article/details/112868740
在使用Pytorch深度学习框架的时候,一定绕不开的就是dataset和dataloader,后者依赖于前者,并给出了高效加载数据的解决方案(多线程,batch训练等)。
以RGB图片为例,dataset出来的数据形状是(3, H, W),而dataloader出来的数据形状是(batch_size, 3, H, W)。很明显,多了一维即batch维度。这显然是dataloader将数据给“叠”了起来。事实上,dataloader是有一个参数为collate_fn的,它的默认值是None,即当你在使用dataloader并不指定collate_fn的时候,实际上pytorch调用了默认的collate_fn函数,将数据“叠”起来之后再交给你。
然而,当你的数据是不定长的数据的时候,它就没有办法成功把数据叠起来,比如我就遇到了如下报错:
RuntimeError: stack expects each tensor to be equal size, but got [2, 4] at entry 0 and [5, 4] at entry 1
一个数据长度为2,一个数据长度为5,显然无法直接stack?此时在面对不定长数据的时候就需要自定义collate_fn进行填充了。譬如,pytorch文档上有这么一段话:
A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch.
那么,如何自定义一个collate_fn?这个collate_fn的输入和输出又是什么?我们来看一下这个例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | def padding_collate_fn(data_batch): batch_bbox_list = [item[ 'bbox' ] for item in data_batch] batch_label_list = [item[ 'label' ] for item in data_batch] batch_filename_list = [item[ 'filename' ] for item in data_batch] padding_bbox = pad_sequence(batch_bbox_list, batch_first = True , padding_value = 0 ) padding_label = pad_sequence(batch_bbox_list, batch_first = True , padding_value = 5 ) result = dict () result[ "bbox" ] = padding_bbox result[ "label" ] = padding_label result[ "filename" ] = batch_filename_list return result |
首先我原始的dataset输出是一个字典,上述代码就是把字典中的值取出来再叠起来,最后放到大字典中返回。其中pad_sequence这个函数在torch.nn.utils.rnn这个包里,很好用。
实际上,batch就是你的dataset[index] ~ dataset[index + batch_size] 构成的列表,知道这一点后问题就迎刃而解了。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· .NET Core 中如何实现缓存的预热?
· 三行代码完成国际化适配,妙~啊~
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?