Scaled Dot-Product Attention 

在实际应用中,经常会用到 Attention 机制,其中最常用的是 Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。

  • Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
  • Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
  • Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.

Mask attention 的思想是 掩盖掉部分内容,不参与 attention 的计算,或许是因为不需要,或许因为不存在,根据实际场景来;

不参与attention计算 其实 就把 qk = 0 就行了

mask

上代码吧

import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F


class Attention_Layer(nn.Module):
    # 用来实现mask-attention layer
    def __init__(self, input_size, hidden_dim):
        super(Attention_Layer, self).__init__()

        self.hidden_dim = hidden_dim
        self.Q_linear = nn.Linear(input_size, hidden_dim, bias=False)
        self.K_linear = nn.Linear(input_size, hidden_dim, bias=False)
        self.V_linear = nn.Linear(input_size, hidden_dim, bias=False)

    def forward(self, inputs, lens):
        size = inputs.size()        # [b h w]   h代表词总量,w代表每个词的编码长度
        # 计算生成QKV矩阵
        Q = self.Q_linear(inputs)   # [b h hidden_dim]
        K = self.K_linear(inputs).permute(0, 2, 1)  # # [b hidden_dim h]
        V = self.V_linear(inputs)   # [b h hidden_dim]

        # 还要计算生成mask矩阵
        max_len = max(lens)  # 最大的句子长度,生成mask矩阵
        sentence_lengths = torch.Tensor(lens)  # 代表每个句子的长度
        print(sentence_lengths)                     # tensor([ 7., 10.,  4.])
        print(sentence_lengths.max().item())        # 10.0
        print(torch.arange(sentence_lengths.max().item()))          # tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
        print(torch.arange(sentence_lengths.max().item())[None, :]) # tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])
        print(sentence_lengths[:, None])
        # tensor([[7.],
        #         [10.],
        #         [4.]])
        mask = torch.arange(sentence_lengths.max().item())[None, :] < sentence_lengths[:, None]
        print(mask) # <前每一行的所有值 分别与 <后每一列的值 进行比较
        # tensor([[True, True, True, True, True, True, True, False, False, False],
        #         [True, True, True, True, True, True, True, True, True, True],
        #         [True, True, True, True, False, False, False, False, False, False]])
        mask = mask.unsqueeze(dim=1)  # [batch_size, 1, max_len]
        mask = mask.expand(size[0], max_len, max_len)  # [batch_size, max_len, max_len]

        padding_num = torch.ones_like(mask)     # 全1
        padding_num = -2 ** 31 * padding_num.float()    # 全无穷小
        # qk=[b h hidden_dim]*[b hidden_dim h]=[b h h] 代表每句话的每个词 和 其他词的关系
        alpha = torch.matmul(Q, K)
        # mask True 区域的 alpha 值 置为 无穷小
        alpha = torch.where(mask, alpha, padding_num)   # 用法:满足条件,返回x,否则返回y
        pd.DataFrame(alpha[0].data).to_csv('mask.csv')
        alpha = F.softmax(alpha, dim=2)
        pd.DataFrame(alpha[0].data).to_csv('softmax.csv')
        # softmax*v=[b h h]*[b h hidden_dim]=[b h hidden_dim]
        out = torch.matmul(alpha, V)
        return out


if __name__ == '__main__':
    input_size = 100
    hidden_size = 8
    input = torch.rand(3, 10, input_size)
    att_L = Attention_Layer(input_size, hidden_size)
    lens = [7, 10, 4]  # 一个batch文本的真实长度

    att_out = att_L(input, lens) 

看看中间结果就明白了

 

 

参考资料: