联邦学习:联邦场景下的域泛化

1 导引

1.1 域泛化

域泛化(domain generalization, DG) [1][2]旨在从多个源域中学习一个能够泛化到未知目标域的模型。形式化地说,给定K个训练的源域数据集S={Skk=1,,K},其中第k个域的数据被表示为Sk={(xik,yik)}i=1nk。这些源域的数据分布各不相同:PXYkPXYl,1klK。域泛化的目标是从这K个源域的数据中学习一个具有强泛化能力的模型:h:XY,使其在一个未知的测试数据集T(即T在训练过程中不可访问且PXYTPXYk for k{1,,K})上具有最小的误差:

minhE(x,y)T[(h(x),y)]

这里E(,)分别为期望和损失函数。域泛化示意图如下图所示:

在对域泛化的理论分析方面,我们常常会在协变量偏移(即标签函数h或者说条件分布PYX在所有域中都相同)的假设下考虑特定目标域上的风险。设ϵ1,,ϵK为源域风险,ϵt为目标域风险。则在协变量偏移的假设下,每个域均可以通过数据X上的分布刻画,故域泛化的学习过程可以被认为是在源域分布的凸包Λ={k=1KπkPXkπΔK}内去找一个目标域分布PXt[22]的最优近似(优化变量π),其中ΔK(K1)维的单纯形,每个π表示一个归一化的混合权重。源域和目标域之间的差异可以通过Hdivergence来度量,Hdivergence同时包括了假设空间的影响。

域泛化的误差界γ:=minπΔMdH(PXt,k=1KπkPXk)为从凸包Λ到目标域特征分布PXt的距离,且PX:=k=1KπkPXk为在Λ内的最优近似(可以理解为PXt在凸包Λ中的投影)。设ρ:=supPX,PXΛdH(PX,PX)为凸包Λ的直径。则目标域T的风险ϵt(h)、源域k的风险ϵk(h)γρ之间满足如下的关系:

ϵt(h)k=1Kπkϵk(h)+γ+ρ2+λH,(PXt,PX)

这里λH,(PXt,PX)是目标域和最优近似分布PX的理想联合风险,在很多情况下我们假设它是一个极小的值,可以忽略不计。那么我们想要最小化目标域的风险,可以:

  • 最小化源域风险(对应上界的第一项);
  • 最小化源域和目标域之间的表征分布差异来在表征空间中减小γρ(对应上界的第二项)。

当然上述理论只是提供了一个视角,亦有文献[23]基于Mixup和领域不变表征学习提出了新的理论,他们的方法表明,域不变表征的Mixup本质上是在增大训练域的覆盖范围。还有许多学者进行了基于信息论[24]和对抗训练[22][24][25][26]的研究。

域泛化的方法 目前为了解决域泛化中的域偏移(domain shift) 问题,已经提出了许多方法,大致以分为下列三类:

  • 数据操作(data manipulation) 这种方法旨在通过数据增强(data augmentation)或数据生成(data generation)方法来丰富数据的多样性,从而辅助学习更有泛化能力的表征。其中数据增强方法常利用数据变换、对抗数据增强(adversarial data augmentation)[3]等手段来增强数据;数据生成方法则通过Mixup(也即对数据进行两两线性插值)[4]等手段来生成一些辅助样本。

  • 表征学习(representation learning) 这种方法旨在通过学习领域不变表征(domain-invariant representations),或者对领域共享(domain-shared)和领域特异(domain-specific)的特征进行特征解耦(feature disentangle),从而增强模型的泛化性能。该类方法我们在往期博客《寻找领域不变量:从生成模型到因果表征 》《跨域推荐:嵌入映射、联合训练和解耦表征》中亦有详细的论述。其中领域不变表征的学习手段包括了对抗学习[5]、显式表征对齐(如优化分布间的MMD距离)[6]等等,而特征解耦则常常通过优化含有互信息(信息瓶颈的思想)或KL散度[7]的损失项来达成,其中大多数会利用VAE等生成模型。

  • 学习策略(learning stategy) 这种方法包括了集成学习[8]、元学习[9]等学习范式。其中,以元学习为基础的方法则利用元学习自发地从构造的任务中学习元知识,这里的构造具体而言是指将源域数据集S按照域为单位来拆分成元训练(meta-train)部分S¯和元测试(meta-test)部分S˘以便对分布偏移进行模拟,最终能够在目标域T的final-test中取得良好的泛化表现。

1.2 联邦域泛化

然而,目前大多数域泛化方法需要将不同领域的数据进行集中收集。然而在现实场景下,由于隐私性的考虑,数据常常是分布式收集的。因此我们需要考虑联邦域泛化(federated domain generalization, FedDG) [21]方法。形式化的说,设S={S1,S2,,SK}表示在联邦场景下的K个分布式的源域数据集,每个源域数据集包含数据和标签对Sk={(xik,yik)}i=1nk,采样自域分布PXYk。联邦域泛化的目标是利用K个分布式的源域学习模型hθ:XY,该模型能够泛化到未知的测试域T。联邦域泛化的架构如下图所示:

