self-attention为什么要除以根号d_k

参考文章:

https://blog.csdn.net/tailonh/article/details/120544719

正如上文所说,原因之一在于:

1、首先要除以一个数,防止输入softmax的值过大,导致偏导数趋近于0;
2、选择根号d_k是因为可以使得q*k的结果满足期望为0,方差为1的分布,类似于归一化。

首先我们来看看softmax函数的导数(下文的\(i,j\)的描述需要看该文章):

https://zhuanlan.zhihu.com/p/105722023
image

其中,

\[p_i= \frac{e^{z_i}}{\sum_{c=1}^{c}e^{z_c}} \]

我们知道,假设\(j=i\),如果\(z_i\)比较大的话,那么\(p_i\)也会比较大,那么偏导会很小;
假设\(j\neq i\),如果\(z_i\)比较大或\(z_j\)比较大的话,那么\(p_i\)会比较大且\(p_j\)比较小(因为\(z_i\)比较大),或者是\(p_i\)会比较小且\(p_j\)比较大,无论什么情况,偏导都会很小,这在反向传播的时候会造成梯度消失。

现在让我们来看看原因二,首先看一个实验现象:
我们知道,self-attention的计算原理可以用下图表示:
image

红色波浪线处的行向量为\(Q\)中的第一个行向量对\(K^T\)做矩阵乘法得到,现在让我们通过实验来看看这个红色波浪线处的行向量的数据分布。

首先,让我们定义\(Q\)\(K\):

Q = torch.randn(size=(512,512))
K = torch.randn(size=(512,512))

如下图所示,\(Q\)\(K\)的每一行的均值基本上为0,方差为1:
image
image

然后,让我们计算\(Q[0]K^T\):
image
最后,让我们来看看变量\(result\)的标准差:
image
其平方很接近512.
如果我们把维度提升10倍:

Q = torch.randn(size=(512,5120))
K = torch.randn(size=(5120,5120))

那么结果为:
image
从图中也可以看出来,数据之间的分布差异很大,如果有个值很大,那么经过softmax后,其值会很大,也会导致其他值很小,基本上就是01分布了,那么在在反向传播的时候,梯度就会很小(见上面的 softmax函数的导函数公式)。另外我们知道,标准差是衡量数据偏离均值的总体程度,如果除以标准差的话,那么结果就不会太大,也不会很小。

posted @ 2022-09-20 08:40  Hisi  阅读(3230)  评论(0编辑  收藏  举报