Attention机制

一、参考资料

通俗易懂理解自注意力机制(Self-Attention)
https://www.bilibili.com/video/BV1sw4m1k7Gt/?spm_id_from=333.337.search-card.all.click&vd_source=3f409e335d99edd58fc22f4c59f6ae9e

通过代码简介什么是attention, self-attention, multi-head attention以及transformer
https://www.bilibili.com/video/BV1QD4y177Sf/?spm_id_from=333.337.search-card.all.click&vd_source=3f409e335d99edd58fc22f4c59f6ae9e

二、笔记

显而易见,attention是学到了上下文(包括自身)的关系,每个单词之间会有一个权重,然后进行加权平均得到下一层的向量输出
image

三、代码示例

1、Attention机制代码

##How to build attention using Pytorch
#In PyTorch, you can build an attention mechanism by using the dot or cosine similarity functions to compute the attention weights, 
#and then applying those weights to the input to obtain the attended output.
#Here is an example of how you can implement attention using PyTorch:

import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, attention_type='dot', hidden_size=256):
        super(Attention, self).__init__()
        self.attention_type = attention_type
        self.hidden_size = hidden_size

        # Linear layer to transform the query (decoder hidden state)
        self.query = nn.Linear(hidden_size, hidden_size, bias=False)
        # Linear layer to transform the key (encoder hidden state)
        self.key = nn.Linear(hidden_size, hidden_size, bias=False)
        # Linear layer to transform the value (encoder hidden state)
        self.value = nn.Linear(hidden_size, hidden_size, bias=False)

        # Softmax layer to compute the attention weights
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, keys, values):
        # Transform the query
        query = self.query(query)
        query = query.unsqueeze(1)

        # Transform the keys
        keys = self.key(keys)
        # Transform the values
        values = self.value(values)

        # Compute the attention weights
        if self.attention_type == 'dot':
            # dot product attention 执行批量矩阵乘法
            attention_weights = torch.bmm(query, keys.transpose(1, 2))
        elif self.attention_type == 'cosine':
            # cosine similarity attention
            query = query / query.norm(dim=-1, keepdim=True)
            keys = keys / keys.norm(dim=-1, keepdim=True)
            attention_weights = torch.bmm(query, keys.transpose(1, 2))
        else:
            raise ValueError(f"Invalid attention type: {self.attention_type}")

        # Normalize the attention weights
        attention_weights = self.softmax(attention_weights)

        # Apply the attention weights to the values to obtain the attended output
        attended_output = torch.bmm(attention_weights, values)

        return attended_output, attention_weights

#To use this attention module, you can pass it the query (decoder hidden state), keys (encoder hidden states), and values (encoder hidden states) 
#as input, and it will return the attended output and the attention weights.
#For example:

# Define the attention module
attention = Attention(attention_type='dot', hidden_size=256)

# Inputs to the attention module
batch_size = 10
hidden_size = 256
sequence_length = 12
query = torch.randn(batch_size, hidden_size)
keys = torch.randn(batch_size, sequence_length, hidden_size)
values = torch.randn(batch_size, sequence_length, hidden_size)

# Compute the attended output and attention weights
attended_output, attention_weights = attention(query, keys, values)
print(attended_output.shape)
print(attention_weights.shape)

2、Muti-Head Attention机制代码

## How to build multi-head attention using Pytorch?
# Multi-head attention is an extension of the attention mechanism that 
# allows the model to attend to multiple different parts of the input simultaneously. 
# It does this by using multiple attention heads, each of which attends to a different part of the input and produces its own attended output. 
# These attended outputs are then concatenated and transformed to obtain the final attended output.
# Here is an example of how you can implement multi-head attention using PyTorch:


import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, input_dim, output_dim):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.query_projections = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
        self.key_projections = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
        self.value_projections = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
        self.output_projection = nn.Linear(num_heads * output_dim, output_dim)

    def forward(self, query, key, value, mask=None):
        outputs = []
        for i in range(self.num_heads):
            query_projection = self.query_projections[i](query)
            key_projection = self.key_projections[i](key)
            value_projection = self.value_projections[i](value)
            
            dot_product = torch.matmul(query_projection, key_projection.transpose(1, 2))
            if mask is not None:
                dot_product = dot_product.masked_fill(mask == 0, -1e9)
            attention_weights = torch.softmax(dot_product, dim=-1)
            output = torch.matmul(attention_weights, value_projection)
            outputs.append(output)
            
        concatenated_outputs = torch.cat(outputs, dim=-1)
        final_output = self.output_projection(concatenated_outputs)
        return final_output

# Define the multi-head attention module
attention = MultiHeadAttention(num_heads=3, input_dim=384, output_dim=128)

# Define the input tensors
query = torch.randn(32, 16, 384)
key = torch.randn(32, 16, 384)
value = torch.randn(32, 16, 384)
mask = torch.zeros(32, 16, 16)

# Apply the attention module to the input tensors
output = attention(query, key, value, mask=mask)
posted @ 2024-10-05 16:37  JaxonYe  阅读(12)  评论(0编辑  收藏  举报