LLaMA 3 源码解读-大语言模型5

本来不是很想写这一篇,因为网上的文章真的烂大街了,我写的真的很有可能没别人写得好。但是想了想,创建这个博客就是想通过对外输出知识的方式来提高自身水平,而不是说我每篇都能写得有多好多好然后吸引别人来看。那作为对整个合集内容的完善,这篇博客会解析现在最火的LLaMA3的模型架构,搞清楚现在的LLM都是啥样的。

事先说明,LLaMA 3 相较于LLaMA 2 在网络架构上没有改进。用知乎网友的话说,“LLaMA 3的发布,更强调了数据工程的重要:模型架构不变,更多的数据量和更高数据质量能够带来明显模型效果提升”。但是仔细看看一个LLM的源码,对于我这种初学者,还是非常有必要的。

https://zhuanlan.zhihu.com/p/693428105

还有就是,这个博客解析的源码是d6e09315954d1a547bf45e37269978c049e73d33这个版本的。如果后面Meta更新的部分代码导致和这篇博客内容对不上,你可以先翻阅这个版本的源码。如果还有什么解决不了的,可以在这篇博客下面给我留言,我们共同学习共同进步。

Generation#

Llama.build模型实例化与如何看源码#

我们通过llama3的ReadMe,找到了这个demo,demo通过

from llama import Dialog, Llama

generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )
results = generator.chat_completion(dialogs, max_gen_len, temperature, top_p)

完成对话。它先调用了 Llama.build,再对返回的对象调用了generator.chat_completion完成对话的功能;导入的库是llama。 进而关注到repo下面的llama文件夹,所以会先看一看文件夹下面的__init__.py

from .generation import Llama
from .model import ModelArgs, Transformer
from .tokenizer import Dialog, Tokenizer

所以demo调用的 Llama.build.generation里面。顺藤摸瓜找到:

