关于attention中对padding的处理:mask

先问了下chatgpt:
我正在学习torch.nn.multiheadattention,请告诉我att_mask和key_padding_mask这两个参数有什么不同,分别用于处理什么问题,以及输出有什么不同,并给出代码示例

chatgpt的回答:
torch.nn.MultiheadAttention 中的 attn_maskkey_padding_mask 是两个非常重要的参数,用于处理不同的注意力机制中的问题。

attn_mask (Attention Mask)

attn_mask 用于在计算注意力权重时屏蔽某些位置。这在例如自回归生成模型中非常有用,可以防止模型在生成当前词时看到未来的词。它可以是一个 2D 的布尔型矩阵,用于屏蔽特定位置,或是一个上三角矩阵,用于屏蔽未来的位置。

key_padding_mask

key_padding_mask 用于在计算注意力权重时忽略某些填充位置(padding)。这是在处理不同长度的序列时常用的技巧,确保模型不会将填充值(通常为0)作为有效的信息来处理。

不同之处

  • attn_mask 是用来遮蔽未来的位置或指定位置,以防止信息泄露或特定位置的信息影响计算。
  • key_padding_mask 是用来忽略填充值的位置,确保这些填充值不会影响模型的注意力计算。

代码示例

import torch
import torch.nn as nn

# 创建一些示例数据
batch_size = 4
seq_length = 6
embed_dim = 8
num_heads = 2

# 模拟输入 (batch_size, seq_length, embed_dim)
x = torch.rand(batch_size, seq_length, embed_dim)

# 模拟填充后的序列
padded_x = nn.utils.rnn.pad_sequence(
    [torch.rand(3, embed_dim), torch.rand(4, embed_dim), torch.rand(5, embed_dim), torch.rand(6, embed_dim)], 
    batch_first=True
)

# 创建 SelfAttention 实例
self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

# Attention mask: 遮蔽未来的位置(上三角矩阵)
attn_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()

# Key padding mask: 遮蔽填充值(padding)
key_padding_mask = torch.tensor([
    [False, False, False, True, True, True],  # 只有前三个位置有效
    [False, False, False, False, True, True], # 只有前四个位置有效
    [False, False, False, False, False, True],# 只有前五个位置有效
    [False, False, False, False, False, False]# 所有位置都有效
])

# 计算注意力
attn_output, attn_weights = self_attention(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)

print("Attention Output:", attn_output)
print("Attention Weights:", attn_weights)

解释代码

  1. 模拟输入数据:创建一个随机的输入张量 x,形状为 (batch_size, seq_length, embed_dim)
  2. 模拟填充后的序列:使用 nn.utils.rnn.pad_sequence 创建一个填充后的序列 padded_x
  3. 创建 SelfAttention 实例:实例化 nn.MultiheadAttention,设置 embed_dimnum_heads
  4. 创建 Attention Maskattn_mask 是一个上三角矩阵,用于屏蔽未来的位置。
  5. 创建 Key Padding Maskkey_padding_mask 是一个布尔型张量,用于屏蔽填充值的位置。
  6. 计算注意力:调用 self_attention,传入输入张量、attn_maskkey_padding_mask

输出

  • Attention Output:经过注意力机制后的输出张量。
  • Attention Weights:注意力权重,展示了每个位置的权重值。

这样,你可以通过 attn_maskkey_padding_mask 灵活地控制注意力机制的行为,以适应不同的任务需求。

参考博客:

  • 【pytorch的key_padding_mask和参数attn_mask有什么区别?】https://www.zhihu.com/question/455164736
    key_padding_mask是一个bool值矩阵[batch_size, seq_len],在seq_len这一维度,有效值的位置对应True,padding值对应False
    比如[2,1,4,0] → [True, True, True, False]
    刚好我过nn.embedding之前的就是一个pad好的序列idx_tensor,直接padding_mask == 0即可
posted @ 2024-05-21 17:52  waterrr。  阅读(823)  评论(0编辑  收藏  举报