这里需要注意的是,传统的域泛化方法常常要求直接对齐表征或操作数据,这在联邦场景下是违反数据隐私性的。此外对于跨域的联邦学习,由于客户端异构的数据分布/领域偏移(如不同的图像风格)所导致的模型偏差(bias),直接聚合本地模型的参数也会导致次优(sub-optimal)的全局模型,从而更难泛化到新的目标域。因此,许多传统域泛化方法在联邦场景下都不太可行,需要因地制宜进行修改,下面试举几例:

  • 对于数据操作的方法,我们常常需要用其它领域的数据来对某个领域的数据进行增强(或进行新数据的插值生成),而这显然违反了数据隐私。目前论文的解决方案是不直接传数据,而传数据的统计量来对数据进行增强[10],这里的统计量指图片的style(即图片逐通道计算的均值和方差)等等。

  • 对于表征学习的方法,也需要在对不同域的表征进行共享/对比的条件下获得领域不变表征(或对表征进行分解),而传送表征事实上也违反了数据隐私。目前论文采用的解决方案包括不显式对齐表征,而是使得所有领域的表征显式/隐式地对齐一个参考分布(reference distribution)[11][12],这个参考分布可以是高斯,也可以由GAN来自适应地生成。也有论文不直接对齐表征,而是对齐不同客户端的类别原型[15]

  • 基于学习策略的方法,如元学习也需要利用多个域的数据来构建meta-train和meta-test,并进行元更新(meta-update),而这也违反了数据隐私性。目前论文的解决方案是使用来自其它域的变换后数据来为当前域构造元学习数据集[13],这里的变换后数据指图像的幅度谱等等。此外,有的方法还针对联邦场景的特点,对联邦学习的策略如聚合方式等进行修改[16][18]

2 论文阅读

2.1 ICLR20 《Federated Adversarial Domain Adaptation》[14]

严格来说,本文属于联邦域自适应范畴(与域泛化的区别在于目标域在训练过程中可访问),不过其方法非常经典,对于联邦域泛化也有较强的指导意义,故在这里也记录一下。本篇论文采用了基于表征学习的方法。具体而言,本文采用对抗学习方法的方法来使得领域间的表征进行对齐,并进一步采用表征解耦来增强知识迁移。本文方法整体的架构如下图所示:

如上图所示,每个源域上都设置有特征提取器Gk,目标域T上亦设置有特征提取器GtGiGt都将做为GAN的生成器使用)。对于每个源域-目标域对(Sk,T),域识别器DI(做为GAN的判别器)负责去区分源域和目标域的表征,而生成器(Gk,Gt)则尽量去欺骗DI,从而以对抗的方式来完成表征分布的对齐。注意这里DI只能访问GiGt的输出表征,故并不违反联邦的隐私设置。事实上我们在博客《联邦学习:联邦场景下的多源知识图谱嵌入》中提到的联邦跨域知识图谱对齐方法也是基于GAN的思想。

接下来我们来看GAN是如何优化的。首先优化是判别器DIk。设第k个源域的数据为XSk,目标域数据为XT,则判别器DIk的目标函数定义如下:

LadvDIk(XSk,XT,Gk,Gt)=ExskXsk[logDIk(Gk(xsk))]ExtXt[log(1DIk(Gt(xt)))]

直观地理解,该目标函数使判别器将Gk产出的表征打高分,而将Gt产出的表征打低分,已完成对源域和目标域的表征对齐。

接下来,判别器DIk保持不动,按照下列目标函数来更新生成器GkGt(注意这里GkGt是在各自的计算节点上单独进行更新,这里为了方便写成一个目标函数):

LadvGk(XSk,XT,DIk)=ExskXsk[logDIk(Gk(xsk))]ExtXt[logDIk(Gt(xt))]

直观地理解,该目标函数使生成器GkGt产出的表征都获得较高的判别器得分,以欺骗判别器。

除了GAN模块之外,本文还设计了表征解耦模块,采用对抗性表征解耦来提取领域不变的特征,即将(Gi,Gt)提取到的特征进一步解耦为领域不变(domain-invariant)和领域特异(domain-specific)的特征。正如上面的框架图所示,解耦器Dk将提取到的特征分离为了fdi=Dk(Gk(xsk))(领域不变)和fds=Dk(Gk(xsk))(领域特异)这两个分支(branch)。

针对这两个branch的表征,作者首先设置一个分类器Ci与一个类识别器CIi来分别基于fdifds特征对标签进行预测,并采用下列的交叉熵损失函数进行训练:

Lcross-entropy =E(xsk,ysk)D^skc=1|C|1[c=ysk]log(Ck(fdi))E(xsk,ysk)D^skc=1|C|1[c=ysk]log(CIk(fds))