class Llama:
    @staticmethod
    def build(
        ckpt_dir: str,
        tokenizer_path: str,
        max_seq_len: int,
        max_batch_size: int,
        model_parallel_size: Optional[int] = None,
        seed: int = 1,
    ) -> "Llama":
        """
        Build a Llama instance by initializing and loading a model checkpoint.

        Args:
            ckpt_dir (str): 模型检查点文件的路径
            tokenizer_path (str): 模型tokenizer文件路径.
            max_seq_len (int): Maximum sequence length for input text.
            max_batch_size (int): Maximum batch size for inference.
            model_parallel_size (Optional[int], optional): Number of model parallel processes.
                If not provided, it's determined from the environment. Defaults to None.

        Returns:
            Llama: An instance of the Llama class with the loaded model and tokenizer.
        """
        # 这里首先是一些模型并行设置
        if not torch.distributed.is_initialized():
            torch.distributed.init_process_group("nccl")
        if not model_parallel_is_initialized():
            if model_parallel_size is None:
                model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
            initialize_model_parallel(model_parallel_size)

        # 多机训练/推理一个模型的话,每个机器都会有个rank。这里就是配置这个rank的。
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        # 随机种子
        torch.manual_seed(seed)
        # 设置输出只在一台设备上进行
        if local_rank > 0:
            sys.stdout = open(os.devnull, "w")

        # 终于到加载模型相关的代码了
        start_time = time.time()
        checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
        # 检查模型检查点文件的数量是否合乎要求
        assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
        assert model_parallel_size == len(
            checkpoints
        ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"

        # 加载模型。多机运行时`get_model_parallel_rank()`返回的结果不一样,所以不需要写for循环。这里的思想有cuda编程那味了
        ckpt_path = checkpoints[get_model_parallel_rank()]
        checkpoint = torch.load(ckpt_path, map_location="cpu")

        # TODO: 读取`params.json`并通过类`ModelArgs`加载进变量`model_args`。这个类我们待会讲
        with open(Path(ckpt_dir) / "params.json", "r") as f:
            params = json.loads(f.read())
        model_args: ModelArgs = ModelArgs(
            max_seq_len=max_seq_len,
            max_batch_size=max_batch_size,
            **params,
        )

        # TODO: 加载Tokenizer。Tokenizer我们待会讲
        tokenizer = Tokenizer(model_path=tokenizer_path)
        assert model_args.vocab_size == tokenizer.n_words

        # 半精度相关
        if torch.cuda.is_bf16_supported():
            torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
        else:
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
        
        # TODO: 是的,llama3的模型主体就是这里的Transformer类。直接model.load_state_dict就能加载好权重。这个也待会讲
        model = Transformer(model_args)
        model.load_state_dict(checkpoint, strict=False)
        print(f"Loaded in {time.time() - start_time:.2f} seconds")

        # TODO: 到这里其实啥都加载完了,这里返回了个Llama类。
        return Llama(model, tokenizer)

这段代码看下来逻辑很清晰,就是给我们留下了几个TODO,这些我们都会讲到。

ModelArgs#

我们首先看到ModelArgs类,这个类只用于保存一些参数,@dataclass装饰器就已经说明了一切:

@dataclass
class ModelArgs:
    dim: int = 4096  # 模型维度
    n_layers: int = 32  # 层数
    n_heads: int = 32  # 头数
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1  # 词汇表大小
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000
    max_batch_size: int = 32
    max_seq_len: int = 2048  # 序列长度

Llama.__init__()#

最后这一句return Llama(model, tokenizer),它实际上会调用Llama.__init__(),代码如下:

from llama.tokenizer import ChatFormat, Dialog, Message, Tokenizer

def __init__(self, model: Transformer, tokenizer: Tokenizer):
    self.model = model
    self.tokenizer = tokenizer
    # TODO: ChatFormat类解析
    self.formatter = ChatFormat(tokenizer)

是的,简单赋值就结束了。formatter这里用到的ChatFormat类我们一会随tokenizer一起解析。

Llama.chat_completion:LLM是如何生成回复的#

generator.chat_completion(dialogs, max_gen_len, temperature, top_p)中,generator是调用Llama.build得到的,而Llama.build返回的又是一个Llama类,所以generator.chat_completion就是Llama类的chat_completion方法:

    def chat_completion(
        self,
        dialogs: List[Dialog],
        temperature: float = 0.6,
        top_p: float = 0.9,
        max_gen_len: Optional[int] = None,
        logprobs: bool = False,
    ) -> List[ChatPrediction]:
        """
        Generate assistant responses for a list of conversational dialogs using the language generation model.

        Args:
            dialogs (List[Dialog]): 
                模型输入。是一个包含多个 Dialog 的列表,每个 Dialog 表示一个对话,一个对话包括了多条消息。
                List of conversational dialogs, where each dialog is a list of messages.
            temperature (float, optional): 
                控制生成文本的随机性。较高的值(如 1.0)使输出更加随机,较低的值(如 0.2)则使输出更具确定性。默认值为 0.6。
                Temperature value for controlling randomness in sampling. Defaults to 0.6.
            top_p (float, optional): 
                计算Top-p 的 nucleus sampling 机制的概率阈值,生成时只选择累计概率超过该阈值的词汇。默认值为 0.9。
                Top-p probability threshold for nucleus sampling. Defaults to 0.9.
            max_gen_len (Optional[int], optional): 
                Maximum length of the generated response sequence.
                If not provided, it's set to the model's maximum sequence length minus 1.
            logprobs (bool, optional): 
                如果设置为 True,会计算并返回生成token的对数概率(即每个标记的生成概率)。默认值为 False。
                Flag indicating whether to compute token log probabilities. Defaults to False.

        Returns:
            List[ChatPrediction]: 
                返回一个列表,包含生成的多个ChatPrediction。每个ChatPrediction由generation、tokens和logprobs构成
                List of chat predictions, each containing the assistant's generated response.

        Note:
            This method generates assistant responses for the provided conversational dialogs.
            It employs nucleus sampling to introduce controlled randomness in text generation.
            If logprobs is True, token log probabilities are computed for each generated token.
        """
        # 如果没有提供 max_gen_len,则默认使用模型的最大序列长度减去 1
        if max_gen_len is None:
            max_gen_len = self.model.params.max_seq_len - 1

        # 对每个dialog进行编码,使用 encode_dialog_prompt 将对话转换成适合模型的prompt_tokens。
        # TODO: self.formatter.encode_dialog_prompt
        prompt_tokens = [
            self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
        ]

        # 调用 generate 方法,根据输入的prompt_tokens以及其他参数生成回复。该方法返回两个输出,
        # generation_tokens是模型生成的回复标记,generation_logprobs是每个生成标记的对数概率,
        # 当然,仅在logprobs为True时有generation_logprobs输出。
        # TODO: self.generate解读
        generation_tokens, generation_logprobs = self.generate(
            prompt_tokens=prompt_tokens,
            max_gen_len=max_gen_len,
            temperature=temperature,
            top_p=top_p,
            logprobs=logprobs,
        )

        # 如果 logprobs=True,返回每个生成的回复的标记及其对应的对数概率:
        if logprobs:
            return [
                {
                    "generation": {
                        "role": "assistant",
                        "content": self.tokenizer.decode(t),  # 生成的文本内容
                    },
                    "tokens": [self.tokenizer.decode([x]) for x in t],  # 生成的token
                    "logprobs": logprobs_i,  # 对数概率
                }
                for t, logprobs_i in zip(generation_tokens, generation_logprobs)
            ]

        # 如果 logprobs=False,则仅返回生成的文本:
        return [
            {
                "generation": {
                    "role": "assistant",
                    "content": self.tokenizer.decode(t),
                },
            }
            for t in generation_tokens
        ]

Llama.generate:根据prompt生成输出#

    @torch.inference_mode()  # 用于提高推理性能的装饰器
    def generate(
        self,
        prompt_tokens: List[List[int]],
        max_gen_len: int,
        temperature: float = 0.6,
        top_p: float = 0.9,
        logprobs: bool = False,
        echo: bool = False,
    ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
        """
        Generate text sequences based on provided prompts using the language generation model.

        Args:
            prompt_tokens (List[List[int]]): 
                提示词列表,列表中每个元素都代表一个prompt_token,prompt_token就是分词后的ID。
                List of tokenized prompts, where each prompt is represented as a list of integers.
            max_gen_len (int): Maximum length of the generated text sequence.
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
            echo (bool, optional): 
                指示是否在生成的输出中包含提示词。
                Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

        Returns:
            Tuple[List[List[int]], Optional[List[List[float]]]]: 
                A tuple containing generated token sequences and, if logprobs is True, 
                corresponding token log probabilities.

        Note:
            This method uses the provided prompts as a basis for generating text. 
            It employs nucleus sampling to produce text with controlled randomness.
            If logprobs is True, token log probabilities are computed for each generated token.

        """
        # 确保 prompt_token 的 batch_size(即有多少个prompt)小于模型参数中设置的最大的 batch_size
        params = self.model.params
        bsz = len(prompt_tokens)
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        # 确保每个tokenized prompt 中最大的 prompt 长度小于 params.max_seq_len
        min_prompt_len = min(len(t) for t in prompt_tokens)
        max_prompt_len = max(len(t) for t in prompt_tokens)
        assert max_prompt_len <= params.max_seq_len
        # total_len: 生成文本的总长度
        # max_gen_len: 默认为params.max_seq_len - 1
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

        # 初始化一个全填充为 pad_id 的 tokens 张量,将每个 prompt 的实际 token 填充到 tokens 张量的对应位置。
        # 换句话说,就是将 prompt_token 填充至 total_len 大小,填充的值为 pad_id
        pad_id = self.tokenizer.pad_id
        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
        if logprobs:
            token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

        # prev_pos 初始化为 0,表示当前处理到序列中的起始位置。
        prev_pos = 0
        # 用于标记每条输入是否达到了 End Of Sentence(EOS)
        eos_reached = torch.tensor([False] * bsz, device="cuda")
        # 创建了一个 input_text_mask 标识哪些位置不是 pad_id
        input_text_mask = tokens != pad_id
        
        # 如果整个batch的最小prompt长度等于生成文本的总长度,也就是说prompt过长模型无需额外生成文本,
        # 就直接获取模型的输出 logits 计算 ce loss
        if min_prompt_len == total_len:
            logits = self.model.forward(tokens, prev_pos)
            token_logprobs = -F.cross_entropy(
                input=logits.transpose(1, 2),
                target=tokens,
                reduction="none",
                ignore_index=pad_id,
            )

        stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))

        # 从 prompt 结束到生成文本的总长度
        for cur_pos in range(min_prompt_len, total_len):
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)  # 模型forward生成logits
            if temperature > 0:
                # temperature 用于 softmax 归一化后经 top-p 采样后得到下一个 token
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                # 用啥 softmax, argmax 又不是不能用
                next_token = torch.argmax(logits[:, -1], dim=-1)

            next_token = next_token.reshape(-1)
            # 将 next_token 中不属于 prompt token 的部分赋值给 tokens
            # only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token

            # 计算并更新token_logprobs
            if logprobs:
                token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
                    input=logits.transpose(1, 2),
                    target=tokens[:, prev_pos + 1 : cur_pos + 1],
                    reduction="none",
                    ignore_index=pad_id,
                )

            # 检查是否达到EOS
            eos_reached |= (~input_text_mask[:, cur_pos]) & (
                torch.isin(next_token, stop_tokens)
            )
            prev_pos = cur_pos
            if all(eos_reached):  # 如果所有样本都达到了EOS,则提前退出循环,停止生成
                break

        if logprobs:
            token_logprobs = token_logprobs.tolist()
        out_tokens, out_logprobs = [], []

        # 遍历每个样本的tokens:
        for i, toks in enumerate(tokens.tolist()):
            # cut to max gen len
            # 如果echo=True则生成内容包含原始prompt,否则不包括
            start = 0 if echo else len(prompt_tokens[i])
            # 限制序列长度不超过max_gen_len加上原始提示的长度
            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
            probs = None
            if logprobs:
                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]

            # 将第一个EOS后面的全部内容都删了cut to after eos tok if any
            for stop_token in self.tokenizer.stop_tokens:
                try:
                    eos_idx = toks.index(stop_token)
                    toks = toks[:eos_idx]
                    probs = probs[:eos_idx] if logprobs else None
                except ValueError:
                    pass

            out_tokens.append(toks)
            out_logprobs.append(probs)
        return (out_tokens, out_logprobs if logprobs else None)

