selfAttention

在PyTorch框架中,nn.MultiheadAttention模块用于实现多头注意力机制,这是Transformer架构中的一个关键组成部分。该模块的输入形状如下:

  • query:形状为(L, N, E)的张量,其中:
    • L 是序列的长度(例如,句子中的单词数量)。
    • N 是批次大小。
    • E 是特征维度(即每个单词的嵌入维度)。
  • key:形状为(S, N, E)的张量,其中:
    • S 是key序列的长度。
    • N 是批次大小。
    • E 是特征维度,通常与query的特征维度相同。
  • value:形状为(S, N, E)的张量,其中:
    • S 是value序列的长度。
    • N 是批次大小。
    • E 是特征维度,通常与query和key的特征维度相同。
      这里是一个简单的例子,展示如何初始化并使用nn.MultiheadAttention
import torch
from torch import nn
# 假设我们有一个嵌入维度为512的模型,序列长度为10,批次大小为32,头数为8
embed_dim = 512
num_heads = 8
seq_len = 10
batch_size = 32
# 初始化MultiheadAttention
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
# 创建随机的query,key和value张量
query = torch.rand(seq_len, batch_size, embed_dim)
key = torch.rand(seq_len, batch_size, embed_dim)
value = torch.rand(seq_len, batch_size, embed_dim)
# 应用多头注意力机制
attn_output, attn_output_weights = multihead_attn(query, key, value)
# attn_output形状为 (seq_len, batch_size, embed_dim)
# attn_output_weights形状为 (batch_size, seq_len, seq_len)

需要注意的是,nn.MultiheadAttention模块中的embed_dim参数指的是每个头的维度,而整个多头注意力的输入和输出张量的特征维度是所有头的总和。如果想要得到每个头的维度,通常是将embed_dim除以num_heads。即每个头的维度是embed_dim // num_heads。如果embed_dim不能被num_heads整除,则需要通过nn.MultiheadAttentionin_proj_weight参数手动指定每个头的维度。

本文作者:seekwhale13

本文链接:https://www.cnblogs.com/seekwhale13/p/18735877

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   seekwhale13  阅读(2)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起