根据论文的CoPE Pytorch实现
根据论文《Contextual Position Encoding: Learning to Count What’s Important》编写的CoPE代码。具备多头计算能力,本人水平不高,个人代码未经验证,有问题和建议欢迎指出。
CoPE代码为代替Llama Model中的RoPE而设计,因此函数名称类似于LlamaModel的RoPE,但实际功能略有不同,有待商榷,使用猴子补丁动态替换LlamaAttention中的forward来达到使用CoPE的目的。
1. llama attention init函数
def llama_attn_init(self, config: LlamaConfig): super(LlamaAttention, self).__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self._init_cope()
2. init cope函数
def _init_cope(self): self.max_cope_position = self.config.max_cope_position self.cope_embeddings = nn.Parameter( torch.randn(self.num_heads, self.max_cope_position, self.hidden_size // self.num_heads))
3. 计算CoPE
def apply_cope_emb(query_states, key_states, cope_embs, max_cope_position): """ :param query_states: [bsz, num_heads, q_len, head_dim], query tensors :param key_states: [bsz, num_key_value_heads, kv_len, head_dim], key tensors :param cope_emb: [num_heads, max_cope_position, head_dim], Contextual Position Embedding :param max_cope_position: maximum position of the Contextual Position Embedding :return: position_weight: [bsz, num_heads, q_len, kv_len], position weight first calculate the z weights, which is the map of all $$z_{i}[p] = q_i \cdot e[p]$$ in paper. Then calculate the gate values, which is the map of all $$p_{i,j} = \sum^{i}_{k=j} g_{ik}$$ in paper. Then calculate the position weight, which is the map of all $$z_i[p_{ij}]$$ in paper. """ z_weights = torch.einsum("bnld, nmd -> bnlm", query_states, cope_embs) # [bsz, num_heads, q_len, max_cope_position] bsz, num_heads, q_len, head_dim = query_states.size() _, _, kv_len, _ = key_states.size() gate_values = torch.einsum("bnid, bnjd -> bnij", query_states, key_states) gate_values = torch.nn.functional.sigmoid(gate_values) gate_values = gate_values.cumsum(dim=-1) index = torch.arange(kv_len - q_len, kv_len, device=key_states.device).reshape(1, 1, 1, -1) base_gate_values = torch.gather(gate_values, 3, index).reshape(1, 1, -1, 1) gate_values = base_gate_values - gate_values gate_values = torch.minimum(gate_values, torch.ones_like(gate_values) * max_cope_position) gate_values = torch.maximum(gate_values, torch.zeros_like(gate_values)) floor_gate_values = torch.floor(gate_values) # [bsz, num_heads, q_len, kv_len] position_weight_floor = torch.gather(z_weights, -1, floor_gate_values.long()) position_weight_ceil = torch.gather(z_weights, -1, floor_gate_values.long() + 1) position_weight = position_weight_floor + (gate_values - floor_gate_values) * (position_weight_ceil - position_weight_floor) return position_weight
4. llama attention forward函数
def llama_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
...
kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] position_weight = apply_cope_emb(query_states, key_states, self.cope_embeddings, self.max_cope_position)
...
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = attn_weights + position_weight
...
5. 使用示例
LlamaAttention._init_cope = _init_cope LlamaAttention.__init__ = llama_attn_init LlamaAttention.forward = llama_attn_forward config = LlamaConfig(num_layers=3, num_heads=2, hidden_size=128, intermediate_size=256, hidden_act="gelu", max_cope_position=128) # add an extra new parameter max_cope_position model = LlamaModel(config) print(model) inputs_embeds = torch.randn(1, 5, 128) output = model(inputs_embeds=inputs_embeds)