在下一步中,我们冻结类识别器CIk,并只训练特征解耦器Dk,通过生成领域特异的特征fds来欺骗类识别器CIk。而这可以通过最小化预测类别分布的负熵损失来达到,目标函数如下所示:

Lent=1nki=1nklogCIk(fds(i))=1nki=1nklogCIk(Dk(Gk(x(i))))

在这里,特征解耦通过保留fdi并消除fds来促进知识迁移。

最后,为了增强特征解耦,作者设计了一个互信息项来最小化领域不变特征fdi和领域特异特征fds之间的互信息I(fdi;fds),这里采用MINE来对互信息进行估计[20]

I(P;Q^)n=supθΘEPPO(n)[Tθ]log(EPP(n)P^O(n)[eTθ])=PPQn(p,q)T(p,q,θ)log(PPn(p)PQn(q)eT(p,q,θ))

关于互信息的上下界估计,大家可以参见我的博客《迁移学习:互信息的变分上下界》。为了避免计算积分,这里采用蒙特卡洛积分来计算该估计:

I(P,Q)=1ni=1nT(p,q,θ)log(1ni=1neT(p,q,θ))

2.2 CVPR21《FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space》[13]

本篇论文是联邦域泛化的第一篇工作。这篇论文属于基于学习策略(采用元学习)的域泛化方法,并通过传图像的幅度谱(amplitude spectrum),而非图像数据本身来构建本地的元学习任务,从而保证联邦场景下的数据隐私性。本文方法的框架示意图如下:

这里K为领域/客户端的个数。该方法使图像的低级特征——幅度谱在不同客户端间共享,而使高级语义特征——相位谱留在本地。这里再不同客户端间共享的幅度谱就可以作为多领域/多源数据分布供本地元学习训练使用。

接下来我们看本地的元学习部分。元学习的基本思想是通过模拟训练/测试数据集的领域偏移来学得具有泛化性的模型参数。而在本文中,本地客户端的领域偏移来自不同分布的频率空间。具体而言,对每轮迭代,我们考虑本地的原输入图片xik做为meta-train,它的训练搭档Tik则由来自其它客户端的频域产生,做为meta-test来表示分布偏移。

设客户端k中的图片xik由正向傅里叶变换F得到的幅度谱为AikRH×W×C,相位谱为PikRH×W×CC为图片通道数)。本文欲在客户端之间交换低级分布也即幅度谱信息,因此需要先构建一个供所有客户端共享的distribution bank A=[A1,,AK],这里Ak={Aik}i=1nk包含了来自第k个客户端所有图片的幅度谱信息,可视为代表了Xk的分布。

之后,作者通过在频域进行连续插值的手段,将distribution bank中的多源分布信息送到本地客户端。如上图所示,对于第k个客户端的图片幅度谱Aik,我们会将其与另外K1个客户端的幅度谱进行插值,其中与第l(lk)个外部客户端的图片幅度谱Aj插值的结果表示为:

Aikl=(1λ)Aik(1M)+λAjlM

这里M是一个控制幅度谱内低频成分比例的二值掩码,λ是插值率。然后以此通过反向傅里叶变换生成变换后的图片:

xikl=F1(Aikl,Pik)

就这样,对于第k个客户端的输入图片xik,我们就得到了属于不同分布的K1个变换后的图片数据Tik={xikl}lk,这些图片和xik共享了相同的语义标签。

接下来在元学习的每轮迭代中,我们将原始数据xik做为meta-train,并将其对应的K1个由频域产生的新数据Tik做为meta-test来表示分布偏移,从而完成在当前客户端的inner-loop的参数更新。

具体而言,元学习范式可以被分解为两步:

第一步 模型参数θk在meta-train上通过segmentaion Dice loss Lseg来更新:

θ^k=θkβθkLseg(xik;θk)

这里参数β表示内层更新的学习率。

第二步 在meta-test数据集Tik上使用元目标函数(meta objective)Lmeta对已更新的参数θ^k进行进一步元更新。

Lmeta=Lseg(Tik;θ^k)+γLboundary(xik,Tik;θ^k)

这里特别重要的是,第二步所要优化的目标函数由在第一部中所更新的参数θ^k计算,最终的优化结果覆盖掉原来的参数θk

如果我们将一二步合在一起看,则可以视为通过下面目标函数来一起优化关于参数θk的内层目标函数和元目标函数:

argminθk Lseg(xik;θk)+Lmeta(xik,Tik;θ^k)

最后,一旦本地训练完成,则来自所有客户端的本地参数θk会被服务器聚合并更新全局模型。

2.3 Arxiv21《Federated Learning with Domain Generalization 》[12]

本篇论文属于基于学习领域不变表征的域泛化方法,并通过使所有客户端的表征对齐一个由GAN自适应生成的参考分布,而非使客户端之间的表征互相对齐,来保证联邦场景下的数据隐私性。本文方法整体的架构如下图所示:

