Transformers-源码解析-五十二-

Transformers 源码解析(五十二)

.\models\gemma\convert_gemma_weights_to_hf.py

# 版权声明和信息
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入必要的库和模块
import argparse  # 用于解析命令行参数
import os  # 用于操作系统相关功能
import warnings  # 用于警告处理

import torch  # PyTorch库,用于深度学习
from accelerate import init_empty_weights  # 加速库,用于加速训练

from transformers import GemmaConfig, GemmaForCausalLM, GemmaTokenizer  # Hugging Face Transformers库,用于自然语言处理模型

# 尝试导入GemmaTokenizerFast,如果失败则给出警告并设置为None
try:
    from transformers import GemmaTokenizerFast
except ImportError as e:
    warnings.warn(e)
    warnings.warn(
        "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
    )
    GemmaTokenizerFast = None

"""
示例用法:

python src/transformers/models/gemma/convert_gemma_weights_to_hf.py \
    --input_dir /path/to/downloaded/gemma/weights --model_size 7B --output_dir /output/path
"""

# Gemma模型配置示例
gemma_2b_config = GemmaConfig(
    num_hidden_layers=18,
    num_attention_heads=8,
    num_key_value_heads=1,
    hidden_size=2048,
    intermediate_size=16384,
)

gemma_7b_config = GemmaConfig()  # Gemma 7B模型配置对象

CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config}  # 配置映射字典
LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"}  # 层名称映射字典


def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False, dtype=torch.float32):
    # 从指定路径获取模型参数
    num_attn_heads = config.num_attention_heads  # 注意力头数目
    hidden_size = config.hidden_size  # 隐藏层大小
    num_kv_heads = config.num_key_value_heads  # 键值头数目
    head_dim = config.head_dim  # 头维度

    print(f"Fetching all parameters from the checkpoint at '{input_base_path}'")  # 输出信息:从指定路径获取所有参数
    model_state_dict = torch.load(input_base_path, map_location="cpu")["model_state_dict"]  # 加载模型状态字典
    model_state_dict.pop("freqs_cis")  # 移除特定键值对应的值

    state_dict = {}  # 初始化状态字典
    # 遍历模型状态字典中的键值对
    for k, v in model_state_dict.items():
        # 检查键名是否包含 "qkv_proj"
        if "qkv_proj" in k:
            # 如果 num_kv_heads 等于 1,则执行以下操作
            if num_kv_heads == 1:
                # 重塑张量 v 的形状,将其分成查询(q_proj)、键(k_proj)、值(v_proj)投影
                v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim, hidden_size)
                q_proj = v[:num_attn_heads, ...]  # 提取查询投影
                k_proj = v[num_attn_heads : num_attn_heads + num_kv_heads, ...].repeat(num_kv_heads, 1, 1)  # 提取键投影
                v_proj = v[-num_kv_heads:, ...].repeat(num_kv_heads, 1, 1)  # 提取值投影

                # 将投影后的张量存入状态字典中,键名替换 "qkv_proj" 为 "q_proj", "k_proj", "v_proj"
                state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape(
                    num_attn_heads * head_dim, hidden_size
                ).clone()
                state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape(
                    num_kv_heads * head_dim, hidden_size
                ).clone()
                state_dict[k.replace("qkv_proj", "v_proj")] = v_proj[0].clone()  # 取第一个值投影
            else:
                # 如果 num_kv_heads 不等于 1,则执行以下操作
                q_proj, k_proj, v_proj = torch.split(v, v.shape[0] // 3, 0)  # 分割 v 为查询、键、值投影
                state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape(
                    num_attn_heads * head_dim, hidden_size
                ).clone()
                state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape(
                    num_kv_heads * head_dim, hidden_size
                ).clone()
                state_dict[k.replace("qkv_proj", "v_proj")] = v_proj.clone()  # 存储值投影

        # 如果键名为 "embedder.weight",将其映射到指定的层名称,并同时将 "lm_head.weight" 也设置为该值
        elif k == "embedder.weight":
            state_dict[LAYER_NAME_MAPPING[k]] = v
            state_dict["lm_head.weight"] = v
        else:
            # 对于其他键名,直接复制对应的值到状态字典中
            state_dict[k] = v

    # 设置默认的张量数据类型
    torch.set_default_dtype(dtype)

    # 输出加载 Gemma 模型的消息
    print("Loading the checkpoint in a Gemma model.")
    
    # 使用空权重初始化上下文管理器
    with init_empty_weights():
        # 根据配置创建 GemmaForCausalLM 模型
        model = GemmaForCausalLM(config)
    
    # 使用状态字典加载模型的参数,允许参数赋值但不强制严格匹配
    model.load_state_dict(state_dict, assign=True, strict=False)

    # 设置模型配置中的 Torch 张量数据类型为 float32
    model.config.torch_dtype = torch.float32
    # 删除模型配置中的 _name_or_path 属性
    del model.config._name_or_path
    # 输出保存为 Transformers 格式的消息
    print("Saving in the Transformers format.")

    # 如果需要推送到 Hub
    if push_to_hub:
        # 输出推送模型到指定路径的消息
        print(f"pushing the model to {save_path}")
        # 将模型推送到 Hub,设置为私有模式
        model.push_to_hub(save_path, safe_serialization=safe_serialization, private=True)
    else:
        # 否则,保存模型到指定路径,进行安全序列化
        model.save_pretrained(save_path, safe_serialization=safe_serialization)
# 主函数,程序的入口点
def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser()
    # 添加命令行参数:输入的模型检查点的绝对路径,必选参数
    parser.add_argument(
        "--input_checkpoint",
        help="Absolute path to the target Gemma weights.",
        required=True,
    )
    # 添加命令行参数:Gemma tokenizer 模型的位置,可选参数
    parser.add_argument(
        "--tokenizer_checkpoint",
        help="Location of Gemma tokenizer model",
    )
    # 添加命令行参数:模型的尺寸,默认为 "7B",可选参数
    parser.add_argument(
        "--model_size",
        default="7B",
        choices=["2B", "7B", "tokenizer_only"],
        help="'f' models correspond to the finetuned versions, and are specific to the Gemma2 official release. For more details on Gemma2, checkout the original repo: https://huggingface.co/google/gemma-7b",
    )
    # 添加命令行参数:输出目录,默认为 "google/gemma-7b",用于保存 HF 模型和 tokenizer
    parser.add_argument(
        "--output_dir",
        default="google/gemma-7b",
        help="Location to write HF model and tokenizer",
    )
    # 添加命令行参数:是否使用 `safetensors` 保存数据,默认为 False,可选参数
    parser.add_argument(
        "--pickle_serialization",
        help="Whether or not to save using `safetensors`.",
        action="store_true",
        default=False,
    )
    # 添加命令行参数:是否转换 tokenizer,默认为 False,可选参数
    parser.add_argument(
        "--convert_tokenizer",
        help="Whether or not to convert the tokenizer as well.",
        action="store_true",
        default=False,
    )
    # 添加命令行参数:是否将模型推送到 HF Hub,默认为 False,可选参数
    parser.add_argument(
        "--push_to_hub",
        help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
        action="store_true",
        default=False,
    )
    # 添加命令行参数:转换后模型的目标数据类型,默认为 "float32",可选参数
    parser.add_argument(
        "--dtype",
        default="float32",
        help="Target dtype of the converted model",
    )
    # 解析命令行参数
    args = parser.parse_args()

    # 如果指定了 --convert_tokenizer 参数
    if args.convert_tokenizer:
        # 如果未提供 --tokenizer_checkpoint 参数,则抛出数值错误异常
        if args.tokenizer_checkpoint is None:
            raise ValueError("Path to the tokenizer is required when passing --convert_tokenizer")

        # 构建完整的 tokenizer 路径
        spm_path = os.path.join(args.tokenizer_checkpoint)
        # 调用 write_tokenizer 函数,保存或推送 tokenizer
        write_tokenizer(spm_path, args.output_dir, args.push_to_hub)

    # 根据模型尺寸选择对应的配置信息
    config = CONFIG_MAPPING[args.model_size]
    # 将 args.dtype 转换为 torch 中的数据类型
    dtype = getattr(torch, args.dtype)
    # 调用 write_model 函数,保存或推送模型
    write_model(
        config=config,
        input_base_path=args.input_checkpoint,
        save_path=args.output_dir,
        safe_serialization=not args.pickle_serialization,
        push_to_hub=args.push_to_hub,
        dtype=dtype,
    )


# 如果当前脚本作为主程序运行,则调用 main 函数
if __name__ == "__main__":
    main()

.\models\gemma\modeling_flax_gemma.py

# 导入必要的库和模块
from typing import Optional, Tuple  # 导入类型提示模块

import flax.linen as nn  # 导入Flax的Linen模块,用于定义模型
import jax  # 导入JAX,用于自动求导和数组操作
import jax.numpy as jnp  # 导入JAX的NumPy接口
import numpy as np  # 导入NumPy,用于数组操作
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入Flax的FrozenDict相关模块,用于不可变字典操作
from flax.linen import combine_masks, make_causal_mask  # 导入Flax的Linen模块中的函数
from flax.linen.attention import dot_product_attention_weights  # 导入注意力机制相关函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入Flax的工具函数,用于字典的扁平化和反扁平化
from jax import lax  # 导入JAX的lax模块,用于定义不同的线性代数和控制流操作

# 导入自定义的模块和类
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput  # 导入Flax模型输出相关类
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring  # 导入Flax预训练模型类和辅助函数
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging  # 导入辅助工具函数和日志记录模块
from .configuration_gemma import GemmaConfig  # 导入Gemma模型的配置类

# 获取logger对象,用于记录日志信息
logger = logging.get_logger(__name__)

# 文档中使用的配置信息
_CONFIG_FOR_DOC = "GemmaConfig"
_CHECKPOINT_FOR_DOC = "google/gemma-2b"
_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2"

# Gemma模型的开始文档字符串,包含模型的基本信息和JAX的特性说明
GEMMA_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
    # Parameters:
    #     config ([`GemmaConfig`]): Model configuration class with all the parameters of the model.
    #         Initializing with a config file does not load the weights associated with the model, only the
    #         configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
    #     dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
    #         The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or
    #         `jax.numpy.bfloat16`.
    #
    #         This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
    #         specified all the computation will be performed with the given `dtype`.
    #
    #         **Note that this only specifies the dtype of the computation and does not influence the dtype of model
    #         parameters.**
    #
    #         If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
    #         [`~FlaxPreTrainedModel.to_bf16`].
"""
定义了一个多行字符串变量 `GEMMA_INPUTS_DOCSTRING`,用于描述以下函数的参数和用法。

"""
def create_sinusoidal_positions(num_pos, dim):
    """
    创建一个正弦位置编码矩阵。

    Args:
        num_pos (int): 序列中位置的总数。
        dim (int): 编码向量的维度。

    Returns:
        numpy.ndarray: 形状为 `(num_pos, dim)` 的正弦位置编码矩阵。
    """
    # 计算逆频率,这里使用了正弦位置编码的标准公式
    inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2)[: (dim // 2)] / dim))
    # 计算频率矩阵,使用 numpy 的 einsum 函数进行计算
    freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
    
    # 返回正弦位置编码矩阵
    return freqs
    # 将频率向量 `freqs` 沿着最后一个轴复制一次,然后进行连接
    emb = np.concatenate((freqs, freqs), axis=-1)
    # 对连接后的数组 `emb` 分别计算正弦和余弦值,然后沿着新增的维度合并起来
    out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
    # 返回处理后的部分数据,只保留前 `num_pos` 个位置的结果
    return jnp.array(out[:, :, :num_pos])
# Copied from transformers.models.llama.modeling_flax_llama.rotate_half
# 函数:将输入张量的后一半隐藏维度进行旋转
def rotate_half(tensor):
    """Rotates half the hidden dims of the input."""
    rotate_half_tensor = jnp.concatenate(
        (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
    )
    return rotate_half_tensor


# Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb
# 函数:应用旋转位置编码到张量上
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
    return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)


# 类:FlaxGemmaRMSNorm
# 类变量:配置信息 config 为 GemmaConfig 类型,默认数据类型为 jnp.float32
class FlaxGemmaRMSNorm(nn.Module):
    config: GemmaConfig
    dtype: jnp.dtype = jnp.float32

    # 方法:初始化设置
    def setup(self):
        self.epsilon = self.config.rms_norm_eps  # 设置 epsilon 参数
        self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)  # 初始化权重参数

    # 方法:调用实例时的行为
    def __call__(self, hidden_states):
        variance = jnp.asarray(hidden_states, dtype=jnp.float32)  # 将隐藏状态转换为 jnp.float32 类型的张量
        variance = jnp.power(variance, 2)  # 计算张量的平方
        variance = variance.mean(-1, keepdims=True)  # 沿着最后一个轴计算张量的均值并保持维度
        # 使用 `jax.numpy.sqrt` 代替 `jax.lax.rsqrt`,因为与 `torch.rsqrt` 不匹配
        hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)  # 对隐藏状态进行归一化

        return (1 + self.weight) * jnp.asarray(hidden_states, dtype=self.dtype)  # 返回归一化后的结果


# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Gemma
# 类:FlaxGemmaRotaryEmbedding
# 类变量:配置信息 config 为 GemmaConfig 类型,默认数据类型为 jnp.float32
class FlaxGemmaRotaryEmbedding(nn.Module):
    config: GemmaConfig
    dtype: jnp.dtype = jnp.float32

    # 方法:初始化设置
    def setup(self):
        head_dim = self.config.head_dim  # 从配置中获取头部维度信息
        self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)  # 创建正弦余弦位置编码

    # 方法:调用实例时的行为
    def __call__(self, key, query, position_ids):
        sincos = self.sincos[position_ids]  # 获取指定位置的正弦余弦位置编码
        sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)  # 将编码拆分为正弦部分和余弦部分

        key = apply_rotary_pos_emb(key, sin_pos, cos_pos)  # 对 key 应用旋转位置编码
        query = apply_rotary_pos_emb(query, sin_pos, cos_pos)  # 对 query 应用旋转位置编码

        key = jnp.asarray(key, dtype=self.dtype)  # 将 key 转换为指定数据类型
        query = jnp.asarray(query, dtype=self.dtype)  # 将 query 转换为指定数据类型

        return key, query  # 返回应用位置编码后的 key 和 query


# 类:FlaxGemmaAttention
# 类变量:配置信息 config 为 GemmaConfig 类型,默认数据类型为 jnp.float32
#       causal 表示是否是因果注意力,is_cross_attention 表示是否是交叉注意力
class FlaxGemmaAttention(nn.Module):
    config: GemmaConfig
    dtype: jnp.dtype = jnp.float32
    causal: bool = True
    is_cross_attention: bool = False
    def setup(self):
        config = self.config
        self.embed_dim = config.hidden_size  # 从配置中获取隐藏层大小作为嵌入维度
        self.num_heads = config.num_attention_heads  # 从配置中获取注意力头的数量
        self.head_dim = config.head_dim  # 从配置中获取每个注意力头的维度
        self.attention_softmax_in_fp32 = self.dtype is not jnp.float32  # 检查数据类型是否为 jnp.float32,用于注意力 softmax

        self.num_key_value_heads = config.num_key_value_heads  # 从配置中获取键值头的数量
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  # 计算键值分组数

        kernel = jax.nn.initializers.normal(self.config.initializer_range)  # 使用正态分布初始化器初始化 kernel

        # 初始化查询投影层,设置输出维度为 num_heads * head_dim
        self.q_proj = nn.Dense(
            self.num_heads * self.head_dim,
            use_bias=config.attention_bias,
            dtype=self.dtype,
            kernel_init=kernel
        )

        # 初始化键投影层,设置输出维度为 num_key_value_heads * head_dim
        self.k_proj = nn.Dense(
            self.num_key_value_heads * self.head_dim,
            use_bias=config.attention_bias,
            dtype=self.dtype,
            kernel_init=kernel,
        )

        # 初始化值投影层,设置输出维度为 num_key_value_heads * head_dim
        self.v_proj = nn.Dense(
            self.num_key_value_heads * self.head_dim,
            use_bias=config.attention_bias,
            dtype=self.dtype,
            kernel_init=kernel,
        )

        # 初始化输出投影层,设置输出维度为 embed_dim
        self.o_proj = nn.Dense(
            self.embed_dim,
            use_bias=config.attention_bias,
            dtype=self.dtype,
            kernel_init=kernel
        )

        # 创建因果掩码,用于自注意力机制,形状为 (1, max_position_embeddings),数据类型为布尔型
        self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")

        # 初始化旋转嵌入,使用 FlaxGemmaRotaryEmbedding 类,传入配置和数据类型
        self.rotary_emb = FlaxGemmaRotaryEmbedding(config, dtype=self.dtype)

    def _split_heads(self, hidden_states, num_heads):
        # 将隐藏状态张量按指定的 num_heads 分割成多个头,保留前两个维度不变
        return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))

    def _merge_heads(self, hidden_states):
        # 将多个头的张量合并成一个头,保留前两个维度不变
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads * self.head_dim,))

    @nn.compact
    # 从 transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache 复制而来
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否正在初始化,通过检查是否存在缓存数据来判断
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取或创建缓存的键值状态,初始化为全零数组,与输入的 key 形状和数据类型相同
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        # 获取或创建缓存的值状态,初始化为全零数组,与输入的 value 形状和数据类型相同
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或创建缓存索引,初始化为整数0
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 获取当前缓存数据的形状,假设为 (*batch_dims, max_length, num_heads, depth_per_head)
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 更新缓存的键和值,使用新的一维空间切片
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            # 使用 lax.dynamic_update_slice 函数更新缓存的键和值
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存中的键和值状态
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新的缓存向量数量
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 对于缓存的解码器自注意力,生成因果掩码:我们的单个查询位置应仅与已生成和缓存的键位置相对应,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 结合因果掩码和输入的注意力掩码
            attention_mask = combine_masks(pad_mask, attention_mask)
        # 返回更新后的键、值和注意力掩码
        return key, value, attention_mask
    ):
        # 使用 self.q_proj 对隐藏状态进行查询投影
        query = self.q_proj(hidden_states)
        # 使用 self.k_proj 对隐藏状态进行键投影
        key = self.k_proj(hidden_states)
        # 使用 self.v_proj 对隐藏状态进行值投影
        value = self.v_proj(hidden_states)

        # 将查询张量按照头数目拆分
        query = self._split_heads(query, self.num_heads)
        # 将键张量按照键值头数目拆分
        key = self._split_heads(key, self.num_key_value_heads)
        # 将值张量按照键值头数目拆分
        value = self._split_heads(value, self.num_key_value_heads)

        # 对键和查询应用旋转嵌入
        key, query = self.rotary_emb(key, query, position_ids)

        # 获取查询和键的长度
        query_length, key_length = query.shape[1], key.shape[1]

        # 如果存在缓存的键,则根据缓存的键值创建一个因果掩码
        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            # 使用 lax.dynamic_slice 创建因果掩码
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
            )
        else:
            # 否则,直接使用预定义的因果掩码
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        # 获取批量大小
        batch_size = hidden_states.shape[0]
        # 广播因果掩码以匹配批量大小
        causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])

        # 广播注意力掩码以匹配因果掩码的形状
        attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        # 将注意力掩码与因果掩码结合
        attention_mask = combine_masks(attention_mask, causal_mask)

        # 初始化 dropout_rng
        dropout_rng = None
        # 如果不是确定性运行且配置中指定了注意力 dropout,则创建 dropout_rng
        if not deterministic and self.config.attention_dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        # 在快速自回归解码期间,逐步逐步提供一个位置,并逐步缓存键和值
        if self.has_variable("cache", "cached_key") or init_cache:
            # 将键、值和查询连接到缓存中
            key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)

        # 将布尔类型的注意力掩码转换为浮点数类型
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
        )

        # 根据指定的键值组数重复键张量
        key = jnp.repeat(key, repeats=self.num_key_value_groups, axis=2)
        # 根据指定的键值组数重复值张量
        value = jnp.repeat(value, repeats=self.num_key_value_groups, axis=2)

        # 执行常规的点积注意力操作
        attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_dropout,
            deterministic=deterministic,
            dtype=attention_dtype,
        )

        # 如果指定了 attention_softmax_in_fp32,则将注意力权重转换为指定的数据类型
        if self.attention_softmax_in_fp32:
            attn_weights = attn_weights.astype(self.dtype)

        # 执行注意力输出计算
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        # 合并多头的输出
        attn_output = self._merge_heads(attn_output)
        # 对注意力输出进行最终的投影
        attn_output = self.o_proj(attn_output)

        # 根据需要返回注意力输出及其权重
        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs
# Gemma MLP 模型的定义,继承自 nn.Module 类
class FlaxGemmaMLP(nn.Module):
    # 指定配置参数为 GemmaConfig 类型
    config: GemmaConfig
    # 指定数据类型为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 模型初始化方法
    def setup(self):
        # 获取嵌入维度
        embed_dim = self.config.hidden_size
        # 获取内部维度,如果未指定则设为 4 倍的嵌入维度
        inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim

        # 使用正态分布初始化器初始化核矩阵,范围为配置中的 initializer_range
        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
        
        # 如果隐藏激活函数未指定,发出警告并设置为 'gelu_pytorch_tanh'
        if self.config.hidden_activation is None:
            logger.warning_once(
                "Gemma's activation function should be approximate GeLU and not exact GeLU. "
                "Changing the activation function to `gelu_pytorch_tanh`."
                f"if you want to use the legacy `{self.config.hidden_act}`, "
                f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` "
                "  instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
            )
            hidden_activation = "gelu_pytorch_tanh"
        else:
            # 否则使用配置中指定的隐藏激活函数
            hidden_activation = self.config.hidden_activation
        
        # 根据激活函数名从预定义的 ACT2FN 字典中获取对应的激活函数
        self.act = ACT2FN[hidden_activation]

        # 初始化门控投影层,使用内部维度,不使用偏置,指定数据类型和核初始化器
        self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
        # 初始化下投影层,使用嵌入维度,不使用偏置,指定数据类型和核初始化器
        self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
        # 初始化上投影层,使用内部维度,不使用偏置,指定数据类型和核初始化器
        self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)

    # 模型调用方法
    def __call__(self, hidden_states):
        # 上投影操作,将隐藏状态映射到内部维度空间
        up_proj_states = self.up_proj(hidden_states)
        # 门控状态,通过激活函数处理门控投影层的输出
        gate_states = self.act(self.gate_proj(hidden_states))

        # 下投影操作,将上投影状态乘以门控状态,映射到嵌入维度空间
        hidden_states = self.down_proj(up_proj_states * gate_states)
        # 返回处理后的隐藏状态
        return hidden_states


# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer 复制而来,将 Llama 改为 Gemma
class FlaxGemmaDecoderLayer(nn.Module):
    # 指定配置参数为 GemmaConfig 类型
    config: GemmaConfig
    # 指定数据类型为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 模型初始化方法
    def setup(self):
        # 初始化输入层归一化,使用 GemmaRMSNorm 类处理配置和数据类型
        self.input_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype)
        # 初始化自注意力层,使用 GemmaAttention 类处理配置和数据类型
        self.self_attn = FlaxGemmaAttention(self.config, dtype=self.dtype)
        # 初始化注意力后归一化层,使用 GemmaRMSNorm 类处理配置和数据类型
        self.post_attention_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype)
        # 初始化 MLP 层,使用 GemmaMLP 类处理配置和数据类型
        self.mlp = FlaxGemmaMLP(self.config, dtype=self.dtype)

    # 模型调用方法
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        ):
            residual = hidden_states
            # 应用输入层归一化
            hidden_states = self.input_layernorm(hidden_states)
            # 使用自注意力机制处理隐藏状态
            outputs = self.self_attn(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                deterministic=deterministic,
                init_cache=init_cache,
                output_attentions=output_attentions,
            )
            # 残差连接
            attn_output = outputs[0]
            hidden_states = residual + attn_output

            residual = hidden_states
            # 应用自注意力后归一化
            hidden_states = self.post_attention_layernorm(hidden_states)
            # 应用多层感知机(MLP)
            hidden_states = self.mlp(hidden_states)
            # 残差连接
            hidden_states = residual + hidden_states

            return (hidden_states,) + outputs[1:]
# 从transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel复制而来,替换GPTNeo为Gemma,GPT_NEO为GEMMA,transformer为model
class FlaxGemmaPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用GemmaConfig作为配置类
    config_class = GemmaConfig
    # 基础模型的前缀为"model"
    base_model_prefix = "model"
    # 模块类未定义
    module_class: nn.Module = None

    def __init__(
        self,
        config: GemmaConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用给定的config和dtype初始化模块
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类的初始化方法
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        # 生成位置ID,广播以匹配输入形状
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        # 分割随机数生成器
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用模块的初始化方法生成随机参数
        random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]

        if params is not None:
            # 如果有提供参数,则与随机参数合并
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def init_cache(self, batch_size, max_length):
        """
        Args:
            batch_size (`int`):
                用于快速自回归解码的批量大小。定义初始化缓存的批量大小。
            max_length (`int`):
                自回归解码的最大可能长度。定义初始化缓存的序列长度。
        """
        # 初始化用于检索缓存的输入变量
        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 使用模块的初始化方法生成缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        # 返回解冻的缓存变量
        return unfreeze(init_variables["cache"])

    @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
    # 定义一个特殊方法 __call__,使得对象可以像函数一样被调用
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 如果指定了输出注意力的选项,则使用指定的值;否则使用默认配置中的值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果指定了输出隐藏状态的选项,则使用指定的值;否则使用默认配置中的值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果指定了返回字典的选项,则使用指定的值;否则使用默认配置中的值
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 获取输入张量的批量大小和序列长度
        batch_size, sequence_length = input_ids.shape

        # 如果未提供位置编码,则根据序列长度和批量大小创建一个默认的位置编码
        if position_ids is None:
            if past_key_values is not None:
                # 如果传递了过去的键值,但未提供位置编码,则引发值错误
                raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
            # 使用序列长度创建一个二维数组,用于表示位置编码
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 如果未提供注意力掩码,则创建一个全为1的注意力掩码数组
        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        # 处理任何需要的随机数生成器
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # 准备输入参数字典,包括模型参数或者自身保存的参数
        inputs = {"params": params or self.params}

        # 如果传递了过去的键值,将它们作为缓存输入,并将"cache"标记为可变
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        # 调用模块的 apply 方法,执行模型的前向计算
        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # 如果传递了过去的键值并且要求返回字典,则将更新后的缓存添加到模型输出中
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        # 如果传递了过去的键值但不要求返回字典,则将更新后的缓存添加到模型输出的第一个元素中
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

        # 返回模型的输出
        return outputs
# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection 复制而来,将 Llama 替换为 Gemma
class FlaxGemmaLayerCollection(nn.Module):
    config: GemmaConfig  # 类型注解,指定 config 为 GemmaConfig 类型
    dtype: jnp.dtype = jnp.float32  # 类型注解,指定 dtype 默认为 jnp.float32

    def setup(self):
        # 初始化 self.blocks 列表,其中每个元素为一个 FlaxGemmaDecoderLayer 实例,根据 config.num_hidden_layers 的值进行循环创建
        self.blocks = [
            FlaxGemmaDecoderLayer(self.config, dtype=self.dtype, name=str(i))
            for i in range(self.config.num_hidden_layers)
        ]

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = False,
    ):
        # 如果 output_attentions 为 True,则初始化 all_attentions 为空元组,否则为 None
        all_attentions = () if output_attentions else None
        # 如果 output_hidden_states 为 True,则初始化 all_hidden_states 为空元组,否则为 None
        all_hidden_states = () if output_hidden_states else None

        # 遍历 self.blocks 列表中的每个 block
        for block in self.blocks:
            # 如果 output_hidden_states 为 True,则将当前 hidden_states 添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            
            # 调用 block 的 __call__ 方法,传递参数并接收返回的 layer_outputs
            layer_outputs = block(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                deterministic=deterministic,
                init_cache=init_cache,
                output_attentions=output_attentions,
            )
            # 更新 hidden_states 为 layer_outputs 的第一个元素(通常是模型的输出)
            hidden_states = layer_outputs[0]

            # 如果 output_attentions 为 True,则将当前层的注意力矩阵添加到 all_attentions 中
            if output_attentions:
                all_attentions += (layer_outputs[1],)

        # 输出包含可能为 None 值的元组 outputs,`FlaxGemmaModule` 将会过滤掉这些 None 值
        outputs = (hidden_states, all_hidden_states, all_attentions)

        return outputs


# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaModule 复制而来,将 Llama 替换为 Gemma
class FlaxGemmaModule(nn.Module):
    config: GemmaConfig  # 类型注解,指定 config 为 GemmaConfig 类型
    dtype: jnp.dtype = jnp.float32  # 类型注解,指定 dtype 默认为 jnp.float32

    def setup(self):
        # 初始化 hidden_size 为 config.hidden_size
        self.hidden_size = self.config.hidden_size
        # 使用正态分布初始化 embed_tokens,形状为 (config.vocab_size, self.hidden_size),dtype 为 self.dtype
        embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
        self.embed_tokens = nn.Embed(
            self.config.vocab_size,
            self.hidden_size,
            embedding_init=embedding_init,
            dtype=self.dtype,
        )
        # 初始化 layers 为 FlaxGemmaLayerCollection 实例,传递 config 和 dtype
        self.layers = FlaxGemmaLayerCollection(self.config, dtype=self.dtype)
        # 初始化 norm 为 FlaxGemmaRMSNorm 实例,传递 config 和 dtype
        self.norm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype)

    # 忽略复制
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        deterministic=True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # ...
        ):
            # 使用模型的嵌入层处理输入的标识符(转换为整数类型)
            input_embeds = self.embed_tokens(input_ids.astype("i4"))

            # 根据论文中建议的缩放因子对嵌入向量进行缩放
            input_embeds = input_embeds * (self.config.hidden_size**0.5)

            # 将输入嵌入向量传递给模型的多层网络进行处理
            outputs = self.layers(
                input_embeds,
                position_ids=position_ids,
                attention_mask=attention_mask,
                deterministic=deterministic,
                init_cache=init_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            # 提取模型处理后的隐藏状态向量
            hidden_states = outputs[0]

            # 对隐藏状态向量进行规范化处理
            hidden_states = self.norm(hidden_states)

            # 如果需要输出所有隐藏状态向量,则构建包含所有隐藏状态的元组
            if output_hidden_states:
                all_hidden_states = outputs[1] + (hidden_states,)
                outputs = (hidden_states, all_hidden_states) + outputs[2:]
            else:
                # 否则,只输出规范化后的隐藏状态向量
                outputs = (hidden_states,) + outputs[1:]

            # 如果不需要返回字典形式的输出,则过滤掉值为None的输出结果
            if not return_dict:
                return tuple(v for v in outputs if v is not None)

            # 返回FlaxBaseModelOutput对象,包含最后的隐藏状态、所有隐藏状态和注意力权重
            return FlaxBaseModelOutput(
                last_hidden_state=hidden_states,
                hidden_states=outputs[1],
                attentions=outputs[-1],
            )
# 定义一个 FlaxGemmaModel 类,继承自 FlaxGemmaPreTrainedModel,用于 Gemma 模型的 transformer 输出原始隐藏状态,没有额外的特定头部。
# 这个类被修改自 transformers.models.llama.modeling_flax_llama.FlaxLlamaModel,其中 Llama 被替换成 Gemma。

@add_start_docstrings(
    "The bare Gemma Model transformer outputting raw hidden-states without any specific head on top.",
    GEMMA_START_DOCSTRING,
)
class FlaxGemmaModel(FlaxGemmaPreTrainedModel):
    module_class = FlaxGemmaModule


# 为 FlaxGemmaModel 类添加样例调用文档字符串,用于检查点、配置等文档化信息。
append_call_sample_docstring(
    FlaxGemmaModel,
    _CHECKPOINT_FOR_DOC,
    FlaxBaseModelOutput,
    _CONFIG_FOR_DOC,
    real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)


# 定义一个 FlaxGemmaForCausalLMModule 类,继承自 nn.Module,表示带因果语言建模头部的 Gemma 模型模块。
# 这个类被复制自 transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule,其中 Llama 被替换成 Gemma。
class FlaxGemmaForCausalLMModule(nn.Module):
    config: GemmaConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.model = FlaxGemmaModule(self.config, dtype=self.dtype)
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )

    # 忽略复制
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用模型进行前向传播
        outputs = self.model(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        # 如果配置指定共享词嵌入,则共享的核心来自模型的参数中的嵌入
        if self.config.tie_word_embeddings:
            shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)

        # 如果不返回字典,则返回 logits 和可能的其他输出
        if not return_dict:
            return (lm_logits,) + outputs[1:]

        # 返回 FlaxCausalLMOutput 对象,包含 logits、隐藏状态和注意力信息
        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


@add_start_docstrings(
    """
    The Gemma Model transformer with a language modeling head (linear layer) on top.
    """,
    GEMMA_START_DOCSTRING,
)
# 定义一个 FlaxGemmaForCausalLM 类,继承自 FlaxGemmaPreTrainedModel,表示带因果语言建模头部的 Gemma 模型。
# 这个类被复制自 transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM,其中 GPTJ 被替换成 Gemma。
class FlaxGemmaForCausalLM(FlaxGemmaPreTrainedModel):
    module_class = FlaxGemmaForCausalLMModule
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # initializing the cache
        # 获取输入张量的批量大小和序列长度
        batch_size, seq_length = input_ids.shape

        # 使用模型的方法初始化缓存,返回过去键值对
        past_key_values = self.init_cache(batch_size, max_length)

        # 注意:通常需要在 attention_mask 的 x > input_ids.shape[-1] 和 x < cache_length 的位置上放置 0。
        # 但由于 Gemma 使用因果掩码,这些位置已经被掩盖了。
        # 因此我们可以在这里创建一个静态的 attention_mask,这对编译更有效率。
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            # 计算位置编码,累积求和并减去 1
            position_ids = attention_mask.cumsum(axis=-1) - 1
            # 使用 lax.dynamic_update_slice 更新 extended_attention_mask 的部分区域
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 如果没有提供 attention_mask,则使用广播方式创建位置编码
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        # 返回准备好的输入字典,包含过去键值对、扩展后的注意力掩码和位置编码
        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新输入用于生成的模型参数
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        # 更新位置编码为最后一个位置的下一个位置
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs
# 调用函数 `append_call_sample_docstring`,添加示例文档字符串到类 `FlaxGemmaForCausalLM` 上
# 使用变量 `_CHECKPOINT_FOR_DOC` 作为实际检查点的示例
# 将类 `FlaxCausalLMOutput` 作为输出配置信息的示例
# 使用变量 `_CONFIG_FOR_DOC` 作为配置信息的示例
# 使用变量 `_REAL_CHECKPOINT_FOR_DOC` 作为真实检查点的示例
append_call_sample_docstring(
    FlaxGemmaForCausalLM,
    _CHECKPOINT_FOR_DOC,
    FlaxCausalLMOutput,
    _CONFIG_FOR_DOC,
    real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)

.\models\gemma\modeling_gemma.py

# coding=utf-8
# 版权 2024 Google Inc. HuggingFace Inc. 团队。保留所有权利。
#
#
# 根据 Apache 许可证版本 2.0 使用本文件("许可证");
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按"原样"分发软件,
# 没有任何明示或暗示的担保或条件。
# 有关特定语言下的权限,请参阅许可证。
""" PyTorch Gemma 模型。"""

import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import (
    AttentionMaskConverter,
    _prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from ...utils.import_utils import is_torch_fx_available
from .configuration_gemma import GemmaConfig

# 如果支持 flash_attn 2.x 版本,则导入相关函数
if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa


# 当使用 Torch FX 时,使 `_prepare_4d_causal_attention_mask` 成为 FX 图中的叶子节点。
# 这意味着该函数不会被跟踪,只会作为图中的一个节点出现。
if is_torch_fx_available():
    if not is_torch_greater_or_equal_than_1_13:
        import torch.fx

    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "GemmaConfig"


def _get_unpad_data(attention_mask):
    # 计算每个序列的长度
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    # 找出非零位置的索引
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    # 找出批次中最大序列长度
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    # 计算累积序列长度
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )


class GemmaRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))
    # 定义一个私有方法 `_norm`,用于对输入张量 x 进行归一化处理
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    # 定义前向传播方法 `forward`,接收输入张量 x
    def forward(self, x):
        # 使用私有方法 `_norm` 对输入张量 x 进行归一化处理,转换为 float 类型
        output = self._norm(x.float())
        
        # 做了一个特定的乘法操作,修改了输出结果 `output` 的值
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # 参考:https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        
        # 将输出结果 `output` 转换为与输入张量 x 相同的数据类型,并返回
        return output.type_as(x)
# 将 GemmaRMSNorm 类添加到 ALL_LAYERNORM_LAYERS 列表中
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)

# 定义 GemmaRotaryEmbedding 类,继承自 nn.Module
class GemmaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        
        # 初始化 GemmaRotaryEmbedding 类的参数
        self.dim = dim  # 维度
        self.max_position_embeddings = max_position_embeddings  # 最大位置嵌入长度
        self.base = base  # 基础值
        self.register_buffer("inv_freq", None, persistent=False)  # 注册非持久化的缓冲区 inv_freq

    @torch.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        
        # 如果 inv_freq 为空,则根据公式计算 inv_freq
        if self.inv_freq is None:
            self.inv_freq = 1.0 / (
                self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
            )
        
        # 扩展 inv_freq 和 position_ids
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        
        # 强制将 freqs 的计算结果转换为 float32,因为 bfloat16 在长上下文中会失去精度
        # 参考 https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)  # 连接 freqs 的 cos 和 sin
            cos = emb.cos()  # 计算余弦值
            sin = emb.sin()  # 计算正弦值
        
        # 返回 cos 和 sin,转换为输入 x 的数据类型
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# 从 transformers.models.llama.modeling_llama.rotate_half 复制并定义 rotate_half 函数
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]  # 取输入的前一半维度
    x2 = x[..., x.shape[-1] // 2 :]  # 取输入的后一半维度
    return torch.cat((-x2, x1), dim=-1)  # 将 x2 反转后与 x1 连接并返回


# 从 transformers.models.llama.modeling_llama.apply_rotary_pos_emb 复制并定义 apply_rotary_pos_emb 函数
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.
    def apply_rotary_pos_emb(q, k, cos, sin, position_ids=torch.Tensor(), unsqueeze_dim=1):
        """
        Apply rotary position embedding to query and key tensors.
    
        Args:
            q (`torch.Tensor`): The query tensor.
            k (`torch.Tensor`): The key tensor.
            cos (`torch.Tensor`): The cosine part of the rotary embedding.
            sin (`torch.Tensor`): The sine part of the rotary embedding.
            position_ids (`torch.Tensor`, *optional*):
                Deprecated and unused.
            unsqueeze_dim (`int`, *optional*, defaults to 1):
                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    
        Returns:
            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
        """
        # Unsqueezing cos and sin along the specified dimension to enable broadcasting
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        
        # Applying rotary position embedding to q and k tensors
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        
        return q_embed, k_embed
class GemmaMLP(nn.Module):
    # GemmaMLP 类定义,继承自 nn.Module
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size  # 从配置中获取隐藏层大小
        self.intermediate_size = config.intermediate_size  # 从配置中获取中间层大小
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 创建线性变换,用于门控投影,输入维度为隐藏层大小,输出维度为中间层大小,无偏置
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 创建线性变换,用于上游投影,输入维度为隐藏层大小,输出维度为中间层大小,无偏置
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        # 创建线性变换,用于下游投影,输入维度为中间层大小,输出维度为隐藏层大小,无偏置
        
        # 如果隐藏层激活函数为 None,则发出警告并设置为 'gelu_pytorch_tanh'
        if config.hidden_activation is None:
            logger.warning_once(
                "Gemma's activation function should be approximate GeLU and not exact GeLU.\n"
                "Changing the activation function to `gelu_pytorch_tanh`."
                f"if you want to use the legacy `{config.hidden_act}`, "
                f"edit the `model.config` to set `hidden_activation={config.hidden_act}` "
                "  instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
            )
            hidden_activation = "gelu_pytorch_tanh"
        else:
            hidden_activation = config.hidden_activation
        
        # 根据配置选择激活函数
        self.act_fn = ACT2FN[hidden_activation]

    def forward(self, x):
        # 前向传播方法,使用门控投影和上游投影进行激活函数后的加权,再经过下游投影
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


# 从 transformers.models.llama.modeling_llama.repeat_kv 复制过来的函数
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    # 将隐藏状态张量在第三维度上进行复制,使得维度从 (batch, num_key_value_heads, slen, head_dim)
    # 变为 (batch, num_key_value_heads * n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class GemmaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""
    # 忽略复制部分
    # 初始化函数,接受配置对象和可选的层索引作为参数
    def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
        # 调用父类的初始化方法
        super().__init__()
        # 将传入的配置对象和层索引保存到实例变量中
        self.config = config
        self.layer_idx = layer_idx
        
        # 如果未传入层索引,则记录警告信息,建议在使用缓存时传入层索引,以避免前向调用中的错误
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )
        
        # 从配置对象中获取并设置注意力机制的参数
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        
        # 检查隐藏层大小是否能够被注意力头数整除,如果不能,则引发值错误异常
        if self.hidden_size % self.num_heads != 0:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        
        # 初始化线性变换层,用于将输入向量投影到注意力头的维度上
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        
        # 初始化旋转嵌入层,用于引入轮转注意力机制
        self.rotary_emb = GemmaRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )
    # 定义函数签名,指定函数的输入参数类型和返回类型
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 获取输入张量的形状信息
        bsz, q_len, _ = hidden_states.size()

        # 将隐藏状态投影到查询、键、值空间
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # 重新组织张量形状以适应多头注意力机制的计算需求,并进行维度转置
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # 获取过去的键值信息(如果存在),并应用旋转位置编码到查询和键状态
        past_key_value = getattr(self, "past_key_value", past_key_value)
        cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

        # 如果存在过去的键值信息,则更新键值状态
        if past_key_value is not None:
            # sin 和 cos 是 RoPE 模型特定的参数;cache_position 是用于静态缓存的参数
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # 将键值信息根据 num_key_value_groups 的设置进行重复
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # 计算注意力权重,采用缩放点积注意力计算方法
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        # 如果存在注意力掩码,则应用到注意力权重上
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # 将注意力权重进行 softmax 归一化,并转换为与 query_states 相同的数据类型
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        # 检查注意力输出的形状是否符合预期
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        # 调整注意力输出的维度顺序,并使其连续存储
        attn_output = attn_output.transpose(1, 2).contiguous()

        # 将注意力输出重新组织为最终输出的形状,并应用输出投影层
        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        # 如果不需要输出注意力权重,则将其置为 None
        if not output_attentions:
            attn_weights = None

        # 返回注意力输出、注意力权重(如果需要)、以及更新后的过去键值信息(如果存在)
        return attn_output, attn_weights, past_key_value
# 从 `transformers.models.llama.modeling_llama.LlamaFlashAttention2` 复制并重命名为 `GemmaFlashAttention2`
class GemmaFlashAttention2(GemmaAttention):
    """
    Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    def __init__(self, *args, **kwargs):
        # 调用父类构造函数,传递所有参数
        super().__init__(*args, **kwargs)

        # TODO: Flash Attention 版本升级至 2.1 后应该移除此段代码。
        # flash_attn<2.1 生成左上角对齐的因果蒙版,而这里需要的是默认为 flash_attn>=2.1 的右下角对齐。此属性用于处理这种差异。
        # 参考链接:https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # 注意,对于 flash_attn<2.1,除了 q_seqlen == 1 的情况外,使用 q_seqlen != k_seqlen 会产生错误的蒙版(左上角)。
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    # 忽略复制
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        """
        执行前向传播,调用 Flash Attention 的公共 API,并处理输入中可能存在的填充标记。
        """
        pass

    def _flash_attention_forward(
        self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
    ):
        """
        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
        first unpad the input, then computes the attention scores and pad the final attention scores.

        Args:
            query_states (`torch.Tensor`):
                Input query states to be passed to Flash Attention API
            key_states (`torch.Tensor`):
                Input key states to be passed to Flash Attention API
            value_states (`torch.Tensor`):
                Input value states to be passed to Flash Attention API
            attention_mask (`torch.Tensor`):
                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
                position of padding tokens and 1 for the position of non-padding tokens.
            dropout (`float`):
                Attention dropout
            softmax_scale (`float`, *optional*):
                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
        """
        # Determine if causal masking is required based on model configuration and query length
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in GemmaFlashAttention2 __init__.
            causal = self.is_causal and query_length != 1

        # Apply padding-aware operations if attention_mask is provided
        if attention_mask is not None:
            batch_size = query_states.shape[0]
            # Unpad inputs based on attention_mask and query_length
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )

            # Extract lengths for query and key sequences after unpadding
            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

            # Perform variable-length Flash Attention computation
            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

            # Pad the attention output back to the original sequence length
            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else:
            # Perform Flash Attention without padding-aware operations
            attn_output = flash_attn_func(
                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
            )

        # Return the computed attention output
        return attn_output
    # 定义一个方法 `_upad_input`,用于处理注意力机制的输入数据
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # 从注意力掩码中获取未填充数据的索引、当前序列长度和批次内最大序列长度
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        # 获取 key_layer 的形状信息
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

        # 重新组织 key_layer 和 value_layer,根据索引 indices_k 来索引未填充的数据
        key_layer = index_first_axis(
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )

        # 根据 query_length 的不同情况处理 query_layer
        if query_length == kv_seq_len:
            # 如果 query_length 等于 kv_seq_len,则直接根据 indices_k 索引未填充的数据
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1:
            # 如果 query_length 等于 1,则创建一个长度为 batch_size 的序列长度 cu_seqlens_q,并进行索引处理
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # 这里有一个 memcpy,这是非常糟糕的。
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # 否则,根据 query_layer 和 attention_mask 的未填充数据,获取未填充的输入数据
            # 这里假设 `-q_len:` 切片表示左填充
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        # 返回处理后的 query_layer, key_layer, value_layer 以及相关的索引和长度信息
        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma
class GemmaSdpaAttention(GemmaAttention):
    """
    Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Ignore copy

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        """
        Perform forward pass of GemmaSdpaAttention.

        Args:
            hidden_states (torch.Tensor): Input tensor of shape (batch_size, sequence_length, hidden_size).
            attention_mask (Optional[torch.Tensor]): Optional tensor of shape (batch_size, sequence_length) 
                containing attention mask for the input sequence. 1.0 for positions that should be attended to, 
                0.0 for masked positions.
            position_ids (Optional[torch.LongTensor]): Optional tensor of shape (batch_size, sequence_length) 
                containing position indices to help distinguish different positions in the input.
            past_key_value (Optional[Cache]): Optional tuple containing cached key and value tensors used for 
                fast decoding.
            output_attentions (bool): Whether to output attentions weights.
            use_cache (bool): Whether to use past key-value states to speed up decoding.
            cache_position (Optional[torch.LongTensor]): Optional tensor of shape (batch_size,) specifying 
                positions in the cache.

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, sequence_length, hidden_size).
        """
        # Implementation details of attention mechanism adapted for Gemma architecture using SDPA API
        pass

GEMMA_ATTENTION_CLASSES = {
    "eager": GemmaAttention,
    "flash_attention_2": GemmaFlashAttention2,
    "sdpa": GemmaSdpaAttention,
}


# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
class GemmaDecoderLayer(nn.Module):
    def __init__(self, config: GemmaConfig, layer_idx: int):
        """
        Initialize GemmaDecoderLayer.

        Args:
            config (GemmaConfig): Configuration object containing model-specific settings.
            layer_idx (int): Index of the decoder layer.
        """
        super().__init__()
        self.hidden_size = config.hidden_size

        # Initialize self-attention mechanism based on configuration
        self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

        # Initialize MLP layer
        self.mlp = GemmaMLP(config)

        # Layer normalization for input to the layer
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Layer normalization after attention mechanism
        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        """
        Perform forward pass of GemmaDecoderLayer.

        Args:
            hidden_states (torch.Tensor): Input tensor of shape (batch_size, sequence_length, hidden_size).
            attention_mask (Optional[torch.Tensor]): Optional tensor of shape (batch_size, sequence_length) 
                containing attention mask for the input sequence. 1.0 for positions that should be attended to, 
                0.0 for masked positions.
            position_ids (Optional[torch.LongTensor]): Optional tensor of shape (batch_size, sequence_length) 
                containing position indices to help distinguish different positions in the input.
            past_key_value (Optional[Tuple[torch.Tensor]]): Optional tuple containing cached key and value tensors 
                used for fast decoding.
            output_attentions (Optional[bool]): Whether to output attentions weights.
            use_cache (Optional[bool]): Whether to use past key-value states to speed up decoding.
            cache_position (Optional[torch.LongTensor]): Optional tensor of shape (batch_size,) specifying 
                positions in the cache.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, sequence_length, hidden_size).
        """
        # Detailed implementation of forward pass through a GemmaDecoderLayer
        pass
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )

        # 记录输入的隐藏状态,用于残差连接
        residual = hidden_states

        # 应用输入层的 Layer Normalization
        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        # 调用 self_attn 方法进行自注意力计算,并返回更新后的隐藏状态、注意力权重和更新的键值对
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        # 将残差连接加回到更新后的隐藏状态中
        hidden_states = residual + hidden_states

        # Fully Connected
        # 记录当前的隐藏状态,用于残差连接
        residual = hidden_states

        # 应用后注意力层的 Layer Normalization
        hidden_states = self.post_attention_layernorm(hidden_states)

        # 应用 MLP 层
        hidden_states = self.mlp(hidden_states)

        # 将残差连接加回到 MLP 输出的隐藏状态中
        hidden_states = residual + hidden_states

        # 构造输出元组,包含更新后的隐藏状态
        outputs = (hidden_states,)

        # 如果需要返回注意力权重,则添加到输出元组中
        if output_attentions:
            outputs += (self_attn_weights,)

        # 如果需要使用缓存,将更新的键值对添加到输出元组中
        if use_cache:
            outputs += (present_key_value,)

        return outputs
# 定义模型文档字符串,描述该模型继承自`PreTrainedModel`,指向其超类文档以获取通用方法信息,
# 并说明它也是一个PyTorch的`torch.nn.Module`子类,应当按照PyTorch文档使用。
GEMMA_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`GemmaConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 应用文档字符串到GemmaPreTrainedModel类,描述其作为一个裸的Gemma模型,输出没有特定顶部头部的原始隐藏状态。
# 包含先前定义的模型文档字符串作为参数详细信息的一部分。
@add_start_docstrings(
    "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
    GEMMA_START_DOCSTRING,
)
class GemmaPreTrainedModel(PreTrainedModel):
    # GemmaPreTrainedModel类使用GemmaConfig作为其配置类
    config_class = GemmaConfig
    # 指定基础模型前缀为'model'
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 指定需要保持在fp32模块中的参数列表
    _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
    # 指定不分割的模块列表
    _no_split_modules = ["GemmaDecoderLayer"]
    # 指定跳过设备放置的键列表
    _skip_keys_device_placement = ["past_key_values", "causal_mask"]
    # 支持flash_attention_2
    _supports_flash_attn_2 = True
    # 支持sdpa
    _supports_sdpa = True
    # 支持cache类
    _supports_cache_class = True

    # 初始化模型权重的私有方法,根据配置中的initializer_range初始化线性层和嵌入层的权重
    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    # 设置缓存的私有方法,根据特定条件初始化模型的缓存
    def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
        # 如果使用flash_attention_2且缓存类为StaticCache,则抛出异常
        if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
            raise ValueError(
                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )

        # 如果最大缓存长度大于模型的causal_mask形状或设备不匹配,则重新生成causal_mask并注册为模型的缓冲区
        if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
            causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
            self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

        # 遍历模型的每一层,为其self-attention层的past_key_value属性设置缓存
        for layer in self.model.layers:
            weights = layer.self_attn.o_proj.weight
            layer.self_attn.past_key_value = cache_cls(
                self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
            )
    # 重置缓存函数,用于清空模型中每个层的注意力机制的过去键值缓存
    def _reset_cache(self):
        # 遍历模型中的每一层
        for layer in self.model.layers:
            # 将每一层的自注意力机制的过去键值缓存置为None,即清空缓存
            layer.self_attn.past_key_value = None
GEMMA_INPUTS_DOCSTRING = r"""
"""


@add_start_docstrings(
    "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
    GEMMA_START_DOCSTRING,
)
# 从transformers.models.llama.modeling_llama.LlamaModel复制而来,将LLAMA->GEMMA,Llama->Gemma
# GemmaModel类定义,用于Transformer解码器,包含config.num_hidden_layers层,每层是GemmaDecoderLayer
# Args:
#     config: GemmaConfig,Gemma配置对象
class GemmaModel(GemmaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]

    Args:
        config: GemmaConfig
    """

    def __init__(self, config: GemmaConfig):
        super().__init__(config)
        # 设置填充索引为config中的pad_token_id
        self.padding_idx = config.pad_token_id
        # 设置词汇表大小为config中的vocab_size
        self.vocab_size = config.vocab_size

        # 创建词嵌入层,参数为词汇表大小、隐藏大小、填充索引
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 创建包含config.num_hidden_layers个GemmaDecoderLayer对象的层列表
        self.layers = nn.ModuleList(
            [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # 创建GemmaRMSNorm对象,参数为隐藏大小、eps值
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 设置梯度检查点为False
        self.gradient_checkpointing = False

        # 初始化权重并应用最终处理
        self.post_init()

    # 返回词嵌入层对象
    def get_input_embeddings(self):
        return self.embed_tokens

    # 设置词嵌入层对象的值
    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # 从transformers.models.llama.modeling_llama.LlamaModel复制而来,将LLAMA->GEMMA,Llama->Gemma
    # GemmaModel前向传播函数,忽略复制
    # TODO: 截至torch==2.2.0,在generate中传递给模型的attention_mask是二维的,即使在使用静态KV缓存时也是动态长度。这是torch.compile的问题,
    #  导致每个解码步骤都重新捕获cudagraphs(例如,`recording cudagraph tree for symint key 13`),速度非常慢。
    #  一个解决方法是@torch.compiler.disable,但这会阻止使用fullgraph=True。详细内容请参见https://github.com/huggingface/transformers/pull/29114
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        pass


# 从transformers.models.llama.modeling_llama.LlamaForCausalLM复制而来,将LLAMA->GEMMA,Llama->Gemma,llama->gemma
# GemmaForCausalLM类定义,继承自GemmaPreTrainedModel
class GemmaForCausalLM(GemmaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    # 初始化函数,接受一个config参数
    def __init__(self, config):
        super().__init__(config)
        # 创建一个GemmaModel对象,传入config参数
        self.model = GemmaModel(config)
        # 设置词汇表大小为config中的vocab_size
        self.vocab_size = config.vocab_size
        # 创建一个线性层,将隐藏大小转换为词汇表大小,没有偏置
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()
    # 返回模型的输入嵌入
    def get_input_embeddings(self):
        return self.model.embed_tokens

    # 设置模型的输入嵌入
    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    # 返回模型的输出嵌入
    def get_output_embeddings(self):
        return self.lm_head

    # 设置模型的输出嵌入
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 设置解码器模型
    def set_decoder(self, decoder):
        self.model = decoder

    # 返回当前模型
    def get_decoder(self):
        return self.model

    # 忽略复制,该函数装饰了 forward 方法,添加了模型前向传播的文档字符串
    @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        pass  # 此处定义了模型的前向传播逻辑,具体实现在其它地方

    # 准备生成的输入,在生成阶段用于处理输入的方法
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
    ):
        pass  # 此方法用于生成阶段准备输入,具体实现在其它地方

    # 静态方法:重新排序缓存中的过去键值,用于束搜索生成
    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past
# Gemma 模型用于序列分类任务,在预训练模型基础上增加了顶部的线性层用于分类。
# 这里的 [`GemmaForSequenceClassification`] 类似于其他因果模型(如 GPT-2),使用最后一个标记进行分类。
# 由于它在最后一个标记上执行分类,需要知道最后一个标记的位置。如果配置中定义了 `pad_token_id`,则会找到每行中最后一个非填充标记。如果没有定义 `pad_token_id`,则简单地取批处理中每行的最后一个值。
# 当传递 `inputs_embeds` 而不是 `input_ids` 时,由于无法猜测填充标记,它执行相同的操作(取批处理中每行的最后一个值)。

@add_start_docstrings(
    GEMMA_START_DOCSTRING,
)
# 从 transformers.models.llama.modeling_llama.LlamaForSequenceClassification 复制并将 LLAMA 改为 GEMMA,Llama 改为 Gemma
class GemmaForSequenceClassification(GemmaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = GemmaModel(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
    # 前向传播函数,接受多种输入参数用于序列分类任务
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,

.\models\gemma\tokenization_gemma.py

# coding=utf-8
# 定义编码格式为 UTF-8

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# 版权声明:2024 年 HuggingFace 公司团队。保留所有权利。

# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache 许可证 2.0 版本授权使用此文件。

# you may not use this file except in compliance with the License.
# 您除非遵循许可证,否则不得使用此文件。

# You may obtain a copy of the License at
# 您可以在以下网址获取许可证副本:

#     http://www.apache.org/licenses/LICENSE-2.0
#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非适用法律要求或书面同意,本软件按“原样”分发,不附带任何明示或暗示的保证或条件。

# See the License for the specific language governing permissions and
# limitations under the License.
# 详细了解许可证以了解权限和限制。

"""Tokenization classes for Gemma."""
# 导入 Gemma 的 Tokenization 类

import os
# 导入操作系统相关的模块
from shutil import copyfile
# 从 shutil 模块导入 copyfile 函数
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
# 导入类型检查相关的模块,以及一些数据结构

import sentencepiece as spm
# 导入 sentencepiece 库

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
# 从 tokenization_utils 模块中导入 AddedToken 和 PreTrainedTokenizer
from ...utils import logging
# 从 utils 模块导入 logging

if TYPE_CHECKING:
    pass
# 如果是类型检查阶段,则不执行任何操作

logger = logging.get_logger(__name__)
# 获取当前模块的 logger 对象

VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
# 定义词汇文件名的映射,vocab_file 对应的文件名是 tokenizer.model

SPIECE_UNDERLINE = "▁"
# 定义特殊字符 SPIECE_UNDERLINE 为 "▁",用于表示词汇中的连接符

class GemmaTokenizer(PreTrainedTokenizer):
    """
    Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
    no padding token in the original model.
    """
    # GemmaTokenizer 类,继承自 PreTrainedTokenizer 类

    def __init__(
        self,
        # 初始化方法,接受以下参数:

    ```
    # 定义函数的参数和默认值
    Args:
        vocab_file (`str`):
            词汇表文件的路径。
        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
            未知标记。词汇表中不存在的标记将被设置为此标记。
        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
            序列的起始标记,用于预训练过程中。也可以用作序列分类器的标记。
        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
            序列的结束标记。
        pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
            特殊标记,用于使标记数组在批处理时具有相同的大小。在注意力机制或损失计算中将被忽略。
        sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
            将传递给 `SentencePieceProcessor.__init__()` 方法的参数字典。
            可用于设置 SentencePiece 的参数,如启用子词正则化和采样参数等。
        add_bos_token (`bool`, *optional*, defaults to `True`):
            是否在序列的开头添加 `bos_token`。
        add_eos_token (`bool`, *optional*, defaults to `False`):
            是否在序列的末尾添加 `eos_token`。
        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
            是否在解码后清理空格,清理包括移除额外的空格等。
        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
            是否使用默认的系统提示。
        spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
            是否在特殊标记之间添加空格。
    """

    # 定义词汇表文件名和模型输入名称列表
    vocab_files_names = VOCAB_FILES_NAMES
    model_input_names = ["input_ids", "attention_mask"]
    # 初始化函数,用于创建一个新的 LlamaTokenizer 对象
    def __init__(
        self,
        vocab_file,
        unk_token="<unk>",
        bos_token="<bos>",
        eos_token="<eos>",
        pad_token="<pad>",
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        add_bos_token=True,
        add_eos_token=False,
        clean_up_tokenization_spaces=False,
        use_default_system_prompt=False,
        spaces_between_special_tokens=False,
        **kwargs,
    ):
        # 如果 sp_model_kwargs 为 None,则设为空字典,否则使用传入的参数
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
        # 如果 bos_token 是字符串,则将其封装为一个 AddedToken 对象
        bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
        # 如果 eos_token 是字符串,则将其封装为一个 AddedToken 对象
        eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
        # 如果 unk_token 是字符串,则将其封装为一个 AddedToken 对象
        unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
        # 如果 pad_token 是字符串,则将其封装为一个 AddedToken 对象
        pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token

        # 将传入的参数赋值给对象的属性
        self.vocab_file = vocab_file
        self.add_bos_token = add_bos_token
        self.add_eos_token = add_eos_token
        self.use_default_system_prompt = use_default_system_prompt

        # 使用 SentencePieceProcessor 初始化 sp_model 对象,并加载给定的词汇文件
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(vocab_file)

        # 调用父类的初始化函数
        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            add_bos_token=add_bos_token,
            add_eos_token=add_eos_token,
            sp_model_kwargs=self.sp_model_kwargs,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            use_default_system_prompt=use_default_system_prompt,
            spaces_between_special_tokens=spaces_between_special_tokens,
            **kwargs,
        )

    # 复制自 transformers.models.llama.tokenization_llama.LlamaTokenizer.__getstate__
    def __getstate__(self):
        # 复制对象的当前状态
        state = self.__dict__.copy()
        # 将 sp_model 设置为 None,避免序列化时包含模型本身
        state["sp_model"] = None
        # 获取 sp_model 的序列化模型,并保存到状态中
        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
        return state

    # 复制自 transformers.models.llama.tokenization_llama.LlamaTokenizer.__setstate__
    def __setstate__(self, d):
        # 恢复对象的状态
        self.__dict__ = d
        # 使用 sp_model_kwargs 初始化 sp_model 对象,并从序列化的 proto 中加载模型
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)

    @property
    # 复制自 transformers.models.llama.tokenization_llama.LlamaTokenizer.vocab_size
    def vocab_size(self):
        """Returns vocab size"""
        # 返回词汇表的大小,即 sp_model 中词汇的数量
        return self.sp_model.get_piece_size()

    # 复制自 transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
    def get_vocab(self):
        """Returns vocab as a dict"""
        # 创建一个词汇表字典,将词汇索引映射为词汇
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        # 将已添加的特殊 token 编码器合并到词汇表中
        vocab.update(self.added_tokens_encoder)
        return vocab
    # 返回经过 Gemma 分词器处理后的文本字符串,不添加前导空格
    def _tokenize(self, text, **kwargs):
        return self.sp_model.encode(text, out_type=str)

    # 从词汇表中将 token(字符串)转换为对应的 id,方法来自于 llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
    def _convert_token_to_id(self, token):
        return self.sp_model.piece_to_id(token)

    # 从词汇表中将 id(整数)转换为对应的 token(字符串),方法来自于 llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
    def _convert_id_to_token(self, index):
        token = self.sp_model.IdToPiece(index)
        return token

    # 将 token_ids(整数列表)解码为字符串,可以选择跳过特殊 token 和在特殊 token 之间添加空格
    def _decode(
        self,
        token_ids: List[int],
        skip_special_tokens: bool = False,
        spaces_between_special_tokens: bool = False,
        **kwargs,
    ) -> str:
        sub_texts = []
        current_sub_text = []
        for ids in token_ids:
            if skip_special_tokens and ids in self.all_special_ids:
                continue
            if ids in self._added_tokens_decoder:
                if current_sub_text:
                    sub_texts.append(self.sp_model.decode(current_sub_text))
                sub_texts.append(self._added_tokens_decoder[ids].content)
                current_sub_text = []
            else:
                current_sub_text.append(ids)
        if current_sub_text:
            sub_texts.append(self.sp_model.decode(current_sub_text))

        if spaces_between_special_tokens:
            sub_texts = " ".join(sub_texts)
        else:
            sub_texts = "".join(sub_texts)

        return sub_texts

    # 将 token(字符串)序列转换为单个字符串
    def convert_tokens_to_string(self, tokens):
        current_sub_tokens = []
        out_string = ""
        for token in tokens:
            if token in self._added_tokens_encoder:
                out_string += self.sp_model.decode(current_sub_tokens) + token
                current_sub_tokens = []
            else:
                current_sub_tokens.append(token)
        out_string += self.sp_model.decode(current_sub_tokens)
        return out_string

    # 保存词汇表的方法,来自 llama.tokenization_llama.LlamaTokenizer.save_vocabulary
    # 从给定的 `token_ids_0` 和 `token_ids_1` 构建带有特殊标记的输入序列
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        # 如果需要,添加开头的 BOS (Beginning of Sentence) 标记
        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
        # 如果需要,添加结尾的 EOS (End of Sentence) 标记
        eos_token_id = [self.eos_token_id] if self.add_eos_token else []

        # 构建包含特殊标记的输出序列
        output = bos_token_id + token_ids_0 + eos_token_id

        # 如果提供了 `token_ids_1`,再次构建包含特殊标记的输出序列
        if token_ids_1 is not None:
            output = output + bos_token_id + token_ids_1 + eos_token_id

        return output

    # 从给定的 `token_ids_0` 和 `token_ids_1` 判断特殊标记的掩码
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ):
    # 从没有添加特殊标记的标记列表中获取序列ID。当使用分词器的 `prepare_for_model` 方法添加特殊标记时调用此方法。
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        # 如果已经包含特殊标记,则调用父类方法获取特殊标记的掩码
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # 根据是否添加起始(bos)和结束(eos)标记,初始化起始和结束标记ID列表
        bos_token_id = [1] if self.add_bos_token else []
        eos_token_id = [1] if self.add_eos_token else []

        # 如果没有第二个序列token_ids_1,则返回仅包含第一个序列的特殊标记掩码
        if token_ids_1 is None:
            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
        
        # 否则,返回包含两个序列的特殊标记掩码
        return (
            bos_token_id
            + ([0] * len(token_ids_0))
            + eos_token_id
            + bos_token_id
            + ([0] * len(token_ids_1))
            + eos_token_id
        )

    # 从 `transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences` 复制过来的方法
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Args:
            token_ids_0 (`List[int]`):
                List of ids.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of token type IDs according to the given sequence(s).
        """
        # 根据是否添加起始(bos)和结束(eos)标记,初始化起始和结束标记ID列表
        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
        eos_token_id = [self.eos_token_id] if self.add_eos_token else []

        # 初始化输出为全0的列表,长度为起始 + 第一个序列 + 结束的总长度
        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)

        # 如果有第二个序列token_ids_1,则设置第二个序列部分的token type ID为1
        if token_ids_1 is not None:
            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)

        return output

.\models\gemma\tokenization_gemma_fast.py

# 导入必要的模块和函数
import os  # 导入操作系统模块
from shutil import copyfile  # 从 shutil 模块导入 copyfile 函数
from typing import Optional, Tuple  # 导入类型提示相关的类和函数

from tokenizers import processors  # 导入 tokenizers 库中的 processors 模块

# 导入相对路径下的模块和函数
from ...tokenization_utils_fast import PreTrainedTokenizerFast  # 导入预训练的快速分词器
from ...utils import is_sentencepiece_available, logging  # 导入工具函数和日志模块
from ...utils.versions import require_version  # 导入版本控制相关函数

# 确保 tokenizers 的版本符合要求
require_version("tokenizers>=0.13.3")

# 如果系统支持 sentencepiece,则导入 GemmaTokenizer
if is_sentencepiece_available():
    from .tokenization_gemma import GemmaTokenizer
else:
    GemmaTokenizer = None  # 否则将 GemmaTokenizer 设置为 None

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义词汇文件的名称
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}


class GemmaTokenizerFast(PreTrainedTokenizerFast):
    """
    构建一个快速的 Gemma 分词器,基于字节级别的 Byte-Pair-Encoding。

    这个分词器使用了 ByteFallback 和没有前缀空格。标准化应用于将 `" "` 替换为 `"▁"`

    ```
    >>> from transformers import GemmaTokenizerFast

    >>> tokenizer = GemmaTokenizerFast.from_pretrained("hf-internal-testing/dummy-gemma")
    >>> tokenizer.encode("Hello this is a test")
    [2, 4521, 736, 603, 476, 2121]
    ```

    如果您想要更改 `bos_token` 或 `eos_token`,请确保在初始化模型时指定它们,或者调用 `tokenizer.update_post_processor()` 
    确保后处理正确完成(否则编码序列的第一个令牌和最后一个令牌的值将不正确)。更多详情,请查看
    [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) 文档。

    这个分词器继承自 [`PreTrainedTokenizerFast`],其中包含大多数主要方法。用户应该参考这个超类获取更多关于这些方法的信息。
    """
    Args:
        vocab_file (`str`, *optional*):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
            contains the vocabulary necessary to instantiate a tokenizer.
        tokenizer_file (`str`, *optional*):
            [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
            contains everything needed to load the tokenizer.
        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
            Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
            extra spaces.
        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
            The end of sequence token.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The padding token
        add_bos_token (`bool`, *optional*, defaults to `True`):
            Whether or not to add a `bos_token` at the start of sequences.
        add_eos_token (`bool`, *optional*, defaults to `False`):
            Whether or not to add an `eos_token` at the end of sequences.
    """
    # 定义用于加载和保存的文件名列表
    vocab_files_names = VOCAB_FILES_NAMES
    # 使用 GemmaTokenizer 作为慢速分词器的类
    slow_tokenizer_class = GemmaTokenizer
    # 填充位置设为左侧
    padding_side = "left"
    # 模型输入名称列表
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file=None,
        tokenizer_file=None,
        clean_up_tokenization_spaces=False,
        unk_token="<unk>",
        bos_token="<bos>",
        eos_token="<eos>",
        pad_token="<pad>",
        add_bos_token=True,
        add_eos_token=False,
        **kwargs,
    ):
        # 调用父类的初始化方法,传递参数以配置分词器
        super().__init__(
            vocab_file=vocab_file,
            tokenizer_file=tokenizer_file,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            add_bos_token=add_bos_token,
            add_eos_token=add_eos_token,
            **kwargs,
        )
        # 设置是否添加 bos_token 到序列开始的标志
        self._add_bos_token = add_bos_token
        # 设置是否添加 eos_token 到序列结尾的标志
        self._add_eos_token = add_eos_token
        # 更新分词器的后处理器
        self.update_post_processor()
        # 保存词汇文件的路径
        self.vocab_file = vocab_file

    @property
    def can_save_slow_tokenizer(self) -> bool:
        # 检查是否可以保存慢速分词器,需要有有效的词汇文件路径
        return os.path.isfile(self.vocab_file) if self.vocab_file else False

    # 从 transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor 复制
    # 更新后处理器,使用当前的 `bos_token` 和 `eos_token` 更新底层后处理器
    def update_post_processor(self):
        bos = self.bos_token  # 获取开始标记符
        bos_token_id = self.bos_token_id  # 获取开始标记符的 ID
        # 如果 `add_bos_token` 为 True 但 `bos_token` 为 None,抛出数值错误
        if bos is None and self.add_bos_token:
            raise ValueError("add_bos_token = True but bos_token = None")

        eos = self.eos_token  # 获取结束标记符
        eos_token_id = self.eos_token_id  # 获取结束标记符的 ID
        # 如果 `add_eos_token` 为 True 但 `eos_token` 为 None,抛出数值错误
        if eos is None and self.add_eos_token:
            raise ValueError("add_eos_token = True but eos_token = None")

        # 构建单句模板
        single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
        # 构建双句模板
        pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"

        special_tokens = []
        # 如果需要添加开始标记符,则将其添加到特殊标记列表中
        if self.add_bos_token:
            special_tokens.append((bos, bos_token_id))
        # 如果需要添加结束标记符,则将其添加到特殊标记列表中
        if self.add_eos_token:
            special_tokens.append((eos, eos_token_id))
        
        # 将后处理器设为模板处理器,使用构建好的模板和特殊标记列表
        self._tokenizer.post_processor = processors.TemplateProcessing(
            single=single, pair=pair, special_tokens=special_tokens
        )

    @property
    def add_eos_token(self):
        return self._add_eos_token  # 返回是否添加结束标记符的属性值

    @property
    def add_bos_token(self):
        return self._add_bos_token  # 返回是否添加开始标记符的属性值

    @add_eos_token.setter
    def add_eos_token(self, value):
        self._add_eos_token = value  # 设置是否添加结束标记符的属性值
        self.update_post_processor()  # 更新后处理器

    @add_bos_token.setter
    def add_bos_token(self, value):
        self._add_bos_token = value  # 设置是否添加开始标记符的属性值
        self.update_post_processor()  # 更新后处理器

    # 从 transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary 复制而来
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 如果无法保存慢速分词器的词汇表,则抛出数值错误
        if not self.can_save_slow_tokenizer:
            raise ValueError(
                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
                "tokenizer."
            )

        # 如果保存路径不是目录,则记录错误信息并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 设置输出的词汇表文件路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件路径与目标路径不同,则复制词汇表文件到目标路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
            copyfile(self.vocab_file, out_vocab_file)

        return (out_vocab_file,)
    # 构建带有特殊令牌的输入序列
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        # 如果需要添加起始令牌,将起始令牌 ID 添加到列表中
        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
        # 如果需要添加结束令牌,将结束令牌 ID 添加到列表中
        eos_token_id = [self.eos_token_id] if self.add_eos_token else []

        # 构建输出序列,连接起始令牌、token_ids_0、结束令牌
        output = bos_token_id + token_ids_0 + eos_token_id

        # 如果存在第二个输入序列 token_ids_1,进行相同的处理
        if token_ids_1 is not None:
            # 连接起始令牌、token_ids_1、结束令牌
            output = output + bos_token_id + token_ids_1 + eos_token_id

        # 返回构建好的输出序列
        return output

