diffusers-源码解析-二-

diffusers 源码解析(二)

.\diffusers\loaders\lora_conversion_utils.py

# 版权声明,表示该代码属于 HuggingFace 团队,所有权利保留
# 授权信息,依据 Apache 许可证版本 2.0
# 用户必须遵循许可证条款使用此文件
# 许可证的获取链接
# 除非适用法律要求或书面同意,软件在“现状”基础上分发,不提供任何形式的保证或条件
# 查看许可证获取关于权限和限制的详细信息

# 导入正则表达式模块
import re

# 从 utils 模块导入 is_peft_version 和 logging 函数
from ..utils import is_peft_version, logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义可能将 SGM 块映射到 diffusers 的函数
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
    # 1. 获取所有的 state_dict 键
    all_keys = list(state_dict.keys())
    # 定义 SGM 模式的列表
    sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]

    # 2. 检查是否需要重新映射,如果不需要则返回原始字典
    is_in_sgm_format = False
    # 遍历所有键以检查是否包含 SGM 模式
    for key in all_keys:
        if any(p in key for p in sgm_patterns):
            is_in_sgm_format = True
            break

    # 如果不在 SGM 格式中,直接返回原字典
    if not is_in_sgm_format:
        return state_dict

    # 3. 否则,根据 SGM 模式重新映射
    new_state_dict = {}
    # 定义内部块映射的列表
    inner_block_map = ["resnets", "attentions", "upsamplers"]

    # 初始化输入、中间和输出块的 ID 集合
    input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()

    # 遍历所有层以填充块 ID
    for layer in all_keys:
        # 如果层名称中包含 "text",直接移到新字典中
        if "text" in layer:
            new_state_dict[layer] = state_dict.pop(layer)
        else:
            # 从层名称中提取 ID
            layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
            # 根据层类型添加到相应的 ID 集合中
            if sgm_patterns[0] in layer:
                input_block_ids.add(layer_id)
            elif sgm_patterns[1] in layer:
                middle_block_ids.add(layer_id)
            elif sgm_patterns[2] in layer:
                output_block_ids.add(layer_id)
            else:
                # 如果层不支持,则抛出异常
                raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")

    # 根据层 ID 获取输入块的所有键
    input_blocks = {
        layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
        for layer_id in input_block_ids
    }
    # 根据层 ID 获取中间块的所有键
    middle_blocks = {
        layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
        for layer_id in middle_block_ids
    }
    # 根据层 ID 获取输出块的所有键
    output_blocks = {
        layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
        for layer_id in output_block_ids
    }

    # 按照新的规则重命名键
    # 遍历输入块的 ID 列表
        for i in input_block_ids:
            # 计算当前块的 ID
            block_id = (i - 1) // (unet_config.layers_per_block + 1)
            # 计算当前层在块内的 ID
            layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
    
            # 遍历当前输入块中的每个键
            for key in input_blocks[i]:
                # 从键中提取内部块的 ID
                inner_block_id = int(key.split(delimiter)[block_slice_pos])
                # 判断当前键是操作还是下采样,并获取对应的内部块键
                inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
                # 确定块内层的字符串表示
                inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
                # 构造新的键
                new_key = delimiter.join(
                    key.split(delimiter)[: block_slice_pos - 1]
                    + [str(block_id), inner_block_key, inner_layers_in_block]
                    + key.split(delimiter)[block_slice_pos + 1 :]
                )
                # 将状态字典中的旧键对应的值存入新键
                new_state_dict[new_key] = state_dict.pop(key)
    
        # 遍历中间块的 ID 列表
        for i in middle_block_ids:
            key_part = None
            # 根据中间块 ID 设置键部分
            if i == 0:
                key_part = [inner_block_map[0], "0"]
            elif i == 1:
                key_part = [inner_block_map[1], "0"]
            elif i == 2:
                key_part = [inner_block_map[0], "1"]
            else:
                # 抛出异常以防无效的中间块 ID
                raise ValueError(f"Invalid middle block id {i}.")
    
            # 遍历当前中间块中的每个键
            for key in middle_blocks[i]:
                # 构造新的键
                new_key = delimiter.join(
                    key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
                )
                # 将状态字典中的旧键对应的值存入新键
                new_state_dict[new_key] = state_dict.pop(key)
    
        # 遍历输出块的 ID 列表
        for i in output_block_ids:
            # 计算当前块的 ID
            block_id = i // (unet_config.layers_per_block + 1)
            # 计算当前层在块内的 ID
            layer_in_block_id = i % (unet_config.layers_per_block + 1)
    
            # 遍历当前输出块中的每个键
            for key in output_blocks[i]:
                # 从键中提取内部块的 ID
                inner_block_id = int(key.split(delimiter)[block_slice_pos])
                # 获取对应的内部块键
                inner_block_key = inner_block_map[inner_block_id]
                # 确定块内层的字符串表示
                inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
                # 构造新的键
                new_key = delimiter.join(
                    key.split(delimiter)[: block_slice_pos - 1]
                    + [str(block_id), inner_block_key, inner_layers_in_block]
                    + key.split(delimiter)[block_slice_pos + 1 :]
                )
                # 将状态字典中的旧键对应的值存入新键
                new_state_dict[new_key] = state_dict.pop(key)
    
        # 如果状态字典还有未转换的条目,则抛出异常
        if len(state_dict) > 0:
            raise ValueError("At this point all state dict entries have to be converted.")
    
        # 返回新的状态字典
        return new_state_dict
# 将非 Diffusers 格式的 LoRA 状态字典转换为兼容 Diffusers 格式的状态字典
def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
    # 初始化 U-Net 模块的状态字典
    unet_state_dict = {}
    # 初始化文本编码器的状态字典
    te_state_dict = {}
    # 初始化第二文本编码器的状态字典
    te2_state_dict = {}
    # 初始化网络 alphas 字典
    network_alphas = {}

    # 检查是否存在 DoRA 支持的 LoRA
    dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
    dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
    dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
    # 如果存在 DoRA,则检查 peft 版本是否满足要求
    if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
        if is_peft_version("<", "0.9.0"):
            # 抛出错误提示需要更新 peft 版本
            raise ValueError(
                "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
            )

    # 遍历所有 LoRA 权重的键
    all_lora_keys = list(state_dict.keys())
    # 遍历所有 LoRA 键
    for key in all_lora_keys:
        # 如果键不以 "lora_down.weight" 结尾,则跳过
        if not key.endswith("lora_down.weight"):
            continue

        # 提取 LoRA 名称
        lora_name = key.split(".")[0]

        # 找到对应的上升权重和 alpha
        lora_name_up = lora_name + ".lora_up.weight"
        lora_name_alpha = lora_name + ".alpha"

        # 处理 U-Net 的 LoRAs
        if lora_name.startswith("lora_unet_"):
            # 转换为 Diffusers 格式的 U-Net LoRA 键
            diffusers_name = _convert_unet_lora_key(key)

            # 存储下权重和上权重
            unet_state_dict[diffusers_name] = state_dict.pop(key)
            unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)

            # 如果存在 DoRA 规模,则存储
            if dora_present_in_unet:
                # 替换相应的 DoRA 规模键
                dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
                unet_state_dict[
                    diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
                ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))

        # 处理文本编码器的 LoRAs
        elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
            # 转换为 Diffusers 格式的文本编码器 LoRA 键
            diffusers_name = _convert_text_encoder_lora_key(key, lora_name)

            # 存储 te 或 te2 的下权重和上权重
            if lora_name.startswith(("lora_te_", "lora_te1_")):
                te_state_dict[diffusers_name] = state_dict.pop(key)
                te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
            else:
                te2_state_dict[diffusers_name] = state_dict.pop(key)
                te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)

            # 如果存在 DoRA 规模,则存储
            if dora_present_in_te or dora_present_in_te2:
                # 替换相应的 DoRA 规模键
                dora_scale_key_to_replace_te = (
                    "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
                )
                if lora_name.startswith(("lora_te_", "lora_te1_")):
                    te_state_dict[
                        diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
                    ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
                elif lora_name.startswith("lora_te2_"):
                    te2_state_dict[
                        diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
                    ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))

        # 如果存在 alpha,则存储
        if lora_name_alpha in state_dict:
            alpha = state_dict.pop(lora_name_alpha).item()
            # 更新网络中的 alpha 名称
            network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))

    # 检查是否还有剩余的键
    if len(state_dict) > 0:
        # 如果有未重命名的键,则引发错误
        raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
    # 记录日志,提示检测到非扩散模型的检查点
        logger.info("Non-diffusers checkpoint detected.")
    
        # 构造最终的状态字典
        # 为 UNet 的状态字典添加前缀
        unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
        # 为文本编码器的状态字典添加前缀
        te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
        # 检查第二个文本编码器状态字典是否非空,如果非空则添加前缀
        te2_state_dict = (
            {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
            if len(te2_state_dict) > 0
            else None
        )
        # 如果第二个文本编码器状态字典存在,则更新第一个状态字典
        if te2_state_dict is not None:
            te_state_dict.update(te2_state_dict)
    
        # 合并 UNet 和文本编码器的状态字典
        new_state_dict = {**unet_state_dict, **te_state_dict}
        # 返回合并后的状态字典和网络的 alpha 值
        return new_state_dict, network_alphas
# 定义一个将 U-Net LoRA 键转换为 Diffusers 兼容键的函数
def _convert_unet_lora_key(key):
    """
    转换 U-Net LoRA 键为 Diffusers 兼容的键。
    """
    # 将键中的前缀替换为更通用的格式
    diffusers_name = key.replace("lora_unet_", "").replace("_", ".")

    # 替换常见的 U-Net 命名模式
    diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
    diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
    diffusers_name = diffusers_name.replace("middle.block", "mid_block")
    diffusers_name = diffusers_name.replace("mid.block", "mid_block")
    diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
    diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
    diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
    diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
    diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
    diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
    diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
    diffusers_name = diffusers_name.replace("proj.in", "proj_in")
    diffusers_name = diffusers_name.replace("proj.out", "proj_out")
    diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")

    # 针对 SDXL 特定的转换
    if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
        # 使用正则表达式移除最后一个数字
        pattern = r"\.\d+(?=\D*$)"
        diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
    if ".in." in diffusers_name:
        # 将指定层替换为新的名称
        diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
    if ".out." in diffusers_name:
        # 将指定层替换为新的名称
        diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
    if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
        # 将操作符替换为卷积层
        diffusers_name = diffusers_name.replace("op", "conv")
    if "skip" in diffusers_name:
        # 将跳跃连接替换为卷积捷径
        diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")

    # 针对 LyCORIS 特定的转换
    if "time.emb.proj" in diffusers_name:
        # 替换时间嵌入投影的名称
        diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
    if "conv.shortcut" in diffusers_name:
        # 替换卷积捷径的名称
        diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")

    # 一般性的转换
    if "transformer_blocks" in diffusers_name:
        if "attn1" in diffusers_name or "attn2" in diffusers_name:
            # 将注意力层的名称进行处理
            diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
            diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
        elif "ff" in diffusers_name:
            pass  # 如果包含前馈层,什么都不做
    elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
        pass  # 如果包含投影层,什么都不做
    else:
        pass  # 其他情况,什么都不做

    # 返回转换后的键
    return diffusers_name


# 定义一个将文本编码器 LoRA 键转换为 Diffusers 兼容键的函数
def _convert_text_encoder_lora_key(key, lora_name):
    """
    转换文本编码器 LoRA 键为 Diffusers 兼容的键。
    """
    # 检查 LoRA 名称的前缀
    if lora_name.startswith(("lora_te_", "lora_te1_")):
        # 根据前缀设置要替换的键
        key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
    else:
        # 如果条件不满足,则设置替换的键为 "lora_te2_"
        key_to_replace = "lora_te2_"

    # 将原始键中的指定键去除,并将下划线替换为点,以形成 diffusers_name
    diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
    # 将 "text.model" 替换为 "text_model"
    diffusers_name = diffusers_name.replace("text.model", "text_model")
    # 将 "self.attn" 替换为 "self_attn"
    diffusers_name = diffusers_name.replace("self.attn", "self_attn")
    # 将 "q.proj.lora" 替换为 "to_q_lora"
    diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
    # 将 "k.proj.lora" 替换为 "to_k_lora"
    diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
    # 将 "v.proj.lora" 替换为 "to_v_lora"
    diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
    # 将 "out.proj.lora" 替换为 "to_out_lora"
    diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
    # 将 "text.projection" 替换为 "text_projection"
    diffusers_name = diffusers_name.replace("text.projection", "text_projection")

    # 检查 diffusers_name 是否包含 "self_attn" 或 "text_projection"
    if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
        # 如果包含,则不进行任何操作,直接跳过
        pass
    # 检查 diffusers_name 是否包含 "mlp"
    elif "mlp" in diffusers_name:
        # 注意这是新的 diffusers 约定,其余代码可能尚未使用此约定
        # 将 ".lora." 替换为 ".lora_linear_layer."
        diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
    # 返回最终的 diffusers_name
    return diffusers_name
# 获取 Diffusers 模型的正确 alpha 名称
def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
    # 文档字符串,说明函数的功能
    """
    Gets the correct alpha name for the Diffusers model.
    """
    # 检查 lora_name_alpha 是否以 "lora_unet_" 开头
    if lora_name_alpha.startswith("lora_unet_"):
        # 设置前缀为 "unet."
        prefix = "unet."
    # 检查 lora_name_alpha 是否以 "lora_te_" 或 "lora_te1_" 开头
    elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
        # 设置前缀为 "text_encoder."
        prefix = "text_encoder."
    else:
        # 其他情况设置前缀为 "text_encoder_2."
        prefix = "text_encoder_2."
    # 生成新的名称,组合前缀和 diffusers_name 的部分,并加上 ".alpha"
    new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
    # 返回一个字典,包含新名称和 alpha 值
    return {new_name: alpha}

.\diffusers\loaders\lora_pipeline.py

# 版权声明,表明该代码的版权所有者及其权利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证的规定使用该文件
# Licensed under the Apache License, Version 2.0 (the "License");
# 只有在遵循许可证的情况下,才能使用此文件
# you may not use this file except in compliance with the License.
# 可以在以下网址获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有约定,否则根据许可证分发的软件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的明示或暗示保证或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 参见许可证以获取有关权限和限制的具体条款
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入操作系统模块
import os
# 从 typing 模块导入类型提示相关的工具
from typing import Callable, Dict, List, Optional, Union

# 导入 PyTorch 库
import torch
# 从 huggingface_hub.utils 导入验证 Hugging Face Hub 参数的函数
from huggingface_hub.utils import validate_hf_hub_args

# 从 utils 模块中导入多个工具函数和常量
from ..utils import (
    USE_PEFT_BACKEND,  # 用于指示是否使用 PEFT 后端的常量
    convert_state_dict_to_diffusers,  # 转换状态字典到 Diffusers 格式的函数
    convert_state_dict_to_peft,  # 转换状态字典到 PEFT 格式的函数
    convert_unet_state_dict_to_peft,  # 将 UNet 状态字典转换为 PEFT 格式的函数
    deprecate,  # 用于标记过时函数的装饰器
    get_adapter_name,  # 获取适配器名称的函数
    get_peft_kwargs,  # 获取 PEFT 关键字参数的函数
    is_peft_version,  # 检查是否为 PEFT 版本的函数
    is_transformers_available,  # 检查 Transformers 库是否可用的函数
    logging,  # 日志记录工具
    scale_lora_layers,  # 调整 LoRA 层规模的函数
)
# 从 lora_base 模块导入 LoraBaseMixin 类
from .lora_base import LoraBaseMixin
# 从 lora_conversion_utils 模块导入两个用于转换的函数
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers

# 如果 Transformers 库可用,则导入相关的模块
if is_transformers_available():
    from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules

# 创建一个日志记录器,用于记录本模块的日志信息
logger = logging.get_logger(__name__)

# 定义一些常量,表示不同组件的名称
TEXT_ENCODER_NAME = "text_encoder"  # 文本编码器的名称
UNET_NAME = "unet"  # UNet 模型的名称
TRANSFORMER_NAME = "transformer"  # Transformer 模型的名称

# 定义 LoRA 权重文件的名称
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"  # 二进制格式的权重文件名
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"  # 安全格式的权重文件名


# 定义一个类,用于加载 LoRA 层到稳定扩散模型中
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
    r"""
    将 LoRA 层加载到稳定扩散模型 [`UNet2DConditionModel`] 和
    [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 中。
    """

    # 可加载的 LoRA 模块列表
    _lora_loadable_modules = ["unet", "text_encoder"]
    unet_name = UNET_NAME  # UNet 模型的名称
    text_encoder_name = TEXT_ENCODER_NAME  # 文本编码器的名称

    # 定义加载 LoRA 权重的方法
    def load_lora_weights(
        self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
    ):
        """
        加载指定的 LoRA 权重到 `self.unet` 和 `self.text_encoder` 中。

        所有关键字参数将转发给 `self.lora_state_dict`。

        详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`],了解如何加载状态字典。

        详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`],了解如何将状态字典加载到 `self.unet` 中。

        详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`],了解如何将状态字典加载到 `self.text_encoder` 中。

        参数:
            pretrained_model_name_or_path_or_dict (`str` 或 `os.PathLike` 或 `dict`):
                详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]。
            kwargs (`dict`, *可选*):
                详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]。
            adapter_name (`str`, *可选*):
                用于引用加载的适配器模型的适配器名称。如果未指定,将使用 `default_{i}`,其中 i 是加载的适配器总数。
        """
        # 检查是否使用 PEFT 后端,如果未使用则引发错误
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")

        # 如果传入的是字典,则复制一份而不是就地修改
        if isinstance(pretrained_model_name_or_path_or_dict, dict):
            pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

        # 首先,确保检查点是兼容的,并且可以成功加载
        state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

        # 检查状态字典中的所有键是否包含 "lora" 或 "dora_scale"
        is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
        # 如果格式不正确,则引发错误
        if not is_correct_format:
            raise ValueError("Invalid LoRA checkpoint.")

        # 将 LoRA 权重加载到 UNet 中
        self.load_lora_into_unet(
            state_dict,
            network_alphas=network_alphas,
            unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
            adapter_name=adapter_name,
            _pipeline=self,
        )
        # 将 LoRA 权重加载到文本编码器中
        self.load_lora_into_text_encoder(
            state_dict,
            network_alphas=network_alphas,
            text_encoder=getattr(self, self.text_encoder_name)
            if not hasattr(self, "text_encoder")
            else self.text_encoder,
            lora_scale=self.lora_scale,
            adapter_name=adapter_name,
            _pipeline=self,
        )

    # 类方法,用于验证 HF Hub 参数
    @classmethod
    @validate_hf_hub_args
    def lora_state_dict(
        cls,
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        **kwargs,
    @classmethod
    # 定义一个类方法,用于将 LoRA 层加载到 UNet 模型中
    def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
        """
        将 `state_dict` 中指定的 LoRA 层加载到 `unet` 中。
    
        参数:
            state_dict (`dict`):
                包含 LoRA 层参数的标准状态字典。键可以直接索引到 unet,或者以额外的 `unet` 前缀标识,以区分文本编码器的 LoRA 层。
            network_alphas (`Dict[str, float]`):
                用于稳定学习和防止下溢的网络 alpha 值。此值与 kohya-ss 训练脚本中的 `--network_alpha` 选项含义相同。参考[此链接](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning)。
            unet (`UNet2DConditionModel`):
                用于加载 LoRA 层的 UNet 模型。
            adapter_name (`str`, *可选*):
                用于引用加载的适配器模型的适配器名称。如果未指定,将使用 `default_{i}`,其中 i 是加载的适配器总数。
        """
        # 检查是否使用 PEFT 后端,如果未使用则引发错误
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")
    
        # 检查序列化格式是否为新格式,`state_dict` 的键是否以 `cls.unet_name` 和/或 `cls.text_encoder_name` 为前缀
        keys = list(state_dict.keys())
        only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
        if not only_text_encoder:
            # 加载与 UNet 对应的层
            logger.info(f"Loading {cls.unet_name}.")
            # 调用 UNet 的加载方法,传入状态字典和其他参数
            unet.load_attn_procs(
                state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
            )
    
    # 定义一个类方法,用于将 LoRA 层加载到文本编码器中
    @classmethod
    def load_lora_into_text_encoder(
        cls,
        state_dict,
        network_alphas,
        text_encoder,
        prefix=None,
        lora_scale=1.0,
        adapter_name=None,
        _pipeline=None,
    ):
        # 方法定义,具体实现未提供
        pass
    
    # 定义一个类方法,用于保存 LoRA 权重
    @classmethod
    def save_lora_weights(
        cls,
        save_directory: Union[str, os.PathLike],
        unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
        text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
        is_main_process: bool = True,
        weight_name: str = None,
        save_function: Callable = None,
        safe_serialization: bool = True,
    ):
        # 方法定义,具体实现未提供
        pass
    ):
        r"""  # 文档字符串,描述函数的作用和参数
        Save the LoRA parameters corresponding to the UNet and text encoder.  # 保存与 UNet 和文本编码器相对应的 LoRA 参数

        Arguments:  # 参数说明
            save_directory (`str` or `os.PathLike`):  # 保存目录的类型说明
                Directory to save LoRA parameters to. Will be created if it doesn't exist.  # 保存 LoRA 参数的目录,如果不存在则创建
            unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):  # UNet 的 LoRA 层状态字典
                State dict of the LoRA layers corresponding to the `unet`.  # 与 `unet` 相对应的 LoRA 层的状态字典
            text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):  # 文本编码器的 LoRA 层状态字典
                State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text  # 与 `text_encoder` 相对应的 LoRA 层状态字典,必须显式传递
                encoder LoRA state dict because it comes from 🤗 Transformers.  # 因为它来自 🤗 Transformers
            is_main_process (`bool`, *optional*, defaults to `True`):  # 主要进程的布尔值,可选,默认值为 True
                Whether the process calling this is the main process or not. Useful during distributed training and you  # 调用此函数的进程是否为主进程,在分布式训练中很有用
                need to call this function on all processes. In this case, set `is_main_process=True` only on the main  # 在这种情况下,只在主进程上设置 `is_main_process=True` 以避免竞争条件
                process to avoid race conditions.  # 避免竞争条件
            save_function (`Callable`):  # 保存函数的类型说明
                The function to use to save the state dictionary. Useful during distributed training when you need to  # 用于保存状态字典的函数,在分布式训练中很有用
                replace `torch.save` with another method. Can be configured with the environment variable  # 可以通过环境变量配置
                `DIFFUSERS_SAVE_MODE`.  # `DIFFUSERS_SAVE_MODE`
            safe_serialization (`bool`, *optional*, defaults to `True`):  # 安全序列化的布尔值,可选,默认值为 True
                Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.  # 是否使用 `safetensors` 或传统的 PyTorch 方法 `pickle` 保存模型
        """  # 文档字符串结束
        state_dict = {}  # 初始化一个空的状态字典

        if not (unet_lora_layers or text_encoder_lora_layers):  # 检查是否至少有一个 LoRA 层
            raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")  # 如果没有,抛出错误

        if unet_lora_layers:  # 如果存在 UNet 的 LoRA 层
            state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))  # 更新状态字典,打包 UNet 权重

        if text_encoder_lora_layers:  # 如果存在文本编码器的 LoRA 层
            state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))  # 更新状态字典,打包文本编码器权重

        # Save the model  # 保存模型的注释
        cls.write_lora_layers(  # 调用类方法保存 LoRA 层
            state_dict=state_dict,  # 状态字典参数
            save_directory=save_directory,  # 保存目录参数
            is_main_process=is_main_process,  # 主要进程参数
            weight_name=weight_name,  # 权重名称参数
            save_function=save_function,  # 保存函数参数
            safe_serialization=safe_serialization,  # 安全序列化参数
        )  # 方法调用结束

    def fuse_lora(  # 定义 fuse_lora 方法
        self,  # 实例方法的 self 参数
        components: List[str] = ["unet", "text_encoder"],  # 组件列表,默认包含 UNet 和文本编码器
        lora_scale: float = 1.0,  # LoRA 缩放因子,默认值为 1.0
        safe_fusing: bool = False,  # 安全融合的布尔值,默认值为 False
        adapter_names: Optional[List[str]] = None,  # 适配器名称的可选列表,默认值为 None
        **kwargs,  # 接收额外的关键字参数
    ):
        r"""  # 开始文档字符串,描述该方法的功能和用法
        Fuses the LoRA parameters into the original parameters of the corresponding blocks.  # 将 LoRA 参数融合到对应块的原始参数中

        <Tip warning={true}>  # 开始警告提示框
        This is an experimental API.  # 说明这是一个实验性 API
        </Tip>  # 结束警告提示框

        Args:  # 开始参数说明
            components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.  # 可注入 LoRA 的组件列表
            lora_scale (`float`, defaults to 1.0):  # LoRA 参数对输出影响的比例
                Controls how much to influence the outputs with the LoRA parameters.  # 控制 LoRA 参数对输出的影响程度
            safe_fusing (`bool`, defaults to `False`):  # 是否在融合前检查权重是否为 NaN
                Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.  # 如果值为 NaN 则不进行融合
            adapter_names (`List[str]`, *optional*):  # 可选的适配器名称
                Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.  # 如果未传入,默认融合所有活动适配器

        Example:  # 示例部分的开始
        ```py  # Python 代码块开始
        from diffusers import DiffusionPipeline  # 导入 DiffusionPipeline 模块
        import torch  # 导入 PyTorch 库

        pipeline = DiffusionPipeline.from_pretrained(  # 从预训练模型创建管道
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16  # 使用 float16 类型的模型
        ).to("cuda")  # 将管道移动到 GPU
        pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")  # 加载 LoRA 权重
        pipeline.fuse_lora(lora_scale=0.7)  # 融合 LoRA,影响比例为 0.7
        ```  # Python 代码块结束
        """
        super().fuse_lora(  # 调用父类的 fuse_lora 方法
            components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names  # 将参数传递给父类方法
        )

    def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):  # 定义 unfuse_lora 方法,带有默认组件
        r"""  # 开始文档字符串,描述该方法的功能和用法
        Reverses the effect of  # 反转 fuse_lora 方法的效果
        [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).  # 提供 fuse_lora 的链接

        <Tip warning={true}>  # 开始警告提示框
        This is an experimental API.  # 说明这是一个实验性 API
        </Tip>  # 结束警告提示框

        Args:  # 开始参数说明
            components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.  # 可注入 LoRA 的组件列表,用于反融合
            unfuse_unet (`bool`, defaults to `True`):  # 是否反融合 UNet 的 LoRA 参数
                Whether to unfuse the UNet LoRA parameters.  # 反融合 UNet LoRA 参数的选项
            unfuse_text_encoder (`bool`, defaults to `True`):  # 是否反融合文本编码器的 LoRA 参数
                Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the  # 反融合文本编码器的 LoRA 参数的选项
                LoRA parameters then it won't have any effect.  # 如果文本编码器未被修改,则不会有任何效果
        """  # 结束文档字符串
        super().unfuse_lora(components=components)  # 调用父类的 unfuse_lora 方法,并传递组件参数