注意,这里所有客户端共享一个参考分布,而这通过共享同一个分布生成器(distribution generator)来实现。在训练过程一边使每个域(客户端)的数据分布会和参考分布对齐,一边最小化分布生成器的损失函数,使其产生的参考分布接近所有源数据分布的“中心”(这也就是”自适应“的体现)。一旦判别器很难区分从特征提取器中提取的特征和从分布生成器中所生成的特征,此时所提取的特征就被认为是跨多个源域不变的。这里的特征分布生成器的输入为噪声样本和标签的one-hot向量,它会按照一定的分布(即参考分布)生成特征。最后,作者还采用了随机投影层来使得判别器更难区分实际提取的特征和生成器生成的特征,使得对抗网络更稳定。在训练完成之后,参考分布和所有源域的数据分布会对齐,此时学得的特征表征被认为是通用(universal)的,能够泛化到未知的领域。

接下来我们来看GAN部分具体的细节。设F()为特征提取器,G()为分布生成器,D()为判别器。设由特征提取器所提取的特征h=F(x)(数据x的生成分布为p(h)),而由分布生成器所产生的特征为h=G(z)(噪声z的生成分布为p(h)。我们设特征提取器所提取的特征为负例,生成器所生成的特征为正例。

于是,我们可以将判别器的优化目标定义为使将特征提取器所生成的特征h判为正类的概率D(h|y)更小,而使将生成器所生成的特征h判为正类的概率D(h|y)更大。

Ladv_d=(Exp(h)[(1D(hy))2]+Ezp(h)[D(hy)2])

生成器尽量使判别器D()将其生成特征h判别为正类的概率D(hy)更大,以求以假乱真:

Ladvg=Ezp(h)[(1D(hy))2]

特征提取器也需要尽量使得其所生成的特征h能够以假乱真:

Ladv_f=Exp(h)[(1D(hy))2]

再加上图像分类本身的交叉熵损失Lerr,则总的损失定义为:

LFedADG=Ladv_d+Ladv_g+λ0Ladv_f+λ1Lerr

论文的最后,作者还对一个问题进行了探讨:关于这里的参考分布,我们为什么不用一个预先选好的确定的分布,要用一个自适应生成的分布呢?那是因为自适应生成的分布有一个重要的好处,那就是少对齐期间的失真(distortion)。作者对多个域/客户端的分布和参考分布进行了可视化,如下图所示:

(a)中为参考分布选择为固定的分布后,与各域特征对比的示意图,图(b)为参考分布选择为自适应生成的分布后,和各域特征对比的示意图。在这两幅图中,红色五角星表示参考分布的特征,除了五角星之外的每种形状代表一个域,每种颜色代表一个类别的样本。可以看到自适应生成的分布和多个源域数据分布的距离,相比固定参考分布和多个源域数据分布的距离更小,因此自适应生成的分布能够减少对齐期间提取特征表征的失真。而更好的失真也就意味着源域数据的关键信息被最大程度的保留,这让本文的方法所得到的表征拥有更好的泛化表现。

2.4 NIPS22 《FedSR: A Simple and Effective Domain Generalization Method for Federated Learning》[11]

本篇论文属于基于学习领域不变表征的域泛化方法,并通过使所有客户端的表征对齐一个高斯参考分布,而非使客户端之间的表征互相对齐,来保证联邦场景下的数据隐私性。本文的动机源于经典机器学习算法的思想,旨在学习一个“简单”(simple)的表征从而获得更好的泛化性能。

首先,作者以生成模型的视角,将表征z建模为从p(z|x)中的采样,然后在此基础上定义领域k的分类目标函数以学得表征:

fk¯(w)=Epk(x,y)[Ep(zx)[logp^(yz)]]1nki=1nklogp^(yk(i)zk(i))

这里领域k的样本表征zj(i)通过编码器+重参数化从p(z|xk(i))中采样产生。

接下来我们来看怎么使得表征更“简单”。本文采用了两个正则项,一个是关于表征的L2正则项来限制表征中所包含的信息;一个是在给定y的条件下,xz的条件互信息I(x,zy)(的上界)来使表征只学习重要的信息,而忽视诸如图片背景之类的伪相关性(spurious correlations)。

关于表征zL2正则项定义如下:

LkL2R=Epk(x)[Ep(zx)[z22]]1nki=1nkzk(i)22a

于是,上式的微妙之处在于可以和领域不变表征联系起来,事实上我们有LkL2R=Epk(x)[Ep(zx)[z22]]=Epk(x,z)[z22]=Epk(z)[z22]=2σ2Epk(z)[logq(z)]=2σ2H(pk(z),q(z)),这里H(pk(z),q(z))=H(pk(z))+DKL[pk(z)q(z)],参考分布q(z)=N(0,σ2I)。如果H(pi(z))在训练中并未发生大的改变,那么最小化lkL2R也就是在最小化DKL[pk(z)q(z)],也即在隐式地对齐一个参考的边缘分布q(z),而这就使得标准的边缘分布pk(z)是跨域不变的。注意该对齐是不需要显式地比较不同客户端分布的。

接下来我们来看条件互信息项。在信息瓶颈理论中,常对x和表征z之间的互信息项I(x,z)进行最小化以对z中所包含的信息进行加以正则,但是这样的约束在实践中如果系数没调整好,就很可能过于严格了,毕竟它迫使表征不包含数据的信息。因此,在这篇论文中,作者选择最小化给定yxz之间的条件互信息。领域k的条件互信息被计算为:

Ik(x,zy)=Epk(x,y,z)[logpk(x,zy)pk(xy)pk(zy)]

直观地看,f¯kIk(x,zy)共同作用,迫使表征z仅仅拥有预测标签y使所包含的信息,而没有关于x的额外(即和标签无关的)信息。

然而,这个互信息项是难解(intractable)的,这是由于计算pk(z|y)很难计算(由于需要对x进行积分将其边缘化消掉)。因此,作者导出了一个上界来对齐进行最小化:

LkCMI=Epk(x,y)[DKL[p(zx)r(zy)]]Ik(x,zy)

这里r(z|y)可以是一个输入y输出分布r(z|y)的神经网络,作者将其设置为高斯N(z;μy,σy2),这里uyσy2y=1,2,,C)是需要优化的神经网络参数,C是类别数量。

