llama源码阅读

收到这个启发:

https://www.bilibili.com/video/BV1Cw411y7gs/?p=5&spm_id_from=pageDriver&vd_source=d68ed178f151e80fea1e02efd205802c

原来的模型也可以自己单机低成本调试.  修改config里面参数即可. 下面这个我自己电脑跑debug, 用了2.6g内存.

这个是调试代码

from transformers.models .llama import LlamaModel,LlamaConfig
import torch

def run ():
  llamaconfig= LlamaConfig(vocab_size=32000,
  hidden_size=4096//2,
  intermediate_size=1108//2,
  num_hidden_layers=32//2,
  num_attention_heads=32//2,max_position_embeddings=2048//2)
  llamamodel = LlamaModel(config=llamaconfig)
  inputs_ids  = torch.randint(low=0,high=llamaconfig.vocab_size, size=(4,30))
  res = llamamodel(inputs_ids)
  print(res)

run()

 

下面就把代码debug时候的核心部分写下来, 

 

 def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)
embedding的padding_idx填充函数. 如果有padding_idx,那么就把weights里面这个权重弄成0. 让padding编码成了0向量. 在llama里面我们不用这个逻辑.
 
 

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)  #=======从1衰减到0.
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
    def _set_cos_sin_cache(self, seq_len, device, dtype): #算sin, cos 对于位置来计算.
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

 

 

#==========归一化.

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

 

#=======self attention 核心代码:

  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

 

#====把向量后半部分和前半部分互换.

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

 

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):

 

    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)  #旋转过的乘以sin, 旋转钱的乘以cos. 至于为什么这么做. 后续我会查询论文.
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

 

 

这些就是llama的代码核心了.

后续我会加入论文上的分析.

 旋转编码:https://zhuanlan.zhihu.com/p/653958566

posted on 2023-12-05 08:07  张博的博客  阅读(97)  评论(0编辑  收藏  举报

导航