from_rnn_2_transformer-cnblog
从RNN到Transformer
各式各样的“attention”
不管是在CV领域还是NLP领域, attention实质上就是一种取权重求和的过程。使得网络focus在其应该focus的地方。
根据Attention的计算区域,可以分成以下几种:
1)Soft Attention,这是比较常见的Attention方式,对所有key求权重概率,每个key都有一个对应的权重,是一种全局的计算方式(也可以叫Global Attention)。这种方式比较理性,参考了所有key的内容,再进行加权。但是计算量可能会比较大一些。
2)Hard Attention,这种方式是直接精准定位到某个key,其余key就都不管了,相当于这个key的概率是1,其余key的概率全部是0。因此这种对齐方式要求很高,要求一步到位,如果没有正确对齐,会带来很大的影响。另一方面,因为不可导,一般需要用强化学习的方法进行训练。(或者使用gumbel softmax之类的)
3)Local Attention,这种方式其实是以上两种方式的一个折中,对一个窗口区域进行计算。先用Hard方式定位到某个地方,以这个点为中心可以得到一个窗口区域,在这个小区域内用Soft方式来算Attention。
循环神经网络
RNN
用公式表达就是
当t=0时, 使用torch.zeros作为上一个的输入
LSTM
如果句子过长,rnn有可能会出现梯度消失或者爆炸现象, 导致“记不住句子开头的详细内容”。LSTM通过它的“门控装置”有效的缓解了这个问题,这也就是为什么我们现在都在使用LSTM而非普通RNN。
遗忘门
如下图
记忆门
更新细胞状态
输出门
Transformer
整体结构
以中英翻译为例, 介绍transformer整体流程
step.1 获得输入X
step.2 将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x) 传入 Encoder 中,经过 6 个 Encoder block 后可以得到句子所有单词的编码信息矩阵 C,如下图。单词向量矩阵用
step.3 将 Encoder 输出的编码信息矩阵 C传递到 Decoder 中,Decoder 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖) 操作遮盖住 i+1 之后的单词。
上图 Decoder 接收了 Encoder 的编码矩阵 C,然后首先输入一个翻译开始符 "
Embedding
word embedding
单词的 Embedding 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到
Annotated Transformer实现
positional embedding
因为 Transformer 不采用 RNN 的结构,而是使用全局信息,不能利用单词的顺序信息,而这部分信息对于 NLP 来说非常重要。所以 Transformer 中使用位置 Embedding 保存单词在序列中的相对或绝对位置。
Annotated Transformer实现
其中 div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
由下面公式转换得来
在具体使用的时候, 词编码与位置编码结合参考如下
Self-Attention(自注意力机制)
transformer内部结构图如下图所示,红色圈中的部分为 Multi-Head Attention,是由多个 Self-Attention组成的。Multi-Head Attention 上方还包括一个 Add & Norm 层,Add 表示残差连接 (Residual Connection) 用于防止网络退化,Norm 表示 Layer Normalization,用于对每一层的激活值进行归一化。下面详细介绍一下self-attention
Self-Attention ,在计算的时候需要用到矩阵Q(查询),K(键值),V(值)。在实际中,Self-Attention 接收的是输入(单词的表示向量x组成的矩阵X) 或者上一个 Encoder block 的输出。而Q,K,V正是通过 Self-Attention 的输入进行线性变换得到的。
可以线性变阵矩阵WQ,WK,WV计算得到Q,K,V。得到QKV就可以计算self-attention的输出, 公式和图解如上图所示
代码实现
Multi-head Attention
实际上就是h个self-attention计算得到h个输出矩阵Z, 再进行拼接
Encoder结构
结构图可参考self-attention中的结构图, 由 Multi-Head Attention, Add & Norm, Feed Forward, Add & Norm 组成的,
公式如下
Feed Forward 层比较简单,是一个两层的全连接层,第一层的激活函数为 Relu,第二层不使用激活函数,对应的公式如下。
Decoder结构
decoder block有一些特殊的地方
- 包含两个 Multi-Head Attention 层。
- 第一个 Multi-Head Attention 层采用了 Masked 操作。
- 第二个 Multi-Head Attention 层的K, V矩阵使用 Encoder 的编码信息矩阵C进行计算,而Q使用上一个 Decoder block 的输出计算。
- 最后有一个 Softmax 层计算下一个翻译单词的概率。
第一个mult-head attention
第二个multi-head attention
Decoder block 第二个 Multi-Head Attention 变化不大, 主要的区别在于其中 Self-Attention 的 K, V矩阵不是使用 上一个 Decoder block 的输出计算的,而是使用 Encoder 的编码信息矩阵 C 计算的。
根据 Encoder 的输出 C计算得到 K, V,根据上一个 Decoder block 的输出 Z 计算 Q (如果是第一个 Decoder block 则使用输入矩阵 X 进行计算),后续的计算方法与之前描述的一致。
这样做的好处是在 Decoder 的时候,每一位单词都可以利用到 Encoder 所有单词的信息 (这些信息无需 Mask)。
其它
why LayerNorm?
BN
如下图所示, 一个batch有R条数据, BN就是对同一维度的特征进行normalization, 这是基于同一维度都是表示同一特征(比如第n维为身高)。但是bs比较小的时候bn效果不好
对于nlp而言, 词嵌入之后的句子长度是不一样的, 比如一个句子长度是20,另外9个句子长度不及5, 那这个时候做BN就是对每一个单词做bn, 显然是不太合理的, 除此之外, 不同句子相同位置的单词是没有语义关系的, 不适合做BN
LN
BatchNorm是对一个batch-size样本内的每个特征做归一化,LayerNorm是对每个样本的所有特征做归一化,对于NLP来说, 有两个句子“今天天气很好”, “你中午吃的什么”, LN就是对整个句子进行归一化处理
transformer为什么可以并行计算?
Transformer可以并行运算的原因是因为其使用了自注意力机制(Self-Attention)。自注意力机制可以同时计算所有输入序列中每个位置的表示,因此可以并行化处理整个输入序列。
具体来说,自注意力机制通过计算每个位置与所有其他位置之间的相关性来确定位置的表示。在计算相关性时,可以通过矩阵乘法来实现并行计算。因此,Transformer可以将整个输入序列通过矩阵乘法并行化处理,从而大大提高了模型的计算效率。
除此之外,Transformer还使用了多头注意力机制(Multi-Head Attention),将自注意力机制并行化,进一步提高了模型的计算效率。多头注意力机制将输入序列分成多个子序列,并对每个子序列进行注意力计算,最后将不同子序列的注意力计算结果拼接在一起,得到整个输入序列的表示。这样可以同时计算多个位置之间的相关性,大大提高了模型的计算效率。
因此,Transformer可以通过自注意力机制和多头注意力机制的并行计算,实现对输入序列进行高效的处理和表示学习。
未完待更新...
__EOF__

本文链接:https://www.cnblogs.com/xle97/p/17747183.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【推荐】一下。您的鼓励是博主的最大动力!
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性