事实上,该正则项和域泛化中的条件分布对齐亦有着理论上的联系,这是因为LkCMI=Epk(x,y)[DKL[p(zx)r(zy)]]Epk(y)[DKL[pk(zy)r(zy)]]。因此,最小化LkCMI我们必然就能够最小化DKL[pk(zy)r(zy)](因为LkCMI是其上界),使得pk(z|y)r(z|y)互相接近,即:pk(z|y)r(z|y)。因此,模型会尝试迫使pk(zy)pl(zy)(r(zy))(对任意客户端/领域k,l)。这也就是说,我们是在做给定标签y时表征z的条件分布的隐式对齐,这在传统的领域泛化中是一种很常见与有效的技术,区别就是这里不需要显式地比较不同客户端的分布。

最后,每个客户端的总体目标函数可以表示为:

Lk=fk¯+αL2RLkL2R+αCMILkCMI

总结一下,这里L2范数正则项LkL2R和给定标签时数据和表征的条件互信息LkCMI(的上界)用于限制表征中所包含的信息。此外,LkL2R将边缘分布pk(z)对齐到一个聚集在0周围的高斯分布,而LiCMI则将条件分布pk(z|y)对齐到一个参考分布(在实验环节作者亦将其选择为高斯)。

2.5 WACV23 《Federated Domain Generalization for Image Recognition via Cross-客户端 Style Transfer》[10]

本篇论文属于基于数据操作的域泛化方法,并通过构造一个style bank供所有客户端共享(类似CVPR21那篇),以使客户端在不共享数据的条件下基于风格(style)来进行数据增强,从而保证联邦场景下的数据隐私性。本文方法整体的架构如下图所示:

如图所示,每个客户端的数据集都有自己的风格。且对于每个客户端而言,都会接受其余客户端的风格来进行数据增强。事实上,这样就可以使得分布式的客户端在不泄露数据的情况下拥有相似的数据分布 。在本方法中,所有客户端的本地模型都拥有一致的学习目标——那就是拟合来自于所有源域的styles,而这种一致性就避免了本地模型之间的模型偏差,从而避免了影响全局模型的效果。此外,本方法可和其它DG的方法结合使用,从而使得其它中心化的DG方法均能得到精度的提升。

关于本文采用的风格迁移模型,有下列要求:1、所有客户端共享的style不能够被用来对数据集进行重构,从而保证数据隐私性;2、用于风格迁移的方法需要是一个实时的任意风格迁移模型,以允许高效和直接的风格迁移。本文最终选择了AdaIN做为本地的风格迁移模型。整个跨客户端/领域风格迁移流程如下图所示:

可以看到,整个跨客户端/领域风格迁移流程分为了三个阶段:

1. Local style Computation

每个客户端需要计算它们的风格并上传到全局服务器。其中可选择单张图片风格(single image style)和整体领域风格(overall domain style )这两种风格来进行计算。

  • 单张图片风格 单张图片风格是图片VGG特征的像素级逐通道(channel-wise)均值和方差。比如我们设在第k个客户端上,随机选取的图片索引为i,其对应的VGG特征Fk(i)=Φ(Ik(i))(这里的Ik(i)表示图像内容,Φ为VGG的编码器),单张图片风格可以被计算为:

Sk(i)=(μ(Fk(i)),σ(Fk(i)))