# 定义一个类,混合自 LoraBaseMixin,用于加载 LoRA 层到 Stable Diffusion XL
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
    r"""
    将 LoRA 层加载到 Stable Diffusion XL 的 [`UNet2DConditionModel`]、
    [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 和
    [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection) 中。
    """

    # 定义可以加载 LoRA 的模块名列表
    _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
    # 指定 UNET 的名称
    unet_name = UNET_NAME
    # 指定文本编码器的名称
    text_encoder_name = TEXT_ENCODER_NAME

    # 定义一个加载 LoRA 权重的方法
    def load_lora_weights(
        self,
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        adapter_name: Optional[str] = None,
        **kwargs,
    ):
    @classmethod
    # 验证传入的 Hugging Face Hub 参数
    @validate_hf_hub_args
    # 定义一个类方法,获取 LoRA 状态字典
    def lora_state_dict(
        cls,
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        **kwargs,
    ):
    @classmethod
    # 定义一个类方法,用于将 LoRA 加载到 UNET 中
    # 定义一个类方法,用于将 LoRA 层加载到 UNet 模型中
    def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
        # 文档字符串,描述方法的作用和参数
        """
        This will load the LoRA layers specified in `state_dict` into `unet`.
    
        Parameters:
            state_dict (`dict`):
                A standard state dict containing the lora layer parameters. The keys can either be indexed directly
                into the unet or prefixed with an additional `unet` which can be used to distinguish between text
                encoder lora layers.
            network_alphas (`Dict[str, float]`):
                The value of the network alpha used for stable learning and preventing underflow. This value has the
                same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
                link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
            unet (`UNet2DConditionModel`):
                The UNet model to load the LoRA layers into.
            adapter_name (`str`, *optional*):
                Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
                `default_{i}` where i is the total number of adapters being loaded.
        """
        # 检查是否启用 PEFT 后端,若未启用则抛出异常
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")
    
        # 获取 state_dict 中的所有键
        keys = list(state_dict.keys())
        # 检查所有键是否都以 text_encoder_name 开头
        only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
        # 如果不是仅有文本编码器
        if not only_text_encoder:
            # 记录正在加载的 UNet 名称
            logger.info(f"Loading {cls.unet_name}.")
            # 加载与 UNet 对应的层
            unet.load_attn_procs(
                state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
            )
    
        @classmethod
        # 从 diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder 复制的方法
        def load_lora_into_text_encoder(
            cls,
            state_dict,
            network_alphas,
            text_encoder,
            prefix=None,
            lora_scale=1.0,
            adapter_name=None,
            _pipeline=None,
        @classmethod
        # 定义一个类方法,用于保存 LoRA 权重
        def save_lora_weights(
            cls,
            save_directory: Union[str, os.PathLike],
            unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
            text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
            text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
            is_main_process: bool = True,
            weight_name: str = None,
            save_function: Callable = None,
            safe_serialization: bool = True,
    ):
        r"""
        # 文档字符串,描述保存 UNet 和文本编码器对应的 LoRA 参数的功能

        Arguments:
            # 保存 LoRA 参数的目录,若不存在则创建
            save_directory (`str` or `os.PathLike`):
                Directory to save LoRA parameters to. Will be created if it doesn't exist.
            # UNet 对应的 LoRA 层的状态字典
            unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
                State dict of the LoRA layers corresponding to the `unet`.
            # 文本编码器对应的 LoRA 层的状态字典,必须显式传入
            text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
                State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
                encoder LoRA state dict because it comes from 🤗 Transformers.
            # 第二个文本编码器对应的 LoRA 层的状态字典,必须显式传入
            text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
                State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
                encoder LoRA state dict because it comes from 🤗 Transformers.
            # 表示调用此函数的进程是否为主进程,主要用于分布式训练
            is_main_process (`bool`, *optional*, defaults to `True`):
                Whether the process calling this is the main process or not. Useful during distributed training and you
                need to call this function on all processes. In this case, set `is_main_process=True` only on the main
                process to avoid race conditions.
            # 保存状态字典的函数,分布式训练时可替换 `torch.save`
            save_function (`Callable`):
                The function to use to save the state dictionary. Useful during distributed training when you need to
                replace `torch.save` with another method. Can be configured with the environment variable
                `DIFFUSERS_SAVE_MODE`.
            # 是否使用 safetensors 保存模型,默认为 True
            safe_serialization (`bool`, *optional*, defaults to `True`):
                Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
        """
        # 初始化状态字典,用于存储 LoRA 参数
        state_dict = {}

        # 如果没有传入任何 LoRA 层,则抛出异常
        if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
            raise ValueError(
                # 报告至少需要传入一个 LoRA 层
                "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
            )

        # 如果有 UNet 的 LoRA 层,则打包并更新状态字典
        if unet_lora_layers:
            state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))

        # 如果有文本编码器的 LoRA 层,则打包并更新状态字典
        if text_encoder_lora_layers:
            state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))

        # 如果有第二个文本编码器的 LoRA 层,则打包并更新状态字典
        if text_encoder_2_lora_layers:
            state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))

        # 写入 LoRA 层参数,调用保存函数
        cls.write_lora_layers(
            state_dict=state_dict,
            save_directory=save_directory,
            is_main_process=is_main_process,
            weight_name=weight_name,
            save_function=save_function,
            safe_serialization=safe_serialization,
        )
    # 定义一个方法 fuse_lora,用于将 LoRA 参数融合到相应模块的原始参数中
    def fuse_lora(
        # 方法参数:可注入 LoRA 的组件列表,默认为 ["unet", "text_encoder", "text_encoder_2"]
        self,
        components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
        # LoRA 权重影响输出的程度,默认为 1.0
        lora_scale: float = 1.0,
        # 是否在融合前检查权重是否为 NaN,默认为 False
        safe_fusing: bool = False,
        # 可选参数,指定用于融合的适配器名称
        adapter_names: Optional[List[str]] = None,
        # 允许传入额外的关键字参数
        **kwargs,
    ):
        r"""
        将 LoRA 参数融合到相应模块的原始参数中。

        <Tip warning={true}>

        这是一个实验性 API。

        </Tip>

        Args:
            components: (`List[str]`): 需要融合 LoRA 的组件列表。
            lora_scale (`float`, defaults to 1.0):
                控制 LoRA 参数对输出的影响程度。
            safe_fusing (`bool`, defaults to `False`):
                在融合前检查权重是否为 NaN 的开关。
            adapter_names (`List[str]`, *optional*):
                用于融合的适配器名称。如果未传入,则将融合所有活动适配器。

        Example:

        ```py
        from diffusers import DiffusionPipeline
        import torch

        pipeline = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
        ).to("cuda")
        # 加载 LoRA 权重
        pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
        # 融合 LoRA 参数,影响程度为 0.7
        pipeline.fuse_lora(lora_scale=0.7)
        ```
        """
        # 调用父类的 fuse_lora 方法,传入组件、LoRA 权重、检查 NaN 的选项和适配器名称
        super().fuse_lora(
            components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
        )

    # 定义一个方法 unfuse_lora,用于逆转 LoRA 参数的融合效果
    def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
        r"""
        逆转
        [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora) 的效果。

        <Tip warning={true}>

        这是一个实验性 API。

        </Tip>

        Args:
            components (`List[str]`): 需要从中解融合 LoRA 的组件列表。
            unfuse_unet (`bool`, defaults to `True`): 是否解融合 UNet 的 LoRA 参数。
            unfuse_text_encoder (`bool`, defaults to `True`):
                是否解融合文本编码器的 LoRA 参数。如果文本编码器没有被 LoRA 参数修补,则不会有任何效果。
        """
        # 调用父类的 unfuse_lora 方法,传入组件和其他参数
        super().unfuse_lora(components=components)
# 定义一个混合类 SD3LoraLoaderMixin,继承自 LoraBaseMixin
class SD3LoraLoaderMixin(LoraBaseMixin):
    r"""
    加载 LoRA 层到 [`SD3Transformer2DModel`]、
    [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 和
    [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection)。

    特定于 [`StableDiffusion3Pipeline`]。
    """

    # 可加载 LoRA 的模块列表
    _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
    # 转换器名称,使用预定义的常量
    transformer_name = TRANSFORMER_NAME
    # 文本编码器名称,使用预定义的常量
    text_encoder_name = TEXT_ENCODER_NAME

    # 类方法,验证 Hugging Face Hub 参数
    @classmethod
    @validate_hf_hub_args
    def lora_state_dict(
        cls,
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        **kwargs,
    ):
        # 加载 LoRA 权重的方法,接收模型名称或路径或字典
        def load_lora_weights(
            self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
        ):
            # 类方法,从 diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin 中复制的加载文本编码器的方法
            @classmethod
            def load_lora_into_text_encoder(
                cls,
                state_dict,
                network_alphas,
                text_encoder,
                prefix=None,
                lora_scale=1.0,
                adapter_name=None,
                _pipeline=None,
            ):
                # 类方法,保存 LoRA 权重到指定目录
                def save_lora_weights(
                    cls,
                    save_directory: Union[str, os.PathLike],
                    transformer_lora_layers: Dict[str, torch.nn.Module] = None,
                    text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
                    text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
                    is_main_process: bool = True,
                    weight_name: str = None,
                    save_function: Callable = None,
                    safe_serialization: bool = True,
                ):
    ):
        r"""
        保存与 UNet 和文本编码器对应的 LoRA 参数。

        参数:
            save_directory (`str` 或 `os.PathLike`):
                保存 LoRA 参数的目录。如果不存在,将创建该目录。
            transformer_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
                与 `transformer` 相关的 LoRA 层的状态字典。
            text_encoder_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
                与 `text_encoder` 相关的 LoRA 层的状态字典。必须显式传递文本编码器的 LoRA 状态字典,因为它来自 🤗 Transformers。
            text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
                与 `text_encoder_2` 相关的 LoRA 层的状态字典。必须显式传递文本编码器的 LoRA 状态字典,因为它来自 🤗 Transformers。
            is_main_process (`bool`, *可选*, 默认值为 `True`):
                调用此函数的进程是否为主进程。在分布式训练期间非常有用,您需要在所有进程上调用此函数。在这种情况下,只有在主进程上设置 `is_main_process=True` 以避免竞争条件。
            save_function (`Callable`):
                用于保存状态字典的函数。在分布式训练时,当您需要将 `torch.save` 替换为其他方法时非常有用。可以通过环境变量 `DIFFUSERS_SAVE_MODE` 进行配置。
            safe_serialization (`bool`, *可选*, 默认值为 `True`):
                是否使用 `safetensors` 保存模型,或使用传统的 PyTorch 方法 `pickle`。
        """
        # 初始化一个空字典,用于存储状态字典
        state_dict = {}

        # 检查是否至少传递了一个 LoRA 层的状态字典,如果没有则引发错误
        if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
            raise ValueError(
                "必须至少传递一个 `transformer_lora_layers`、`text_encoder_lora_layers` 或 `text_encoder_2_lora_layers`。"
            )

        # 如果传递了 transformer_lora_layers,则将其打包并更新状态字典
        if transformer_lora_layers:
            state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))

        # 如果传递了 text_encoder_lora_layers,则将其打包并更新状态字典
        if text_encoder_lora_layers:
            state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))

        # 如果传递了 text_encoder_2_lora_layers,则将其打包并更新状态字典
        if text_encoder_2_lora_layers:
            state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))

        # 保存模型
        cls.write_lora_layers(
            state_dict=state_dict,  # 要保存的状态字典
            save_directory=save_directory,  # 保存目录
            is_main_process=is_main_process,  # 主进程标志
            weight_name=weight_name,  # 权重名称
            save_function=save_function,  # 保存函数
            safe_serialization=safe_serialization,  # 安全序列化标志
        )
    # 定义一个方法,用于将 LoRA 参数融合到原始参数中
    def fuse_lora(
        # 方法的参数列表
        self,
        # 可选组件列表,默认包括 "transformer"、"text_encoder" 和 "text_encoder_2"
        components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
        # LoRA 参数影响输出的比例,默认为 1.0
        lora_scale: float = 1.0,
        # 安全融合标志,默认为 False
        safe_fusing: bool = False,
        # 可选适配器名称列表,默认为 None
        adapter_names: Optional[List[str]] = None,
        # 其他关键字参数
        **kwargs,
    ):
        # 方法文档字符串,描述该方法的功能和参数
        r"""
        Fuses the LoRA parameters into the original parameters of the corresponding blocks.

        <Tip warning={true}>

        This is an experimental API.

        </Tip>

        Args:
            components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
            lora_scale (`float`, defaults to 1.0):
                Controls how much to influence the outputs with the LoRA parameters.
            safe_fusing (`bool`, defaults to `False`):
                Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
            adapter_names (`List[str]`, *optional*):
                Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.

        Example:

        ```py
        from diffusers import DiffusionPipeline
        import torch

        pipeline = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
        ).to("cuda")
        pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
        pipeline.fuse_lora(lora_scale=0.7)
        ```
        """
        # 调用父类的方法进行 LoRA 参数融合,传递相关参数
        super().fuse_lora(
            components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
        )

    # 定义一个方法,用于将 LoRA 参数从组件中移除
    def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
        # 方法文档字符串,描述该方法的功能和参数
        r"""
        Reverses the effect of
        [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).

        <Tip warning={true}>

        This is an experimental API.

        </Tip>

        Args:
            components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
            unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
            unfuse_text_encoder (`bool`, defaults to `True`):
                Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
                LoRA parameters then it won't have any effect.
        """
        # 调用父类的方法进行 LoRA 参数移除,传递组件参数
        super().unfuse_lora(components=components)
