为什么 argmax 不可微?
前言:这个题目有些标题党,argmax 的值域甚至都不是实数域,没有良定义的梯度。那为什么要讨论这个问题呢?因为 argmax 是深度学习模型中常用的一个操作,对于包含离散随机变量的模型来说,更是一个不可回避的操作,这就带来了不可微的问题,使得模型难以进行端到端的训练。但究竟为什么不可微,似乎鲜有深入的讨论。本文刨根问到底,希望回答以下问题:
- 能否定义 argmax 的梯度?
- argmax 的梯度具有什么性质?
通过本文的探讨将得到以下结论:
argmax 可以用独热编码等价地表示为几乎处处连续的函数,除间断点以外,梯度为零。
背景
自动微分以及基于梯度的优化是目前深度学习模型能够进行端到端训练的根基之一。但是我们并不总是在处理图像或者隐层向量表示这样的连续型变量,常常也会面对离散型变量,例如模型生成句子等。离散型变量的引入使得模型缺少了良定义的梯度,阻断了反向传播的过程,给端到端训练带来了困难。通常我们处理的离散型变量为服从类别分布的随机变量,如词表中的一个词等。记 \(c\sim\text{Cat}(\boldsymbol{\alpha})\) 为一个服从类别分布的随机变量,\(c\) 是取值范围在 \(1\) 到 \(K\) 的正整数,\(\boldsymbol{\alpha}\) 是类别分布的参数,表示各个类别出现的概率。
通常会通过两类方式从分布 \(\text{Cat}(\boldsymbol{\alpha})\) 中得到 \(c\),一种是贪婪地选取概率最大的类别,另一种是按概率采样。显然前者需要使用 argmax 函数;后者可以通过 Gumbel-Max 技巧将采样过程变成一个确定性的计算过程,通过 argmax 得到采样结果,具体参考博客。下面直接讨论 argmax 本身的性质。记 \([K]\) 为 1 到 \(K\) 的正整数构成的集合,定义 argmax 函数如下:
需要注意的是,当存在超过一个最大值时,\(k'\) 取其中的最小下标,这也符合 argmax 函数的具体实现。这是一个约定,同时也是 argmax 不连续点和不可微点的根源。
由上述定义可知,由于值域是离散域,因此 \(f\) 本身就没有良定义的梯度。\(k\) 实际上是下标索引,可以等价替换为其独热表示 \(\mathbf{e}_{k'}\in\{0,1\}^K\),即 \(\mathbb{R}^K\) 标准基底中的第 \(k'\) 个向量,满足第 \(k'\) 个元素为 1 ,其余元素均为 0。进一步,可以把值域扩展到 \(\mathbb{R}^K\),然后就得到了 argmax 的连续值域版本,即
那为什么不直接把 \([K]\) 扩展到 \(\mathbb{R}\) 呢?因为 argmax 的结果还要参与模型的后续运算,下标索引的连续性和可微性难以定义。反之,用独热表示,下标索引的过程可以表示为独热向量 \(\mathbf{e}_{k'}\) 与嵌入矩阵 \(\mathbf{H}\in\mathbb{R}^{K\times D}\) 的乘法,即
这样易于对函数的性质进行分析。
argmax 函数的连续性与可微性
根据 \(\mathbf{x}\) 是否存在唯一的最大值进行讨论。
- 若 \(\mathbf{x}\) 存在唯一的最大值,则 \(f\) 在 \(\mathbf{x}\) 处连续。
若 \(x_{k'} > x_k,\forall\,k\neq k'\),取 \(\epsilon=x_{k'}-\max\{x_k:k\neq k'\}\)。当 \(\| \Delta \mathbf{x}\|_\infty<\epsilon\) 时,\(f(\mathbf{x})=f(\mathbf{x} + \Delta \mathbf{x})\)。令 \(\epsilon\to 0\),即可得到 \(f\) 在 \(\mathbf{x}\) 处连续。进一步可以得到 \(f\) 在该点可微,并且满足 \(\nabla f=\mathbf{O}\),其中 \(\mathbf{O}\in\mathbb{R}^{K\times K}\) 是一个全零矩阵。
- 若 \(\mathbf{x}\) 存在不唯一的最大值,则 \(f\) 在 \(\mathbf{x}\) 处间断。
不失一般性,假设存在两个最大值 \(x_k=x_{k'},k>k'\)。令 \(\Delta \mathbf{x}=\epsilon\cdot\mathbf{e}_{k}\),则
因此,\(f\) 在 \(\mathbf{x}\) 处不连续。显然,\(f\) 在这些点不可微。
argmax 不连续点的“个数”
不连续点至少有两个分量相等,分量的组合只有有限种,下面证明其中的每一种测度都是零,因此不连续点的测度是零。
记 \(\mathcal{A}=\{ \mathbf{x} \in\mathbb{R}^K:x_1=x_2\}\)。显然测度 \(\mu(\mathcal{A})=0\),因此 \(f\) 的不连续点的测度为零。
后记:argmax 不可微的问题,可以通过连续松弛和梯度估计来解决,不在本篇博客讨论范围内,感兴趣的读者可以自行搜索相关资料。