.\models\gemma\__init__.py

# 导入所需模块和函数,这里从不同的模块和子模块中导入特定的内容
from typing import TYPE_CHECKING  # 导入类型检查相关的功能

from ...utils import (
    OptionalDependencyNotAvailable,  # 导入自定义的异常类
    _LazyModule,  # 导入懒加载模块的支持
    is_flax_available,  # 检查是否有Flax库可用
    is_sentencepiece_available,  # 检查是否有SentencePiece库可用
    is_tokenizers_available,  # 检查是否有Tokenizers库可用
    is_torch_available,  # 检查是否有PyTorch库可用
)

# 定义一个字典,用于描述导入结构
_import_structure = {
    "configuration_gemma": ["GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "GemmaConfig"],  # Gemma模型配置相关内容
}

# 检查是否有SentencePiece库可用,若不可用则引发OptionalDependencyNotAvailable异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["tokenization_gemma"] = ["GemmaTokenizer"]  # 导入Gemma模型的分词器

# 检查是否有Tokenizers库可用,若不可用则引发OptionalDependencyNotAvailable异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["tokenization_gemma_fast"] = ["GemmaTokenizerFast"]  # 导入Gemma模型的快速分词器

# 检查是否有PyTorch库可用,若不可用则引发OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_gemma"] = [
        "GemmaForCausalLM",  # Gemma模型的因果语言模型
        "GemmaModel",  # Gemma模型基类
        "GemmaPreTrainedModel",  # Gemma模型的预训练模型基类
        "GemmaForSequenceClassification",  # Gemma模型的序列分类模型
    ]