# 定义一个混合类,用于加载 LoRA 层,继承自 LoraBaseMixin
class FluxLoraLoaderMixin(LoraBaseMixin):
    r"""
    加载 LoRA 层到 [`FluxTransformer2DModel`] 和 [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)。
    
    特定于 [`StableDiffusion3Pipeline`]。
    """

    # 可加载的 LoRA 模块名称列表
    _lora_loadable_modules = ["transformer", "text_encoder"]
    # Transformer 的名称
    transformer_name = TRANSFORMER_NAME
    # 文本编码器的名称
    text_encoder_name = TEXT_ENCODER_NAME

    # 类方法,验证 Hugging Face Hub 参数,并获取 LoRA 状态字典
    @classmethod
    @validate_hf_hub_args
    def lora_state_dict(
        cls,
        # 预训练模型的名称、路径或字典,类型可以是字符串或字典
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        # 是否返回 alpha 值,默认为 False
        return_alphas: bool = False,
        # 其他关键字参数
        **kwargs,
    ):
        # 方法体缺失,需实现具体逻辑
        pass

    # 实例方法,加载 LoRA 权重
    def load_lora_weights(
        self, 
        # 预训练模型的名称、路径或字典,类型可以是字符串或字典
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 
        # 可选的适配器名称
        adapter_name=None, 
        # 其他关键字参数
        **kwargs
    ):
        # 方法体缺失,需实现具体逻辑
        pass
    ):
        """
        加载指定的 LoRA 权重到 `self.transformer` 和 `self.text_encoder`。

        所有关键字参数会转发给 `self.lora_state_dict`。

        详见 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] 如何加载状态字典。

        详见 [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] 如何将状态字典加载到 `self.transformer`。

        参数:
            pretrained_model_name_or_path_or_dict (`str` 或 `os.PathLike` 或 `dict`):
                详见 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]。
            kwargs (`dict`, *可选*):
                详见 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]。
            adapter_name (`str`, *可选*):
                用于引用加载的适配器模型的名称。如果未指定,将使用
                `default_{i}`,其中 i 是加载的适配器总数。
        """
        # 检查是否启用 PEFT 后端,若未启用则抛出错误
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")

        # 如果传入的是字典,则复制它以避免就地修改
        if isinstance(pretrained_model_name_or_path_or_dict, dict):
            pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

        # 首先,确保检查点是兼容的,并可以成功加载
        state_dict, network_alphas = self.lora_state_dict(
            pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
        )

        # 验证状态字典的格式是否正确,确保包含 "lora" 或 "dora_scale"
        is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
        if not is_correct_format:
            raise ValueError("Invalid LoRA checkpoint.")

        # 将状态字典加载到 transformer 中
        self.load_lora_into_transformer(
            state_dict,
            network_alphas=network_alphas,
            transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
            adapter_name=adapter_name,
            _pipeline=self,
        )

        # 从状态字典中提取与 text_encoder 相关的部分
        text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
        # 如果提取的字典不为空,则加载到 text_encoder 中
        if len(text_encoder_state_dict) > 0:
            self.load_lora_into_text_encoder(
                text_encoder_state_dict,
                network_alphas=network_alphas,
                text_encoder=self.text_encoder,
                prefix="text_encoder",
                lora_scale=self.lora_scale,
                adapter_name=adapter_name,
                _pipeline=self,
            )

    @classmethod
    @classmethod
    # 从 diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder 复制的
    # 定义一个类方法,用于将 Lora 模型加载到文本编码器中
        def load_lora_into_text_encoder(
            cls,  # 类本身
            state_dict,  # 状态字典,包含模型权重
            network_alphas,  # 网络中的缩放因子
            text_encoder,  # 文本编码器实例
            prefix=None,  # 可选的前缀,用于命名
            lora_scale=1.0,  # Lora 缩放因子,默认为 1.0
            adapter_name=None,  # 可选的适配器名称
            _pipeline=None,  # 可选的管道参数,用于进一步处理
        @classmethod  # 指定这是一个类方法
        # 从 diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights 拷贝而来,将 unet 替换为 transformer
        def save_lora_weights(
            cls,  # 类本身
            save_directory: Union[str, os.PathLike],  # 保存权重的目录
            transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,  # transformer 的 Lora 层
            text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,  # 文本编码器的 Lora 层
            is_main_process: bool = True,  # 标识当前是否为主进程
            weight_name: str = None,  # 权重文件的名称
            save_function: Callable = None,  # 自定义保存函数
            safe_serialization: bool = True,  # 是否安全序列化,默认为 True
    ):
        r"""  # 定义文档字符串,描述此函数的功能及参数
        Save the LoRA parameters corresponding to the UNet and text encoder.  # 描述保存LoRA参数的功能

        Arguments:  # 开始列出函数的参数
            save_directory (`str` or `os.PathLike`):  # 参数:保存LoRA参数的目录,类型为字符串或路径类
                Directory to save LoRA parameters to. Will be created if it doesn't exist.  # 描述:如果目录不存在,将创建该目录
            transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):  # 参数:与transformer对应的LoRA层的状态字典
                State dict of the LoRA layers corresponding to the `transformer`.  # 描述:说明参数的作用
            text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):  # 参数:与text_encoder对应的LoRA层的状态字典
                State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text  # 描述:说明此参数必须提供,来自🤗 Transformers
                encoder LoRA state dict because it comes from 🤗 Transformers.  # 继续描述参数的来源
            is_main_process (`bool`, *optional*, defaults to `True`):  # 参数:指示当前进程是否为主进程,类型为布尔值
                Whether the process calling this is the main process or not. Useful during distributed training and you  # 描述:用于分布式训练时判断主进程
                need to call this function on all processes. In this case, set `is_main_process=True` only on the main  # 进一步说明如何使用此参数
                process to avoid race conditions.  # 描述:避免竞争条件
            save_function (`Callable`):  # 参数:用于保存状态字典的函数,类型为可调用对象
                The function to use to save the state dictionary. Useful during distributed training when you need to  # 描述:在分布式训练中,可能需要替换默认的保存方法
                replace `torch.save` with another method. Can be configured with the environment variable  # 说明如何配置此参数
                `DIFFUSERS_SAVE_MODE`.  # 提供环境变量名称
            safe_serialization (`bool`, *optional*, defaults to `True`):  # 参数:指示是否使用安全序列化保存模型,类型为布尔值
                Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.  # 描述:选择保存模型的方式
        """
        state_dict = {}  # 初始化一个空字典,用于存储状态字典

        if not (transformer_lora_layers or text_encoder_lora_layers):  # 检查是否至少有一个LoRA层字典传入
            raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")  # 如果没有,抛出异常

        if transformer_lora_layers:  # 如果存在transformer的LoRA层字典
            state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))  # 打包LoRA权重并更新状态字典

        if text_encoder_lora_layers:  # 如果存在text_encoder的LoRA层字典
            state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))  # 打包LoRA权重并更新状态字典

        # Save the model  # 保存模型的注释
        cls.write_lora_layers(  # 调用类方法以写入LoRA层
            state_dict=state_dict,  # 传入状态字典
            save_directory=save_directory,  # 传入保存目录
            is_main_process=is_main_process,  # 传入主进程标志
            weight_name=weight_name,  # 传入权重名称
            save_function=save_function,  # 传入保存函数
            safe_serialization=safe_serialization,  # 传入安全序列化标志
        )

    # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer  # 注释说明此方法的来源和修改
    def fuse_lora(  # 定义fuse_lora方法
        self,  # 方法的第一个参数是实例自身
        components: List[str] = ["transformer", "text_encoder"],  # 参数:要融合的组件列表,默认包含transformer和text_encoder
        lora_scale: float = 1.0,  # 参数:LoRA的缩放因子,默认值为1.0
        safe_fusing: bool = False,  # 参数:指示是否安全融合,默认值为False
        adapter_names: Optional[List[str]] = None,  # 参数:可选的适配器名称列表,默认为None
        **kwargs,  # 可接收其他关键字参数
    ):
        r""" 
        # 文档字符串,说明此函数的作用和用法
        
        Fuses the LoRA parameters into the original parameters of the corresponding blocks.
        # 将 LoRA 参数融合到对应块的原始参数中

        <Tip warning={true}>
        # 警告提示,说明这是一个实验性 API

        This is an experimental API.
        # 这是一项实验性 API

        </Tip>

        Args:
            components: (`List[str]`): 
            # 参数说明,接受一个字符串列表,表示要融合 LoRA 的组件
            
            lora_scale (`float`, defaults to 1.0):
            # 参数说明,控制 LoRA 参数对输出的影响程度
            
                Controls how much to influence the outputs with the LoRA parameters.
                # 控制 LoRA 参数对输出的影响程度
            
            safe_fusing (`bool`, defaults to `False`):
            # 参数说明,是否在融合之前检查权重中是否有 NaN 值
            
                Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
                # 是否在融合之前检查权重的 NaN 值,如果存在则不进行融合
            
            adapter_names (`List[str]`, *optional*):
            # 参数说明,可选的适配器名称列表,用于融合
            
                Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
                # 用于融合的适配器名称列表,如果未传入,则将融合所有活动适配器

        Example:
        # 示例代码,展示如何使用该 API

        ```py
        from diffusers import DiffusionPipeline
        # 导入 DiffusionPipeline 类
        
        import torch
        # 导入 PyTorch 库

        pipeline = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
        ).to("cuda")
        # 从预训练模型创建管道,并将其移动到 CUDA 设备上
        
        pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
        # 加载 LoRA 权重到管道中
        
        pipeline.fuse_lora(lora_scale=0.7)
        # 融合 LoRA 参数,设置影响程度为 0.7
        ```
        """
        super().fuse_lora(
            # 调用父类的 fuse_lora 方法,将相关参数传递给它
            components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
        )

    def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
        r"""
        # 方法文档字符串,说明此方法的作用和用法
        
        Reverses the effect of
        [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
        # 反转 fuse_lora 方法的效果

        <Tip warning={true}>
        # 警告提示,说明这是一个实验性 API

        This is an experimental API.
        # 这是一项实验性 API

        </Tip>

        Args:
            components (`List[str]`): 
            # 参数说明,接受一个字符串列表,表示要从中解除 LoRA 的组件
            
            List of LoRA-injectable components to unfuse LoRA from.
            # 要从中解除 LoRA 的组件列表
        """
        super().unfuse_lora(components=components)
        # 调用父类的 unfuse_lora 方法,将相关参数传递给它
# 这里我们从 `StableDiffusionLoraLoaderMixin` 子类化,因为 Amused 最初依赖于该类提供 LoRA 支持
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
    # 可加载的 LoRA 模块列表
    _lora_loadable_modules = ["transformer", "text_encoder"]
    # 定义变换器的名称
    transformer_name = TRANSFORMER_NAME
    # 定义文本编码器的名称
    text_encoder_name = TEXT_ENCODER_NAME

    @classmethod
    @classmethod
    # 从 diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin 中复制的方法,用于将 LoRA 加载到文本编码器中
    def load_lora_into_text_encoder(
        cls,
        state_dict,
        network_alphas,
        text_encoder,
        prefix=None,
        lora_scale=1.0,
        adapter_name=None,
        _pipeline=None,
    @classmethod
    # 定义保存 LoRA 权重的方法
    def save_lora_weights(
        cls,
        save_directory: Union[str, os.PathLike],
        text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
        transformer_lora_layers: Dict[str, torch.nn.Module] = None,
        is_main_process: bool = True,
        weight_name: str = None,
        save_function: Callable = None,
        safe_serialization: bool = True,
    # 定义一个方法,保存与 UNet 和文本编码器对应的 LoRA 参数
        ):
            r""" 
            保存与 UNet 和文本编码器对应的 LoRA 参数。
    
            参数:
                save_directory (`str` 或 `os.PathLike`):
                    保存 LoRA 参数的目录。如果目录不存在,将被创建。
                unet_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
                    与 `unet` 相关的 LoRA 层的状态字典。
                text_encoder_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
                    与 `text_encoder` 相关的 LoRA 层的状态字典。必须明确传递文本编码器的 LoRA 状态字典,因为它来自 🤗 Transformers。
                is_main_process (`bool`, *可选*, 默认值为 `True`):
                    调用此函数的过程是否为主过程。在分布式训练期间,您需要在所有进程上调用此函数。在这种情况下,只有在主过程中将 `is_main_process=True`,以避免竞争条件。
                save_function (`Callable`):
                    用于保存状态字典的函数。在分布式训练时,需要用其他方法替换 `torch.save`。可以通过环境变量 `DIFFUSERS_SAVE_MODE` 进行配置。
                safe_serialization (`bool`, *可选*, 默认值为 `True`):
                    是否使用 `safetensors` 或传统的 PyTorch 方式通过 `pickle` 保存模型。
            """
            # 初始化状态字典,用于存储 LoRA 参数
            state_dict = {}
    
            # 检查至少传递一个 LoRA 层的状态字典
            if not (transformer_lora_layers or text_encoder_lora_layers):
                # 如果两个都没有,抛出错误
                raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
    
            # 如果有 transformer LoRA 层,更新状态字典
            if transformer_lora_layers:
                state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
    
            # 如果有文本编码器 LoRA 层,更新状态字典
            if text_encoder_lora_layers:
                state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
    
            # 保存模型的过程
            cls.write_lora_layers(
                # 传入状态字典
                state_dict=state_dict,
                # 保存目录
                save_directory=save_directory,
                # 是否为主进程
                is_main_process=is_main_process,
                # 权重名称
                weight_name=weight_name,
                # 保存函数
                save_function=save_function,
                # 是否使用安全序列化
                safe_serialization=safe_serialization,
            )
# 定义一个名为 LoraLoaderMixin 的类,继承自 StableDiffusionLoraLoaderMixin
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
    # 初始化方法,接收可变位置和关键字参数
    def __init__(self, *args, **kwargs):
        # 设置弃用警告信息,提示用户该类将在未来版本中移除
        deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
        # 调用 deprecate 函数,记录该类的弃用信息
        deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
        # 调用父类的初始化方法,传递位置和关键字参数
        super().__init__(*args, **kwargs)

.\diffusers\loaders\peft.py

# 指定文件的编码格式为 UTF-8
# coding=utf-8
# 版权声明,表示此代码归 2024 The HuggingFace Inc. 团队所有
# Copyright 2024 The HuggingFace Inc. team.
#
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 除非遵循许可证,否则不得使用此文件。
# 可在以下网址获取许可证副本:
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,否则软件在“按原样”基础上分发,
# 不附带任何明示或暗示的担保或条件。
# 有关许可证下具体权利和限制的更多信息,请查看许可证。
import inspect  # 导入 inspect 模块,用于获取对象的信息
from functools import partial  # 从 functools 模块导入 partial 函数,用于部分应用
from typing import Dict, List, Optional, Union  # 导入类型注解,便于类型提示

from ..utils import (  # 从父目录的 utils 模块导入多个工具函数
    MIN_PEFT_VERSION,  # 最小 PEFT 版本常量
    USE_PEFT_BACKEND,  # 使用 PEFT 后端的标志
    check_peft_version,  # 检查 PEFT 版本的函数
    delete_adapter_layers,  # 删除适配器层的函数
    is_peft_available,  # 检查 PEFT 是否可用的函数
    set_adapter_layers,  # 设置适配器层的函数
    set_weights_and_activate_adapters,  # 设置权重并激活适配器的函数
)
from .unet_loader_utils import _maybe_expand_lora_scales  # 从当前目录的 unet_loader_utils 模块导入函数

# 定义适配器缩放函数映射字典,以模型名称为键,缩放函数为值
_SET_ADAPTER_SCALE_FN_MAPPING = {
    "UNet2DConditionModel": _maybe_expand_lora_scales,  # UNet2DConditionModel 使用特定缩放函数
    "UNetMotionModel": _maybe_expand_lora_scales,  # UNetMotionModel 使用特定缩放函数
    "SD3Transformer2DModel": lambda model_cls, weights: weights,  # SD3Transformer2DModel 直接返回权重
    "FluxTransformer2DModel": lambda model_cls, weights: weights,  # FluxTransformer2DModel 直接返回权重
}

class PeftAdapterMixin:  # 定义 PeftAdapterMixin 类
    """
    包含用于加载和使用适配器权重的所有函数,该函数在 PEFT 库中受支持。有关适配器的更多详细信息以及如何将其注入基础模型,请查阅 PEFT
    [文档](https://huggingface.co/docs/peft/index)。

    安装最新版本的 PEFT,并使用此混入以:

    - 在模型中附加新适配器。
    - 附加多个适配器并逐步激活/停用它们。
    - 激活/停用模型中的所有适配器。
    - 获取活动适配器的列表。
    """

    _hf_peft_config_loaded = False  # 初始化标志,指示 PEFT 配置是否已加载

    def set_adapters(  # 定义设置适配器的方法
        self,
        adapter_names: Union[List[str], str],  # 适配器名称,可以是字符串或字符串列表
        weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,  # 可选的权重参数
    ):
        """
        设置当前活跃的适配器,以便在 UNet 中使用。

        参数:
            adapter_names (`List[str]` 或 `str`):
                要使用的适配器名称。
            adapter_weights (`Union[List[float], float]`, *可选*):
                与 UNet 一起使用的适配器权重。如果为 `None`,则所有适配器的权重设置为 `1.0`。

        示例:

        ```py
        from diffusers import AutoPipelineForText2Image
        import torch

        pipeline = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
        ).to("cuda")
        pipeline.load_lora_weights(
            "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
        )
        pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
        pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
        ```py
        """
        # 检查是否启用 PEFT 后端,如果未启用则引发错误
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for `set_adapters()`.")

        # 如果 adapter_names 是字符串,则将其转换为列表
        adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

        # 将权重扩展为列表,每个适配器一个条目
        # 例如对于 2 个适配器:  [{...}, 7] -> [7,7] ; None -> [None, None]
        if not isinstance(weights, list):
            weights = [weights] * len(adapter_names)

        # 检查适配器名称和权重的长度是否匹配
        if len(adapter_names) != len(weights):
            raise ValueError(
                f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
            )

        # 将 None 值设置为默认的 1.0
        # 例如: [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
        weights = [w if w is not None else 1.0 for w in weights]

        # 扩展权重以适配特定适配器的要求
        # 例如: [{...}, 7] -> [{expanded dict...}, 7]
        scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
        weights = scale_expansion_fn(self, weights)

        # 设置权重并激活适配器
        set_weights_and_activate_adapters(self, adapter_names, weights)
    # 定义一个添加适配器的方法,接受适配器配置和可选的适配器名称
    def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
        r"""
        向当前模型添加一个新的适配器用于训练。如果未传递适配器名称,将为适配器分配默认名称,以遵循 PEFT 库的约定。

        如果您不熟悉适配器和 PEFT 方法,建议您查看 PEFT 的
        [文档](https://huggingface.co/docs/peft)。

        参数:
            adapter_config (`[~peft.PeftConfig]`):
                要添加的适配器的配置;支持的适配器包括非前缀调整和适应提示方法。
            adapter_name (`str`, *可选*, 默认为 `"default"`):
                要添加的适配器名称。如果未传递名称,将为适配器分配默认名称。
        """
        # 检查 PEFT 的版本是否符合最低要求
        check_peft_version(min_version=MIN_PEFT_VERSION)

        # 如果 PEFT 不可用,则引发 ImportError 异常
        if not is_peft_available():
            raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")

        # 从 peft 模块导入 PeftConfig 和 inject_adapter_in_model
        from peft import PeftConfig, inject_adapter_in_model

        # 如果尚未加载 HF PEFT 配置,则设置标志为 True
        if not self._hf_peft_config_loaded:
            self._hf_peft_config_loaded = True
        # 如果适配器名称已存在,则引发 ValueError 异常
        elif adapter_name in self.peft_config:
            raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")

        # 检查 adapter_config 是否为 PeftConfig 的实例
        if not isinstance(adapter_config, PeftConfig):
            raise ValueError(
                f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
            )

        # 将适配器配置的基础模型名称或路径设置为 None,因为加载逻辑由 load_lora_layers 或 StableDiffusionLoraLoaderMixin 处理
        adapter_config.base_model_name_or_path = None
        # 将适配器注入到模型中
        inject_adapter_in_model(adapter_config, self, adapter_name)
        # 设置适配器名称
        self.set_adapter(adapter_name)
    # 设置特定适配器,强制模型仅使用该适配器并禁用其他适配器
    def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
        """
        设置特定适配器,强制模型仅使用该适配器并禁用其他适配器。
    
        如果您不熟悉适配器和 PEFT 方法,我们邀请您阅读 PEFT 的更多信息
        [文档](https://huggingface.co/docs/peft)。
    
        参数:
            adapter_name (Union[str, List[str]])):
                要设置的适配器名称或适配器名称列表(如果是单个适配器)。
        """
        # 检查 PEFT 版本是否符合最低要求
        check_peft_version(min_version=MIN_PEFT_VERSION)
    
        # 如果尚未加载 HF PEFT 配置,则抛出错误
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")
    
        # 如果适配器名称是字符串,则将其转换为列表
        if isinstance(adapter_name, str):
            adapter_name = [adapter_name]
    
        # 计算缺失的适配器名称
        missing = set(adapter_name) - set(self.peft_config)
        # 如果有缺失的适配器,则抛出错误
        if len(missing) > 0:
            raise ValueError(
                f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
                f" current loaded adapters are: {list(self.peft_config.keys())}"
            )
    
        # 从 peft.tuners.tuners_utils 导入 BaseTunerLayer
        from peft.tuners.tuners_utils import BaseTunerLayer
    
        # 初始化适配器设置标志
        _adapters_has_been_set = False
    
        # 遍历模型中命名的模块
        for _, module in self.named_modules():
            # 如果模块是 BaseTunerLayer 的实例
            if isinstance(module, BaseTunerLayer):
                # 如果模块具有 set_adapter 方法,则调用该方法
                if hasattr(module, "set_adapter"):
                    module.set_adapter(adapter_name)
                # 如果没有 set_adapter 方法且适配器名称列表不为1,抛出错误
                elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
                    raise ValueError(
                        "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
                        " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
                    )
                # 否则,将活动适配器设置为适配器名称
                else:
                    module.active_adapter = adapter_name
                # 标记适配器已设置
                _adapters_has_been_set = True
    
        # 如果没有成功设置适配器,则抛出错误
        if not _adapters_has_been_set:
            raise ValueError(
                "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
            )
    # 定义一个方法来禁用模型中所有附加的适配器,仅使用基础模型进行推理
    def disable_adapters(self) -> None:
        r"""
        禁用所有附加到模型的适配器,并回退到仅使用基础模型进行推理。
    
        如果您对适配器和 PEFT 方法不熟悉,我们邀请您在 PEFT
        [文档](https://huggingface.co/docs/peft) 中了解更多信息。
        """
        # 检查 PEFT 版本,确保满足最低版本要求
        check_peft_version(min_version=MIN_PEFT_VERSION)
    
        # 如果没有加载 HF PEFT 配置,则抛出异常
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")
    
        # 从 tuners_utils 导入基础调优层
        from peft.tuners.tuners_utils import BaseTunerLayer
    
        # 遍历模型中所有命名的模块
        for _, module in self.named_modules():
            # 如果模块是基础调优层的实例
            if isinstance(module, BaseTunerLayer):
                # 如果模块具有 enable_adapters 属性,则禁用适配器
                if hasattr(module, "enable_adapters"):
                    module.enable_adapters(enabled=False)
                else:
                    # 支持旧版 PEFT
                    module.disable_adapters = True
    
    # 定义一个方法来启用附加到模型的适配器
    def enable_adapters(self) -> None:
        """
        启用附加到模型的适配器。模型使用 `self.active_adapters()` 检索要启用的适配器列表。
    
        如果您对适配器和 PEFT 方法不熟悉,我们邀请您在 PEFT
        [文档](https://huggingface.co/docs/peft) 中了解更多信息。
        """
        # 检查 PEFT 版本,确保满足最低版本要求
        check_peft_version(min_version=MIN_PEFT_VERSION)
    
        # 如果没有加载 HF PEFT 配置,则抛出异常
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")
    
        # 从 tuners_utils 导入基础调优层
        from peft.tuners.tuners_utils import BaseTunerLayer
    
        # 遍历模型中所有命名的模块
        for _, module in self.named_modules():
            # 如果模块是基础调优层的实例
            if isinstance(module, BaseTunerLayer):
                # 如果模块具有 enable_adapters 属性,则启用适配器
                if hasattr(module, "enable_adapters"):
                    module.enable_adapters(enabled=True)
                else:
                    # 支持旧版 PEFT
                    module.disable_adapters = False
    
    # 定义一个方法来获取当前活动的适配器列表
    def active_adapters(self) -> List[str]:
        """
        获取模型当前活动的适配器列表。
    
        如果您对适配器和 PEFT 方法不熟悉,我们邀请您在 PEFT
        [文档](https://huggingface.co/docs/peft) 中了解更多信息。
        """
        # 检查 PEFT 版本,确保满足最低版本要求
        check_peft_version(min_version=MIN_PEFT_VERSION)
    
        # 如果 PEFT 不可用,则抛出导入错误
        if not is_peft_available():
            raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
    
        # 如果没有加载 HF PEFT 配置,则抛出异常
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")
    
        # 从 tuners_utils 导入基础调优层
        from peft.tuners.tuners_utils import BaseTunerLayer
    
        # 遍历模型中所有命名的模块
        for _, module in self.named_modules():
            # 如果模块是基础调优层的实例
            if isinstance(module, BaseTunerLayer):
                # 返回活动适配器
                return module.active_adapter
    # 定义融合 LoRA 的方法,允许传入比例和安全融合标志
        def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
            # 检查是否使用 PEFT 后端
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for `fuse_lora()`.")
    
            # 设置 LoRA 的比例
            self.lora_scale = lora_scale
            # 设置安全融合标志
            self._safe_fusing = safe_fusing
            # 应用融合方法,部分应用于指定的适配器名称
            self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
    
        # 定义具体的 LoRA 融合应用方法
        def _fuse_lora_apply(self, module, adapter_names=None):
            # 从 PEFT 库导入基础调优层
            from peft.tuners.tuners_utils import BaseTunerLayer
    
            # 设置合并参数,包含安全合并标志
            merge_kwargs = {"safe_merge": self._safe_fusing}
    
            # 检查模块是否为基础调优层
            if isinstance(module, BaseTunerLayer):
                # 如果 LoRA 比例不为 1.0,则调整比例
                if self.lora_scale != 1.0:
                    module.scale_layer(self.lora_scale)
    
                # 检查合并方法是否支持适配器名称参数
                supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
                if "adapter_names" in supported_merge_kwargs:
                    merge_kwargs["adapter_names"] = adapter_names
                elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
                    # 抛出错误提示 PEFT 版本不支持适配器名称
                    raise ValueError(
                        "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
                        " to the latest version of PEFT. `pip install -U peft`"
                    )
    
                # 调用合并方法,传入合并参数
                module.merge(**merge_kwargs)
    
        # 定义解除 LoRA 融合的方法
        def unfuse_lora(self):
            # 检查是否使用 PEFT 后端
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for `unfuse_lora()`.")
            # 应用解除融合方法
            self.apply(self._unfuse_lora_apply)
    
        # 定义具体的 LoRA 解除融合应用方法
        def _unfuse_lora_apply(self, module):
            # 从 PEFT 库导入基础调优层
            from peft.tuners.tuners_utils import BaseTunerLayer
    
            # 检查模块是否为基础调优层
            if isinstance(module, BaseTunerLayer):
                # 调用解除合并方法
                module.unmerge()
    
        # 定义卸载 LoRA 的方法
        def unload_lora(self):
            # 检查是否使用 PEFT 后端
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for `unload_lora()`.")
    
            # 从工具库导入递归删除 PEFT 层的方法
            from ..utils import recurse_remove_peft_layers
    
            # 递归删除 PEFT 层
            recurse_remove_peft_layers(self)
            # 如果存在 PEFT 配置,则删除该属性
            if hasattr(self, "peft_config"):
                del self.peft_config
    
        # 定义禁用 LoRA 的方法
        def disable_lora(self):
            """
            禁用底层模型的活动 LoRA 层。
    
            示例:
    
            ```py
            from diffusers import AutoPipelineForText2Image
            import torch
    
            pipeline = AutoPipelineForText2Image.from_pretrained(
                "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
            ).to("cuda")
            pipeline.load_lora_weights(
                "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
            )
            pipeline.disable_lora()
            ```py
            """
            # 检查是否使用 PEFT 后端
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for this method.")
            # 设置适配器层为禁用状态
            set_adapter_layers(self, enabled=False)
    # 启用底层模型的活动 LoRA 层
    def enable_lora(self):
        # 检查是否启用了 PEFT 后端,若未启用则抛出错误
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")
        # 设置适配器层为启用状态
        set_adapter_layers(self, enabled=True)
    
    # 删除底层模型的适配器 LoRA 层
    def delete_adapters(self, adapter_names: Union[List[str], str]):
        # 删除适配器的 LoRA 层的说明
        """
        Delete an adapter's LoRA layers from the underlying model.
    
        Args:
            adapter_names (`Union[List[str], str]`):
                The names (single string or list of strings) of the adapter to delete.
    
        Example:
        ...
        """
        # 检查是否启用了 PEFT 后端,若未启用则抛出错误
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")
    
        # 如果传入的是单个适配器名称,将其转换为列表
        if isinstance(adapter_names, str):
            adapter_names = [adapter_names]
    
        # 遍历所有适配器名称,逐个删除
        for adapter_name in adapter_names:
            # 删除指定适配器的层
            delete_adapter_layers(self, adapter_name)
    
            # 从配置中也删除相应的适配器
            if hasattr(self, "peft_config"):
                self.peft_config.pop(adapter_name, None)

.\diffusers\loaders\single_file.py

# 版权所有 2024 HuggingFace 团队,保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)许可;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下位置获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件
# 按“原样”分发,没有任何明示或暗示的保证或条件。
# 请参阅许可证以获取特定语言的权限和
# 限制。
import importlib  # 导入用于动态导入模块的库
import inspect  # 导入用于检查对象的库
import os  # 导入与操作系统交互的库

import torch  # 导入 PyTorch 库
from huggingface_hub import snapshot_download  # 从 Hugging Face Hub 下载快照
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args  # 导入特定异常和验证函数
from packaging import version  # 导入用于处理版本的库

from ..utils import deprecate, is_transformers_available, logging  # 从上级模块导入工具函数
from .single_file_utils import (  # 从当前模块导入多个单文件相关的工具函数和类
    SingleFileComponentError,
    _is_legacy_scheduler_kwargs,
    _is_model_weights_in_cached_folder,
    _legacy_load_clip_tokenizer,
    _legacy_load_safety_checker,
    _legacy_load_scheduler,
    create_diffusers_clip_model_from_ldm,
    create_diffusers_t5_model_from_checkpoint,
    fetch_diffusers_config,
    fetch_original_config,
    is_clip_model_in_single_file,
    is_t5_in_single_file,
    load_single_file_checkpoint,
)

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# 旧版行为。`from_single_file` 不会加载安全检查器,除非明确提供
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]

