attention伪代码(pytorch 版)

Attention的原理已经有很多介绍了,实现的伪代码参照transformer,下面写了最简单的版本

import torch, math
from torch import nn
dropout_prob = 0.1

def forward(
        hidden_size, # d
        input, #(b, s, d)
        attention_mask  #(b, s, s)
):
    query = nn.Linear(hidden_size, hidden_size) #(d,d)
    key = nn.Linear(hidden_size, hidden_size)
    value = nn.Linear(hidden_size, hidden_size)
    dropout = nn.Dropout(dropout_prob)

    query_layer = query(input) #(b, s, d)
    key_layer = key(input)
    value_layer = value(input)

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) #(b, s, s)
    attention_scores = attention_scores / math.sqrt(hidden_size)
    if attention_mask is not None:
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

    # Normalize the attention scores to probabilities.
    attention_probs = nn.functional.softmax(attention_scores, dim=-1) #(b, s, s)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = dropout(attention_probs)


    outputs = torch.matmul(attention_probs, value_layer) # (b, s, s), (b, s, d) -> (b, s, d)

    return outputs

posted @ 2024-03-28 11:02  高空降落  阅读(38)  评论(0编辑  收藏  举报