# 检查是否有Flax库可用,若不可用则引发OptionalDependencyNotAvailable异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_flax_gemma"] = [
        "FlaxGemmaForCausalLM",  # Flax版本的Gemma因果语言模型
        "FlaxGemmaModel",  # Flax版本的Gemma模型基类
        "FlaxGemmaPreTrainedModel",  # Flax版本的Gemma预训练模型基类
    ]

# 如果在类型检查模式下
if TYPE_CHECKING:
    from .configuration_gemma import GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP, GemmaConfig  # 导入Gemma模型的配置映射和配置类

    # 检查是否有SentencePiece库可用,若不可用则忽略导入
    try:
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_gemma import GemmaTokenizer  # 导入Gemma模型的分词器

    # 检查是否有Tokenizers库可用,若不可用则忽略导入
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_gemma_fast import GemmaTokenizerFast  # 导入Gemma模型的快速分词器

    # 检查是否有PyTorch库可用,若不可用则忽略导入
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_gemma import (
            GemmaForCausalLM,  # Gemma模型的因果语言模型
            GemmaForSequenceClassification,  # Gemma模型的序列分类模型
            GemmaModel,  # Gemma模型基类
            GemmaPreTrainedModel,  # Gemma模型的预训练模型基类
        )

    # 检查是否有Flax库可用,若不可用则忽略导入
    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    # 如果发生 OptionalDependencyNotAvailable 异常,则什么也不做,直接 pass
    except OptionalDependencyNotAvailable:
        pass
    # 如果没有发生异常,则导入以下模块
    else:
        from .modeling_flax_gemma import (
            FlaxGemmaForCausalLM,    # 导入 FlaxGemmaForCausalLM 类
            FlaxGemmaModel,         # 导入 FlaxGemmaModel 类
            FlaxGemmaPreTrainedModel,  # 导入 FlaxGemmaPreTrainedModel 类
        )
