[论文阅读] EMO@ Earth Mover Distance Optimization for Auto-Regressive Language Modeling
Pre
title: EMO: Earth Mover Distance Optimization for Auto-Regressive Language Modeling
accepted: arXiv2023
paper: https://arxiv.org/abs/2310.04691
code: https://github.com/DRSY/EMO
ref: https://spaces.ac.cn/archives/9797
关键词: language-modeling, optimal transport, earth mover distance
阅读理由: 或许可以通过替换交叉熵而带来性能提升?
Idea
将自回归语言建模中常用的交叉熵损失换成基于推土机距离的新损失,通过 Embedding 算相似度,来为“近义词”分配了更合理的惩罚,主要用于微调阶段
Motivation&Solution
- 自然语言模型训练用的最大似然估计(MLE)等效于交叉熵(forward cross-entropy),但它由于 recall-prioritization, negative diversity ignorance, train-test mismatch ,对于人类语言-模型生成分布的对齐并非最优 —— 基于最优传输思想进行改进
- 直接计算EMO(Earth Mover Distance Optimization)复杂度较高且无法梯度传播 —— 求解其上界来简化端到端训练
Background
交叉熵训练损失低未必能说明模型学得好,它作为一个指标有些固有缺陷:
- recall-prioritized (只关心召回?): 每个时间步都仅专注于增加下一个 ground-truth token 的模型概率,这会导致训练数据有噪声时学到的模型分布精度很差或高质量文本充足时收敛却很慢。
- negative diversity ignorance (忽视了负例的多样性?): 预测下一个token时所有非groundtruth被视为同等的不正确,然而某些token还是有点道理,甚至能替代groundtruth
- train-test objective mismatch (缺少与测试指标的相关性?): 其形式与语言模型评估时不一致,无法作为模型能力的指示器。
EMD的优点:
- 建模时能同时考虑精度和召回率
- 认可数据样本有不同的正确度,允许更细致地训练
- 其数学形式使得训练测试阶段更一致
但计算EMD需要额外的求解器(solver),它不属于计算图,会阻碍梯度传播,通过使用EMD的上界进行端到端的训练
EMO: EARTH MOVER DISTANCE OPTIMIZATION
ADAPTING EARTH MOVER’S DISTANCE TO AUTO-REGRESSIVE LANGUAGE MODELING
推土机距离(EMD)定义为两个概率分布 \(P_1,P_2\) 之间的最优传输成本:
其中 \(\prod(P_1,P_2)\) 表示以 \(P_1, P_2\) 为边缘分布的所有联合分布 \(\gamma(x_1, x_2)\) 的集合 (两个分布中各自取两个样本?)。 \(\gamma(x_1, x_2)\) 可解释为从 \(P_1(x_1)\) 传输到 \(P_2(x_2)\) 的概率物质总量。 \(C(x_1, x_2)\) 是非负函数,测量从 \(x_1\) 传输单位物质到 \(x_2\) 所需的成本。
而 \(\inf\) 是下确界\(^{[ref1]}\),也就是将最低的传输成本作为 \(P_1,P_2\) 之间的差异度量。
在自回归语言建模中, \(P_1\) 指模型分布, \(P_2\) 则是数据分布,两者都表示给定先前token、在时间步t时,下一个token的局部可分解概率分布 (locally factorized probability distribution),也就是说 \(P_1 := Q_\theta(\cdot|x_{\lt t}),\; P_2 := P(\cdot|x_{\lt t})\) 。因此公式8可以重写为下列形式:
公式9
其中 \(V\) 表示语言模型的词汇表,\(v_i\) 是其中第i个token,一旦成本函数 \(C\) 确定下来,上述推土机距离的计算就相当于求解下列约束线性优化问题:
公式10
晕了,这两个看不懂,而且大脑也在抗拒着,总之就是将EMD作为语言模型的损失吧。[ref1]似乎浓缩大概介绍了下,不过为了不重复就不贴过来了。
Semantically-Informed Transport Cost 接下来要建立\(C\)的定义,它得反映出token对 \(v_i,\; v_j\) 之间的有意义的距离。直观上,那些可以互相替换的token理应有更近的距离,例如 glad 和 happy,而那些无法适应对方语境的就应该远离,如 cat 和 galaxy。在上下文嵌入空间中有一种余弦距离可作为token距离,即 \(C(v_i,v_j) = 1 - \frac{e^\top_i e_j}{|e_i||ej|}\) 其中 \(e_i\) 是语言模型 \(Q_\phi\) 的语言建模头 \(E\) ,它使用MLE进行预训练。由于训练时 \(e_i\) 被优化去接近所有下一个token是 \(v_i\) 的前缀(已有的token序列?)的上下文表达,因此 \(e_i,\; e_j\) 之间的余弦距离可作为 \(v_i, v_j\) 之间有效的代理。由于成本函数是一个priori(先验的?表示需要固定不变?),因此它(\(Q_\phi?)\)需要在\(Q_\theta\)训练时固定
A TRACTABLE UPPER BOUND
传统EMD求解器计算公式10的复杂度是 \(O(|V|^3\log|V|)\) ,而当今的LLM词汇表都很大,因此难以计算。但用外部的求解器又会扰乱梯度传播。
数据分布\(P\),模型分布\(Q_\theta\),\(\tilde{\gamma}\) 是一个传输计划,同时它也满足公式10的约束条件(边缘分布为P/Q):
本质上 \(\tilde{\gamma}\) 代表一种数据相关的传输计划,它按 \(P\) 指定的比例将 \(Q_\theta\) 的概率物质 \(v_i\) 移动到其他token上,由于 \(Q_\theta, P\) 两者各自的和皆为1, 因此 \(\tilde{\gamma}\) 是一个可行但非最优的计划。
没懂, \(Q_\theta(v_i) \cdot \{P(v_j)_0, P(v_j)_1, \ldots, P(v_j)_n\}\) 计算 \(v_i\) 移动到 \(P\) 中各个元素的概率吗?感觉是已知序列前缀 \(x\lt t\) 时,模型预测(\(Q_\theta\))和真实数据(\(P\))中下一个token(t时刻)的分布概率分别为 \(Q_\theta(v),\; P(v)\) 而 \(\tilde{\gamma}(v_i, v_j)\) 则表示各种情况出现的概率(预测\(v_1\), 真实分布是\(v_3\); 预测\(v_7\), 真实分布是\(v_5\)...),再配合 \(C(v_i,v_j)\) 来衡量t时刻模型预测 \(v_i\) 和 真实分布 \(v_j\) 之间的距离?但一般前缀一样,模型预测的跟真实的下一个token不是应该一样,直接让 \(v_i, v_i\) 之间小不就可以...又或者 \(P(v)\) 其实是one-hot...
总之记最优传输方案为 \(\gamma^*\) ,就能用不等式变出了EMD的上界:
公式12~14
这里推导的上界只与\(Q_\theta\)的训练有关,更加稳定和有效,将公式14定义为可微推土机距离 Differentiable Earth Mover Distance (DEMD) ,并将其作为token级别的优化目标。
PROPERTIES OF DEMD
分析了DEMD,介绍了它是如何做到比MLE优秀的,模型参数为\(\theta\),作者给出了它的梯度:
公式15
- Harmonizing Recall and Precision
- Negative Diversity Awareness
- Better Train-Test Consistency
就是跟本文提到的MLE三个缺点一一对应,具体分析这里略过!
Experiment
作为一个基于预训练LM的微调方法(continual fine-tuning method),实验主要比较EMO和其他损失在语言模型微调上的效果
Settings
使用的数据集略
选 GPT-2 和 OPT-125M,在每个数据集的训练集部分微调3epoch,保存验证损失最低的一个checkpoint。EMO最终损失由MLE和DEMD构成,使用AdamW优化器,学习率5e-5,batchsize在所有实验中固定为32,最大输入长度设置为256。
LANGUAGE MODELING MAIN RESULTS
表1 对预训练LM微调后的无偏采样(unbiased sampling, ancestral sampling)结果(Mauve↑)。分数由3个不同的随机种子各5轮采样平均而来,粗体表示结果远好于MLE(p-value<0.001)
TaiLr和MixCE都用上了新的距离度量,相对交叉熵有理论优势,但有着对模型训练动态假设弱(mild assumption about the model’s training dynamics)或退化为交叉熵的正则化版本的问题,因此仍然有MLE的部分缺点
这里的评价指标是MAUVE,越大越好,它出自《MAUVE: Measuring the Gap Between Neural Text and Human Text using Divergence Frontiers》,是跟人工评价最相关的自动评测指标之一
EXPERIMENT WITH ORACLE DATA GENERATOR
表2 不同训练标准微调的GPT-2无偏采样结果。分数由3个不同的随机种子各5轮采样平均而来,粗体表示结果远好于MLE(p-value<0.001)
EMO比起基线方法有着更低的PPL,表示它能有效缓解与低质量文本相关的高估问题,而PPL跟MLE更相关,表示EMO的优秀并非是指标选取带来的假象。而且EMO也取得更高的ROUGE分数,表示用它训练模型能更有效地平衡精度和召回率。
这里的PPL(perplexity)是借助 orcale GPT-2-Large 模型来计算的,其分布已知
LANGUAGE UNDERSTANDING MAIN RESULTS
表3 不同训练目标微调的LLaMa-7B/13B在下游任务的表现
图1 EMO对于模型规模和数据量的缩放法则
除了 LLaMa-7B/13B 还加上了 OPT-1.3B/2.7B 去做下游微调实验,并可视化8个数据集上随模型规模缩放变化的任务准确率,如图1左侧,发现MLE微调后并不能总是比预训练的好,而TaiLr和MixCE在精调权重系数时能得到提升,EMO总是更好。
针对LLaMa-13B,微调时改变数据量能得到图1右侧的图,MLE去微调反而随着见到的数据而性能下降,作者认为是其理论上的缺陷导致的。EMO提升很大,4m数据就能匹敌MixCE的100m数据微调。
Conclusion
EMO是一种训练自回归语言模型的新方法,通过优化模型分布和人类文本分布之间推土机距离的可微上界来实现,它效果好,并且展示出了对训练数据量的缩放属性,是一种通用的连续微调方法(continual fine-tuning)
Critique
图表展现出来的性能提升很吸引人,但通过文章和实验来看好像是用在模型微调上,似乎不能直接用于训练模型比较遗憾。
推导有点没看懂