多任务学习中的数据分布问题

今天这个专题源于我在做联邦/分布式多任务学习实验时在选取数据集的时候的疑惑,以下我们讨论多任务学习中(尤其是在分布式的环境下)如何选择数据集和定义任务。

多任务学习最初的定义是:"多任务学习是一种归纳迁移机制,基本目标是提高泛化性能。多任务学习通过相关任务训练信号中的领域特定信息来提高泛化能力,利用共享表示采用并行训练的方法学习多个任务"。然而其具体实现手段却有许多(如基于神经网络的和不基于神经网络的,这也是容易让人糊涂的地方),但是不管如何,其关键点——共享表示是核心。

1.经典(非神经网络的)多任务学习

经典(非神经网络的)多任务学习我们已经在博文《基于正则化的多任务学习》中详细讨论,此处不再赘述。在这种模式中给定T个学习任务{Tt}t=1T,每个任务各对应一个数据集Dt={(xti,yti)i=1mt}(其中xtiRdytiR),然后根据根据T个任务的训练集学习T个函数{ft(x)}t=1T。在这种模式下,每个任务的模型假设(比如都是线性函数)都常常是相同,导致每个任务的模型(权重)不同的原因归根结底在于每个任务的数据集不同(每个任务的损失函数默认相同,但其实可同可不同)。 此模式优化的目标函数可以写作:

(1)minWt=1TE(xti,yti)Dt[L(yti,f(xti;wt))]+λg(W)=t=1T[1mti=1mtL(yti,f(xti;wt))]+λg(W)

(此处W=(w1,w2,...,wT)为所有任务参数构成的矩阵,g(W)编码了任务的相关性)

而联邦/分布式多任务学习,采用的数据分布假设也大多来自这种情况(参见我的博客《分布式多任务学习及联邦学习个性化》)。

2.基于神经网络的多任务学习中的数据分布

基于神经网络的多任务学习(也就是大多数在CV、NLP)中使用的那种,分类和定义其实非常会乱,下面我们来看其中的一些常见方式。

2.1 同样的输入数据,不同的loss

大多数基于神经网络的多任务学习采用的方式是各任务基于同样的输入数据(或者可以看做将不同任务的数据混在一起使用),用不同的loss定义不同任务的。

如CV中使用的深度关系多任务学习模型:

CV多任务学习

NLP中的Joint learning:

NLP多任务学习

推荐系统中的用户序列多任务模型:

NLP多任务学习

2.2 不同的输入数据,不同的loss

我们也可以保持共享表示层这一关键特性不变,但是每个任务有不同的输入数据和不同的loss,如下图所示:
NLP多任务学习
在这种架构中,Input x表示不同任务的输入数据,绿色部分表示不同任务之间共享的层,紫色表示每个任务特定的层,Task x表示不同任务对应的损失函数层。在多任务深度网络中,低层次语义信息的共享有助于减少计算量,同时共享表示层可以使得几个有共性的任务更好的结合相关性信息,任务特定层则可以单独建模任务特定的信息,实现共享信息和任务特定信息的统一。

(注意,在深度网络中,多任务的语义信息还可以从不同的层次输出,例如GoogLeNet中的两个辅助损失层。另外一个例子比如衣服图像检索系统,颜色这类的信息可以从较浅层的时候就进行输出判断,而衣服的样式风格这类的信息,更接近高层语义,需要从更高的层次进行输出,这里的输出指的是每个任务对应的损失层的前一层。)

2.3 不同的输入数据,相同的loss

我们想一下,每个任务对应不同的输入数据,相同的loss的情况。比如我们同一个图像分类网络和交叉熵损失,但一个任务的数据集是男人和女人,一个任务数据集是人和狗,我们将这两个数据集进行联合学习,这是否算是多任务学习?如果是,是否能同时提升人-人分类器的精度和人-狗分类器的精度?(如下图所示)
NLP多任务学习

第一个问题,按照经典多任务学习的分类,这种应该是算的,因为每个任务的数据集不同,直接导致了学得的模型不同,又由于有共享表示这一关键特性,也可以算是多任务学习。至于第二个问题,我觉得是可以的,因为这两个任务虽然数据集不同,但是是互相关联的,比如人的话可能会检测头发,狗的话可能会检测耳朵,但是都有一个检测局部特征的相似性在里面。

posted @   orion-orion  阅读(902)  评论(1编辑  收藏  举报
编辑推荐:
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· .NET Core 托管堆内存泄露/CPU异常的常见思路
· PostgreSQL 和 SQL Server 在统计信息维护中的关键差异
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
点击右上角即可分享
微信分享提示