Model#

这一部分应该是被人关心得最多的部分了。

Transformer.__init__()#

首先看模型初始化,这里就是设置了一堆类的属性。我们直接上代码,解析见代码注释:

from fairscale.nn.model_parallel.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    VocabParallelEmbedding,
) # FairScale库的模块都是用于实现模型并行化的,不需要深究

class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        # VocabParallelEmbedding类导入自fairscale,功能同`torch.nn.embedding`
        self.tok_embeddings = VocabParallelEmbedding(
            params.vocab_size, params.dim, init_method=lambda x: x
        )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            # TODO: TransformerBlock
            self.layers.append(TransformerBlock(layer_id, params))

        # TODO: RMSNorm
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)

        # ColumnParallelLinear 相当于 `torch.nn.linear`
        self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )

        # TODO: precompute_freqs_cis
        self.freqs_cis = precompute_freqs_cis(
            params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,
        )

RMSNorm#

RMSNorm是均值为0的LayerNorm:

(1)a¯i=aiRMS(a)gi where RMS(a)=1ni=1nai2

注:layerNorm为

(2)a¯i=aiμσgi where μ=1ni=1nai and σ=1ni=1n(aiμ)2

用代码实现出来是这个样子的:

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # 初始化为1的可学习参数

    def _norm(self, x):
        # torch.rsqrt: 平方根的倒数,这里用于计算标准差的倒数
        # x.pow(2).mean(-1, keepdim=True): 沿着倒数第一维计算平方并求平均
        #    a_i * 元素平方的均值取平方根后再取倒数 + 无穷小量
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

