cross attention的源码实现,并代码详细讲解
import numpy as np def softmax(x, axis=-1): """Softmax函数,用于计算注意力权重""" e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return e_x / e_x.sum(axis=axis, keepdims=True) def scaled_dot_product_attention(q, k, v, mask=None): """缩放点积注意力机制,用于计算输出和注意力权重""" print(q.shape) print(k.transpose().shape) matmul_qk = np.matmul(q, k.transpose(0,2,1)) # 计算查询和键的矩阵乘积 d_k = k.shape[-1] # 键的维度 scaled_attention_logits = matmul_qk / np.sqrt(d_k) # 缩放注意力分数 if mask is not None: # 如果有注意力掩码,将其添加到分数上 scaled_attention_logits += (mask * -1e9) attention_weights = softmax(scaled_attention_logits) # 计算注意力权重 output = np.matmul(attention_weights, v) # 计算输出 return output, attention_weights def cross_attention(q, k, v, mask=None): """Cross-Attention机制""" # q, k, v 必须有匹配的前导维度 # q, k, v 的形状: (batch_size, seq_len, embed_dim) # mask 的形状: (batch_size, seq_len_q, seq_len_k) # 使用缩放点积注意力机制计算注意力 output, attention_weights = scaled_dot_product_attention(q, k, v, mask) return output, attention_weights # 测试用例 np.random.seed(0) # 确保可重复性 # 创建查询、键和值矩阵 batch_size = 2 seq_len_q = 3 seq_len_k = 4 embed_dim = 5 q = np.random.rand(batch_size, seq_len_q, embed_dim) k = np.random.rand(batch_size, seq_len_k, embed_dim) v = np.random.rand(batch_size, seq_len_k, embed_dim) # 创建注意力掩码(可选) mask = np.zeros((batch_size, seq_len_q, seq_len_k)) mask[:, :, -1:] = 1 # 假设我们想忽略每个序列的最后一个元素 # 计算Cross-Attention output, attention_weights = cross_attention(q, k, v, mask) print("Output shape:", output.shape) # 应该是 (batch_size, seq_len_q, embed_dim) print("Attention weights shape:", attention_weights.shape) # 应该是 (batch_size, seq_len_q, seq_len_k)
Cross-Attention,也称为自注意力或查询(Query)-键(Key)-值(Value)注意力机制,是一种在Transformer模型中广泛使用的注意力机制。在Cross-Attention中,查询(Query)通常来自于一个序列(如文本序列),而键(Key)和值(Value)来自于另一个序列(如另一个文本序列或图像特征)。
以下是一个简化的Cross-Attention的源码实现,使用Python和NumPy库。这个实现是为了说明Cross-Attention的基本概念,并不是一个高效或完整的实现。在实际应用中,Cross-Attention通常使用更高效的库,如TensorFlow或PyTorch。
代码讲解:
-
softmax
函数:用于计算注意力权重。它首先从输入矩阵中减去每行的最大值,以增加数值稳定性,然后计算指数,最后将结果归一化为概率分布。 -
scaled_dot_product_attention
函数:实现缩放点积注意力机制。它首先计算查询(q)和键(k)的转置的矩阵乘积,然后除以键的维度的平方根进行缩放。如果有注意力掩码(mask),将其应用于注意力分数以忽略某些部分。最后,使用softmax函数计算注意力权重,并将其与值(v)相乘以得到输出。 -
cross_attention
函数:实现Cross-Attention机制。它接受查询(q)、键(k)和值(v)作为输入,以及一个可选的注意力掩码(mask)。它调用scaled_dot_product_attention
函数来计算输出和注意力权重,并将其返回。
在实际应用中,Cross-Attention通常使用深度学习框架(如PyTorch或TensorFlow)的内置函数和类来实现,这些实现更加高效和灵活。上述代码仅用于说明Cross-Attention的基本概念。