多头注意力机制的python实现
多头注意力机制是一种用于处理序列数据的神经网络结构,在自然语言处理领域中得到广泛应用。它可以帮助模型更好地理解和学习输入序列中的信息,提高模型在各种任务上的性能。
多头注意力机制是基于注意力机制的改进版本,它引入了多个注意力头,每个头都可以关注输入序列中不同位置的信息。通过汇总多个头的输出,模型可以更全面地捕捉输入序列中的特征。
下面我们用一个简单的例子来演示如何使用python实现多头注意力机制。我们将使用pytorch框架来构建模型。
import torch import torch.nn as nn import torch.nn.functional as F class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model self.query_linear = nn.Linear(d_model, d_model) self.key_linear = nn.Linear(d_model, d_model) self.value_linear = nn.Linear(d_model, d_model) self.output_linear = nn.Linear(d_model, d_model) def forward(self, query, key, value): batch_size = query.size(0) query = self.query_linear(query) key = self.key_linear(key) value = self.value_linear(value) query = query.view(batch_size, -1, self.num_heads, self.d_model// self.num_heads).transpose(1,2) key = key.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1,2) value = value.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1,2) scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d_model // self.num_heads) ** 0.5 attention_weights = F.softmax(scores, dim = -1) output = torch.matmul(attention_weights, value) output = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model) return self.output_linear(output) if __name__ == "__main__": query = torch.randn(5,10,20) key = torch.randn(5,10,20) value = torch.randn(5,10,20) multi_head_attention = MultiHeadAttention(d_model = 20, num_heads = 4) output = multi_head_attention(query, key, value) print("output.shape: ", output.shape)
运行上面的代码,我们可以看到模型输出的形状为(5,10,20),说明多头注意力机制成功运行并得到了输出。