作者认为这种模式在简化了Layer Norm的同时,可以在各个模型上减少约 7%∼64% 的计算时间

旋转位置编码RoPE#

该部分内容参考了 苏剑林的博客。苏剑林是RoPE的发明者。

旋转位置编码通过绝对位置编码的方式实现相对位置编码。假设通过下述运算来给 q,k 添加绝对位置信息:

分别为 q,k 设计操作 f(,m),f(,n) ,使得经过该操作后,q~m,k~n 就带有了位置 m,n 的绝对位置信息。Attention的核心运算是内积,所以我们希望的内积的结果带有相对位置信息,因此假设存在恒等关系:

(3)f(q,m),f(k,n)=g(q,k,mn)

解得:

(4)f(q,m)=Rf(q,m)eiΘf(q,m)=qei(Θ(q)+mθ)=qeimθ

可以写成:

(5)f(q,m)=(cosmθsinmθsinmθcosmθ)(q0q1)

由于内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接,即:

(6)(cosmθ0sinmθ00000sinmθ0cosmθ0000000cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθd/21sinmθd/210000sinmθd/21cosmθd/21)Rm(q0q1q2q3qd2qd1)

我们便可以通过以下方式实现RoPE:

(7)(q0q1q2q3qd2qd1)(cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/21cosmθd/21)+(q1q0q3q2qd1qd2)(sinmθ0sinmθ0sinmθ1sinmθ1sinmθd/21sinmθd/21)

