Keras的多头自注意力实现(multi head attention)
from keras import Sequential, Model from keras.optimizers import Adam from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau from keras.layers import Layer, Input, Embedding, Conv1D, Bidirectional, LSTM, Dense, Dropout, BatchNormalization, GlobalMaxPooling1D, Flatten import tensorflow as tf # Only used for various tensor operations # A more general and complete version of the layer defined in the linked keras example class MultiHeadSelfAttention(Layer): """ This uses Bahadanau attention """ def __init__(self, num_heads = 8, weights_dim = 64): """ Constructor: Initializes parameters of the Attention layer """ # Initialize base class: super(MultiHeadSelfAttention, self).__init__() # Initialize parameters of the layer: self.num_heads = num_heads self.weights_dim = weights_dim if self.weights_dim % self.num_heads != 0: raise ValueError(f"Weights dimension = {weights_dim} should be divisible by number of heads = {num_heads} to ensure proper division into sub-matrices") # We use this to divide the Q,K,V matrices into num_heads submatrices, to compute multi-headed attention self.sub_matrix_dim = self.weights_dim // self.num_heads """ Note that all K,Q,V matrices and their respective weight matrices are initialized and computed as a whole This ensures somewhat of a parallel processing/vectorization After computing K,Q,V, we split these into num_heads submatrices for computing the different attentions """ # Weight matrices for computing query, key and value (Note that we haven't defined an activation function anywhere) # Important: In keras units contain the shape of the output self.W_q = Dense(units = weights_dim) self.W_k = Dense(units = weights_dim) self.W_v = Dense(units = weights_dim) def get_config(self): """ Required for saving/loading the model """ config = super().get_config().copy() config.update({ "num_heads" : self.num_heads, "weights_dim" : self.weights_dim # All args of __init__() must be included here }) return config def build(self, input_shape): """ Initializes various weights dynamically based on input_shape """ input_dim = input_shape[-1] self.input_dim = input_dim # Weight matrix for combining the output from multiple heads: # Takes in input of shape (batch_size, seq_len, weights_dim) returns output of shape (batch_size, seq_len, input_dim) self.W_h = Dense(units = input_dim) def attention(self, query, key, value): """ The main logic """ # Compute the raw score = QK^T score = tf.matmul(query, key, transpose_b=True) # Scale by dimension of K dim_key = tf.cast(tf.shape(key)[-1], tf.float32) # == DIM_KEY scaled_score = score / tf.math.sqrt(dim_key) # Weights are the softmax of scaled scores weights = tf.nn.softmax(scaled_score, axis=-1) # The final output of the attention layer (weighted sum of hidden states) output = tf.matmul(weights, value) return output, weights def separate_heads(self, x, batch_size): """ Splits the given x into num_heads submatrices and returns the result as a concatenation of these sub-matrices """ x = tf.reshape(x, (batch_size, -1, self.num_heads, self.sub_matrix_dim)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, inputs): """ All computations take place here """ batch_size = tf.shape(inputs)[0] # Compute Q = W_q*X query = self.W_q(inputs) # (batch_size, seq_len, weights_dim) # Compute K = W_k*X key = self.W_k(inputs) # (batch_size, seq_len, weights_dim) # Compute V = W_v*X value = self.W_v(inputs) # (batch_size, seq_len, weights_dim) # Split into n_heads submatrices query = self.separate_heads(query, batch_size) # (batch_size, num_heads, seq_len, sub_matrix_dim) key = self.separate_heads(key, batch_size) # (batch_size, num_heads, seq_len, sub_matrix_dim) value = self.separate_heads(value, batch_size) # (batch_size, num_heads, seq_len, sub_matrix_dim) # Compute attention (contains weights and attentions for all heads): attention, weights = self.attention(query, key, value) attention = tf.transpose(attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len, num_heads, sub_matrix_dim) # Concatenate all attentions from different heads (squeeze the last dimension): concat_attention = tf.reshape(attention, (batch_size, -1, self.weights_dim)) # (batch_size, seq_len, weights_dim) # Use a weighted average of the attentions from different heads: output = self.W_h(concat_attention) # (batch_size, seq_len, input_dim) return output def compute_output_shape(self, input_shape): print(input_shape) """ Specifies the output shape of the custom layer, without this, the model doesn't work """ return input_shape
model 实现:
def buildModel1(): # The input as sequences: input_layer = Input(shape=(N_SEQ,)) # Create the embedding layer embedding_layer = Embedding( N_vocab, N_EMB, weights = [embedding_matrix], input_length = N_SEQ, trainable = False # No need to train as our embeddings are already finetuned on Twitter data ) # Create the embeddings embedded_sequences = embedding_layer(input_layer) # The core of the model: # Single layer BiLSTM architecture x = Bidirectional( LSTM( units = DIM_HIDDEN, # In Keras, "units" mean the dimensionality of the hidden states h_t output by the LSTM dropout = 0.2, # recurrent_dropout = 0.2, # Can't use the GPU if this is included return_sequences = True ), merge_mode = "concat" # Just like in Transformers, thus output h = [h_f; h_b] will have dimension 2*DIM_HIDDEN )(embedded_sequences) # Adding multiheaded self attention x = MultiHeadSelfAttention(N_HEADS, DIM_KEY)(x) outputs = Flatten()(x) model = Model(input_layer, outputs) return model
参考来源:https://keras.io/examples/nlp/text_classification_with_transformer/
注意一点:输出是的shape=(?,?,dim),实际过程中,需要明确第二维真实数据,手动更改如下:
concat_attention = tf.reshape(attention, (batch_size, seq_len, self.weights_dim)) # (batch_size, seq_len, weights_dim)
时刻记着自己要成为什么样的人!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)
2020-03-03 Adversarial Training Methods For Semi-Supervised Text Classification 虚拟对抗训练思路指引
2020-03-03 英伟达Transfer Learning Toolkit 1.0 产品介绍、演示及技术交流