self-attention为什么要除以根号d_k
参考文章:
正如上文所说,原因之一在于:
1、首先要除以一个数,防止输入softmax的值过大,导致偏导数趋近于0;
2、选择根号d_k是因为可以使得q*k的结果满足期望为0,方差为1的分布,类似于归一化。
首先我们来看看softmax函数的导数(下文的\(i,j\)的描述需要看该文章):
其中,
我们知道,假设\(j=i\),如果\(z_i\)比较大的话,那么\(p_i\)也会比较大,那么偏导会很小;
假设\(j\neq i\),如果\(z_i\)比较大或\(z_j\)比较大的话,那么\(p_i\)会比较大且\(p_j\)比较小(因为\(z_i\)比较大),或者是\(p_i\)会比较小且\(p_j\)比较大,无论什么情况,偏导都会很小,这在反向传播的时候会造成梯度消失。
现在让我们来看看原因二,首先看一个实验现象:
我们知道,self-attention的计算原理可以用下图表示:
红色波浪线处的行向量为\(Q\)中的第一个行向量对\(K^T\)做矩阵乘法得到,现在让我们通过实验来看看这个红色波浪线处的行向量的数据分布。
首先,让我们定义\(Q\)和\(K\):
Q = torch.randn(size=(512,512))
K = torch.randn(size=(512,512))
如下图所示,\(Q\)和\(K\)的每一行的均值基本上为0,方差为1:
然后,让我们计算\(Q[0]K^T\):
最后,让我们来看看变量\(result\)的标准差:
其平方很接近512.
如果我们把维度提升10倍:
Q = torch.randn(size=(512,5120))
K = torch.randn(size=(5120,5120))
那么结果为:
从图中也可以看出来,数据之间的分布差异很大,如果有个值很大,那么经过softmax后,其值会很大,也会导致其他值很小,基本上就是01分布了,那么在在反向传播的时候,梯度就会很小(见上面的 softmax函数的导函数公式)。另外我们知道,标准差是衡量数据偏离均值的总体程度,如果除以标准差的话,那么结果就不会太大,也不会很小。