diffusers-源码解析-一-

diffusers 源码解析(一)

.\diffusers\callbacks.py

# 导入类型注解 Any, Dict, List
from typing import Any, Dict, List

# 从配置工具导入基类 ConfigMixin 和注册函数 register_to_config
from .configuration_utils import ConfigMixin, register_to_config
# 从工具模块导入常量 CONFIG_NAME
from .utils import CONFIG_NAME


# 定义一个回调基类,用于管道中的所有官方回调
class PipelineCallback(ConfigMixin):
    """
    Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
    custom callbacks and ensures that all callbacks have a consistent interface.

    Please implement the following:
        `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
        include
            variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
        `callback_fn`: This method defines the core functionality of your callback.
    """

    # 设置配置名称为 CONFIG_NAME
    config_name = CONFIG_NAME

    # 注册构造函数到配置
    @register_to_config
    def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
        # 调用父类构造函数
        super().__init__()

        # 检查 cutoff_step_ratio 和 cutoff_step_index 是否同时为 None 或同时存在
        if (cutoff_step_ratio is None and cutoff_step_index is None) or (
            cutoff_step_ratio is not None and cutoff_step_index is not None
        ):
            # 如果同时为 None 或同时存在则抛出异常
            raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")

        # 检查 cutoff_step_ratio 是否为有效的浮点数
        if cutoff_step_ratio is not None and (
            not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
        ):
            # 如果 cutoff_step_ratio 不在 0.0 到 1.0 之间则抛出异常
            raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")

    # 定义 tensor_inputs 属性,返回类型为 List[str]
    @property
    def tensor_inputs(self) -> List[str]:
        # 抛出未实现错误,提醒用户必须实现该属性
        raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")

    # 定义 callback_fn 方法,接收管道、步骤索引、时间步和回调参数,返回类型为 Dict[str, Any]
    def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
        # 抛出未实现错误,提醒用户必须实现该方法
        raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")

    # 定义可调用方法,接收管道、步骤索引、时间步和回调参数,返回类型为 Dict[str, Any]
    def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        # 调用 callback_fn 方法并返回结果
        return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)


# 定义一个多管道回调类
class MultiPipelineCallbacks:
    """
    This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
    provides a unified interface for calling all of them.
    """

    # 初始化方法,接收回调列表
    def __init__(self, callbacks: List[PipelineCallback]):
        # 将回调列表存储为类属性
        self.callbacks = callbacks

    # 定义 tensor_inputs 属性,返回所有回调的输入列表
    @property
    def tensor_inputs(self) -> List[str]:
        # 使用列表推导式从每个回调中获取 tensor_inputs
        return [input for callback in self.callbacks for input in callback.tensor_inputs]

    # 定义可调用方法,接收管道、步骤索引、时间步和回调参数,返回类型为 Dict[str, Any]
    def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        """
        Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
        """
        # 遍历所有回调并依次调用
        for callback in self.callbacks:
            # 更新回调参数
            callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)

        # 返回最终的回调参数
        return callback_kwargs


# 定义稳定扩散管道的截止回调
class SDCFGCutoffCallback(PipelineCallback):
    """
    Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
    # 回调函数用于在特定步骤禁用 CFG
        `cutoff_step_index`), this callback will disable the CFG.
    
        # 注意:此回调通过将 `_guidance_scale` 属性更改为 0.0 来修改管道,发生在截止步骤之后。
        """
    
        # 定义输入的张量名称列表
        tensor_inputs = ["prompt_embeds"]
    
        # 定义回调函数,接收管道、步骤索引、时间步和其他回调参数
        def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
            # 获取截止步骤比例
            cutoff_step_ratio = self.config.cutoff_step_ratio
            # 获取截止步骤索引
            cutoff_step_index = self.config.cutoff_step_index
    
            # 如果截止步骤索引不为 None,使用该索引,否则根据比例计算截止步骤
            cutoff_step = (
                cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
            )
    
            # 如果当前步骤索引等于截止步骤,进行以下操作
            if step_index == cutoff_step:
                # 从回调参数中获取提示嵌入
                prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
                # 获取最后一个嵌入,表示条件文本标记的嵌入
                prompt_embeds = prompt_embeds[-1:]  # "-1" denotes the embeddings for conditional text tokens.
    
                # 将管道的指导比例设置为 0.0
                pipeline._guidance_scale = 0.0
    
                # 更新回调参数中的提示嵌入
                callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
            # 返回更新后的回调参数
            return callback_kwargs
# 定义 SDXLCFGCutoffCallback 类,继承自 PipelineCallback
class SDXLCFGCutoffCallback(PipelineCallback):
    """
    Stable Diffusion XL 管道的回调函数。在指定步骤数后(由 `cutoff_step_ratio` 或 `cutoff_step_index` 设置),
    此回调将禁用 CFG。

    注意:此回调通过将 `_guidance_scale` 属性在截止步骤后更改为 0.0 来改变管道。
    """

    # 定义需要处理的张量输入
    tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]

    # 定义回调函数,接受管道、步骤索引、时间步和回调参数
    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        # 从配置中获取截止步骤比例
        cutoff_step_ratio = self.config.cutoff_step_ratio
        # 从配置中获取截止步骤索引
        cutoff_step_index = self.config.cutoff_step_index

        # 如果截止步骤索引不为 None,则使用该值,否则使用截止步骤比例计算
        cutoff_step = (
            cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
        )

        # 如果当前步骤等于截止步骤
        if step_index == cutoff_step:
            # 获取条件文本令牌的嵌入
            prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
            # 取最后一个嵌入,表示条件文本令牌的嵌入
            prompt_embeds = prompt_embeds[-1:]  # "-1" denotes the embeddings for conditional text tokens.

            # 获取条件池化文本令牌的嵌入
            add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
            # 取最后一个嵌入,表示条件池化文本令牌的嵌入
            add_text_embeds = add_text_embeds[-1:]  # "-1" denotes the embeddings for conditional pooled text tokens

            # 获取条件附加时间向量的 ID
            add_time_ids = callback_kwargs[self.tensor_inputs[2]]
            # 取最后一个 ID,表示条件附加时间向量的 ID
            add_time_ids = add_time_ids[-1:]  # "-1" denotes the embeddings for conditional added time vector

            # 将管道的引导比例设置为 0.0
            pipeline._guidance_scale = 0.0

            # 更新回调参数中的嵌入和 ID
            callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
            callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
            callback_kwargs[self.tensor_inputs[2]] = add_time_ids
        # 返回更新后的回调参数
        return callback_kwargs


# 定义 IPAdapterScaleCutoffCallback 类,继承自 PipelineCallback
class IPAdapterScaleCutoffCallback(PipelineCallback):
    """
    适用于任何继承 `IPAdapterMixin` 的管道的回调函数。在指定步骤数后(由 `cutoff_step_ratio` 或
    `cutoff_step_index` 设置),此回调将 IP 适配器的比例设置为 `0.0`。

    注意:此回调通过在截止步骤后将比例设置为 0.0 来改变 IP 适配器注意力处理器。
    """

    # 定义需要处理的张量输入(此类无具体输入)
    tensor_inputs = []

    # 定义回调函数,接受管道、步骤索引、时间步和回调参数
    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        # 从配置中获取截止步骤比例
        cutoff_step_ratio = self.config.cutoff_step_ratio
        # 从配置中获取截止步骤索引
        cutoff_step_index = self.config.cutoff_step_index

        # 如果截止步骤索引不为 None,则使用该值,否则使用截止步骤比例计算
        cutoff_step = (
            cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
        )

        # 如果当前步骤等于截止步骤
        if step_index == cutoff_step:
            # 将 IP 适配器的比例设置为 0.0
            pipeline.set_ip_adapter_scale(0.0)
        # 返回回调参数
        return callback_kwargs

.\diffusers\commands\diffusers_cli.py

# 指定解释器路径
#!/usr/bin/env python
# 版权信息,表明版权归 HuggingFace 团队所有
# 版权声明,受 Apache 2.0 许可协议约束
#
# 在使用此文件前必须遵循许可证的规定
# 许可证获取地址
#
# 软件在无任何保证的情况下按 "原样" 发行
# 参见许可证以了解权限和限制

# 导入命令行参数解析器
from argparse import ArgumentParser

# 导入环境命令模块
from .env import EnvironmentCommand
# 导入 FP16 Safetensors 命令模块
from .fp16_safetensors import FP16SafetensorsCommand

# 定义主函数
def main():
    # 创建命令行解析器,设置程序名称和使用说明
    parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
    # 添加子命令解析器,帮助信息
    commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")

    # 注册环境命令为子命令
    EnvironmentCommand.register_subcommand(commands_parser)
    # 注册 FP16 Safetensors 命令为子命令
    FP16SafetensorsCommand.register_subcommand(commands_parser)

    # 解析命令行参数
    args = parser.parse_args()

    # 检查是否有有效的命令函数
    if not hasattr(args, "func"):
        # 打印帮助信息并退出
        parser.print_help()
        exit(1)

    # 执行指定的命令函数
    service = args.func(args)
    # 运行服务
    service.run()

# 如果当前文件是主程序,则执行主函数
if __name__ == "__main__":
    main()

.\diffusers\commands\env.py

# 版权声明,标明版权归 HuggingFace 团队所有
# 
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律要求或书面同意,软件按“原样”分发,
# 不附有任何形式的担保或条件,无论是明示或暗示的。
# 有关许可证的具体条款,请参阅许可证。
#
# 导入 platform 模块,用于获取系统平台信息
import platform
# 导入 subprocess 模块,用于创建子进程
import subprocess
# 从 argparse 模块导入 ArgumentParser 类,用于处理命令行参数
from argparse import ArgumentParser

# 导入 huggingface_hub 库,提供与 Hugging Face Hub 交互的功能
import huggingface_hub

# 从上层包中导入版本信息
from .. import __version__ as version
# 从 utils 模块中导入多个可用性检查函数
from ..utils import (
    is_accelerate_available,      # 检查 accelerate 库是否可用
    is_bitsandbytes_available,     # 检查 bitsandbytes 库是否可用
    is_flax_available,             # 检查 flax 库是否可用
    is_google_colab,              # 检查当前环境是否为 Google Colab
    is_peft_available,             # 检查 peft 库是否可用
    is_safetensors_available,      # 检查 safetensors 库是否可用
    is_torch_available,            # 检查 torch 库是否可用
    is_transformers_available,     # 检查 transformers 库是否可用
    is_xformers_available,         # 检查 xformers 库是否可用
)
# 从当前包中导入 BaseDiffusersCLICommand 基类
from . import BaseDiffusersCLICommand


# 定义一个工厂函数,返回 EnvironmentCommand 的实例
def info_command_factory(_):
    return EnvironmentCommand()


# 定义 EnvironmentCommand 类,继承自 BaseDiffusersCLICommand
class EnvironmentCommand(BaseDiffusersCLICommand):
    # 注册子命令的方法,接收 ArgumentParser 对象
    @staticmethod
    def register_subcommand(parser: ArgumentParser) -> None:
        # 在解析器中添加名为 "env" 的子命令
        download_parser = parser.add_parser("env")
        # 设置默认的处理函数为 info_command_factory
        download_parser.set_defaults(func=info_command_factory)

    # 格式化字典的方法,将字典转换为字符串
    @staticmethod
    def format_dict(d: dict) -> str:
        # 以特定格式将字典内容转换为字符串
        return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"

.\diffusers\commands\fp16_safetensors.py

# 版权声明,指明代码的版权归 HuggingFace 团队所有
# 
# 根据 Apache License 2.0 版本授权,声明用户须遵守该许可
# 用户可以在以下网址获取许可副本
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律规定或书面同意,软件在“按原样”基础上分发,
# 不提供任何形式的担保或条件
# 请查看许可以了解特定语言的权限和限制

"""
用法示例:
    diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
"""

# 导入所需模块
import glob  # 用于文件路径匹配
import json  # 用于 JSON 数据解析
import warnings  # 用于发出警告
from argparse import ArgumentParser, Namespace  # 用于命令行参数解析
from importlib import import_module  # 动态导入模块

import huggingface_hub  # Hugging Face Hub 的接口库
import torch  # PyTorch 深度学习库
from huggingface_hub import hf_hub_download  # 从 Hugging Face Hub 下载模型的函数
from packaging import version  # 用于版本比较

from ..utils import logging  # 导入日志模块
from . import BaseDiffusersCLICommand  # 导入基本 CLI 命令类


def conversion_command_factory(args: Namespace):
    # 根据传入的命令行参数创建转换命令
    if args.use_auth_token:
        # 发出警告,提示 --use_auth_token 参数已弃用
        warnings.warn(
            "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
            " handled automatically if user is logged in."
        )
    # 返回 FP16SafetensorsCommand 的实例
    return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)


