Attention机制
一、参考资料
通俗易懂理解自注意力机制(Self-Attention)
https://www.bilibili.com/video/BV1sw4m1k7Gt/?spm_id_from=333.337.search-card.all.click&vd_source=3f409e335d99edd58fc22f4c59f6ae9e
通过代码简介什么是attention, self-attention, multi-head attention以及transformer
https://www.bilibili.com/video/BV1QD4y177Sf/?spm_id_from=333.337.search-card.all.click&vd_source=3f409e335d99edd58fc22f4c59f6ae9e
二、笔记
显而易见,attention是学到了上下文(包括自身)的关系,每个单词之间会有一个权重,然后进行加权平均得到下一层的向量输出
三、代码示例
1、Attention机制代码
##How to build attention using Pytorch
#In PyTorch, you can build an attention mechanism by using the dot or cosine similarity functions to compute the attention weights,
#and then applying those weights to the input to obtain the attended output.
#Here is an example of how you can implement attention using PyTorch:
import torch
import torch.nn as nn
class Attention(nn.Module):
def __init__(self, attention_type='dot', hidden_size=256):
super(Attention, self).__init__()
self.attention_type = attention_type
self.hidden_size = hidden_size
# Linear layer to transform the query (decoder hidden state)
self.query = nn.Linear(hidden_size, hidden_size, bias=False)
# Linear layer to transform the key (encoder hidden state)
self.key = nn.Linear(hidden_size, hidden_size, bias=False)
# Linear layer to transform the value (encoder hidden state)
self.value = nn.Linear(hidden_size, hidden_size, bias=False)
# Softmax layer to compute the attention weights
self.softmax = nn.Softmax(dim=-1)
def forward(self, query, keys, values):
# Transform the query
query = self.query(query)
query = query.unsqueeze(1)
# Transform the keys
keys = self.key(keys)
# Transform the values
values = self.value(values)
# Compute the attention weights
if self.attention_type == 'dot':
# dot product attention 执行批量矩阵乘法
attention_weights = torch.bmm(query, keys.transpose(1, 2))
elif self.attention_type == 'cosine':
# cosine similarity attention
query = query / query.norm(dim=-1, keepdim=True)
keys = keys / keys.norm(dim=-1, keepdim=True)
attention_weights = torch.bmm(query, keys.transpose(1, 2))
else:
raise ValueError(f"Invalid attention type: {self.attention_type}")
# Normalize the attention weights
attention_weights = self.softmax(attention_weights)
# Apply the attention weights to the values to obtain the attended output
attended_output = torch.bmm(attention_weights, values)
return attended_output, attention_weights
#To use this attention module, you can pass it the query (decoder hidden state), keys (encoder hidden states), and values (encoder hidden states)
#as input, and it will return the attended output and the attention weights.
#For example:
# Define the attention module
attention = Attention(attention_type='dot', hidden_size=256)
# Inputs to the attention module
batch_size = 10
hidden_size = 256
sequence_length = 12
query = torch.randn(batch_size, hidden_size)
keys = torch.randn(batch_size, sequence_length, hidden_size)
values = torch.randn(batch_size, sequence_length, hidden_size)
# Compute the attended output and attention weights
attended_output, attention_weights = attention(query, keys, values)
print(attended_output.shape)
print(attention_weights.shape)
2、Muti-Head Attention机制代码
## How to build multi-head attention using Pytorch?
# Multi-head attention is an extension of the attention mechanism that
# allows the model to attend to multiple different parts of the input simultaneously.
# It does this by using multiple attention heads, each of which attends to a different part of the input and produces its own attended output.
# These attended outputs are then concatenated and transformed to obtain the final attended output.
# Here is an example of how you can implement multi-head attention using PyTorch:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, input_dim, output_dim):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.input_dim = input_dim
self.output_dim = output_dim
self.query_projections = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
self.key_projections = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
self.value_projections = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
self.output_projection = nn.Linear(num_heads * output_dim, output_dim)
def forward(self, query, key, value, mask=None):
outputs = []
for i in range(self.num_heads):
query_projection = self.query_projections[i](query)
key_projection = self.key_projections[i](key)
value_projection = self.value_projections[i](value)
dot_product = torch.matmul(query_projection, key_projection.transpose(1, 2))
if mask is not None:
dot_product = dot_product.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(dot_product, dim=-1)
output = torch.matmul(attention_weights, value_projection)
outputs.append(output)
concatenated_outputs = torch.cat(outputs, dim=-1)
final_output = self.output_projection(concatenated_outputs)
return final_output
# Define the multi-head attention module
attention = MultiHeadAttention(num_heads=3, input_dim=384, output_dim=128)
# Define the input tensors
query = torch.randn(32, 16, 384)
key = torch.randn(32, 16, 384)
value = torch.randn(32, 16, 384)
mask = torch.zeros(32, 16, 16)
# Apply the attention module to the input tensors
output = attention(query, key, value, mask=mask)
本文来自博客园,作者:JaxonYe,转载请注明原文链接:https://www.cnblogs.com/yechangxin/articles/18446405
侵权必究