如果单张图片风格被用于风格迁移,那么就需要将该客户端不同图片对应的多种风格都上传到服务器,从而避免单张图片的偏差并增加多样性。而这就需要建立本地图片的style bank Sksingle并将其上传到服务器。这里作者随机选择J张图像的style加入了本地style bank:

Sksingle={Sk(i1),,Sk(iJ)}

  • 整体领域风格 整体领域风格是领域层次的逐通道均值和方差,其中考虑了一个客户端中的所有图片。比如我们假设客户端k拥有Nk个训练图片和对应的VGG特征{Fk(1),Fk(2),,Fk(Nk)}。则该客户端的整体领域风格Skoverall为:

Skoverall=(μ(Fkall),σ(Fkall))Fkall=Stack(Fk(1),Fk(2),,Fk(Nk))

相比单张图片风格,整体领域风格的计算代价非常高。不过,由于每个客户端/领域只有一个领域风格Skoverall,选择上传整体领域风格到服务器的通信效率会更高。

2. Style Bank on Server

当服务器接收到来自各个客户端的风格时,它会将所有风格汇总为一个style bank B 并将其广播回所有客户端。在两种不同的风格共享模式下,style bank亦会有所不同。

  • 单图像风格的style bank B为:

Bsingle={Sksinglek=1,2,K}

  • 整体领域风格的style bank B为:

Boverall={Skoverallk=1,2,,K}

BsingleBoverall会消耗更多存储空间,因此后者会更加通信友好。

3. Local Style Transfer

当客户端k收到style bank B后,本地数据会通过迁移B中的风格来进行增强,而这就将其它领域的风格引入了当前客户端。作者设置了超参数L{1,2,,K}做为增强级别,意为从style bank B中随机选择L个域所对应的风格来对每个图片进行增强,因此L表明了增强数据集的多样性。设第k个客户端数据集大小为Nk,则在进行跨客户端的领域迁移之后,增强后数据集的大小会变为Nk×L。其中对客户端k中的每张图片Ik(i),其对应的每个被选中的域都会拥有一个style vectorS被作为图像生成器G的输入。这里关于style vector的获取有个细节需要注意:假设我们选了域k,如果迁移的是整体领域风格,则Skoverall直接即可做为style vector;如果迁移的是单图片风格,则还会进一步从选中Sksingle中随机选择一个风格Sk(i)做为域k的style vector。对以上两种风格模式而言,如果一个域被选中,则其对应的风格化图片就会被直接加入增强后的数据集中。

2.6 CVPR23 《Rethinking Federated Learning with Domain Shift: A Prototype View》[15]

本文属于考虑了领域漂移的异构联邦学习,而不属于域泛化,不过两个领域有很多相似之处,故在这里也记录一下。本篇论文采用了基于表征学习的方法。具体而言,本文采用原型学习的视角,设计了一种聚类原型学习方法来解决领域偏移问题。本文方法整体的架构如下图所示:

如上图所示,首先根据每个客户端属于类别cc[|C|]|C|为类别个数)的样本表征集合,来计算出在该客户端上各类别的原型:

pkc=1|Skc|(xi,yi)Skcfk(xi)

这里Skc为客户端k上属于类别c的样本集合。

之后,将每种类别对应的原型集合{pkc}k=1K分别聚为K类,以得到K个代表性的聚类原型:

Pc={pkc}k=1K Cluster {pkc}k=1KRK×dP={P1,,Pc,,P|C|}

在聚完类之后,再对每个类别所对应的K个聚类原型取平均,得到最终的无偏原型:

Uc=1Kk=1KpkcRdU=[U1,,Uc,,UC]

作者还画了一张图来解释为什么对类别原型进行聚类可以有效解决域偏移的问题:

如图所示,全局原型无法描述不同领域的信息,并且被潜在的主导领域所支配。而聚类原型和无偏原型则携带着多个域的知识和平稳的优化信号。

在得到了聚类原型和无偏原型之后,作者设计了对比正则项以拉近同类的聚类原型之间的距离,而增大不同类的聚类原型之间的距离:

LCPCL=logpPcexp(s(zi,p))pPcexp(s(zi,c))+pNcexp(s(zi,p))

这里表征zi=f(xi)xi为图片实例),其与对应类别的聚类原型p之间的相似度定义为:s(zi,c)=ziczi×c/τNc=PPc为类别不为c的聚类原型集合。

此外,作者还设计了一致化正则项来拉近表征zi和其对应类别的无偏原型Uc之间的距离:

LUPCR=v=1d(zi,vUvk)2

这里v用于索引表征的各个维度。

最后,将LCPCLLUPCR和图片分类任务本身的交叉熵LCE加起来,就得到了总的损失函数:

L=LCPCL+LUPCR+LCE

2.7 ArXiv23 《PerAda: Parameter-Efficient and Generalizable Federated Learning Personalization with Guarantees》[16]