class FP16SafetensorsCommand(BaseDiffusersCLICommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        # 注册子命令 fp16_safetensors
        conversion_parser = parser.add_parser("fp16_safetensors")
        # 添加 ckpt_id 参数,用于指定检查点的仓库 ID
        conversion_parser.add_argument(
            "--ckpt_id",
            type=str,
            help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
        )
        # 添加 fp16 参数,指示是否以 FP16 精度序列化变量
        conversion_parser.add_argument(
            "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
        )
        # 添加 use_safetensors 参数,指示是否以 safetensors 格式序列化
        conversion_parser.add_argument(
            "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
        )
        # 添加 use_auth_token 参数,用于处理私有可见性的检查点
        conversion_parser.add_argument(
            "--use_auth_token",
            action="store_true",
            help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
        )
        # 设置默认函数为 conversion_command_factory
        conversion_parser.set_defaults(func=conversion_command_factory)

    def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
        # 初始化命令类,设置日志记录器和参数
        self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
        # 存储检查点 ID
        self.ckpt_id = ckpt_id
        # 定义本地检查点目录
        self.local_ckpt_dir = f"/tmp/{ckpt_id}"
        # 存储 FP16 精度设置
        self.fp16 = fp16
        # 存储 safetensors 设置
        self.use_safetensors = use_safetensors

        # 检查是否同时未使用 safetensors 和 fp16,若是则抛出异常
        if not self.use_safetensors and not self.fp16:
            raise NotImplementedError(
                "When `use_safetensors` and `fp16` both are False, then this command is of no use."
            )
    # 定义运行方法
    def run(self):
        # 检查 huggingface_hub 版本是否低于 0.9.0
        if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
            # 如果版本低于要求,抛出导入错误
            raise ImportError(
                "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
                " installation."
            )
        else:
            # 从 huggingface_hub 导入创建提交的函数
            from huggingface_hub import create_commit
            # 从 huggingface_hub 导入提交操作类
            from huggingface_hub._commit_api import CommitOperationAdd
    
        # 下载模型索引文件
        model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
        # 打开模型索引文件并读取内容
        with open(model_index, "r") as f:
            # 从 JSON 中提取管道类名称
            pipeline_class_name = json.load(f)["_class_name"]
        # 动态导入对应的管道类
        pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
        # 记录导入的管道类名称
        self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
    
        # 加载适当的管道
        pipeline = pipeline_class.from_pretrained(
            self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
        )
        # 将管道保存到本地目录
        pipeline.save_pretrained(
            self.local_ckpt_dir,
            safe_serialization=True if self.use_safetensors else False,
            variant="fp16" if self.fp16 else None,
        )
        # 记录管道保存的本地目录
        self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
    
        # 获取所有的路径
        if self.fp16:
            # 获取所有 FP16 文件的路径
            modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
        elif self.use_safetensors:
            # 获取所有 Safetensors 文件的路径
            modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
    
        # 准备提交请求
        commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
        operations = []
        # 遍历修改过的路径,准备提交操作
        for path in modified_paths:
            operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
    
        # 打开提交请求
        commit_description = (
            "Variables converted by the [`diffusers`' `fp16_safetensors`"
            " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
        )
        # 创建提交请求并获取其 URL
        hub_pr_url = create_commit(
            repo_id=self.ckpt_id,
            operations=operations,
            commit_message=commit_message,
            commit_description=commit_description,
            repo_type="model",
            create_pr=True,
        ).pr_url
        # 记录提交请求的 URL
        self.logger.info(f"PR created here: {hub_pr_url}.")

.\diffusers\commands\__init__.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 除非遵守该许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,否则根据许可证分发的软件是按“原样”基础分发的,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参阅许可证以获取有关权限和限制的具体信息。

# 从 abc 模块导入 ABC 类和 abstractmethod 装饰器
from abc import ABC, abstractmethod
# 从 argparse 模块导入 ArgumentParser 类,用于解析命令行参数
from argparse import ArgumentParser


# 定义一个抽象基类 BaseDiffusersCLICommand,继承自 ABC
class BaseDiffusersCLICommand(ABC):
    # 定义一个静态抽象方法 register_subcommand,接受一个 ArgumentParser 实例作为参数
    @staticmethod
    @abstractmethod
    def register_subcommand(parser: ArgumentParser):
        # 如果子类没有实现此方法,则抛出 NotImplementedError
        raise NotImplementedError()

    # 定义一个抽象方法 run,供子类实现具体的执行逻辑
    @abstractmethod
    def run(self):
        # 如果子类没有实现此方法,则抛出 NotImplementedError
        raise NotImplementedError()

.\diffusers\configuration_utils.py

# 指定编码为 UTF-8
# 版权声明,指明版权归 HuggingFace Inc. 团队所有
# 版权声明,指明版权归 NVIDIA CORPORATION 所有
#
# 根据 Apache License, Version 2.0 授权本文件
# 只能在遵循许可证的情况下使用该文件
# 可在以下网址获取许可证副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 本文件按“原样”提供,未提供任何明示或暗示的保证或条件
# 参见许可证中关于权限和限制的具体条款
"""配置混合类的基类及其工具函数。"""

# 导入必要的库和模块
import dataclasses  # 提供数据类支持
import functools  # 提供高阶函数的工具
import importlib  # 提供导入模块的功能
import inspect  # 提供获取对象信息的功能
import json  # 提供JSON数据的处理
import os  # 提供与操作系统交互的功能
import re  # 提供正则表达式支持
from collections import OrderedDict  # 提供有序字典支持
from pathlib import Path  # 提供路径操作支持
from typing import Any, Dict, Tuple, Union  # 提供类型提示支持

import numpy as np  # 导入 NumPy 库
from huggingface_hub import create_repo, hf_hub_download  # 从 Hugging Face Hub 导入相关函数
from huggingface_hub.utils import (  # 导入 Hugging Face Hub 工具中的异常和验证函数
    EntryNotFoundError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
    validate_hf_hub_args,
)
from requests import HTTPError  # 导入处理 HTTP 错误的类

from . import __version__  # 导入当前模块的版本信息
from .utils import (  # 从工具模块导入常用工具
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,
    DummyObject,
    deprecate,
    extract_commit_hash,
    http_user_agent,
    logging,
)

# 创建日志记录器实例
logger = logging.get_logger(__name__)

# 编译正则表达式用于匹配配置文件名
_re_configuration_file = re.compile(r"config\.(.*)\.json")


class FrozenDict(OrderedDict):  # 定义一个不可变字典类,继承自有序字典
    def __init__(self, *args, **kwargs):  # 初始化方法,接收任意参数
        super().__init__(*args, **kwargs)  # 调用父类初始化方法

        for key, value in self.items():  # 遍历字典中的每个键值对
            setattr(self, key, value)  # 将每个键值对作为属性设置

        self.__frozen = True  # 标记字典为不可变状态

    def __delitem__(self, *args, **kwargs):  # 禁止使用 del 删除字典项
        raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")  # 抛出异常

    def setdefault(self, *args, **kwargs):  # 禁止使用 setdefault 方法
        raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")  # 抛出异常

    def pop(self, *args, **kwargs):  # 禁止使用 pop 方法
        raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")  # 抛出异常

    def update(self, *args, **kwargs):  # 禁止使用 update 方法
        raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")  # 抛出异常

    def __setattr__(self, name, value):  # 重写设置属性的方法
        if hasattr(self, "__frozen") and self.__frozen:  # 检查是否已被冻结
            raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")  # 抛出异常
        super().__setattr__(name, value)  # 调用父类方法设置属性

    def __setitem__(self, name, value):  # 重写设置字典项的方法
        if hasattr(self, "__frozen") and self.__frozen:  # 检查是否已被冻结
            raise Exception(f"You cannot use ``__setitem__`` on a {self.__class__.__name__} instance.")  # 抛出异常
        super().__setitem__(name, value)  # 调用父类方法设置字典项


class ConfigMixin:  # 定义配置混合类
    r"""  # 类文档字符串,描述类的用途
    Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
    provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
    # 保存从 `ConfigMixin` 继承的类的配置。
    # 类属性:
    # - **config_name** (`str`) -- 应该在调用 `~ConfigMixin.save_config` 时存储的配置文件名(应由父类重写)。
    # - **ignore_for_config** (`List[str]`) -- 不应在配置中保存的属性列表(应由子类重写)。
    # - **has_compatibles** (`bool`) -- 类是否有兼容的类(应由子类重写)。
    # - **_deprecated_kwargs** (`List[str]`) -- 已废弃的关键字参数。注意,`init` 函数只有在至少有一个参数被废弃时才应具有 `kwargs` 参数(应由子类重写)。
    class ConfigMixin:
        config_name = None  # 配置文件名初始化为 None
        ignore_for_config = []  # 不保存到配置的属性列表初始化为空
        has_compatibles = False  # 默认没有兼容类
    
        _deprecated_kwargs = []  # 已废弃的关键字参数列表初始化为空
    
        def register_to_config(self, **kwargs):
            # 检查 config_name 是否已定义
            if self.config_name is None:
                raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
            # 针对用于废弃警告的特殊情况
            # TODO: 当移除废弃警告和 `kwargs` 参数时删除此处
            kwargs.pop("kwargs", None)  # 从 kwargs 中移除 "kwargs"
    
            # 如果没有 _internal_dict 则初始化
            if not hasattr(self, "_internal_dict"):
                internal_dict = kwargs  # 直接使用 kwargs
            else:
                previous_dict = dict(self._internal_dict)  # 复制之前的字典
                # 合并之前的字典和新的 kwargs
                internal_dict = {**self._internal_dict, **kwargs}
                logger.debug(f"Updating config from {previous_dict} to {internal_dict}")  # 记录更新日志
    
            self._internal_dict = FrozenDict(internal_dict)  # 将内部字典冻结以防修改
    
        def __getattr__(self, name: str) -> Any:
            """覆盖 `getattr` 的唯一原因是优雅地废弃直接访问配置属性。
            参见 https://github.com/huggingface/diffusers/pull/3129
    
            此函数主要复制自 PyTorch 的 `__getattr__` 重写。
            """
            # 检查是否在 _internal_dict 中,并且该名称存在
            is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
            is_attribute = name in self.__dict__  # 检查是否为直接属性
    
            # 如果在配置中但不是属性,则发出废弃警告
            if is_in_config and not is_attribute:
                deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
                deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
                return self._internal_dict[name]  # 通过 _internal_dict 返回该属性
    
            raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")  # 引发属性错误
    # 定义一个保存配置的方法,接受保存目录和其他可选参数
    def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
        """
        保存配置对象到指定目录 `save_directory`,以便使用
        [`~ConfigMixin.from_config`] 类方法重新加载。

        参数:
            save_directory (`str` 或 `os.PathLike`):
                保存配置 JSON 文件的目录(如果不存在则会创建)。
            push_to_hub (`bool`, *可选*, 默认值为 `False`):
                保存后是否将模型推送到 Hugging Face Hub。可以用 `repo_id` 指定
                要推送的仓库(默认为 `save_directory` 的名称)。
            kwargs (`Dict[str, Any]`, *可选*):
                额外的关键字参数,将传递给 [`~utils.PushToHubMixin.push_to_hub`] 方法。
        """
        # 如果提供的路径是文件,则抛出异常,要求是目录
        if os.path.isfile(save_directory):
            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

        # 创建目录,存在时不会报错
        os.makedirs(save_directory, exist_ok=True)

        # 根据预定义名称保存时,可以使用 `from_config` 加载
        output_config_file = os.path.join(save_directory, self.config_name)

        # 将配置写入 JSON 文件
        self.to_json_file(output_config_file)
        # 记录保存配置的日志信息
        logger.info(f"Configuration saved in {output_config_file}")

        # 如果需要推送到 Hub
        if push_to_hub:
            # 从 kwargs 中弹出提交信息
            commit_message = kwargs.pop("commit_message", None)
            # 从 kwargs 中弹出私有标志
            private = kwargs.pop("private", False)
            # 从 kwargs 中弹出创建 PR 的标志
            create_pr = kwargs.pop("create_pr", False)
            # 从 kwargs 中弹出令牌
            token = kwargs.pop("token", None)
            # 从 kwargs 中获取 repo_id,默认为保存目录名称
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            # 创建仓库,若存在则不报错,并返回仓库 ID
            repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

            # 上传文件夹到指定的仓库
            self._upload_folder(
                save_directory,
                repo_id,
                token=token,
                commit_message=commit_message,
                create_pr=create_pr,
            )

    # 定义一个类方法,获取配置字典
    @classmethod
    @classmethod
    def get_config_dict(cls, *args, **kwargs):
        # 生成废弃消息,提醒用户此方法将被移除
        deprecation_message = (
            f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
            " removed in version v1.0.0"
        )
        # 调用废弃警告函数
        deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
        # 返回加载的配置
        return cls.load_config(*args, **kwargs)

    # 定义一个类方法,用于加载配置
    @classmethod
    @validate_hf_hub_args
    def load_config(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        return_unused_kwargs=False,
        return_commit_hash=False,
        **kwargs,
    # 定义一个静态方法,获取初始化所需的关键字
    @staticmethod
    def _get_init_keys(input_class):
        # 返回类初始化方法的参数名称集合
        return set(dict(inspect.signature(input_class.__init__).parameters).keys())

    # 额外的类方法
    @classmethod
    @classmethod
    # 从 JSON 文件创建字典
    def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
        # 打开指定的 JSON 文件,使用 UTF-8 编码读取内容
        with open(json_file, "r", encoding="utf-8") as reader:
            # 读取文件内容并存储到变量 text 中
            text = reader.read()
        # 将读取的 JSON 字符串解析为字典并返回
        return json.loads(text)

    # 返回类的字符串表示形式
    def __repr__(self):
        # 使用类名和 JSON 字符串表示配置实例返回字符串
        return f"{self.__class__.__name__} {self.to_json_string()}"

    # 定义一个只读属性 config
    @property
    def config(self) -> Dict[str, Any]:
        """
        返回类的配置作为一个不可变字典

        Returns:
            `Dict[str, Any]`: 类的配置字典。
        """
        # 返回内部字典 _internal_dict
        return self._internal_dict

    # 将配置实例序列化为 JSON 字符串
    def to_json_string(self) -> str:
        """
        将配置实例序列化为 JSON 字符串。

        Returns:
            `str`:
                包含配置实例的所有属性的 JSON 格式字符串。
        """
        # 检查是否存在 _internal_dict,若不存在则使用空字典
        config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
        # 将类名添加到配置字典中
        config_dict["_class_name"] = self.__class__.__name__
        # 将当前版本添加到配置字典中
        config_dict["_diffusers_version"] = __version__

        # 定义一个用于将值转换为可保存的 JSON 格式的函数
        def to_json_saveable(value):
            # 如果值是 numpy 数组,则转换为列表
            if isinstance(value, np.ndarray):
                value = value.tolist()
            # 如果值是 Path 对象,则转换为 POSIX 路径字符串
            elif isinstance(value, Path):
                value = value.as_posix()
            # 返回转换后的值
            return value

        # 对配置字典中的每个键值对进行转换
        config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
        # 从字典中移除 "_ignore_files" 和 "_use_default_values" 项
        config_dict.pop("_ignore_files", None)
        config_dict.pop("_use_default_values", None)

        # 将配置字典转换为格式化的 JSON 字符串并返回
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

    # 将配置实例的参数保存到 JSON 文件
    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
        """
        将配置实例的参数保存到 JSON 文件。

        Args:
            json_file_path (`str` or `os.PathLike`):
                要保存配置实例参数的 JSON 文件路径。
        """
        # 打开指定的 JSON 文件进行写入,使用 UTF-8 编码
        with open(json_file_path, "w", encoding="utf-8") as writer:
            # 将配置实例转换为 JSON 字符串并写入文件
            writer.write(self.to_json_string())
# 装饰器,用于应用在继承自 [`ConfigMixin`] 的类的初始化方法上,自动将所有参数发送到 `self.register_for_config`
def register_to_config(init):
    # 文档字符串,描述装饰器的功能和警告
    r"""
    Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
    automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
    shouldn't be registered in the config, use the `ignore_for_config` class variable

    Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
    """

    # 包装原始初始化方法,以便在其上添加功能
    @functools.wraps(init)
    def inner_init(self, *args, **kwargs):
        # 忽略初始化方法中的私有关键字参数
        init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
        # 提取私有关键字参数以供后续使用
        config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
        # 检查当前类是否继承自 ConfigMixin
        if not isinstance(self, ConfigMixin):
            raise RuntimeError(
                f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
                "not inherit from `ConfigMixin`."
            )

        # 获取需要忽略的配置参数列表
        ignore = getattr(self, "ignore_for_config", [])
        # 对齐位置参数与关键字参数
        new_kwargs = {}
        # 获取初始化方法的签名
        signature = inspect.signature(init)
        # 提取参数名和默认值,排除忽略的参数
        parameters = {
            name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
        }
        # 将位置参数映射到新关键字参数
        for arg, name in zip(args, parameters.keys()):
            new_kwargs[name] = arg

        # 更新新关键字参数,加入所有未被忽略的关键字参数
        new_kwargs.update(
            {
                k: init_kwargs.get(k, default)
                for k, default in parameters.items()
                if k not in ignore and k not in new_kwargs
            }
        )

        # 记录未在配置中出现的参数
        if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
            new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))

        # 合并配置初始化参数和新关键字参数
        new_kwargs = {**config_init_kwargs, **new_kwargs}
        # 调用类的注册方法,将参数发送到配置
        getattr(self, "register_to_config")(**new_kwargs)
        # 调用原始初始化方法
        init(self, *args, **init_kwargs)

    # 返回包装后的初始化方法
    return inner_init


# 装饰器函数,用于在类上注册配置功能
def flax_register_to_config(cls):
    # 保存原始初始化方法
    original_init = cls.__init__

    # 包装原始初始化方法,以便在其上添加功能
    @functools.wraps(original_init)
    # 定义初始化方法,接受可变位置和关键字参数
        def init(self, *args, **kwargs):
            # 检查当前实例是否继承自 ConfigMixin
            if not isinstance(self, ConfigMixin):
                raise RuntimeError(
                    # 抛出异常,提示类未继承 ConfigMixin
                    f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
                    "not inherit from `ConfigMixin`."
                )
    
            # 忽略私有关键字参数,获取所有传入的属性
            init_kwargs = dict(kwargs.items())
    
            # 获取默认值
            fields = dataclasses.fields(self)
            default_kwargs = {}
            for field in fields:
                # 忽略 flax 特定属性
                if field.name in self._flax_internal_args:
                    continue
                # 检查字段的默认值是否缺失
                if type(field.default) == dataclasses._MISSING_TYPE:
                    default_kwargs[field.name] = None
                else:
                    # 获取字段的默认值
                    default_kwargs[field.name] = getattr(self, field.name)
    
            # 确保 init_kwargs 可以覆盖默认值
            new_kwargs = {**default_kwargs, **init_kwargs}
            # 从 new_kwargs 中移除 dtype,确保它仅在 init_kwargs 中
            if "dtype" in new_kwargs:
                new_kwargs.pop("dtype")
    
            # 获取与关键字参数对齐的位置参数
            for i, arg in enumerate(args):
                name = fields[i].name
                new_kwargs[name] = arg
    
            # 记录未在加载配置中出现的参数
            if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
                new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
    
            # 调用 register_to_config 方法,传入新构建的关键字参数
            getattr(self, "register_to_config")(**new_kwargs)
            # 调用原始初始化方法
            original_init(self, *args, **kwargs)
    
        # 将自定义初始化方法赋值给类的 __init__ 方法
        cls.__init__ = init
        return cls
# 定义一个名为 LegacyConfigMixin 的类,它是 ConfigMixin 的子类
class LegacyConfigMixin(ConfigMixin):
    r"""
    该类是 `ConfigMixin` 的子类,用于将旧类(如 `Transformer2DModel`)映射到更
    特定于管道的类(如 `DiTTransformer2DModel`)。
    """

    # 定义一个类方法 from_config,接收配置和其他可选参数
    @classmethod
    def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
        # 为了防止依赖导入问题,从指定模块导入函数
        from .models.model_loading_utils import _fetch_remapped_cls_from_config

        # 调用函数,解析类的映射关系
        remapped_class = _fetch_remapped_cls_from_config(config, cls)

        # 返回映射后的类使用配置和其他参数进行的实例化
        return remapped_class.from_config(config, return_unused_kwargs, **kwargs)

.\diffusers\dependency_versions_check.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守该许可证,否则您不得使用此文件。
# 您可以在以下位置获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,
# 否则根据许可证分发的软件按“原样”提供,
# 不附带任何形式的明示或暗示的保证或条件。
# 请参阅许可证以获取特定语言管理权限和
# 限制的信息。

# 从当前包中导入依赖版本表
from .dependency_versions_table import deps
# 从当前包中导入版本检查工具
from .utils.versions import require_version, require_version_core


# 定义我们在运行时始终要检查的模块版本
# (通常是 setup.py 中定义的 `install_requires`)
#
# 特定顺序说明:
# - tqdm 必须在 tokenizers 之前检查

# 需要在运行时检查的包列表,使用空格分隔并拆分成列表
pkgs_to_check_at_runtime = "python requests filelock numpy".split()
# 遍历每个需要检查的包
for pkg in pkgs_to_check_at_runtime:
    # 如果包在依赖版本字典中
    if pkg in deps:
        # 检查该包的版本是否符合要求
        require_version_core(deps[pkg])
    # 如果包不在依赖版本字典中,抛出异常
    else:
        raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")


# 定义一个函数用于检查特定包的版本
def dep_version_check(pkg, hint=None):
    # 调用版本检查工具,检查指定包的版本
    require_version(deps[pkg], hint)

.\diffusers\dependency_versions_table.py

# 该文件为自动生成文件。要更新:
# 1. 修改 setup.py 中的 `_deps` 字典
# 2. 运行 `make deps_table_update`
deps = {  # 创建一个字典,用于存储依赖包及其版本
    "Pillow": "Pillow",  # 指定 Pillow 包,版本默认为最新
    "accelerate": "accelerate>=0.31.0",  # 指定 accelerate 包,版本要求 >=0.31.0
    "compel": "compel==0.1.8",  # 指定 compel 包,版本固定为 0.1.8
    "datasets": "datasets",  # 指定 datasets 包,版本默认为最新
    "filelock": "filelock",  # 指定 filelock 包,版本默认为最新
    "flax": "flax>=0.4.1",  # 指定 flax 包,版本要求 >=0.4.1
    "hf-doc-builder": "hf-doc-builder>=0.3.0",  # 指定 hf-doc-builder 包,版本要求 >=0.3.0
    "huggingface-hub": "huggingface-hub>=0.23.2",  # 指定 huggingface-hub 包,版本要求 >=0.23.2
    "requests-mock": "requests-mock==1.10.0",  # 指定 requests-mock 包,版本固定为 1.10.0
    "importlib_metadata": "importlib_metadata",  # 指定 importlib_metadata 包,版本默认为最新
    "invisible-watermark": "invisible-watermark>=0.2.0",  # 指定 invisible-watermark 包,版本要求 >=0.2.0
    "isort": "isort>=5.5.4",  # 指定 isort 包,版本要求 >=5.5.4
    "jax": "jax>=0.4.1",  # 指定 jax 包,版本要求 >=0.4.1
    "jaxlib": "jaxlib>=0.4.1",  # 指定 jaxlib 包,版本要求 >=0.4.1
    "Jinja2": "Jinja2",  # 指定 Jinja2 包,版本默认为最新
    "k-diffusion": "k-diffusion>=0.0.12",  # 指定 k-diffusion 包,版本要求 >=0.0.12
    "torchsde": "torchsde",  # 指定 torchsde 包,版本默认为最新
    "note_seq": "note_seq",  # 指定 note_seq 包,版本默认为最新
    "librosa": "librosa",  # 指定 librosa 包,版本默认为最新
    "numpy": "numpy",  # 指定 numpy 包,版本默认为最新
    "parameterized": "parameterized",  # 指定 parameterized 包,版本默认为最新
    "peft": "peft>=0.6.0",  # 指定 peft 包,版本要求 >=0.6.0
    "protobuf": "protobuf>=3.20.3,<4",  # 指定 protobuf 包,版本要求 >=3.20.3 且 <4
    "pytest": "pytest",  # 指定 pytest 包,版本默认为最新
    "pytest-timeout": "pytest-timeout",  # 指定 pytest-timeout 包,版本默认为最新
    "pytest-xdist": "pytest-xdist",  # 指定 pytest-xdist 包,版本默认为最新
    "python": "python>=3.8.0",  # 指定 python,版本要求 >=3.8.0
    "ruff": "ruff==0.1.5",  # 指定 ruff 包,版本固定为 0.1.5
    "safetensors": "safetensors>=0.3.1",  # 指定 safetensors 包,版本要求 >=0.3.1
    "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",  # 指定 sentencepiece 包,版本要求 >=0.1.91 且 !=0.1.92
    "GitPython": "GitPython<3.1.19",  # 指定 GitPython 包,版本要求 <3.1.19
    "scipy": "scipy",  # 指定 scipy 包,版本默认为最新
    "onnx": "onnx",  # 指定 onnx 包,版本默认为最新
    "regex": "regex!=2019.12.17",  # 指定 regex 包,版本要求 !=2019.12.17
    "requests": "requests",  # 指定 requests 包,版本默认为最新
    "tensorboard": "tensorboard",  # 指定 tensorboard 包,版本默认为最新
    "torch": "torch>=1.4",  # 指定 torch 包,版本要求 >=1.4
    "torchvision": "torchvision",  # 指定 torchvision 包,版本默认为最新
    "transformers": "transformers>=4.41.2",  # 指定 transformers 包,版本要求 >=4.41.2
    "urllib3": "urllib3<=2.0.0",  # 指定 urllib3 包,版本要求 <=2.0.0
    "black": "black",  # 指定 black 包,版本默认为最新
}

.\diffusers\experimental\rl\value_guided_sampling.py

# 版权声明,2024年HuggingFace团队版权所有
# 
# 根据Apache许可证第2.0版("许可证")许可;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面协议另有约定,
# 否则根据许可证分发的软件按“原样”提供,
# 不附带任何明示或暗示的担保或条件。
# 请参阅许可证以了解有关权限和限制的具体语言。

# 导入numpy库以进行数值计算
import numpy as np
# 导入PyTorch库以进行深度学习
import torch
# 导入tqdm库以显示进度条
import tqdm

# 从自定义模块中导入UNet1DModel
from ...models.unets.unet_1d import UNet1DModel
# 从自定义模块中导入DiffusionPipeline
from ...pipelines import DiffusionPipeline
# 从自定义模块中导入DDPMScheduler
from ...utils.dummy_pt_objects import DDPMScheduler
# 从自定义模块中导入randn_tensor函数
from ...utils.torch_utils import randn_tensor


# 定义用于值引导采样的管道类
class ValueGuidedRLPipeline(DiffusionPipeline):
    r"""
    用于从训练的扩散模型中进行值引导采样的管道,模型预测状态序列。

    该模型继承自[`DiffusionPipeline`]。请查阅超类文档以获取所有管道实现的通用方法
    (下载、保存、在特定设备上运行等)。

    参数:
        value_function ([`UNet1DModel`]):
            一个专门用于基于奖励微调轨迹的UNet。
        unet ([`UNet1DModel`]):
            用于去噪编码轨迹的UNet架构。
        scheduler ([`SchedulerMixin`]):
            用于与`unet`结合去噪编码轨迹的调度器。此应用程序的默认调度器为[`DDPMScheduler`]。
        env ():
            一个遵循OpenAI gym API的环境进行交互。目前仅Hopper有预训练模型。
    """

    # 初始化方法,接受各个组件作为参数
    def __init__(
        self,
        value_function: UNet1DModel,  # 值函数UNet模型
        unet: UNet1DModel,             # 去噪UNet模型
        scheduler: DDPMScheduler,      # 调度器
        env,                           # 环境
    ):
        super().__init__()  # 调用父类的初始化方法

        # 注册模型和调度器模块
        self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)

        # 从环境获取数据集
        self.data = env.get_dataset()
        self.means = {}  # 初始化均值字典
        # 遍历数据集的每个键
        for key in self.data.keys():
            try:
                # 计算并存储每个键的均值
                self.means[key] = self.data[key].mean()
            except:  # 捕获异常
                pass
        self.stds = {}  # 初始化标准差字典
        # 再次遍历数据集的每个键
        for key in self.data.keys():
            try:
                # 计算并存储每个键的标准差
                self.stds[key] = self.data[key].std()
            except:  # 捕获异常
                pass
        # 获取状态维度
        self.state_dim = env.observation_space.shape[0]
        # 获取动作维度
        self.action_dim = env.action_space.shape[0]

    # 归一化输入数据
    def normalize(self, x_in, key):
        return (x_in - self.means[key]) / self.stds[key]  # 根据均值和标准差归一化

    # 反归一化输入数据
    def de_normalize(self, x_in, key):
        return x_in * self.stds[key] + self.means[key]  # 根据均值和标准差反归一化
    # 定义将输入转换为 Torch 张量的方法
        def to_torch(self, x_in):
            # 检查输入是否为字典类型
            if isinstance(x_in, dict):
                # 递归地将字典中的每个值转换为 Torch 张量
                return {k: self.to_torch(v) for k, v in x_in.items()}
            # 检查输入是否为 Torch 张量
            elif torch.is_tensor(x_in):
                # 将张量移动到指定设备
                return x_in.to(self.unet.device)
            # 将输入转换为 Torch 张量,并移动到指定设备
            return torch.tensor(x_in, device=self.unet.device)
    
    # 定义重置输入状态的方法
        def reset_x0(self, x_in, cond, act_dim):
            # 遍历条件字典中的每个键值对
            for key, val in cond.items():
                # 用条件值的克隆来更新输入的特定部分
                x_in[:, key, act_dim:] = val.clone()
            # 返回更新后的输入
            return x_in
    
    # 定义运行扩散过程的方法
        def run_diffusion(self, x, conditions, n_guide_steps, scale):
            # 获取输入的批次大小
            batch_size = x.shape[0]
            # 初始化输出
            y = None
            # 遍历调度器的每个时间步
            for i in tqdm.tqdm(self.scheduler.timesteps):
                # 创建用于传递给模型的时间步批次
                timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
                # 对于每个引导步骤
                for _ in range(n_guide_steps):
                    # 启用梯度计算
                    with torch.enable_grad():
                        # 设置输入张量为需要梯度计算
                        x.requires_grad_()
    
                        # 变换维度以匹配预训练模型的输入格式
                        y = self.value_function(x.permute(0, 2, 1), timesteps).sample
                        # 计算损失的梯度
                        grad = torch.autograd.grad([y.sum()], [x])[0]
    
                        # 获取当前时间步的后验方差
                        posterior_variance = self.scheduler._get_variance(i)
                        # 计算模型的标准差
                        model_std = torch.exp(0.5 * posterior_variance)
                        # 根据标准差缩放梯度
                        grad = model_std * grad
    
                    # 对于前两个时间步,设置梯度为零
                    grad[timesteps < 2] = 0
                    # 分离计算图,防止反向传播
                    x = x.detach()
                    # 更新输入张量,增加缩放后的梯度
                    x = x + scale * grad
                    # 使用条件重置输入张量
                    x = self.reset_x0(x, conditions, self.action_dim)
    
                # 使用 UNet 模型生成前一步的样本
                prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
    
                # TODO: 验证此关键字参数的弃用情况
                # 根据调度器步骤更新输入张量
                x = self.scheduler.step(prev_x, i, x)["prev_sample"]
    
                # 将条件应用于轨迹(设置初始状态)
                x = self.reset_x0(x, conditions, self.action_dim)
                # 将输入转换为 Torch 张量
                x = self.to_torch(x)
            # 返回最终输出和生成的样本
            return x, y
    # 定义调用方法,接收观测值及其他参数
    def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
        # 归一化观测值并创建批次维度
        obs = self.normalize(obs, "observations")
        # 在第一个维度上重复观测值以形成批次
        obs = obs[None].repeat(batch_size, axis=0)
    
        # 将观测值转换为 PyTorch 张量,并创建条件字典
        conditions = {0: self.to_torch(obs)}
        # 定义输出张量的形状
        shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
    
        # 生成初始噪声并应用条件,使轨迹从当前状态开始
        x1 = randn_tensor(shape, device=self.unet.device)
        # 重置噪声张量,使其符合条件
        x = self.reset_x0(x1, conditions, self.action_dim)
        # 将张量转换为 PyTorch 格式
        x = self.to_torch(x)
    
        # 运行扩散过程以生成轨迹
        x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
    
        # 按值对输出轨迹进行排序
        sorted_idx = y.argsort(0, descending=True).squeeze()
        # 根据排序索引获取对应的值
        sorted_values = x[sorted_idx]
        # 提取行动部分
        actions = sorted_values[:, :, : self.action_dim]
        # 将张量转换为 NumPy 数组并分离
        actions = actions.detach().cpu().numpy()
        # 反归一化行动
        denorm_actions = self.de_normalize(actions, key="actions")
    
        # 选择具有最高值的行动
        if y is not None:
            # 如果存在值,引导选择索引为 0
            selected_index = 0
        else:
            # 如果没有运行值引导,随机选择一个行动
            selected_index = np.random.randint(0, batch_size)
    
        # 获取选中的反归一化行动
        denorm_actions = denorm_actions[selected_index, 0]
        # 返回最终选定的行动
        return denorm_actions

.\diffusers\experimental\rl\__init__.py

# 从当前包中导入 ValueGuidedRLPipeline 类
from .value_guided_sampling import ValueGuidedRLPipeline

.\diffusers\experimental\__init__.py

# 从当前模块导入 ValueGuidedRLPipeline 类
from .rl import ValueGuidedRLPipeline

.\diffusers\image_processor.py

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 版权声明,说明该代码的版权归 HuggingFace 团队所有
#
# Licensed under the Apache License, Version 2.0 (the "License");
# 在遵守许可证的前提下,可以使用此文件
# you may not use this file except in compliance with the License.
# 使用本文件必须遵循许可证的条款
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
# 可以在指定的网址获取许可证副本
#
# Unless required by applicable law or agreed to in writing, software
# 在法律要求或书面协议下,软件的分发是按“现状”基础进行的
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 没有任何明示或暗示的担保或条件
# See the License for the specific language governing permissions and
# 查看许可证以获取适用的权限和限制
# limitations under the License.

import math
# 导入数学库,用于数学运算
import warnings
# 导入警告库,用于发出警告信息
from typing import List, Optional, Tuple, Union
# 从 typing 模块导入类型注解,便于类型检查

import numpy as np
# 导入 numpy 库,用于数组操作
import PIL.Image
# 导入 PIL.Image,用于图像处理
import torch
# 导入 PyTorch 库,用于深度学习操作
import torch.nn.functional as F
# 导入 PyTorch 的函数式API,用于神经网络操作
from PIL import Image, ImageFilter, ImageOps
# 从 PIL 导入图像处理相关类

from .configuration_utils import ConfigMixin, register_to_config
# 从配置工具模块导入 ConfigMixin 和 register_to_config,用于配置管理
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
# 从工具模块导入常量和函数

PipelineImageInput = Union[
    PIL.Image.Image,
    np.ndarray,
    torch.Tensor,
    List[PIL.Image.Image],
    List[np.ndarray],
    List[torch.Tensor],
]
# 定义图像输入的类型,可以是单个图像或图像列表

PipelineDepthInput = PipelineImageInput
# 深度输入类型与图像输入类型相同

def is_valid_image(image):
    # 检查输入是否为有效图像
    return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
    # 如果是 PIL 图像,或 2D/3D 的 numpy 数组或 PyTorch 张量,则返回 True

def is_valid_image_imagelist(images):
    # 检查图像输入是否为支持的格式,支持以下三种格式:
    # (1) 4D 的 PyTorch 张量或 numpy 数组
    # (2) 有效图像:PIL.Image.Image,2D np.ndarray 或 torch.Tensor(灰度图像),3D np.ndarray 或 torch.Tensor
    # (3) 有效图像列表
    if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
        return True
        # 如果是 4D 的 numpy 数组或 PyTorch 张量,返回 True
    elif is_valid_image(images):
        return True
        # 如果是有效的单个图像,返回 True
    elif isinstance(images, list):
        return all(is_valid_image(image) for image in images)
        # 如果是列表,检查列表中每个图像是否有效,全部有效则返回 True
    return False
    # 如果不满足以上条件,返回 False

class VaeImageProcessor(ConfigMixin):
    # 定义 VAE 图像处理器类,继承自 ConfigMixin
    """
    Image processor for VAE.
    # VAE 的图像处理器
    # 参数列表,定义该类或函数的输入参数
        Args:
            do_resize (`bool`, *optional*, defaults to `True`):  # 是否将图像的高度和宽度缩放到 `vae_scale_factor` 的倍数
                Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
                `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
            vae_scale_factor (`int`, *optional*, defaults to `8`):  # VAE缩放因子,影响图像的缩放行为
                VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
            resample (`str`, *optional*, defaults to `lanczos`):  # 指定图像缩放时使用的重采样滤波器
                Resampling filter to use when resizing the image.
            do_normalize (`bool`, *optional*, defaults to `True`):  # 是否将图像归一化到[-1,1]的范围
                Whether to normalize the image to [-1,1].
            do_binarize (`bool`, *optional*, defaults to `False`):  # 是否将图像二值化为0或1
                Whether to binarize the image to 0/1.
            do_convert_rgb (`bool`, *optional*, defaults to be `False`):  # 是否将图像转换为RGB格式
                Whether to convert the images to RGB format.
            do_convert_grayscale (`bool`, *optional*, defaults to be `False`):  # 是否将图像转换为灰度格式
                Whether to convert the images to grayscale format.
        """
    
        config_name = CONFIG_NAME  # 将配置名称赋值给config_name变量
    
        @register_to_config  # 装饰器,将该函数注册为配置项
        def __init__((
            self,
            do_resize: bool = True,  # 初始化时的参数,是否缩放图像
            vae_scale_factor: int = 8,  # VAE缩放因子,默认值为8
            vae_latent_channels: int = 4,  # VAE潜在通道数,默认值为4
            resample: str = "lanczos",  # 重采样滤波器的默认值为lanczos
            do_normalize: bool = True,  # 初始化时的参数,是否归一化图像
            do_binarize: bool = False,  # 初始化时的参数,是否二值化图像
            do_convert_rgb: bool = False,  # 初始化时的参数,是否转换为RGB格式
            do_convert_grayscale: bool = False,  # 初始化时的参数,是否转换为灰度格式
        ):
            super().__init__()  # 调用父类的初始化方法
            if do_convert_rgb and do_convert_grayscale:  # 检查同时设置RGB和灰度格式的情况
                raise ValueError(  # 抛出值错误,提示不允许同时设置为True
                    "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
                    " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
                    " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
                )
    
        @staticmethod  # 静态方法,不依赖于类的实例
        def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:  # 将numpy数组转换为PIL图像列表
            """
            Convert a numpy image or a batch of images to a PIL image.
            """
            if images.ndim == 3:  # 检查图像是否为三维数组(单个图像)
                images = images[None, ...]  # 将其扩展为四维数组
            images = (images * 255).round().astype("uint8")  # 将图像值从[0, 1]转换为[0, 255]并转为无符号8位整数
            if images.shape[-1] == 1:  # 如果是单通道(灰度)图像
                # special case for grayscale (single channel) images
                pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]  # 转换为PIL灰度图像
            else:  # 处理多通道图像
                pil_images = [Image.fromarray(image) for image in images]  # 转换为PIL图像
    
            return pil_images  # 返回PIL图像列表
    
        @staticmethod  # 静态方法,不依赖于类的实例
        def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:  # 将PIL图像转换为numpy数组
            """
            Convert a PIL image or a list of PIL images to NumPy arrays.
            """
            if not isinstance(images, list):  # 如果输入不是列表
                images = [images]  # 将其转换为列表
            images = [np.array(image).astype(np.float32) / 255.0 for image in images]  # 转换为numpy数组并归一化
            images = np.stack(images, axis=0)  # 在新的轴上堆叠数组,形成四维数组
    
            return images  # 返回numpy数组
    # 将 NumPy 图像转换为 PyTorch 张量
    def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
        # 文档字符串:将 NumPy 图像转换为 PyTorch 张量
        """
        Convert a NumPy image to a PyTorch tensor.
        """
        # 检查图像的维度是否为 3(即 H x W x C)
        if images.ndim == 3:
            # 如果是 3 维,添加一个新的维度以适应模型输入
            images = images[..., None]
    
        # 将 NumPy 数组转置并转换为 PyTorch 张量
        images = torch.from_numpy(images.transpose(0, 3, 1, 2))
        # 返回转换后的张量
        return images
    
    # 将 PyTorch 张量转换为 NumPy 图像
    @staticmethod
    def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
        # 文档字符串:将 PyTorch 张量转换为 NumPy 图像
        """
        Convert a PyTorch tensor to a NumPy image.
        """
        # 将张量移至 CPU,并调整维度顺序为 H x W x C,转换为浮点型并转为 NumPy 数组
        images = images.cpu().permute(0, 2, 3, 1).float().numpy()
        # 返回转换后的 NumPy 数组
        return images
    
    # 规范化图像数组到 [-1,1] 范围
    @staticmethod
    def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
        # 文档字符串:将图像数组规范化到 [-1,1] 范围
        """
        Normalize an image array to [-1,1].
        """
        # 将图像数组的值范围缩放到 [-1, 1]
        return 2.0 * images - 1.0
    
    # 将图像数组反规范化到 [0,1] 范围
    @staticmethod
    def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
        # 文档字符串:将图像数组反规范化到 [0,1] 范围
        """
        Denormalize an image array to [0,1].
        """
        # 将图像数组的值范围调整到 [0, 1] 并限制在该范围内
        return (images / 2 + 0.5).clamp(0, 1)
    
    # 将 PIL 图像转换为 RGB 格式
    @staticmethod
    def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
        # 文档字符串:将 PIL 图像转换为 RGB 格式
        """
        Converts a PIL image to RGB format.
        """
        # 使用 PIL 库将图像转换为 RGB 格式
        image = image.convert("RGB")
    
        # 返回转换后的图像
        return image
    
    # 将 PIL 图像转换为灰度格式
    @staticmethod
    def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
        # 文档字符串:将 PIL 图像转换为灰度格式
        """
        Converts a PIL image to grayscale format.
        """
        # 使用 PIL 库将图像转换为灰度格式
        image = image.convert("L")
    
        # 返回转换后的图像
        return image
    
    # 对图像应用高斯模糊
    @staticmethod
    def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
        # 文档字符串:对图像应用高斯模糊
        """
        Applies Gaussian blur to an image.
        """
        # 使用 PIL 库对图像应用高斯模糊
        image = image.filter(ImageFilter.GaussianBlur(blur_factor))
    
        # 返回模糊后的图像
        return image
    
    # 调整图像大小并填充
    @staticmethod
    def _resize_and_fill(
        self,
        image: PIL.Image.Image,
        width: int,
        height: int,
    ) -> PIL.Image.Image:  # 返回处理后的图像对象
        """  # 文档字符串,描述函数的作用
        Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
        the image within the dimensions, filling empty with data from image.  # 说明功能:调整图像大小并居中

        Args:  # 参数说明
            image: The image to resize.  # 待调整大小的图像
            width: The width to resize the image to.  # 目标宽度
            height: The height to resize the image to.  # 目标高度
        """  # 文档字符串结束

        ratio = width / height  # 计算目标宽高比
        src_ratio = image.width / image.height  # 计算源图像的宽高比

        src_w = width if ratio < src_ratio else image.width * height // image.height  # 计算源宽度
        src_h = height if ratio >= src_ratio else image.height * width // image.width  # 计算源高度

        resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])  # 调整图像大小
        res = Image.new("RGB", (width, height))  # 创建新的 RGB 图像,尺寸为目标宽高
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))  # 将调整后的图像居中粘贴

        if ratio < src_ratio:  # 如果目标宽高比小于源宽高比
            fill_height = height // 2 - src_h // 2  # 计算需要填充的高度
            if fill_height > 0:  # 如果需要填充高度大于零
                res.paste(resized.resize((width, fill_height), box=(0, 0)), box=(0, 0))  # 填充上方空白
                res.paste(  # 填充下方空白
                    resized.resize((width, fill_height), box=(0, resized.height)), 
                    box=(0, fill_height + src_h),
                )
        elif ratio > src_ratio:  # 如果目标宽高比大于源宽高比
            fill_width = width // 2 - src_w // 2  # 计算需要填充的宽度
            if fill_width > 0:  # 如果需要填充宽度大于零
                res.paste(resized.resize((fill_width, height), box=(0, 0)), box=(0, 0))  # 填充左侧空白
                res.paste(  # 填充右侧空白
                    resized.resize((fill_width, height), box=(resized.width, 0)), 
                    box=(fill_width + src_w, 0),
                )

        return res  # 返回最终调整后的图像

    def _resize_and_crop(  # 定义一个私有方法,用于调整大小并裁剪
        self,
        image: PIL.Image.Image,  # 输入的图像
        width: int,  # 目标宽度
        height: int,  # 目标高度
    ) -> PIL.Image.Image:  # 返回处理后的图像对象
        """  # 文档字符串,描述函数的作用
        Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
        the image within the dimensions, cropping the excess.  # 说明功能:调整大小并裁剪

        Args:  # 参数说明
            image: The image to resize.  # 待调整大小的图像
            width: The width to resize the image to.  # 目标宽度
            height: The height to resize the image to.  # 目标高度
        """  # 文档字符串结束

        ratio = width / height  # 计算目标宽高比
        src_ratio = image.width / image.height  # 计算源图像的宽高比

        src_w = width if ratio > src_ratio else image.width * height // image.height  # 计算源宽度
        src_h = height if ratio <= src_ratio else image.height * width // image.width  # 计算源高度

        resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])  # 调整图像大小
        res = Image.new("RGB", (width, height))  # 创建新的 RGB 图像,尺寸为目标宽高
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))  # 将调整后的图像居中粘贴
        return res  # 返回最终调整后的图像

    def resize(  # 定义调整大小的公共方法
        self,
        image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],  # 输入的图像类型
        height: int,  # 目标高度
        width: int,  # 目标宽度
        resize_mode: str = "default",  # 指定调整大小模式,默认为 "default"
    ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
        """
        调整图像大小。

        参数:
            image (`PIL.Image.Image`, `np.ndarray` 或 `torch.Tensor`):
                输入图像,可以是 PIL 图像、numpy 数组或 pytorch 张量。
            height (`int`):
                要调整的高度。
            width (`int`):
                要调整的宽度。
            resize_mode (`str`, *可选*, 默认为 `default`):
                使用的调整模式,可以是 `default` 或 `fill`。如果是 `default`,将调整图像以适应
                指定的宽度和高度,可能不保持原始纵横比。如果是 `fill`,将调整图像以适应
                指定的宽度和高度,保持纵横比,然后将图像居中填充空白。如果是 `crop`,将调整
                图像以适应指定的宽度和高度,保持纵横比,然后将图像居中,裁剪多余部分。请注意,
                `fill` 和 `crop` 只支持 PIL 图像输入。

        返回:
            `PIL.Image.Image`, `np.ndarray` 或 `torch.Tensor`:
                调整后的图像。
        """
        # 检查调整模式是否有效,并确保图像为 PIL 图像
        if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
            raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
        # 如果输入是 PIL 图像
        if isinstance(image, PIL.Image.Image):
            # 如果调整模式是默认模式,调整图像大小
            if resize_mode == "default":
                image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
            # 如果调整模式是填充,调用填充函数
            elif resize_mode == "fill":
                image = self._resize_and_fill(image, width, height)
            # 如果调整模式是裁剪,调用裁剪函数
            elif resize_mode == "crop":
                image = self._resize_and_crop(image, width, height)
            # 如果调整模式不支持,抛出错误
            else:
                raise ValueError(f"resize_mode {resize_mode} is not supported")

        # 如果输入是 PyTorch 张量
        elif isinstance(image, torch.Tensor):
            # 使用插值调整张量大小
            image = torch.nn.functional.interpolate(
                image,
                size=(height, width),
            )
        # 如果输入是 numpy 数组
        elif isinstance(image, np.ndarray):
            # 将 numpy 数组转换为 PyTorch 张量
            image = self.numpy_to_pt(image)
            # 使用插值调整张量大小
            image = torch.nn.functional.interpolate(
                image,
                size=(height, width),
            )
            # 将张量转换回 numpy 数组
            image = self.pt_to_numpy(image)
        # 返回调整后的图像
        return image

    def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
        """
        创建掩膜。

        参数:
            image (`PIL.Image.Image`):
                输入图像,应该是 PIL 图像。

        返回:
            `PIL.Image.Image`:
                二值化图像。值小于 0.5 的设置为 0,值大于等于 0.5 的设置为 1。
        """
        # 将小于 0.5 的像素值设置为 0
        image[image < 0.5] = 0
        # 将大于等于 0.5 的像素值设置为 1
        image[image >= 0.5] = 1

        # 返回二值化后的图像
        return image
    # 定义一个方法,获取图像的默认高度和宽度
    def get_default_height_width(
            self,
            image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
            height: Optional[int] = None,
            width: Optional[int] = None,
        ) -> Tuple[int, int]:
            """
            该函数返回按 `vae_scale_factor` 下调到下一个整数倍的高度和宽度。
    
            参数:
                image(`PIL.Image.Image`, `np.ndarray` 或 `torch.Tensor`):
                    输入图像,可以是 PIL 图像、numpy 数组或 pytorch 张量。若为 numpy 数组,应该具有
                    形状 `[batch, height, width]` 或 `[batch, height, width, channel]`;若为 pytorch 张量,应该
                    具有形状 `[batch, channel, height, width]`。
                height (`int`, *可选*, 默认为 `None`):
                    预处理图像的高度。如果为 `None`,将使用输入图像的高度。
                width (`int`, *可选*, 默认为 `None`):
                    预处理的宽度。如果为 `None`,将使用输入图像的宽度。
            """
    
            # 如果高度为 None,尝试从图像中获取高度
            if height is None:
                # 如果图像是 PIL 图像,使用其高度
                if isinstance(image, PIL.Image.Image):
                    height = image.height
                # 如果图像是 pytorch 张量,使用其形状中的高度
                elif isinstance(image, torch.Tensor):
                    height = image.shape[2]
                # 否则,假设是 numpy 数组,使用其形状中的高度
                else:
                    height = image.shape[1]
    
            # 如果宽度为 None,尝试从图像中获取宽度
            if width is None:
                # 如果图像是 PIL 图像,使用其宽度
                if isinstance(image, PIL.Image.Image):
                    width = image.width
                # 如果图像是 pytorch 张量,使用其形状中的宽度
                elif isinstance(image, torch.Tensor):
                    width = image.shape[3]
                # 否则,假设是 numpy 数组,使用其形状中的宽度
                else:
                    width = image.shape[2]
    
            # 将宽度和高度调整为 vae_scale_factor 的整数倍
            width, height = (
                x - x % self.config.vae_scale_factor for x in (width, height)
            )  # 调整为 vae_scale_factor 的整数倍
    
            # 返回调整后的高度和宽度
            return height, width
    
    # 定义一个方法,预处理图像
    def preprocess(
            self,
            image: PipelineImageInput,
            height: Optional[int] = None,
            width: Optional[int] = None,
            resize_mode: str = "default",  # "default", "fill", "crop"
            crops_coords: Optional[Tuple[int, int, int, int]] = None,
    # 定义一个方法,后处理图像
    def postprocess(
            self,
            image: torch.Tensor,
            output_type: str = "pil",
            do_denormalize: Optional[List[bool]] = None,
    # 返回处理后的图像,类型为 PIL.Image.Image、np.ndarray 或 torch.Tensor
    ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
        """
        处理从张量输出的图像,转换为 `output_type`。
    
        参数:
            image (`torch.Tensor`):
                输入的图像,应该是形状为 `B x C x H x W` 的 pytorch 张量。
            output_type (`str`, *可选*, 默认为 `pil`):
                图像的输出类型,可以是 `pil`、`np`、`pt`、`latent` 之一。
            do_denormalize (`List[bool]`, *可选*, 默认为 `None`):
                是否将图像反归一化到 [0,1]。如果为 `None`,将使用 `VaeImageProcessor` 配置中的 `do_normalize` 值。
    
        返回:
            `PIL.Image.Image`、`np.ndarray` 或 `torch.Tensor`:
                处理后的图像。
        """
        # 检查输入的图像是否为 pytorch 张量
        if not isinstance(image, torch.Tensor):
            # 抛出值错误,如果输入格式不正确
            raise ValueError(
                f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
            )
        # 检查输出类型是否在支持的列表中
        if output_type not in ["latent", "pt", "np", "pil"]:
            # 创建弃用信息,说明当前输出类型已过时
            deprecation_message = (
                f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
                "`pil`, `np`, `pt`, `latent`"
            )
            # 调用弃用函数,记录警告信息
            deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
            # 将输出类型设置为默认值
            output_type = "np"
    
        # 如果输出类型为 "latent",直接返回输入图像
        if output_type == "latent":
            return image
    
        # 如果 do_denormalize 为 None,则根据配置设置其值
        if do_denormalize is None:
            do_denormalize = [self.config.do_normalize] * image.shape[0]
    
        # 通过 denormalize 方法处理图像,生成新张量
        image = torch.stack(
            [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
        )
    
        # 如果输出类型为 "pt",返回处理后的图像张量
        if output_type == "pt":
            return image
    
        # 将处理后的图像从 pytorch 张量转换为 numpy 数组
        image = self.pt_to_numpy(image)
    
        # 如果输出类型为 "np",返回 numpy 数组
        if output_type == "np":
            return image
    
        # 如果输出类型为 "pil",将 numpy 数组转换为 PIL 图像并返回
        if output_type == "pil":
            return self.numpy_to_pil(image)
    
    # 定义应用遮罩的函数,接受多个参数
    def apply_overlay(
        self,
        mask: PIL.Image.Image,
        init_image: PIL.Image.Image,
        image: PIL.Image.Image,
        crop_coords: Optional[Tuple[int, int, int, int]] = None,
    ) -> PIL.Image.Image:
        """
        将修复输出叠加到原始图像上
        """

        # 获取原始图像的宽度和高度
        width, height = image.width, image.height

        # 调整初始图像和掩膜图像到与原始图像相同的大小
        init_image = self.resize(init_image, width=width, height=height)
        mask = self.resize(mask, width=width, height=height)

        # 创建一个新的 RGBA 图像,用于存放初始图像的掩膜效果
        init_image_masked = PIL.Image.new("RGBa", (width, height))
        # 将初始图像按掩膜方式粘贴到新的图像上,掩膜为掩码的反转图像
        init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
        # 将初始图像掩膜转换为 RGBA 格式
        init_image_masked = init_image_masked.convert("RGBA")

        # 如果给定了裁剪坐标
        if crop_coords is not None:
            # 解包裁剪坐标
            x, y, x2, y2 = crop_coords
            # 计算裁剪区域的宽度和高度
            w = x2 - x
            h = y2 - y
            # 创建一个新的 RGBA 图像作为基础图像
            base_image = PIL.Image.new("RGBA", (width, height))
            # 将原始图像调整到裁剪区域的大小
            image = self.resize(image, height=h, width=w, resize_mode="crop")
            # 将调整后的图像粘贴到基础图像的指定位置
            base_image.paste(image, (x, y))
            # 将基础图像转换为 RGB 格式
            image = base_image.convert("RGB")

        # 将图像转换为 RGBA 格式
        image = image.convert("RGBA")
        # 将初始图像的掩膜叠加到当前图像上
        image.alpha_composite(init_image_masked)
        # 将结果图像转换为 RGB 格式
        image = image.convert("RGB")

        # 返回最终的图像
        return image
# 定义 VAE LDM3D 图像处理器类,继承自 VaeImageProcessor
class VaeImageProcessorLDM3D(VaeImageProcessor):
    """
    VAE LDM3D 的图像处理器。

    参数:
        do_resize (`bool`, *可选*, 默认值为 `True`):
            是否将图像的(高度,宽度)尺寸缩小到 `vae_scale_factor` 的倍数。
        vae_scale_factor (`int`, *可选*, 默认值为 `8`):
            VAE 缩放因子。如果 `do_resize` 为 `True`,图像会自动调整为该因子的倍数。
        resample (`str`, *可选*, 默认值为 `lanczos`):
            在调整图像大小时使用的重采样滤波器。
        do_normalize (`bool`, *可选*, 默认值为 `True`):
            是否将图像归一化到 [-1,1] 范围内。
    """

    # 配置名称常量
    config_name = CONFIG_NAME

    # 注册到配置中的初始化方法
    @register_to_config
    def __init__(
        self,
        do_resize: bool = True,        # 是否调整大小,默认为 True
        vae_scale_factor: int = 8,     # VAE 缩放因子,默认为 8
        resample: str = "lanczos",     # 重采样方法,默认为 lanczos
        do_normalize: bool = True,     # 是否归一化,默认为 True
    ):
        # 调用父类的初始化方法
        super().__init__()

    # 静态方法:将 NumPy 图像或图像批次转换为 PIL 图像
    @staticmethod
    def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
        """
        将 NumPy 图像或图像批次转换为 PIL 图像。
        """
        # 如果输入是 3 维数组,添加一个新的维度
        if images.ndim == 3:
            images = images[None, ...]
        # 将图像数据放大到 255,四舍五入并转换为无符号 8 位整数
        images = (images * 255).round().astype("uint8")
        # 检查最后一个维度是否为 1(灰度图像)
        if images.shape[-1] == 1:
            # 特殊情况处理灰度(单通道)图像
            pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
        else:
            # 处理 RGB 图像(提取前三个通道)
            pil_images = [Image.fromarray(image[:, :, :3]) for image in images]

        # 返回 PIL 图像列表
        return pil_images

    # 静态方法:将 PIL 图像或图像列表转换为 NumPy 数组
    @staticmethod
    def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
        """
        将 PIL 图像或图像列表转换为 NumPy 数组。
        """
        # 如果输入不是列表,将其转换为单元素列表
        if not isinstance(images, list):
            images = [images]

        # 将每个 PIL 图像转换为 NumPy 数组,并归一化到 [0, 1] 范围
        images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
        # 将图像堆叠成一个 4D 数组
        images = np.stack(images, axis=0)
        # 返回 NumPy 数组
        return images

    # 静态方法:将 RGB 深度图像转换为深度图
    @staticmethod
    def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
        """
        参数:
            image: RGB 类似深度图像

        返回: 深度图
        """
        # 提取深度图,使用红色通道和蓝色通道计算深度值
        return image[:, :, 1] * 2**8 + image[:, :, 2]
    # 将 NumPy 深度图像或图像批处理转换为 PIL 图像
    def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
        # 文档字符串,说明函数的作用
        """
        Convert a NumPy depth image or a batch of images to a PIL image.
        """
        # 检查输入图像的维度是否为 3,如果是,则在前面添加一个维度
        if images.ndim == 3:
            images = images[None, ...]
        # 从输入图像中提取深度信息,假设深度数据在最后的几维
        images_depth = images[:, :, :, 3:]
        # 检查最后一个维度是否为 6,表示有额外的信息
        if images.shape[-1] == 6:
            # 将深度值范围缩放到 0-255,并转换为无符号 8 位整数
            images_depth = (images_depth * 255).round().astype("uint8")
            # 将每个深度图像转换为 PIL 图像,使用特定模式
            pil_images = [
                Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
            ]
        # 检查最后一个维度是否为 4,表示仅有深度数据
        elif images.shape[-1] == 4:
            # 将深度值范围缩放到 0-65535,并转换为无符号 16 位整数
            images_depth = (images_depth * 65535.0).astype(np.uint16)
            # 将每个深度图像转换为 PIL 图像,使用特定模式
            pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
        # 如果输入的形状不符合要求,抛出异常
        else:
            raise Exception("Not supported")
    
        # 返回生成的 PIL 图像列表
        return pil_images
    
    # 处理图像的后处理函数,接受图像和输出类型等参数
    def postprocess(
        self,
        image: torch.Tensor,
        output_type: str = "pil",
        do_denormalize: Optional[List[bool]] = None,
    ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
        """
        处理图像输出,将张量转换为 `output_type` 格式。

        参数:
            image (`torch.Tensor`):
                输入的图像,应该是形状为 `B x C x H x W` 的 PyTorch 张量。
            output_type (`str`, *可选*, 默认为 `pil`):
                图像的输出类型,可以是 `pil`、`np`、`pt` 或 `latent` 之一。
            do_denormalize (`List[bool]`, *可选*, 默认为 `None`):
                是否将图像反归一化到 [0,1]。如果为 `None`,将使用 `VaeImageProcessor` 配置中的 `do_normalize` 值。

        返回:
            `PIL.Image.Image`、`np.ndarray` 或 `torch.Tensor`:
                处理后的图像。
        """
        # 检查输入图像是否为 PyTorch 张量,如果不是,则抛出错误
        if not isinstance(image, torch.Tensor):
            raise ValueError(
                f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
            )
        # 检查输出类型是否在支持的选项中,如果不在,发送弃用警告并设置为默认值 `np`
        if output_type not in ["latent", "pt", "np", "pil"]:
            deprecation_message = (
                f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
                "`pil`, `np`, `pt`, `latent`"
            )
            deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
            output_type = "np"  # 设置输出类型为默认的 `np`

        # 如果反归一化标志为 None,则根据配置初始化为与图像批大小相同的列表
        if do_denormalize is None:
            do_denormalize = [self.config.do_normalize] * image.shape[0]

        # 对每个图像进行反归一化处理,构建处理后的图像堆叠
        image = torch.stack(
            [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
        )

        # 将处理后的图像从 PyTorch 张量转换为 NumPy 数组
        image = self.pt_to_numpy(image)

        # 根据输出类型返回相应的处理结果
        if output_type == "np":
            # 如果图像的最后一个维度为 6,则提取深度图
            if image.shape[-1] == 6:
                image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
            else:
                # 否则直接提取最后三个通道作为深度图
                image_depth = image[:, :, :, 3:]
            return image[:, :, :, :3], image_depth  # 返回 RGB 图像和深度图

        if output_type == "pil":
            # 将 NumPy 数组转换为 PIL 图像并返回
            return self.numpy_to_pil(image), self.numpy_to_depth(image)
        else:
            # 如果输出类型不被支持,抛出异常
            raise Exception(f"This type {output_type} is not supported")

    def preprocess(
        self,
        rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
        depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
        height: Optional[int] = None,
        width: Optional[int] = None,
        target_res: Optional[int] = None,
# 定义一个处理 IP 适配器图像掩码的图像处理器类
class IPAdapterMaskProcessor(VaeImageProcessor):
    """
    IP适配器图像掩码的图像处理器。

    参数:
        do_resize (`bool`, *可选*, 默认为 `True`):
            是否将图像的高度和宽度缩小为 `vae_scale_factor` 的倍数。
        vae_scale_factor (`int`, *可选*, 默认为 `8`):
            VAE缩放因子。如果 `do_resize` 为 `True`,图像将自动调整为该因子的倍数。
        resample (`str`, *可选*, 默认为 `lanczos`):
            调整图像大小时使用的重采样滤波器。
        do_normalize (`bool`, *可选*, 默认为 `False`):
            是否将图像标准化到 [-1,1]。
        do_binarize (`bool`, *可选*, 默认为 `True`):
            是否将图像二值化为 0/1。
        do_convert_grayscale (`bool`, *可选*, 默认为 `True`):
            是否将图像转换为灰度格式。

    """

    # 配置名称常量
    config_name = CONFIG_NAME

    @register_to_config
    # 初始化函数,设置处理器的参数
    def __init__(
        self,
        do_resize: bool = True,  # 是否缩放图像
        vae_scale_factor: int = 8,  # VAE缩放因子
        resample: str = "lanczos",  # 重采样滤波器
        do_normalize: bool = False,  # 是否标准化图像
        do_binarize: bool = True,  # 是否二值化图像
        do_convert_grayscale: bool = True,  # 是否转换为灰度图像
    ):
        # 调用父类的初始化方法,传递参数
        super().__init__(
            do_resize=do_resize,
            vae_scale_factor=vae_scale_factor,
            resample=resample,
            do_normalize=do_normalize,
            do_binarize=do_binarize,
            do_convert_grayscale=do_convert_grayscale,
        )

    @staticmethod
    # 定义 downsample 函数,输入为掩码张量和其他参数,输出为下采样后的掩码张量
        def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
            """
            将提供的掩码张量下采样到与缩放点积注意力预期的维度匹配。如果掩码的长宽比与输出图像的长宽比不匹配,则发出警告。
    
            参数:
                mask (`torch.Tensor`):
                    由 `IPAdapterMaskProcessor.preprocess()` 生成的输入掩码张量。
                batch_size (`int`):
                    批处理大小。
                num_queries (`int`):
                    查询的数量。
                value_embed_dim (`int`):
                    值嵌入的维度。
    
            返回:
                `torch.Tensor`:
                    下采样后的掩码张量。
    
            """
            # 获取掩码的高度和宽度
            o_h = mask.shape[1]
            o_w = mask.shape[2]
            # 计算掩码的长宽比
            ratio = o_w / o_h
            # 计算下采样后掩码的高度
            mask_h = int(math.sqrt(num_queries / ratio))
            # 根据掩码高度调整,确保可以容纳所有查询
            mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
            # 计算下采样后掩码的宽度
            mask_w = num_queries // mask_h
    
            # 对掩码进行插值下采样
            mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
    
            # 重复掩码以匹配批处理大小
            if mask_downsample.shape[0] < batch_size:
                mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
    
            # 调整掩码形状为 (batch_size, -1)
            mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
    
            # 计算下采样后的区域大小
            downsampled_area = mask_h * mask_w
            # 如果输出图像和掩码的长宽比不相同,发出警告并填充张量
            if downsampled_area < num_queries:
                warnings.warn(
                    "掩码的长宽比与输出图像的长宽比不匹配。"
                    "请更新掩码或调整输出大小以获得最佳性能。",
                    UserWarning,
                )
                mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
            # 如果下采样后的掩码形状大于查询数量,则截断最后的嵌入
            if downsampled_area > num_queries:
                warnings.warn(
                    "掩码的长宽比与输出图像的长宽比不匹配。"
                    "请更新掩码或调整输出大小以获得最佳性能。",
                    UserWarning,
                )
                mask_downsample = mask_downsample[:, :num_queries]
    
            # 重复最后一个维度以匹配 SDPA 输出形状
            mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
                1, 1, value_embed_dim
            )
    
            # 返回下采样后的掩码
            return mask_downsample
# PixArt 图像处理器类,继承自 VaeImageProcessor
class PixArtImageProcessor(VaeImageProcessor):
    """
    PixArt 图像的调整大小和裁剪处理器。

    参数:
        do_resize (`bool`, *可选*, 默认为 `True`):
            是否将图像的(高度,宽度)尺寸缩小为 `vae_scale_factor` 的倍数。可以接受
            来自 [`image_processor.VaeImageProcessor.preprocess`] 方法的 `height` 和 `width` 参数。
        vae_scale_factor (`int`, *可选*, 默认为 `8`):
            VAE 缩放因子。如果 `do_resize` 为 `True`,图像会自动调整为该因子的倍数。
        resample (`str`, *可选*, 默认为 `lanczos`):
            调整图像大小时使用的重采样滤镜。
        do_normalize (`bool`, *可选*, 默认为 `True`):
            是否将图像标准化到 [-1,1]。
        do_binarize (`bool`, *可选*, 默认为 `False`):
            是否将图像二值化为 0/1。
        do_convert_rgb (`bool`, *可选*, 默认为 `False`):
            是否将图像转换为 RGB 格式。
        do_convert_grayscale (`bool`, *可选*, 默认为 `False`):
            是否将图像转换为灰度格式。
    """

    # 注册到配置中的初始化方法
    @register_to_config
    def __init__(
        self,
        do_resize: bool = True,  # 是否调整大小
        vae_scale_factor: int = 8,  # VAE 缩放因子
        resample: str = "lanczos",  # 重采样滤镜
        do_normalize: bool = True,  # 是否标准化
        do_binarize: bool = False,  # 是否二值化
        do_convert_grayscale: bool = False,  # 是否转换为灰度
    ):
        # 调用父类初始化方法,传递参数
        super().__init__(
            do_resize=do_resize,
            vae_scale_factor=vae_scale_factor,
            resample=resample,
            do_normalize=do_normalize,
            do_binarize=do_binarize,
            do_convert_grayscale=do_convert_grayscale,
        )

    # 静态方法,分类高度和宽度到最近的比例
    @staticmethod
    def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
        """返回分箱的高度和宽度。"""
        ar = float(height / width)  # 计算高度与宽度的比率
        closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))  # 找到最接近的比率
        default_hw = ratios[closest_ratio]  # 获取该比率对应的默认高度和宽度
        return int(default_hw[0]), int(default_hw[1])  # 返回整数形式的高度和宽度

    @staticmethod
    # 定义一个函数,调整张量的大小并裁剪到指定的宽度和高度
    def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
        # 获取原始张量的高度和宽度
        orig_height, orig_width = samples.shape[2], samples.shape[3]
    
        # 检查是否需要调整大小
        if orig_height != new_height or orig_width != new_width:
            # 计算调整大小的比例
            ratio = max(new_height / orig_height, new_width / orig_width)
            # 计算调整后的宽度和高度
            resized_width = int(orig_width * ratio)
            resized_height = int(orig_height * ratio)
    
            # 调整大小
            samples = F.interpolate(
                samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
            )
    
            # 计算中心裁剪的起始和结束坐标
            start_x = (resized_width - new_width) // 2
            end_x = start_x + new_width
            start_y = (resized_height - new_height) // 2
            end_y = start_y + new_height
            # 裁剪样本到目标大小
            samples = samples[:, :, start_y:end_y, start_x:end_x]
    
        # 返回调整和裁剪后的张量
        return samples

.\diffusers\loaders\ip_adapter.py

# 版权声明,2024年HuggingFace团队保留所有权利
# 
# 根据Apache许可证第2.0版(“许可证”)授权;
# 除非遵守许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律要求或书面同意,软件在“原样”基础上分发,
# 不附带任何形式的明示或暗示的担保或条件。
# 请参阅许可证以了解有关权限和
# 限制的具体条款。

# 从pathlib模块导入Path类,用于路径操作
from pathlib import Path
# 从typing模块导入各种类型提示
from typing import Dict, List, Optional, Union

# 导入torch库
import torch
# 导入torch的功能模块,用于神经网络操作
import torch.nn.functional as F
# 从huggingface_hub.utils导入验证函数,用于验证HF Hub参数
from huggingface_hub.utils import validate_hf_hub_args
# 从safetensors导入安全打开函数
from safetensors import safe_open

# 从本地模型工具导入低CPU内存使用默认值和加载状态字典的函数
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
# 从本地工具导入多个实用函数和变量
from ..utils import (
    USE_PEFT_BACKEND,  # 是否使用PEFT后端的标志
    _get_model_file,  # 获取模型文件的函数
    is_accelerate_available,  # 检查加速库是否可用
    is_torch_version,  # 检查Torch版本的函数
    is_transformers_available,  # 检查Transformers库是否可用
    logging,  # 导入日志模块
)
# 从unet_loader_utils模块导入可能扩展LoRA缩放的函数
from .unet_loader_utils import _maybe_expand_lora_scales

# 如果Transformers库可用,则导入相关的类和函数
if is_transformers_available():
    # 导入CLIP图像处理器和带投影的CLIP视觉模型
    from transformers import (
        CLIPImageProcessor,
        CLIPVisionModelWithProjection,
    )

    # 导入注意力处理器类
    from ..models.attention_processor import (
        AttnProcessor,  # 注意力处理器
        AttnProcessor2_0,  # 版本2.0的注意力处理器
        IPAdapterAttnProcessor,  # IP适配器注意力处理器
        IPAdapterAttnProcessor2_0,  # 版本2.0的IP适配器注意力处理器
    )

# 获取日志记录器实例,使用当前模块的名称
logger = logging.get_logger(__name__)

# 定义一个处理IP适配器的Mixin类
class IPAdapterMixin:
    """处理IP适配器的Mixin类。"""

    # 使用装饰器验证HF Hub参数
    @validate_hf_hub_args
    def load_ip_adapter(
        # 定义加载IP适配器所需的参数,包括模型名称、子文件夹、权重名称等
        self,
        pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
        subfolder: Union[str, List[str]],
        weight_name: Union[str, List[str]],
        # 可选参数,默认值为“image_encoder”
        image_encoder_folder: Optional[str] = "image_encoder",
        **kwargs,  # 其他关键字参数
    # 定义方法以设置 IP-Adapter 的缩放比例,输入参数 scale 可为单个配置或多个配置的列表
    def set_ip_adapter_scale(self, scale):
        # 文档字符串,提供使用示例和配置说明
        """
        Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
        granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
    
        Example:
    
        ```py
        # To use original IP-Adapter
        scale = 1.0
        pipeline.set_ip_adapter_scale(scale)
    
        # To use style block only
        scale = {
            "up": {"block_0": [0.0, 1.0, 0.0]},
        }
        pipeline.set_ip_adapter_scale(scale)
    
        # To use style+layout blocks
        scale = {
            "down": {"block_2": [0.0, 1.0]},
            "up": {"block_0": [0.0, 1.0, 0.0]},
        }
        pipeline.set_ip_adapter_scale(scale)
    
        # To use style and layout from 2 reference images
        scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
        pipeline.set_ip_adapter_scale(scales)
        ```py
        """
        # 根据名称获取 UNet 对象,如果没有则使用已有的 unet 属性
        unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
        # 如果 scale 不是列表,将其转换为列表
        if not isinstance(scale, list):
            scale = [scale]
        # 调用辅助函数以展开缩放配置,默认缩放为 0.0
        scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
    
        # 遍历 UNet 的注意力处理器字典
        for attn_name, attn_processor in unet.attn_processors.items():
            # 检查处理器是否为 IPAdapter 类型
            if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
                # 验证缩放配置数量与处理器的数量匹配
                if len(scale_configs) != len(attn_processor.scale):
                    raise ValueError(
                        f"Cannot assign {len(scale_configs)} scale_configs to "
                        f"{len(attn_processor.scale)} IP-Adapter."
                    )
                # 如果只有一个缩放配置,复制到每个处理器
                elif len(scale_configs) == 1:
                    scale_configs = scale_configs * len(attn_processor.scale)
                # 遍历每个缩放配置
                for i, scale_config in enumerate(scale_configs):
                    # 如果配置是字典,则根据名称匹配进行设置
                    if isinstance(scale_config, dict):
                        for k, s in scale_config.items():
                            if attn_name.startswith(k):
                                attn_processor.scale[i] = s
                    # 否则直接将缩放配置赋值
                    else:
                        attn_processor.scale[i] = scale_config
    # 定义一个方法来卸载 IP 适配器的权重
    def unload_ip_adapter(self):
        """
        卸载 IP 适配器的权重

        示例:

        ```python
        >>> # 假设 `pipeline` 已经加载了 IP 适配器的权重。
        >>> pipeline.unload_ip_adapter()
        >>> ...
        ```py
        """
        # 移除 CLIP 图像编码器
        if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
            # 将图像编码器设为 None
            self.image_encoder = None
            # 更新配置,移除图像编码器的相关信息
            self.register_to_config(image_encoder=[None, None])

        # 仅当 safety_checker 为 None 时移除特征提取器,因为 safety_checker 后续会使用 feature_extractor
        if not hasattr(self, "safety_checker"):
            if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
                # 将特征提取器设为 None
                self.feature_extractor = None
                # 更新配置,移除特征提取器的相关信息
                self.register_to_config(feature_extractor=[None, None])

        # 移除隐藏编码器
        self.unet.encoder_hid_proj = None
        # 将编码器的隐藏维度类型设为 None
        self.unet.config.encoder_hid_dim_type = None

        # Kolors: 使用 `text_encoder_hid_proj` 恢复 `encoder_hid_proj`
        if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
            # 将 encoder_hid_proj 设置为 text_encoder_hid_proj
            self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
            # 将 text_encoder_hid_proj 设为 None
            self.unet.text_encoder_hid_proj = None
            # 更新编码器的隐藏维度类型为 "text_proj"
            self.unet.config.encoder_hid_dim_type = "text_proj"

        # 恢复原始 Unet 注意力处理器层
        attn_procs = {}
        # 遍历 Unet 的注意力处理器
        for name, value in self.unet.attn_processors.items():
            # 根据 F 是否具有 scaled_dot_product_attention 选择注意力处理器的类
            attn_processor_class = (
                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
            )
            # 将注意力处理器添加到字典中,若是 IPAdapter 的类则使用新的类,否则使用原类
            attn_procs[name] = (
                attn_processor_class
                if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
                else value.__class__()
            )
        # 设置 Unet 的注意力处理器为新生成的处理器字典
        self.unet.set_attn_processor(attn_procs)

.\diffusers\loaders\lora_base.py

# 版权声明,指明版权持有者及其权利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache License 2.0 版本授权该文件,用户需遵守该授权
# Licensed under the Apache License, Version 2.0 (the "License");
# 只能在符合该许可证的情况下使用此文件
# 许可证副本可在以下地址获取
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,否则根据许可证分发的软件是按“原样”提供
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取有关权限和限制的具体信息
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入标准库中的复制模块
import copy
# 导入检查模块以获取对象的签名和源代码
import inspect
# 导入操作系统模块以处理文件和目录
import os
# 从路径库导入 Path 类以方便路径操作
from pathlib import Path
# 从类型库导入所需的类型注解
from typing import Callable, Dict, List, Optional, Union

# 导入 safetensors 库以处理安全张量
import safetensors
# 导入 PyTorch 库
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 从 huggingface_hub 导入模型信息获取函数
from huggingface_hub import model_info
# 从 huggingface_hub 导入常量以指示离线模式
from huggingface_hub.constants import HF_HUB_OFFLINE

# 从父模块导入模型混合相关的工具和函数
from ..models.modeling_utils import ModelMixin, load_state_dict
# 从工具模块导入多个实用函数和常量
from ..utils import (
    USE_PEFT_BACKEND,              # 是否使用 PEFT 后端的标志
    _get_model_file,               # 获取模型文件的函数
    delete_adapter_layers,         # 删除适配器层的函数
    deprecate,                     # 用于标记弃用功能的函数
    is_accelerate_available,       # 检查 accelerate 是否可用的函数
    is_peft_available,             # 检查 PEFT 是否可用的函数
    is_transformers_available,     # 检查 transformers 是否可用的函数
    logging,                       # 日志模块
    recurse_remove_peft_layers,    # 递归删除 PEFT 层的函数
    set_adapter_layers,            # 设置适配器层的函数
    set_weights_and_activate_adapters, # 设置权重并激活适配器的函数
)

# 如果 transformers 可用,则导入 PreTrainedModel 类
if is_transformers_available():
    from transformers import PreTrainedModel

# 如果 PEFT 可用,则导入 BaseTunerLayer 类
if is_peft_available():
    from peft.tuners.tuners_utils import BaseTunerLayer

# 如果 accelerate 可用,则导入相关的钩子
if is_accelerate_available():
    from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module

# 创建一个日志记录器实例,用于当前模块
logger = logging.get_logger(__name__)

# 定义一个函数以融合文本编码器的 LoRA
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
    """
    融合文本编码器的 LoRA。

    参数:
        text_encoder (`torch.nn.Module`):
            要设置适配器层的文本编码器模块。如果为 `None`,则会尝试获取 `text_encoder`
            属性。
        lora_scale (`float`, defaults to 1.0):
            控制 LoRA 参数对输出的影响程度。
        safe_fusing (`bool`, defaults to `False`):
            是否在融合之前检查融合的权重是否为 NaN 值,如果为 NaN 则不进行融合。
        adapter_names (`List[str]` 或 `str`):
            要使用的适配器名称列表。
    """
    # 定义合并参数字典,包含安全合并选项
    merge_kwargs = {"safe_merge": safe_fusing}
    # 遍历文本编码器中的所有模块
    for module in text_encoder.modules():
        # 检查当前模块是否是 BaseTunerLayer 类型
        if isinstance(module, BaseTunerLayer):
            # 如果 lora_scale 不是 1.0,则对当前模块进行缩放
            if lora_scale != 1.0:
                module.scale_layer(lora_scale)

            # 为了与之前的 PEFT 版本兼容,检查 `merge` 方法的签名
            # 以查看是否支持 `adapter_names` 参数
            supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
            # 如果 `adapter_names` 参数被支持,则将其添加到合并参数中
            if "adapter_names" in supported_merge_kwargs:
                merge_kwargs["adapter_names"] = adapter_names
            # 如果不支持 `adapter_names` 且其值不为 None,则抛出错误
            elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
                raise ValueError(
                    # 抛出的错误信息,提示用户升级 PEFT 版本
                    "The `adapter_names` argument is not supported with your PEFT version. "
                    "Please upgrade to the latest version of PEFT. `pip install -U peft`"
                )

            # 调用模块的 merge 方法,使用合并参数进行合并
            module.merge(**merge_kwargs)
# 解锁文本编码器的 LoRA 层
def unfuse_text_encoder_lora(text_encoder):
    """
    解锁文本编码器的 LoRA 层。

    参数:
        text_encoder (`torch.nn.Module`):
            要设置适配器层的文本编码器模块。如果为 `None`,将尝试获取 `text_encoder` 属性。
    """
    # 遍历文本编码器中的所有模块
    for module in text_encoder.modules():
        # 检查当前模块是否是 BaseTunerLayer 的实例
        if isinstance(module, BaseTunerLayer):
            # 对于符合条件的模块,解除合并操作
            module.unmerge()


# 设置文本编码器的适配器层
def set_adapters_for_text_encoder(
    adapter_names: Union[List[str], str],
    text_encoder: Optional["PreTrainedModel"] = None,  # noqa: F821
    text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
):
    """
    设置文本编码器的适配器层。

    参数:
        adapter_names (`List[str]` 或 `str`):
            要使用的适配器名称。
        text_encoder (`torch.nn.Module`, *可选*):
            要设置适配器层的文本编码器模块。如果为 `None`,将尝试获取 `text_encoder` 属性。
        text_encoder_weights (`List[float]`, *可选*):
            要用于文本编码器的权重。如果为 `None`,则所有适配器的权重均设置为 `1.0`。
    """
    # 如果文本编码器为 None,抛出错误
    if text_encoder is None:
        raise ValueError(
            "管道没有默认的 `pipe.text_encoder` 类。请确保传递一个 `text_encoder`。"
        )

    # 处理适配器权重的函数
    def process_weights(adapter_names, weights):
        # 将权重扩展为列表,确保每个适配器都有一个权重
        # 例如,对于 2 个适配器:  7 -> [7,7] ; [3, None] -> [3, None]
        if not isinstance(weights, list):
            weights = [weights] * len(adapter_names)

        # 检查适配器名称与权重列表的长度是否相等
        if len(adapter_names) != len(weights):
            raise ValueError(
                f"适配器名称的长度 {len(adapter_names)} 不等于权重的长度 {len(weights)}"
            )

        # 将 None 值设置为默认值 1.0
        # 例如: [7,7] -> [7,7] ; [3, None] -> [3,1]
        weights = [w if w is not None else 1.0 for w in weights]

        return weights

    # 如果适配器名称是字符串,则将其转为列表
    adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
    # 处理适配器权重
    text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
    # 设置权重并激活适配器
    set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)


# 禁用文本编码器的 LoRA 层
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
    """
    禁用文本编码器的 LoRA 层。

    参数:
        text_encoder (`torch.nn.Module`, *可选*):
            要禁用 LoRA 层的文本编码器模块。如果为 `None`,将尝试获取 `text_encoder` 属性。
    """
    # 如果文本编码器为 None,抛出错误
    if text_encoder is None:
        raise ValueError("未找到文本编码器。")
    # 设置适配器层为禁用状态
    set_adapter_layers(text_encoder, enabled=False)


# 启用文本编码器的 LoRA 层
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
    """
    启用文本编码器的 LoRA 层。
    # 函数参数文档字符串,说明 text_encoder 的作用和类型
        Args:
            text_encoder (`torch.nn.Module`, *optional*):
                # 可选参数,文本编码器模块,用于启用 LoRA 层。如果为 `None`,将尝试获取 `text_encoder`
                attribute.
        """
        # 如果未提供文本编码器,则抛出错误
        if text_encoder is None:
            raise ValueError("Text Encoder not found.")
        # 调用函数以启用适配器层
        set_adapter_layers(text_encoder, enabled=True)
# 移除文本编码器的猴子补丁
def _remove_text_encoder_monkey_patch(text_encoder):
    # 递归移除 PEFT 层
    recurse_remove_peft_layers(text_encoder)
    # 如果 text_encoder 有 peft_config 属性且不为 None
    if getattr(text_encoder, "peft_config", None) is not None:
        # 删除 peft_config 属性
        del text_encoder.peft_config
        # 将 hf_peft_config_loaded 设置为 None
        text_encoder._hf_peft_config_loaded = None


class LoraBaseMixin:
    """处理 LoRA 的实用类。"""

    # 可加载的 LoRA 模块列表
    _lora_loadable_modules = []
    # 融合 LoRA 的数量
    num_fused_loras = 0

    # 加载 LoRA 权重的未实现方法
    def load_lora_weights(self, **kwargs):
        raise NotImplementedError("`load_lora_weights()` is not implemented.")

    # 保存 LoRA 权重的未实现方法
    @classmethod
    def save_lora_weights(cls, **kwargs):
        raise NotImplementedError("`save_lora_weights()` not implemented.")

    # 获取 LoRA 状态字典的未实现方法
    @classmethod
    def lora_state_dict(cls, **kwargs):
        raise NotImplementedError("`lora_state_dict()` is not implemented.")

    # 可选地禁用管道的离线加载
    @classmethod
    def _optionally_disable_offloading(cls, _pipeline):
        """
        可选地移除已离线加载到 CPU 的管道。

        Args:
            _pipeline (`DiffusionPipeline`):
                要禁用离线加载的管道。

        Returns:
            tuple:
                指示 `is_model_cpu_offload` 或 `is_sequential_cpu_offload` 是否为 True 的元组。
        """
        # 初始化模型和序列 CPU 离线标志
        is_model_cpu_offload = False
        is_sequential_cpu_offload = False

        # 如果管道不为 None 且没有 hf_device_map
        if _pipeline is not None and _pipeline.hf_device_map is None:
            # 遍历管道的组件
            for _, component in _pipeline.components.items():
                # 如果组件是 nn.Module 且有 _hf_hook 属性
                if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
                    # 判断模型是否已经离线加载到 CPU
                    if not is_model_cpu_offload:
                        is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
                    # 判断是否序列化离线加载
                    if not is_sequential_cpu_offload:
                        is_sequential_cpu_offload = (
                            isinstance(component._hf_hook, AlignDevicesHook)
                            or hasattr(component._hf_hook, "hooks")
                            and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
                        )

                    # 记录检测到的加速钩子信息
                    logger.info(
                        "检测到加速钩子。由于您已调用 `load_lora_weights()`,之前的钩子将首先被移除。然后将加载 LoRA 参数并再次应用钩子。"
                    )
                    # 从模块中移除钩子
                    remove_hook_from_module(component, recurse=is_sequential_cpu_offload)

        # 返回模型和序列 CPU 离线状态
        return (is_model_cpu_offload, is_sequential_cpu_offload)

    # 获取状态字典的方法,参数尚未完全列出
    @classmethod
    def _fetch_state_dict(
        cls,
        pretrained_model_name_or_path_or_dict,
        weight_name,
        use_safetensors,
        local_files_only,
        cache_dir,
        force_download,
        proxies,
        token,
        revision,
        subfolder,
        user_agent,
        allow_pickle,
    ):
        # 从当前模块导入 LORA_WEIGHT_NAME 和 LORA_WEIGHT_NAME_SAFE
        from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE

        # 初始化模型文件为 None
        model_file = None
        # 检查传入的模型参数是否为字典
        if not isinstance(pretrained_model_name_or_path_or_dict, dict):
            # 如果使用 safetensors 且权重名称为空,或者权重名称以 .safetensors 结尾
            # 则尝试加载 .safetensors 权重
            if (use_safetensors and weight_name is None) or (
                weight_name is not None and weight_name.endswith(".safetensors")
            ):
                try:
                    # 放宽加载检查以提高推理 API 的友好性
                    # 有时无法自动确定 `weight_name`
                    if weight_name is None:
                        # 获取最佳猜测的权重名称
                        weight_name = cls._best_guess_weight_name(
                            pretrained_model_name_or_path_or_dict,
                            file_extension=".safetensors",  # 指定文件扩展名为 .safetensors
                            local_files_only=local_files_only,  # 仅限本地文件
                        )
                    # 获取模型文件
                    model_file = _get_model_file(
                        pretrained_model_name_or_path_or_dict,
                        weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,  # 使用安全的权重名称
                        cache_dir=cache_dir,  # 指定缓存目录
                        force_download=force_download,  # 是否强制下载
                        proxies=proxies,  # 代理设置
                        local_files_only=local_files_only,  # 仅限本地文件
                        token=token,  # 认证令牌
                        revision=revision,  # 版本信息
                        subfolder=subfolder,  # 子文件夹
                        user_agent=user_agent,  # 用户代理信息
                    )
                    # 从模型文件加载状态字典到 CPU 设备
                    state_dict = safetensors.torch.load_file(model_file, device="cpu")
                except (IOError, safetensors.SafetensorError) as e:
                    # 如果不允许使用 pickle,则抛出异常
                    if not allow_pickle:
                        raise e
                    # 尝试加载非 safetensors 权重
                    model_file = None
                    pass  # 忽略异常并继续执行

            # 如果模型文件仍然为 None
            if model_file is None:
                # 如果权重名称为空,获取最佳猜测的权重名称
                if weight_name is None:
                    weight_name = cls._best_guess_weight_name(
                        pretrained_model_name_or_path_or_dict,  # 使用给定的参数
                        file_extension=".bin",  # 指定文件扩展名为 .bin
                        local_files_only=local_files_only  # 仅限本地文件
                    )
                # 获取模型文件
                model_file = _get_model_file(
                    pretrained_model_name_or_path_or_dict,
                    weights_name=weight_name or LORA_WEIGHT_NAME,  # 使用常规权重名称
                    cache_dir=cache_dir,  # 指定缓存目录
                    force_download=force_download,  # 是否强制下载
                    proxies=proxies,  # 代理设置
                    local_files_only=local_files_only,  # 仅限本地文件
                    token=token,  # 认证令牌
                    revision=revision,  # 版本信息
                    subfolder=subfolder,  # 子文件夹
                    user_agent=user_agent,  # 用户代理信息
                )
                # 从模型文件加载状态字典
                state_dict = load_state_dict(model_file)
        else:
            # 如果传入的是字典,则直接将其赋值给状态字典
            state_dict = pretrained_model_name_or_path_or_dict

        # 返回加载的状态字典
        return state_dict

    # 定义类方法的装饰器
    @classmethod
    # 获取最佳权重名称的方法,支持多种输入形式
        def _best_guess_weight_name(
            # 类参数,预训练模型名称或路径或字典,文件扩展名,是否仅使用本地文件
            cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
        ):
            # 从lora_pipeline模块导入权重名称常量
            from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
    
            # 如果是本地文件模式或离线模式,抛出错误
            if local_files_only or HF_HUB_OFFLINE:
                raise ValueError("When using the offline mode, you must specify a `weight_name`.")
    
            # 初始化目标文件列表
            targeted_files = []
    
            # 如果输入是文件,直接返回
            if os.path.isfile(pretrained_model_name_or_path_or_dict):
                return
            # 如果输入是目录,列出符合扩展名的文件
            elif os.path.isdir(pretrained_model_name_or_path_or_dict):
                targeted_files = [
                    f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
                ]
            # 否则从模型信息中获取文件列表
            else:
                files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
                targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
            # 如果没有找到目标文件,直接返回
            if len(targeted_files) == 0:
                return
    
            # 定义不允许的子字符串
            unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
            # 过滤掉包含不允许子字符串的文件
            targeted_files = list(
                filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
            )
    
            # 如果找到以LORA_WEIGHT_NAME结尾的文件,仅保留这些文件
            if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
                targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
            # 否则如果找到以LORA_WEIGHT_NAME_SAFE结尾的文件,保留这些
            elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
                targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
    
            # 如果找到多个目标文件,抛出错误
            if len(targeted_files) > 1:
                raise ValueError(
                    f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one  `.safetensors` or `.bin` file in  {pretrained_model_name_or_path_or_dict}."
                )
            # 选择第一个目标文件作为权重名称
            weight_name = targeted_files[0]
            # 返回权重名称
            return weight_name
    
        # 卸载LoRA权重的方法
        def unload_lora_weights(self):
            """
            卸载LoRA参数的方法。
    
            示例:
    
            ```python
            >>> # 假设`pipeline`已经加载了LoRA参数。
            >>> pipeline.unload_lora_weights()
            >>> ...
            ```py
            """
            # 如果未使用PEFT后端,抛出错误
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for this method.")
    
            # 遍历可加载LoRA模块
            for component in self._lora_loadable_modules:
                # 获取相应的模型
                model = getattr(self, component, None)
                # 如果模型存在
                if model is not None:
                    # 如果模型是ModelMixin的子类,卸载LoRA
                    if issubclass(model.__class__, ModelMixin):
                        model.unload_lora()
                    # 如果模型是PreTrainedModel的子类,移除文本编码器的猴子补丁
                    elif issubclass(model.__class__, PreTrainedModel):
                        _remove_text_encoder_monkey_patch(model)
    # 定义一个方法,融合 LoRA 参数
        def fuse_lora(
            self,  # 方法所属的类实例
            components: List[str] = [],  # 要融合的组件列表,默认为空列表
            lora_scale: float = 1.0,  # LoRA 的缩放因子,默认为1.0
            safe_fusing: bool = False,  # 是否安全融合的标志,默认为 False
            adapter_names: Optional[List[str]] = None,  # 可选的适配器名称列表
            **kwargs,  # 额外的关键字参数
        ):
            # 定义一个方法,反融合 LoRA 参数
            def unfuse_lora(self, components: List[str] = [], **kwargs):
                r"""  # 文档字符串,描述该方法的作用
                Reverses the effect of  # 反转 fuse_lora 方法的效果
                [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
    
                <Tip warning={true}>  # 提示框,表示这是一个实验性 API
                This is an experimental API.  # 说明该 API 是实验性质的
                </Tip>
    
                Args:  # 参数说明
                    components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.  # 要反融合 LoRA 的组件列表
                    unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.  # 是否反融合 UNet 的 LoRA 参数
                    unfuse_text_encoder (`bool`, defaults to `True`):  # 是否反融合文本编码器的 LoRA 参数
                        Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the  # 如果文本编码器没有打补丁,则无效
                        LoRA parameters then it won't have any effect.  # 如果没有效果,则不会反融合
                """
                # 检查关键字参数中是否包含 unfuse_unet
                if "unfuse_unet" in kwargs:
                    depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."  # 过时的警告消息
                    deprecate(  # 调用 deprecate 函数
                        "unfuse_unet",  # 被弃用的参数名
                        "1.0.0",  # 被弃用的版本
                        depr_message,  # 过时消息
                    )
                # 检查关键字参数中是否包含 unfuse_transformer
                if "unfuse_transformer" in kwargs:
                    depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."  # 过时的警告消息
                    deprecate(  # 调用 deprecate 函数
                        "unfuse_transformer",  # 被弃用的参数名
                        "1.0.0",  # 被弃用的版本
                        depr_message,  # 过时消息
                    )
                # 检查关键字参数中是否包含 unfuse_text_encoder
                if "unfuse_text_encoder" in kwargs:
                    depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."  # 过时的警告消息
                    deprecate(  # 调用 deprecate 函数
                        "unfuse_text_encoder",  # 被弃用的参数名
                        "1.0.0",  # 被弃用的版本
                        depr_message,  # 过时消息
                    )
    
                # 如果组件列表为空,则抛出异常
                if len(components) == 0:
                    raise ValueError("`components` cannot be an empty list.")  # 抛出 ValueError,说明组件列表不能为空
    
                # 遍历组件列表中的每个组件
                for fuse_component in components:
                    # 如果组件不在可加载的 LoRA 模块中,抛出异常
                    if fuse_component not in self._lora_loadable_modules:
                        raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")  # 抛出 ValueError,说明组件未找到
    
                    # 获取当前组件的模型
                    model = getattr(self, fuse_component, None)  # 从当前实例获取组件的模型
                    # 如果模型存在
                    if model is not None:
                        # 检查模型是否是 ModelMixin 或 PreTrainedModel 的子类
                        if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
                            # 遍历模型中的每个模块
                            for module in model.modules():
                                # 如果模块是 BaseTunerLayer 的实例
                                if isinstance(module, BaseTunerLayer):
                                    module.unmerge()  # 调用 unmerge 方法,反融合 LoRA 参数
    
                # 将融合的 LoRA 数量减少1
                self.num_fused_loras -= 1  # 更新已融合 LoRA 的数量
    
        # 定义一个方法,设置适配器
        def set_adapters(
            self,  # 方法所属的类实例
            adapter_names: Union[List[str], str],  # 适配器名称,可以是列表或字符串
            adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,  # 可选的适配器权重
    ):
        # 将 adapter_names 转换为列表,如果它是字符串则单独包装成列表
        adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

        # 深拷贝 adapter_weights,防止修改原始数据
        adapter_weights = copy.deepcopy(adapter_weights)

        # 如果 adapter_weights 不是列表,则将其扩展为与 adapter_names 相同长度的列表
        if not isinstance(adapter_weights, list):
            adapter_weights = [adapter_weights] * len(adapter_names)

        # 检查 adapter_names 和 adapter_weights 的长度是否一致,若不一致则抛出错误
        if len(adapter_names) != len(adapter_weights):
            raise ValueError(
                f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
            )

        # 获取所有适配器的列表,返回一个字典,示例:{"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
        list_adapters = self.get_list_adapters()  # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
        
        # 获取所有适配器的集合,例如:{"adapter1", "adapter2"}
        all_adapters = {
            adapter for adapters in list_adapters.values() for adapter in adapters
        }  # eg ["adapter1", "adapter2"]
        
        # 生成一个字典,键为适配器,值为适配器所对应的部分
        invert_list_adapters = {
            adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
            for adapter in all_adapters
        }  # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}

        # 初始化一个空字典,用于存放分解后的权重
        _component_adapter_weights = {}
        
        # 遍历可加载的模块
        for component in self._lora_loadable_modules:
            # 动态获取模块的实例
            model = getattr(self, component)

            # 将适配器名称与权重一一对应
            for adapter_name, weights in zip(adapter_names, adapter_weights):
                # 如果权重是字典,尝试从中获取特定组件的权重
                if isinstance(weights, dict):
                    component_adapter_weights = weights.pop(component, None)

                    # 如果权重存在但模型中没有该组件,记录警告
                    if component_adapter_weights is not None and not hasattr(self, component):
                        logger.warning(
                            f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
                        )

                    # 如果权重存在但适配器中不包含该组件,记录警告
                    if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
                        logger.warning(
                            (
                                f"Lora weight dict for adapter '{adapter_name}' contains {component},"
                                f"but this will be ignored because {adapter_name} does not contain weights for {component}."
                                f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
                            )
                        )

                else:
                    # 如果权重不是字典,直接使用权重
                    component_adapter_weights = weights

                # 确保组件权重字典中有该组件的列表,如果没有则初始化为空列表
                _component_adapter_weights.setdefault(component, [])
                # 将组件的权重添加到对应的列表中
                _component_adapter_weights[component].append(component_adapter_weights)

            # 如果模型是 ModelMixin 的子类,设置适配器
            if issubclass(model.__class__, ModelMixin):
                model.set_adapters(adapter_names, _component_adapter_weights[component])
            # 如果模型是 PreTrainedModel 的子类,设置文本编码器的适配器
            elif issubclass(model.__class__, PreTrainedModel):
                set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
    # 定义一个禁用 LoRA 的方法
        def disable_lora(self):
            # 检查是否使用 PEFT 后端,若不使用则抛出错误
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for this method.")
    
            # 遍历可加载 LoRA 的模块
            for component in self._lora_loadable_modules:
                # 获取当前组件对应的模型
                model = getattr(self, component, None)
                # 如果模型存在
                if model is not None:
                    # 如果模型是 ModelMixin 的子类,禁用其 LoRA
                    if issubclass(model.__class__, ModelMixin):
                        model.disable_lora()
                    # 如果模型是 PreTrainedModel 的子类,调用相应的禁用方法
                    elif issubclass(model.__class__, PreTrainedModel):
                        disable_lora_for_text_encoder(model)
    
    # 定义一个启用 LoRA 的方法
        def enable_lora(self):
            # 检查是否使用 PEFT 后端,若不使用则抛出错误
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for this method.")
    
            # 遍历可加载 LoRA 的模块
            for component in self._lora_loadable_modules:
                # 获取当前组件对应的模型
                model = getattr(self, component, None)
                # 如果模型存在
                if model is not None:
                    # 如果模型是 ModelMixin 的子类,启用其 LoRA
                    if issubclass(model.__class__, ModelMixin):
                        model.enable_lora()
                    # 如果模型是 PreTrainedModel 的子类,调用相应的启用方法
                    elif issubclass(model.__class__, PreTrainedModel):
                        enable_lora_for_text_encoder(model)
    
    # 定义一个删除适配器的函数
        def delete_adapters(self, adapter_names: Union[List[str], str]):
            """
            Args:
            Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
                adapter_names (`Union[List[str], str]`):
                    The names of the adapter to delete. Can be a single string or a list of strings
            """
            # 检查是否使用 PEFT 后端,若不使用则抛出错误
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for this method.")
    
            # 如果 adapter_names 是字符串,则转换为列表
            if isinstance(adapter_names, str):
                adapter_names = [adapter_names]
    
            # 遍历可加载 LoRA 的模块
            for component in self._lora_loadable_modules:
                # 获取当前组件对应的模型
                model = getattr(self, component, None)
                # 如果模型存在
                if model is not None:
                    # 如果模型是 ModelMixin 的子类,删除适配器
                    if issubclass(model.__class__, ModelMixin):
                        model.delete_adapters(adapter_names)
                    # 如果模型是 PreTrainedModel 的子类,逐个删除适配器层
                    elif issubclass(model.__class__, PreTrainedModel):
                        for adapter_name in adapter_names:
                            delete_adapter_layers(model, adapter_name)
    # 定义获取当前活动适配器的函数,返回类型为字符串列表
    def get_active_adapters(self) -> List[str]:
        # 函数说明:获取当前活动适配器的列表,包含使用示例
        """
        Gets the list of the current active adapters.
    
        Example:
    
        ```python
        from diffusers import DiffusionPipeline
    
        pipeline = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
        ).to("cuda")
        pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
        pipeline.get_active_adapters()
        ```py
        """
        # 检查是否启用了 PEFT 后端,未启用则抛出异常
        if not USE_PEFT_BACKEND:
            raise ValueError(
                "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
            )
    
        # 初始化活动适配器列表
        active_adapters = []
    
        # 遍历可加载的 LORA 模块
        for component in self._lora_loadable_modules:
            # 获取当前组件的模型,如果不存在则为 None
            model = getattr(self, component, None)
            # 检查模型是否存在且是 ModelMixin 的子类
            if model is not None and issubclass(model.__class__, ModelMixin):
                # 遍历模型的所有模块
                for module in model.modules():
                    # 检查模块是否为 BaseTunerLayer 的实例
                    if isinstance(module, BaseTunerLayer):
                        # 获取活动适配器并赋值
                        active_adapters = module.active_adapters
                        break
    
        # 返回活动适配器列表
        return active_adapters
    
    # 定义获取当前所有可用适配器列表的函数,返回类型为字典
    def get_list_adapters(self) -> Dict[str, List[str]]:
        # 函数说明:获取当前管道中所有可用适配器的列表
        """
        Gets the current list of all available adapters in the pipeline.
        """
        # 检查是否启用了 PEFT 后端,未启用则抛出异常
        if not USE_PEFT_BACKEND:
            raise ValueError(
                "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
            )
    
        # 初始化适配器集合字典
        set_adapters = {}
    
        # 遍历可加载的 LORA 模块
        for component in self._lora_loadable_modules:
            # 获取当前组件的模型,如果不存在则为 None
            model = getattr(self, component, None)
            # 检查模型是否存在且是 ModelMixin 或 PreTrainedModel 的子类,并具有 peft_config 属性
            if (
                model is not None
                and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
                and hasattr(model, "peft_config")
            ):
                # 将适配器配置的键列表存入字典
                set_adapters[component] = list(model.peft_config.keys())
    
        # 返回适配器集合字典
        return set_adapters
    # 定义一个方法,用于将指定的 LoRA 适配器移动到目标设备
    def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
        """
        将 `adapter_names` 中列出的 LoRA 适配器移动到目标设备。此方法用于在加载多个适配器时将 LoRA 移动到 CPU,以释放一些 GPU 内存。

        Args:
            adapter_names (`List[str]`):
                要发送到设备的适配器列表。
            device (`Union[torch.device, str, int]`):
                适配器要发送到的设备,可以是 torch 设备、字符串或整数。
        """
        # 检查是否使用 PEFT 后端,如果没有则抛出错误
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")

        # 遍历可加载 LoRA 模块的组件
        for component in self._lora_loadable_modules:
            # 获取当前组件的模型,如果没有则为 None
            model = getattr(self, component, None)
            # 如果模型存在,则继续处理
            if model is not None:
                # 遍历模型的所有模块
                for module in model.modules():
                    # 检查模块是否是 BaseTunerLayer 的实例
                    if isinstance(module, BaseTunerLayer):
                        # 遍历适配器名称列表
                        for adapter_name in adapter_names:
                            # 将 lora_A 适配器移动到指定设备
                            module.lora_A[adapter_name].to(device)
                            # 将 lora_B 适配器移动到指定设备
                            module.lora_B[adapter_name].to(device)
                            # 如果模块有 lora_magnitude_vector 属性并且不为 None
                            if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
                                # 将 lora_magnitude_vector 中的适配器移动到指定设备,并重新赋值
                                module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
                                    adapter_name
                                ].to(device)

    # 定义一个静态方法,用于打包层的权重
    @staticmethod
    def pack_weights(layers, prefix):
        # 获取层的状态字典,如果层是 nn.Module 则调用其 state_dict() 方法
        layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
        # 将权重和模块名称组合成一个新的字典
        layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
        # 返回新的权重字典
        return layers_state_dict

    # 定义一个静态方法,用于写入 LoRA 层的权重
    @staticmethod
    def write_lora_layers(
        state_dict: Dict[str, torch.Tensor],
        save_directory: str,
        is_main_process: bool,
        weight_name: str,
        save_function: Callable,
        safe_serialization: bool,
    ):
        # 从本地模块导入 LORA_WEIGHT_NAME 和 LORA_WEIGHT_NAME_SAFE 常量
        from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE

        # 检查提供的路径是否为文件,如果是则记录错误并返回
        if os.path.isfile(save_directory):
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
            return

        # 如果没有提供保存函数
        if save_function is None:
            # 根据是否使用安全序列化来定义保存函数
            if safe_serialization:
                # 定义一个保存函数,使用 safetensors 库保存文件,带有元数据
                def save_function(weights, filename):
                    return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
            else:
                # 如果不使用安全序列化,使用 PyTorch 自带的保存函数
                save_function = torch.save

        # 创建保存目录,如果目录已存在则不报错
        os.makedirs(save_directory, exist_ok=True)

        # 如果没有提供权重名称,根据安全序列化的设置选择默认名称
        if weight_name is None:
            if safe_serialization:
                # 使用安全权重名称
                weight_name = LORA_WEIGHT_NAME_SAFE
            else:
                # 使用普通权重名称
                weight_name = LORA_WEIGHT_NAME

        # 构造保存文件的完整路径
        save_path = Path(save_directory, weight_name).as_posix()
        # 调用保存函数,将状态字典保存到指定路径
        save_function(state_dict, save_path)
        # 记录模型权重保存成功的信息
        logger.info(f"Model weights saved in {save_path}")

    @property
    # 定义属性函数,返回 lora_scale 的值,可以在运行时由管道设置
    def lora_scale(self) -> float:
        # 如果 _lora_scale 未被设置,返回默认值 1
        return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
posted @ 2024-10-22 12:39  绝不原创的飞龙  阅读(209)  评论(0编辑  收藏  举报