【NLP】Attention机制学习

Attention 机制学习

Attention 机制中一般需要用到的三个参数

query(Q), key(K), value(V)

preview

attention 包括硬编码和软编码

其中\(h^i\)是编码器每个step的输出, \(z^j\) 是解码器每个step的输出,计算步骤是这样的:

  1. 先对输入进行编码,得到 \([h^1, h^2, h^3, h^4]\)
  2. 开始解码了,先用固定的start token也就是 \(z ^ 0\) 最为Q,去和每个\(h^i\) (同时作为K和V)去计算attention,得到加权的 \(c ^ 0\)
  3. \(c^0\)作为解码的RNN输入(同时还有上一步的 \(z ^ 0\)),得到 \(z ^ 1\) 并预测出第一个词是machine
  4. 再继续预测的话,就是用\(z^1\) 作为Q去求attention:

增加了attention的学习机制后,可以编码更长的序列信息,同时,也可以优化输出序列和输入序列中,单词排序不同情况下的表现,这在机器对语句进行理解、摘要或者翻译中,具有重要影响。

当然,这种attention可能会减少对序列顺序的敏感性,同时,由于使用rnn,不能并行化计算。


在实现Seq2Seq模型中,Decoder解码部分,对于前一个预测词,有两种来源可以采用:

  • 模型预测的单词
  • 给定结果的单词

模型01、02,我没有采用Attention机制,同时只用给定结果的单词参与运算。导致了模型训练极度拟合了train训练集,因此,在valid集上的loss越来越大。在自己抽样调查中,可以明显感知,valid上,完全是用train中的原句去预测。可以在抽样中看到。

在模型03中,我采用了Attention机制,虽然沿用了只“给定结果的单词”的方式,但是效果还不错,在20个epoch训练后,valid集上交叉熵损失随着train集的损失稳步下降。抽样调查也令人欣慰。

参考:https://zhuanlan.zhihu.com/p/44121378
图片来源也是

posted @ 2021-02-20 23:09  ckxkexing  阅读(120)  评论(0编辑  收藏  举报