if is_transformers_available():  # 检查 transformers 库是否可用
    import transformers  # 导入 transformers 库
    from transformers import PreTrainedModel, PreTrainedTokenizer  # 导入预训练模型和分词器类


def load_single_file_sub_model(  # 定义加载单文件子模型的函数
    library_name,  # 库名称
    class_name,  # 类名称
    name,  # 模型名称
    checkpoint,  # 检查点
    pipelines,  # 管道
    is_pipeline_module,  # 是否为管道模块
    cached_model_config_path,  # 缓存模型配置路径
    original_config=None,  # 原始配置(可选)
    local_files_only=False,  # 是否仅使用本地文件
    torch_dtype=None,  # PyTorch 数据类型
    is_legacy_loading=False,  # 是否为旧版加载
    **kwargs,  # 其他参数
):
    if is_pipeline_module:  # 如果是管道模块
        pipeline_module = getattr(pipelines, library_name)  # 从管道中获取指定库的模块
        class_obj = getattr(pipeline_module, class_name)  # 获取指定类
    else:  # 否则从库中导入
        library = importlib.import_module(library_name)  # 动态导入库
        class_obj = getattr(library, class_name)  # 获取指定类

    if is_transformers_available():  # 检查 transformers 库是否可用
        transformers_version = version.parse(version.parse(transformers.__version__).base_version)  # 解析 transformers 版本
    else:  # 如果不可用
        transformers_version = "N/A"  # 设置版本为不可用

    is_transformers_model = (  # 检查是否为 transformers 模型
        is_transformers_available()  # transformers 可用
        and issubclass(class_obj, PreTrainedModel)  # 是预训练模型的子类
        and transformers_version >= version.parse("4.20.0")  # 版本不低于 4.20.0
    )
    is_tokenizer = (  # 检查是否为分词器
        is_transformers_available()  # transformers 可用
        and issubclass(class_obj, PreTrainedTokenizer)  # 是预训练分词器的子类
        and transformers_version >= version.parse("4.20.0")  # 版本不低于 4.20.0
    )

    diffusers_module = importlib.import_module(__name__.split(".")[0])  # 动态导入当前模块的上级模块
    is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)  # 检查是否为 diffusers 单文件模型的子类
    # 检查类对象是否是 diffusers_module.ModelMixin 的子类
    is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
    # 检查类对象是否是 diffusers_module.SchedulerMixin 的子类
    is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
    
    # 如果是单文件模型
    if is_diffusers_single_file_model:
        # 获取类对象的 from_single_file 方法
        load_method = getattr(class_obj, "from_single_file")
    
        # 如果提供了 original_config,则不能同时使用 cached_model_config_path
        if original_config:
            # 忽略加载缓存的模型配置路径
            cached_model_config_path = None
    
        # 调用 from_single_file 方法加载子模型
        loaded_sub_model = load_method(
            pretrained_model_link_or_path_or_dict=checkpoint,  # 加载预训练模型链接或路径或字典
            original_config=original_config,  # 原始配置
            config=cached_model_config_path,  # 缓存的模型配置路径
            subfolder=name,  # 子文件夹名称
            torch_dtype=torch_dtype,  # Torch 数据类型
            local_files_only=local_files_only,  # 仅加载本地文件
            **kwargs,  # 其他参数
        )
    
    # 如果是 transformers 模型且是单文件中的 CLIP 模型
    elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
        # 从 LDM 创建 diffusers CLIP 模型
        loaded_sub_model = create_diffusers_clip_model_from_ldm(
            class_obj,  # 类对象
            checkpoint=checkpoint,  # 检查点
            config=cached_model_config_path,  # 缓存的模型配置路径
            subfolder=name,  # 子文件夹名称
            torch_dtype=torch_dtype,  # Torch 数据类型
            local_files_only=local_files_only,  # 仅加载本地文件
            is_legacy_loading=is_legacy_loading,  # 是否为遗留加载
        )
    
    # 如果是 transformers 模型且检查点是单文件中的 T5 模型
    elif is_transformers_model and is_t5_in_single_file(checkpoint):
        # 从检查点创建 diffusers T5 模型
        loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
            class_obj,  # 类对象
            checkpoint=checkpoint,  # 检查点
            config=cached_model_config_path,  # 缓存的模型配置路径
            subfolder=name,  # 子文件夹名称
            torch_dtype=torch_dtype,  # Torch 数据类型
            local_files_only=local_files_only,  # 仅加载本地文件
        )
    
    # 如果是 tokenizer 并且在遗留加载状态
    elif is_tokenizer and is_legacy_loading:
        # 从检查点加载遗留 CLIP tokenizer
        loaded_sub_model = _legacy_load_clip_tokenizer(
            class_obj,  # 类对象
            checkpoint=checkpoint,  # 检查点
            config=cached_model_config_path,  # 缓存的模型配置路径
            local_files_only=local_files_only  # 仅加载本地文件
        )
    
    # 如果是 diffusers scheduler 且处于遗留加载状态或参数为遗留 scheduler 的关键字
    elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
        # 加载遗留调度器
        loaded_sub_model = _legacy_load_scheduler(
            class_obj,  # 类对象
            checkpoint=checkpoint,  # 检查点
            component_name=name,  # 组件名称
            original_config=original_config,  # 原始配置
            **kwargs  # 其他参数
        )
    else:  # 处理非预期条件的情况
        # 检查 class_obj 是否具有 from_pretrained 方法
        if not hasattr(class_obj, "from_pretrained"):
            # 如果没有,抛出值错误,提示加载方法不支持
            raise ValueError(
                (
                    f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
                    " a supported loading method."
                )
            )

        loading_kwargs = {}  # 初始化加载参数的字典
        # 更新加载参数字典,添加预训练模型路径和其他配置
        loading_kwargs.update(
            {
                "pretrained_model_name_or_path": cached_model_config_path,  # 预训练模型路径
                "subfolder": name,  # 子文件夹名称
                "local_files_only": local_files_only,  # 仅加载本地文件的标志
            }
        )

        # Schedulers 和 Tokenizers 不使用 torch_dtype
        # 因此跳过将其传递给这些对象
        if issubclass(class_obj, torch.nn.Module):  # 检查 class_obj 是否是 torch.nn.Module 的子类
            loading_kwargs.update({"torch_dtype": torch_dtype})  # 如果是,添加 torch_dtype 到加载参数

        # 检查是否为 diffusers 或 transformers 模型
        if is_diffusers_model or is_transformers_model:
            # 检查权重文件是否存在于缓存文件夹中
            if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
                # 如果权重缺失,抛出错误
                raise SingleFileComponentError(
                    f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
                )

        # 获取 class_obj 的 from_pretrained 方法
        load_method = getattr(class_obj, "from_pretrained")
        # 调用 from_pretrained 方法并传入加载参数,加载子模型
        loaded_sub_model = load_method(**loading_kwargs)

    return loaded_sub_model  # 返回加载的子模型
# 映射组件类型到配置字典
def _map_component_types_to_config_dict(component_types):
    # 导入当前模块的主模块
    diffusers_module = importlib.import_module(__name__.split(".")[0])
    # 初始化配置字典
    config_dict = {}
    # 从组件类型中移除 'self' 键
    component_types.pop("self", None)

    # 检查 transformers 库是否可用
    if is_transformers_available():
        # 解析 transformers 版本的基版本
        transformers_version = version.parse(version.parse(transformers.__version__).base_version)
    else:
        # 如果不可用,则版本设置为 "N/A"
        transformers_version = "N/A"

    # 遍历组件名称和对应的值
    for component_name, component_value in component_types.items():
        # 检查组件值是否为 diffusers 模型的子类
        is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
        # 检查组件值是否为 KarrasDiffusionSchedulers 枚举
        is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
        # 检查组件值是否为调度器的子类
        is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)

        # 检查组件值是否为 transformers 模型
        is_transformers_model = (
            is_transformers_available()
            and issubclass(component_value[0], PreTrainedModel)
            and transformers_version >= version.parse("4.20.0")
        )
        # 检查组件值是否为 transformers 分词器
        is_transformers_tokenizer = (
            is_transformers_available()
            and issubclass(component_value[0], PreTrainedTokenizer)
            and transformers_version >= version.parse("4.20.0")
        )

        # 如果是 diffusers 模型且不在单文件可选组件中
        if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
            # 将组件名称和模型名称添加到配置字典
            config_dict[component_name] = ["diffusers", component_value[0].__name__]

        # 如果是调度器枚举或调度器
        elif is_scheduler_enum or is_scheduler:
            # 如果是调度器枚举,默认使用 DDIMScheduler
            if is_scheduler_enum:
                # 因为无法从 hub 获取调度器配置,默认使用 DDIMScheduler
                config_dict[component_name] = ["diffusers", "DDIMScheduler"]

            # 如果是调度器
            elif is_scheduler:
                config_dict[component_name] = ["diffusers", component_value[0].__name__]

        # 如果是 transformers 模型或分词器且不在单文件可选组件中
        elif (
            is_transformers_model or is_transformers_tokenizer
        ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
            # 将组件名称和模型名称添加到配置字典
            config_dict[component_name] = ["transformers", component_value[0].__name__]

        # 否则设置为 None
        else:
            config_dict[component_name] = [None, None]

    # 返回配置字典
    return config_dict


# 推断管道配置字典
def _infer_pipeline_config_dict(pipeline_class):
    # 获取管道类初始化方法的参数
    parameters = inspect.signature(pipeline_class.__init__).parameters
    # 收集所有必需的参数
    required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
    # 获取管道类的组件类型
    component_types = pipeline_class._get_signature_types()

    # 忽略非必需参数的组件类型
    component_types = {k: v for k, v in component_types.items() if k in required_parameters}
    # 映射组件类型到配置字典
    config_dict = _map_component_types_to_config_dict(component_types)

    # 返回配置字典
    return config_dict


# 从 hub 下载 diffusers 模型配置
def _download_diffusers_model_config_from_hub(
    pretrained_model_name_or_path,
    cache_dir,
    revision,
    proxies,
    force_download=None,
    local_files_only=None,
    token=None,
):
    # 定义允许的文件模式
    allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
    # 下载预训练模型的快照,并将其缓存到指定目录
        cached_model_path = snapshot_download(
            # 指定要下载的预训练模型名称或路径
            pretrained_model_name_or_path,
            # 指定缓存目录
            cache_dir=cache_dir,
            # 指定版本修订
            revision=revision,
            # 代理设置
            proxies=proxies,
            # 是否强制下载,即使缓存中已有
            force_download=force_download,
            # 是否仅使用本地文件
            local_files_only=local_files_only,
            # 访问令牌
            token=token,
            # 允许的文件模式
            allow_patterns=allow_patterns,
        )
    
    # 返回缓存模型的路径
        return cached_model_path
# 定义一个名为 FromSingleFileMixin 的类
class FromSingleFileMixin:
    """
    加载以 `.ckpt` 格式保存的模型权重到 [`DiffusionPipeline`] 中。
    """

    # 定义一个类方法,通常用于从类中直接调用
    @classmethod
    # 装饰器,用于验证与 Hugging Face Hub 相关的参数
    @validate_hf_hub_args

.\diffusers\loaders\single_file_model.py

# 版权声明,说明该文件的所有权归 HuggingFace 团队所有,保留所有权利
# 
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵循该许可证,否则您不能使用此文件。
# 您可以在以下网址获取许可证的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是按“现状”基础提供的,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参阅许可证以了解有关权限和限制的具体信息。
import importlib  # 导入动态模块加载功能的标准库
import inspect  # 导入用于获取对象信息的标准库
import re  # 导入正则表达式模块
from contextlib import nullcontext  # 从上下文管理器模块导入 nullcontext,用于无操作的上下文管理
from typing import Optional  # 从类型提示模块导入 Optional,表示可选类型

from huggingface_hub.utils import validate_hf_hub_args  # 从 Hugging Face Hub 工具导入验证函数

from ..utils import deprecate, is_accelerate_available, logging  # 从上级模块导入相关的工具和日志模块
from .single_file_utils import (  # 从当前包的单文件工具模块导入多个函数和类
    SingleFileComponentError,  # 导入单文件组件错误类
    convert_animatediff_checkpoint_to_diffusers,  # 导入将 Animatediff 检查点转换为 Diffusers 的函数
    convert_controlnet_checkpoint,  # 导入将 ControlNet 检查点转换的函数
    convert_flux_transformer_checkpoint_to_diffusers,  # 导入将 Flux Transformer 检查点转换为 Diffusers 的函数
    convert_ldm_unet_checkpoint,  # 导入将 LDM UNet 检查点转换的函数
    convert_ldm_vae_checkpoint,  # 导入将 LDM VAE 检查点转换的函数
    convert_sd3_transformer_checkpoint_to_diffusers,  # 导入将 SD3 Transformer 检查点转换为 Diffusers 的函数
    convert_stable_cascade_unet_single_file_to_diffusers,  # 导入将 Stable Cascade UNet 单文件转换为 Diffusers 的函数
    create_controlnet_diffusers_config_from_ldm,  # 导入从 LDM 创建 ControlNet Diffusers 配置的函数
    create_unet_diffusers_config_from_ldm,  # 导入从 LDM 创建 UNet Diffusers 配置的函数
    create_vae_diffusers_config_from_ldm,  # 导入从 LDM 创建 VAE Diffusers 配置的函数
    fetch_diffusers_config,  # 导入获取 Diffusers 配置的函数
    fetch_original_config,  # 导入获取原始配置的函数
    load_single_file_checkpoint,  # 导入加载单文件检查点的函数
)

logger = logging.get_logger(__name__)  # 创建一个日志记录器,使用当前模块的名称

if is_accelerate_available():  # 检查是否可用加速库
    from accelerate import init_empty_weights  # 从 accelerate 导入初始化空权重的功能

    from ..models.modeling_utils import load_model_dict_into_meta  # 从上级模型工具模块导入将模型字典加载到元数据的函数

# 定义一个包含可加载类及其相关配置的字典
SINGLE_FILE_LOADABLE_CLASSES = {
    "StableCascadeUNet": {  # 对应 StableCascadeUNet 类的配置
        "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,  # 检查点映射函数
    },
    "UNet2DConditionModel": {  # 对应 UNet2DConditionModel 类的配置
        "checkpoint_mapping_fn": convert_ldm_unet_checkpoint,  # 检查点映射函数
        "config_mapping_fn": create_unet_diffusers_config_from_ldm,  # 配置映射函数
        "default_subfolder": "unet",  # 默认子文件夹名称
        "legacy_kwargs": {  # 旧参数的映射
            "num_in_channels": "in_channels",  # 旧参数映射到新参数的例子
        },
    },
    "AutoencoderKL": {  # 对应 AutoencoderKL 类的配置
        "checkpoint_mapping_fn": convert_ldm_vae_checkpoint,  # 检查点映射函数
        "config_mapping_fn": create_vae_diffusers_config_from_ldm,  # 配置映射函数
        "default_subfolder": "vae",  # 默认子文件夹名称
    },
    "ControlNetModel": {  # 对应 ControlNetModel 类的配置
        "checkpoint_mapping_fn": convert_controlnet_checkpoint,  # 检查点映射函数
        "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,  # 配置映射函数
    },
    "SD3Transformer2DModel": {  # 对应 SD3Transformer2DModel 类的配置
        "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,  # 检查点映射函数
        "default_subfolder": "transformer",  # 默认子文件夹名称
    },
    "MotionAdapter": {  # 对应 MotionAdapter 类的配置
        "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,  # 检查点映射函数
    },
    "SparseControlNetModel": {  # 对应 SparseControlNetModel 类的配置
        "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,  # 检查点映射函数
    },
}
    # 定义一个包含 FluxTransformer2DModel 配置的字典
        "FluxTransformer2DModel": {
            # 指定一个函数,用于将 FluxTransformer 的检查点映射到 diffusers 格式
            "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
            # 设置默认子文件夹名称为 "transformer"
            "default_subfolder": "transformer",
        },
# 结束上一个代码块
}


# 定义获取单一文件可加载映射类的函数
def _get_single_file_loadable_mapping_class(cls):
    # 导入当前模块的父级模块
    diffusers_module = importlib.import_module(__name__.split(".")[0])
    # 遍历所有单文件可加载类的字符串名称
    for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
        # 获取对应的可加载类
        loadable_class = getattr(diffusers_module, loadable_class_str)

        # 检查给定的类是否是可加载类的子类
        if issubclass(cls, loadable_class):
            # 如果是,则返回该可加载类的字符串名称
            return loadable_class_str

    # 如果没有找到合适的可加载类,返回 None
    return None


# 定义获取映射函数关键字参数的函数
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
    # 获取映射函数的参数签名
    parameters = inspect.signature(mapping_fn).parameters

    # 创建一个字典以存储匹配的关键字参数
    mapping_kwargs = {}
    # 遍历所有参数
    for parameter in parameters:
        # 如果参数在提供的关键字参数中,则存入字典
        if parameter in kwargs:
            mapping_kwargs[parameter] = kwargs[parameter]

    # 返回匹配的关键字参数字典
    return mapping_kwargs


# 定义一个混合类,用于从原始模型加载预训练权重
class FromOriginalModelMixin:
    """
    加载保存为 `.ckpt` 或 `.safetensors` 格式的预训练权重到模型中。
    """

    # 声明该方法为类方法
    @classmethod
    # 应用参数验证装饰器
    @validate_hf_hub_args

.\diffusers\loaders\single_file_utils.py

# 设置文件编码为 UTF-8
# 版权信息,标明版权归 HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证 2.0 版本授权该文件;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律另有规定或书面同意,软件在许可证下分发,
# 均按“原样”基础提供,不附带任何形式的保证或条件,
# 明示或暗示均不作任何承诺。
# 请参阅许可证以获取特定的语言管理权限和
# 限制条款。
"""用于 Stable Diffusion 检查点的转换脚本。"""

# 导入必要的模块
import os  # 用于操作系统功能的模块
import re  # 用于正则表达式操作的模块
from contextlib import nullcontext  # 提供上下文管理器功能
from io import BytesIO  # 用于处理字节流的模块
from urllib.parse import urlparse  # 用于解析 URL 的模块

import requests  # 用于发送 HTTP 请求的库
import torch  # PyTorch 深度学习库
import yaml  # 用于处理 YAML 文件的库