precompute_freqs_cis#

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    # 计算词向量元素两两分组以后,每组元素对应的旋转角度 
    # torch.arange(0, dim, 2): 生成 [0,2,4...126]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)  # t = [0,....end]
    # torch.outer: torch.outer(a, b) = a^T * b
    freqs = torch.outer(t, freqs)  # freqs.shape = (t.len(),freqs.len()) #shape (end,dim//2)

    # 根据角坐标生成复数向量
    # torch.polar(abs,angle): abs*cos(angle) + abs*sin(angle)*j
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # freqs_cis.shape  = (end,dim//2)
    return freqs_cis

reshape_for_broadcast#

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    # ndim为x的维度数, 此时应该为4
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    # (1, x.shape[1], 1, x.shape[-1])
    return freqs_cis.view(*shape)

apply_rotary_emb#

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """将xq和xk的最后一个维度进行复数运算,得到新的xq和xk"""
    # xq.shape = [bsz, seqlen, self.n_local_heads, self.head_dim]
    # xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2 , 2]
    # torch.view_as_complex用于将二维向量转换为复数域 torch.view_as_complex即([x,y]) -> (x+yj)
    # 所以经过view_as_complex变换后xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # freqs_cis.shape = (1,x.shape[1],1,x.shape[-1])
    
    # xq_ 与freqs_cis广播哈达玛积
    # [bsz, seqlen, self.n_local_heads, self.head_dim//2] * [1,seqlen,1,self.head_dim//2]
    # torch.view_as_real用于将复数再转换回实数向量, 再经过flatten展平第4个维度 
    # [bsz, seqlen, self.n_local_heads, self.head_dim//2] ->[bsz, seqlen, self.n_local_heads, self.head_dim//2,2 ] ->[bsz, seqlen, self.n_local_heads, self.head_dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

TransformerBlock#

这个类比较简单,只是一个transformer block。


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        """初始化函数主要就是定义了transformer block的各个组件,包括自注意力机制和前馈神经网络。"""
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads

        # TODO: Attention
        self.attention = Attention(args)

        # TODO: FeedForward
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of,  ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        """这个函数是transformer block的前向传播函数,输入是x,start_pos,freqs_cis,mask,输出是out"""
        # 这个函数的实现比较简单,首先对输入张量x进行自注意力机制计算,然后对计算结果进行残差连接和归一化,再通过前馈神经网络计算,最后再次进行残差连接和归一化。
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

Attention#

为了实现Group Query Attention,这里用到了一个函数repeat_kv,它的作用是将key和value的head维度重复n_rep次,以匹配query的head数。repeat_kv函数使用 expand 方法将输入张量在第四个维度上扩展 n_rep 次,并使用 reshape 方法将其调整为适当的形状

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )# 精简版Attention
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.wq = Linear(...)
        self.wk = Linear(...)
        self.wv = Linear(...)
        
        self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)def forward(self, x: torch.Tensor):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        # attention 操作之前,应用旋转位置编码
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
        #...
        # 进行后续Attention计算
        scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
        scores = F.softmax(scores.float(), dim=-1)
        output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)

FeedForward类与SwiGLU激活函数#

FeedForward类实现的是:

(8)FFNswiGLU(x,W,V,W2)=(Swish1(xW)xV)W2


使用的激活函数是SwiGLU,这里有:

(9)SwiGLU=Swish(Wx+b)(Vx+c)

(10)Swish(x)=x×sigmoid(βx)

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):  # 我们不妨跳过这个函数,太无聊了
        ...

    def forward(self, x):
        # w2 * silu(w1 * x) * w3
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

