根据论文的CoPE Pytorch实现

根据论文《Contextual Position Encoding: Learning to Count What’s Important》编写的CoPE代码。具备多头计算能力,本人水平不高,个人代码未经验证,有问题和建议欢迎指出。

CoPE代码为代替Llama Model中的RoPE而设计,因此函数名称类似于LlamaModel的RoPE,但实际功能略有不同,有待商榷,使用猴子补丁动态替换LlamaAttention中的forward来达到使用CoPE的目的。

1. llama attention init函数

def llama_attn_init(self, config: LlamaConfig):
    super(LlamaAttention, self).__init__()
    self.config = config
    self.hidden_size = config.hidden_size
    self.num_heads = config.num_attention_heads
    self.head_dim = self.hidden_size // self.num_heads
    self.num_key_value_heads = config.num_key_value_heads
    self.num_key_value_groups = self.num_heads // self.num_key_value_heads
    self.max_position_embeddings = config.max_position_embeddings

    if (self.head_dim * self.num_heads) != self.hidden_size:
        raise ValueError(
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
            f" and `num_heads`: {self.num_heads})."
        )
    self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
    self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
    self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
    self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
    self._init_cope()

2. init cope函数

def _init_cope(self):
    self.max_cope_position = self.config.max_cope_position
    self.cope_embeddings = nn.Parameter(
        torch.randn(self.num_heads, self.max_cope_position, self.hidden_size // self.num_heads))

3. 计算CoPE

def apply_cope_emb(query_states, key_states, cope_embs, max_cope_position):
    """
    :param query_states: [bsz, num_heads, q_len, head_dim], query tensors
    :param key_states: [bsz, num_key_value_heads, kv_len, head_dim], key tensors
    :param cope_emb: [num_heads, max_cope_position, head_dim], Contextual Position Embedding
    :param max_cope_position: maximum position of the Contextual Position Embedding
    :return: position_weight: [bsz, num_heads, q_len, kv_len], position weight
    first calculate the z weights, which is the map of all $$z_{i}[p] = q_i \cdot e[p]$$ in paper.
    Then calculate the gate values, which is the map of all $$p_{i,j} = \sum^{i}_{k=j} g_{ik}$$ in paper.
    Then calculate the position weight, which is the map of all $$z_i[p_{ij}]$$ in paper.
    """
    z_weights = torch.einsum("bnld, nmd -> bnlm", query_states, cope_embs)  # [bsz, num_heads, q_len, max_cope_position]
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, kv_len, _ = key_states.size()
    gate_values = torch.einsum("bnid, bnjd -> bnij", query_states, key_states)
    gate_values = torch.nn.functional.sigmoid(gate_values)
    gate_values = gate_values.cumsum(dim=-1)
    index = torch.arange(kv_len - q_len, kv_len, device=key_states.device).reshape(1, 1, 1, -1)
    base_gate_values = torch.gather(gate_values, 3, index).reshape(1, 1, -1, 1)
    gate_values = base_gate_values - gate_values
    gate_values = torch.minimum(gate_values, torch.ones_like(gate_values) * max_cope_position)
    gate_values = torch.maximum(gate_values, torch.zeros_like(gate_values))
    floor_gate_values = torch.floor(gate_values)  # [bsz, num_heads, q_len, kv_len]
    position_weight_floor = torch.gather(z_weights, -1, floor_gate_values.long())
    position_weight_ceil = torch.gather(z_weights, -1, floor_gate_values.long() + 1)
    position_weight = position_weight_floor + (gate_values - floor_gate_values) * (position_weight_ceil - position_weight_floor)
    return position_weight

4. llama attention forward函数

def llama_attn_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  ...
kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] position_weight = apply_cope_emb(query_states, key_states, self.cope_embeddings, self.max_cope_position)   
  ...
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = attn_weights + position_weight   
  ...

5. 使用示例

LlamaAttention._init_cope = _init_cope
LlamaAttention.__init__ = llama_attn_init
LlamaAttention.forward = llama_attn_forward
config = LlamaConfig(num_layers=3, num_heads=2, hidden_size=128, intermediate_size=256, hidden_act="gelu",
                     max_cope_position=128)  # add an extra new parameter max_cope_position
model = LlamaModel(config)
print(model)
inputs_embeds = torch.randn(1, 5, 128)
output = model(inputs_embeds=inputs_embeds)

 

posted @ 2024-06-03 16:33  myendless0  阅读(91)  评论(0编辑  收藏  举报