Fastformer: Additive Attention Can Be All You Need
创新点:
- 本文根据transformer模型进行改进,提出了一个高效的模型,模型复杂度呈线性。
- 主要改进了注意力机制,出发点在于降低了注意力矩阵的重要程度,该方法采用一个(1*T)一维向量替换了原始T*T大小的注意力矩阵。
注意力结构图:
在这里,输入同样通过不同的线性映射得到Q,K,V,然后通过Q得到Q的权重:
其中,从Q到Q的权重变化过程为:
weight:(B,T,D)->(B,T,h)->(B,h,T)->(B,h,1,T)
然后Q和Q的权重做乘法运算weight*query=(B,h,1,ad)->(B,1,h,ad)->(B,1,D)->(B,T,D):
得到的结果和K做逐点乘法运算:
K的权重向量和Q的求法相同:
同样的K和K的权重做乘法运算:
最后的结果和V做逐点运算:
在这里Q和V是相同的,采用了权重共享的方法。
在espnet中的代码实现:
import numpy import torch class FastSelfAttention(torch.nn.Module): """Fast self-attention used in Fastformer.""" def __init__( self, size, attention_heads, dropout_rate, ): super().__init__() if size % attention_heads != 0: raise ValueError( f"Hidden size ({size}) is not an integer multiple " f"of attention heads ({attention_heads})" ) self.attention_head_size = size // attention_heads self.num_attention_heads = attention_heads self.query = torch.nn.Linear(size, size) self.query_att = torch.nn.Linear(size, attention_heads) self.key = torch.nn.Linear(size, size) self.key_att = torch.nn.Linear(size, attention_heads) self.transform = torch.nn.Linear(size, size) self.dropout = torch.nn.Dropout(dropout_rate) def espnet_initialization_fn(self): self.apply(self.init_weights) def init_weights(self, module): if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, torch.nn.Linear) and module.bias is not None: module.bias.data.zero_() def transpose_for_scores(self, x): """Reshape and transpose to compute scores. Args: x: (batch, time, size = n_heads * attn_dim) Returns: (batch, n_heads, time, attn_dim) """ new_x_shape = x.shape[:-1] + ( self.num_attention_heads, self.attention_head_size, ) return x.reshape(*new_x_shape).transpose(1, 2) def forward(self, xs_pad, mask): """Forward method. Args: xs_pad: (batch, time, size = n_heads * attn_dim) mask: (batch, 1, time), nonpadding is 1, padding is 0 Returns: torch.Tensor: (batch, time, size) """ batch_size, seq_len, _ = xs_pad.shape mixed_query_layer = self.query(xs_pad) # (batch, time, size) mixed_key_layer = self.key(xs_pad) # (batch, time, size) if mask is not None: mask = mask.eq(0) # padding is 1, nonpadding is 0 # (batch, n_heads, time) query_for_score = ( self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5 ) if mask is not None: min_value = float( numpy.finfo( torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype ).min ) query_for_score = query_for_score.masked_fill(mask, min_value) query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0) else: query_weight = torch.softmax(query_for_score, dim=-1) query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time) query_layer = self.transpose_for_scores( mixed_query_layer ) # (batch, n_heads, time, attn_dim) pooled_query = ( torch.matmul(query_weight, query_layer) .transpose(1, 2) .reshape(-1, 1, self.num_attention_heads * self.attention_head_size) ) # (batch, 1, size = n_heads * attn_dim) pooled_query = self.dropout(pooled_query) pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size) mixed_query_key_layer = ( mixed_key_layer * pooled_query_repeat ) # (batch, time, size) # (batch, n_heads, time) query_key_score = ( self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5 ).transpose(1, 2) if mask is not None: min_value = float( numpy.finfo( torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype ).min ) query_key_score = query_key_score.masked_fill(mask, min_value) query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill( mask, 0.0 ) else: query_key_weight = torch.softmax(query_key_score, dim=-1) query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time) key_layer = self.transpose_for_scores( mixed_query_key_layer ) # (batch, n_heads, time, attn_dim) pooled_key = torch.matmul( query_key_weight, key_layer ) # (batch, n_heads, 1, attn_dim) pooled_key = self.dropout(pooled_key) # NOTE: value = query, due to param sharing weighted_value = (pooled_key * query_layer).transpose( 1, 2 ) # (batch, time, n_heads, attn_dim) weighted_value = weighted_value.reshape( weighted_value.shape[:-2] + (self.num_attention_heads * self.attention_head_size,) ) # (batch, time, size) weighted_value = ( self.dropout(self.transform(weighted_value)) + mixed_query_layer ) return weighted_value