以下内容参考知乎

β=1swish(x)就是silu(x)

(11)silu(x)=x×sigmoid(x)=x1+ex

函数图像如下:

Transformer.forward()#

前向传播就是我们熟悉的 Transformer 前向传播了。

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape  # 批大小和序列长度
        h = self.tok_embeddings(tokens)  # 词嵌入层进行嵌入,得到表示输入序列的张量h
        self.freqs_cis = self.freqs_cis.to(h.device)  # 将频率转换为与输入张量相同的设备
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]  # 从预计算的频率张量中提取频率

        mask = None  # 用于在自注意力机制中屏蔽不必要的位置的mask
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)  # 创建一个形状为(seqlen, seqlen)的张量,填充为负无穷
            mask = torch.triu(mask, diagonal=1)  # 上三角矩阵
            mask = torch.hstack(
                [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
            ).type_as(h)  # 将mask张量与全零张量水平拼接,以适应输入张量h的维度

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)  # 逐层进行transformer计算
        h = self.norm(h)  # 对输出张量进行归一化
        output = self.output(h).float()  # 输出层进行线性变换
        return output

Tokenizer#

RoleMessageDialog#

直接看代码吧,看不懂的去补补python基础语法:

Role = Literal["system", "user", "assistant"]

class Message(TypedDict):
    role: Role
    content: str

Dialog = Sequence[Message]

Tokenizer#

Tokenizer类主要调用tiktoken库,没啥好讲的。该类存在的主要意义是方便自定义token。类里的方法大多是前面定义了一大堆东西,但是翻阅具体业务的时候发现其实还是在调库。encode部分仅对 bos eos allowed_special disallowed_special等参数进行了封装,decode部分更是直接self.model.decode(cast(List[int], t))_split_whitespaces_or_nonwhitespaces见名知意用于将字符串按空白或非空白字符的最大连续长度进行分割。