# 导入模型相关的工具
from ..models.modeling_utils import load_state_dict  # 加载模型状态字典的函数
from ..schedulers import (  # 导入不同的调度器类
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    EDMDPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    HeunDiscreteScheduler,
    LMSDiscreteScheduler,
    PNDMScheduler,
)

# 导入实用工具
from ..utils import (  # 导入一些实用函数和常量
    SAFETENSORS_WEIGHTS_NAME,  # 安全张量权重名称
    WEIGHTS_NAME,  # 权重名称
    deprecate,  # 警告使用过时功能的函数
    is_accelerate_available,  # 检查 accelerate 模块是否可用的函数
    is_transformers_available,  # 检查 transformers 模块是否可用的函数
    logging,  # 日志记录功能
)
from ..utils.hub_utils import _get_model_file  # 获取模型文件的辅助函数

# 如果 transformers 可用,则导入相关类
if is_transformers_available():
    from transformers import AutoImageProcessor  # 自动图像处理器类

# 如果 accelerate 可用,则导入相关功能
if is_accelerate_available():
    from accelerate import init_empty_weights  # 初始化空权重的函数

    from ..models.modeling_utils import load_model_dict_into_meta  # 加载模型字典到元数据的函数

logger = logging.get_logger(__name__)  # 创建一个记录器实例,用于日志记录;禁用 pylint 的名称检查

# 定义一个字典,用于存储不同检查点关键名称的映射
CHECKPOINT_KEY_NAMES = {
    "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",  # v2 模型的权重名称
    "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",  # xl_base 模型的偏置名称
    "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",  # xl_refiner 模型的偏置名称
    "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",  # upscale 模型的偏置名称
    "controlnet": "control_model.time_embed.0.weight",  # controlnet 模型的权重名称
    "playground-v2-5": "edm_mean",  # playground-v2-5 模型的平均值
    "inpainting": "model.diffusion_model.input_blocks.0.0.weight",  # inpainting 模型的权重名称
    "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",  # clip 模型的位置嵌入权重
    "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",  # clip_sdxl 模型的位置嵌入权重
    "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",  # clip_sd3 模型的位置嵌入权重
    "open_clip": "cond_stage_model.model.token_embedding.weight",  # open_clip 模型的嵌入权重
    "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",  # open_clip_sdxl 模型的位置嵌入
    "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",  # open_clip_sdxl_refiner 模型的文本投影
    "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",  # open_clip_sd3 模型的位置嵌入权重
    "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",  # stable_cascade_stage_b 模型的权重名称
    "stable_cascade_stage_c": "clip_txt_mapper.weight",  # stable_cascade_stage_c 模型的权重名称
}
    # 定义一个字典的键值对,键为模型名称,值为对应的模型参数路径
    "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
    # 定义另一个模型的参数路径,适用于 animatediff 模型
    "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
    # 定义 animatediff_v2 模型的参数路径
    "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
    # 定义 animatediff_sdxl_beta 模型的参数路径
    "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
    # 定义 animatediff_scribble 模型的参数路径
    "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
    # 定义 animatediff_rgb 模型的参数路径
    "animatediff_rgb": "controlnet_cond_embedding.weight",
    # 定义一个列表,包含 flux 模型相关的参数路径
    "flux": [
        # flux 模型中 double_blocks 组件的参数路径
        "double_blocks.0.img_attn.norm.key_norm.scale",
        # flux 模型中另一个 double_blocks 组件的参数路径
        "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
    ],
# 定义默认的 Diffusers 管道路径,映射模型名称到其预训练模型的路径
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
    # xl_base 模型的预训练模型路径
    "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
    # xl_refiner 模型的预训练模型路径
    "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
    # xl_inpaint 模型的预训练模型路径
    "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
    # playground-v2-5 模型的预训练模型路径
    "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
    # upscale 模型的预训练模型路径
    "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
    # inpainting 模型的预训练模型路径
    "inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"},
    # inpainting_v2 模型的预训练模型路径
    "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
    # controlnet 模型的预训练模型路径
    "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
    # v2 模型的预训练模型路径
    "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
    # v1 模型的预训练模型路径
    "v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"},
    # stable_cascade_stage_b 模型的预训练模型路径及其子文件夹
    "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
    # stable_cascade_stage_b_lite 模型的预训练模型路径及其子文件夹
    "stable_cascade_stage_b_lite": {
        "pretrained_model_name_or_path": "stabilityai/stable-cascade",
        "subfolder": "decoder_lite",
    },
    # stable_cascade_stage_c 模型的预训练模型路径及其子文件夹
    "stable_cascade_stage_c": {
        "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
        "subfolder": "prior",
    },
    # stable_cascade_stage_c_lite 模型的预训练模型路径及其子文件夹
    "stable_cascade_stage_c_lite": {
        "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
        "subfolder": "prior_lite",
    },
    # sd3 模型的预训练模型路径
    "sd3": {
        "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
    },
    # animatediff_v1 模型的预训练模型路径
    "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
    # animatediff_v2 模型的预训练模型路径
    "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
    # animatediff_v3 模型的预训练模型路径
    "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
    # animatediff_sdxl_beta 模型的预训练模型路径
    "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
    # animatediff_scribble 模型的预训练模型路径
    "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
    # animatediff_rgb 模型的预训练模型路径
    "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
    # flux-dev 模型的预训练模型路径
    "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
    # flux-schnell 模型的预训练模型路径
    "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
}

# 用于配置模型样本大小,当提供原始配置时
DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
    # xl_base 模型的默认图像大小
    "xl_base": 1024,
    # xl_refiner 模型的默认图像大小
    "xl_refiner": 1024,
    # xl_inpaint 模型的默认图像大小
    "xl_inpaint": 1024,
    # playground-v2-5 模型的默认图像大小
    "playground-v2-5": 1024,
    # upscale 模型的默认图像大小
    "upscale": 512,
    # inpainting 模型的默认图像大小
    "inpainting": 512,
    # inpainting_v2 模型的默认图像大小
    "inpainting_v2": 512,
    # controlnet 模型的默认图像大小
    "controlnet": 512,
    # v2 模型的默认图像大小
    "v2": 768,
    # v1 模型的默认图像大小
    "v1": 512,
}

# 定义 Diffusers 到 LDM 的映射
DIFFUSERS_TO_LDM_MAPPING = {
    # 定义一个包含 UNet 模型参数的字典
    "unet": {
        # 定义 UNet 模型的层参数
        "layers": {
            # 将时间嵌入层的第一个线性层权重映射到新位置
            "time_embedding.linear_1.weight": "time_embed.0.weight",
            # 将时间嵌入层的第一个线性层偏置映射到新位置
            "time_embedding.linear_1.bias": "time_embed.0.bias",
            # 将时间嵌入层的第二个线性层权重映射到新位置
            "time_embedding.linear_2.weight": "time_embed.2.weight",
            # 将时间嵌入层的第二个线性层偏置映射到新位置
            "time_embedding.linear_2.bias": "time_embed.2.bias",
            # 将输入卷积层的权重映射到新位置
            "conv_in.weight": "input_blocks.0.0.weight",
            # 将输入卷积层的偏置映射到新位置
            "conv_in.bias": "input_blocks.0.0.bias",
            # 将输出归一化层的权重映射到新位置
            "conv_norm_out.weight": "out.0.weight",
            # 将输出归一化层的偏置映射到新位置
            "conv_norm_out.bias": "out.0.bias",
            # 将输出卷积层的权重映射到新位置
            "conv_out.weight": "out.2.weight",
            # 将输出卷积层的偏置映射到新位置
            "conv_out.bias": "out.2.bias",
        },
        # 定义分类嵌入层的参数
        "class_embed_type": {
            # 将分类嵌入层的第一个线性层权重映射到新位置
            "class_embedding.linear_1.weight": "label_emb.0.0.weight",
            # 将分类嵌入层的第一个线性层偏置映射到新位置
            "class_embedding.linear_1.bias": "label_emb.0.0.bias",
            # 将分类嵌入层的第二个线性层权重映射到新位置
            "class_embedding.linear_2.weight": "label_emb.0.2.weight",
            # 将分类嵌入层的第二个线性层偏置映射到新位置
            "class_embedding.linear_2.bias": "label_emb.0.2.bias",
        },
        # 定义附加嵌入层的参数
        "addition_embed_type": {
            # 将附加嵌入层的第一个线性层权重映射到新位置
            "add_embedding.linear_1.weight": "label_emb.0.0.weight",
            # 将附加嵌入层的第一个线性层偏置映射到新位置
            "add_embedding.linear_1.bias": "label_emb.0.0.bias",
            # 将附加嵌入层的第二个线性层权重映射到新位置
            "add_embedding.linear_2.weight": "label_emb.0.2.weight",
            # 将附加嵌入层的第二个线性层偏置映射到新位置
            "add_embedding.linear_2.bias": "label_emb.0.2.bias",
        },
    },
    # 定义一个包含 ControlNet 模型参数的字典
    "controlnet": {
        # 定义 ControlNet 模型的层参数
        "layers": {
            # 将时间嵌入层的第一个线性层权重映射到新位置
            "time_embedding.linear_1.weight": "time_embed.0.weight",
            # 将时间嵌入层的第一个线性层偏置映射到新位置
            "time_embedding.linear_1.bias": "time_embed.0.bias",
            # 将时间嵌入层的第二个线性层权重映射到新位置
            "time_embedding.linear_2.weight": "time_embed.2.weight",
            # 将时间嵌入层的第二个线性层偏置映射到新位置
            "time_embedding.linear_2.bias": "time_embed.2.bias",
            # 将输入卷积层的权重映射到新位置
            "conv_in.weight": "input_blocks.0.0.weight",
            # 将输入卷积层的偏置映射到新位置
            "conv_in.bias": "input_blocks.0.0.bias",
            # 将 ControlNet 条件嵌入的输入卷积层权重映射到新位置
            "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
            # 将 ControlNet 条件嵌入的输入卷积层偏置映射到新位置
            "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias",
            # 将 ControlNet 条件嵌入的输出卷积层权重映射到新位置
            "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight",
            # 将 ControlNet 条件嵌入的输出卷积层偏置映射到新位置
            "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias",
        },
        # 定义分类嵌入层的参数
        "class_embed_type": {
            # 将分类嵌入层的第一个线性层权重映射到新位置
            "class_embedding.linear_1.weight": "label_emb.0.0.weight",
            # 将分类嵌入层的第一个线性层偏置映射到新位置
            "class_embedding.linear_1.bias": "label_emb.0.0.bias",
            # 将分类嵌入层的第二个线性层权重映射到新位置
            "class_embedding.linear_2.weight": "label_emb.0.2.weight",
            # 将分类嵌入层的第二个线性层偏置映射到新位置
            "class_embedding.linear_2.bias": "label_emb.0.2.bias",
        },
        # 定义附加嵌入层的参数
        "addition_embed_type": {
            # 将附加嵌入层的第一个线性层权重映射到新位置
            "add_embedding.linear_1.weight": "label_emb.0.0.weight",
            # 将附加嵌入层的第一个线性层偏置映射到新位置
            "add_embedding.linear_1.bias": "label_emb.0.0.bias",
            # 将附加嵌入层的第二个线性层权重映射到新位置
            "add_embedding.linear_2.weight": "label_emb.0.2.weight",
            # 将附加嵌入层的第二个线性层偏置映射到新位置
            "add_embedding.linear_2.bias": "label_emb.0.2.bias",
        },
    },
    # 定义一个字典,包含 VAE 模型的参数映射
    "vae": {
        # 映射编码器输入卷积层的权重
        "encoder.conv_in.weight": "encoder.conv_in.weight",
        # 映射编码器输入卷积层的偏置
        "encoder.conv_in.bias": "encoder.conv_in.bias",
        # 映射编码器输出卷积层的权重
        "encoder.conv_out.weight": "encoder.conv_out.weight",
        # 映射编码器输出卷积层的偏置
        "encoder.conv_out.bias": "encoder.conv_out.bias",
        # 映射编码器归一化输出层的权重
        "encoder.conv_norm_out.weight": "encoder.norm_out.weight",
        # 映射编码器归一化输出层的偏置
        "encoder.conv_norm_out.bias": "encoder.norm_out.bias",
        # 映射解码器输入卷积层的权重
        "decoder.conv_in.weight": "decoder.conv_in.weight",
        # 映射解码器输入卷积层的偏置
        "decoder.conv_in.bias": "decoder.conv_in.bias",
        # 映射解码器输出卷积层的权重
        "decoder.conv_out.weight": "decoder.conv_out.weight",
        # 映射解码器输出卷积层的偏置
        "decoder.conv_out.bias": "decoder.conv_out.bias",
        # 映射解码器归一化输出层的权重
        "decoder.conv_norm_out.weight": "decoder.norm_out.weight",
        # 映射解码器归一化输出层的偏置
        "decoder.conv_norm_out.bias": "decoder.norm_out.bias",
        # 映射量化卷积层的权重
        "quant_conv.weight": "quant_conv.weight",
        # 映射量化卷积层的偏置
        "quant_conv.bias": "quant_conv.bias",
        # 映射后量化卷积层的权重
        "post_quant_conv.weight": "post_quant_conv.weight",
        # 映射后量化卷积层的偏置
        "post_quant_conv.bias": "post_quant_conv.bias",
    },
    # 定义一个字典,包含 OpenCLIP 模型的参数映射
    "openclip": {
        # 定义嵌套字典,包含文本模型的层参数映射
        "layers": {
            # 映射文本模型的位置信息嵌入层权重
            "text_model.embeddings.position_embedding.weight": "positional_embedding",
            # 映射文本模型的标记嵌入层权重
            "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
            # 映射文本模型最终归一化层的权重
            "text_model.final_layer_norm.weight": "ln_final.weight",
            # 映射文本模型最终归一化层的偏置
            "text_model.final_layer_norm.bias": "ln_final.bias",
            # 映射文本投影层的权重
            "text_projection.weight": "text_projection",
        },
        # 定义嵌套字典,包含 transformer 的参数映射
        "transformer": {
            # 映射文本模型编码器层的前缀
            "text_model.encoder.layers.": "resblocks.",
            # 映射 transformer 的第一层归一化
            "layer_norm1": "ln_1",
            # 映射 transformer 的第二层归一化
            "layer_norm2": "ln_2",
            # 映射全连接层的第一部分
            ".fc1.": ".c_fc.",
            # 映射全连接层的第二部分
            ".fc2.": ".c_proj.",
            # 映射 transformer 最终层归一化的前缀
            "transformer.text_model.final_layer_norm.": "ln_final.",
            # 映射 transformer 中的标记嵌入层权重
            "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
            # 映射 transformer 中的位置信息嵌入层权重
            "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding",
        },
    },
# 定义一个列表,用于存储需要忽略的文本编码器键
SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
    # 忽略的特定权重和偏置参数
    "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
    "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
    "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias",
    "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight",
    "cond_stage_model.model.transformer.resblocks.23.ln_1.bias",
    "cond_stage_model.model.transformer.resblocks.23.ln_1.weight",
    "cond_stage_model.model.transformer.resblocks.23.ln_2.bias",
    "cond_stage_model.model.transformer.resblocks.23.ln_2.weight",
    "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias",
    "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight",
    "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias",
    "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight",
    "cond_stage_model.model.text_projection",
]

# 定义调度器的默认配置,支持遗留的参数类型
SCHEDULER_DEFAULT_CONFIG = {
    # β调度类型
    "beta_schedule": "scaled_linear",
    # 调度的起始值
    "beta_start": 0.00085,
    # 调度的结束值
    "beta_end": 0.012,
    # 插值类型
    "interpolation_type": "linear",
    # 训练时间步数
    "num_train_timesteps": 1000,
    # 预测类型
    "prediction_type": "epsilon",
    # 采样最大值
    "sample_max_value": 1.0,
    # 是否将 alpha 设置为 1
    "set_alpha_to_one": False,
    # 是否跳过 PRK 步骤
    "skip_prk_steps": True,
    # 时间步偏移
    "steps_offset": 1,
    # 时间步间隔类型
    "timestep_spacing": "leading",
}

# 定义包含 VAE 相关键的列表
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
# 定义 VAE 默认缩放因子
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
# 定义 Playground VAE 缩放因子
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
# 定义 LDM UNet 的键
LDM_UNET_KEY = "model.diffusion_model."
# 定义 LDM ControlNet 的键
LDM_CONTROLNET_KEY = "control_model."
# 定义要去除的 CLIP 前缀列表
LDM_CLIP_PREFIX_TO_REMOVE = [
    "cond_stage_model.transformer.",
    "conditioner.embedders.0.transformer.",
]
# 定义 LDM Open CLIP 文本投影维度
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
# 定义遗留调度器的关键字参数列表
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]

# 定义有效的 URL 前缀列表
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]

# 定义自定义异常类,用于单文件组件错误
class SingleFileComponentError(Exception):
    # 初始化异常类
    def __init__(self, message=None):
        self.message = message
        # 调用父类构造函数
        super().__init__(self.message)

# 定义验证 URL 的函数
def is_valid_url(url):
    # 解析 URL
    result = urlparse(url)
    # 检查是否有有效的方案和网络地址
    if result.scheme and result.netloc:
        return True

    # 返回无效
    return False

# 定义提取模型 ID 和权重名称的私有函数
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
    # 检查提供的路径是否是有效 URL
    if not is_valid_url(pretrained_model_name_or_path):
        raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")

    # 定义匹配模式
    pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
    weights_name = None
    repo_id = (None,)
    # 遍历有效的 URL 前缀进行替换
    for prefix in VALID_URL_PREFIXES:
        pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
    # 使用正则表达式匹配模式
    match = re.match(pattern, pretrained_model_name_or_path)
    # 如果没有匹配,记录警告并返回
    if not match:
        logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
        return repo_id, weights_name

    # 提取 repo_id 和 weights_name
    repo_id = f"{match.group(1)}/{match.group(2)}"
    weights_name = match.group(3)

    # 返回提取的结果
    return repo_id, weights_name
# 检查模型权重是否在缓存文件夹中
def _is_model_weights_in_cached_folder(cached_folder, name):
    # 拼接缓存文件夹路径和模型名称,形成预训练模型的路径
    pretrained_model_name_or_path = os.path.join(cached_folder, name)
    # 初始化权重存在标志为 False
    weights_exist = False

    # 遍历可能的权重文件名
    for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
        # 检查指定路径下是否存在权重文件
        if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
            # 如果存在,设置权重存在标志为 True
            weights_exist = True

    # 返回权重是否存在的标志
    return weights_exist


# 检查传入的关键字参数是否包含遗留调度器的关键字
def _is_legacy_scheduler_kwargs(kwargs):
    # 检查关键字参数中是否有遗留调度器的关键字
    return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())


# 加载单个文件的检查点
def load_single_file_checkpoint(
    pretrained_model_link_or_path,
    force_download=False,
    proxies=None,
    token=None,
    cache_dir=None,
    local_files_only=None,
    revision=None,
):
    # 检查给定路径是否是一个文件
    if os.path.isfile(pretrained_model_link_or_path):
        # 如果是文件,保持路径不变
        pretrained_model_link_or_path = pretrained_model_link_or_path

    else:
        # 如果不是文件,提取仓库 ID 和权重名称
        repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
        # 获取模型文件的路径
        pretrained_model_link_or_path = _get_model_file(
            repo_id,
            weights_name=weights_name,
            force_download=force_download,
            cache_dir=cache_dir,
            proxies=proxies,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
        )

    # 加载状态字典
    checkpoint = load_state_dict(pretrained_model_link_or_path)

    # 一些检查点的模型状态字典可能在 "state_dict" 键下
    while "state_dict" in checkpoint:
        # 取出状态字典
        checkpoint = checkpoint["state_dict"]

    # 返回最终的检查点
    return checkpoint


# 获取原始配置
def fetch_original_config(original_config_file, local_files_only=False):
    # 检查给定的配置文件是否是一个有效的文件
    if os.path.isfile(original_config_file):
        # 如果是文件,读取其内容
        with open(original_config_file, "r") as fp:
            original_config_file = fp.read()

    elif is_valid_url(original_config_file):
        # 如果是有效的 URL
        if local_files_only:
            # 如果设置为只允许本地文件,抛出错误
            raise ValueError(
                "`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
                "Please provide a valid local file path."
            )

        # 下载 URL 的内容并封装为字节流
        original_config_file = BytesIO(requests.get(original_config_file).content)

    else:
        # 如果既不是文件也不是有效的 URL,抛出错误
        raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")

    # 解析 YAML 格式的原始配置
    original_config = yaml.safe_load(original_config_file)

    # 返回解析后的配置
    return original_config


# 检查给定的检查点是否为 CLIP 模型
def is_clip_model(checkpoint):
    # 检查检查点中是否包含 CLIP 模型的键
    if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
        return True

    # 返回 False
    return False


# 检查给定的检查点是否为 CLIP SDXL 模型
def is_clip_sdxl_model(checkpoint):
    # 检查检查点中是否包含 CLIP SDXL 模型的键
    if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
        return True

    # 返回 False
    return False


# 检查给定的检查点是否为 CLIP SD3 模型
def is_clip_sd3_model(checkpoint):
    # 检查检查点中是否包含 CLIP SD3 模型的键
    if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
        return True

    # 返回 False
    return False


# 检查给定的检查点是否为 Open CLIP 模型
def is_open_clip_model(checkpoint):
    # 检查检查点中是否包含 Open CLIP 模型的键
    if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
        return True

    # 返回 False
    return False


# 检查给定的检查点是否为 Open CLIP SDXL 模型
def is_open_clip_sdxl_model(checkpoint):
    # 检查检查点中是否包含 Open CLIP SDXL 模型的键
    if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
        return True

    # 返回 False
    return False


# 检查给定的检查点是否为 Open CLIP SD3 模型
def is_open_clip_sd3_model(checkpoint):
    # 检查检查点中是否包含特定的键
        if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
            # 如果找到特定键,返回 True
            return True
    
        # 如果没有找到特定键,返回 False
        return False
# 检查给定的检查点是否包含 OpenCLIP SDXL Refiner 模型的关键字
def is_open_clip_sdxl_refiner_model(checkpoint):
    # 如果检查点中包含指定的关键字,则返回 True
    if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
        return True

    # 否则返回 False
    return False


