联邦学习:联邦场景下的域泛化
1 导引
1.1 域泛化
域泛化(domain generalization, DG) [1][2]旨在从多个源域中学习一个能够泛化到未知目标域的模型。形式化地说,给定\(K\)个训练的源域数据集\(\mathcal{S}=\left\{\mathcal{S}^k \mid k=1, \cdots, K\right\}\),其中第\(k\)个域的数据被表示为\(\mathcal{S}^k = \left\{\left(x_i^k, y_i^k\right)\right\}_{i=1}^{n^k}\)。这些源域的数据分布各不相同:\(P_{X Y}^k \neq P_{X Y}^l, 1 \leq k \neq l \leq K\)。域泛化的目标是从这\(K\)个源域的数据中学习一个具有强泛化能力的模型:\(h: \mathcal{X}\rightarrow \mathcal{Y}\),使其在一个未知的测试数据集\(\mathcal{T}\)(即\(\mathcal{T}\)在训练过程中不可访问且\(P_{X Y}^{\mathcal{T}} \neq P_{X Y}^k \text { for } k \in\{1, \cdots, K\}\))上具有最小的误差:
这里\(\mathbb{E}\)和\(\ell(\cdot, \cdot)\)分别为期望和损失函数。域泛化示意图如下图所示:
在对域泛化的理论分析方面,我们常常会在协变量偏移(即标签函数\(h^*\)或者说条件分布\(P_{Y\mid X}\)在所有域中都相同)的假设下考虑特定目标域上的风险。设\(\epsilon^1, \cdots, \epsilon^K\)为源域风险,\(\epsilon^t\)为目标域风险。则在协变量偏移的假设下,每个域均可以通过数据\(\mathcal{X}\)上的分布刻画,故域泛化的学习过程可以被认为是在源域分布的凸包\(\Lambda=\{\sum_{k=1}^K\pi_kP^k_X \mid \pi \in \Delta_K\}\)内去找一个目标域分布\(P^t_X\)[22]的最优近似(优化变量\(\pi\)),其中\(\Delta_K\)是\((K - 1)\)维的单纯形,每个\(\pi\)表示一个归一化的混合权重。源域和目标域之间的差异可以通过\(\mathcal{H}-\text{divergence}\)来度量,\(\mathcal{H}-\text{divergence}\)同时包括了假设空间的影响。
域泛化的误差界 设\(\gamma:=\min _{\pi \in \Delta_M} d_{\mathcal{H}}\left(P_X^t, \sum_{k=1}^K \pi_k P_X^k\right)\)为从凸包\(\Lambda\)到目标域特征分布\(P^t_X\)的距离,且\(P_X^*:=\sum_{k=1}^K \pi_k^* P_X^k\)为在\(\Lambda\)内的最优近似(可以理解为\(P^t_X\)在凸包\(\Lambda\)中的投影)。设\(\rho:=\sup _{P_X^{\prime}, P_X^{\prime \prime} \in \Lambda} d_{\mathcal{H}}\left(P_X^{\prime}, P_X^{\prime \prime}\right)\)为凸包\(\Lambda\)的直径。则目标域\(\mathcal{T}\)的风险\(\epsilon^t(h)\)、源域\(k\)的风险\(\epsilon^k(h)\)与\(\gamma\)、\(\rho\)之间满足如下的关系:
这里\(\lambda_{\mathcal{H},\left(P_X^t, P_X^*\right)}\)是目标域和最优近似分布\(P^*_X\)的理想联合风险,在很多情况下我们假设它是一个极小的值,可以忽略不计。那么我们想要最小化目标域的风险,可以:
- 最小化源域风险(对应上界的第一项);
- 最小化源域和目标域之间的表征分布差异来在表征空间中减小\(\gamma\)和\(\rho\)(对应上界的第二项)。
当然上述理论只是提供了一个视角,亦有文献[23]基于Mixup和领域不变表征学习提出了新的理论,他们的方法表明,域不变表征的Mixup本质上是在增大训练域的覆盖范围。还有许多学者进行了基于信息论[24]和对抗训练[22][24][25][26]的研究。
域泛化的方法 目前为了解决域泛化中的域偏移(domain shift) 问题,已经提出了许多方法,大致以分为下列三类:
-
数据操作(data manipulation) 这种方法旨在通过数据增强(data augmentation)或数据生成(data generation)方法来丰富数据的多样性,从而辅助学习更有泛化能力的表征。其中数据增强方法常利用数据变换、对抗数据增强(adversarial data augmentation)[3]等手段来增强数据;数据生成方法则通过Mixup(也即对数据进行两两线性插值)[4]等手段来生成一些辅助样本。
-
表征学习(representation learning) 这种方法旨在通过学习领域不变表征(domain-invariant representations),或者对领域共享(domain-shared)和领域特异(domain-specific)的特征进行特征解耦(feature disentangle),从而增强模型的泛化性能。该类方法我们在往期博客《寻找领域不变量:从生成模型到因果表征 》和《跨域推荐:嵌入映射、联合训练和解耦表征》中亦有详细的论述。其中领域不变表征的学习手段包括了对抗学习[5]、显式表征对齐(如优化分布间的MMD距离)[6]等等,而特征解耦则常常通过优化含有互信息(信息瓶颈的思想)或KL散度[7]的损失项来达成,其中大多数会利用VAE等生成模型。
-
学习策略(learning stategy) 这种方法包括了集成学习[8]、元学习[9]等学习范式。其中,以元学习为基础的方法则利用元学习自发地从构造的任务中学习元知识,这里的构造具体而言是指将源域数据集\(\mathcal{S}\)按照域为单位来拆分成元训练(meta-train)部分\(\bar{\mathcal{S}}\)和元测试(meta-test)部分\(\breve{\mathcal{S}}\)以便对分布偏移进行模拟,最终能够在目标域\(\mathcal{T}\)的final-test中取得良好的泛化表现。
1.2 联邦域泛化
然而,目前大多数域泛化方法需要将不同领域的数据进行集中收集。然而在现实场景下,由于隐私性的考虑,数据常常是分布式收集的。因此我们需要考虑联邦域泛化(federated domain generalization, FedDG) [21]方法。形式化的说,设\(\mathcal{S}=\left\{\mathcal{S}^1, \mathcal{S}^2, \ldots, \mathcal{S}^K\right\}\)表示在联邦场景下的\(K\)个分布式的源域数据集,每个源域数据集包含数据和标签对\(\mathcal{S}^k=\left\{\left(x_i^k, y_i^k\right)\right\}_{i=1}^{n^k}\),采样自域分布\(P_{X Y}^k\)。联邦域泛化的目标是利用\(K\)个分布式的源域学习模型\(h_\theta: \mathcal{X} \rightarrow \mathcal{Y}\),该模型能够泛化到未知的测试域\(\mathcal{T}\)。联邦域泛化的架构如下图所示:
这里需要注意的是,传统的域泛化方法常常要求直接对齐表征或操作数据,这在联邦场景下是违反数据隐私性的。此外对于跨域的联邦学习,由于客户端异构的数据分布/领域偏移(如不同的图像风格)所导致的模型偏差(bias),直接聚合本地模型的参数也会导致次优(sub-optimal)的全局模型,从而更难泛化到新的目标域。因此,许多传统域泛化方法在联邦场景下都不太可行,需要因地制宜进行修改,下面试举几例:
-
对于数据操作的方法,我们常常需要用其它领域的数据来对某个领域的数据进行增强(或进行新数据的插值生成),而这显然违反了数据隐私。目前论文的解决方案是不直接传数据,而传数据的统计量来对数据进行增强[10],这里的统计量指图片的style(即图片逐通道计算的均值和方差)等等。
-
对于表征学习的方法,也需要在对不同域的表征进行共享/对比的条件下获得领域不变表征(或对表征进行分解),而传送表征事实上也违反了数据隐私。目前论文采用的解决方案包括不显式对齐表征,而是使得所有领域的表征显式/隐式地对齐一个参考分布(reference distribution)[11][12],这个参考分布可以是高斯,也可以由GAN来自适应地生成。也有论文不直接对齐表征,而是对齐不同客户端的类别原型[15]。
-
基于学习策略的方法,如元学习也需要利用多个域的数据来构建meta-train和meta-test,并进行元更新(meta-update),而这也违反了数据隐私性。目前论文的解决方案是使用来自其它域的变换后数据来为当前域构造元学习数据集[13],这里的变换后数据指图像的幅度谱等等。此外,有的方法还针对联邦场景的特点,对联邦学习的策略如聚合方式等进行修改[16][18]。
2 论文阅读
2.1 ICLR20 《Federated Adversarial Domain Adaptation》[14]
严格来说,本文属于联邦域自适应范畴(与域泛化的区别在于目标域在训练过程中可访问),不过其方法非常经典,对于联邦域泛化也有较强的指导意义,故在这里也记录一下。本篇论文采用了基于表征学习的方法。具体而言,本文采用对抗学习方法的方法来使得领域间的表征进行对齐,并进一步采用表征解耦来增强知识迁移。本文方法整体的架构如下图所示:
如上图所示,每个源域上都设置有特征提取器\(G_k\),目标域\(T\)上亦设置有特征提取器\(G_t\)(\(G_i\)、\(G_t\)都将做为GAN的生成器使用)。对于每个源域-目标域对\((S_k, T)\),域识别器\(DI\)(做为GAN的判别器)负责去区分源域和目标域的表征,而生成器\((G_k, G_t)\)则尽量去欺骗\(DI\),从而以对抗的方式来完成表征分布的对齐。注意这里\(DI\)只能访问\(G_i\)和\(G_t\)的输出表征,故并不违反联邦的隐私设置。事实上我们在博客《联邦学习:联邦场景下的多源知识图谱嵌入》中提到的联邦跨域知识图谱对齐方法也是基于GAN的思想。
接下来我们来看GAN是如何优化的。首先优化是判别器\(DI_k\)。设第\(k\)个源域的数据为\(\mathbf{X}^{S_k}\),目标域数据为\(\mathbf{X}^T\),则判别器\(DI_k\)的目标函数定义如下:
直观地理解,该目标函数使判别器将\(G_k\)产出的表征打高分,而将\(G_t\)产出的表征打低分,已完成对源域和目标域的表征对齐。
接下来,判别器\(DI_k\)保持不动,按照下列目标函数来更新生成器\(G_k\)、\(G_t\)(注意这里\(G_k\)、\(G_t\)是在各自的计算节点上单独进行更新,这里为了方便写成一个目标函数):
直观地理解,该目标函数使生成器\(G_k\)、\(G_t\)产出的表征都获得较高的判别器得分,以欺骗判别器。
除了GAN模块之外,本文还设计了表征解耦模块,采用对抗性表征解耦来提取领域不变的特征,即将\((G_i, G_t)\)提取到的特征进一步解耦为领域不变(domain-invariant)和领域特异(domain-specific)的特征。正如上面的框架图所示,解耦器\(D_k\)将提取到的特征分离为了\(f_{di}=D_k(G_k(\mathbf{x}^{\mathbf{s}_k}))\)(领域不变)和\(f_{ds}=D_k(G_k(\mathbf{x}^{\mathbf{s}_k}))\)(领域特异)这两个分支(branch)。
针对这两个branch的表征,作者首先设置一个分类器\(C_i\)与一个类识别器\(CI_i\)来分别基于\(f_{di}\)和\(f_{ds}\)特征对标签进行预测,并采用下列的交叉熵损失函数进行训练:
在下一步中,我们冻结类识别器\(CI_k\),并只训练特征解耦器\(D_k\),通过生成领域特异的特征\(f_{ds}\)来欺骗类识别器\(CI_k\)。而这可以通过最小化预测类别分布的负熵损失来达到,目标函数如下所示:
在这里,特征解耦通过保留\(f_{di}\)并消除\(f_{ds}\)来促进知识迁移。
最后,为了增强特征解耦,作者设计了一个互信息项来最小化领域不变特征\(f_{di}\)和领域特异特征\(f_{ds}\)之间的互信息\(I\left(f_{d i} ; f_{d s}\right)\),这里采用MINE来对互信息进行估计[20]:
关于互信息的上下界估计,大家可以参见我的博客《迁移学习:互信息的变分上下界》。为了避免计算积分,这里采用蒙特卡洛积分来计算该估计:
2.2 CVPR21《FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space》[13]
本篇论文是联邦域泛化的第一篇工作。这篇论文属于基于学习策略(采用元学习)的域泛化方法,并通过传图像的幅度谱(amplitude spectrum),而非图像数据本身来构建本地的元学习任务,从而保证联邦场景下的数据隐私性。本文方法的框架示意图如下:
这里\(K\)为领域/客户端的个数。该方法使图像的低级特征——幅度谱在不同客户端间共享,而使高级语义特征——相位谱留在本地。这里再不同客户端间共享的幅度谱就可以作为多领域/多源数据分布供本地元学习训练使用。
接下来我们看本地的元学习部分。元学习的基本思想是通过模拟训练/测试数据集的领域偏移来学得具有泛化性的模型参数。而在本文中,本地客户端的领域偏移来自不同分布的频率空间。具体而言,对每轮迭代,我们考虑本地的原输入图片\(x_{i}^k\)做为meta-train,它的训练搭档\(\mathcal{T}_i^{k}\)则由来自其它客户端的频域产生,做为meta-test来表示分布偏移。
设客户端\(k\)中的图片\(x^k_i\)由正向傅里叶变换\(\mathcal{F}\)得到的幅度谱为\(\mathcal{A}_i^k \in \mathbb{R}^{H \times W \times C}\),相位谱为\(\mathcal{P}_i^k \in \mathbb{R}^{H \times W \times C}\)(\(C\)为图片通道数)。本文欲在客户端之间交换低级分布也即幅度谱信息,因此需要先构建一个供所有客户端共享的distribution bank \(\mathcal{A} = [\mathcal{A}^1, \cdots, \mathcal{A}^K]\),这里\(A^k = {\{\mathcal{A}^k_i\}}^{n^k}_{i=1}\)包含了来自第\(k\)个客户端所有图片的幅度谱信息,可视为代表了\(\mathcal{X}^k\)的分布。
之后,作者通过在频域进行连续插值的手段,将distribution bank中的多源分布信息送到本地客户端。如上图所示,对于第\(k\)个客户端的图片幅度谱\(\mathcal{A}_i^{k}\),我们会将其与另外\(K-1\)个客户端的幅度谱进行插值,其中与第\(l(l\neq k)\)个外部客户端的图片幅度谱\(\mathcal{A}_j\)插值的结果表示为:
这里\(\mathcal{M}\)是一个控制幅度谱内低频成分比例的二值掩码,\(\lambda\)是插值率。然后以此通过反向傅里叶变换生成变换后的图片:
就这样,对于第\(k\)个客户端的输入图片\(x^k_i\),我们就得到了属于不同分布的\(K-1\)个变换后的图片数据\(\mathcal{T}^k_i = \{x^{k\rightarrow l}_i\}_{l\neq k}\),这些图片和\(x^k_i\)共享了相同的语义标签。
接下来在元学习的每轮迭代中,我们将原始数据\(x^k_i\)做为meta-train,并将其对应的\(K-1\)个由频域产生的新数据\(\mathcal{T}^k_i\)做为meta-test来表示分布偏移,从而完成在当前客户端的inner-loop的参数更新。
具体而言,元学习范式可以被分解为两步:
第一步 模型参数\(\theta^k\)在meta-train上通过segmentaion Dice loss \(\mathcal{L}_{seg}\)来更新:
这里参数\(\beta\)表示内层更新的学习率。
第二步 在meta-test数据集\(\mathcal{T}^k_i\)上使用元目标函数(meta objective)\(\mathcal{L}_{meta}\)对已更新的参数\(\hat{\theta}^k\)进行进一步元更新。
这里特别重要的是,第二步所要优化的目标函数由在第一部中所更新的参数\(\hat{\theta}^k\)计算,最终的优化结果覆盖掉原来的参数\(\theta^k\)。
如果我们将一二步合在一起看,则可以视为通过下面目标函数来一起优化关于参数\(\theta^k\)的内层目标函数和元目标函数:
最后,一旦本地训练完成,则来自所有客户端的本地参数\(\theta^k\)会被服务器聚合并更新全局模型。
2.3 Arxiv21《Federated Learning with Domain Generalization 》[12]
本篇论文属于基于学习领域不变表征的域泛化方法,并通过使所有客户端的表征对齐一个由GAN自适应生成的参考分布,而非使客户端之间的表征互相对齐,来保证联邦场景下的数据隐私性。本文方法整体的架构如下图所示:
注意,这里所有客户端共享一个参考分布,而这通过共享同一个分布生成器(distribution generator)来实现。在训练过程一边使每个域(客户端)的数据分布会和参考分布对齐,一边最小化分布生成器的损失函数,使其产生的参考分布接近所有源数据分布的“中心”(这也就是”自适应“的体现)。一旦判别器很难区分从特征提取器中提取的特征和从分布生成器中所生成的特征,此时所提取的特征就被认为是跨多个源域不变的。这里的特征分布生成器的输入为噪声样本和标签的one-hot向量,它会按照一定的分布(即参考分布)生成特征。最后,作者还采用了随机投影层来使得判别器更难区分实际提取的特征和生成器生成的特征,使得对抗网络更稳定。在训练完成之后,参考分布和所有源域的数据分布会对齐,此时学得的特征表征被认为是通用(universal)的,能够泛化到未知的领域。
接下来我们来看GAN部分具体的细节。设\(F(\cdot)\)为特征提取器,\(G(\cdot)\)为分布生成器,\(D(\cdot)\)为判别器。设由特征提取器所提取的特征\(\mathbf{h} = F(\mathbf{x})\)(数据\(\mathbf{x}\)的生成分布为\(p(\mathbf{h})\)),而由分布生成器所产生的特征为\(\mathbf{h}'= G(\mathbf{z})\)(噪声\(\mathbf{z}\)的生成分布为\(p(\mathbf{h}')\)。我们设特征提取器所提取的特征为负例,生成器所生成的特征为正例。
于是,我们可以将判别器的优化目标定义为使将特征提取器所生成的特征\(\mathbf{h}\)判为正类的概率\(D(\mathbf{h}|\mathbf{y})\)更小,而使将生成器所生成的特征\(\mathbf{h}'\)判为正类的概率\(D(\mathbf{h}'|\mathbf{y})\)更大。
生成器尽量使判别器\(D(\cdot)\)将其生成特征\(\mathbf{h}'\)判别为正类的概率\(D\left(\mathbf{h}^{\prime} \mid \mathbf{y}\right)\)更大,以求以假乱真:
特征提取器也需要尽量使得其所生成的特征\(\mathbf{h}\)能够以假乱真:
再加上图像分类本身的交叉熵损失\(\mathcal{L}_{err}\),则总的损失定义为:
论文的最后,作者还对一个问题进行了探讨:关于这里的参考分布,我们为什么不用一个预先选好的确定的分布,要用一个自适应生成的分布呢?那是因为自适应生成的分布有一个重要的好处,那就是少对齐期间的失真(distortion)。作者对多个域/客户端的分布和参考分布进行了可视化,如下图所示:
(a)中为参考分布选择为固定的分布后,与各域特征对比的示意图,图(b)为参考分布选择为自适应生成的分布后,和各域特征对比的示意图。在这两幅图中,红色五角星表示参考分布的特征,除了五角星之外的每种形状代表一个域,每种颜色代表一个类别的样本。可以看到自适应生成的分布和多个源域数据分布的距离,相比固定参考分布和多个源域数据分布的距离更小,因此自适应生成的分布能够减少对齐期间提取特征表征的失真。而更好的失真也就意味着源域数据的关键信息被最大程度的保留,这让本文的方法所得到的表征拥有更好的泛化表现。
2.4 NIPS22 《FedSR: A Simple and Effective Domain Generalization Method for Federated Learning》[11]
本篇论文属于基于学习领域不变表征的域泛化方法,并通过使所有客户端的表征对齐一个高斯参考分布,而非使客户端之间的表征互相对齐,来保证联邦场景下的数据隐私性。本文的动机源于经典机器学习算法的思想,旨在学习一个“简单”(simple)的表征从而获得更好的泛化性能。
首先,作者以生成模型的视角,将表征\(z\)建模为从\(p(z|x)\)中的采样,然后在此基础上定义领域\(k\)的分类目标函数以学得表征:
这里领域\(k\)的样本表征\(z_j^{(i)}\)通过编码器+重参数化从\(p(z|x_k^{(i)})\)中采样产生。
接下来我们来看怎么使得表征更“简单”。本文采用了两个正则项,一个是关于表征的\(L2\)正则项来限制表征中所包含的信息;一个是在给定\(y\)的条件下,\(x\)与\(z\)的条件互信息\(I(x, z\mid y)\)(的上界)来使表征只学习重要的信息,而忽视诸如图片背景之类的伪相关性(spurious correlations)。
关于表征\(z\)的\(L2\)正则项定义如下:
于是,上式的微妙之处在于可以和领域不变表征联系起来,事实上我们有\(\mathcal{L}_k^{L 2 R}=\mathbb{E}_{p_k(x)}\left[\mathbb{E}_{p(z \mid x)}\left[\|z\|_2^2\right]\right]=\mathbb{E}_{p_k(x, z)}\left[\|z\|_2^2\right]=\mathbb{E}_{p_k(z)}\left[\|z\|_2^2\right]=2 \sigma^2 \mathbb{E}_{p_k(z)}[-\log q(z)]=2 \sigma^2 H\left(p_k(z), q(z)\right)\),这里\(H\left(p_k(z), q(z)\right)=H\left(p_k(z)\right)+ D_{\text{KL}} \left[p_k(z) \Vert q(z)\right]\),参考分布\(q(z)=\mathcal{N}\left(0, \sigma^2 I\right)\)。如果\(H(p_i(z))\)在训练中并未发生大的改变,那么最小化\(l_k^{L2R}\)也就是在最小化\(D_{\text{KL}}[p_k(z) \Vert q(z)]\),也即在隐式地对齐一个参考的边缘分布\(q(z)\),而这就使得标准的边缘分布\(p_k(z)\)是跨域不变的。注意该对齐是不需要显式地比较不同客户端分布的。
接下来我们来看条件互信息项。在信息瓶颈理论中,常对\(x\)和表征\(z\)之间的互信息项\(I(x, z)\)进行最小化以对\(z\)中所包含的信息进行加以正则,但是这样的约束在实践中如果系数没调整好,就很可能过于严格了,毕竟它迫使表征不包含数据的信息。因此,在这篇论文中,作者选择最小化给定\(y\)时\(x\)和\(z\)之间的条件互信息。领域\(k\)的条件互信息被计算为:
直观地看,\(\bar{f}_k\)和\(I_k(x, z\mid y)\)共同作用,迫使表征\(z\)仅仅拥有预测标签\(y\)使所包含的信息,而没有关于\(x\)的额外(即和标签无关的)信息。
然而,这个互信息项是难解(intractable)的,这是由于计算\(p_k(z|y)\)很难计算(由于需要对\(x\)进行积分将其边缘化消掉)。因此,作者导出了一个上界来对齐进行最小化:
这里\(r(z|y)\)可以是一个输入\(y\)输出分布\(r(z|y)\)的神经网络,作者将其设置为高斯\(\mathcal{N}\left(z ; \mu_y, \sigma_y^2\right)\),这里\(u_y\),\(\sigma^2_y\)(\(y=1, 2, \cdots, C\))是需要优化的神经网络参数,\(C\)是类别数量。
事实上,该正则项和域泛化中的条件分布对齐亦有着理论上的联系,这是因为\( \mathcal{L}_k^{C M I}=\mathbb{E}_{p_k(x, y)}[D_{\text{KL}}[p(z \mid x) \Vert r(z \mid y)]] \geq \mathbb{E}_{p_k(y)}\left[D_{\text{KL}}\left[p_k(z \mid y) \Vert r(z \mid y)\right]\right] \)。因此,最小化\(\mathcal{L}_k^{CMI}\)我们必然就能够最小化\(D_{\text{KL}}\left[p_k(z \mid y) \Vert r(z \mid y)\right]\)(因为\(\mathcal{L}^{CMI}_k\)是其上界),使得\(p_k(z|y)\)和\(r(z|y)\)互相接近,即:\(p_k(z|y)\approx r(z|y)\)。因此,模型会尝试迫使\(p_k(z \mid y) \approx p_l(z \mid y)(\approx r(z \mid y))\)(对任意客户端/领域\(k, l\))。这也就是说,我们是在做给定标签\(y\)时表征\(z\)的条件分布的隐式对齐,这在传统的领域泛化中是一种很常见与有效的技术,区别就是这里不需要显式地比较不同客户端的分布。
最后,每个客户端的总体目标函数可以表示为:
总结一下,这里\(L2\)范数正则项\(\mathcal{L}_k^{L2R}\)和给定标签时数据和表征的条件互信息\(\mathcal{L}_k^{CMI}\)(的上界)用于限制表征中所包含的信息。此外,\(\mathcal{L}_k^{L2R}\)将边缘分布\(p_k(z)\)对齐到一个聚集在0周围的高斯分布,而\(\mathcal{L}_i^{CMI}\)则将条件分布\(p_k(z|y)\)对齐到一个参考分布(在实验环节作者亦将其选择为高斯)。
2.5 WACV23 《Federated Domain Generalization for Image Recognition via Cross-客户端 Style Transfer》[10]
本篇论文属于基于数据操作的域泛化方法,并通过构造一个style bank供所有客户端共享(类似CVPR21那篇),以使客户端在不共享数据的条件下基于风格(style)来进行数据增强,从而保证联邦场景下的数据隐私性。本文方法整体的架构如下图所示:
如图所示,每个客户端的数据集都有自己的风格。且对于每个客户端而言,都会接受其余客户端的风格来进行数据增强。事实上,这样就可以使得分布式的客户端在不泄露数据的情况下拥有相似的数据分布 。在本方法中,所有客户端的本地模型都拥有一致的学习目标——那就是拟合来自于所有源域的styles,而这种一致性就避免了本地模型之间的模型偏差,从而避免了影响全局模型的效果。此外,本方法可和其它DG的方法结合使用,从而使得其它中心化的DG方法均能得到精度的提升。
关于本文采用的风格迁移模型,有下列要求:1、所有客户端共享的style不能够被用来对数据集进行重构,从而保证数据隐私性;2、用于风格迁移的方法需要是一个实时的任意风格迁移模型,以允许高效和直接的风格迁移。本文最终选择了AdaIN做为本地的风格迁移模型。整个跨客户端/领域风格迁移流程如下图所示:
可以看到,整个跨客户端/领域风格迁移流程分为了三个阶段:
1. Local style Computation
每个客户端需要计算它们的风格并上传到全局服务器。其中可选择单张图片风格(single image style)和整体领域风格(overall domain style )这两种风格来进行计算。
- 单张图片风格 单张图片风格是图片VGG特征的像素级逐通道(channel-wise)均值和方差。比如我们设在第\(k\)个客户端上,随机选取的图片索引为\(i\),其对应的VGG特征\(F_k^{(i)}=\Phi(I^{(i)}_k)\)(这里的\(I^{(i)}_k\)表示图像内容,\(\Phi\)为VGG的编码器),单张图片风格可以被计算为:
如果单张图片风格被用于风格迁移,那么就需要将该客户端不同图片对应的多种风格都上传到服务器,从而避免单张图片的偏差并增加多样性。而这就需要建立本地图片的style bank \(\mathcal{S}_k^{single}\)并将其上传到服务器。这里作者随机选择\(J\)张图像的style加入了本地style bank:
- 整体领域风格 整体领域风格是领域层次的逐通道均值和方差,其中考虑了一个客户端中的所有图片。比如我们假设客户端\(k\)拥有\(N_k\)个训练图片和对应的VGG特征\(\{F_k^{(1)}, F_k^{(2)}, \ldots, F_{k}^{(N_k)}\}\)。则该客户端的整体领域风格\(S_k^{overall}\)为:
相比单张图片风格,整体领域风格的计算代价非常高。不过,由于每个客户端/领域只有一个领域风格\(S_k^{overall}\),选择上传整体领域风格到服务器的通信效率会更高。
2. Style Bank on Server
当服务器接收到来自各个客户端的风格时,它会将所有风格汇总为一个style bank \(\mathcal{B}\) 并将其广播回所有客户端。在两种不同的风格共享模式下,style bank亦会有所不同。
- 单图像风格的style bank \(\mathcal{B}\)为:
- 整体领域风格的style bank \(\mathcal{B}\)为:
\(\mathcal{B}_{single}\)比\(\mathcal{B}_{overall}\)会消耗更多存储空间,因此后者会更加通信友好。
3. Local Style Transfer
当客户端\(k\)收到style bank \(\mathcal{B}\)后,本地数据会通过迁移\(\mathcal{B}\)中的风格来进行增强,而这就将其它领域的风格引入了当前客户端。作者设置了超参数\(L \in\{1,2, \ldots, K\}\)做为增强级别,意为从style bank \(\mathcal{B}\)中随机选择\(L\)个域所对应的风格来对每个图片进行增强,因此\(L\)表明了增强数据集的多样性。设第\(k\)个客户端数据集大小为\(N_k\),则在进行跨客户端的领域迁移之后,增强后数据集的大小会变为\(N_k \times L\)。其中对客户端\(k\)中的每张图片\(I^{(i)}_k\),其对应的每个被选中的域都会拥有一个style vector\(S\)被作为图像生成器\(G\)的输入。这里关于style vector的获取有个细节需要注意:假设我们选了域\(k\),如果迁移的是整体领域风格,则\(S^{overall}_k\)直接即可做为style vector;如果迁移的是单图片风格,则还会进一步从选中\(\mathcal{S}^{single}_k\)中随机选择一个风格\(S_k^{(i)}\)做为域\(k\)的style vector。对以上两种风格模式而言,如果一个域被选中,则其对应的风格化图片就会被直接加入增强后的数据集中。
2.6 CVPR23 《Rethinking Federated Learning with Domain Shift: A Prototype View》[15]
本文属于考虑了领域漂移的异构联邦学习,而不属于域泛化,不过两个领域有很多相似之处,故在这里也记录一下。本篇论文采用了基于表征学习的方法。具体而言,本文采用原型学习的视角,设计了一种聚类原型学习方法来解决领域偏移问题。本文方法整体的架构如下图所示:
如上图所示,首先根据每个客户端属于类别\(c\)(\(c\in \left[ |C| \right]\),\(|C|\)为类别个数)的样本表征集合,来计算出在该客户端上各类别的原型:
这里\(S^c_k\)为客户端\(k\)上属于类别\(c\)的样本集合。
之后,将每种类别对应的原型集合\(\{p^c_k\}^K_{k=1}\)分别聚为\(K^{\prime}\)类,以得到\(K^{\prime}\)个代表性的聚类原型:
在聚完类之后,再对每个类别所对应的\(K^{\prime}\)个聚类原型取平均,得到最终的无偏原型:
作者还画了一张图来解释为什么对类别原型进行聚类可以有效解决域偏移的问题:
如图所示,全局原型无法描述不同领域的信息,并且被潜在的主导领域所支配。而聚类原型和无偏原型则携带着多个域的知识和平稳的优化信号。
在得到了聚类原型和无偏原型之后,作者设计了对比正则项以拉近同类的聚类原型之间的距离,而增大不同类的聚类原型之间的距离:
这里表征\(z_i = f(x_i)\)(\(x_i\)为图片实例),其与对应类别的聚类原型\(p\)之间的相似度定义为:\(s\left(z_i, c\right)=\frac{z_i \cdot c}{\left\|z_i\right\| \times\|c\| / \tau}\)。\(\mathcal{N}^c = \mathcal{P} - \mathcal{P}^c\)为类别不为\(c\)的聚类原型集合。
此外,作者还设计了一致化正则项来拉近表征\(z_i\)和其对应类别的无偏原型\(\mathcal{U}^c\)之间的距离:
这里\(v\)用于索引表征的各个维度。
最后,将\(\mathcal{L}_{C P C L}\)、\(\mathcal{L}_{U P C R}\)和图片分类任务本身的交叉熵\(\mathcal{L}_{\text{CE}}\)加起来,就得到了总的损失函数:
2.7 ArXiv23 《PerAda: Parameter-Efficient and Generalizable Federated Learning Personalization with Guarantees》[16]
本篇论文属于基于学习策略的域泛化方法。具体而言,本文为每个客户端设置了个性化的模型,并在服务器端增设了知识蒸馏过程以从各客户端聚合泛化信息。本文方法整体的架构如下图所示:
本文整体依据Ditto[17]的架构,在客户端本地设置个性化模型\(\{v_k\}^K_{k=1}\)(本文称之为个性化适配器,personalized adapter),并将问题建模为如下的优化问题:
这里\(u\in \mathbb{R}^{d_u}\)表示固定的预训练参数,且\(v_k, w\in\mathbb{R}^{d_a}\)分别表示个性化适配器和全局适配器(global adapter)。
因此,本地客户端的优化也分两步走,先在本轮接收到的全局适配器\(w^t\)的约束条件下,更新本地的个性化适配器:
这里\(\xi_k^{t, s}\)为从本地训练集\(\mathcal{D}_k\)中采样的batch数据。
然后再更新本地适配器(local adapter)\(\theta\):
同样地,这里\(\xi_k^{t, e}\)也表示从本地训练集\(\mathcal{D}_k\)中采样的batch数据。
本文的创新之处在于服务器端采用参数平均\(w^{t} \leftarrow \sum_{k \in \mathcal{S}_t} \frac{1}{\left|\mathcal{S}_t\right|} \theta_k^{t+1}\)(\(\mathcal{S}_t\)为所采样的客户端子集)进行聚合之后,没有直接将聚合所得的全局适配器\(w^t\)广播给客户端,而是继续采用知识蒸馏来对服务器端的全局适配器\(w^t\)进行更新:
这里\(\xi_k^{t, r}\)也表示从本地训练集\(\mathcal{D}_k\)中采样的batch数据,知识蒸馏损失\(\mathcal{R}_{\text{KD}}\)定义如下:
该损失也即在辅助(无标签)数据集\({\mathcal{D}}_{\text{aux}}={\{x_j\}}^{n_{\text{aux}}}_{j=1}\)上,本地适配器的平均logits和全局适配器的logits的平均蒸馏损失。这里\(\ell_{\mathrm{KD}}(a, b)=\mathrm{KL}(\sigma(a / \tau), \sigma(b / \tau))\)为KL散度损失(\(\sigma\)为softmax函数,\(\tau\)为温度)。
2.8 CVPR23 《Federated Domain Generalization with Generalization Adjustment》[18]
本篇论文属于基于学习策略的域泛化方法。具体而言,本文为联邦域泛化问题设计了一种新的目标函数,该目标函数考虑到了各客户端上泛化差距(generalization gap)的方差,从而保证了在所有领域上最优全局模型的平坦性(flatness)。在思想脉络上,本文是由解决OOD问题的经典方法《Out-of-distribution generalization via risk extrapolation》[19]得到的启发。本文方法整体的架构如下图所示:
如上图所示,相比普通FedAvg方法直接对各领域模型按照样本比例\(p_1, p_2, \cdots, p_3\)(训练中固定)进行加权聚合,本文的GA方法按照可学习的权重\(\alpha_i\)来聚合各领域模型。此外,本文方法的目标函数中还带有一个公平性(fairness)正则化项\(\text{Var}\left(\cdot\right)\),可通过动态校准聚合权重来进行优化。
本文方法的全局目标函数如下:
这里\(\mathbf{a}\)为可学习的客户端/域的聚合权重,而\(\beta \in \left[ 0, \infty \right)\)用于控制在减少全局经验风险(即\(\sum_{k=1}^K a_k \widehat{\mathcal{E}}_{\widehat{D}_k}(\theta)\))与加强泛化差距公平性(即\(\text{Var}\left(\cdot\right)\))之间的平衡。当\(\beta=0\)时退化为普通的FedAvg算法,当\(\beta\rightarrow \infty\)时将仅仅去使得泛化差距相等。
\(G_{\widehat{D}_k}\)则计算的是全局模型\(\theta\)和本地模型\(\theta_k\)之间的泛化差距(定义为全局模型在本地训练集上的经验风险-本地模型在本地训练集上的经验风险),定义如下:
那么以上仅仅是给出了目标函数,具体是如何计算求解的呢?具体到每轮迭代上,本地客户端从服务器端接受全局模型\(\theta^t\)后,先计算全局模型\(\theta^t\)与本地模型\(\theta^{t}_k\)的泛化差距\(G_{\widehat{D}_k}\left(\theta^t\right)\);然后完成本地参数更新得到\(\theta^{t+1}_k\),并计算本地模型的经验损失\(\widehat{\mathcal{E}}_{\widehat{D}_k}\left(\theta_k^{t+1}\right)\)(留给下一轮迭代计算\(G_{\widehat{D}_k}\left(\theta^t\right)\)用)。最后,客户端将计算好的泛化差距\(G_{\widehat{D}_k}\left(\theta^t\right)\)与更新后的本地模型\(\theta^{t+1}_k\)发往服务器端。
而服务端则先通过\(\left\{G_{\widehat{D}_k}\left(\theta^{t}\right)\right\}_{k=1}^K\)与上一轮的聚合权重\(\mathbf{a}^{t}\)来计算更新后的聚合权重\(\mathbf{a}^{t+1}\)(动量更新),并对其进行归一化:
这里\(\mu=\frac{1}{K} \sum_{k=1}^K G_{\widehat{D}_k}\left(\theta^t\right)\),且\(d^t=(1-t / T) * d, d\in (0, 1)\)是一个控制每次更新幅值的超参数,可以被视为目标函数中\(\beta\)的替代。
然后服务器端再通过\(\mathbf{a}^{t+1}\)来聚合\(\{\theta_k^{t+1}\}^K_k\)以得到最新的全局模型:
参考
- [1] Wang J, Lan C, Liu C, et al. Generalizing to unseen domains: A survey on domain generalization[J]. IEEE Transactions on Knowledge and Data Engineering, 2022.
- [2] 王晋东,陈益强. 迁移学习导论(第2版)[M]. 电子工业出版社, 2022.
- [3] Volpi R, Namkoong H, Sener O, et al. Generalizing to unseen domains via adversarial data augmentation[C]. Advances in neural information processing systems, 2018, 31.
- [4] Zhou K, Yang Y, Qiao Y, et al. Domain generalization with mixstyle[C]. ICLR, 2021.
- [5] Li H, Pan S J, Wang S, et al. Domain generalization with adversarial feature learning[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 5400-5409.
- [6] Li Y, Gong M, Tian X, et al. Domain generalization via conditional invariant representations[C]//Proceedings of the AAAI conference on artificial intelligence. 2018, 32(1).
- [7] Ilse M, Tomczak J M, Louizos C, et al. Diva: Domain invariant variational autoencoders[C]//Medical Imaging with Deep Learning. PMLR, 2020: 322-348.
- [8] Qin X, Wang J, Chen Y, et al. Domain Generalization for Activity Recognition via Adaptive Feature Fusion[J]. ACM Transactions on Intelligent Systems and Technology, 2022, 14(1): 1-21.
- [9] Li D, Yang Y, Song Y Z, et al. Learning to generalize: Meta-learning for domain generalization[C]//Proceedings of the AAAI conference on artificial intelligence. 2018, 32(1).
- [10] Chen J, Jiang M, Dou Q, et al. Federated Domain Generalization for Image Recognition via Cross-Client Style Transfer[C]//Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2023: 361-370.
- [11] Nguyen A T, Torr P, Lim S N. Fedsr: A simple and effective domain generalization method for federated learning[J]. Advances in Neural Information Processing Systems, 2022, 35: 38831-38843.
- [12] Zhang L, Lei X, Shi Y, et al. Federated learning with domain generalization[J]. arXiv preprint arXiv:2111.10487, 2021.
- [13] Liu Q, Chen C, Qin J, et al. Feddg: Federated domain generalization on medical image segmentation via episodic learning in continuous frequency space[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 1013-1023.
- [14] Peng X, Huang Z, Zhu Y, et al. Federated adversarial domain adaptation[J]. arXiv preprint arXiv:1911.02054, 2019.
- [15] Huang W, Ye M, Shi Z, et al. Rethinking federated learning with domain shift: A prototype view[C]//2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2023: 16312-16322.
- [16] Xie C, Huang D A, Chu W, et al. PerAda: Parameter-Efficient and Generalizable Federated Learning Personalization with Guarantees[J]. arXiv preprint arXiv:2302.06637, 2023.
- [17] Li T, Hu S, Beirami A, et al. Ditto: Fair and robust federated learning through personalization[C]//International Conference on Machine Learning. PMLR, 2021: 6357-6368.
- [18] Zhang R, Xu Q, Yao J, et al. Federated domain generalization with generalization adjustment[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023: 3954-3963.
- [19] Krueger D, Caballero E, Jacobsen J H, et al. Out-of-distribution generalization via risk extrapolation (rex)[C]//International Conference on Machine Learning. PMLR, 2021: 5815-5826.
- [20] Belghazi M I, Baratin A, Rajeshwar S, et al. Mutual information neural estimation[C]//International conference on machine learning. PMLR, 2018: 531-540.
- [21] Li Y, Wang X, Zeng R, et al. Federated Domain Generalization: A Survey[J]. arXiv preprint arXiv:2306.01334, 2023.
- [22] Albuquerque I, Monteiro J, Falk T H, et al. Adversarial target-invariant representation learning for domain generalization[J]. arXiv preprint arXiv:1911.00804, 2019, 8.
- [23] Lu W, Wang J, Yu H, et al. FIXED: Frustratingly Easy Domain Generalization with Mixup[J]. arXiv preprint arXiv:2211.05228, 2022.
- [24] Ye H, Xie C, Cai T, et al. Towards a theoretical framework of out-of-distribution generalization[J]. Advances in Neural Information Processing Systems, 2021, 34: 23519-23531.
- [25] Deshmukh A A, Lei Y, Sharma S, et al. A generalization error bound for multi-class domain generalization[J]. arXiv preprint arXiv:1905.10392, 2019.
- [26] Sicilia A, Zhao X, Hwang S J. Domain adversarial neural networks for domain generalization: When it works and how to improve[J]. Machine Learning, 2023: 1-37.