class Tokenizer:
    """
    Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
    """

    special_tokens: Dict[str, int]  # 注意这里
    num_reserved_special_tokens = 256
    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501

    def __init__(self, model_path: str):
        """
        Initializes the Tokenizer with a Tiktoken model.

        Args:
            model_path (str): The path to the Tiktoken model file.
        """
        assert os.path.isfile(model_path), model_path
        mergeable_ranks = load_tiktoken_bpe(model_path)
        num_base_tokens = len(mergeable_ranks)
        special_tokens = [
            "<|begin_of_text|>", "<|end_of_text|>",
            "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>",  # end of turn
            "<|reserved_special_token_0|>", "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>", "<|reserved_special_token_3|>", "<|reserved_special_token_4|>",
        ] + [
            f"<|reserved_special_token_{i}|>"
            for i in range(5, self.num_reserved_special_tokens - 5)
        ]
        self.special_tokens = {
            token: num_base_tokens + i for i, token in enumerate(special_tokens)
        }
        self.model = tiktoken.Encoding(
            name=Path(model_path).name, pat_str=self.pat_str,
            mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens,
        )
        self.n_words: int = self.model.n_vocab
        # BOS / EOS token IDs
        self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
        self.eos_id: int = self.special_tokens["<|end_of_text|>"]
        self.pad_id: int = -1
        self.stop_tokens = {
            self.special_tokens["<|end_of_text|>"],
            self.special_tokens["<|eot_id|>"],
        }

    def encode(
        self, s: str, *, bos: bool, eos: bool,
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
        disallowed_special: Union[Literal["all"], Collection[str]] = (),
    ) -> List[int]:
        """
        Encodes a string into a list of token IDs.

        Args:
            s (str): The input string to be encoded.
            bos (bool): Whether to prepend the beginning-of-sequence token.
            eos (bool): Whether to append the end-of-sequence token.
            allowed_tokens ("all"|set[str]): allowed special tokens in string
            disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string

        Returns:
            list[int]: A list of token IDs.

        By default, setting disallowed_special=() encodes a string by ignoring
        special tokens. Specifically:
        - Setting `disallowed_special` to () will cause all text corresponding
          to special tokens to be encoded as natural text (insteading of raising
          an error).
        - Setting `allowed_special` to "all" will treat all text corresponding
          to special tokens to be encoded as special tokens.
        """
        assert type(s) is str

        # The tiktoken tokenizer can handle <=400k chars without pyo3_runtime.PanicException.
        TIKTOKEN_MAX_ENCODE_CHARS = 400_000

        # Here we iterate over subsequences and split if we exceed the limit of max consecutive non-whitespace or whitespace characters.
        MAX_NO_WHITESPACES_CHARS = 25_000

        substrs = (
            substr
            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
            for substr in self._split_whitespaces_or_nonwhitespaces(
                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
            )
        )
        t: List[int] = []
        for substr in substrs:
            t.extend(
                # 调用在这里
                self.model.encode(
                    substr,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            )
        if bos:
            t.insert(0, self.bos_id)
        if eos:
            t.append(self.eos_id)
        return t

    def decode(self, t: Sequence[int]) -> str:
        """
        Decodes a list of token IDs into a string.

        Args:
            t (List[int]): The list of token IDs to be decoded.

        Returns:
            str: The decoded string.
        """
        # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
        return self.model.decode(cast(List[int], t))

    @staticmethod
    def _split_whitespaces_or_nonwhitespaces(
        s: str, max_consecutive_slice_len: int
    ) -> Iterator[str]:
        """
        Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
        consecutive whitespaces or consecutive non-whitespaces.
        """
        current_slice_len = 0
        current_slice_is_space = s[0].isspace() if len(s) > 0 else False
        slice_start = 0

        for i in range(len(s)):
            is_now_space = s[i].isspace()

            if current_slice_is_space ^ is_now_space:
                current_slice_len = 1
                current_slice_is_space = is_now_space
            else:
                current_slice_len += 1
                if current_slice_len > max_consecutive_slice_len:
                    yield s[slice_start:i]
                    slice_start = i
                    current_slice_len = 1
        yield s[slice_start:]

ChatFormat#

ChatFormat类借助Tokenizer类,对Tokenizer进行了进一步包装,提供了encode_headerencode_messageencode_dialog_prompt三种encode方式。Tokenizer 类负责将文本转换为token以及将token转换为文本。

class ChatFormat:
    def __init__(self, tokenizer: Tokenizer):
        self.tokenizer = tokenizer

    def encode_header(self, message: Message) -> List[int]:
        """将一条消息的角色部分编码成token序列"""
        tokens = []
        # 首先添加特殊token <|start_header_id|> 标记角色名称的开始
        tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
        # 然后编码角色名称,不添加句首(BOS)和句尾(EOS)标记
        tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
        # 接着添加特殊token <|end_header_id|> 表示角色信息结束
        tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
        # 最后编码两个换行符 \n\n 以分隔不同部分
        tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
        return tokens

    def encode_message(self, message: Message) -> List[int]:
        """编码整个消息(包括角色和内容)为token序列"""
        # 使用 encode_header 方法编码角色信息
        tokens = self.encode_header(message)
        # 接着编码消息的内容部分,去除首尾空格,不添加BOS和EOS标记
        tokens.extend(
            self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
        )
        # 在最后添加特殊的 <|eot_id|> token表示消息结束
        tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
        return tokens

    def encode_dialog_prompt(self, dialog: Dialog) -> List[int]:
        """将整个对话编码为token序列,准备输入到模型中"""
        tokens = []
        # 首先添加特殊token <|begin_of_text|> 标识文本的开始
        tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
        # 遍历对话中的每条消息,使用 encode_message 方法将其编码并追加到tokens列表中
        for message in dialog:
            tokens.extend(self.encode_message(message))
        # 在所有已有消息之后,添加一个仅包含角色"assistant"的header,但没有具体内容,
        # 这是为了让模型知道接下来应该生成assistant的消息。
        # Add the start of an assistant message for the model to complete.
        tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
        return tokens

总结#

以上就是全部的源码解读。如有疑问请留言。

作者:xiangcaoacao

出处:https://www.cnblogs.com/xiangcaoacao/p/18173863

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   vanilla阿草  阅读(4672)  评论(6编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 如何使用 Uni-app 实现视频聊天(源码,支持安卓、iOS)
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
more_horiz
keyboard_arrow_up dark_mode palette
选择主题
menu
点击右上角即可分享
微信分享提示