# 检查给定的类对象是否与单个文件中的 CLIP 模型匹配
def is_clip_model_in_single_file(class_obj, checkpoint):
    # 检查检查点是否包含任何 CLIP 模型的关键字
    is_clip_in_checkpoint = any(
        [
            is_clip_model(checkpoint),  # 检查是否是 CLIP 模型
            is_clip_sd3_model(checkpoint),  # 检查是否是 SD3 模型
            is_open_clip_model(checkpoint),  # 检查是否是 OpenCLIP 模型
            is_open_clip_sdxl_model(checkpoint),  # 检查是否是 OpenCLIP SDXL 模型
            is_open_clip_sdxl_refiner_model(checkpoint),  # 检查是否是 OpenCLIP SDXL Refiner 模型
            is_open_clip_sd3_model(checkpoint),  # 检查是否是 OpenCLIP SD3 模型
        ]
    )
    # 如果类对象名称是 CLIPTextModel 或 CLIPTextModelWithProjection,并且检查点中存在 CLIP 模型
    if (
        class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
    ) and is_clip_in_checkpoint:
        return True  # 返回 True,表示匹配

    # 否则返回 False
    return False


# 推断给定检查点的 Diffusers 模型类型
def infer_diffusers_model_type(checkpoint):
    # 检查点中是否包含“inpainting”关键字,并且其形状的第二维为 9
    if (
        CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
    ):
        # 检查点中是否包含“v2”关键字,并且其形状的最后一维为 1024
        if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
            model_type = "inpainting_v2"  # 设置模型类型为 inpainting_v2
        else:
            model_type = "inpainting"  # 设置模型类型为 inpainting

    # 检查点中是否仅包含“v2”关键字,并且其形状的最后一维为 1024
    elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
        model_type = "v2"  # 设置模型类型为 v2

    # 检查点中是否包含“playground-v2-5”关键字
    elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
        model_type = "playground-v2-5"  # 设置模型类型为 playground-v2-5

    # 检查点中是否包含“xl_base”关键字
    elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
        model_type = "xl_base"  # 设置模型类型为 xl_base

    # 检查点中是否包含“xl_refiner”关键字
    elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
        model_type = "xl_refiner"  # 设置模型类型为 xl_refiner

    # 检查点中是否包含“upscale”关键字
    elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
        model_type = "upscale"  # 设置模型类型为 upscale

    # 检查点中是否包含“controlnet”关键字
    elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
        model_type = "controlnet"  # 设置模型类型为 controlnet

    # 检查点中是否包含“stable_cascade_stage_c”关键字,且其形状的第一维为 1536
    elif (
        CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
    ):
        model_type = "stable_cascade_stage_c_lite"  # 设置模型类型为 stable_cascade_stage_c_lite

    # 检查点中是否包含“stable_cascade_stage_c”关键字,且其形状的第一维为 2048
    elif (
        CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
    ):
        model_type = "stable_cascade_stage_c"  # 设置模型类型为 stable_cascade_stage_c

    # 检查点中是否包含“stable_cascade_stage_b”关键字,且其形状的最后一维为 576
    elif (
        CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
    ):
        model_type = "stable_cascade_stage_b_lite"  # 设置模型类型为 stable_cascade_stage_b_lite

    # 检查点中是否包含“stable_cascade_stage_b”关键字,且其形状的最后一维为 640
    elif (
        CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
    ):
        model_type = "stable_cascade_stage_b"  # 设置模型类型为 stable_cascade_stage_b

    # 检查点中是否包含“sd3”关键字
    elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
        model_type = "sd3"  # 设置模型类型为 sd3
    # 检查 checkpoint 中是否包含 "animatediff" 的键
    elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
        # 检查 checkpoint 中是否包含 "animatediff_scribble" 的键
        if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
            # 设置模型类型为 "animatediff_scribble"
            model_type = "animatediff_scribble"

        # 检查 checkpoint 中是否包含 "animatediff_rgb" 的键
        elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
            # 设置模型类型为 "animatediff_rgb"
            model_type = "animatediff_rgb"

        # 检查 checkpoint 中是否包含 "animatediff_v2" 的键
        elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
            # 设置模型类型为 "animatediff_v2"
            model_type = "animatediff_v2"

        # 检查 checkpoint 中 "animatediff_sdxl_beta" 的形状最后一维是否为 320
        elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
            # 设置模型类型为 "animatediff_sdxl_beta"
            model_type = "animatediff_sdxl_beta"

        # 检查 checkpoint 中 "animatediff" 的形状第二维是否为 24
        elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
            # 设置模型类型为 "animatediff_v1"
            model_type = "animatediff_v1"

        # 以上条件都不满足时
        else:
            # 设置模型类型为 "animatediff_v3"
            model_type = "animatediff_v3"

    # 检查 checkpoint 中是否包含 "flux" 相关的任意键
    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
        # 检查 checkpoint 中是否包含特定的权重偏置键
        if any(
            g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
        ):
            # 设置模型类型为 "flux-dev"
            model_type = "flux-dev"
        # 如果不包含特定的权重偏置键
        else:
            # 设置模型类型为 "flux-schnell"
            model_type = "flux-schnell"
    # 以上条件都不满足时
    else:
        # 设置模型类型为 "v1"
        model_type = "v1"

    # 返回确定的模型类型
    return model_type
# 根据检查点获取 diffuser 配置
def fetch_diffusers_config(checkpoint):
    # 推断模型类型
    model_type = infer_diffusers_model_type(checkpoint)
    # 从默认路径获取模型路径
    model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]

    # 返回模型路径
    return model_path


# 设置图像大小,如果未提供,则基于检查点推断
def set_image_size(checkpoint, image_size=None):
    # 如果提供了图像大小,直接返回
    if image_size:
        return image_size

    # 推断模型类型
    model_type = infer_diffusers_model_type(checkpoint)
    # 根据模型类型获取默认图像大小
    image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]

    # 返回图像大小
    return image_size


# 从检查点转换卷积注意力为线性形式
def conv_attn_to_linear(checkpoint):
    # 获取检查点中的所有键
    keys = list(checkpoint.keys())
    # 定义需要转换的注意力权重键
    attn_keys = ["query.weight", "key.weight", "value.weight"]
    # 遍历每个键
    for key in keys:
        # 如果键属于注意力权重
        if ".".join(key.split(".")[-2:]) in attn_keys:
            # 如果权重维度大于2,则只保留第一维
            if checkpoint[key].ndim > 2:
                checkpoint[key] = checkpoint[key][:, :, 0, 0]
        # 如果键是投影注意力权重
        elif "proj_attn.weight" in key:
            # 如果权重维度大于2,则只保留第一维
            if checkpoint[key].ndim > 2:
                checkpoint[key] = checkpoint[key][:, :, 0]


# 从 LDM 创建 UNet diffuser 配置
def create_unet_diffusers_config_from_ldm(
    original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
):
    """
    基于 LDM 模型配置创建 diffuser 配置。
    """
    # 如果提供了图像大小,记录弃用信息
    if image_size is not None:
        deprecation_message = (
            "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
            "is deprecated and will be ignored in future versions."
        )
        # 调用弃用警告函数
        deprecate("image_size", "1.0.0", deprecation_message)

    # 设置图像大小
    image_size = set_image_size(checkpoint, image_size=image_size)

    # 获取 UNet 参数,如果存在 unet_config
    if (
        "unet_config" in original_config["model"]["params"]
        and original_config["model"]["params"]["unet_config"] is not None
    ):
        unet_params = original_config["model"]["params"]["unet_config"]["params"]
    else:
        # 否则从网络配置中获取
        unet_params = original_config["model"]["params"]["network_config"]["params"]

    # 如果提供了输入通道数,记录弃用信息
    if num_in_channels is not None:
        deprecation_message = (
            "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
            "is deprecated and will be ignored in future versions."
        )
        # 调用弃用警告函数
        deprecate("image_size", "1.0.0", deprecation_message)
        # 设置输入通道数
        in_channels = num_in_channels
    else:
        # 否则从 UNet 参数中获取输入通道数
        in_channels = unet_params["in_channels"]

    # 获取 VAE 参数
    vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
    # 计算每个块的输出通道数
    block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]

    # 定义向下块的类型
    down_block_types = []
    resolution = 1
    # 遍历每个输出通道块
    for i in range(len(block_out_channels)):
        # 根据分辨率选择块类型
        block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
        down_block_types.append(block_type)
        # 更新分辨率
        if i != len(block_out_channels) - 1:
            resolution *= 2

    # 定义向上块的类型
    up_block_types = []
    # 遍历输出通道的数量
    for i in range(len(block_out_channels)):
        # 根据分辨率判断块类型,选择跨注意力块或上采样块
        block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
        # 将块类型添加到上采样块列表中
        up_block_types.append(block_type)
        # 更新分辨率,进行下一个块的处理
        resolution //= 2

    # 检查是否设置了变换器的深度
    if unet_params["transformer_depth"] is not None:
        # 获取每个块的变换器层数,支持整型或列表
        transformer_layers_per_block = (
            unet_params["transformer_depth"]
            if isinstance(unet_params["transformer_depth"], int)
            else list(unet_params["transformer_depth"])
        )
    else:
        # 如果没有设置,默认每块只有一层
        transformer_layers_per_block = 1

    # 计算 VAE 的缩放因子
    vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)

    # 获取头部维度,如果存在的话
    head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
    # 检查是否使用线性投影
    use_linear_projection = (
        unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
    )
    # 如果使用线性投影
    if use_linear_projection:
        # 针对稳定扩散 2 的特定配置
        if head_dim is None:
            # 计算头部维度的乘数
            head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
            # 根据通道乘数生成头部维度列表
            head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]

    # 初始化额外的嵌入类型和维度
    class_embed_type = None
    addition_embed_type = None
    addition_time_embed_dim = None
    projection_class_embeddings_input_dim = None
    context_dim = None

    # 如果上下文维度存在
    if unet_params["context_dim"] is not None:
        # 获取上下文维度,支持整型或列表
        context_dim = (
            unet_params["context_dim"]
            if isinstance(unet_params["context_dim"], int)
            else unet_params["context_dim"][0]
        )

    # 检查类别数量设置
    if "num_classes" in unet_params:
        # 如果类别为顺序
        if unet_params["num_classes"] == "sequential":
            # 根据上下文维度决定额外嵌入类型
            if context_dim in [2048, 1280]:
                # 针对 SDXL 的配置
                addition_embed_type = "text_time"
                addition_time_embed_dim = 256
            else:
                # 其他情况下使用投影嵌入
                class_embed_type = "projection"
            # 确保包含 ADM 输入通道
            assert "adm_in_channels" in unet_params
            # 获取投影嵌入的输入维度
            projection_class_embeddings_input_dim = unet_params["adm_in_channels"]

    # 配置字典,包含各种模型参数
    config = {
        # 计算样本大小
        "sample_size": image_size // vae_scale_factor,
        # 输入通道数
        "in_channels": in_channels,
        # 各个下采样块类型
        "down_block_types": down_block_types,
        # 输出通道数量
        "block_out_channels": block_out_channels,
        # 每个块的层数
        "layers_per_block": unet_params["num_res_blocks"],
        # 上下文维度
        "cross_attention_dim": context_dim,
        # 注意力头的维度
        "attention_head_dim": head_dim,
        # 是否使用线性投影
        "use_linear_projection": use_linear_projection,
        # 类别嵌入类型
        "class_embed_type": class_embed_type,
        # 额外嵌入类型
        "addition_embed_type": addition_embed_type,
        # 额外时间嵌入维度
        "addition_time_embed_dim": addition_time_embed_dim,
        # 投影类别嵌入的输入维度
        "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
        # 每个块的变换器层数
        "transformer_layers_per_block": transformer_layers_per_block,
    }
    # 检查是否提供了 upcast_attention 参数
    if upcast_attention is not None:
        # 构造弃用提示信息,告知用户该参数在未来版本中将被忽略
        deprecation_message = (
            "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
            "is deprecated and will be ignored in future versions."
        )
        # 调用 deprecate 函数,记录该参数的弃用信息
        deprecate("image_size", "1.0.0", deprecation_message)
        # 将 upcast_attention 参数存储到 config 字典中
        config["upcast_attention"] = upcast_attention

    # 检查 unet_params 中是否包含 disable_self_attentions 键
    if "disable_self_attentions" in unet_params:
        # 如果存在,设置 config 字典中的 only_cross_attention 为对应值
        config["only_cross_attention"] = unet_params["disable_self_attentions"]

    # 检查 unet_params 中是否包含 num_classes 键且其值为整数
    if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
        # 将 num_classes 的值存储到 config 字典中的 num_class_embeds 键
        config["num_class_embeds"] = unet_params["num_classes"]

    # 将 unet_params 中的 out_channels 值存储到 config 字典中
    config["out_channels"] = unet_params["out_channels"]
    # 将 up_block_types 存储到 config 字典中
    config["up_block_types"] = up_block_types

    # 返回配置字典
    return config
# 从 LDM 模型的配置创建 ControlNet 的 Diffusers 配置
def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
    # 检查 image_size 参数是否提供
    if image_size is not None:
        # 创建弃用提示信息
        deprecation_message = (
            "Configuring ControlNetModel with the `image_size` argument"
            "is deprecated and will be ignored in future versions."
        )
        # 调用 deprecate 函数记录弃用信息
        deprecate("image_size", "1.0.0", deprecation_message)

    # 设置 image_size,使用检查点中的值
    image_size = set_image_size(checkpoint, image_size=image_size)

    # 从原始配置中提取 UNet 参数
    unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
    # 创建 Diffusers 的 UNet 配置
    diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)

    # 构建 ControlNet 配置字典
    controlnet_config = {
        "conditioning_channels": unet_params["hint_channels"],
        "in_channels": diffusers_unet_config["in_channels"],
        "down_block_types": diffusers_unet_config["down_block_types"],
        "block_out_channels": diffusers_unet_config["block_out_channels"],
        "layers_per_block": diffusers_unet_config["layers_per_block"],
        "cross_attention_dim": diffusers_unet_config["cross_attention_dim"],
        "attention_head_dim": diffusers_unet_config["attention_head_dim"],
        "use_linear_projection": diffusers_unet_config["use_linear_projection"],
        "class_embed_type": diffusers_unet_config["class_embed_type"],
        "addition_embed_type": diffusers_unet_config["addition_embed_type"],
        "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"],
        "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"],
        "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"],
    }

    # 返回构建好的 ControlNet 配置
    return controlnet_config


# 从 LDM 模型的配置创建 VAE 的 Diffusers 配置
def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
    """
    根据 LDM 模型的配置创建 Diffusers 配置。
    """
    # 检查 image_size 参数是否提供
    if image_size is not None:
        # 创建弃用提示信息
        deprecation_message = (
            "Configuring AutoencoderKL with the `image_size` argument"
            "is deprecated and will be ignored in future versions."
        )
        # 调用 deprecate 函数记录弃用信息
        deprecate("image_size", "1.0.0", deprecation_message)

    # 设置 image_size,使用检查点中的值
    image_size = set_image_size(checkpoint, image_size=image_size)

    # 检查检查点中是否包含 edm_mean 和 edm_std
    if "edm_mean" in checkpoint and "edm_std" in checkpoint:
        # 提取潜变量的均值
        latents_mean = checkpoint["edm_mean"]
        # 提取潜变量的标准差
        latents_std = checkpoint["edm_std"]
    else:
        # 如果未提供,设置为 None
        latents_mean = None
        latents_std = None

    # 从原始配置中提取 VAE 参数
    vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
    # 根据条件设置缩放因子
    if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
        scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR

    elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
        scaling_factor = original_config["model"]["params"]["scale_factor"]

    elif scaling_factor is None:
        scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
    # 计算每个块的输出通道数,乘以相应的倍数
        block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
        # 生成与输出通道数相同数量的下采样块类型
        down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
        # 生成与输出通道数相同数量的上采样块类型
        up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
    
        # 创建配置字典,存储各类参数
        config = {
            # 图像样本大小
            "sample_size": image_size,
            # 输入通道数
            "in_channels": vae_params["in_channels"],
            # 输出通道数
            "out_channels": vae_params["out_ch"],
            # 下采样块类型列表
            "down_block_types": down_block_types,
            # 上采样块类型列表
            "up_block_types": up_block_types,
            # 块的输出通道数
            "block_out_channels": block_out_channels,
            # 潜在通道数
            "latent_channels": vae_params["z_channels"],
            # 每个块的层数
            "layers_per_block": vae_params["num_res_blocks"],
            # 缩放因子
            "scaling_factor": scaling_factor,
        }
        # 如果潜在均值和标准差不为 None,更新配置字典
        if latents_mean is not None and latents_std is not None:
            config.update({"latents_mean": latents_mean, "latents_std": latents_std})
    
        # 返回配置字典
        return config
# 更新 UNet 中 ResNet 结构的 LDM 到 Diffusers 格式
def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None):
    # 遍历所有 LDM 键
    for ldm_key in ldm_keys:
        # 根据规则替换 LDM 键为对应的 Diffusers 键
        diffusers_key = (
            ldm_key.replace("in_layers.0", "norm1")
            .replace("in_layers.2", "conv1")
            .replace("out_layers.0", "norm2")
            .replace("out_layers.3", "conv2")
            .replace("emb_layers.1", "time_emb_proj")
            .replace("skip_connection", "conv_shortcut")
        )
        # 如果有映射,则替换旧键为新键
        if mapping:
            diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
        # 从 checkpoint 获取 LDM 键的数据并存入新的 checkpoint
        new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)


# 更新 UNet 中注意力结构的 LDM 到 Diffusers 格式
def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
    # 遍历所有 LDM 键
    for ldm_key in ldm_keys:
        # 根据映射替换 LDM 键为对应的 Diffusers 键
        diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
        # 从 checkpoint 获取 LDM 键的数据并存入新的 checkpoint
        new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)


# 更新 VAE 中 ResNet 结构的 LDM 到 Diffusers 格式
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
    # 遍历所有 LDM 键
    for ldm_key in keys:
        # 根据映射替换 LDM 键,并替换特定的 nin_shortcut
        diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
        # 从 checkpoint 获取 LDM 键的数据并存入新的 checkpoint
        new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)


# 更新 VAE 中注意力结构的 LDM 到 Diffusers 格式
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
    # 遍历所有 LDM 键
    for ldm_key in keys:
        # 根据映射替换 LDM 键为对应的 Diffusers 键,并进行多个字段的替换
        diffusers_key = (
            ldm_key.replace(mapping["old"], mapping["new"])
            .replace("norm.weight", "group_norm.weight")
            .replace("norm.bias", "group_norm.bias")
            .replace("q.weight", "to_q.weight")
            .replace("q.bias", "to_q.bias")
            .replace("k.weight", "to_k.weight")
            .replace("k.bias", "to_k.bias")
            .replace("v.weight", "to_v.weight")
            .replace("v.bias", "to_v.bias")
            .replace("proj_out.weight", "to_out.0.weight")
            .replace("proj_out.bias", "to_out.0.bias")
        )
        # 从 checkpoint 获取 LDM 键的数据并存入新的 checkpoint
        new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)

        # proj_attn.weight 需要从一维卷积转换为线性
        shape = new_checkpoint[diffusers_key].shape

        # 如果形状为三维,截取第一个维度
        if len(shape) == 3:
            new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
        # 如果形状为四维,截取前两个维度
        elif len(shape) == 4:
            new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]


# 转换稳定级联 UNet 单文件到 Diffusers 格式
def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
    # 检查是否包含特定的权重
    is_stage_c = "clip_txt_mapper.weight" in checkpoint
    # 检查是否处于阶段 C
    if is_stage_c:
        # 初始化一个空字典,用于存储状态
        state_dict = {}
        # 遍历检查点中的所有键
        for key in checkpoint.keys():
            # 如果键以 "in_proj_weight" 结尾
            if key.endswith("in_proj_weight"):
                # 将权重分块为三个部分
                weights = checkpoint[key].chunk(3, 0)
                # 替换键名并保存对应的权重到字典
                state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
                state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
                state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
            # 如果键以 "in_proj_bias" 结尾
            elif key.endswith("in_proj_bias"):
                # 将偏置分块为三个部分
                weights = checkpoint[key].chunk(3, 0)
                # 替换键名并保存对应的偏置到字典
                state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
                state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
                state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
            # 如果键以 "out_proj.weight" 结尾
            elif key.endswith("out_proj.weight"):
                # 获取权重
                weights = checkpoint[key]
                # 替换键名并保存对应的权重到字典
                state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
            # 如果键以 "out_proj.bias" 结尾
            elif key.endswith("out_proj.bias"):
                # 获取偏置
                weights = checkpoint[key]
                # 替换键名并保存对应的偏置到字典
                state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
            # 对于其它情况,直接保存键值对
            else:
                state_dict[key] = checkpoint[key]
    # 如果不在阶段 C
    else:
        # 初始化一个空字典,用于存储状态
        state_dict = {}
        # 遍历检查点中的所有键
        for key in checkpoint.keys():
            # 如果键以 "in_proj_weight" 结尾
            if key.endswith("in_proj_weight"):
                # 将权重分块为三个部分
                weights = checkpoint[key].chunk(3, 0)
                # 替换键名并保存对应的权重到字典
                state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
                state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
                state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
            # 如果键以 "in_proj_bias" 结尾
            elif key.endswith("in_proj_bias"):
                # 将偏置分块为三个部分
                weights = checkpoint[key].chunk(3, 0)
                # 替换键名并保存对应的偏置到字典
                state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
                state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
                state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
            # 如果键以 "out_proj.weight" 结尾
            elif key.endswith("out_proj.weight"):
                # 获取权重
                weights = checkpoint[key]
                # 替换键名并保存对应的权重到字典
                state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
            # 如果键以 "out_proj.bias" 结尾
            elif key.endswith("out_proj.bias"):
                # 获取偏置
                weights = checkpoint[key]
                # 替换键名并保存对应的偏置到字典
                state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
            # 如果键以 "clip_mapper.weight" 结尾
            elif key.endswith("clip_mapper.weight"):
                # 获取权重
                weights = checkpoint[key]
                # 替换键名并保存对应的权重到字典
                state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
            # 如果键以 "clip_mapper.bias" 结尾
            elif key.endswith("clip_mapper.bias"):
                # 获取偏置
                weights = checkpoint[key]
                # 替换键名并保存对应的偏置到字典
                state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
            # 对于其它情况,直接保存键值对
            else:
                state_dict[key] = checkpoint[key]

    # 返回构建好的状态字典
    return state_dict
