多头注意力机制的python实现

多头注意力机制是一种用于处理序列数据的神经网络结构,在自然语言处理领域中得到广泛应用。它可以帮助模型更好地理解和学习输入序列中的信息,提高模型在各种任务上的性能。

多头注意力机制是基于注意力机制的改进版本,它引入了多个注意力头,每个头都可以关注输入序列中不同位置的信息。通过汇总多个头的输出,模型可以更全面地捕捉输入序列中的特征。

下面我们用一个简单的例子来演示如何使用python实现多头注意力机制。我们将使用pytorch框架来构建模型。

import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)
    def forward(self, query, key, value):
        batch_size = query.size(0)
        query = self.query_linear(query)
        key = self.key_linear(key)
        value = self.value_linear(value)
        query = query.view(batch_size, -1, self.num_heads, self.d_model// self.num_heads).transpose(1,2)
        key = key.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1,2)
        value = value.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1,2)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d_model // self.num_heads) ** 0.5
        attention_weights = F.softmax(scores, dim = -1)
        output = torch.matmul(attention_weights, value)
        output = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)
        return self.output_linear(output)
if __name__ == "__main__":
    query = torch.randn(5,10,20)
    key = torch.randn(5,10,20)
    value = torch.randn(5,10,20)
    multi_head_attention = MultiHeadAttention(d_model = 20, num_heads = 4)
    output = multi_head_attention(query, key, value)
    print("output.shape: ", output.shape)

 运行上面的代码,我们可以看到模型输出的形状为(5,10,20),说明多头注意力机制成功运行并得到了输出。

 

posted @ 2024-07-09 18:19  小丑_jk  阅读(134)  评论(0编辑  收藏  举报