本篇论文属于基于学习策略的域泛化方法。具体而言,本文为每个客户端设置了个性化的模型,并在服务器端增设了知识蒸馏过程以从各客户端聚合泛化信息。本文方法整体的架构如下图所示:

本文整体依据Ditto[17]的架构,在客户端本地设置个性化模型{vk}k=1K(本文称之为个性化适配器,personalized adapter),并将问题建模为如下的优化问题:

min{vk}1Kk=1KPk(vk,w), (Personal Obj)  with Pk(vk,w):=Lk((u,vk))+λ2vkw2

这里uRdu表示固定的预训练参数,且vk,wRda分别表示个性化适配器和全局适配器(global adapter)。

因此,本地客户端的优化也分两步走,先在本轮接收到的全局适配器wt的约束条件下,更新本地的个性化适配器:

vkt,s+1vkt,sηp(Lk((u,vkt,s),ξkt,s)+λ(vkt,swt))

这里ξkt,s为从本地训练集Dk中采样的batch数据。

然后再更新本地适配器(local adapter)θ

θkt,e+1θkt,eηlLk((u,θkt,e),ξkt,e)

同样地,这里ξkt,e也表示从本地训练集Dk中采样的batch数据。

本文的创新之处在于服务器端采用参数平均wtkSt1|St|θkt+1St为所采样的客户端子集)进行聚合之后,没有直接将聚合所得的全局适配器wt广播给客户端,而是继续采用知识蒸馏来对服务器端的全局适配器wt进行更新:

wt,r+1wt,rηgβwRKD(u,{θkt+1}kSt,wt,r,ξt,r)

这里ξkt,r也表示从本地训练集Dk中采样的batch数据,知识蒸馏损失RKD定义如下:

RKD(u,{θk}k=1K,w):=j=1nauxKD(k=1Kf((u,θk),xj)K,f((u,w),xj))

该损失也即在辅助(无标签)数据集Daux={xj}j=1naux上,本地适配器的平均logits和全局适配器的logits的平均蒸馏损失。这里KD(a,b)=KL(σ(a/τ),σ(b/τ))为KL散度损失(σ为softmax函数,τ为温度)。

2.8 CVPR23 《Federated Domain Generalization with Generalization Adjustment》[18]

本篇论文属于基于学习策略的域泛化方法。具体而言,本文为联邦域泛化问题设计了一种新的目标函数,该目标函数考虑到了各客户端上泛化差距(generalization gap)的方差,从而保证了在所有领域上最优全局模型的平坦性(flatness)。在思想脉络上,本文是由解决OOD问题的经典方法《Out-of-distribution generalization via risk extrapolation》[19]得到的启发。本文方法整体的架构如下图所示:

如上图所示,相比普通FedAvg方法直接对各领域模型按照样本比例p1,p2,,p3(训练中固定)进行加权聚合,本文的GA方法按照可学习的权重αi来聚合各领域模型。此外,本文方法的目标函数中还带有一个公平性(fairness)正则化项Var(),可通过动态校准聚合权重来进行优化。

本文方法的全局目标函数如下:

minθ1,,θK,aE^D^(θ)=k=1KakE^D^k(θ)+βVar({GD^k(θ)}k=1K) s.t. k=1Kak=1,θ=k=1Kakθk, and k,ak0

这里a为可学习的客户端/域的聚合权重,而β[0,)用于控制在减少全局经验风险(即k=1KakE^D^k(θ))与加强泛化差距公平性(即Var())之间的平衡。当β=0时退化为普通的FedAvg算法,当β时将仅仅去使得泛化差距相等。

GD^k则计算的是全局模型θ和本地模型θk之间的泛化差距(定义为全局模型在本地训练集上的经验风险-本地模型在本地训练集上的经验风险),定义如下:

GD^k(θ)=E^D^k(θ)E^D^k(θk)

那么以上仅仅是给出了目标函数,具体是如何计算求解的呢?具体到每轮迭代上,本地客户端从服务器端接受全局模型θt后,先计算全局模型θt与本地模型θkt的泛化差距GD^k(θt);然后完成本地参数更新得到θkt+1,并计算本地模型的经验损失E^D^k(θkt+1)(留给下一轮迭代计算GD^k(θt)用)。最后,客户端将计算好的泛化差距GD^k(θt)与更新后的本地模型θkt+1发往服务器端。

而服务端则先通过{GD^k(θt)}k=1K与上一轮的聚合权重at来计算更新后的聚合权重at+1(动量更新),并对其进行归一化:

akt+1=GA(akt,{GD^k(θt)}k=1K,dt)=(GD^k(θt)μ)dtmaxl(GD^l(θt)μ)+aktakt+1=akt+1l=1Kalt+1

这里μ=1Kk=1KGD^k(θt),且dt=(1t/T)d,d(0,1)是一个控制每次更新幅值的超参数,可以被视为目标函数中β的替代。

然后服务器端再通过at+1来聚合{θkt+1}kK以得到最新的全局模型:

θt+1=k=1Kakt+1θkt+1

参考

  • [1] Wang J, Lan C, Liu C, et al. Generalizing to unseen domains: A survey on domain generalization[J]. IEEE Transactions on Knowledge and Data Engineering, 2022.
  • [2] 王晋东,陈益强. 迁移学习导论(第2版)[M]. 电子工业出版社, 2022.
  • [3] Volpi R, Namkoong H, Sener O, et al. Generalizing to unseen domains via adversarial data augmentation[C]. Advances in neural information processing systems, 2018, 31.
  • [4] Zhou K, Yang Y, Qiao Y, et al. Domain generalization with mixstyle[C]. ICLR, 2021.
  • [5] Li H, Pan S J, Wang S, et al. Domain generalization with adversarial feature learning[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 5400-5409.
  • [6] Li Y, Gong M, Tian X, et al. Domain generalization via conditional invariant representations[C]//Proceedings of the AAAI conference on artificial intelligence. 2018, 32(1).
  • [7] Ilse M, Tomczak J M, Louizos C, et al. Diva: Domain invariant variational autoencoders[C]//Medical Imaging with Deep Learning. PMLR, 2020: 322-348.
  • [8] Qin X, Wang J, Chen Y, et al. Domain Generalization for Activity Recognition via Adaptive Feature Fusion[J]. ACM Transactions on Intelligent Systems and Technology, 2022, 14(1): 1-21.
  • [9] Li D, Yang Y, Song Y Z, et al. Learning to generalize: Meta-learning for domain generalization[C]//Proceedings of the AAAI conference on artificial intelligence. 2018, 32(1).
  • [10] Chen J, Jiang M, Dou Q, et al. Federated Domain Generalization for Image Recognition via Cross-Client Style Transfer[C]//Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2023: 361-370.
  • [11] Nguyen A T, Torr P, Lim S N. Fedsr: A simple and effective domain generalization method for federated learning[J]. Advances in Neural Information Processing Systems, 2022, 35: 38831-38843.
  • [12] Zhang L, Lei X, Shi Y, et al. Federated learning with domain generalization[J]. arXiv preprint arXiv:2111.10487, 2021.
  • [13] Liu Q, Chen C, Qin J, et al. Feddg: Federated domain generalization on medical image segmentation via episodic learning in continuous frequency space[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 1013-1023.
  • [14] Peng X, Huang Z, Zhu Y, et al. Federated adversarial domain adaptation[J]. arXiv preprint arXiv:1911.02054, 2019.
  • [15] Huang W, Ye M, Shi Z, et al. Rethinking federated learning with domain shift: A prototype view[C]//2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2023: 16312-16322.
  • [16] Xie C, Huang D A, Chu W, et al. PerAda: Parameter-Efficient and Generalizable Federated Learning Personalization with Guarantees[J]. arXiv preprint arXiv:2302.06637, 2023.
  • [17] Li T, Hu S, Beirami A, et al. Ditto: Fair and robust federated learning through personalization[C]//International Conference on Machine Learning. PMLR, 2021: 6357-6368.
  • [18] Zhang R, Xu Q, Yao J, et al. Federated domain generalization with generalization adjustment[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023: 3954-3963.
  • [19] Krueger D, Caballero E, Jacobsen J H, et al. Out-of-distribution generalization via risk extrapolation (rex)[C]//International Conference on Machine Learning. PMLR, 2021: 5815-5826.
  • [20] Belghazi M I, Baratin A, Rajeshwar S, et al. Mutual information neural estimation[C]//International conference on machine learning. PMLR, 2018: 531-540.
  • [21] Li Y, Wang X, Zeng R, et al. Federated Domain Generalization: A Survey[J]. arXiv preprint arXiv:2306.01334, 2023.
  • [22] Albuquerque I, Monteiro J, Falk T H, et al. Adversarial target-invariant representation learning for domain generalization[J]. arXiv preprint arXiv:1911.00804, 2019, 8.
  • [23] Lu W, Wang J, Yu H, et al. FIXED: Frustratingly Easy Domain Generalization with Mixup[J]. arXiv preprint arXiv:2211.05228, 2022.
  • [24] Ye H, Xie C, Cai T, et al. Towards a theoretical framework of out-of-distribution generalization[J]. Advances in Neural Information Processing Systems, 2021, 34: 23519-23531.
  • [25] Deshmukh A A, Lei Y, Sharma S, et al. A generalization error bound for multi-class domain generalization[J]. arXiv preprint arXiv:1905.10392, 2019.
  • [26] Sicilia A, Zhao X, Hwang S J. Domain adversarial neural networks for domain generalization: When it works and how to improve[J]. Machine Learning, 2023: 1-37.
posted @   orion-orion  阅读(1377)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
历史上的今天:
2022-05-13 Python:conda install 和pip install的区别
2022-05-13 用Docker打包Python运行环境
点击右上角即可分享
微信分享提示