动手学pytorch-注意力机制和Seq2Seq模型

注意力机制和Seq2Seq模型

1.基本概念

2.两种常用的attention层

3.带注意力机制的Seq2Seq模型

4.实验

1. 基本概念

Attention 是一种通用的带权池化方法,输入由两部分构成:询问(query)和键值对(key-value pairs)。\(𝐤_𝑖∈ℝ^{𝑑_𝑘}, 𝐯_𝑖∈ℝ^{𝑑_𝑣}\). Query \(𝐪∈ℝ^{𝑑_𝑞}\) , attention layer得到输出与value的维度一致 \(𝐨∈ℝ^{𝑑_𝑣}\). 对于一个query来说,attention layer 会与每一个key计算注意力分数并进行权重的归一化,输出的向量\(o\)则是value的加权求和,而每个key计算的权重与value一一对应。

为了计算输出,我们首先假设有一个函数\(\alpha\) 用于计算query和key的相似性,然后可以计算所有的 attention scores \(a_1, \ldots, a_n\) by

\[a_i = \alpha(\mathbf q, \mathbf k_i). \]

我们使用 softmax函数 获得注意力权重:

\[b_1, \ldots, b_n = \textrm{softmax}(a_1, \ldots, a_n). \]

最终的输出就是value的加权求和:

\[\mathbf o = \sum_{i=1}^n b_i \mathbf v_i. \]

Image Name

不同的attetion layer的区别在于score函数的选择,下面主要讨论两个常用的注意层 Dot-product Attention 和 Multilayer Perceptron Attention

Softmax屏蔽

在深入研究实现之前,首先介绍softmax操作符的一个屏蔽操作,主要目的是屏蔽无关信息。

超出2维矩阵的乘法

\(X\)\(Y\) 是维度分别为\((b,n,m)\)\((b, m, k)\)的张量,进行 \(b\) 次二维矩阵乘法后得到 \(Z\), 维度为 \((b, n, k)\)

\[ Z[i,:,:] = dot(X[i,:,:], Y[i,:,:])\qquad for\ i= 1,…,n\ . \]

2. 两种常用的attention层

2.1点积注意力

The dot product 假设query和keys有相同的维度, 即 $\forall i, 𝐪,𝐤_𝑖 ∈ ℝ_𝑑 $. 通过计算query和key转置的乘积来计算attention score,通常还会除去 \(\sqrt{d}\) 减少计算出来的score对维度𝑑的依赖性,如下

\[𝛼(𝐪,𝐤)=⟨𝐪,𝐤⟩/ \sqrt{d} \]

假设 $ 𝐐∈ℝ^{𝑚×𝑑}$ 有 \(m\) 个query,\(𝐊∈ℝ^{𝑛×𝑑}\)\(n\) 个keys. 我们可以通过矩阵运算的方式计算所有 \(mn\) 个score:

\[𝛼(𝐐,𝐊)=𝐐𝐊^𝑇/\sqrt{d} \]

下面来实现这个层,它支持一批查询和键值对。此外,它支持作为正则化随机删除一些注意力权重.

测试
创建两个batch,每个batch有一个query和10个key-values对。通过valid_length指定,对于第一批,只关注前2个键-值对,而对于第二批,我们将检查前6个键-值对。尽管这两个批处理具有相同的查询和键值对,但获得的输出是不同的。

2.2 多层感知机注意力

在多层感知器中,我们首先将 query and keys 投影到 \(ℝ^ℎ\) .为了更具体,我们将可以学习的参数做如下映射
\(𝐖_𝑘∈ℝ^{ℎ×𝑑_𝑘}\) , \(𝐖_𝑞∈ℝ^{ℎ×𝑑_𝑞}\) , and \(𝐯∈ℝ^h\) . 将score函数定义

\[𝛼(𝐤,𝐪)=𝐯^𝑇tanh(𝐖_𝑘𝐤+𝐖_𝑞𝐪) \]

.
然后将key 和 value 在特征的维度上合并(concatenate),然后送至 a single hidden layer perceptron 这层中 hidden layer 为 ℎ and 输出的size为 1 .隐层激活函数为tanh,无偏置.

测试
尽管MLPAttention包含一个额外的MLP模型,但如果给定相同的输入和相同的键,我们将获得与DotProductAttention相同的输出

3. 带注意力机制的Seq2seq模型

解码器

在解码的每个时间步,使用解码器的最后一个RNN层的输出作为注意层的query。然后,将注意力模型的输出与输入嵌入向量连接起来,输入到RNN层。虽然RNN层隐藏状态也包含来自解码器的历史信息,但是attention model的输出显式地选择了enc_valid_len以内的编码器输出,这样attention机制就会尽可能排除其他不相关的信息。

4. 实验

posted @   hou永胜  阅读(1044)  评论(0编辑  收藏  举报
编辑推荐:
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
阅读排行:
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!
点击右上角即可分享
微信分享提示