# 转换 LDM UNet 检查点,接受检查点和配置,并返回转换后的检查点
def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
    """
    接受状态字典和配置,并返回转换后的检查点。
    """
    # 创建一个字典用于提取 UNet 的状态字典
    unet_state_dict = {}
    # 获取检查点的所有键
    keys = list(checkpoint.keys())
    # 定义 UNet 的关键字
    unet_key = LDM_UNET_KEY

    # 检查有多少参数以 `model_ema` 开头,以确定是否为 EMA 检查点
    if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
        # 记录警告:检查点包含 EMA 和非 EMA 权重
        logger.warning("Checkpoint has both EMA and non-EMA weights.")
        # 记录警告:仅提取 EMA 权重
        logger.warning(
            "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
            " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
        )
        # 遍历所有键
        for key in keys:
            # 如果键以 "model.diffusion_model" 开头
            if key.startswith("model.diffusion_model"):
                # 替换键前缀并获取对应的 EMA 权重
                flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
                unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
    else:
        # 如果存在 EMA 权重,但不提取 EMA
        if sum(k.startswith("model_ema") for k in keys) > 100:
            # 记录警告:仅提取非 EMA 权重
            logger.warning(
                "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
                " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
            )
        # 遍历所有键
        for key in keys:
            # 如果键以 UNet 的关键字开头
            if key.startswith(unet_key):
                # 将对应的权重添加到 UNet 状态字典中
                unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)

    # 创建一个新的检查点字典
    new_checkpoint = {}
    # 获取 UNet 图层的映射键
    ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
    # 遍历 Diffusers 和 LDM 键的映射
    for diffusers_key, ldm_key in ldm_unet_keys.items():
        # 如果 LDM 键不在 UNet 状态字典中,则跳过
        if ldm_key not in unet_state_dict:
            continue
        # 将 UNet 状态字典中的权重添加到新的检查点
        new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]

    # 如果配置中存在 class_embed_type 且其值为 "timestep" 或 "projection"
    if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]):
        # 获取对应的 class_embed 键映射
        class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"]
        # 遍历并添加 class_embed 的权重到新的检查点
        for diffusers_key, ldm_key in class_embed_keys.items():
            new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]

    # 如果配置中存在 addition_embed_type 且其值为 "text_time"
    if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"):
        # 获取对应的 addition_embed 键映射
        addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"]
        # 遍历并添加 addition_embed 的权重到新的检查点
        for diffusers_key, ldm_key in addition_embed_keys.items():
            new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]

    # 与 StableDiffusionUpscalePipeline 相关
    if "num_class_embeds" in config:
        # 检查 num_class_embeds 是否不为空且 UNet 状态字典中存在 label_emb.weight
        if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
            # 将 label_emb.weight 添加到新的检查点中
            new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]

    # 仅获取输入块的键
    num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
    # 创建一个字典,存储每个输入块的相关键
    input_blocks = {
        # 遍历每个输入块的层ID,生成对应的键列表
        layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
        for layer_id in range(num_input_blocks)
    }

    # 获取中间块的数量
    num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
    # 创建一个字典,存储每个中间块的相关键
    middle_blocks = {
        # 遍历每个中间块的层ID,生成对应的键列表
        layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
        for layer_id in range(num_middle_blocks)
    }

    # 获取输出块的数量
    num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
    # 创建一个字典,存储每个输出块的相关键
    output_blocks = {
        # 遍历每个输出块的层ID,生成对应的键列表
        layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
        for layer_id in range(num_output_blocks)
    }

    # 处理输入块
    for i in range(1, num_input_blocks):
        # 计算当前块的ID
        block_id = (i - 1) // (config["layers_per_block"] + 1)
        # 计算当前层在块内的ID
        layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)

        # 找到当前输入块的残差连接
        resnets = [
            key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
        ]
        # 更新 UNet 残差连接到新的检查点
        update_unet_resnet_ldm_to_diffusers(
            resnets,
            new_checkpoint,
            unet_state_dict,
            {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
        )

        # 如果当前输入块的权重存在,则将其更新到新的检查点
        if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
                f"input_blocks.{i}.0.op.weight"
            )
            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
                f"input_blocks.{i}.0.op.bias"
            )

        # 找到当前输入块的注意力连接
        attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
        # 如果注意力连接存在,更新它到新的检查点
        if attentions:
            update_unet_attention_ldm_to_diffusers(
                attentions,
                new_checkpoint,
                unet_state_dict,
                {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
            )

    # 处理中间块
    for key in middle_blocks.keys():
        # 计算对应的 diffusers 键
        diffusers_key = max(key - 1, 0)
        # 如果是偶数层,更新残差连接
        if key % 2 == 0:
            update_unet_resnet_ldm_to_diffusers(
                middle_blocks[key],
                new_checkpoint,
                unet_state_dict,
                mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
            )
        # 如果是奇数层,更新注意力连接
        else:
            update_unet_attention_ldm_to_diffusers(
                middle_blocks[key],
                new_checkpoint,
                unet_state_dict,
                mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
            )

    # 处理上升块
    # 遍历输出块的数量
    for i in range(num_output_blocks):
        # 计算当前块的 ID
        block_id = i // (config["layers_per_block"] + 1)
        # 计算当前层在块中的 ID
        layer_in_block_id = i % (config["layers_per_block"] + 1)

        # 筛选当前输出块中的 ResNet 相关键,排除特定的操作键
        resnets = [
            key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key
        ]
        # 更新 U-Net 中 ResNet 的状态字典,映射旧的键到新的键
        update_unet_resnet_ldm_to_diffusers(
            resnets,
            new_checkpoint,
            unet_state_dict,
            {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"},
        )

        # 筛选当前输出块中的注意力相关键,排除特定的卷积键
        attentions = [
            key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key
        ]
        # 如果找到注意力键,则更新状态字典
        if attentions:
            update_unet_attention_ldm_to_diffusers(
                attentions,
                new_checkpoint,
                unet_state_dict,
                {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"},
            )

        # 如果在状态字典中找到当前卷积层的权重,则更新新的检查点字典
        if f"output_blocks.{i}.1.conv.weight" in unet_state_dict:
            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
                f"output_blocks.{i}.1.conv.weight"
            ]
            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
                f"output_blocks.{i}.1.conv.bias"
            ]
        # 如果在状态字典中找到下一个卷积层的权重,则更新新的检查点字典
        if f"output_blocks.{i}.2.conv.weight" in unet_state_dict:
            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
                f"output_blocks.{i}.2.conv.weight"
            ]
            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
                f"output_blocks.{i}.2.conv.bias"
            ]

    # 返回更新后的检查点字典
    return new_checkpoint
# 定义一个函数,用于转换 ControlNet 的检查点文件
def convert_controlnet_checkpoint(
    checkpoint,  # 输入的检查点数据
    config,  # 配置参数
    **kwargs,  # 额外的关键字参数
):
    # 检查点中如果包含时间嵌入权重,则将其直接赋值
    if "time_embed.0.weight" in checkpoint:
        controlnet_state_dict = checkpoint
    # 否则,初始化空的状态字典
    else:
        controlnet_state_dict = {}
        keys = list(checkpoint.keys())  # 获取检查点的所有键
        controlnet_key = LDM_CONTROLNET_KEY  # 定义 ControlNet 的关键字
        # 遍历检查点中的所有键
        for key in keys:
            # 如果键以 ControlNet 的关键字开头,则提取相关数据
            if key.startswith(controlnet_key):
                controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)

    new_checkpoint = {}  # 初始化新的检查点字典
    ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]  # 获取 LDM 和 Diffusers 的映射
    # 将 Diffusers 的键映射到新的检查点
    for diffusers_key, ldm_key in ldm_controlnet_keys.items():
        if ldm_key not in controlnet_state_dict:  # 如果 LDM 键不存在,则跳过
            continue
        new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key]  # 添加映射数据到新检查点

    # 仅检索输入块的键
    num_input_blocks = len(
        {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer}
    )  # 计算输入块的数量
    # 创建输入块字典
    input_blocks = {
        layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key]
        for layer_id in range(num_input_blocks)
    }

    # 处理下块
    for i in range(1, num_input_blocks):  # 从第一个输入块开始处理
        block_id = (i - 1) // (config["layers_per_block"] + 1)  # 计算当前块的ID
        layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)  # 计算当前层在块中的ID

        # 获取当前块中的所有 ResNet
        resnets = [
            key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
        ]
        # 更新 UNet 中 ResNet 的映射
        update_unet_resnet_ldm_to_diffusers(
            resnets,  # ResNet 列表
            new_checkpoint,  # 新检查点
            controlnet_state_dict,  # 原状态字典
            {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
        )

        # 如果有权重数据,则映射到新的检查点
        if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
                f"input_blocks.{i}.0.op.weight"
            )
            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
                f"input_blocks.{i}.0.op.bias"
            )

        # 获取当前块中的所有注意力层
        attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
        if attentions:  # 如果存在注意力层,则进行更新
            update_unet_attention_ldm_to_diffusers(
                attentions,
                new_checkpoint,
                controlnet_state_dict,
                {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
            )

    # 处理 ControlNet 的下块
    for i in range(num_input_blocks):
        # 将零卷积的权重映射到新的检查点
        new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
        new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
    # 仅检索中间块的键
    num_middle_blocks = len(
        # 从 controlnet_state_dict 中提取包含 'middle_block' 的层的前两个部分,去重后形成集合
        {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer}
    )
    # 为每个中间块的层 ID 创建一个字典,映射到对应的控制网络状态字典中的键列表
    middle_blocks = {
        layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
        for layer_id in range(num_middle_blocks)
    }
    
    # 遍历中间块的键
    for key in middle_blocks.keys():
        # 获取前一个块的索引,确保不为负数
        diffusers_key = max(key - 1, 0)
        # 如果键是偶数,调用更新函数处理 ResNet
        if key % 2 == 0:
            update_unet_resnet_ldm_to_diffusers(
                middle_blocks[key],
                new_checkpoint,
                controlnet_state_dict,
                # 映射旧的和新的层名称
                mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
            )
        # 如果键是奇数,调用更新函数处理 Attention
        else:
            update_unet_attention_ldm_to_diffusers(
                middle_blocks[key],
                new_checkpoint,
                controlnet_state_dict,
                # 映射旧的和新的层名称
                mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
            )
    
    # 处理中间块的输出权重和偏差
    new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
    new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
    
    # 控制网络条件嵌入块
    cond_embedding_blocks = {
        # 提取包含 'input_hint_block' 的层的前两部分,去重后形成集合,排除特定的键
        ".".join(layer.split(".")[:2])
        for layer in controlnet_state_dict
        if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
    }
    # 计算条件嵌入块的数量
    num_cond_embedding_blocks = len(cond_embedding_blocks)
    
    # 遍历条件嵌入块索引
    for idx in range(1, num_cond_embedding_blocks + 1):
        diffusers_idx = idx - 1  # 转换为 Diffusers 索引
        cond_block_id = 2 * idx  # 计算条件块的 ID
    
        # 从控制网络状态字典获取对应权重并添加到新检查点
        new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
            f"input_hint_block.{cond_block_id}.weight"
        )
        # 从控制网络状态字典获取对应偏差并添加到新检查点
        new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
            f"input_hint_block.{cond_block_id}.bias"
        )
    
    # 返回更新后的检查点
    return new_checkpoint
# 将 LDM VAE 检查点转换为适用于 Diffusers 的格式
def convert_ldm_vae_checkpoint(checkpoint, config):
    # 提取 VAE 的状态字典
    vae_state_dict = {}
    # 获取检查点中所有键的列表
    keys = list(checkpoint.keys())
    vae_key = ""
    # 找到 LDM VAE 关键字
    for ldm_vae_key in LDM_VAE_KEYS:
        # 检查是否有键以当前 LDM VAE 关键字开头
        if any(k.startswith(ldm_vae_key) for k in keys):
            vae_key = ldm_vae_key

    # 从检查点中提取与 VAE 相关的键
    for key in keys:
        # 如果键以 VAE 关键字开头,则替换关键字并存储在状态字典中
        if key.startswith(vae_key):
            vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)

    new_checkpoint = {}
    # 获取 VAE 对应的 Diffusers 键映射
    vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"]
    # 构建新的检查点字典
    for diffusers_key, ldm_key in vae_diffusers_ldm_map.items():
        # 如果状态字典中没有对应的 LDM 键,则跳过
        if ldm_key not in vae_state_dict:
            continue
        # 将 LDM 状态字典中的值映射到新的检查点中
        new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]

    # 仅获取编码器下块的键
    num_down_blocks = len(config["down_block_types"])
    # 构建每个下块的键映射
    down_blocks = {
        layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
    }

    # 遍历下块进行处理
    for i in range(num_down_blocks):
        # 获取当前下块的所有 ResNet 键
        resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
        # 更新 VAE ResNet 的键映射到 Diffusers
        update_vae_resnet_ldm_to_diffusers(
            resnets,
            new_checkpoint,
            vae_state_dict,
            mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
        )
        # 如果存在下采样权重,则添加到新的检查点中
        if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
            new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
                f"encoder.down.{i}.downsample.conv.weight"
            )
            new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
                f"encoder.down.{i}.downsample.conv.bias"
            )

    # 获取中间 ResNet 的键
    mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
    num_mid_res_blocks = 2
    # 遍历中间 ResNet 进行处理
    for i in range(1, num_mid_res_blocks + 1):
        # 获取当前中间块的所有 ResNet 键
        resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
        # 更新中间 ResNet 的键映射到 Diffusers
        update_vae_resnet_ldm_to_diffusers(
            resnets,
            new_checkpoint,
            vae_state_dict,
            mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
        )

    # 获取中间注意力的键
    mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
    # 更新中间注意力的键映射到 Diffusers
    update_vae_attentions_ldm_to_diffusers(
        mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
    )

    # 仅获取解码器上块的键
    num_up_blocks = len(config["up_block_types"])
    # 构建每个上块的键映射
    up_blocks = {
        layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
    }
    # 遍历向上块的数量
    for i in range(num_up_blocks):
        # 计算当前块的 ID,从最后一个块开始往前
        block_id = num_up_blocks - 1 - i
        # 收集当前块的所有 ResNet 相关的键,排除上采样的键
        resnets = [
            key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
        ]
        # 更新 VAE 的 ResNet 模块到新的 Diffusers 格式
        update_vae_resnet_ldm_to_diffusers(
            resnets,
            new_checkpoint,
            vae_state_dict,
            # 映射旧键到新键
            mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
        )
        # 检查 VAE 状态字典中是否包含当前块的上采样卷积权重
        if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
            # 将上采样卷积权重更新到新检查点中
            new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
                f"decoder.up.{block_id}.upsample.conv.weight"
            ]
            # 将上采样卷积偏置更新到新检查点中
            new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
                f"decoder.up.{block_id}.upsample.conv.bias"
            ]

    # 收集中间块的所有 ResNet 相关的键
    mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
    # 定义中间 ResNet 块的数量
    num_mid_res_blocks = 2
    # 遍历中间块的数量
    for i in range(1, num_mid_res_blocks + 1):
        # 收集当前中间块的所有 ResNet 相关的键
        resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
        # 更新 VAE 的中间块 ResNet 模块到新的 Diffusers 格式
        update_vae_resnet_ldm_to_diffusers(
            resnets,
            new_checkpoint,
            vae_state_dict,
            # 映射旧键到新键
            mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
        )

    # 收集中间块的所有注意力相关的键
    mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
    # 更新 VAE 的中间注意力模块到新的 Diffusers 格式
    update_vae_attentions_ldm_to_diffusers(
        mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
    )
    # 将卷积注意力转换为线性注意力
    conv_attn_to_linear(new_checkpoint)

    # 返回更新后的新检查点
    return new_checkpoint
# 将 LDM-CLIP 检查点转换为文本模型的字典
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
    # 获取检查点的所有键,转换为列表
    keys = list(checkpoint.keys())
    # 初始化一个空字典,用于存储文本模型的键值对
    text_model_dict = {}

    # 创建一个空列表用于存储要移除的前缀
    remove_prefixes = []
    # 将预定义的前缀添加到列表中
    remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
    # 如果提供了移除前缀,则添加到列表
    if remove_prefix:
        remove_prefixes.append(remove_prefix)

    # 遍历所有键
    for key in keys:
        # 对每个前缀进行检查
        for prefix in remove_prefixes:
            # 如果键以当前前缀开头
            if key.startswith(prefix):
                # 替换前缀,得到新的键
                diffusers_key = key.replace(prefix, "")
                # 将原始检查点中的值赋给新的键
                text_model_dict[diffusers_key] = checkpoint.get(key)

    # 返回文本模型字典
    return text_model_dict


# 将 Open-CLIP 检查点转换为文本模型的字典
def convert_open_clip_checkpoint(
    text_model,
    checkpoint,
    prefix="cond_stage_model.model.",
):
    # 初始化一个空字典,用于存储文本模型的键值对
    text_model_dict = {}
    # 构造文本投影的键
    text_proj_key = prefix + "text_projection"

    # 如果文本投影键在检查点中
    if text_proj_key in checkpoint:
        # 获取文本投影的维度
        text_proj_dim = int(checkpoint[text_proj_key].shape[0])
    # 如果文本模型配置中有投影维度属性
    elif hasattr(text_model.config, "projection_dim"):
        # 获取该投影维度
        text_proj_dim = text_model.config.projection_dim
    # 否则使用默认的投影维度
    else:
        text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM

    # 获取检查点的所有键,转换为列表
    keys = list(checkpoint.keys())
    # 获取要忽略的键列表
    keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE

    # 获取 Open-CLIP 到 LDM 的映射
    openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"]
    # 遍历映射中的每个键
    for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items():
        # 将 LDM 键添加前缀
        ldm_key = prefix + ldm_key
        # 如果 LDM 键不在检查点中,则跳过
        if ldm_key not in checkpoint:
            continue
        # 如果 LDM 键在要忽略的键列表中,则跳过
        if ldm_key in keys_to_ignore:
            continue
        # 如果 LDM 键以文本投影结尾
        if ldm_key.endswith("text_projection"):
            # 转置并存储值到文本模型字典
            text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous()
        else:
            # 否则直接存储值到文本模型字典
            text_model_dict[diffusers_key] = checkpoint[ldm_key]
    # 遍历给定的键列表
        for key in keys:
            # 如果当前键在忽略的键列表中,则跳过
            if key in keys_to_ignore:
                continue
    
            # 如果当前键不以指定前缀开头,则跳过
            if not key.startswith(prefix + "transformer."):
                continue
    
            # 移除前缀,得到变换器的键
            diffusers_key = key.replace(prefix + "transformer.", "")
            # 获取变换器到 LDM 映射的字典
            transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"]
            # 遍历新旧键的映射
            for new_key, old_key in transformer_diffusers_to_ldm_map.items():
                # 替换旧键为新键,并移除特定后缀
                diffusers_key = (
                    diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "")
                )
    
            # 如果当前键以 ".in_proj_weight" 结尾
            if key.endswith(".in_proj_weight"):
                # 从检查点获取权重值
                weight_value = checkpoint.get(key)
    
                # 将权重值的子集赋值给查询投影权重
                text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
                # 将权重值的子集赋值给键投影权重
                text_model_dict[diffusers_key + ".k_proj.weight"] = (
                    weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
                )
                # 将权重值的子集赋值给值投影权重
                text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
    
            # 如果当前键以 ".in_proj_bias" 结尾
            elif key.endswith(".in_proj_bias"):
                # 从检查点获取偏置值
                weight_value = checkpoint.get(key)
                # 将偏置值的子集赋值给查询投影偏置
                text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
                # 将偏置值的子集赋值给键投影偏置
                text_model_dict[diffusers_key + ".k_proj.bias"] = (
                    weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
                )
                # 将偏置值的子集赋值给值投影偏置
                text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
            # 如果当前键既不是权重也不是偏置,则直接获取该键的值
            else:
                text_model_dict[diffusers_key] = checkpoint.get(key)
    
        # 返回最终的文本模型字典
        return text_model_dict
# 创建一个从 LDM 生成 Diffusers CLIP 模型的函数
def create_diffusers_clip_model_from_ldm(
    cls,  # 模型类
    checkpoint,  # 训练检查点
    subfolder="",  # 子文件夹名称,默认为空
    config=None,  # 配置参数,默认为 None
    torch_dtype=None,  # PyTorch 数据类型,默认为 None
    local_files_only=None,  # 是否仅使用本地文件,默认为 None
    is_legacy_loading=False,  # 是否为旧版加载,默认为 False
):
    # 如果提供了配置,则将其封装为字典
    if config:
        config = {"pretrained_model_name_or_path": config}
    # 如果未提供配置,则从检查点中获取配置
    else:
        config = fetch_diffusers_config(checkpoint)

    # 向后兼容处理
    # 旧版的 `from_single_file` 期望 CLIP 配置放在原始 transformers 模型库的缓存目录中
    # 而不是放在 Diffusers 模型的子文件夹中
    if is_legacy_loading:
        # 发出警告,提示用户进行兼容性更新
        logger.warning(
            (
                "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
                "the local cache directory with the necessary CLIP model config files. "
                "Attempting to load CLIP model from legacy cache directory."
            )
        )

        # 如果检查点是 CLIP 模型或 CLIP SDXL 模型
        if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
            clip_config = "openai/clip-vit-large-patch14"  # 设置 CLIP 配置为 OpenAI 的模型
            config["pretrained_model_name_or_path"] = clip_config  # 更新配置
            subfolder = ""  # 子文件夹设为空

        # 如果检查点是 OpenCLIP 模型
        elif is_open_clip_model(checkpoint):
            clip_config = "stabilityai/stable-diffusion-2"  # 设置 CLIP 配置为 StabilityAI 的模型
            config["pretrained_model_name_or_path"] = clip_config  # 更新配置
            subfolder = "text_encoder"  # 子文件夹设为 text_encoder

        # 如果不满足以上条件,使用默认 CLIP 配置
        else:
            clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"  # 设置为默认的 CLIP 配置
            config["pretrained_model_name_or_path"] = clip_config  # 更新配置
            subfolder = ""  # 子文件夹设为空

    # 从预训练配置加载模型配置
    model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
    # 根据是否可用的加速库选择上下文
    ctx = init_empty_weights if is_accelerate_available() else nullcontext
    # 使用上下文初始化模型
    with ctx():
        model = cls(model_config)  # 实例化模型

    # 获取位置嵌入的维度
    position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]

    # 如果检查点是 CLIP 模型
    if is_clip_model(checkpoint):
        diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)  # 转换检查点格式

    # 如果检查点是 CLIP SDXL 模型并且形状与位置嵌入维度匹配
    elif (
        is_clip_sdxl_model(checkpoint)
        and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
    ):
        diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)  # 转换检查点格式

    # 如果检查点是 CLIP SD3 模型并且形状与位置嵌入维度匹配
    elif (
        is_clip_sd3_model(checkpoint)
        and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
    ):
        diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")  # 转换检查点格式
        diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)  # 设置权重为单位矩阵

    # 如果检查点是 OpenCLIP 模型
    elif is_open_clip_model(checkpoint):
        prefix = "cond_stage_model.model."  # 设置前缀
        diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)  # 转换检查点格式

    # 如果检查点是 OpenCLIP SDXL 模型并且形状与位置嵌入维度匹配
    elif (
        is_open_clip_sdxl_model(checkpoint)
        and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
    ):
    # 检查条件,开始处理不同类型的检查点
    ):
        # 设置前缀,用于转换模型检查点
        prefix = "conditioner.embedders.1.model."
        # 将检查点转换为 diffusers 格式,使用指定前缀
        diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)

    # 检查是否为 SDXL 精炼器模型
    elif is_open_clip_sdxl_refiner_model(checkpoint):
        # 设置前缀,用于转换模型检查点
        prefix = "conditioner.embedders.0.model."
        # 将检查点转换为 diffusers 格式,使用指定前缀
        diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)

    # 检查是否为 SD3 模型,并验证位置嵌入维度是否匹配
    elif (
        is_open_clip_sd3_model(checkpoint)
        and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
    ):
        # 将检查点转换为 LDM 格式,指定文本编码器的前缀
        diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")

    # 如果以上条件都不满足,抛出异常
    else:
        raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")

    # 检查是否可以使用加速功能
    if is_accelerate_available():
        # 加载模型字典并处理意外的键
        unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
    else:
        # 加载模型状态字典,允许不严格匹配
        _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

    # 如果模型有忽略的意外键,进行过滤
    if model._keys_to_ignore_on_load_unexpected is not None:
        for pat in model._keys_to_ignore_on_load_unexpected:
            # 过滤掉与忽略模式匹配的意外键
            unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

    # 如果存在意外的键,记录警告
    if len(unexpected_keys) > 0:
        logger.warning(
            f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
        )

    # 如果指定了数据类型,转换模型
    if torch_dtype is not None:
        model.to(torch_dtype)

    # 设置模型为评估模式
    model.eval()

    # 返回初始化后的模型
    return model