else:
    # 如果不是以上情况,即模块不是以导入方式被调用
    import sys
    # 导入 sys 模块,用于操作 Python 解释器的系统功能

    # 将当前模块注册为懒加载模块的实例
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
    # 使用 _LazyModule 类创建一个新的模块对象,并将其注册到 sys.modules 中,以实现懒加载模块的特性

.\models\git\configuration_git.py

# coding=utf-8
# 上面的行声明了文件的编码格式为 UTF-8,确保文件中的中文和特殊字符能正确解析
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
# 版权声明,声明代码版权归 HuggingFace Inc. 团队所有,保留所有权利
#
# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache 许可证版本 2.0 进行许可,即除非符合许可证要求,否则不得使用此文件
# You may obtain a copy of the License at
# 你可以在上述链接获取许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 根据许可证分发的软件是基于 "AS IS" 基础分发的,没有任何形式的担保或条件
# See the License for the specific language governing permissions and
# limitations under the License.
# 查看许可证以了解详细的条款和条件
#
# Importing necessary modules
# 导入必要的模块
import os
# Importing Union type hint from typing module
# 从 typing 模块导入 Union 类型提示
from typing import Union
# Importing necessary modules from local package
# 从本地包中导入必要的模块
from ...configuration_utils import PretrainedConfig
from ...utils import logging
# Getting the logger object specific to the current module
# 获取与当前模块相关的日志记录器对象
logger = logging.get_logger(__name__)
# Mapping of pretrained model identifier to its configuration file URL
# 预训练模型标识符到配置文件 URL 的映射
GIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "microsoft/git-base": "https://huggingface.co/microsoft/git-base/resolve/main/config.json",
}


class GitVisionConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`GitVisionModel`]. It is used to instantiate a GIT
    vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the vision encoder of the GIT
    [microsoft/git-base](https://huggingface.co/microsoft/git-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """
    # Configuration class for GitVisionModel
    # GitVisionModel 的配置类
    # This class inherits from PretrainedConfig
    # 该类继承自 PretrainedConfig
    # It defines configuration parameters for GitVisionModel
    # 定义 GitVisionModel 的配置参数
    # Read PretrainedConfig documentation for more details
    # 详细信息请参阅 PretrainedConfig 文档
    # 模型类型标识字符串,表示这是一个 Git Vision 模型
    model_type = "git_vision_model"
    
    # GitVisionConfig 类的构造函数,用于初始化模型配置参数
    def __init__(
        self,
        hidden_size=768,  # 编码器层和池化层的维度大小,默认为768
        intermediate_size=3072,  # Transformer 编码器中间层(即前馈层)的维度大小,默认为3072
        num_hidden_layers=12,  # Transformer 编码器中的隐藏层数,默认为12
        num_attention_heads=12,  # Transformer 编码器中每个注意力层的注意头数量,默认为12
        num_channels=3,  # 图像通道数,默认为3(RGB)
        image_size=224,  # 每个图像的分辨率大小,默认为224
        patch_size=16,  # 每个图像块(patch)的大小,默认为16
        hidden_act="quick_gelu",  # 编码器和池化器中的非线性激活函数,默认为"quick_gelu"
        layer_norm_eps=1e-5,  # 层归一化层使用的 epsilon 值,默认为1e-5
        attention_dropout=0.0,  # 注意力概率的 dropout 比率,默认为0.0(不进行 dropout)
        initializer_range=0.02,  # 初始化所有权重矩阵的截断正态分布的标准差,默认为0.02
        **kwargs,  # 其他可选关键字参数
    ):
        # 调用父类的构造函数,初始化其他可能存在的关键字参数
        super().__init__(**kwargs)
    
        # 设置实例变量,将传入的参数赋值给对象的对应属性
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.image_size = image_size
        self.initializer_range = initializer_range
        self.attention_dropout = attention_dropout
        self.layer_norm_eps = layer_norm_eps
        self.hidden_act = hidden_act
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        # 调用类方法 _set_token_in_kwargs,将 token 设置到 kwargs 中
        cls._set_token_in_kwargs(kwargs)

        # 调用类方法 get_config_dict,获取预训练模型的配置字典和更新后的 kwargs
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        # 如果配置字典中的 model_type 是 "git",则从 vision_config 中获取配置字典
        if config_dict.get("model_type") == "git":
            config_dict = config_dict["vision_config"]

        # 如果配置字典中有 "model_type",并且类有 model_type 属性,并且它们不相等,发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        # 调用类方法 from_dict,使用配置字典和 kwargs 创建预训练配置对象并返回
        return cls.from_dict(config_dict, **kwargs)
class GitConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`GitModel`]. It is used to instantiate a GIT model
    according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the GIT
    [microsoft/git-base](https://huggingface.co/microsoft/git-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Examples:

    ```
    >>> from transformers import GitConfig, GitModel

    >>> # Initializing a GIT microsoft/git-base style configuration
    >>> configuration = GitConfig()

    >>> # Initializing a model (with random weights) from the microsoft/git-base style configuration
    >>> model = GitModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "git"

    def __init__(
        self,
        vision_config=None,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=6,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=1024,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        position_embedding_type="absolute",
        use_cache=True,
        tie_word_embeddings=False,
        bos_token_id=101,
        eos_token_id=102,
        num_image_with_embedding=None,
        **kwargs,
    ):
        # 调用父类构造函数,初始化基本配置,如起始、结束、填充 token ID 等
        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)

        # 如果未提供 vision_config,则使用空字典,并记录日志
        if vision_config is None:
            vision_config = {}
            logger.info("vision_config is None. initializing the GitVisionConfig with default values.")

        # 根据提供的 vision_config 创建 GitVisionConfig 实例
        self.vision_config = GitVisionConfig(**vision_config)
        # 设置模型的词汇表大小
        self.vocab_size = vocab_size
        # 设置隐藏层的大小
        self.hidden_size = hidden_size
        # 设置隐藏层的数量
        self.num_hidden_layers = num_hidden_layers
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads
        # 设置隐藏层的激活函数
        self.hidden_act = hidden_act
        # 设置中间层的大小
        self.intermediate_size = intermediate_size
        # 设置隐藏层的 dropout 概率
        self.hidden_dropout_prob = hidden_dropout_prob
        # 设置注意力机制的 dropout 概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        # 设置最大位置嵌入的长度
        self.max_position_embeddings = max_position_embeddings
        # 设置初始化范围
        self.initializer_range = initializer_range
        # 设置层归一化的 epsilon 值
        self.layer_norm_eps = layer_norm_eps
        # 设置位置嵌入类型
        self.position_embedding_type = position_embedding_type
        # 设置是否使用缓存
        self.use_cache = use_cache
        # 设置是否绑定词嵌入
        self.tie_word_embeddings = tie_word_embeddings
        # 设置具有嵌入的图像数量
        self.num_image_with_embedding = num_image_with_embedding

        # 设置起始 token ID
        self.bos_token_id = bos_token_id
        # 设置结束 token ID
        self.eos_token_id = eos_token_id

.\models\git\convert_git_to_pytorch.py

# 设置脚本的编码格式为 UTF-8
# 版权声明,声明代码归 HuggingFace Inc. 团队所有,遵循 Apache License 2.0
# 获取命令行参数解析器
import argparse
# 导入路径处理模块 Path
from pathlib import Path

# 导入 numpy 库,用于科学计算
import numpy as np
# 导入 requests 库,用于发送 HTTP 请求
import requests
# 导入 PyTorch 深度学习库
import torch
# 从 huggingface_hub 库中导入 hf_hub_download 函数
from huggingface_hub import hf_hub_download
# 导入 PIL 库中的 Image 模块,用于图像处理
from PIL import Image
# 从 torchvision.transforms 模块导入图像预处理函数
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor

# 从 transformers 库中导入相关模块和类
from transformers import (
    AutoTokenizer,
    CLIPImageProcessor,
    GitConfig,
    GitForCausalLM,
    GitProcessor,
    GitVisionConfig,
    VideoMAEImageProcessor,
)
# 从 transformers.utils 模块中导入 logging 模块
from transformers.utils import logging

# 设置日志输出级别为 INFO
logging.set_verbosity_info()
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)


# 定义函数,根据模型名称获取 GitConfig 对象
def get_git_config(model_name):
    # 根据模型名称设置图像大小
    if "base" in model_name and "vqa" in model_name:
        image_size = 480
    elif "large" in model_name and "vqa" in model_name:
        image_size = 420
    else:
        image_size = 224

    # 创建 GitVisionConfig 对象,设置图像大小
    vision_config = GitVisionConfig(image_size=image_size)

    # 如果模型名称中包含 "large",则设置更大的模型参数
    if "large" in model_name:
        vision_config.patch_size = 14
        vision_config.hidden_size = 1024
        vision_config.intermediate_size = 4096
        vision_config.num_hidden_layers = 24
        vision_config.num_attention_heads = 16

    # 根据模型名称判断是否处理视频
    is_video = "vatex" in model_name or "msrvtt" in model_name
    # 如果处理视频,则设置 num_image_with_embedding 为 6,否则为 None
    num_image_with_embedding = 6 if is_video else None
    # 创建 GitConfig 对象,包含视觉配置和图像嵌入数量
    config = GitConfig(vision_config=vision_config.to_dict(), num_image_with_embedding=num_image_with_embedding)

    return config, image_size, is_video


# 定义函数,创建用于重命名的键列表
def create_rename_keys(config, prefix=""):
    rename_keys = []

    # 图像编码器部分的键重命名
    # ftm: off
    rename_keys.append(
        (f"{prefix}image_encoder.class_embedding", "git.image_encoder.vision_model.embeddings.class_embedding")
    )
    rename_keys.append(
        (
            f"{prefix}image_encoder.positional_embedding",
            "git.image_encoder.vision_model.embeddings.position_embedding.weight",
        )
    )
    rename_keys.append(
        (f"{prefix}image_encoder.conv1.weight", "git.image_encoder.vision_model.embeddings.patch_embedding.weight")
    )
    rename_keys.append((f"{prefix}image_encoder.ln_pre.weight", "git.image_encoder.vision_model.pre_layrnorm.weight"))
    rename_keys.append((f"{prefix}image_encoder.ln_pre.bias", "git.image_encoder.vision_model.pre_layrnorm.bias"))
    rename_keys.append(
        (f"{prefix}image_encoder.ln_post.weight", "git.image_encoder.vision_model.post_layernorm.weight")
    )
    rename_keys.append((f"{prefix}image_encoder.ln_post.bias", "git.image_encoder.vision_model.post_layernorm.bias"))
    # 将旧的键和新的键对添加到 rename_keys 列表中,用于重命名权重和偏置项

    # fmt: on
    rename_keys.append((f"{prefix}image_encoder.proj", "git.image_encoder.visual_projection.weight"))
    # 将旧的键和新的键对添加到 rename_keys 列表中,用于重命名视觉投影的权重

    # fmt: off
    for i in range(config.vision_config.num_hidden_layers):
        # 对于每一个视觉编码器的层,依次添加权重和偏置项的重命名对
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.attn.out_proj.weight", f"git.image_encoder.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.attn.out_proj.bias", f"git.image_encoder.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.ln_1.weight", f"git.image_encoder.vision_model.encoder.layers.{i}.layer_norm1.weight"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.ln_1.bias", f"git.image_encoder.vision_model.encoder.layers.{i}.layer_norm1.bias"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.mlp.c_fc.weight", f"git.image_encoder.vision_model.encoder.layers.{i}.mlp.fc1.weight"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.mlp.c_fc.bias", f"git.image_encoder.vision_model.encoder.layers.{i}.mlp.fc1.bias"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.mlp.c_proj.weight", f"git.image_encoder.vision_model.encoder.layers.{i}.mlp.fc2.weight"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.mlp.c_proj.bias", f"git.image_encoder.vision_model.encoder.layers.{i}.mlp.fc2.bias"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.ln_2.weight", f"git.image_encoder.vision_model.encoder.layers.{i}.layer_norm2.weight"))
        rename_keys.append((f"{prefix}image_encoder.transformer.resblocks.{i}.ln_2.bias", f"git.image_encoder.vision_model.encoder.layers.{i}.layer_norm2.bias"))
    # fmt: on

    # text decoder
    # fmt: off
    rename_keys.append((f"{prefix}textual.embedding.words.weight", "git.embeddings.word_embeddings.weight"))
    rename_keys.append((f"{prefix}textual.embedding.positions.weight", "git.embeddings.position_embeddings.weight"))
    rename_keys.append((f"{prefix}textual.visual_projection.0.weight", "git.visual_projection.visual_projection.0.weight"))
    rename_keys.append((f"{prefix}textual.visual_projection.0.bias", "git.visual_projection.visual_projection.0.bias"))
    rename_keys.append((f"{prefix}textual.visual_projection.1.weight", "git.visual_projection.visual_projection.1.weight"))
    rename_keys.append((f"{prefix}textual.visual_projection.1.bias", "git.visual_projection.visual_projection.1.bias"))
    # 将文本解码器相关的旧的键和新的键对添加到 rename_keys 列表中,用于重命名文本嵌入和视觉投影的权重和偏置项
    # 将需要重命名的键值对添加到 rename_keys 列表中
    rename_keys.append((f"{prefix}textual.embedding.layer_norm.weight", "git.embeddings.LayerNorm.weight"))
    rename_keys.append((f"{prefix}textual.embedding.layer_norm.bias", "git.embeddings.LayerNorm.bias"))
    rename_keys.append((f"{prefix}textual.output.weight", "output.weight"))
    rename_keys.append((f"{prefix}textual.output.bias", "output.bias"))
    
    # 遍历配置中指定的隐藏层数量,生成对应的键值对并添加到 rename_keys 中
    for i in range(config.num_hidden_layers):
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.self.query.weight", f"git.encoder.layer.{i}.attention.self.query.weight"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.self.query.bias", f"git.encoder.layer.{i}.attention.self.query.bias"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.self.key.weight", f"git.encoder.layer.{i}.attention.self.key.weight"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.self.key.bias", f"git.encoder.layer.{i}.attention.self.key.bias"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.self.value.weight", f"git.encoder.layer.{i}.attention.self.value.weight"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.self.value.bias", f"git.encoder.layer.{i}.attention.self.value.bias"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.output.dense.weight", f"git.encoder.layer.{i}.attention.output.dense.weight"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.output.dense.bias", f"git.encoder.layer.{i}.attention.output.dense.bias"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.output.LayerNorm.weight", f"git.encoder.layer.{i}.attention.output.LayerNorm.weight"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.attention.output.LayerNorm.bias", f"git.encoder.layer.{i}.attention.output.LayerNorm.bias"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.intermediate.dense.weight", f"git.encoder.layer.{i}.intermediate.dense.weight"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.intermediate.dense.bias", f"git.encoder.layer.{i}.intermediate.dense.bias"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.output.dense.weight", f"git.encoder.layer.{i}.output.dense.weight"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.output.dense.bias", f"git.encoder.layer.{i}.output.dense.bias"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.output.LayerNorm.weight", f"git.encoder.layer.{i}.output.LayerNorm.weight"))
        rename_keys.append((f"{prefix}textual.transformer.encoder.layer.{i}.output.LayerNorm.bias", f"git.encoder.layer.{i}.output.LayerNorm.bias"))
    # fmt: on
    # 如果配置中指定了嵌入图像的数量,则执行以下操作
    if config.num_image_with_embedding is not None:
        # 将以下键值对添加到重命名键列表中,用于重命名图像临时嵌入的索引
        rename_keys.append(("img_temperal_embedding.0", "git.img_temperal_embedding.0"))
        rename_keys.append(("img_temperal_embedding.1", "git.img_temperal_embedding.1"))
        rename_keys.append(("img_temperal_embedding.2", "git.img_temperal_embedding.2"))
        rename_keys.append(("img_temperal_embedding.3", "git.img_temperal_embedding.3"))
        rename_keys.append(("img_temperal_embedding.4", "git.img_temperal_embedding.4"))
        rename_keys.append(("img_temperal_embedding.5", "git.img_temperal_embedding.5"))

    # 返回更新后的重命名键列表
    return rename_keys
# 从字典中移除旧键,将其对应的值保存到变量val中
def rename_key(dct, old, new):
    val = dct.pop(old)
    # 如果新键中包含特定字符串,则对值进行转置操作
    dct[new] = val.T if "image_encoder.visual_projection" in new else val


# 从状态字典中读取查询、键和值,并添加到指定位置的新键名下
def read_in_q_k_v(state_dict, config, prefix=""):
    # 获取隐藏层的大小
    dim = config.vision_config.hidden_size
    for i in range(config.vision_config.num_hidden_layers):
        # 读取注意力机制中的输入投影层的权重和偏置
        in_proj_weight = state_dict.pop(f"{prefix}image_encoder.transformer.resblocks.{i}.attn.in_proj_weight")
        in_proj_bias = state_dict.pop(f"{prefix}image_encoder.transformer.resblocks.{i}.attn.in_proj_bias")
        # 将查询、键和值的投影加入到状态字典中
        state_dict[f"git.image_encoder.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:dim, :]
        state_dict[f"git.image_encoder.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:dim]
        state_dict[f"git.image_encoder.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[dim:dim*2, :]
        state_dict[f"git.image_encoder.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[dim:dim*2]
        state_dict[f"git.image_encoder.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-dim:, :]
        state_dict[f"git.image_encoder.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-dim:]


# 根据模型名称准备图像数据
def prepare_img(model_name):
    if "textvqa" in model_name:
        # 如果模型名称包含"textvqa",则下载并打开示例图像文件
        filepath = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
        image = Image.open(filepath).convert("RGB")
    else:
        # 否则,从指定的 URL 下载图像文件
        url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        image = Image.open(requests.get(url, stream=True).raw)

    return image


# 准备视频数据,使用decord库进行视频处理
def prepare_video():
    from decord import VideoReader, cpu

    # 设置随机数种子以保证可重现性
    np.random.seed(0)

    def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
        """
        Sample a given number of frame indices from the video.

        Args:
            clip_len (`int`): Total number of frames to sample.
            frame_sample_rate (`int`): Sample every n-th frame.
            seg_len (`int`): Maximum allowed index of sample's last frame.

        Returns:
            indices (`List[int]`): List of sampled frame indices
        """
        # 计算需要采样的帧的数量
        converted_len = int(clip_len * frame_sample_rate)
        # 在视频长度内随机选择结束帧索引
        end_idx = np.random.randint(converted_len, seg_len)
        start_idx = end_idx - converted_len
        # 生成均匀分布的帧索引列表,并限制在视频长度范围内
        indices = np.linspace(start_idx, end_idx, num=clip_len)
        indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        return indices
    # 从指定的 HF Hub 仓库下载视频数据集中的特定文件,此处下载的文件是 "eating_spaghetti.mp4"
    file_path = hf_hub_download(repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset")
    
    # 使用 VideoReader 类读取视频文件,设置线程数为 1,在 CPU 0 上执行
    videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
    
    # 将视频读取器定位到视频的起始位置
    videoreader.seek(0)
    
    # 通过 sample_frame_indices 函数从视频中随机抽取 6 个帧的索引
    # clip_len=6 表示要抽取 6 个帧
    # frame_sample_rate=4 表示每隔 4 个帧抽取一次
    # seg_len=len(videoreader) 获取视频的总帧数,作为抽取帧的范围
    indices = sample_frame_indices(clip_len=6, frame_sample_rate=4, seg_len=len(videoreader))
    
    # 从 videoreader 中获取指定 indices 的帧数据,返回一个 numpy 数组
    video = videoreader.get_batch(indices).asnumpy()
    
    # 返回抽取的视频帧数据
    return video
# 声明一个装饰器,用于指示在函数执行过程中不需要计算梯度
@torch.no_grad()
def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
    """
    Copy/paste/tweak model's weights to our GIT structure.
    """

    # 定义不同模型名称对应的预训练模型下载链接
    model_name_to_url = {
        "git-base": "https://publicgit.blob.core.windows.net/data/output/GIT_BASE/snapshot/model.pt",
        "git-base-coco": "https://publicgit.blob.core.windows.net/data/output/GIT_BASE_COCO/snapshot/model.pt",
        "git-base-textcaps": "https://publicgit.blob.core.windows.net/data/output/GIT_BASE_TEXTCAPS/snapshot/model.pt",
        "git-base-vqav2": "https://publicgit.blob.core.windows.net/data/output/GIT_BASE_VQAv2/snapshot/model.pt",
        "git-base-textvqa": "https://publicgit.blob.core.windows.net/data/output/GIT_BASE_TEXTVQA/snapshot/model.pt",  # todo
        "git-base-vatex": "https://publicgit.blob.core.windows.net/data/output/GIT_BASE_VATEX/snapshot/model.pt",
        "git-base-msrvtt-qa": (
            "https://publicgit.blob.core.windows.net/data/output/GIT_BASE_MSRVTT_QA/snapshot/model.pt"
        ),
        "git-large": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE/snapshot/model.pt",
        "git-large-coco": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_COCO/snapshot/model.pt",
        "git-large-textcaps": (
            "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_TEXTCAPS/snapshot/model.pt"
        ),
        "git-large-vqav2": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_VQAv2/snapshot/model.pt",
        "git-large-textvqa": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_TEXTVQA/snapshot/model.pt",
        "git-large-vatex": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_VATEX/snapshot/model.pt",
        "git-large-msrvtt-qa": (
            "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_MSRVTT_QA/snapshot/model.pt"
        ),
        "git-large-r": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R/snapshot/model.pt",
        "git-large-r-coco": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_COCO/snapshot/model.pt",
        "git-large-r-textcaps": (
            "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_TEXTCAPS/snapshot/model.pt"
        ),
    }

    # 定义不同模型名称对应的本地路径
    model_name_to_path = {
        "git-large": "/Users/nielsrogge/Documents/GIT/git_large_model.pt",
        "git-large-coco": "/Users/nielsrogge/Documents/GIT/git_large_coco_model.pt",
        "git-large-textcaps": "/Users/nielsrogge/Documents/GIT/git_large_textcaps_model.pt",
        "git-large-vqav2": "/Users/nielsrogge/Documents/GIT/git_large_vqav2_model.pt",
        "git-large-textvqa": "/Users/nielsrogge/Documents/GIT/git_large_textvqa_model.pt",
    }

    # 根据模型名称获取相应的 GIT 配置,图像尺寸和是否为视频
    config, image_size, is_video = get_git_config(model_name)
    # 检查模型名称中是否包含"large",且不是视频模型,且不是"large-r"模型
    if "large" in model_name and not is_video and "large-r" not in model_name:
        # 如果是大模型,从本地加载预训练权重
        checkpoint_path = model_name_to_path[model_name]
        state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
    else:
        # 否则,从指定的 URL 加载预训练权重
        checkpoint_url = model_name_to_url[model_name]
        state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", file_name=model_name)["model"]

    # 根据模型名称确定键名前缀是否为"module."
    prefix = "module." if model_name == "git-base" else ""
    # 创建重命名键名的映射列表
    rename_keys = create_rename_keys(config, prefix=prefix)
    # 对预训练权重中的键名进行重命名
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    # 读取输入、查询和值的权重
    read_in_q_k_v(state_dict, config, prefix=prefix)

    # 加载 HuggingFace 模型
    model = GitForCausalLM(config)
    # 加载模型权重,允许缺少键名和不期待的键名
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    model.eval()

    print("Missing keys:", missing_keys)
    print("Unexpected keys:", unexpected_keys)

    # 断言确实缺少的键名和意外的键名
    assert missing_keys == ["git.embeddings.position_ids", "git.image_encoder.vision_model.embeddings.position_ids"]
    assert unexpected_keys == ["git.image_encoder.visual_projection.weight"]

    # 验证处理结果
    # 根据是否为视频选择不同的图像处理器
    image_processor = (
        VideoMAEImageProcessor(
            size={"shortest_edge": image_size}, crop_size={"height": image_size, "width": image_size}
        )
        if is_video
        else CLIPImageProcessor(
            size={"shortest_edge": image_size}, crop_size={"height": image_size, "width": image_size}
        )
    )
    # 根据模型类型选择适当的 tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        "google-bert/bert-base-uncased", model_input_names=["input_ids", "attention_mask"]
    )
    # 创建 GitProcessor 对象,用于处理文本和图像输入
    processor = GitProcessor(tokenizer=tokenizer, image_processor=image_processor)

    if is_video:
        # 准备视频并处理像素值
        video = prepare_video()
        pixel_values = processor(images=list(video), return_tensors="pt").pixel_values
    else:
        # 准备图像并进行图像转换
        image = prepare_img(model_name)
        image_transforms = Compose(
            [
                Resize(image_size, interpolation=Image.BICUBIC),
                CenterCrop(image_size),
                ToTensor(),
                Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
            ]
        )
        # 对原始图像应用转换并获取像素值张量
        original_pixel_values = image_transforms(image).unsqueeze(0)
        pixel_values = processor(images=image, return_tensors="pt").pixel_values

        # 断言处理后的像素值与原始像素值接近
        assert torch.allclose(pixel_values, original_pixel_values)

    # 创建输入张量
    input_ids = torch.tensor([[101]])
    # 使用模型生成输出
    outputs = model(input_ids, pixel_values=pixel_values)
    logits = outputs.logits
    print("Logits:", logits[0, -1, :3])

    # 根据模型名称选择预期的切片 logits
    if model_name == "git-base":
        expected_slice_logits = torch.tensor([-1.2832, -1.2835, -1.2840])
    elif model_name == "git-base-coco":
        expected_slice_logits = torch.tensor([-0.9925, -0.9930, -0.9935])
    # 如果模型名称为 "git-base-textcaps",设置预期的输出 logits
    elif model_name == "git-base-textcaps":
        expected_slice_logits = torch.tensor([-1.2980, -1.2983, -1.2985])
    # 如果模型名称为 "git-base-vqav2",设置预期的输出 logits
    elif model_name == "git-base-vqav2":
        expected_slice_logits = torch.tensor([-0.8570, -0.8568, -0.8561])
    # 如果模型名称为 "git-base-textvqa",设置预期的输出 logits
    elif model_name == "git-base-textvqa":
        expected_slice_logits = torch.tensor([-1.4085, -1.4083, -1.4082])
    # 如果模型名称为 "git-base-vatex",设置预期的输出 logits
    elif model_name == "git-base-vatex":
        expected_slice_logits = torch.tensor([-1.3451, -1.3447, -1.3447])
    # 如果模型名称为 "git-base-msrvtt-qa",设置预期的输出 logits
    elif model_name == "git-base-msrvtt-qa":
        expected_slice_logits = torch.tensor([-0.8554, -0.8550, -0.8540])
    # 如果模型名称为 "git-large",设置预期的输出 logits
    elif model_name == "git-large":
        expected_slice_logits = torch.tensor([-1.1708, -1.1707, -1.1705])
    # 如果模型名称为 "git-large-coco",设置预期的输出 logits
    elif model_name == "git-large-coco":
        expected_slice_logits = torch.tensor([-1.0425, -1.0423, -1.0422])
    # 如果模型名称为 "git-large-textcaps",设置预期的输出 logits
    elif model_name == "git-large-textcaps":
        expected_slice_logits = torch.tensor([-1.2705, -1.2708, -1.2706])
    # 如果模型名称为 "git-large-vqav2",设置预期的输出 logits
    elif model_name == "git-large-vqav2":
        expected_slice_logits = torch.tensor([-0.7042, -0.7043, -0.7043])
    # 如果模型名称为 "git-large-textvqa",设置预期的输出 logits
    elif model_name == "git-large-textvqa":
        expected_slice_logits = torch.tensor([-0.8590, -0.8592, -0.8590])
    # 如果模型名称为 "git-large-vatex",设置预期的输出 logits
    elif model_name == "git-large-vatex":
        expected_slice_logits = torch.tensor([-1.0113, -1.0114, -1.0113])
    # 如果模型名称为 "git-large-msrvtt-qa",设置预期的输出 logits
    elif model_name == "git-large-msrvtt-qa":
        expected_slice_logits = torch.tensor([0.0130, 0.0134, 0.0131])
    # 如果模型名称为 "git-large-r",设置预期的输出 logits
    elif model_name == "git-large-r":
        expected_slice_logits = torch.tensor([-1.1283, -1.1285, -1.1286])
    # 如果模型名称为 "git-large-r-coco",设置预期的输出 logits
    elif model_name == "git-large-r-coco":
        expected_slice_logits = torch.tensor([-0.9641, -0.9641, -0.9641])
    # 如果模型名称为 "git-large-r-textcaps",设置预期的输出 logits
    elif model_name == "git-large-r-textcaps":
        expected_slice_logits = torch.tensor([-1.1121, -1.1120, -1.1124])

    # 断言检查模型输出 logits 的正确性
    assert torch.allclose(logits[0, -1, :3], expected_slice_logits, atol=1e-4)
    # 输出提示信息
    print("Looks ok!")

    # 根据模型名称设置不同的提示语句
    prompt = ""
    if "textvqa" in model_name:
        prompt = "what does the front of the bus say at the top?"
    elif "msrvtt-qa" in model_name:
        prompt = "what does the woman eat?"
    elif "vqa" in model_name:
        prompt = "what are the cats doing?"

    # 使用分词器处理提示语句,生成输入的 token IDs
    input_ids = tokenizer(prompt, add_special_tokens=False).input_ids
    # 在 token IDs 前添加特殊 token 的 ID
    input_ids = [processor.tokenizer.cls_token_id] + input_ids
    # 将输入 token IDs 转换成张量并增加一个维度
    input_ids = torch.tensor(input_ids).unsqueeze(0)
    # 输出生成标题的提示信息
    print("Generating caption...")
    # 使用模型生成标题
    generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
    # 打印生成的标题,跳过特殊 token 的解码
    print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))

    # 如果指定了 PyTorch 模型保存路径
    if pytorch_dump_folder_path is not None:
        # 确保路径存在或创建路径
        Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
        # 输出保存模型和处理器的信息
        print(f"Saving model and processor of {model_name} to {pytorch_dump_folder_path}")
        # 将模型保存到指定路径
        model.save_pretrained(pytorch_dump_folder_path)
        # 将处理器保存到指定路径
        processor.save_pretrained(pytorch_dump_folder_path)
    # 如果 push_to_hub 为 True,则执行以下操作
    if push_to_hub:
        # 打印推送模型和处理器到 hub 的信息,包括模型名称
        print(f"Pushing model and processor of {model_name} to the hub...")
        # 调用 model 对象的 push_to_hub 方法,将模型推送到 Microsoft 的 hub 中
        model.push_to_hub(f"microsoft/{model_name}")
        # 调用 processor 对象的 push_to_hub 方法,将处理器推送到 Microsoft 的 hub 中
        processor.push_to_hub(f"microsoft/{model_name}")
if __name__ == "__main__":
    # 如果作为主程序运行,执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # Required parameters
    parser.add_argument(
        "--model_name",
        default="git-base",
        type=str,
        help="Name of the model you'd like to convert.",
    )
    # 添加一个必需的参数:模型的名称,如果未提供则默认为"git-base",类型为字符串

    parser.add_argument(
        "--pytorch_dump_folder_path",
        default=None,
        type=str,
        help="Path to the output PyTorch model directory.",
    )
    # 添加一个参数:PyTorch 模型输出目录的路径,如果未提供则为None,默认类型为字符串

    parser.add_argument(
        "--push_to_hub",
        action="store_true",
        help="Whether to push the model to the hub.",
    )
    # 添加一个参数:是否将模型推送到 hub 上,采用布尔标志方式

    args = parser.parse_args()
    # 解析命令行参数并将其存储在 args 对象中

    convert_git_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
    # 调用函数 convert_git_checkpoint,传入解析后的参数:模型名称、输出目录路径、是否推送到 hub

.\models\git\modeling_git.py

# coding=utf-8
# 定义文件编码格式为 UTF-8

# 版权声明和许可证信息,声明版权归 Microsoft Research 和 HuggingFace Inc. 团队所有
# 受 Apache 许可证第 2.0 版的限制,除非遵守许可证,否则不得使用此文件
# 可在 http://www.apache.org/licenses/LICENSE-2.0 获取许可证副本

# 引入 math 模块,用于数学运算
import math
# 引入 dataclass 用于创建数据类,引入 List、Optional、Tuple 和 Union 用于类型注解
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

# 引入 PyTorch 深度学习框架
import torch
# 引入 PyTorch 中的检查点功能
import torch.utils.checkpoint
# 从 torch 模块中导入 nn 模块,用于神经网络构建
from torch import nn
# 从 nn 模块导入交叉熵损失函数
from torch.nn import CrossEntropyLoss

# 引入相对路径下的模块和函数
from ...activations import ACT2FN
from ...file_utils import ModelOutput
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPast,
    BaseModelOutputWithPooling,
    CausalLMOutputWithPast,
)
# 从 modeling_utils 模块导入预训练模型的基类
from ...modeling_utils import PreTrainedModel
# 从 pytorch_utils 模块导入一些辅助函数
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
# 从 utils 模块导入添加文档字符串、模型前向方法的文档字符串、日志记录和替换返回文档字符串的函数
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
# 从当前目录的 configuration_git 模块中导入 GitConfig 和 GitVisionConfig 类
from .configuration_git import GitConfig, GitVisionConfig

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 用于文档的检查点名称
_CHECKPOINT_FOR_DOC = "microsoft/git-base"
# 用于文档的配置名称
_CONFIG_FOR_DOC = "GitConfig"

# 预训练的 Git 模型存档列表
GIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "microsoft/git-base",
    # 可以在 https://huggingface.co/models?filter=git 查看所有 GIT 模型
]

# 数据类,用于描述 Git 视觉模型的输出
@dataclass
# 继承自 ModelOutput 类
# 与 CLIP 模型中的 CLIPVisionModelOutput 类似,但适用于 Git 模型
class GitVisionModelOutput(ModelOutput):
    """
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
    """
    """
    Args:
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 可选参数:图像嵌入向量,形状为 `(batch_size, output_dim)`,在模型初始化时如果使用 `with_projection=True` 会返回
    image_embeds: Optional[torch.FloatTensor] = None

    # 必需参数:最后一层模型的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)`
    last_hidden_state: torch.FloatTensor = None

    # 可选参数:模型隐藏状态的元组,包含每层的输出,如果设置了 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None

    # 可选参数:注意力权重的元组,包含每层的注意力权重,如果设置了 `output_attentions=True` 或 `config.output_attentions=True` 时返回
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class GitEmbeddings(nn.Module):
    """构建从词嵌入和位置嵌入到最终嵌入的模块。"""

    def __init__(self, config):
        super().__init__()
        # 初始化词嵌入层,根据配置参数指定词汇量大小、隐藏层大小,并设置填充索引
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 初始化位置嵌入层,根据配置参数指定最大位置嵌入数量和隐藏层大小
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        # 使用非蛇形命名以保持与 TensorFlow 模型变量名的一致性,并能够加载任何 TensorFlow 检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 初始化 Dropout 层,根据配置参数指定隐藏层的丢弃概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 设置位置嵌入类型,默认为绝对位置嵌入
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # 注册位置 id 缓冲区,是一个 1x最大位置嵌入数量的张量,用于序列化时持久化存储
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
    ) -> torch.Tensor:
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        # 如果未提供位置 ids,则使用预注册的位置 ids,并根据序列长度截取所需部分
        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        # 如果未提供输入嵌入向量,则根据输入 ids 获取词嵌入
        if inputs_embeds is None:
            embeddings = self.word_embeddings(input_ids)
        else:
            embeddings = inputs_embeds

        # 如果位置嵌入类型为绝对,则添加位置嵌入到词嵌入中
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings

        # 对嵌入向量进行 LayerNorm 处理
        embeddings = self.LayerNorm(embeddings)
        # 对嵌入向量进行 Dropout 处理
        embeddings = self.dropout(embeddings)
        return embeddings


