MoNA:复用跨模态预训练模型,少样本模态的福音 | ICML'24
跨模态转移旨在利用大型预训练模型来完成可能不属于预训练数据模态的任务。现有的研究在将经典微调扩展到跨模态场景方面取得了一定的成功,但仍然缺乏对模态差距对转移的影响的理解。在这项工作中,进行了一系列关于转移过程中源表示质量的实验,揭示了更大的模态差距与较少知识重用之间的联系,这意味着转移效果不佳。然后,使用条件分布 \(P(Y|X)\) 将这种差距形式化为不同模态之间的知识不对齐。针对这个问题,论文提出了“模态知识对齐”(
MoNA
)方法,这是一种元学习方法,旨在学习目标数据变换,以减少转移前的模态知识差异。实验证明,论文的方法实现了更好地重用跨模态转移中的源模态知识,从而对现有微调方法的改进。来源:晓飞的算法工程笔记 公众号,转载请注明出处
论文: Learning Modality Knowledge Alignment for Cross-Modality Transfer
Introduction
将来自过去经验的知识转移至新任务是人类智能的基本能力。在机器学习社区中,不断追求这种获取和重用知识的能力,旨在构建能够更准确预测并更高效学习数据的人工智能系统。如今,由于在大量数据上进行训练的大型基础模型已广泛可用,因此使用这样的预训练模型作为新任务的强大特征提取器已成为迁移学习的常见做法。自然地,预训练模型和下游任务来自同一模态,例如,在ImageNet
上预训练的视觉Transformer
模型和CIFAR-100
分类任务。然而,最近的研究尝试扩展到跨模态转移,例如使用视觉Transformer
进行音频分类,并对表格数据进行微调语言模型。
这种跨模态转移的动机易于理解,特别是当目标模态数据稀缺时。科学任务,如心电图分类和蛋白质距离预测,在收集大量训练数据方面存在困难,并且进一步需要人类专家进行昂贵的注释成本。在这种情况下,利用其他数据更容易收集的模态(如视觉和语言)的预训练模型,来帮助目标模式任务是可取的。然而,由于两个挑战,跨模态转移并不像模态内转移那样直接:1
)跨模态的输入空间和标签空间不同,2
)解决不同模态任务所需的知识也可能不同。
先前的研究通过设计模态特定的嵌入器和预测器来应对第一个挑战,以便从输入到输出与预训练模型进行接口。然而,第二个挑战尚未得到很好解决。一些方法认为,大型预训练模型可以作为通用编码器,并在微调过程中冻结预训练模型。其他方法则同时微调预训练模型和模态特定组件。这两种方法实证表明,预训练模型可以迁移到其他模态。然而,源模态中哪些知识通过预训练模型进行了转移,以及这些知识如何有利于目标模态,仍然是一个未解决的核心问题。例如,ORCA
观察到,在某些目标模态任务上从头开始训练模型甚至比对预训练模型进行普通微调更好。这表明如果不恰当地转移,预训练模型中包含的知识可能不会提高目标性能。
在这项工作中,论文深入探讨了跨模态转移的第二个挑战。首先进行实验,研究目标模态微调如何影响源模态数据的表示质量。论文观察到,在一些目标模态任务上微调预训练的Swin Transformer
可以帮助Swin
编码器提取更具有区分性的图像特征,而在其他模态上微调则会削弱这种能力。这一实证观察表明,在不同程度上,不同模态之间可能存在知识面,称之为模态语义知识,这些知识会影响跨模态转移的有效性。
为了明确模态之间差异的这一方面,将模态语义知识解释为条件概率分布 \(P(Y|X)\) 。根据目标模态的任务修改源模态的条件分布,使两者可以进行比较。因此,能够将模态知识差异形式化为源模态和目标模态的条件分布之间的差异。当目标条件分布与修改后的源条件分布相似时,称模态语义知识是对齐的,预训练模型学习的源区分功能可以被重用于目标模态。相反,模态语义知识相互矛盾,可能没有相互促进的作用,这解释了ORCA
中的观察。
论文的解释为理解先前跨模态转移作品提出的两阶段调优流程的有效性提供了新的视角:将第一阶段视为针对目标模态的隐式数据转换学习,从而使转换后数据的条件分布与源数据更加对齐。因此,这启示可以在微调之前直接学习一个适当的目标嵌入函数,有助于最小化知识不对齐。基于此,论文提出了一种新方法MoNA
,通过两阶段训练改善跨模态转移。在第一阶段,MoNA
利用元学习来学习一个最优的目标嵌入器,在全面微调时作为初始化之一,结合预训练权重,实现在全面微调过程中最大程度地重用源模态知识。在第二阶段,利用学习到的目标嵌入器作为起点,沿用传统微调方法,更新所有参数以适应目标任务,同时最大程度地利用源知识。
论文在两个跨模态转移基准数据集NAS-Bench-360
和PDEBench
上进行了大量实验,以验证假设和提出的方法的有效性。这两个基准数据集都集中在与科学问题相关的模态上,其中训练数据的稀缺性尤为严重。对MoNA
与先前方法进行了比较,实验结果表明该方法表现出色。
Problem Formulation and Analysis
Introduction to basic notations and architecture
考虑源模态 \(\mathcal{M}^s\) 和目标模态 \(\mathcal{M}^t\) 之间的知识转移。源模态中的数据(如视觉或语言数据)更容易获取且成本更低,同时大型预训练模型也是公开可用的。相反,目标模态数据不足以预训练自己的大型模型。这两个模态在输入空间和标签空间上均存在差异,即 \(\mathcal{X}^s\neq\mathcal{X}^t\) , \(\mathcal{Y}^s\neq\mathcal{Y}^t\) 。跨模态转移旨在利用源预训练模型(由参数 \(\boldsymbol{\theta}^{\mathcal{S}}\) 参数化)来帮助处理目标任务,该目标任务仅具有一小组带标签数据 \(\{\boldsymbol{x}_i^{t}, y_i^{t}\}_{i=1}^{n_t}\) 。
根据先前的研究,模型结构 \(g_{\boldsymbol{\theta}}\) 包括一个嵌入器 \(e(\cdot;\boldsymbol{\theta}_e)\) ,一个Transformer
编码器 \(f(\cdot;\boldsymbol{\theta}_f)\) 和一个预测器 \(h(\cdot;\boldsymbol{\theta}_h)\) ,整个模型的参数表示为 \(\boldsymbol{\theta} = \{\boldsymbol{\theta}_e, \boldsymbol{\theta}_f, \boldsymbol{\theta}_h\}\) 。特别地,预训练的Transformer
具有自己的嵌入器和预测器,因此将源模型的预训练权重表示为 \(\boldsymbol{\theta}^{\mathcal{S}}_0 = \{\boldsymbol{\theta}_{e_0}^\mathcal{S}, \boldsymbol{\theta}_{f_0}^\mathcal{S}, \boldsymbol{\theta}_{h_0}^\mathcal{S}\}\) 。
嵌入器将输入数据映射到共享的输入嵌入空间 \(\hat{\mathcal{X}}\) ,编码器从嵌入的输入中提取特征。预测器是一个线性层,将编码器的输出映射到标签空间上。对于目标模型 \(g_{\boldsymbol{\theta}^\mathcal{T}}: \boldsymbol{\theta}^\mathcal{T} = \{\boldsymbol{\theta}^\mathcal{T}_e, \boldsymbol{\theta}^\mathcal{T}_f, \boldsymbol{\theta}^\mathcal{T}_h\}\) ,嵌入器和预测器均经过特定重新设计以适应目标任务的输入和标签空间,同时使用 \(\boldsymbol{\theta}_{f_0}^\mathcal{S}\) 来初始化编码器权重 \(\boldsymbol{\theta}^\mathcal{T}_f\) 。
这种架构的灵活性使得能够在目标任务上进行端到端的训练,简单地通过在给定训练数据集上最小化特定任务损失来微调目标模型的所有参数:
这里的 \(\ell\) 是任务损失函数,例如交叉熵。
通过这种方式直接从目标监督中学习,鼓励模型学习有助于区分目标数据的知识。由于预训练模型已经包含源领域的辨别知识,因此跨模态转移自然期望源领域和目标领域的知识在某些方面相似,以便源领域的知识可以被重用来促进目标学习。接下来,1
)进行实验表明这种相似性取决于模态,2
)提供模态知识的解释并形式化知识差异。
-
Detailed Explanation of the Model Architecture
特定于模态的嵌入器和预测器的实现完全按照ORCA
中的设计进行。
特定于模态的嵌入器的结构取决于任务是2D
还是1D
。
-
对于
2D
任务,嵌入器由线性投影层和LayerNorm
操作组成。对于任何大小为 \(C\times H\times W\) 的输入数据,其中 \(C\) 、 \(H\) 和 \(W\) 分别表示通道数、高度和宽度。首先将其调整大小为 \(C\times 224^2\) 并分成大小为 \(C\times 4^2\) 的 \(N\) 个图像块,然后线性投影层将每个图像块映射到大小为 \(128\) 的标记,LayerNorm
操作应用于所有映射的图像块。因此,嵌入器可以表示为一个函数 \(e_{2D}: \mathbb R^{N\times 16C}\to\mathbb R^{N\times 128}\) 。 -
对于
1D
任务,嵌入器由线性投影层、LayerNorm
操作和可学习的位置嵌入组成。对于任何大小为 \(C\times L\) 的输入数据,其中 \(C\) 和 \(L\) 分别表示通道数和序列长度。首先将其分成大小为 \(C\times \frac{L}{N}\) 的 \(N\) 个块,然后线性投影层将每个块映射到大小为 \(768\) 的令牌,LayerNorm
操作应用于所有投影的块,最后将位置嵌入添加到块中。因此,嵌入器可以表示为一个函数 \(e_{1D}: \mathbb R^{CL}\to\mathbb R^{768N}\) 。
特定于模态的预测器的结构取决于任务是分类还是密集预测。
-
对于分类任务,预测器由一个平均池化层和一个线性投影层组成。平均池化层将大小为 \(N'\times d\) 的密集特征图平均为大小为 \(d\) 的特征,然后线性投影层将特征映射到大小为 \(K\) 的对数值,其中 \(d\) 和 \(K\) 分别表示特征维度和类别数量。因此,预测器可以表示为一个函数 \(h_{c}: \mathbb R^{N'd} \to \mathbb R^K\) 。
-
对于密集预测任务,预测器由一个线性投影层、一个像素重新排列操作和两个自适应池化层组成。线性投影层以大小为 \(7^2\times d\) 的密集特征图作为输入,输出大小为 \(7^2 \times 3072\) 的新特征,然后重新组织成形状为 \(224^2 \times 3\) 。接下来,两个池化操作依次应用,将特征大小从 \(3 \times 224^2\) 变为 \(K \times 224^2\) ,最终变为 \(K \times H \times W\) ,与输入的空间维度相符。因此,预测器可以表示为一个函数 \(h_{d}:\mathbb R^{49d}\to\mathbb R^{KHW}\) 。
Distortion of learned source modality knowledge
论文寻找一种定量比较不同跨模态转移场景中知识重用程度的方法。选择图像模态作为知识源,并从不同模态中选择四个目标任务,括两个与图像密切相关的任务:CIFAR-100
,其中包含球形投影图像,以及两个与图像模态不相似的任务:表示手势的NinaPro
和包含声音事件音频剪辑的FSD50K
。具体来说,采用在ImageNet-22k
上预训练的Swin Transformer Base
作为源模型,并在不同任务上微调后检查模型的属性。
考虑到比较是在不同模态间进行的,缺乏一个通用的度量标准来衡量转移过程中知识重用程度。因此,转而比较源知识的失真程度。具体来说,如果更多的源知识被重用来解决目标任务,则认为失真会更小,反之亦然。因此,利用预训练的源模型提取CIFAR-10
的视觉表示,这是模型未见过的替代图像数据集。该特定源数据集中的样本被表示为 \(\{\boldsymbol{x}^s_i,y^s_i\}\) ,其对应的特征集为 \(\{\boldsymbol{f}^s_i = f(e(\boldsymbol{x}^s_i;\boldsymbol{\theta}_{e_0}^\mathcal{S});\boldsymbol{\theta}_{f_0}^\mathcal{S})\}\) 。然后,分别使用公式1
在四个目标任务上对预训练模型进行微调。在微调过程之后,再次利用微调后的编码器提取CIFAR-10
的表示,并得到 \(\{\boldsymbol{f}^s_i(\mathcal{M}_t) = f(e(\boldsymbol{x}^s_i;\boldsymbol{\theta}_{e_0}^\mathcal{S});\boldsymbol{\theta}_f^\mathcal{T}, \mathcal{M}_t)\}\) 。
图2
展示了五组不同的CIFAR-10
图像特征的T-SNE
可视化结果。该图表明,微调在CIFAR-100
或Spherical
上的编码器在CIFAR-10
图像样本上保持或甚至提高了它们的可区分性,而在NinaPro
和FSD50K
上微调的编码器则无法提取适用于图像的类别判别特征。考虑到在目标模态上微调会使得编码器专注于对目标数据进行分类和学习目标判别函数,这一观察表明,相较于后两种模态,用于区分CIFAR-100
和Spherical
样本所需的知识与用于CIFAR-10
样本所需的知识更加一致。这样的结论符合论文的直觉,因为CIFAR-100
是视觉数据集,Spherical
源自自然图像,而NinaPro
和FSD50K
与图像相关性较低。
另一方面,结果显示,CIFAR-100
和Spherical
能更好地重用预训练编码器中的源知识来解决任务,而NinaPro
和FSD50K
需要编码器进行更大调整,以适应目标任务。
为了更全面地定量研究跨模态转移过程中源知识的重用(或失真),在CIFAR-10
上使用线性探针评估使用不同目标模态微调的编码器提取表示的质量,分别考虑:1
)不同的微调目标模态,2
)不同的训练轮数,以及3
)不同的转移方法。除了普通微调之外,还考虑以下两个基线:
ORCA
在微调之前添加了一个嵌入器训练阶段。第一阶段仅更新目标嵌入器参数 \(\boldsymbol{\theta}_e^\mathcal{T}\) ,在共享输入空间 \(\hat{\mathcal{X}}\) 中最小化源嵌入和目标嵌入之间的最优数据集转移距离(经过嵌入器处理后的分布尽量相似)。- 论文提出了另一个基线方法,即从先前的工作修改而来的
Embedder warmup
(Emb
),这也是一种两阶段训练方法。第一阶段仅通过使用与普通微调相同的任务损失来更新目标嵌入器,同时保持网络的其余部分冻结。第二阶段对整个网络进行微调。
图3
显示了线性探测的错误率,虚线表示预训练编码器上的线性探测结果作为参考。请注意,所有这些结果都是在CIFAR-10
数据集上的错误率,反映了模型保留源模态知识的程度,暂时不关注目标模态的性能比较。从实验中,可以观察到模态对线性探测结果有最大的影响。在FSD50K
上进行微调显著扭曲了编码器并损害了其在图像数据上的可区分性。在目标数据集上进行更多轮数的微调会导致对所有目标模态的源知识更大的扭曲,除了图像模态(CIFAR-100
)。这些观察结果导致了一个结论,即在不同模态中区分样本的知识在不同程度上有所不同,将其称为模态语义知识的错位。论文认为,巨大的差异可能阻碍跨模态转移的有效性,因此,关于源模态预训练对目标模态有益的假设应该取决于这种差异。
论文对两阶段训练方法的源知识保留效果做出了额外观察。与普通微调相比,ORCA
和Emb
都实现了更低的源错误率,并且Emb
的表现优于ORCA
。这表明,在它们的第一阶段训练中,目标嵌入器隐式学习了一个从 \(\mathcal{X}^t\) 到 \(\hat{\mathcal{X}}\) 的映射,缓解了目标和源之间的知识错位,并因此减少了模型在适应目标任务期间的扭曲。
Modality semantic knowledge discrepancy
考虑使用条件分布 \(P(Y|X)\) 来表示模态内的语义知识,该条件分布描述了模态的原始数据空间和语义空间之间的关系。这是因为对于神经网络而言,获取语义知识意味着学习一个从数据空间到语义空间的映射,这个映射类似于真实的条件分布。
然而,衡量两个模态之间这种知识的一致性或“相似度”是非常具有挑战性的。困难在于,数据空间 \(\mathcal{X}\) 和标签空间 \(\mathcal{Y}\) 在不同的模态之间甚至是不同且不重叠的。因此,需要修改条件分布以使其在不同模态之间可比较。修改输入空间相对较容易,因为可以使用特定于模态的嵌入器将输入嵌入到一个共享空间 \(\hat{\mathcal{X}}\) 中。然而,修改标签空间则更加复杂。
考虑到源模态(如视觉和语言)拥有大型预训练模型且在语义上都非常丰富,论文做出以下假设:源模态标签空间的基数大于目标模态标签空间的基数,即 \(|\mathcal{Y}^s| = |\mathcal{Y}^t|\)。
这一假设在实践中很容易得到满足。例如,在ImageNet
上训练的视觉Transformer
可以学习一个包含一千个类别的判别函数,而在心电图分类任务中仅考虑了四类。基于这一假设,可以选择源模态标签空间 \(\mathcal{Y}^s_{\mathcal{B}} \subset \mathcal{Y}^s\) 的子集,使得 \(|\mathcal{Y}^s_{\mathcal{B}}| = |\mathcal{Y}^t|\) 。进一步引入一个类别置换 \(\pi(\cdot)\) ,调整源类别的顺序。因此,可以定义一个新的源模态标签空间,即在置换后的源子集 \(\mathcal{Y}^s_{\pi,\mathcal{B}} \triangleq \pi(\mathcal{Y}^s_{\mathcal{B}})\) 。通过衡量修改后的条件分布 \(P(Y^s_{\pi,\mathcal{B}}|\hat{X})\) 和 \(P(Y^t|\hat{X})\) 之间的差异,可以形式化模态语义知识的对齐程度如下:
给定满足假设的源模态 \(\mathcal{M}^s\) 和目标模态 \(\mathcal{M}^t\) ,设 \(\hat{\mathcal{X}}\) 是由特定于模态的嵌入器从原始数据空间生成的共享输入空间, \(P(Y^s|\hat{X})\) , \(P(Y^t|\hat{X})\) 分别是源模态和目标模态的条件分布。那么,两个模态之间的模态语义知识差异为
其中 $ d(\cdot,\cdot)$ 是两个条件分布之间的任意差异度量。
该定义基本上表示,如果能在源语义中找到一个最优子集,并在源语义和目标语义之间进行适当的一一对应匹配,使其与目标模态具有相似的条件分布,那么知识差异被认为很小。源模型应当能够像在子集内辨别源样本一样正确地区分目标样本。
通过该定义,论文使用一种极端近似算法计算图像模态与四个目标任务之间的模态语义知识差异。如图4
所示,与先前的观察相一致,表明不同模态确实具有不同程度的知识差异,而在这四个任务中,FSD50K
是与图像模态最不相似的模态。
Modality Knowledge Alignment
发现模态知识可能没有很好地对齐以及源知识重用不足的后果,论文提出了一种新方法MoNA
,完整算法如算法1
所示。该方法改善模态知识的对齐性,提高了跨模态传输的效果。
Embedder Warmup
在先前的实验中,论文发现嵌入器warmup
尽管训练目标很简单,却比其他方法更好地保留源知识。相应地,开始测试其在目标模态上的表现,同样优于其对应的方法。论文认为,在嵌入器warmup
过程中,为了最小化任务损失,嵌入器被明确地强制将目标原始输入投影到源模型所冻结并根据源知识提取特征的可区分嵌入中。
结合先前的分析,假设有效转移的关键是学习一个目标嵌入函数 \(e^\mathcal{T}: \mathcal{X}\to \hat{\mathcal{X}}\) ,使得目标条件分布 \(P(Y^t|\hat{X})\) 更加与源知识对齐。因此,论文建议在完整微调过程之前,提前学习这个嵌入函数。
Learning to Align Modality Knowledge
由于无法在没有训练模型的情况下估计目标条件概率,直接将模态知识差异作为优化目标是困难的。作为替代方案,建议利用元学习流程来模拟图3
中的过程,并在微调后优化源数据的表示质量。具体来说,一个理想的目标嵌入器会对齐模态知识,使得编码器在目标微调过程中保持其在图像数据上的可区分性。因此,如果使用一个源数据集来评估由这个理想目标嵌入器初始化的微调编码器,将获得源数据上的最小误差。
这个过程是元学习中广泛研究的标准双层优化问题。特别是在当前的场景下,外部循环根据外部循环损失来更新目标嵌入器,该损失通过内部循环优化之后的目标编码器计算。图5(a)
说明了在元学习期间外部循环中嵌入器参数 \(\boldsymbol{\phi}_e\) 的单次更新,图5(b)
展示了双层优化的过程。
更具体地说,内部循环是在目标数据集上对模型进行优化,受到目标嵌入器由 \(\boldsymbol{\phi}_e\) 初始化条件的限制
其中, \(\mathcal{L}_{inner}\) 是与公式1
中相同的损失函数,而
这种内部循环优化模拟了第二阶段的完整微调过程,并返回一个已经适应目标模态的编码器。请注意,内部循环中的整个目标模型的最优化取决于目标嵌入器的初始化,因此有 \(\boldsymbol{\theta}^{\mathcal{T}^*}(\boldsymbol{\phi}_e) = \{ \boldsymbol{\theta}^{\mathcal{T}^*}_e(\boldsymbol{\phi}_e),\boldsymbol{\theta}^{\mathcal{T}^*}_f(\boldsymbol{\phi}_e),\boldsymbol{\theta}^{\mathcal{T}^*}_h(\boldsymbol{\phi}_e) \}\) 。
外部循环是针对目标嵌入器的优化问题,目标是找到最优嵌入器参数 \({\phi}_e^*\) ,使得生成的最优目标编码器 \(\boldsymbol{\theta}^{\mathcal{T}^*}_f(\boldsymbol{\phi}_e^*)\) 能够产生高质量的源数据表示。为了计算损失,利用源模态中的一小部分带标签的数据集 \(\{\boldsymbol{x}_i^s, y_i^s\}\) 作为替代,并计算它们的特征 \(\{\boldsymbol{f}^s_i = f(e(\boldsymbol{x}^s_i;\boldsymbol{\theta}_{e_0}^\mathcal{S});\boldsymbol{\theta}_f^{\mathcal{T}^*}(\boldsymbol{\phi}_e))\}\) 。然后,将这些特征归一化到单位球上,并测量源特征的对齐性和均匀性。具体来说,对齐损失衡量了来自同一类别的特征是否接近,而均匀性损失则衡量了来自不同类别的特征是否均匀分布在球面上。
衡量编码器源模态可区分性的外部循环目标具有以下形式:
值得注意的是,在嵌入器训练开始阶段,源知识无法被很好地保留。为了防止嵌入器过分关注源模态,并保持优化过程稳定,通过共同最小化两个目标并引入权衡参数 \(\lambda\) 在源知识学习和目标知识学习之间取得平衡:
在实践中,内部循环中采用简化的单步更新,这使得能够重复使用在内部循环模拟更新期间计算的损失 \(\mathcal{L}_{inner}\) ,来有效地计算这个组合目标 \(\mathcal{L}_{outer}^{'}\) 。为此,论文提出的MoNA
在第一阶段,使用以下公式更新目标嵌入器:
随着模态知识的更好对齐,MoNA
在第二阶段进行普通微调。
Experiments
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】