# 定义一个私有方法来加载调度器
def _legacy_load_scheduler(
    cls,
    checkpoint,  # 传入检查点数据
    component_name,  # 组件名称
    original_config=None,  # 原始配置,可选
    **kwargs,  # 其他关键字参数
):
    # 从关键字参数获取调度器类型,默认值为 None
    scheduler_type = kwargs.get("scheduler_type", None)
    # 从关键字参数获取预测类型,默认值为 None
    prediction_type = kwargs.get("prediction_type", None)

    # 如果调度器类型不为 None,发出弃用警告
    if scheduler_type is not None:
        deprecation_message = (
            "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
            "Example:\n\n"
            "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
            "scheduler = DDIMScheduler()\n"
            "pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
        )
        # 调用弃用函数,记录 scheduler_type 的弃用信息
        deprecate("scheduler_type", "1.0.0", deprecation_message)

    # 如果预测类型不为 None,发出弃用警告
    if prediction_type is not None:
        deprecation_message = (
            "Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
            "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
            "Example:\n\n"
            "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
            'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
            "pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
        )
        # 调用弃用函数,记录 prediction_type 的弃用信息
        deprecate("prediction_type", "1.0.0", deprecation_message)

    # 初始化调度器配置为默认配置
    scheduler_config = SCHEDULER_DEFAULT_CONFIG
    # 推断模型类型
    model_type = infer_diffusers_model_type(checkpoint=checkpoint)

    # 获取全局步数,如果不存在则为 None
    global_step = checkpoint["global_step"] if "global_step" in checkpoint else None

    # 如果原始配置存在,获取训练时间步数,否则使用默认值
    if original_config:
        num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
    else:
        num_train_timesteps = 1000

    # 将训练时间步数存入调度器配置
    scheduler_config["num_train_timesteps"] = num_train_timesteps

    # 如果模型类型是 v2
    if model_type == "v2":
        if prediction_type is None:
            # 对于稳定扩散 2 基础版本,建议传递 `prediction_type=="epsilon"`,因为这里依赖于脆弱的全局步数参数
            prediction_type = "epsilon" if global_step == 875000 else "v_prediction"

    else:
        # 如果模型类型不是 v2,设置预测类型为 epsilon 或现有值
        prediction_type = prediction_type or "epsilon"

    # 将预测类型存入调度器配置
    scheduler_config["prediction_type"] = prediction_type

    # 根据模型类型设置调度器类型和相关参数
    if model_type in ["xl_base", "xl_refiner"]:
        scheduler_type = "euler"
    elif model_type == "playground":
        scheduler_type = "edm_dpm_solver_multistep"
    else:
        # 如果原始配置存在,获取 beta_start 和 beta_end 值
        if original_config:
            beta_start = original_config["model"]["params"].get("linear_start")
            beta_end = original_config["model"]["params"].get("linear_end")

        else:
            # 否则使用默认的 beta_start 和 beta_end 值
            beta_start = 0.02
            beta_end = 0.085

        # 将 beta 参数和其他调度器配置存入调度器配置
        scheduler_config["beta_start"] = beta_start
        scheduler_config["beta_end"] = beta_end
        scheduler_config["beta_schedule"] = "scaled_linear"
        scheduler_config["clip_sample"] = False
        scheduler_config["set_alpha_to_one"] = False
    # 处理特殊情况,StableDiffusionUpscale 管道有两个调度器
    if component_name == "low_res_scheduler":
        # 从配置中创建并返回一个调度器实例
        return cls.from_config(
            {
                # 设置 Beta 结束值
                "beta_end": 0.02,
                # 设置 Beta 调度类型
                "beta_schedule": "scaled_linear",
                # 设置 Beta 起始值
                "beta_start": 0.0001,
                # 是否剪辑样本
                "clip_sample": True,
                # 训练时间步数
                "num_train_timesteps": 1000,
                # 预测类型
                "prediction_type": "epsilon",
                # 训练的 Beta 值
                "trained_betas": None,
                # 方差类型
                "variance_type": "fixed_small",
            }
        )

    # 如果调度器类型为空
    if scheduler_type is None:
        # 从给定的调度器配置中创建调度器
        return cls.from_config(scheduler_config)

    # 如果调度器类型为 "pndm"
    elif scheduler_type == "pndm":
        # 设置跳过 PRK 步骤为真
        scheduler_config["skip_prk_steps"] = True
        # 从配置中创建 PNDM 调度器
        scheduler = PNDMScheduler.from_config(scheduler_config)

    # 如果调度器类型为 "lms"
    elif scheduler_type == "lms":
        # 从配置中创建 LMS 离散调度器
        scheduler = LMSDiscreteScheduler.from_config(scheduler_config)

    # 如果调度器类型为 "heun"
    elif scheduler_type == "heun":
        # 从配置中创建 Heun 离散调度器
        scheduler = HeunDiscreteScheduler.from_config(scheduler_config)

    # 如果调度器类型为 "euler"
    elif scheduler_type == "euler":
        # 从配置中创建 Euler 离散调度器
        scheduler = EulerDiscreteScheduler.from_config(scheduler_config)

    # 如果调度器类型为 "euler-ancestral"
    elif scheduler_type == "euler-ancestral":
        # 从配置中创建 Euler 祖先离散调度器
        scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)

    # 如果调度器类型为 "dpm"
    elif scheduler_type == "dpm":
        # 从配置中创建 DPM 多步调度器
        scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)

    # 如果调度器类型为 "ddim"
    elif scheduler_type == "ddim":
        # 从配置中创建 DDIM 调度器
        scheduler = DDIMScheduler.from_config(scheduler_config)

    # 如果调度器类型为 "edm_dpm_solver_multistep"
    elif scheduler_type == "edm_dpm_solver_multistep":
        # 定义调度器配置字典
        scheduler_config = {
            # 算法类型
            "algorithm_type": "dpmsolver++",
            # 动态阈值比例
            "dynamic_thresholding_ratio": 0.995,
            # 最终是否使用欧拉法
            "euler_at_final": False,
            # 最终 sigma 类型
            "final_sigmas_type": "zero",
            # 较低阶最终设置
            "lower_order_final": True,
            # 训练时间步数
            "num_train_timesteps": 1000,
            # 预测类型
            "prediction_type": "epsilon",
            # rho 值
            "rho": 7.0,
            # 样本最大值
            "sample_max_value": 1.0,
            # 数据 sigma
            "sigma_data": 0.5,
            # 最大 sigma
            "sigma_max": 80.0,
            # 最小 sigma
            "sigma_min": 0.002,
            # 求解器阶数
            "solver_order": 2,
            # 求解器类型
            "solver_type": "midpoint",
            # 是否使用阈值处理
            "thresholding": False,
        }
        # 从配置中创建 EDM DPM 多步调度器
        scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)

    # 如果调度器类型不匹配,抛出异常
    else:
        raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")

    # 返回创建的调度器实例
    return scheduler
# 定义一个类方法,用于加载旧版 CLIP 分词器
def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
    # 如果提供了 config,则将其包装为包含模型路径的字典
    if config:
        config = {"pretrained_model_name_or_path": config}
    # 如果未提供 config,则从检查点获取配置信息
    else:
        config = fetch_diffusers_config(checkpoint)

    # 检查点是否为 CLIP 模型或 CLIP SDXL 模型
    if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
        # 设置使用的 CLIP 配置
        clip_config = "openai/clip-vit-large-patch14"
        # 将配置中的模型路径设置为 CLIP 配置
        config["pretrained_model_name_or_path"] = clip_config
        # 设置子文件夹为空
        subfolder = ""

    # 检查点是否为 Open CLIP 模型
    elif is_open_clip_model(checkpoint):
        # 设置使用的 Open CLIP 配置
        clip_config = "stabilityai/stable-diffusion-2"
        # 将配置中的模型路径设置为 Open CLIP 配置
        config["pretrained_model_name_or_path"] = clip_config
        # 设置子文件夹为 tokenizer
        subfolder = "tokenizer"

    # 如果不是以上模型,则使用默认的 CLIP 配置
    else:
        clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
        # 将配置中的模型路径设置为默认 CLIP 配置
        config["pretrained_model_name_or_path"] = clip_config
        # 设置子文件夹为空
        subfolder = ""

    # 从预训练模型加载分词器,传入配置和子文件夹
    tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)

    # 返回加载的分词器
    return tokenizer


# 定义一个加载安全检查器的函数
def _legacy_load_safety_checker(local_files_only, torch_dtype):
    # 使用过时的 `load_safety_checker` 参数支持加载安全检查器组件

    # 从指定路径导入稳定扩散安全检查器
    from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker

    # 从预训练模型加载特征提取器
    feature_extractor = AutoImageProcessor.from_pretrained(
        "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
    )
    # 从预训练模型加载安全检查器
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
    )

    # 返回包含安全检查器和特征提取器的字典
    return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}


# 在 SD3 的原始实现中,AdaLayerNormContinuous 将线性投影输出分为 shift 和 scale;
# 而在 diffusers 中,顺序为 scale 和 shift。这里交换线性投影的权重,以便能使用 diffusers 的实现
def swap_scale_shift(weight, dim):
    # 将权重在指定维度分成两部分:shift 和 scale
    shift, scale = weight.chunk(2, dim=0)
    # 重新组合权重,将 scale 放在前面,shift 放在后面
    new_weight = torch.cat([scale, shift], dim=0)
    # 返回新的权重
    return new_weight


# 定义一个将 SD3 转换为 diffusers 的检查点转换函数
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
    # 初始化一个空的字典用于保存转换后的状态字典
    converted_state_dict = {}
    # 获取检查点中的所有键并转换为列表
    keys = list(checkpoint.keys())
    # 遍历所有键
    for k in keys:
        # 如果键包含特定字符串,则将其替换
        if "model.diffusion_model." in k:
            checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

    # 获取 joint_blocks 的层数,并计算出总层数
    num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1  # noqa: C401
    # 设置 caption projection 的维度
    caption_projection_dim = 1536

    # 处理位置和补丁嵌入
    converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
    converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
    converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")

    # 处理时间步嵌入
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
        "t_embedder.mlp.0.weight"
    )
    # 从检查点中弹出时间文本嵌入器的线性层1的偏置,并赋值给转换后的状态字典
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
    # 从检查点中弹出时间文本嵌入器的线性层2的权重,并赋值给转换后的状态字典
    converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
        "t_embedder.mlp.2.weight"
    )
    # 从检查点中弹出时间文本嵌入器的线性层2的偏置,并赋值给转换后的状态字典
    converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
    
    # 从检查点中弹出上下文嵌入器的权重,并赋值给转换后的状态字典
    converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
    # 从检查点中弹出上下文嵌入器的偏置,并赋值给转换后的状态字典
    converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
    
    # 从检查点中弹出时间文本嵌入的线性层1的权重,并赋值给转换后的状态字典
    converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
    # 从检查点中弹出时间文本嵌入的线性层1的偏置,并赋值给转换后的状态字典
    converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
    # 从检查点中弹出时间文本嵌入的线性层2的权重,并赋值给转换后的状态字典
    converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
    # 从检查点中弹出时间文本嵌入的线性层2的偏置,并赋值给转换后的状态字典
    converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
    
    # 从检查点中弹出最终层的线性权重,并赋值给转换后的状态字典
    converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
    # 从检查点中弹出最终层的偏置,并赋值给转换后的状态字典
    converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
    # 从检查点中弹出最终层的自适应层归一化调制的权重,进行维度调整后赋值给转换后的状态字典
    converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
        checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
    )
    # 从检查点中弹出最终层的自适应层归一化调制的偏置,进行维度调整后赋值给转换后的状态字典
    converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
        checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
    )
    
    # 返回转换后的状态字典
    return converted_state_dict
# 检查给定的检查点是否包含特定的 T5 模型权重
def is_t5_in_single_file(checkpoint):
    # 如果检查点中包含 T5 权重,则返回 True
    if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
        return True

    # 否则返回 False
    return False


# 将 SD3 格式的 T5 检查点转换为 Diffusers 格式
def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
    # 获取检查点中的所有键
    keys = list(checkpoint.keys())
    # 初始化空的字典以存储转换后的模型权重
    text_model_dict = {}

    # 定义需要移除的前缀
    remove_prefixes = ["text_encoders.t5xxl.transformer."]

    # 遍历每个键
    for key in keys:
        # 对每个前缀进行检查
        for prefix in remove_prefixes:
            # 如果键以前缀开头
            if key.startswith(prefix):
                # 替换前缀并获取新的键名
                diffusers_key = key.replace(prefix, "")
                # 将原键对应的值存入新字典中
                text_model_dict[diffusers_key] = checkpoint.get(key)

    # 返回转换后的模型权重字典
    return text_model_dict


# 从检查点创建 Diffusers 格式的 T5 模型
def create_diffusers_t5_model_from_checkpoint(
    cls,
    checkpoint,
    subfolder="",
    config=None,
    torch_dtype=None,
    local_files_only=None,
):
    # 如果提供了配置,则使用它
    if config:
        config = {"pretrained_model_name_or_path": config}
    else:
        # 否则从检查点中获取配置
        config = fetch_diffusers_config(checkpoint)

    # 从配置中加载模型配置
    model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
    # 根据是否可用的加速初始化上下文
    ctx = init_empty_weights if is_accelerate_available() else nullcontext
    # 使用上下文创建模型
    with ctx():
        model = cls(model_config)

    # 将检查点转换为 Diffusers 格式
    diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)

    # 如果加速可用,加载模型权重
    if is_accelerate_available():
        unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
        # 检查是否有需要忽略的意外键
        if model._keys_to_ignore_on_load_unexpected is not None:
            for pat in model._keys_to_ignore_on_load_unexpected:
                # 过滤意外键
                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

        # 如果存在意外键,发出警告
        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
            )

    else:
        # 否则直接加载权重
        model.load_state_dict(diffusers_format_checkpoint)

    # 检查是否需要保持 FP32 模块
    use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
    if use_keep_in_fp32_modules:
        # 获取需要保持为 FP32 的模块
        keep_in_fp32_modules = model._keep_in_fp32_modules
    else:
        keep_in_fp32_modules = []

    # 如果存在需要保持为 FP32 的模块
    if keep_in_fp32_modules is not None:
        # 遍历模型的每个参数
        for name, param in model.named_parameters():
            # 如果参数名中包含需要保持为 FP32 的模块
            if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
                # 将参数数据转换为 FP32(只在局部作用域有效)
                param.data = param.data.to(torch.float32)

    # 返回最终模型
    return model


# 将 Animatediff 格式的检查点转换为 Diffusers 格式
def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
    # 初始化空字典以存储转换后的状态字典
    converted_state_dict = {}
    # 遍历检查点字典中的每个键值对
        for k, v in checkpoint.items():
            # 如果键中包含 "pos_encoder",则跳过此项
            if "pos_encoder" in k:
                continue
    
            else:
                # 替换键名中的特定子字符串,并将值赋给新字典
                converted_state_dict[
                    k.replace(".norms.0", ".norm1")  # 替换 ".norms.0" 为 ".norm1"
                    .replace(".norms.1", ".norm2")  # 替换 ".norms.1" 为 ".norm2"
                    .replace(".ff_norm", ".norm3")  # 替换 ".ff_norm" 为 ".norm3"
                    .replace(".attention_blocks.0", ".attn1")  # 替换 ".attention_blocks.0" 为 ".attn1"
                    .replace(".attention_blocks.1", ".attn2")  # 替换 ".attention_blocks.1" 为 ".attn2"
                    .replace(".temporal_transformer", "")  # 移除 ".temporal_transformer"
                ] = v  # 将原值 v 赋给新的键
    
        # 返回转换后的状态字典
        return converted_state_dict
# 将给定的检查点转换为 Diffusers 格式的模型
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
    # 初始化一个空字典,用于存储转换后的状态字典
    converted_state_dict = {}
    # 获取检查点中所有键的列表
    keys = list(checkpoint.keys())
    # 遍历每个键
    for k in keys:
        # 如果键包含 "model.diffusion_model.",则替换该部分并更新检查点
        if "model.diffusion_model." in k:
            checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

    # 计算双层块的数量
    num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1  # noqa: C401
    # 计算单层块的数量
    num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1  # noqa: C401
    # 设置 MLP 比率
    mlp_ratio = 4.0
    # 设置内部维度
    inner_dim = 3072

    # 定义一个函数,用于交换线性投影的权重顺序
    def swap_scale_shift(weight):
        # 将权重拆分为 shift 和 scale
        shift, scale = weight.chunk(2, dim=0)
        # 重新连接为新的权重
        new_weight = torch.cat([scale, shift], dim=0)
        return new_weight

    ## 将时间嵌入的线性层权重从检查点中提取并赋值
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
        "time_in.in_layer.weight"
    )
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
    converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
        "time_in.out_layer.weight"
    )
    converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")

    ## 将文本嵌入的线性层权重从检查点中提取并赋值
    converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
    converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
    converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
        "vector_in.out_layer.weight"
    )
    converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")

    # 检查是否有引导信息
    has_guidance = any("guidance" in k for k in checkpoint)
    # 如果存在引导信息,从检查点中提取并赋值
    if has_guidance:
        converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
            "guidance_in.in_layer.weight"
        )
        converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
            "guidance_in.in_layer.bias"
        )
        converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
            "guidance_in.out_layer.weight"
        )
        converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
            "guidance_in.out_layer.bias"
        )

    # 提取上下文嵌入的权重和偏置
    converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
    converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
    # x_embedder
    # 从检查点中弹出图像输入的权重,赋值给转换后的状态字典的 x_embedder 权重
    converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
    # 从检查点中弹出图像输入的偏置,赋值给转换后的状态字典的 x_embedder 偏置
    converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")

    # double transformer blocks
    # single transfomer blocks
    # 遍历单个变换器层的数量
    for i in range(num_single_layers):
        # 生成当前单个变换器块的前缀
        block_prefix = f"single_transformer_blocks.{i}."
        # norm.linear  <- single_blocks.0.modulation.lin
        # 从检查点中弹出当前层的线性权重,赋值给转换后的状态字典的归一化层权重
        converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
            f"single_blocks.{i}.modulation.lin.weight"
        )
        # 从检查点中弹出当前层的线性偏置,赋值给转换后的状态字典的归一化层偏置
        converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
            f"single_blocks.{i}.modulation.lin.bias"
        )
        # Q, K, V, mlp
        # 计算 MLP 的隐藏维度
        mlp_hidden_dim = int(inner_dim * mlp_ratio)
        # 定义分割大小,用于分割线性权重
        split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
        # 从检查点中弹出当前层的线性权重,并按分割大小分割成 Q, K, V 和 MLP
        q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
        # 从检查点中弹出当前层的线性偏置,并按分割大小分割成 Q, K, V 和 MLP 偏置
        q_bias, k_bias, v_bias, mlp_bias = torch.split(
            checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
        )
        # 将 Q 的权重和偏置添加到转换后的状态字典
        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
        converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
        # 将 K 的权重和偏置添加到转换后的状态字典
        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
        converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
        # 将 V 的权重和偏置添加到转换后的状态字典
        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
        converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
        # 将 MLP 的权重和偏置添加到转换后的状态字典
        converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
        converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
        # qk norm
        # 从检查点中弹出当前层的 Q 归一化权重,赋值给转换后的状态字典
        converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
            f"single_blocks.{i}.norm.query_norm.scale"
        )
        # 从检查点中弹出当前层的 K 归一化权重,赋值给转换后的状态字典
        converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
            f"single_blocks.{i}.norm.key_norm.scale"
        )
        # output projections.
        # 从检查点中弹出当前层的输出线性权重,赋值给转换后的状态字典
        converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
        # 从检查点中弹出当前层的输出线性偏置,赋值给转换后的状态字典
        converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")

    # 从检查点中弹出最终层的线性权重,赋值给转换后的状态字典
    converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
    # 从检查点中弹出最终层的线性偏置,赋值给转换后的状态字典
    converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
    # 从检查点中弹出最终层的归一化调制权重,并进行换位操作
    converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
        checkpoint.pop("final_layer.adaLN_modulation.1.weight")
    )
    # 从检查点中弹出最终层的归一化调制偏置,并进行换位操作
    converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
        checkpoint.pop("final_layer.adaLN_modulation.1.bias")
    )

    # 返回转换后的状态字典
    return converted_state_dict
posted @   绝不原创的飞龙  阅读(309)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示