collate_fn的应用教程
作用
collate_fn:即用于collate的function,用于整理数据的函数。
说到整理数据,你当然要会用数据,即会用数据制作工具torch.utils.data.Dataset,虽然我们今天谈的是torch.utils.data.DataLoader。
collate_fn笼统的说就是用于整理数据,通常我们不需要使用,其应用的情形是:各个数据长度不一样的情况,比如第一张图片大小是2828,第二张是5050,这样的话就如果不自己写collate_fn,而使用默认的,就会报错。
基础
dataset
我们必须先看看torch.utils.data.Dataset如何使用,以一个例子为例:
import torch.utils.data as Data class mydataset(Data.Dataset): def __init__(self,train_inputs,train_targets):#必须有 super(mydataset,self).__init__() self.inputs=train_inputs self.targets=train_targets def __getitem__(self, index):#必须重写 return self.inputs[index],self.targets[index] def __len__(self):#必须重写 return len(self.targets)
#构造训练数据 datax=torch.randn(4,3)#构造4个输入 datay=torch.empty(4).random_(2)#构造4个标签
#制作dataset dataset=mydataset(datax,datay)
下面,可以对dataset进行一系列操作,这些操作返回的结果和你之前那个class的三个函数定义都息息相关。我想说,那三个函数非常自由,你想怎么定义就怎么定义,上述只是一种常见的而已,你可以定制一个特色的。
len(dataset)#调用了你上面定义的def __len__()那个函数 #4
上面的输出结果和你的定义有关,比如你完全可以把def getitem()改成:
def __getitem__(self, index): return self.inputs[index]#不输出标签
那么,
dataset[0]#此时当然变化。 #tensor([-1.1426, -1.3239, 1.8372])
dataloader
torch.utils.data.DataLoader
dataloader=Data.DataLoader(dataset,batch_size=2)
一共有4条数据,batch_size=2,所以一共有2个batch。
collate_fn如果你不指定,会调用pytorch内部的,也就是说这个函数是一定会调用的,而且调用这个函数时pytorch会往这个函数里面传入一个参数batch。
def my_collate(batch): return xxx
这个batch是什么?这个东西和你定义的dataset, batch_size息息相关。batch是一个列表[x, ... , x],长度就是batch_size,里面每一个元素是dataset的某一个元素,即dataset[i]。
在我们的例子中,由于我们没有对dataloader设置需要打乱数据即shuffle=True,那么第1个batch就是前两个数据,如下:
print(datax) print(datay) batch=[dataset[0],dataset[1]] # 所以才说和你dataset中get_item的定义有关。 print(batch)
对,你没有看错,上述代码展示的batch就会传入到pytorch默认的collate_fn中,然后经过默认的处理,输出如下:
it=iter(dataloader) nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果 print(nex)
其实,上面就是我们常用的,经典的输出结果,即输入和标签是分开的,第一项是输入tensor,第二项是标签tensor,输入的维度变成了(batch_size,input_size)。
但是我们乍一看,将第一个batch变成上述输出结果很容易呀,我们也会!我们下面就来自己写一个collate_fn实现这个功能。
# a simple custom collate function, just to show the idea # `batch` is a list of tuple where first element is input tensor and the second element is corresponding label def my_collate(batch): inputs=[data[0].tolist() for data in batch] target = torch.tensor([data[1] for data in batch]) return [data, target]
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
print(datax) print(datay)
it=iter(dataloader) nex=next(it) print(nex)
这不就和默认的collate_fn的输出结果一样了嘛!无非就是默认的还把输入变成了tensor,标签变成了tensor,我上面是列表,我改就是了嘛!如下:
def my_collate(batch): inputs=[data[0].tolist() for data in batch] inputs=torch.tensor(inputs) target =[data[1].tolist() for data in batch] target=torch.tensor(target) return [inputs, target]
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
it=iter(dataloader) nex=next(it) print(nex)
给大家的一个经验就是,一般dataset是不会报错的,而是根据dataset制作dataloader的时候容易报错,因为默认collate_fn把dataset的类型限制得比较死。
应用情形
假设我们还是4个输入,但是维度不固定的。
a=[[1,2],[3,4,5],[1],[3,4,9]] b=[1,0,0,1] dataset=mydataset(a,b) dataloader=Data.DataLoader(dataset,batch_size=2) it=iter(dataloader) nex=next(it) nex
使用默认的collate_fn,直接报错,要求相同维度。
这个时候,我们可以使用自己的collate_fn,避免报错。
不过话说回来,我个人感受是:
在这里避免报错好像也没有什么用,因为大多数的神经网络都是定长输入的,而且很多的操作也要求相同维度才能相加或相乘,所以:这里不报错,后面还是报错。如果后面解决这个问题的方法是:在不足维度上进行补0操作,那么我们为什么不在建立dataset之前先补好呢?所以,collate_fn这个东西的应用场景还是有限的。
https://www.jb51.net/article/237011.htm
https://mp.weixin.qq.com/s/Uc2LYM6tIOY8KyxB7aQrOw
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性