class GitSelfAttention(nn.Module):
    # 在这里开始编写 GitSelfAttention 类的注释
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 检查隐藏大小是否可以被注意力头的数量整除,同时没有嵌入大小属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 设置注意力头的数量和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 计算图像块标记的数量
        self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
        if config.num_image_with_embedding is not None:
            self.image_patch_tokens *= config.num_image_with_embedding

        # 定义查询、键、值的线性层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 定义dropout层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

        # 确定位置嵌入的类型,默认为绝对位置嵌入
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        # 如果位置嵌入类型是相对键或相对键查询,则设置最大位置嵌入数和距离嵌入层
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        # 将张量重新形状以适应注意力头的排列
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
        pixel_values_present: Optional[bool] = False,
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
class GitSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层,输入和输出大小都为config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 创建一个LayerNorm层,对隐藏状态进行归一化,设置eps为config.layer_norm_eps
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建一个Dropout层,使用config.hidden_dropout_prob概率进行随机丢弃
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态通过线性层dense进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的结果进行随机丢弃
        hidden_states = self.dropout(hidden_states)
        # 将丢弃后的结果与输入张量input_tensor相加,并通过LayerNorm层进行归一化处理
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回处理后的隐藏状态张量
        return hidden_states


class GitAttention(nn.Module):
    # Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 创建GitSelfAttention对象,传入config和position_embedding_type参数
        self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
        # 创建GitSelfOutput对象,传入config参数
        self.output = GitSelfOutput(config)
        # 初始化一个空集合,用于存储被修剪的注意力头
        self.pruned_heads = set()

    # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
    def prune_heads(self, heads):
        # 如果heads列表为空,则直接返回
        if len(heads) == 0:
            return
        # 调用find_pruneable_heads_and_indices函数,找到可以修剪的头部及其索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 对self中的query、key、value和output.dense属性进行修剪线性层操作
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储修剪后的头部
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
        pixel_values_present: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 调用self.self的forward方法,进行自注意力机制计算
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            past_key_value,
            output_attentions,
            pixel_values_present,
        )
        # 将self_outputs的第一个元素作为输入,通过self.output进行输出层的处理
        attention_output = self.output(self_outputs[0], hidden_states)
        # 如果需要输出注意力权重,则将其加入到输出元组中
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        # 返回输出元组
        return outputs


# Copied from transformers.models.bert.modeling_bert.BertIntermediate
class GitIntermediate(nn.Module):
    # 初始化函数,用于创建一个新的神经网络层对象
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个全连接层,将输入维度为 config.hidden_size 的向量映射到 config.intermediate_size 的向量
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 根据配置文件中的激活函数类型,选择合适的激活函数
        if isinstance(config.hidden_act, str):
            # 如果配置的激活函数是字符串类型,则从预定义的字典 ACT2FN 中获取对应的函数
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            # 否则,直接使用配置中指定的激活函数
            self.intermediate_act_fn = config.hidden_act
    
    # 前向传播函数,接收一个张量 hidden_states 作为输入,返回一个张量作为输出
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入张量通过全连接层 self.dense 进行线性变换
        hidden_states = self.dense(hidden_states)
        # 将线性变换后的张量通过选择的激活函数 self.intermediate_act_fn 进行非线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 返回经过激活函数处理后的张量作为输出
        return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertOutput
# 从 transformers 库中的 BertOutput 类复制而来,用于定义 Git 模型的输出层
class GitOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 定义一个全连接层,将中间隐藏层的大小映射到最终隐藏层的大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # LayerNorm 层,对隐藏状态进行归一化处理
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # Dropout 层,以一定概率随机将输入置零,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 全连接层映射操作,将中间隐藏状态映射到最终隐藏状态
        hidden_states = self.dense(hidden_states)
        # Dropout 操作,对全连接层输出进行随机置零
        hidden_states = self.dropout(hidden_states)
        # LayerNorm 操作,对加上输入张量后的隐藏状态进行归一化处理
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回处理后的隐藏状态作为输出
        return hidden_states


class GitLayer(nn.Module):
    # 从 transformers.models.bert.modeling_bert.BertLayer 复制而来,用于定义 Git 模型的层
    def __init__(self, config):
        super().__init__()
        # 定义一个块大小用于前馈传播的参数
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度,默认为1
        self.seq_len_dim = 1
        # Git 模型中的注意力机制层
        self.attention = GitAttention(config)
        # Git 模型中的中间层
        self.intermediate = GitIntermediate(config)
        # Git 模型中的输出层
        self.output = GitOutput(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
        pixel_values_present: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 如果过去键/值存在,则从中提取自注意力的缓存键/值元组
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 执行自注意力机制层的前向传播,输出包含注意力输出和其他可能的输出
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
            pixel_values_present=pixel_values_present,
        )
        # 提取自注意力层的注意力输出
        attention_output = self_attention_outputs[0]

        # 如果是解码器,最后的输出是自注意力缓存元组
        outputs = self_attention_outputs[1:-1]
        # 提取当前键/值
        present_key_value = self_attention_outputs[-1]

        # 将前馈传播函数应用于注意力输出
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        # 将层输出添加到输出元组中
        outputs = (layer_output,) + outputs

        # 如果是解码器,将注意力键/值作为最后的输出添加到元组中
        outputs = outputs + (present_key_value,)

        # 返回所有输出
        return outputs

    def feed_forward_chunk(self, attention_output):
        # 中间层的前馈传播函数,先通过中间层处理注意力输出,然后通过输出层得到层输出
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        # 返回层输出
        return layer_output


class GitEncoder(nn.Module):
    # 从 transformers.models.bert.modeling_bert.BertEncoder.__init__ 复制而来,用于定义 Git 模型的编码器层
    # 暂时没有提供具体实现
    # 初始化方法,用于创建一个新的实例
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 将传入的配置参数保存到实例变量中
        self.config = config
        # 创建一个包含多个 GitLayer 实例的模块列表,数量由配置参数决定
        self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)])
        # 梯度检查点标志,默认为 False
        self.gradient_checkpointing = False

    # 前向传播方法,定义了模型的正向计算过程
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        pixel_values_present: Optional[bool] = False,
        return_dict: Optional[bool] = True,
        ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
        # 如果启用了梯度检查点并且处于训练模式下
        if self.gradient_checkpointing and self.training:
            # 如果 use_cache 参数为 True,则发出警告并强制将其设为 False
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # 初始化用于存储所有隐藏状态和自注意力权重的变量
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        # 如果不使用缓存,初始化存储下一步解码器缓存的变量
        next_decoder_cache = () if use_cache else None

        # 遍历所有的 Transformer 层
        for i, layer_module in enumerate(self.layer):
            # 如果输出隐藏状态,则将当前层的隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码(如果有)
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 获取过去的键值对(如果有)
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 如果启用了梯度检查点并且处于训练模式下,调用梯度检查点函数
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                # 否则直接调用当前层的 forward 方法
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    past_key_value,
                    output_attentions,
                    pixel_values_present,
                )

            # 更新当前隐藏状态为当前层输出的隐藏状态
            hidden_states = layer_outputs[0]

            # 如果使用缓存,将当前层的缓存信息添加到 next_decoder_cache 中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)

            # 如果输出自注意力权重,则将当前层的自注意力权重添加到 all_self_attentions 中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果输出隐藏状态,则将最终的隐藏状态添加到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典形式的结果,将各个部分组合成元组返回
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                ]
                if v is not None
            )
        # 否则返回一个包含多个属性的对象,表示模型的输出
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
# GitPreTrainedModel 类继承自 PreTrainedModel 类,用于处理权重初始化和预训练模型的下载和加载接口。
class GitPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 配置类,指定为 GitConfig
    config_class = GitConfig
    # 基础模型前缀为 "git"
    base_model_prefix = "git"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    # 初始化模型权重
    def _init_weights(self, module):
        """Initialize the weights"""
        # 如果 module 是 GitVisionEmbeddings 类的实例
        if isinstance(module, GitVisionEmbeddings):
            # 初始化 class_embedding 层的权重为正态分布,均值为 0.0,标准差为 self.config.initializer_range
            nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
            # 初始化 patch_embedding 层的权重为正态分布,标准差为 self.config.initializer_range
            nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
            # 初始化 position_embedding 层的权重为正态分布,标准差为 self.config.initializer_range
            nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
        # 如果 module 是 nn.Linear 类的实例
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化权重,均值为 0.0,标准差为 self.config.initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果有偏置项,将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果 module 是 nn.Embedding 类的实例
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重,均值为 0.0,标准差为 self.config.initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果有 padding_idx,将其对应的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 如果 module 是 nn.LayerNorm 类的实例
        elif isinstance(module, nn.LayerNorm):
            # 将偏置项初始化为零
            module.bias.data.zero_()
            # 将权重初始化为 1.0
            module.weight.data.fill_(1.0)


