[机器学习]对transformer使用padding mask
注:本文是对GPT4的回答的整理校正补充。
在处理序列数据时,由于不同的序列可能具有不同的长度,我们经常需要对较短的序列进行填充(padding)以使它们具有相同的长度。但是,在模型的计算过程中,这些填充值是没有实际意义的,因此我们需要一种方法来确保模型在其计算中忽略这些填充值。这就是padding mask的作用。
比如常用的就是在数据集准备中,想用batch来训练,就得将一个batch的数据的长度全部对齐。
1. 什么是Padding Mask?
Padding mask是一个与输入序列形状相同的二进制矩阵,用于指示哪些位置是真实的数据,哪些位置是填充值。
- 真实数据位置的mask值为0。
- 填充位置的mask值为1。
2. 如何使用Padding Mask?
在自注意力机制中,我们计算查询和键的点积来得到注意力分数。在应用softmax函数之前,我们可以使用padding mask来确保填充位置的注意力分数为一个非常大的负数(例如,乘以-1e9)。这样,当应用softmax函数时,这些位置的权重将接近于零,从而确保模型在其计算中忽略这些填充值。
3. 示例
假设我们有一个长度为4的序列:[A, B, C, <pad>]
,其中<pad>
是填充标记。对应的padding mask是:[0, 0, 0, 1]
。
在计算注意力分数后,我们可以使用以下方法应用padding mask:
attention_scores = attention_scores.masked_fill(mask == 1, -1e9)
这里,masked_fill
是一个PyTorch函数,它会将mask中值为1的位置替换为-1e9
看图,这里的attention_scores就是Q×K的矩阵,把尾部多余的部分变成-inf,再过SoftMax,这样就是0了。这样,即使V的后半部分有padding的部分,也会因为乘0而变回0。这样被padding掉的部分就从计算图上被剥离了,由此不会影响模型的训练。
4. 代码
笔者自己写的,不保证靠谱哈。
import torch.nn as nn
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, mask=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
# Apply the padding mask
if mask is not None:
attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2) == 1, float('-inf'))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
5. 为什么需要Padding Mask?
-
忽略无关信息:通过使用padding mask,我们可以确保模型在其计算中忽略填充值,从而避免这些无关的信息对模型的输出产生影响。
-
稳定性:如果不使用padding mask,填充值可能会对模型的输出产生不稳定的影响,尤其是在使用softmax函数时。
-
解释性:使用padding mask可以提高模型的解释性,因为我们可以确保模型的输出只与真实的输入数据有关,而不是与填充值有关。
总之,padding mask是处理序列数据时的一个重要工具,它确保模型在其计算中忽略填充值,从而提高模型的性能和稳定性。