attention伪代码(pytorch 版)
Attention的原理已经有很多介绍了,实现的伪代码参照transformer,下面写了最简单的版本
import torch, math
from torch import nn
dropout_prob = 0.1
def forward(
hidden_size, # d
input, #(b, s, d)
attention_mask #(b, s, s)
):
query = nn.Linear(hidden_size, hidden_size) #(d,d)
key = nn.Linear(hidden_size, hidden_size)
value = nn.Linear(hidden_size, hidden_size)
dropout = nn.Dropout(dropout_prob)
query_layer = query(input) #(b, s, d)
key_layer = key(input)
value_layer = value(input)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) #(b, s, s)
attention_scores = attention_scores / math.sqrt(hidden_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1) #(b, s, s)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = dropout(attention_probs)
outputs = torch.matmul(attention_probs, value_layer) # (b, s, s), (b, s, d) -> (b, s, d)
return outputs