GIT_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`GitConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

GIT_INPUTS_DOCSTRING = r"""
        Args:
            input_ids (`torch.LongTensor` of shape `({0})`):
                # 输入序列标记在词汇表中的索引

                Indices of input sequence tokens in the vocabulary.
                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
                # 注意力遮罩,避免在填充标记上执行注意力操作

                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)

            position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
                # 输入序列标记在位置嵌入中的索引

                Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
                config.max_position_embeddings - 1]`.

                [What are position IDs?](../glossary#position-ids)

            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                # 像素值,可以使用 [`AutoImageProcessor`] 获取

                Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
                [`CLIPImageProcessor.__call__`] for details.

            head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
                # 自注意力模块中选择性屏蔽的头部掩码

                Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
                # 直接传入嵌入表示而不是 `input_ids` 的选择性参数

                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
                is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
                model's internal embedding lookup matrix.

            output_attentions (`bool`, *optional*):
                # 是否返回所有注意力层的注意力张量

                Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
                tensors for more detail.

            output_hidden_states (`bool`, *optional*):
                # 是否返回所有层的隐藏状态

                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
                more detail.

            return_dict (`bool`, *optional*):
                # 是否返回 [`~utils.ModelOutput`] 而不是普通元组

                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
# 从 transformers.models.clip.modeling_clip.CLIPVisionEmbeddings 复制而来的 GitVisionEmbeddings 类定义
class GitVisionEmbeddings(nn.Module):
    # 初始化函数,接收一个 GitVisionConfig 类型的参数 config
    def __init__(self, config: GitVisionConfig):
        super().__init__()
        # 将 config 参数保存到对象的 config 属性中
        self.config = config
        # 设置嵌入维度为隐藏大小
        self.embed_dim = config.hidden_size
        # 设置图像大小为配置中的图像大小
        self.image_size = config.image_size
        # 设置patch大小为配置中的patch大小
        self.patch_size = config.patch_size

        # 创建一个 nn.Parameter 类型的类别嵌入,维度为隐藏大小
        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        # 创建一个二维卷积层,输入通道数为配置中的通道数,输出通道数为隐藏大小,核大小为patch大小,步长为patch大小,无偏置项
        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

        # 计算图像中patch的数量
        self.num_patches = (self.image_size // self.patch_size) ** 2
        # 计算位置嵌入的数量为patch的数量加1
        self.num_positions = self.num_patches + 1
        # 创建一个位置嵌入,大小为(num_positions, embed_dim),用于嵌入位置信息
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        # 注册一个缓冲区,存储位置ID的张量,形状为(1, num_positions),不持久保存
        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

    # 前向传播函数,接收像素值张量作为输入,返回嵌入的张量
    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        # 获取输入张量的批量大小
        batch_size = pixel_values.shape[0]
        # 设置目标数据类型为patch_embedding权重的数据类型
        target_dtype = self.patch_embedding.weight.dtype
        # 对输入像素值进行patch嵌入,形状为[*, width, grid, grid]
        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
        # 将patch_embeds扁平化,并转置维度1和2
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        # 将类别嵌入扩展到与批量大小匹配的维度
        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        # 拼接类别嵌入和patch嵌入,维度为[batch_size, num_patches + 1, embed_dim]
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        # 将位置嵌入加到嵌入张量上,形状为[batch_size, num_patches + 1, embed_dim]
        embeddings = embeddings + self.position_embedding(self.position_ids)
        # 返回嵌入张量
        return embeddings


# 从 transformers.models.clip.modeling_clip.CLIPMLP 复制而来的 GitVisionMLP 类定义
class GitVisionMLP(nn.Module):
    # 初始化函数,接收一个 config 参数
    def __init__(self, config):
        super().__init__()
        # 将 config 参数保存到对象的 config 属性中
        self.config = config
        # 设置激活函数为配置中指定的隐藏层激活函数
        self.activation_fn = ACT2FN[config.hidden_act]
        # 创建一个全连接层,输入大小为隐藏大小,输出大小为中间大小
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        # 创建一个全连接层,输入大小为中间大小,输出大小为隐藏大小
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    # 前向传播函数,接收隐藏状态张量作为输入,返回变换后的张量
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 经过第一个全连接层和激活函数的变换
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        # 经过第二个全连接层的变换
        hidden_states = self.fc2(hidden_states)
        # 返回变换后的张量
        return hidden_states


# 从 transformers.models.clip.modeling_clip.CLIPAttention 复制而来的 GitVisionAttention 类定义
class GitVisionAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 将配置对象保存为属性
        self.config = config
        # 设置嵌入维度为配置中的隐藏大小
        self.embed_dim = config.hidden_size
        # 设置注意力头的数量为配置中的注意力头数
        self.num_heads = config.num_attention_heads
        # 计算每个注意力头的维度
        self.head_dim = self.embed_dim // self.num_heads
        # 检查嵌入维度是否能整除注意力头数,若不能则抛出异常
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: "
                f"{self.num_heads})."
            )
        # 计算缩放因子,用于注意力计算
        self.scale = self.head_dim**-0.5
        # 设置注意力中的 dropout 率
        self.dropout = config.attention_dropout

        # 初始化线性变换层,用于查询、键、值和输出的投影
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    # 定义一个函数用于重塑张量形状,用于多头注意力
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    # 前向传播函数,接受隐藏状态张量和可选的注意力掩码作为输入
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GitVision
class GitVisionEncoderLayer(nn.Module):
    def __init__(self, config: GitVisionConfig):
        super().__init__()
        # 设置嵌入维度为隐藏大小
        self.embed_dim = config.hidden_size
        # 使用 GitVisionAttention 类创建自注意力机制对象
        self.self_attn = GitVisionAttention(config)
        # 第一层归一化,使用 LayerNorm
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        # MLP 层,使用 GitVisionMLP 创建多层感知机
        self.mlp = GitVisionMLP(config)
        # 第二层归一化,使用 LayerNorm
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        causal_attention_mask: torch.Tensor,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): 输入层的张量,形状为 `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): 注意力掩码张量,形状为
                `(batch, 1, tgt_len, src_len)`,其中填充元素由非常大的负值指示。
            output_attentions (`bool`, *optional*):
                是否返回所有注意力层的注意力张量。详见返回的张量中的 `attentions`。

        Returns:
            outputs: 包含处理结果的元组,其中元素为 `torch.FloatTensor` 张量
        """
        # 保存残差连接
        residual = hidden_states

        # 应用第一层归一化
        hidden_states = self.layer_norm1(hidden_states)
        # 使用自注意力机制计算注意力权重和新的隐藏状态
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        # 添加残差连接到新的隐藏状态
        hidden_states = residual + hidden_states

        # 保存残差连接
        residual = hidden_states
        # 应用第二层归一化
        hidden_states = self.layer_norm2(hidden_states)
        # 应用 MLP 层
        hidden_states = self.mlp(hidden_states)
        # 添加残差连接到最终输出
        hidden_states = residual + hidden_states

        # 构建输出元组,包含最终隐藏状态
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,添加到输出元组中
        if output_attentions:
            outputs += (attn_weights,)

        return outputs


# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->GitVision, CLIPConfig
class GitVisionEncoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`GitVisionEncoderLayer`].

    Args:
        config: GitVisionConfig
    """

    def __init__(self, config: GitVisionConfig):
        super().__init__()
        # 保存配置对象
        self.config = config
        # 使用 GitVisionEncoderLayer 创建多个编码层,并组成层列表
        self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        # 梯度检查点默认关闭
        self.gradient_checkpointing = False

    def forward(
        self,
        inputs_embeds,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Args:
            inputs_embeds: 输入的嵌入张量
            attention_mask: 注意力掩码张量
            causal_attention_mask: 因果注意力掩码张量
            output_attentions: 是否返回注意力张量
            output_hidden_states: 是否返回隐藏状态张量
            return_dict: 是否以字典形式返回结果

        Returns:
            depending on `return_dict`, a tuple of shape `(last_hidden_state, (attentions))` where each
            element is a tensor
        """
        # 留空,等待实现具体的前向传播逻辑
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values input to the model. This tensor represents the images to be processed, organized as batches
            with specified channels, height, and width.
            Padding in the input will be ignored by default.
            Pixel values can be obtained using `AutoImageProcessor`. See `CLIPImageProcessor.__call__` for more details.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attention tensors from all attention layers. If set to `True`, the returned
            tensors will include the attentions for each layer.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states from all layers. If set to `True`, the returned tensors will
            include the hidden states for each layer.

        return_dict (`bool`, *optional*):
            Whether or not to return a `utils.ModelOutput` object instead of a plain tuple. If `True`, the returned
            output will be a structured object containing various model outputs such as logits, hidden states,
            attentions, etc.
"""
@add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
# 覆盖了 forward 方法的 docstring,指定了输入和输出的详细描述

class GitVisionTransformer(nn.Module):
    # 从 transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ 复制而来,将 CLIPEncoder 改为 GitVisionEncoder,CLIP 改为 Git
    # 初始化 GitVisionTransformer 类
    def __init__(self, config: GitVisionConfig):
        super().__init__()
        # 设置配置信息
        self.config = config
        # 设定嵌入维度为隐藏大小
        embed_dim = config.hidden_size

        # 初始化嵌入层
        self.embeddings = GitVisionEmbeddings(config)
        # 初始化前层归一化层
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        # 初始化编码器
        self.encoder = GitVisionEncoder(config)
        # 初始化后层归一化层
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    # 覆盖了 forward 方法的 docstring,指定了返回值的详细描述
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        """
        Returns:
            Either a tuple or a BaseModelOutput depending on `return_dict`.
        """
        # 如果 output_attentions 未指定,则使用配置中的值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果 output_hidden_states 未指定,则使用配置中的值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果 return_dict 未指定,则使用配置中的值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果 pixel_values 为 None,则抛出数值错误
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 对输入的像素值进行嵌入处理
        hidden_states = self.embeddings(pixel_values)
        # 应用前层归一化
        hidden_states = self.pre_layrnorm(hidden_states)

        # 将处理后的隐藏状态传入编码器进行处理
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取编码器的最后隐藏状态
        last_hidden_state = encoder_outputs[0]

        # 应用后层归一化
        last_hidden_state = self.post_layernorm(last_hidden_state)

        # 如果 return_dict 为 False,则返回一个包含最后隐藏状态和其他编码器输出的元组
        if not return_dict:
            return (last_hidden_state,) + encoder_outputs[1:]

        # 否则,返回一个包含所有输出的 BaseModelOutput 对象
        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )


@add_start_docstrings(
    """The vision model from CLIP, used in GIT, without any head or projection on top.""",
    GIT_START_DOCSTRING,
)
# 覆盖了类 GitVisionModel 的 docstring,提供了关于这个视觉模型的描述以及 GIT 的起始文档字符串
class GitVisionModel(GitPreTrainedModel):
    # 指定配置类为 GitVisionConfig
    config_class = GitVisionConfig
    # 主输入名称为 "pixel_values"
    main_input_name = "pixel_values"

    # 从 transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ 复制而来,将 CLIP 改为 Git
    # 初始化 GitVisionModel 类
    def __init__(self, config: GitVisionConfig):
        super().__init__(config)
        # 初始化视觉模型
        self.vision_model = GitVisionTransformer(config)
        # 执行初始化权重和应用最终处理
        self.post_init()

    # 获取输入嵌入的方法
    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding
    @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
    # 调用装饰器,添加模型前向传播方法的文档字符串,使用给定的输入文档字符串GIT_VISION_INPUTS_DOCSTRING
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
    # 调用装饰器,替换返回值的文档字符串,指定输出类型为BaseModelOutput,配置类为GitVisionConfig

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        r"""
        模型的前向传播方法。

        Returns:
            返回值是一个元组或BaseModelOutput对象。

        Examples:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 如果return_dict不为None,则使用它;否则使用self.config.use_return_dict的值

        return self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 调用self.vision_model进行模型的前向传播,传入相关参数和返回值设置
# 定义一个 GitModel 类,继承自 GitPreTrainedModel,表示一个 GIT 模型
@add_start_docstrings(
    "The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states"
    " without any specific head on top.",
    GIT_START_DOCSTRING,
)
class GitModel(GitPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # 初始化 GitModel 对象时执行以下操作:

        # 实例化 GitEmbeddings 类,用于处理模型的嵌入层
        self.embeddings = GitEmbeddings(config)

        # 实例化 GitVisionModel 类,用于处理视觉编码器部分
        self.image_encoder = GitVisionModel(config.vision_config)

        # 实例化 GitEncoder 类,用于处理文本编码器部分
        self.encoder = GitEncoder(config)

        # 实例化 GitProjection 类,用于定义视觉投影层
        self.visual_projection = GitProjection(config)

        # 如果配置中指定了 num_image_with_embedding,创建对应数量的图像嵌入参数列表
        if config.num_image_with_embedding is not None:
            self.img_temperal_embedding = nn.ParameterList(
                nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
                for _ in range(config.num_image_with_embedding)
            )

        # 调用 post_init 方法,用于初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回模型的输入嵌入层对象
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        # 设置模型的输入嵌入层对象
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历待修剪的头信息,并在相应层中执行修剪操作
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
        # 生成一个未来的遮罩,用于自注意力机制
        # 默认遮罩适用于正向方向,如果需要反向遮罩,需要将其翻转
        mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
        mask = mask.masked_fill(mask == 1, float("-inf"))  # 将遮罩中的所有 1 替换为负无穷
        return mask
    # 创建注意力遮罩,用于Transformer模型的self-attention机制,生成一个三维张量
    def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
        # 获取目标(target)序列的长度
        num_tgt = tgt.shape[1]
        # 获取记忆(memory)序列的长度
        num_memory = memory.shape[1]
        # 获取目标(target)张量所在设备
        device = tgt.device
        # 获取目标(target)张量的数据类型
        dtype = tgt.dtype

        # 创建左上角的全零张量
        top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
        # 创建右上角的全负无穷张量,用于填充注意力矩阵的右上部分
        top_right = torch.full(
            (num_memory, num_tgt + past_key_values_length),
            float("-inf"),
            device=tgt.device,
            dtype=dtype,
        )
        # 创建左下角的全零张量,用于填充注意力矩阵的左下部分
        bottom_left = torch.zeros(
            (num_tgt, num_memory),
            dtype=dtype,
            device=tgt_mask.device,
        )

        # 如果存在过去的键值长度大于零,则需要重新定义目标掩码
        if past_key_values_length > 0:
            tgt_mask = torch.zeros(
                (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
                dtype=dtype,
                device=tgt_mask.device,
            )

        # 将左上角、左下角组合成左侧部分张量
        left = torch.cat((top_left, bottom_left), dim=0)
        # 将右上角和目标掩码组合成右侧部分张量
        right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)

        # 将左侧部分张量和右侧部分张量连接起来,形成完整的注意力矩阵
        full_attention_mask = torch.cat((left, right), dim=1)[None, :]

        # 如果记忆序列的键值填充掩码为None,则创建全假掩码张量
        if memory_key_padding_mask is None:
            memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
        # 如果记忆序列的键值填充掩码不是布尔类型,则抛出值错误异常
        if memory_key_padding_mask.dtype != torch.bool:
            raise ValueError("Memory key padding mask must be a boolean tensor.")
        # 创建与记忆序列形状相同的全零张量
        zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
        # 将填充的位置替换为负无穷
        zero_negative_infinity[memory_key_padding_mask] = float("-inf")
        # 将完整的注意力矩阵张量扩展为指定形状
        full_attention_mask = full_attention_mask.expand(
            (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
        )
        # 克隆扩展后的完整注意力矩阵张量
        full_attention_mask = full_attention_mask.clone()
        # 获取注意力矩阵的左侧原始部分
        origin_left = full_attention_mask[:, :, :num_memory]
        # 执行更新操作,将负无穷的张量加到原始左侧部分
        update = zero_negative_infinity[:, None, :]
        full_attention_mask[:, :, :num_memory] = origin_left + update

        # 为多头注意力添加额外的维度
        full_attention_mask = full_attention_mask[:, None, :, :]

        # 返回完整的注意力矩阵
        return full_attention_mask

    # 用于Transformer模型的前向传播,包含各种输入参数和输出控制标志
    @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
"""
GIT Model with a `language modeling` head on top for autoregressive language modeling.
"""
@add_start_docstrings(
    GIT_START_DOCSTRING
)
class GitForCausalLM(GitPreTrainedModel):
    # List of keys whose weights are tied
    _tied_weights_keys = ["output.weight"]

    def __init__(self, config):
        """
        Initializes the GitForCausalLM model.

        Args:
            config (:class:`~transformers.GitConfig`):
                The configuration object that defines the model architecture.
        """
        super().__init__(config)

        # Initialize the base GitModel with the provided configuration
        self.git = GitModel(config)
        # Linear layer for output
        self.output = nn.Linear(config.hidden_size, config.vocab_size)

        # Initialize weights and apply final processing
        self.post_init()

    def get_output_embeddings(self):
        """
        Returns the output layer.

        Returns:
            :obj:`torch.nn.Linear`: The output layer of the model.
        """
        return self.output

    def set_output_embeddings(self, new_embeddings):
        """
        Sets new output embeddings.

        Args:
            new_embeddings (:obj:`torch.nn.Linear`):
                New embeddings to set for the output layer.
        """
        self.output = new_embeddings

    @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.Tensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass of the GitForCausalLM model.

        Args:
            input_ids (:obj:`torch.Tensor`, optional):
                Input tensor of token indices.
            attention_mask (:obj:`torch.Tensor`, optional):
                Mask to avoid performing attention on padding tokens.
            position_ids (:obj:`torch.Tensor`, optional):
                Indices of positions of each input sequence tokens in the position embeddings.
            pixel_values (:obj:`torch.Tensor`, optional):
                Pixel values if the model is a vision model.
            head_mask (:obj:`torch.Tensor`, optional):
                Mask to nullify selected heads of the self-attention modules.
            inputs_embeds (:obj:`torch.Tensor`, optional):
                Optional tensor to override the input embeddings.
            labels (:obj:`torch.Tensor`, optional):
                Labels for computing the cross entropy classification loss.
            past_key_values (:obj:`List[torch.Tensor]`, optional):
                List of tensors containing cached keys and values.
            use_cache (:obj:`bool`, optional):
                Whether to use the cache for faster decoding.
            output_attentions (:obj:`bool`, optional):
                Whether to output the attentions weights.
            output_hidden_states (:obj:`bool`, optional):
                Whether to output the hidden states.
            return_dict (:obj:`bool`, optional):
                Whether to return a dictionary as the output instead of a tuple.

        Returns:
            :class:`~transformers.modeling_outputs.CausalLMOutputWithPast`:
                The output of the model, potentially with past states.
        """
        # Actual forward logic implementation is within the transformers library

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
    ):
        """
        Prepares input for generation.

        Args:
            input_ids (:obj:`torch.Tensor`):
                Input tensor of token indices.
            past_key_values (:obj:`List[torch.Tensor]`, optional):
                List of tensors containing cached keys and values.
            attention_mask (:obj:`torch.Tensor`, optional):
                Mask to avoid performing attention on padding tokens.
            use_cache (:obj:`bool`, optional):
                Whether to use the cache for faster decoding.
            **kwargs:
                Additional keyword arguments.

        Returns:
            :obj:`Dict[str, torch.Tensor]`:
                Dictionary containing the prepared inputs for generation.
        """
        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            input_ids = input_ids[:, -1:]

        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        input_shape = input_ids.shape
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": kwargs.get("pixel_values", None),
            "past_key_values": past_key_values,
            "use_cache": use_cache,
        }

    def _reorder_cache(self, past_key_values, beam_idx):
        """
        Reorders the cached states for beam search.

        Args:
            past_key_values (:obj:`List[torch.Tensor]`):
                List of tensors containing cached keys and values.
            beam_idx (:obj:`torch.Tensor`):
                Indices of beams to reorder the past states.

        Returns:
            Tuple[List[torch.Tensor]]:
                Reordered past states for beam search.
        """
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past

.\models\git\processing_git.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image/Text processor class for GIT
"""

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding


class GitProcessor(ProcessorMixin):
    r"""
    Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.

    [`GitProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BertTokenizerFast`]. See the
    [`~GitProcessor.__call__`] and [`~GitProcessor.decode`] for more information.

    Args:
        image_processor ([`AutoImageProcessor`]):
            The image processor is a required input.
        tokenizer ([`AutoTokenizer`]):
            The tokenizer is a required input.
    """

    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, image_processor, tokenizer):
        super().__init__(image_processor, tokenizer)
        self.current_processor = self.image_processor
        # 将传入的图像处理器设为当前处理器

    def batch_decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
        refer to the docstring of this method for more information.
        """
        return self.tokenizer.batch_decode(*args, **kwargs)
        # 调用 tokenizer 对象的 batch_decode 方法,并将参数传递给 BertTokenizerFast 对象的 batch_decode 方法

    def decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        """
        return self.tokenizer.decode(*args, **kwargs)
        # 调用 tokenizer 对象的 decode 方法,并将参数传递给 BertTokenizerFast 对象的 decode 方法

    @property
    def model_input_names(self):
        return ["input_ids", "attention_mask", "pixel_values"]
        # 返回模型输入的名称列表,包括 input_ids、attention_mask 和 pixel_values

.\models\git\__init__.py

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入自定义异常和模块延迟加载工具函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构
_import_structure = {
    "configuration_git": ["GIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "GitConfig", "GitVisionConfig"],
    "processing_git": ["GitProcessor"],
}

# 检查是否可以导入 Torch,如果不行则抛出自定义的异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可以导入 Torch,则增加一些模型相关的导入结构
    _import_structure["modeling_git"] = [
        "GIT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "GitForCausalLM",
        "GitModel",
        "GitPreTrainedModel",
        "GitVisionModel",
    ]

# 如果当前是类型检查模式
if TYPE_CHECKING:
    # 导入配置相关的类和常量
    from .configuration_git import GIT_PRETRAINED_CONFIG_ARCHIVE_MAP, GitConfig, GitVisionConfig
    # 导入处理相关的类
    from .processing_git import GitProcessor

    # 再次检查 Torch 是否可用,如果不行则忽略模型相关导入
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入模型相关的类和常量
        from .modeling_git import (
            GIT_PRETRAINED_MODEL_ARCHIVE_LIST,
            GitForCausalLM,
            GitModel,
            GitPreTrainedModel,
            GitVisionModel,
        )

# 如果不是类型检查模式,则将当前模块设为一个延迟加载模块
else:
    import sys

    # 使用 _LazyModule 实现模块的延迟加载
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\glpn\configuration_glpn.py

# coding=utf-8
# Copyright 2022 KAIST and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" GLPN model configuration"""

# 导入预训练配置类和日志模块
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义预训练模型配置文件的映射字典,包含预训练模型名称和其配置文件的链接
GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "vinvino02/glpn-kitti": "https://huggingface.co/vinvino02/glpn-kitti/resolve/main/config.json",
    # See all GLPN models at https://huggingface.co/models?filter=glpn
}


class GLPNConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`GLPNModel`]. It is used to instantiate an GLPN
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the GLPN
    [vinvino02/glpn-kitti](https://huggingface.co/vinvino02/glpn-kitti) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """
    # 定义一个配置类GLPNConfig,用于初始化GLPN模型的参数
    Args:
        num_channels (`int`, *optional*, defaults to 3):
            输入通道的数量。
        num_encoder_blocks (`int`, *optional*, defaults to 4):
            编码器块的数量(即Mix Transformer编码器中的阶段数)。
        depths (`List[int]`, *optional*, defaults to `[2, 2, 2, 2]`):
            每个编码器块中的层的数量。
        sr_ratios (`List[int]`, *optional*, defaults to `[8, 4, 2, 1]`):
            每个编码器块中的序列缩减比率。
        hidden_sizes (`List[int]`, *optional*, defaults to `[32, 64, 160, 256]`):
            每个编码器块的维度。
        patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`):
            每个编码器块之前的补丁大小。
        strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
            每个编码器块之前的步幅。
        num_attention_heads (`List[int]`, *optional*, defaults to `[1, 2, 5, 8]`):
            Transformer编码器每个注意层中的注意头数量。
        mlp_ratios (`List[int]`, *optional*, defaults to `[4, 4, 4, 4]`):
            Mix FFN中隐藏层大小与输入层大小的比率。
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
            编码器和池化器中的非线性激活函数(函数或字符串)。支持的字符串有:"gelu", "relu", "selu"和"gelu_new"。
        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
            嵌入、编码器和池化器中所有全连接层的丢弃概率。
        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
            注意力概率的丢弃比率。
        initializer_range (`float`, *optional*, defaults to 0.02):
            用于初始化所有权重矩阵的截断正态初始化器的标准差。
        drop_path_rate (`float`, *optional*, defaults to 0.1):
            随机深度的丢弃概率,用于Transformer编码器块。
        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
            层归一化层使用的epsilon值。
        decoder_hidden_size (`int`, *optional*, defaults to 64):
            解码器的维度。
        max_depth (`int`, *optional*, defaults to 10):
            解码器的最大深度。
        head_in_index (`int`, *optional*, defaults to -1):
            在头部使用的特征的索引。

    Example:

    ```
    >>> from transformers import GLPNModel, GLPNConfig

    >>> # 初始化一个GLPN vinvino02/glpn-kitti风格的配置
    >>> configuration = GLPNConfig()

    >>> # 使用vinvino02/glpn-kitti风格的配置初始化一个模型
    >>> model = GLPNModel(configuration)
    ```
    # 访问模型配置
    configuration = model.config
posted @ 2024-06-30 15:38  绝不原创的飞龙  阅读(61)  评论(0编辑  收藏  举报