diffusers 源码解析(四十二)
.\diffusers\pipelines\pipeline_utils.py
# coding=utf-8 # 指定文件编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team. # 版权声明,表明文件归 HuggingFace Inc. 团队所有
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # 版权声明,表明文件归 NVIDIA CORPORATION 所有
#
# Licensed under the Apache License, Version 2.0 (the "License"); # 指明此文件的许可证为 Apache 2.0 版本
# you may not use this file except in compliance with the License. # 指出必须遵循许可证才能使用此文件
# You may obtain a copy of the License at # 提供获取许可证的方式
#
# http://www.apache.org/licenses/LICENSE-2.0 # 指向许可证的 URL
#
# Unless required by applicable law or agreed to in writing, software # 指出软件在没有明确同意或适用法律时按“原样”提供
# distributed under the License is distributed on an "AS IS" BASIS, # 声明不提供任何形式的保证
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # 没有任何形式的明示或暗示的担保
# See the License for the specific language governing permissions and # 建议查看许可证以获取权限和限制的具体信息
# limitations under the License. # 指出许可证下的限制
import fnmatch # 导入 fnmatch 模块,用于文件名匹配
import importlib # 导入 importlib 模块,用于动态导入模块
import inspect # 导入 inspect 模块,用于获取对象的信息
import os # 导入 os 模块,用于操作系统相关功能
import re # 导入 re 模块,用于正则表达式匹配
import sys # 导入 sys 模块,用于访问 Python 解释器的变量和函数
from dataclasses import dataclass # 从 dataclasses 导入 dataclass 装饰器,用于简化类的定义
from pathlib import Path # 从 pathlib 导入 Path 类,用于路径操作
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin # 导入类型提示相关的工具
import numpy as np # 导入 NumPy 库并简写为 np,用于数值计算
import PIL.Image # 导入 PIL 的 Image 模块,用于图像处理
import requests # 导入 requests 库,用于发送 HTTP 请求
import torch # 导入 PyTorch 库,用于深度学习
from huggingface_hub import ( # 从 huggingface_hub 导入多个功能
ModelCard, # 导入 ModelCard 类,用于处理模型卡
create_repo, # 导入 create_repo 函数,用于创建模型仓库
hf_hub_download, # 导入 hf_hub_download 函数,用于从 Hugging Face Hub 下载文件
model_info, # 导入 model_info 函数,用于获取模型信息
snapshot_download, # 导入 snapshot_download 函数,用于下载快照
)
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args # 导入帮助函数用于验证参数和检查离线模式
from packaging import version # 从 packaging 导入 version 模块,用于版本比较
from requests.exceptions import HTTPError # 从 requests.exceptions 导入 HTTPError,用于处理 HTTP 错误
from tqdm.auto import tqdm # 从 tqdm.auto 导入 tqdm,用于显示进度条
from .. import __version__ # 从当前模块导入版本号
from ..configuration_utils import ConfigMixin # 从上级模块导入 ConfigMixin 类,用于配置混入
from ..models import AutoencoderKL # 从上级模块导入 AutoencoderKL 模型
from ..models.attention_processor import FusedAttnProcessor2_0 # 从上级模块导入 FusedAttnProcessor2_0 类
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin # 从上级模块导入常量和 ModelMixin 类
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME # 从上级模块导入调度器配置名称
from ..utils import ( # 从上级模块导入多个工具函数和常量
CONFIG_NAME, # 配置文件名
DEPRECATED_REVISION_ARGS, # 已弃用的修订参数
BaseOutput, # 基础输出类
PushToHubMixin, # 推送到 Hub 的混入类
deprecate, # 用于标记弃用的函数
is_accelerate_available, # 检查 accelerate 库是否可用的函数
is_accelerate_version, # 检查 accelerate 版本的函数
is_torch_npu_available, # 检查 PyTorch NPU 是否可用的函数
is_torch_version, # 检查 PyTorch 版本的函数
logging, # 日志记录模块
numpy_to_pil, # NumPy 数组转换为 PIL 图像的函数
)
from ..utils.hub_utils import load_or_create_model_card, populate_model_card # 从上级模块导入处理模型卡的函数
from ..utils.torch_utils import is_compiled_module # 从上级模块导入检查模块是否已编译的函数
if is_torch_npu_available(): # 如果 PyTorch NPU 可用
import torch_npu # 导入 torch_npu 模块,提供对 NPU 的支持 # noqa: F401 # noqa: F401 表示忽略未使用的导入警告
from .pipeline_loading_utils import ( # 从当前包导入多个加载管道相关的工具
ALL_IMPORTABLE_CLASSES, # 所有可导入的类
CONNECTED_PIPES_KEYS, # 连接管道的键
CUSTOM_PIPELINE_FILE_NAME, # 自定义管道文件名
LOADABLE_CLASSES, # 可加载的类
_fetch_class_library_tuple, # 获取类库元组的私有函数
_get_custom_pipeline_class, # 获取自定义管道类的私有函数
_get_final_device_map, # 获取最终设备映射的私有函数
_get_pipeline_class, # 获取管道类的私有函数
_unwrap_model, # 解包模型的私有函数
is_safetensors_compatible, # 检查是否兼容 SafeTensors 的函数
load_sub_model, # 加载子模型的函数
maybe_raise_or_warn, # 可能抛出警告或错误的函数
variant_compatible_siblings, # 检查变体兼容的兄弟类的函数
warn_deprecated_model_variant, # 发出关于模型变体弃用的警告的函数
)
if is_accelerate_available(): # 如果 accelerate 库可用
import accelerate # 导入 accelerate 库,提供加速功能
LIBRARIES = [] # 初始化空列表,用于存储库
for library in LOADABLE_CLASSES: # 遍历可加载的类
LIBRARIES.append(library) # 将每个库添加到 LIBRARIES 列表中
SUPPORTED_DEVICE_MAP = ["balanced"] # 定义支持的设备映射,使用平衡策略
logger = logging.get_logger(__name__) # 创建一个与当前模块同名的日志记录器
@dataclass # 使用 dataclass 装饰器定义一个数据类
class ImagePipelineOutput(BaseOutput): # 定义图像管道输出类,继承自 BaseOutput
"""
Output class for image pipelines. # 图像管道的输出类
Args: # 参数说明
images (`List[PIL.Image.Image]` or `np.ndarray`) # images 参数,接受 PIL 图像列表或 NumPy 数组
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`. # 说明该参数可以是图像列表或具有特定形状的 NumPy 数组
"""
# 定义一个变量 images,它可以是一个 PIL 图像对象列表或一个 NumPy 数组
images: Union[List[PIL.Image.Image], np.ndarray]
# 定义音频管道输出的数据类,继承自 BaseOutput
@dataclass
class AudioPipelineOutput(BaseOutput):
"""
音频管道的输出类。
参数:
audios (`np.ndarray`)
一个形状为 `(batch_size, num_channels, sample_rate)` 的 NumPy 数组,表示去噪后的音频样本列表。
"""
# 存储音频样本的 NumPy 数组
audios: np.ndarray
# 定义扩散管道的基类,继承自 ConfigMixin 和 PushToHubMixin
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
r"""
所有管道的基类。
[`DiffusionPipeline`] 存储所有扩散管道的组件(模型、调度器和处理器),并提供加载、下载和保存模型的方法。它还包含以下方法:
- 将所有 PyTorch 模块移动到您选择的设备
- 启用/禁用去噪迭代的进度条
类属性:
- **config_name** (`str`) -- 存储扩散管道所有组件类和模块名称的配置文件名。
- **_optional_components** (`List[str]`) -- 所有可选组件的列表,这些组件在管道功能上并不是必需的(应由子类重写)。
"""
# 配置文件名称,默认值为 "model_index.json"
config_name = "model_index.json"
# 模型 CPU 卸载序列,初始值为 None
model_cpu_offload_seq = None
# Hugging Face 设备映射,初始值为 None
hf_device_map = None
# 可选组件列表,初始化为空
_optional_components = []
# 不参与 CPU 卸载的组件列表,初始化为空
_exclude_from_cpu_offload = []
# 是否加载连接的管道,初始化为 False
_load_connected_pipes = False
# 是否为 ONNX 格式,初始化为 False
_is_onnx = False
# 注册模块的方法,接收任意关键字参数
def register_modules(self, **kwargs):
# 遍历关键字参数中的模块
for name, module in kwargs.items():
# 检索库
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
# 如果模块为 None,注册字典设置为 None
register_dict = {name: (None, None)}
else:
# 获取库和类名的元组
library, class_name = _fetch_class_library_tuple(module)
# 注册字典设置为库和类名元组
register_dict = {name: (library, class_name)}
# 保存模型索引配置
self.register_to_config(**register_dict)
# 设置模型
setattr(self, name, module)
# 自定义属性设置方法
def __setattr__(self, name: str, value: Any):
# 检查属性是否在实例字典中且在配置中存在
if name in self.__dict__ and hasattr(self.config, name):
# 如果名称在配置中存在,则需要覆盖配置
if isinstance(getattr(self.config, name), (tuple, list)):
# 如果值不为 None 且配置中存在有效值
if value is not None and self.config[name][0] is not None:
# 获取类库元组
class_library_tuple = _fetch_class_library_tuple(value)
else:
# 否则设置为 None
class_library_tuple = (None, None)
# 注册到配置中
self.register_to_config(**{name: class_library_tuple})
else:
# 直接注册到配置中
self.register_to_config(**{name: value})
# 调用父类的设置属性方法
super().__setattr__(name, value)
# 保存预训练模型的方法
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
@property
# 定义一个方法,返回当前使用的设备类型
def device(self) -> torch.device:
r"""
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
# 获取当前实例的模块名和其他相关信息
module_names, _ = self._get_signature_keys(self)
# 根据模块名获取实例中对应的模块对象,若不存在则为 None
modules = [getattr(self, n, None) for n in module_names]
# 过滤出类型为 torch.nn.Module 的模块
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
# 遍历所有模块
for module in modules:
# 返回第一个模块的设备类型
return module.device
# 如果没有模块,默认返回 CPU 设备
return torch.device("cpu")
# 定义一个只读属性,返回当前使用的数据类型
@property
def dtype(self) -> torch.dtype:
r"""
Returns:
`torch.dtype`: The torch dtype on which the pipeline is located.
"""
# 获取当前实例的模块名和其他相关信息
module_names, _ = self._get_signature_keys(self)
# 根据模块名获取实例中对应的模块对象,若不存在则为 None
modules = [getattr(self, n, None) for n in module_names]
# 过滤出类型为 torch.nn.Module 的模块
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
# 遍历所有模块
for module in modules:
# 返回第一个模块的数据类型
return module.dtype
# 如果没有模块,默认返回 float32 数据类型
return torch.float32
# 定义一个类方法,返回模型的名称或路径
@classmethod
@validate_hf_hub_args
@property
def name_or_path(self) -> str:
# 从配置中获取名称或路径,若不存在则为 None
return getattr(self.config, "_name_or_path", None)
# 定义一个只读属性,返回执行设备
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
# 遍历组件字典中的每个模型
for name, model in self.components.items():
# 如果不是 nn.Module 或者在排除列表中,则跳过
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
continue
# 如果模型没有 HF hook,返回当前设备
if not hasattr(model, "_hf_hook"):
return self.device
# 遍历模型中的所有模块
for module in model.modules():
# 检查模块是否有执行设备信息
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
# 返回找到的执行设备
return torch.device(module._hf_hook.execution_device)
# 如果没有找到,返回当前设备
return self.device
# 定义一个方法,用于移除所有注册的 hook
def remove_all_hooks(self):
r"""
Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
"""
# 遍历组件字典中的每个模型
for _, model in self.components.items():
# 如果是 nn.Module 且有 HF hook,则移除 hook
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
accelerate.hooks.remove_hook_from_module(model, recurse=True)
# 清空所有 hooks 列表
self._all_hooks = []
# 定义一个可能释放模型钩子的函数
def maybe_free_model_hooks(self):
r"""
该函数卸载所有组件,移除通过 `enable_model_cpu_offload` 添加的模型钩子,然后再次应用它们。
如果模型未被卸载,该函数无操作。确保将此函数添加到管道的 `__call__` 函数末尾,以便在应用 enable_model_cpu_offload 时正确工作。
"""
# 检查是否没有钩子被添加,如果没有,什么都不做
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
# `enable_model_cpu_offload` 尚未被调用,因此静默返回
return
# 确保模型的状态与调用之前一致
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
# 定义一个重置设备映射的函数
def reset_device_map(self):
r"""
将设备映射(如果存在)重置为 None。
"""
# 如果设备映射已经是 None,什么都不做
if self.hf_device_map is None:
return
else:
# 移除所有钩子
self.remove_all_hooks()
# 遍历组件,将每个 torch.nn.Module 移动到 CPU
for name, component in self.components.items():
if isinstance(component, torch.nn.Module):
component.to("cpu")
# 将设备映射设置为 None
self.hf_device_map = None
# 定义一个类方法以获取签名键
@classmethod
@validate_hf_hub_args
@classmethod
def _get_signature_keys(cls, obj):
# 获取对象初始化方法的参数
parameters = inspect.signature(obj.__init__).parameters
# 获取所需参数(没有默认值的)
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
# 获取可选参数(有默认值的)
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
# 预期模块为所需参数的键集,排除 "self"
expected_modules = set(required_parameters.keys()) - {"self"}
# 将可选参数名转换为列表
optional_names = list(optional_parameters)
# 遍历可选参数,如果在可选组件中,则添加到预期模块并从可选参数中移除
for name in optional_names:
if name in cls._optional_components:
expected_modules.add(name)
optional_parameters.remove(name)
# 返回预期模块和可选参数
return expected_modules, optional_parameters
# 定义一个类方法以获取签名类型
@classmethod
def _get_signature_types(cls):
# 初始化一个字典以存储签名类型
signature_types = {}
# 遍历初始化方法的参数,获取每个参数的注解
for k, v in inspect.signature(cls.__init__).parameters.items():
# 如果参数注解是类,存储该注解
if inspect.isclass(v.annotation):
signature_types[k] = (v.annotation,)
# 如果参数注解是联合类型,获取所有类型
elif get_origin(v.annotation) == Union:
signature_types[k] = get_args(v.annotation)
# 如果无法获取类型注解,记录警告
else:
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
# 返回签名类型字典
return signature_types
# 定义一个属性
@property
# 定义一个方法,返回一个字典,包含初始化管道所需的所有模块
def components(self) -> Dict[str, Any]:
r""" # 文档字符串,描述方法的功能和返回值
The `self.components` property can be useful to run different pipelines with the same weights and
configurations without reallocating additional memory.
Returns (`dict`):
A dictionary containing all the modules needed to initialize the pipeline.
Examples:
```py
>>> from diffusers import (
... StableDiffusionPipeline,
... StableDiffusionImg2ImgPipeline,
... StableDiffusionInpaintPipeline,
... )
>>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
```py
"""
# 获取预期模块和可选参数的签名
expected_modules, optional_parameters = self._get_signature_keys(self)
# 构建一个字典,包含所有必要的组件
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
# 检查组件的键是否与预期模块匹配
if set(components.keys()) != expected_modules:
# 如果不匹配,抛出错误,说明初始化有误
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components.keys()} are defined."
)
# 返回构建的组件字典
return components
# 定义一个静态方法,将 NumPy 图像或图像批次转换为 PIL 图像
@staticmethod
def numpy_to_pil(images):
"""
Convert a NumPy image or a batch of images to a PIL image.
"""
# 调用外部函数进行转换
return numpy_to_pil(images)
# 定义一个方法,用于创建进度条
def progress_bar(self, iterable=None, total=None):
# 检查是否已定义进度条配置,如果没有则初始化为空字典
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
# 如果已经定义,则检查其类型是否为字典
elif not isinstance(self._progress_bar_config, dict):
# 如果类型不匹配,抛出错误
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
# 如果提供了可迭代对象,则返回一个带进度条的可迭代对象
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
# 如果提供了总数,则返回一个总数为 total 的进度条
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
# 如果两个都没有提供,抛出错误
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
# 定义一个方法,用于设置进度条的配置
def set_progress_bar_config(self, **kwargs):
# 将传入的参数存储到进度条配置中
self._progress_bar_config = kwargs
# 定义一个启用 xFormers 内存高效注意力的方法,支持可选的注意力操作
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
启用来自 [xFormers](https://facebookresearch.github.io/xformers/) 的内存高效注意力。启用此选项后,
你应该会观察到较低的 GPU 内存使用率,并在推理过程中可能加速。训练期间的加速不保证。
<Tip warning={true}>
⚠️ 当同时启用内存高效注意力和切片注意力时,内存高效注意力优先。
</Tip>
参数:
attention_op (`Callable`, *可选*):
用于覆盖默认的 `None` 操作符,以用作 xFormers 的
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
函数的 `op` 参数。
示例:
```py
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
>>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
>>> # 针对 Flash Attention 使用 VAE 时不接受注意力形状的解决方法
>>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
```py
"""
# 调用设置内存高效注意力的函数,并将标志设为 True 和传入的注意力操作
self.set_use_memory_efficient_attention_xformers(True, attention_op)
# 定义一个禁用 xFormers 内存高效注意力的方法
def disable_xformers_memory_efficient_attention(self):
r"""
禁用来自 [xFormers](https://facebookresearch.github.io/xformers/) 的内存高效注意力。
"""
# 调用设置内存高效注意力的函数,并将标志设为 False
self.set_use_memory_efficient_attention_xformers(False)
# 定义一个设置内存高效注意力的函数,接受有效标志和可选注意力操作
def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# 递归遍历所有子模块
# 任何暴露 set_use_memory_efficient_attention_xformers 方法的子模块将接收此消息
def fn_recursive_set_mem_eff(module: torch.nn.Module):
# 检查模块是否有设置内存高效注意力的方法,如果有则调用
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
# 递归处理所有子模块
for child in module.children():
fn_recursive_set_mem_eff(child)
# 获取当前对象的模块名称及其签名
module_names, _ = self._get_signature_keys(self)
# 获取所有子模块,过滤出 torch.nn.Module 类型的模块
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
# 对每个模块调用递归设置函数
for module in modules:
fn_recursive_set_mem_eff(module)
# 定义一个方法来启用切片注意力计算,默认为“auto”
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
启用切片注意力计算。当启用此选项时,注意力模块将输入张量分成多个切片
以分步骤计算注意力。对于多个注意力头,计算将在每个头上顺序执行。
这有助于节省一些内存,换取略微降低的速度。
<Tip warning={true}>
⚠️ 如果您已经在使用来自 PyTorch 2.0 的 `scaled_dot_product_attention` (SDPA) 或 xFormers,
请勿启用注意力切片。这些注意力计算已经非常节省内存,因此您不需要启用
此功能。如果在 SDPA 或 xFormers 中启用注意力切片,可能会导致严重的性能下降!
</Tip>
参数:
slice_size (`str` 或 `int`, *可选*, 默认为 `"auto"`):
当为 `"auto"` 时,将输入分为两个注意力头进行计算。
如果为 `"max"`,则通过一次只运行一个切片来保存最大内存。
如果提供一个数字,使用 `attention_head_dim // slice_size` 个切片。
在这种情况下,`attention_head_dim` 必须是 `slice_size` 的倍数。
示例:
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline
>>> pipe = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5",
... torch_dtype=torch.float16,
... use_safetensors=True,
... )
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> pipe.enable_attention_slicing()
>>> image = pipe(prompt).images[0]
```
"""
# 调用设置切片的方法,传入切片大小
self.set_attention_slice(slice_size)
# 定义一个方法来禁用切片注意力计算
def disable_attention_slicing(self):
r"""
禁用切片注意力计算。如果之前调用过 `enable_attention_slicing`,则注意力
将在一步中计算。
"""
# 将切片大小设置为 `None` 以禁用 `attention slicing`
self.enable_attention_slicing(None)
# 定义一个方法来设置切片大小
def set_attention_slice(self, slice_size: Optional[int]):
# 获取当前类的签名键和模块名称
module_names, _ = self._get_signature_keys(self)
# 获取当前类的所有模块
modules = [getattr(self, n, None) for n in module_names]
# 过滤出具有 `set_attention_slice` 方法的 PyTorch 模块
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
# 遍历所有模块并设置切片大小
for module in modules:
module.set_attention_slice(slice_size)
# 类方法的定义开始
@classmethod
# 定义一个混合类,用于处理具有 VAE 和 UNet 的扩散管道(主要用于稳定扩散 LDM)
class StableDiffusionMixin:
r"""
帮助 DiffusionPipeline 使用 VAE 和 UNet(主要用于 LDM,如稳定扩散)
"""
# 启用切片 VAE 解码的功能
def enable_vae_slicing(self):
r"""
启用切片 VAE 解码。当启用此选项时,VAE 将输入张量分割为切片
以分几步计算解码。这对于节省内存和允许更大的批处理大小很有用。
"""
# 调用 VAE 的方法以启用切片
self.vae.enable_slicing()
# 禁用切片 VAE 解码的功能
def disable_vae_slicing(self):
r"""
禁用切片 VAE 解码。如果之前启用了 `enable_vae_slicing`,此方法将恢复到
一步计算解码。
"""
# 调用 VAE 的方法以禁用切片
self.vae.disable_slicing()
# 启用平铺 VAE 解码的功能
def enable_vae_tiling(self):
r"""
启用平铺 VAE 解码。当启用此选项时,VAE 将输入张量分割为块
以分几步计算解码和编码。这对于节省大量内存并允许处理更大图像很有用。
"""
# 调用 VAE 的方法以启用平铺
self.vae.enable_tiling()
# 禁用平铺 VAE 解码的功能
def disable_vae_tiling(self):
r"""
禁用平铺 VAE 解码。如果之前启用了 `enable_vae_tiling`,此方法将恢复到
一步计算解码。
"""
# 调用 VAE 的方法以禁用平铺
self.vae.disable_tiling()
# 启用 FreeU 机制,使用指定的缩放因子
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""启用 FreeU 机制,如 https://arxiv.org/abs/2309.11497 所述。
缩放因子后缀表示应用它们的阶段。
请参考 [官方库](https://github.com/ChenyangSi/FreeU) 以获取已知适用于不同管道(如
稳定扩散 v1、v2 和稳定扩散 XL)组合的值。
Args:
s1 (`float`):
第一阶段的缩放因子,用于减轻跳过特征的贡献,以缓解增强去噪过程中的
“过平滑效应”。
s2 (`float`):
第二阶段的缩放因子,用于减轻跳过特征的贡献,以缓解增强去噪过程中的
“过平滑效应”。
b1 (`float`): 第一阶段的缩放因子,用于放大骨干特征的贡献。
b2 (`float`): 第二阶段的缩放因子,用于放大骨干特征的贡献。
"""
# 检查当前对象是否具有 `unet` 属性
if not hasattr(self, "unet"):
# 如果没有,则抛出值错误
raise ValueError("The pipeline must have `unet` for using FreeU.")
# 调用 UNet 的方法以启用 FreeU,传递缩放因子
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# 禁用 FreeU 机制
def disable_freeu(self):
"""禁用 FreeU 机制(如果已启用)。"""
# 调用 UNet 的方法以禁用 FreeU
self.unet.disable_freeu()
# 定义融合 QKV 投影的方法,默认启用 UNet 和 VAE
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
启用融合 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)被融合。
对于交叉注意力模块,键和值投影矩阵被融合。
<Tip warning={true}>
此 API 为 🧪 实验性。
</Tip>
参数:
unet (`bool`, 默认值为 `True`): 是否在 UNet 上应用融合。
vae (`bool`, 默认值为 `True`): 是否在 VAE 上应用融合。
"""
# 初始化 UNet 和 VAE 的融合状态为 False
self.fusing_unet = False
self.fusing_vae = False
# 如果启用 UNet 融合
if unet:
# 设置 UNet 融合状态为 True
self.fusing_unet = True
# 调用 UNet 的 QKV 融合方法
self.unet.fuse_qkv_projections()
# 设置 UNet 的注意力处理器为融合版本
self.unet.set_attn_processor(FusedAttnProcessor2_0())
# 如果启用 VAE 融合
if vae:
# 检查 VAE 是否为 AutoencoderKL 类型
if not isinstance(self.vae, AutoencoderKL):
# 抛出异常提示不支持的 VAE 类型
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
# 设置 VAE 融合状态为 True
self.fusing_vae = True
# 调用 VAE 的 QKV 融合方法
self.vae.fuse_qkv_projections()
# 设置 VAE 的注意力处理器为融合版本
self.vae.set_attn_processor(FusedAttnProcessor2_0())
# 定义取消 QKV 投影融合的方法,默认启用 UNet 和 VAE
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""如果启用了 QKV 投影融合,则禁用它。
<Tip warning={true}>
此 API 为 🧪 实验性。
</Tip>
参数:
unet (`bool`, 默认值为 `True`): 是否在 UNet 上应用融合。
vae (`bool`, 默认值为 `True`): 是否在 VAE 上应用融合。
"""
# 如果启用 UNet 解融合
if unet:
# 检查 UNet 是否已经融合
if not self.fusing_unet:
# 如果没有融合,记录警告信息
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
# 调用 UNet 的解融合方法
self.unet.unfuse_qkv_projections()
# 设置 UNet 融合状态为 False
self.fusing_unet = False
# 如果启用 VAE 解融合
if vae:
# 检查 VAE 是否已经融合
if not self.fusing_vae:
# 如果没有融合,记录警告信息
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
# 调用 VAE 的解融合方法
self.vae.unfuse_qkv_projections()
# 设置 VAE 融合状态为 False
self.fusing_vae = False
.\diffusers\pipelines\pixart_alpha\pipeline_pixart_alpha.py
# 版权声明,2024年PixArt-Alpha团队与HuggingFace团队版权所有
#
# 根据Apache许可证第2.0版(“许可证”)授权;
# 您只能在遵循许可证的情况下使用此文件。
# 您可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是按“原样”提供的,
# 不附有任何形式的明示或暗示的保证或条件。
# 有关许可证的具体权限和限制,请参阅许可证。
import html # 导入html库,用于处理HTML内容
import inspect # 导入inspect库,用于获取对象的信息
import re # 导入re库,用于正则表达式处理
import urllib.parse as ul # 导入urllib.parse库并简化命名为ul,用于解析URL
from typing import Callable, List, Optional, Tuple, Union # 导入类型提示相关的工具
import torch # 导入PyTorch库,用于深度学习
from transformers import T5EncoderModel, T5Tokenizer # 从transformers库导入T5编码模型和分词器
from ...image_processor import PixArtImageProcessor # 从上级模块导入PixArt图像处理器
from ...models import AutoencoderKL, PixArtTransformer2DModel # 从上级模块导入模型类
from ...schedulers import DPMSolverMultistepScheduler # 从上级模块导入调度器类
from ...utils import ( # 从上级模块导入多个工具函数和常量
BACKENDS_MAPPING,
deprecate,
is_bs4_available,
is_ftfy_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor # 从工具模块导入随机张量生成函数
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput # 从上级模块导入扩散管道和图像输出类
logger = logging.get_logger(__name__) # 创建一个logger实例以记录日志,使用当前模块名称
if is_bs4_available(): # 检查BeautifulSoup库是否可用
from bs4 import BeautifulSoup # 如果可用,则导入BeautifulSoup类
if is_ftfy_available(): # 检查ftfy库是否可用
import ftfy # 如果可用,则导入ftfy库
EXAMPLE_DOC_STRING = """ # 定义示例文档字符串
Examples: # 示例部分
```py # Python代码块开始
>>> import torch # 导入torch库
>>> from diffusers import PixArtAlphaPipeline # 从diffusers模块导入PixArtAlphaPipeline类
>>> # 你可以用"PixArt-alpha/PixArt-XL-2-512x512"替换检查点ID。
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) # 创建管道实例,加载预训练模型
>>> # 启用内存优化。
>>> pipe.enable_model_cpu_offload() # 启用模型的CPU卸载以优化内存使用
>>> prompt = "A small cactus with a happy face in the Sahara desert." # 定义生成图像的提示
>>> image = pipe(prompt).images[0] # 使用提示生成图像并获取第一个图像
```py # Python代码块结束
"""
ASPECT_RATIO_1024_BIN = { # 定义一个字典以存储不同长宽比下的图像尺寸
"0.25": [512.0, 2048.0], # 长宽比为0.25时的尺寸
"0.28": [512.0, 1856.0], # 长宽比为0.28时的尺寸
"0.32": [576.0, 1792.0], # 长宽比为0.32时的尺寸
"0.33": [576.0, 1728.0], # 长宽比为0.33时的尺寸
"0.35": [576.0, 1664.0], # 长宽比为0.35时的尺寸
"0.4": [640.0, 1600.0], # 长宽比为0.4时的尺寸
"0.42": [640.0, 1536.0], # 长宽比为0.42时的尺寸
"0.48": [704.0, 1472.0], # 长宽比为0.48时的尺寸
"0.5": [704.0, 1408.0], # 长宽比为0.5时的尺寸
"0.52": [704.0, 1344.0], # 长宽比为0.52时的尺寸
"0.57": [768.0, 1344.0], # 长宽比为0.57时的尺寸
"0.6": [768.0, 1280.0], # 长宽比为0.6时的尺寸
"0.68": [832.0, 1216.0], # 长宽比为0.68时的尺寸
"0.72": [832.0, 1152.0], # 长宽比为0.72时的尺寸
"0.78": [896.0, 1152.0], # 长宽比为0.78时的尺寸
"0.82": [896.0, 1088.0], # 长宽比为0.82时的尺寸
"0.88": [960.0, 1088.0], # 长宽比为0.88时的尺寸
"0.94": [960.0, 1024.0], # 长宽比为0.94时的尺寸
"1.0": [1024.0, 1024.0], # 长宽比为1.0时的尺寸
"1.07": [1024.0, 960.0], # 长宽比为1.07时的尺寸
"1.13": [1088.0, 960.0], # 长宽比为1.13时的尺寸
"1.21": [1088.0, 896.0], # 长宽比为1.21时的尺寸
"1.29": [1152.0, 896.0], # 长宽比为1.29时的尺寸
"1.38": [1152.0, 832.0], # 长宽比为1.38时的尺寸
"1.46": [1216.0, 832.0], # 长宽比为1.46时的尺寸
"1.67": [1280.0, 768.0], # 长宽比为1.67时的尺寸
"1.75": [1344.0, 768.0], # 长宽比为1.75时的尺寸
"2.0": [1408.0, 704.0], # 长宽比为2.0时的尺寸
"2.09": [1472.0, 704.0], # 长宽比为2.09时的尺寸
"2.4": [1536.0, 640.0], # 长宽比为2.4时的尺寸
"2.5": [1600.0, 640.0], # 长宽比为2.5时的尺寸
"3.0": [1728.0, 576.0], # 长宽比为3.0时的尺寸
"4.0": [2048.0, 512.0], # 长宽比为4.0时的尺寸
}
ASPECT_RATIO_512_BIN = { # 定义一个字典以存储512宽度下的不同长宽比图像尺寸
# 定义一个字典的条目,键为字符串类型的数值,值为包含两个浮点数的列表
"0.25": [256.0, 1024.0], # 键 "0.25" 对应的值是一个列表,包含 256.0 和 1024.0
"0.28": [256.0, 928.0], # 键 "0.28" 对应的值是一个列表,包含 256.0 和 928.0
"0.32": [288.0, 896.0], # 键 "0.32" 对应的值是一个列表,包含 288.0 和 896.0
"0.33": [288.0, 864.0], # 键 "0.33" 对应的值是一个列表,包含 288.0 和 864.0
"0.35": [288.0, 832.0], # 键 "0.35" 对应的值是一个列表,包含 288.0 和 832.0
"0.4": [320.0, 800.0], # 键 "0.4" 对应的值是一个列表,包含 320.0 和 800.0
"0.42": [320.0, 768.0], # 键 "0.42" 对应的值是一个列表,包含 320.0 和 768.0
"0.48": [352.0, 736.0], # 键 "0.48" 对应的值是一个列表,包含 352.0 和 736.0
"0.5": [352.0, 704.0], # 键 "0.5" 对应的值是一个列表,包含 352.0 和 704.0
"0.52": [352.0, 672.0], # 键 "0.52" 对应的值是一个列表,包含 352.0 和 672.0
"0.57": [384.0, 672.0], # 键 "0.57" 对应的值是一个列表,包含 384.0 和 672.0
"0.6": [384.0, 640.0], # 键 "0.6" 对应的值是一个列表,包含 384.0 和 640.0
"0.68": [416.0, 608.0], # 键 "0.68" 对应的值是一个列表,包含 416.0 和 608.0
"0.72": [416.0, 576.0], # 键 "0.72" 对应的值是一个列表,包含 416.0 和 576.0
"0.78": [448.0, 576.0], # 键 "0.78" 对应的值是一个列表,包含 448.0 和 576.0
"0.82": [448.0, 544.0], # 键 "0.82" 对应的值是一个列表,包含 448.0 和 544.0
"0.88": [480.0, 544.0], # 键 "0.88" 对应的值是一个列表,包含 480.0 和 544.0
"0.94": [480.0, 512.0], # 键 "0.94" 对应的值是一个列表,包含 480.0 和 512.0
"1.0": [512.0, 512.0], # 键 "1.0" 对应的值是一个列表,包含 512.0 和 512.0
"1.07": [512.0, 480.0], # 键 "1.07" 对应的值是一个列表,包含 512.0 和 480.0
"1.13": [544.0, 480.0], # 键 "1.13" 对应的值是一个列表,包含 544.0 和 480.0
"1.21": [544.0, 448.0], # 键 "1.21" 对应的值是一个列表,包含 544.0 和 448.0
"1.29": [576.0, 448.0], # 键 "1.29" 对应的值是一个列表,包含 576.0 和 448.0
"1.38": [576.0, 416.0], # 键 "1.38" 对应的值是一个列表,包含 576.0 和 416.0
"1.46": [608.0, 416.0], # 键 "1.46" 对应的值是一个列表,包含 608.0 和 416.0
"1.67": [640.0, 384.0], # 键 "1.67" 对应的值是一个列表,包含 640.0 和 384.0
"1.75": [672.0, 384.0], # 键 "1.75" 对应的值是一个列表,包含 672.0 和 384.0
"2.0": [704.0, 352.0], # 键 "2.0" 对应的值是一个列表,包含 704.0 和 352.0
"2.09": [736.0, 352.0], # 键 "2.09" 对应的值是一个列表,包含 736.0 和 352.0
"2.4": [768.0, 320.0], # 键 "2.4" 对应的值是一个列表,包含 768.0 和 320.0
"2.5": [800.0, 320.0], # 键 "2.5" 对应的值是一个列表,包含 800.0 和 320.0
"3.0": [864.0, 288.0], # 键 "3.0" 对应的值是一个列表,包含 864.0 和 288.0
"4.0": [1024.0, 256.0], # 键 "4.0" 对应的值是一个列表,包含 1024.0 和 256.0
# 定义一个常量字典,表示不同宽高比对应的二进制值
ASPECT_RATIO_256_BIN = {
# 键为宽高比,值为对应的宽度和高度
"0.25": [128.0, 512.0],
"0.28": [128.0, 464.0],
"0.32": [144.0, 448.0],
"0.33": [144.0, 432.0],
"0.35": [144.0, 416.0],
"0.4": [160.0, 400.0],
"0.42": [160.0, 384.0],
"0.48": [176.0, 368.0],
"0.5": [176.0, 352.0],
"0.52": [176.0, 336.0],
"0.57": [192.0, 336.0],
"0.6": [192.0, 320.0],
"0.68": [208.0, 304.0],
"0.72": [208.0, 288.0],
"0.78": [224.0, 288.0],
"0.82": [224.0, 272.0],
"0.88": [240.0, 272.0],
"0.94": [240.0, 256.0],
"1.0": [256.0, 256.0],
"1.07": [256.0, 240.0],
"1.13": [272.0, 240.0],
"1.21": [272.0, 224.0],
"1.29": [288.0, 224.0],
"1.38": [288.0, 208.0],
"1.46": [304.0, 208.0],
"1.67": [320.0, 192.0],
"1.75": [336.0, 192.0],
"2.0": [352.0, 176.0],
"2.09": [368.0, 176.0],
"2.4": [384.0, 160.0],
"2.5": [400.0, 160.0],
"3.0": [432.0, 144.0],
"4.0": [512.0, 128.0],
}
# 定义一个函数,用于从调度器中检索时间步长
def retrieve_timesteps(
# 调度器对象
scheduler,
# 推断步骤数,可选
num_inference_steps: Optional[int] = None,
# 指定设备,可选
device: Optional[Union[str, torch.device]] = None,
# 自定义时间步长,可选
timesteps: Optional[List[int]] = None,
# 自定义sigma值,可选
sigmas: Optional[List[float]] = None,
# 其他可选参数
**kwargs,
):
"""
调用调度器的 `set_timesteps` 方法,并在调用后从调度器中检索时间步长。处理自定义时间步长。
任何kwargs将传递给 `scheduler.set_timesteps`。
参数:
scheduler (`SchedulerMixin`):
用于获取时间步长的调度器。
num_inference_steps (`int`):
生成样本时使用的扩散步骤数。如果使用,`timesteps` 必须为 `None`。
device (`str` 或 `torch.device`, *可选*):
时间步长移动到的设备。如果为 `None`,则不移动时间步长。
timesteps (`List[int]`, *可选*):
自定义时间步长,用于覆盖调度器的时间步长间距策略。如果传递 `timesteps`,`num_inference_steps` 和 `sigmas` 必须为 `None`。
sigmas (`List[float]`, *可选*):
自定义sigma值,用于覆盖调度器的时间步长间距策略。如果传递 `sigmas`,`num_inference_steps` 和 `timesteps` 必须为 `None`。
返回:
`Tuple[torch.Tensor, int]`: 一个元组,第一个元素是调度器的时间步长调度,第二个元素是推断步骤数。
"""
# 检查是否同时传递了时间步长和sigma值,抛出错误
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
# 检查 timesteps 是否为 None,如果不为 None,表示需要处理时间步
if timesteps is not None:
# 检查当前调度器的 set_timesteps 方法是否接受 timesteps 参数
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
# 如果不接受 timesteps 参数,抛出错误
if not accepts_timesteps:
raise ValueError(
# 报告调度器类不支持自定义时间步调度的错误
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
# 调用调度器的 set_timesteps 方法,传入 timesteps 和其他参数
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
# 从调度器获取当前的时间步
timesteps = scheduler.timesteps
# 计算推理步骤的数量
num_inference_steps = len(timesteps)
# 如果 timesteps 为 None,检查 sigmas 是否不为 None
elif sigmas is not None:
# 检查当前调度器的 set_timesteps 方法是否接受 sigmas 参数
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
# 如果不接受 sigmas 参数,抛出错误
if not accept_sigmas:
raise ValueError(
# 报告调度器类不支持自定义 sigma 调度的错误
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
# 调用调度器的 set_timesteps 方法,传入 sigmas 和其他参数
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
# 从调度器获取当前的时间步
timesteps = scheduler.timesteps
# 计算推理步骤的数量
num_inference_steps = len(timesteps)
# 如果两者都为 None,使用默认推理步骤数设置调度器
else:
# 调用调度器的 set_timesteps 方法,传入推理步骤数和其他参数
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
# 从调度器获取当前的时间步
timesteps = scheduler.timesteps
# 返回时间步和推理步骤的数量
return timesteps, num_inference_steps
# 定义一个名为 PixArtAlphaPipeline 的类,继承自 DiffusionPipeline
class PixArtAlphaPipeline(DiffusionPipeline):
r"""
用于文本到图像生成的 PixArt-Alpha 管道。
此模型继承自 [`DiffusionPipeline`]。有关库为所有管道实现的通用方法(例如下载或保存、在特定设备上运行等),请查看超类文档。
参数:
vae ([`AutoencoderKL`]):
变分自编码器(VAE)模型,用于将图像编码和解码为潜在表示。
text_encoder ([`T5EncoderModel`]):
冻结的文本编码器。PixArt-Alpha 使用
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel),具体为
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) 变体。
tokenizer (`T5Tokenizer`):
类的标记器
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer)。
transformer ([`PixArtTransformer2DModel`]):
一个文本条件的 `PixArtTransformer2DModel`,用于去噪编码的图像潜在。
scheduler ([`SchedulerMixin`]):
一个调度器,用于与 `transformer` 结合使用,以去噪编码的图像潜在。
"""
# 定义一个正则表达式,用于匹配不良标点符号
bad_punct_regex = re.compile(
r"["
+ "#®•©™&@·º½¾¿¡§~"
+ r"\)"
+ r"\("
+ r"\]"
+ r"\["
+ r"\}"
+ r"\{"
+ r"\|"
+ "\\"
+ r"\/"
+ r"\*"
+ r"]{1,}"
) # noqa
# 可选组件的列表,包含 tokenizer 和 text_encoder
_optional_components = ["tokenizer", "text_encoder"]
# 定义模型 CPU 卸载的顺序
model_cpu_offload_seq = "text_encoder->transformer->vae"
# 初始化方法,定义类的构造函数
def __init__(
self,
tokenizer: T5Tokenizer, # 输入的标记器
text_encoder: T5EncoderModel, # 输入的文本编码器
vae: AutoencoderKL, # 输入的变分自编码器模型
transformer: PixArtTransformer2DModel, # 输入的 PixArt 转换器模型
scheduler: DPMSolverMultistepScheduler, # 输入的调度器
):
super().__init__() # 调用父类的构造函数
# 注册模块,包括 tokenizer、text_encoder、vae、transformer 和 scheduler
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
# 根据 VAE 的配置计算缩放因子
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 创建图像处理器实例,使用计算得到的缩放因子
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# 从 diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt 中改编而来
def encode_prompt(
self,
prompt: Union[str, List[str]], # 输入的提示,可以是字符串或字符串列表
do_classifier_free_guidance: bool = True, # 是否使用无分类器自由引导
negative_prompt: str = "", # 可选的负面提示
num_images_per_prompt: int = 1, # 每个提示生成的图像数量
device: Optional[torch.device] = None, # 设备参数,默认是 None
prompt_embeds: Optional[torch.Tensor] = None, # 提示的嵌入向量,默认是 None
negative_prompt_embeds: Optional[torch.Tensor] = None, # 负面提示的嵌入向量,默认是 None
prompt_attention_mask: Optional[torch.Tensor] = None, # 提示的注意力掩码,默认是 None
negative_prompt_attention_mask: Optional[torch.Tensor] = None, # 负面提示的注意力掩码,默认是 None
clean_caption: bool = False, # 是否清理标题,默认是 False
max_sequence_length: int = 120, # 最大序列长度,默认是 120
**kwargs, # 其他关键字参数
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制而来
def prepare_extra_step_kwargs(self, generator, eta):
# 为调度器步骤准备额外的参数,因为并非所有调度器都有相同的签名
# eta(η)仅用于 DDIMScheduler,其他调度器将忽略它。
# eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
# 并且应在 [0, 1] 之间
# 检查调度器步骤是否接受 eta 参数
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 创建一个空字典以存储额外参数
extra_step_kwargs = {}
# 如果接受 eta,将其添加到额外参数字典中
if accepts_eta:
extra_step_kwargs["eta"] = eta
# 检查调度器步骤是否接受 generator 参数
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 如果接受 generator,将其添加到额外参数字典中
if accepts_generator:
extra_step_kwargs["generator"] = generator
# 返回包含额外参数的字典
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
# 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing 复制而来
def _text_preprocessing(self, text, clean_caption=False):
# 如果需要清理标题但 bs4 不可用,则记录警告并将 clean_caption 设置为 False
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
# 如果需要清理标题但 ftfy 不可用,则记录警告并将 clean_caption 设置为 False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
# 如果 text 不是元组或列表,则将其转换为列表
if not isinstance(text, (tuple, list)):
text = [text]
# 定义处理文本的内部函数
def process(text: str):
# 如果需要清理标题,调用清理方法
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
# 否则将文本转为小写并去除空白
text = text.lower().strip()
return text
# 返回处理后的文本列表
return [process(t) for t in text]
# 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption 复制而来
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 复制而来
# 准备潜在向量,接受多种参数以控制生成过程
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
# 定义潜在向量的形状,基于输入参数计算
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor, # 通过 VAE 缩放因子调整高度
int(width) // self.vae_scale_factor, # 通过 VAE 缩放因子调整宽度
)
# 检查生成器是否是列表且其长度与批次大小匹配
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
# 抛出错误,提示生成器的长度与请求的批次大小不匹配
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果没有提供潜在向量,则生成随机潜在向量
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 将提供的潜在向量转换到指定的设备上
latents = latents.to(device)
# 根据调度器要求的标准差缩放初始噪声
latents = latents * self.scheduler.init_noise_sigma
# 返回准备好的潜在向量
return latents
# 禁用梯度计算以减少内存使用
@torch.no_grad()
# 替换示例文档字符串
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
# 定义调用方法的输入参数及其默认值
prompt: Union[str, List[str]] = None, # 正向提示
negative_prompt: str = "", # 负向提示
num_inference_steps: int = 20, # 推理步骤的数量
timesteps: List[int] = None, # 时间步长列表
sigmas: List[float] = None, # 噪声标准差列表
guidance_scale: float = 4.5, # 引导缩放因子
num_images_per_prompt: Optional[int] = 1, # 每个提示生成的图像数量
height: Optional[int] = None, # 输出图像的高度
width: Optional[int] = None, # 输出图像的宽度
eta: float = 0.0, # 附加参数
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, # 生成器
latents: Optional[torch.Tensor] = None, # 潜在向量
prompt_embeds: Optional[torch.Tensor] = None, # 正向提示嵌入
prompt_attention_mask: Optional[torch.Tensor] = None, # 正向提示注意力掩码
negative_prompt_embeds: Optional[torch.Tensor] = None, # 负向提示嵌入
negative_prompt_attention_mask: Optional[torch.Tensor] = None, # 负向提示注意力掩码
output_type: Optional[str] = "pil", # 输出类型,默认为 PIL 图像
return_dict: bool = True, # 是否返回字典格式的结果
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, # 回调函数
callback_steps: int = 1, # 每隔多少步调用一次回调
clean_caption: bool = True, # 是否清理标题
use_resolution_binning: bool = True, # 是否使用分辨率分箱
max_sequence_length: int = 120, # 最大序列长度
**kwargs, # 其他可选参数
.\diffusers\pipelines\pixart_alpha\pipeline_pixart_sigma.py
# 版权声明,标明版权所有者和团队
# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
#
# 按照 Apache 2.0 许可证许可使用本文件
# Licensed under the Apache License, Version 2.0 (the "License");
# 仅可在遵守许可证的情况下使用此文件
# you may not use this file except in compliance with the License.
# 可在以下网址获取许可证副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,本软件按“现状”基础分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解特定的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
import html # 导入 html 模块,用于处理 HTML 实体
import inspect # 导入 inspect 模块,用于获取对象的信息
import re # 导入 re 模块,用于正则表达式匹配
import urllib.parse as ul # 导入 urllib.parse 模块并重命名为 ul,用于处理 URL
from typing import Callable, List, Optional, Tuple, Union # 导入类型注解,便于定义类型
import torch # 导入 PyTorch 库,用于深度学习
from transformers import T5EncoderModel, T5Tokenizer # 从 transformers 导入 T5 编码器模型和分词器
from ...image_processor import PixArtImageProcessor # 从相对路径导入 PixArt 图像处理器
from ...models import AutoencoderKL, PixArtTransformer2DModel # 从相对路径导入模型
from ...schedulers import KarrasDiffusionSchedulers # 从相对路径导入 Karras 采样调度器
from ...utils import ( # 从相对路径导入多个工具函数
BACKENDS_MAPPING, # 后端映射
deprecate, # 标记弃用的函数
is_bs4_available, # 检查 BeautifulSoup 是否可用
is_ftfy_available, # 检查 ftfy 是否可用
logging, # 日志记录工具
replace_example_docstring, # 替换示例文档字符串的函数
)
from ...utils.torch_utils import randn_tensor # 从相对路径导入 randn_tensor 函数
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput # 从相对路径导入扩散管道和图像输出
from .pipeline_pixart_alpha import ( # 从相对路径导入多个常量
ASPECT_RATIO_256_BIN, # 256 比例常量
ASPECT_RATIO_512_BIN, # 512 比例常量
ASPECT_RATIO_1024_BIN, # 1024 比例常量
)
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器,禁用 pylint 的无效名称警告
if is_bs4_available(): # 检查 BeautifulSoup 是否可用
from bs4 import BeautifulSoup # 导入 BeautifulSoup 库以解析 HTML
if is_ftfy_available(): # 检查 ftfy 库是否可用
import ftfy # 导入 ftfy 库用于文本修复
# 定义一个字典,用于存储不同宽高比的尺寸
ASPECT_RATIO_2048_BIN = {
"0.25": [1024.0, 4096.0], # 宽高比为 0.25 时的宽高尺寸
"0.26": [1024.0, 3968.0], # 宽高比为 0.26 时的宽高尺寸
"0.27": [1024.0, 3840.0], # 宽高比为 0.27 时的宽高尺寸
"0.28": [1024.0, 3712.0], # 宽高比为 0.28 时的宽高尺寸
"0.32": [1152.0, 3584.0], # 宽高比为 0.32 时的宽高尺寸
"0.33": [1152.0, 3456.0], # 宽高比为 0.33 时的宽高尺寸
"0.35": [1152.0, 3328.0], # 宽高比为 0.35 时的宽高尺寸
"0.4": [1280.0, 3200.0], # 宽高比为 0.4 时的宽高尺寸
"0.42": [1280.0, 3072.0], # 宽高比为 0.42 时的宽高尺寸
"0.48": [1408.0, 2944.0], # 宽高比为 0.48 时的宽高尺寸
"0.5": [1408.0, 2816.0], # 宽高比为 0.5 时的宽高尺寸
"0.52": [1408.0, 2688.0], # 宽高比为 0.52 时的宽高尺寸
"0.57": [1536.0, 2688.0], # 宽高比为 0.57 时的宽高尺寸
"0.6": [1536.0, 2560.0], # 宽高比为 0.6 时的宽高尺寸
"0.68": [1664.0, 2432.0], # 宽高比为 0.68 时的宽高尺寸
"0.72": [1664.0, 2304.0], # 宽高比为 0.72 时的宽高尺寸
"0.78": [1792.0, 2304.0], # 宽高比为 0.78 时的宽高尺寸
"0.82": [1792.0, 2176.0], # 宽高比为 0.82 时的宽高尺寸
"0.88": [1920.0, 2176.0], # 宽高比为 0.88 时的宽高尺寸
"0.94": [1920.0, 2048.0], # 宽高比为 0.94 时的宽高尺寸
"1.0": [2048.0, 2048.0], # 宽高比为 1.0 时的宽高尺寸
"1.07": [2048.0, 1920.0], # 宽高比为 1.07 时的宽高尺寸
"1.13": [2176.0, 1920.0], # 宽高比为 1.13 时的宽高尺寸
"1.21": [2176.0, 1792.0], # 宽高比为 1.21 时的宽高尺寸
"1.29": [2304.0, 1792.0], # 宽高比为 1.29 时的宽高尺寸
"1.38": [2304.0, 1664.0], # 宽高比为 1.38 时的宽高尺寸
"1.46": [2432.0, 1664.0], # 宽高比为 1.46 时的宽高尺寸
"1.67": [2560.0, 1536.0], # 宽高比为 1.67 时的宽高尺寸
"1.75": [2688.0, 1536.0], # 宽高比为 1.75 时的宽高尺寸
"2.0": [2816.0, 1408.0], # 宽高比为 2.0 时的宽高尺寸
"2.09": [2944.0, 1408.0], # 宽高比为 2.09 时的宽高尺寸
"2.4": [3072.0, 1280.0], # 宽高比为 2.4 时的宽高尺寸
"2.5": [3200.0, 1280.0], # 宽高比为 2.5 时的宽高尺寸
"2.89": [3328.0, 1152.0], # 宽高比为 2.89 时的宽高尺寸
"3.0": [3456.0, 1152.0], # 宽高比为 3.0 时的宽
# 示例代码部分
Examples:
```py
>>> import torch # 导入 PyTorch 库,以便进行张量操作和深度学习
>>> from diffusers import PixArtSigmaPipeline # 从 diffusers 库导入 PixArtSigmaPipeline 类
>>> # 你可以将检查点 ID 替换为 "PixArt-alpha/PixArt-Sigma-XL-2-512-MS"
>>> pipe = PixArtSigmaPipeline.from_pretrained( # 从预训练模型加载 PixArtSigmaPipeline
... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16 # 指定模型路径及张量数据类型为 float16
... )
>>> # 启用内存优化
>>> # pipe.enable_model_cpu_offload() # 可选:启用模型的 CPU 卸载以节省内存
>>> prompt = "A small cactus with a happy face in the Sahara desert." # 设置生成图像的提示文本
>>> image = pipe(prompt).images[0] # 生成图像并提取第一张图像
```py
"""
# 该函数用于从调度器中检索时间步
# 复制自 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler, # 调度器对象,用于获取时间步
num_inference_steps: Optional[int] = None, # 生成样本时使用的扩散步骤数量,默认为 None
device: Optional[Union[str, torch.device]] = None, # 要移动时间步的设备,默认为 None
timesteps: Optional[List[int]] = None, # 自定义时间步,用于覆盖调度器的时间步间距策略,默认为 None
sigmas: Optional[List[float]] = None, # 自定义 sigma,用于覆盖调度器的时间步间距策略,默认为 None
**kwargs, # 其他可选参数,传递给调度器的 set_timesteps 方法
):
"""
调用调度器的 `set_timesteps` 方法并在调用后从调度器检索时间步。处理
自定义时间步。任何 kwargs 将被传递给 `scheduler.set_timesteps`。
Args:
scheduler (`SchedulerMixin`): 需要从中获取时间步的调度器。
num_inference_steps (`int`): 生成样本时使用的扩散步骤数量。如果使用,`timesteps` 必须为 `None`。
device (`str` or `torch.device`, *optional*): 要移动时间步的设备。如果为 `None`,则不移动时间步。
timesteps (`List[int]`, *optional*): 自定义时间步,用于覆盖调度器的时间步间距策略。如果传递了 `timesteps`,则 `num_inference_steps` 和 `sigmas` 必须为 `None`。
sigmas (`List[float]`, *optional*): 自定义 sigma,用于覆盖调度器的时间步间距策略。如果传递了 `sigmas`,则 `num_inference_steps` 和 `timesteps` 必须为 `None`。
Returns:
`Tuple[torch.Tensor, int]`: 一个元组,其中第一个元素是调度器的时间步调度,第二个元素是推理步骤的数量。
"""
# 检查是否同时传递了自定义时间步和 sigma
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
# 如果传递了自定义时间步
if timesteps is not None:
# 检查调度器的 set_timesteps 方法是否接受自定义时间步
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: # 如果不支持,抛出错误
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
# 设置自定义时间步并移动到指定设备
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
# 从调度器获取设置后的时间步
timesteps = scheduler.timesteps
# 计算推理步骤的数量
num_inference_steps = len(timesteps)
# 如果传递了自定义 sigma
elif sigmas is not None:
# 检查调度器的 set_timesteps 方法是否接受自定义 sigma
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas: # 如果不支持,抛出错误
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
# 设置自定义 sigma 并移动到指定设备
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
# 从调度器获取设置后的时间步
timesteps = scheduler.timesteps
# 计算推理步骤的数量
num_inference_steps = len(timesteps)
else: # 如果条件不满足,则执行下面的代码
# 设置调度器的时间步数,传入推理步数和设备参数,可能还包含其他关键字参数
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
# 获取当前调度器的时间步数
timesteps = scheduler.timesteps
# 返回时间步数和推理步数
return timesteps, num_inference_steps
# 定义一个名为 PixArtSigmaPipeline 的类,继承自 DiffusionPipeline 类
class PixArtSigmaPipeline(DiffusionPipeline):
r"""
使用 PixArt-Sigma 进行文本到图像生成的管道。
"""
# 编译一个正则表达式,用于匹配不良标点符号
bad_punct_regex = re.compile(
r"["
+ "#®•©™&@·º½¾¿¡§~"
+ r"\)"
+ r"\("
+ r"\]"
+ r"\["
+ r"\}"
+ r"\{"
+ r"\|"
+ "\\"
+ r"\/"
+ r"\*"
+ r"]{1,}"
) # noqa
# 定义可选组件的名称列表
_optional_components = ["tokenizer", "text_encoder"]
# 定义模型的 CPU 卸载顺序
model_cpu_offload_seq = "text_encoder->transformer->vae"
# 初始化方法,接受多个组件作为参数
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: PixArtTransformer2DModel,
scheduler: KarrasDiffusionSchedulers,
):
# 调用父类的初始化方法
super().__init__()
# 注册传入的模块
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
# 计算 VAE 的缩放因子
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 创建图像处理器,使用计算出的缩放因子
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# 从 PixArtAlphaPipeline 复制的方法,用于编码提示信息,最大序列长度从 120 改为 300
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 300,
**kwargs,
# 从 StableDiffusionPipeline 复制的方法,用于准备额外的调度步骤参数
def prepare_extra_step_kwargs(self, generator, eta):
# 为调度器步骤准备额外的关键字参数,因为并不是所有调度器都有相同的签名
# eta (η) 仅在 DDIMScheduler 中使用,其他调度器将忽略它。
# eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
# 应在 [0, 1] 之间
# 检查调度器是否接受 eta 参数
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
# 如果接受 eta,则将其添加到额外参数中
if accepts_eta:
extra_step_kwargs["eta"] = eta
# 检查调度器是否接受 generator 参数
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 如果接受 generator,则将其添加到额外参数中
if accepts_generator:
extra_step_kwargs["generator"] = generator
# 返回准备好的额外参数字典
return extra_step_kwargs
# 从 PixArtAlphaPipeline 复制的方法,用于检查输入参数
# 定义检查输入参数的函数
def check_inputs(
self,
prompt, # 提示文本
height, # 图像高度
width, # 图像宽度
negative_prompt, # 负面提示文本
callback_steps, # 回调步数
prompt_embeds=None, # 提示嵌入(可选)
negative_prompt_embeds=None, # 负面提示嵌入(可选)
prompt_attention_mask=None, # 提示注意力掩码(可选)
negative_prompt_attention_mask=None, # 负面提示注意力掩码(可选)
# 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing 复制
def _text_preprocessing(self, text, clean_caption=False): # 文本预处理函数,带有清理标志
if clean_caption and not is_bs4_available(): # 检查是否需要清理且 bs4 库不可用
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) # 记录警告
logger.warning("Setting `clean_caption` to False...") # 记录清理被设置为 False 的警告
clean_caption = False # 设置清理标志为 False
if clean_caption and not is_ftfy_available(): # 检查是否需要清理且 ftfy 库不可用
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) # 记录警告
logger.warning("Setting `clean_caption` to False...") # 记录清理被设置为 False 的警告
clean_caption = False # 设置清理标志为 False
if not isinstance(text, (tuple, list)): # 检查文本类型是否为元组或列表
text = [text] # 将文本包装为列表
def process(text: str): # 定义处理文本的内部函数
if clean_caption: # 如果需要清理文本
text = self._clean_caption(text) # 清理文本
text = self._clean_caption(text) # 再次清理文本
else: # 如果不需要清理文本
text = text.lower().strip() # 将文本转为小写并去除空格
return text # 返回处理后的文本
return [process(t) for t in text] # 对每个文本项进行处理并返回结果列表
# 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption 复制
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 复制
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): # 准备潜在变量的函数
shape = ( # 定义潜在变量的形状
batch_size, # 批量大小
num_channels_latents, # 潜在变量的通道数
int(height) // self.vae_scale_factor, # 缩放后的高度
int(width) // self.vae_scale_factor, # 缩放后的宽度
)
if isinstance(generator, list) and len(generator) != batch_size: # 检查生成器列表的长度是否与批量大小匹配
raise ValueError( # 抛出值错误
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" # 提示生成器长度与批量大小不匹配
f" size of {batch_size}. Make sure the batch size matches the length of the generators." # 提示用户检查匹配
)
if latents is None: # 如果未提供潜在变量
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # 生成随机潜在变量
else: # 如果提供了潜在变量
latents = latents.to(device) # 将潜在变量移动到指定设备
# 根据调度器要求的标准差缩放初始噪声
latents = latents * self.scheduler.init_noise_sigma # 缩放潜在变量
return latents # 返回潜在变量
@torch.no_grad() # 在无梯度模式下运行
@replace_example_docstring(EXAMPLE_DOC_STRING) # 替换示例文档字符串
# 定义一个可调用的类方法,接受多种参数
def __call__(
# 用户输入的提示,可以是字符串或字符串列表
self,
prompt: Union[str, List[str]] = None,
# 用户的负面提示,默认为空字符串
negative_prompt: str = "",
# 推理步骤的数量,默认为20
num_inference_steps: int = 20,
# 时间步列表,默认为None
timesteps: List[int] = None,
# Sigma值列表,默认为None
sigmas: List[float] = None,
# 指导比例,默认为4.5
guidance_scale: float = 4.5,
# 每个提示生成的图像数量,默认为1
num_images_per_prompt: Optional[int] = 1,
# 输出图像的高度,默认为None
height: Optional[int] = None,
# 输出图像的宽度,默认为None
width: Optional[int] = None,
# Eta值,默认为0.0
eta: float = 0.0,
# 随机数生成器,可为单个或多个torch.Generator,默认为None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# 潜在向量,默认为None
latents: Optional[torch.Tensor] = None,
# 提示的嵌入表示,默认为None
prompt_embeds: Optional[torch.Tensor] = None,
# 提示的注意力掩码,默认为None
prompt_attention_mask: Optional[torch.Tensor] = None,
# 负面提示的嵌入表示,默认为None
negative_prompt_embeds: Optional[torch.Tensor] = None,
# 负面提示的注意力掩码,默认为None
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
# 输出类型,默认为'pil'
output_type: Optional[str] = "pil",
# 是否返回字典格式,默认为True
return_dict: bool = True,
# 回调函数,接受三个参数,默认为None
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
# 回调步骤的数量,默认为1
callback_steps: int = 1,
# 是否清理标题,默认为True
clean_caption: bool = True,
# 是否使用分辨率分箱,默认为True
use_resolution_binning: bool = True,
# 最大序列长度,默认为300
max_sequence_length: int = 300,
# 其他关键字参数
**kwargs,
.\diffusers\pipelines\pixart_alpha\__init__.py
# 导入类型检查相关的类型
from typing import TYPE_CHECKING
# 从父目录导入所需的工具和依赖
from ...utils import (
DIFFUSERS_SLOW_IMPORT, # 慢导入的标志
OptionalDependencyNotAvailable, # 可选依赖不可用的异常
_LazyModule, # 懒加载模块的工具
get_objects_from_module, # 从模块中获取对象的工具
is_torch_available, # 检查 PyTorch 是否可用
is_transformers_available, # 检查 Transformers 是否可用
)
# 初始化一个空字典以存储虚拟对象
_dummy_objects = {}
# 初始化一个字典以存储模块的导入结构
_import_structure = {}
# 尝试检查依赖是否可用
try:
# 如果 Transformers 和 Torch 都不可用,抛出异常
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从工具中导入虚拟对象以避免错误
from ...utils import dummy_torch_and_transformers_objects # noqa F403
# 更新虚拟对象字典
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果依赖可用,更新导入结构
else:
# 在导入结构中添加 PixArtAlphaPipeline
_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
# 在导入结构中添加 PixArtSigmaPipeline
_import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"]
# 如果在类型检查或慢导入模式下,进行以下操作
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
# 检查 Transformers 和 Torch 是否可用
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从工具中导入虚拟对象以避免错误
from ...utils.dummy_torch_and_transformers_objects import *
else:
# 从 PixArtAlphaPipeline 模块中导入所需的对象
from .pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN, # 256 比例的常量
ASPECT_RATIO_512_BIN, # 512 比例的常量
ASPECT_RATIO_1024_BIN, # 1024 比例的常量
PixArtAlphaPipeline, # PixArtAlphaPipeline 类
)
# 从 PixArtSigmaPipeline 模块中导入所需的对象
from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline
# 如果不是类型检查或慢导入模式,执行以下操作
else:
# 导入 sys 模块以进行模块操作
import sys
# 用懒加载模块替换当前模块
sys.modules[__name__] = _LazyModule(
__name__, # 当前模块的名称
globals()["__file__"], # 当前模块的文件路径
_import_structure, # 模块的导入结构
module_spec=__spec__, # 模块的规范
)
# 遍历虚拟对象字典并设置当前模块的属性
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value) # 设置模块的属性
.\diffusers\pipelines\semantic_stable_diffusion\pipeline_output.py
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入 List、Optional 和 Union 类型提示
from typing import List, Optional, Union
# 导入 numpy 库并简写为 np
import numpy as np
# 导入 PIL.Image 模块,用于处理图像
import PIL.Image
# 从上层目录的 utils 模块导入 BaseOutput 基类
from ...utils import BaseOutput
# 定义一个数据类 SemanticStableDiffusionPipelineOutput,继承自 BaseOutput
@dataclass
class SemanticStableDiffusionPipelineOutput(BaseOutput):
"""
Stable Diffusion 流水线的输出类。
参数:
images (`List[PIL.Image.Image]` 或 `np.ndarray`)
包含去噪后 PIL 图像的列表,长度为 `batch_size`,或形状为 `(batch_size, height, width,
num_channels)` 的 NumPy 数组。
nsfw_content_detected (`List[bool]`)
列表,指示相应生成的图像是否包含“非安全内容”(nsfw),
如果无法执行安全检查,则为 `None`。
"""
# 定义 images 属性,可以是 PIL 图像列表或 NumPy 数组
images: Union[List[PIL.Image.Image], np.ndarray]
# 定义 nsfw_content_detected 属性,表示安全检查结果,类型为可选布尔列表
nsfw_content_detected: Optional[List[bool]]
.\diffusers\pipelines\semantic_stable_diffusion\pipeline_semantic_stable_diffusion.py
# 导入 Python 的 inspect 模块,用于获取信息
import inspect
# 从 itertools 模块导入 repeat 函数,用于生成重复元素
from itertools import repeat
# 导入类型提示所需的 Callable, List, Optional, Union
from typing import Callable, List, Optional, Union
# 导入 PyTorch 库
import torch
# 从 transformers 库导入 CLIP 图像处理器、文本模型和分词器
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
# 从自定义模块导入 VAE 图像处理器
from ...image_processor import VaeImageProcessor
# 从自定义模型中导入 AutoencoderKL 和 UNet2DConditionModel
from ...models import AutoencoderKL, UNet2DConditionModel
# 从安全检查器模块导入 StableDiffusionSafetyChecker
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# 从调度器模块导入 KarrasDiffusionSchedulers
from ...schedulers import KarrasDiffusionSchedulers
# 从工具模块导入弃用和日志记录功能
from ...utils import deprecate, logging
# 从 PyTorch 工具模块导入随机张量生成函数
from ...utils.torch_utils import randn_tensor
# 从管道工具模块导入 DiffusionPipeline 和 StableDiffusionMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
# 从管道输出模块导入 SemanticStableDiffusionPipelineOutput
from .pipeline_output import SemanticStableDiffusionPipelineOutput
# 创建一个日志记录器,记录当前模块的日志
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定义一个用于语义稳定扩散的管道类,继承自 DiffusionPipeline 和 StableDiffusionMixin
class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
r"""
使用稳定扩散进行文本到图像生成的管道,支持潜在编辑。
此模型继承自 [`DiffusionPipeline`],并基于 [`StableDiffusionPipeline`]。有关所有管道的通用方法的文档,
请查阅超类文档(下载、保存、在特定设备上运行等)。
参数:
vae ([`AutoencoderKL`]):
用于将图像编码和解码为潜在表示的变分自编码器模型。
text_encoder ([`~transformers.CLIPTextModel`]):
冻结的文本编码器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
tokenizer ([`~transformers.CLIPTokenizer`]):
用于对文本进行分词的 `CLIPTokenizer`。
unet ([`UNet2DConditionModel`]):
用于去噪编码后的图像潜在表示的 `UNet2DConditionModel`。
scheduler ([`SchedulerMixin`]):
用于与 `unet` 结合以去噪编码图像潜在表示的调度器,可以是
[`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`]。
safety_checker ([`Q16SafetyChecker`]):
评估生成图像是否可能被视为冒犯或有害的分类模块。
有关模型潜在危害的更多详细信息,请参阅 [模型卡](https://huggingface.co/runwayml/stable-diffusion-v1-5)。
feature_extractor ([`~transformers.CLIPImageProcessor`]):
用于从生成图像中提取特征的 `CLIPImageProcessor`;用于 `safety_checker` 的输入。
"""
# 定义模型在 CPU 上的卸载顺序
model_cpu_offload_seq = "text_encoder->unet->vae"
# 定义可选组件列表
_optional_components = ["safety_checker", "feature_extractor"]
# 初始化方法,接受多个参数以配置管道
def __init__(
self,
vae: AutoencoderKL, # 接受变分自编码器模型
text_encoder: CLIPTextModel, # 接受文本编码器模型
tokenizer: CLIPTokenizer, # 接受文本分词器
unet: UNet2DConditionModel, # 接受去噪网络模型
scheduler: KarrasDiffusionSchedulers, # 接受调度器
safety_checker: StableDiffusionSafetyChecker, # 接受安全检查器
feature_extractor: CLIPImageProcessor, # 接受图像处理器
requires_safety_checker: bool = True, # 可选参数,指示是否需要安全检查器
):
# 调用父类的初始化方法
super().__init__()
# 如果未提供安全检查器且需要安全检查器,发出警告
if safety_checker is None and requires_safety_checker:
logger.warning(
# 输出警告信息,提示用户需要遵守使用条件
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
# 如果提供了安全检查器但未提供特征提取器,抛出异常
if safety_checker is not None and feature_extractor is None:
raise ValueError(
# 抛出值错误,提示用户需定义特征提取器
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
# 注册多个模块,方便后续调用
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# 计算 VAE 的缩放因子
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 创建图像处理器实例,用于后续图像处理
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# 将是否需要安全检查器的配置注册到实例
self.register_to_config(requires_safety_checker=requires_safety_checker)
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 复制的方法
def run_safety_checker(self, image, device, dtype):
# 如果没有安全检查器,初始化 nsfw 概念为 None
if self.safety_checker is None:
has_nsfw_concept = None
else:
# 如果输入是张量,进行后处理转换为 PIL 图像
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
# 如果输入不是张量,将其转换为 PIL 图像
feature_extractor_input = self.image_processor.numpy_to_pil(image)
# 将图像输入特征提取器并转换为适合设备的张量
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
# 运行安全检查器,获取处理后的图像和 nsfw 概念
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
# 返回处理后的图像和 nsfw 概念
return image, has_nsfw_concept
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 复制的方法
# 解码潜在表示的方法
def decode_latents(self, latents):
# 提示用户该方法已被弃用,并将在1.0.0中移除,建议使用新的方法
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
# 调用弃用警告函数,标记该方法为不推荐使用
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
# 将潜在表示根据缩放因子进行缩放
latents = 1 / self.vae.config.scaling_factor * latents
# 解码潜在表示,返回的第一项是解码后的图像
image = self.vae.decode(latents, return_dict=False)[0]
# 对图像进行归一化处理,将值限制在[0, 1]范围内
image = (image / 2 + 0.5).clamp(0, 1)
# 将图像转换为float32格式,兼容bfloat16,且不会造成显著开销
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# 返回解码后的图像
return image
# 从diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs复制而来
def prepare_extra_step_kwargs(self, generator, eta):
# 为调度器步骤准备额外的参数,因为并非所有调度器具有相同的参数签名
# eta(η)仅用于DDIMScheduler,其他调度器将忽略该参数
# eta对应于DDIM论文中的η,范围应在[0, 1]之间
# 检查调度器的步骤方法是否接受eta参数
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 创建一个字典以存储额外的步骤参数
extra_step_kwargs = {}
# 如果接受eta参数,则将其添加到字典中
if accepts_eta:
extra_step_kwargs["eta"] = eta
# 检查调度器的步骤方法是否接受generator参数
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 如果接受generator参数,则将其添加到字典中
if accepts_generator:
extra_step_kwargs["generator"] = generator
# 返回包含额外参数的字典
return extra_step_kwargs
# 从diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs复制而来
def check_inputs(
# 方法参数包括提示、图像高度、宽度、回调步骤等
prompt,
height,
width,
callback_steps,
# 负面提示、提示嵌入和负面提示嵌入可选参数
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
# 检查高度和宽度是否能被8整除,若不满足则引发错误
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# 检查回调步数是否为正整数,若不满足则引发错误
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 检查在步骤结束时的张量输入是否有效,若无效则引发错误
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# 检查是否同时提供了提示和提示嵌入,若是则引发错误
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 检查提示和提示嵌入是否都未定义,若是则引发错误
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 检查提示的类型是否有效,若无效则引发错误
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 检查负面提示和负面提示嵌入是否同时提供,若是则引发错误
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
# 检查提示嵌入和负面提示嵌入的形状是否一致,若不一致则引发错误
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 中复制
# 准备潜在变量的函数,接收一系列参数
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
# 定义潜在变量的形状,包括批量大小、通道数、高度和宽度
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
# 检查生成器是否为列表且长度与批量大小不匹配,若不匹配则抛出错误
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果潜在变量为 None,则生成随机的潜在变量
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 如果潜在变量不为 None,则将其移动到指定设备
latents = latents.to(device)
# 根据调度器所需的标准差缩放初始噪声
latents = latents * self.scheduler.init_noise_sigma
# 返回处理后的潜在变量
return latents
# 禁用梯度计算的上下文装饰器
@torch.no_grad()
def __call__(
# 接收多个参数,包括提示、图像高度和宽度、推理步骤数等
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
editing_prompt: Optional[Union[str, List[str]]] = None,
editing_prompt_embeddings: Optional[torch.Tensor] = None,
reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
edit_warmup_steps: Optional[Union[int, List[int]]] = 10,
edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
edit_threshold: Optional[Union[float, List[float]]] = 0.9,
edit_momentum_scale: Optional[float] = 0.1,
edit_mom_beta: Optional[float] = 0.4,
edit_weights: Optional[List[float]] = None,
sem_guidance: Optional[List[torch.Tensor]] = None,
.\diffusers\pipelines\semantic_stable_diffusion\__init__.py
# 导入类型检查相关常量
from typing import TYPE_CHECKING
# 从工具模块导入必要的组件
from ...utils import (
DIFFUSERS_SLOW_IMPORT, # 导入慢加载标志
OptionalDependencyNotAvailable, # 导入可选依赖不可用异常
_LazyModule, # 导入懒加载模块
get_objects_from_module, # 导入从模块获取对象的函数
is_torch_available, # 导入检查 PyTorch 是否可用的函数
is_transformers_available, # 导入检查 Transformers 是否可用的函数
)
# 初始化一个空字典用于存储虚拟对象
_dummy_objects = {}
# 初始化一个空字典用于存储导入结构
_import_structure = {}
# 尝试执行依赖检查
try:
# 检查 Transformers 和 PyTorch 是否都可用
if not (is_transformers_available() and is_torch_available()):
# 如果不可用,抛出异常
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 导入虚拟对象以避免实际依赖
from ...utils import dummy_torch_and_transformers_objects # noqa F403
# 更新虚拟对象字典
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
# 如果依赖可用,更新导入结构
_import_structure["pipeline_output"] = ["SemanticStableDiffusionPipelineOutput"]
_import_structure["pipeline_semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
# 根据类型检查或慢加载标志执行进一步的检查
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
# 检查 Transformers 和 PyTorch 是否都可用
if not (is_transformers_available() and is_torch_available()):
# 如果不可用,抛出异常
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从虚拟对象模块导入所有内容
from ...utils.dummy_torch_and_transformers_objects import *
else:
# 如果依赖可用,导入语义稳定扩散管道
from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
# 如果不是类型检查或慢加载
else:
import sys
# 用懒加载模块替代当前模块
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
# 将虚拟对象字典中的每个对象设置到当前模块
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
.\diffusers\pipelines\shap_e\camera.py
# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入数据类装饰器,用于简化类的定义
from dataclasses import dataclass
# 导入元组类型,用于类型注解
from typing import Tuple
# 导入 NumPy 库,进行数值计算
import numpy as np
# 导入 PyTorch 库,进行张量操作
import torch
# 定义一个可微分的投影相机类
@dataclass
class DifferentiableProjectiveCamera:
"""
Implements a batch, differentiable, standard pinhole camera
"""
# 相机的原点,形状为 [batch_size x 3]
origin: torch.Tensor # [batch_size x 3]
# x 轴方向向量,形状为 [batch_size x 3]
x: torch.Tensor # [batch_size x 3]
# y 轴方向向量,形状为 [batch_size x 3]
y: torch.Tensor # [batch_size x 3]
# z 轴方向向量,形状为 [batch_size x 3]
z: torch.Tensor # [batch_size x 3]
# 相机的宽度
width: int
# 相机的高度
height: int
# 水平视场角
x_fov: float
# 垂直视场角
y_fov: float
# 相机的形状信息,元组类型
shape: Tuple[int]
# 初始化后进行验证
def __post_init__(self):
# 验证原点和方向向量的批次大小一致
assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
# 验证每个方向向量的维度为3
assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
# 验证每个张量的维度都是2
assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2
# 返回相机的分辨率
def resolution(self):
# 将宽度和高度转为浮点型张量
return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
# 返回相机的视场角
def fov(self):
# 将水平和垂直视场角转为浮点型张量
return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
# 获取图像坐标
def get_image_coords(self) -> torch.Tensor:
"""
:return: coords of shape (width * height, 2)
"""
# 生成像素索引,范围从 0 到 width * height - 1
pixel_indices = torch.arange(self.height * self.width)
# 计算坐标,分离出 x 和 y 组件
coords = torch.stack(
[
pixel_indices % self.width, # x 坐标
torch.div(pixel_indices, self.width, rounding_mode="trunc"), # y 坐标
],
axis=1, # 沿新轴堆叠
)
# 返回坐标张量
return coords
# 计算相机光线
@property
def camera_rays(self):
# 获取批次大小和其他形状信息
batch_size, *inner_shape = self.shape
# 计算内部批次大小
inner_batch_size = int(np.prod(inner_shape))
# 获取图像坐标
coords = self.get_image_coords()
# 将坐标广播到批次大小
coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
# 获取相机光线
rays = self.get_camera_rays(coords)
# 调整光线张量的形状
rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3)
# 返回光线张量
return rays
# 获取相机射线的函数,输入为坐标张量,输出为射线张量
def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
# 提取批大小、形状和坐标数量
batch_size, *shape, n_coords = coords.shape
# 确保坐标数量为 2
assert n_coords == 2
# 确保批大小与原点的数量一致
assert batch_size == self.origin.shape[0]
# 将坐标展平,形状变为 (batch_size, -1, 2)
flat = coords.view(batch_size, -1, 2)
# 获取分辨率
res = self.resolution()
# 获取视场角
fov = self.fov()
# 计算归一化坐标,范围从 -1 到 1
fracs = (flat.float() / (res - 1)) * 2 - 1
# 将归一化坐标转换为视场角下的方向
fracs = fracs * torch.tan(fov / 2)
# 将归一化坐标重新调整形状
fracs = fracs.view(batch_size, -1, 2)
# 计算射线方向
directions = (
self.z.view(batch_size, 1, 3) # z 方向
+ self.x.view(batch_size, 1, 3) * fracs[:, :, :1] # x 方向
+ self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] # y 方向
)
# 对方向进行归一化
directions = directions / directions.norm(dim=-1, keepdim=True)
# 堆叠原点和方向形成射线张量
rays = torch.stack(
[
# 扩展原点以匹配方向的形状
torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]),
directions, # 射线方向
],
dim=2, # 在最后一个维度进行堆叠
)
# 返回最终的射线张量,形状为 (batch_size, *shape, 2, 3)
return rays.view(batch_size, *shape, 2, 3)
# 调整图像大小的函数,返回新的相机对象
def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
"""
创建一个新的相机用于调整后的视图,假设长宽比不变。
"""
# 确保宽高比不变
assert width * self.height == height * self.width, "The aspect ratio should not change."
# 返回新的可微分投影相机对象
return DifferentiableProjectiveCamera(
origin=self.origin, # 原点
x=self.x, # x 方向
y=self.y, # y 方向
z=self.z, # z 方向
width=width, # 新的宽度
height=height, # 新的高度
x_fov=self.x_fov, # x 方向的视场角
y_fov=self.y_fov, # y 方向的视场角
)
# 创建一个全景摄像机的函数,返回一个可微分的投影摄像机
def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera:
# 初始化原点、x轴、y轴和z轴的列表
origins = []
xs = []
ys = []
zs = []
# 生成20个从0到2π的均匀分布的角度
for theta in np.linspace(0, 2 * np.pi, num=20):
# 计算z轴方向的单位向量
z = np.array([np.sin(theta), np.cos(theta), -0.5])
# 将z向量标准化
z /= np.sqrt(np.sum(z**2))
# 计算相机原点位置,向外移动4个单位
origin = -z * 4
# 计算x轴方向的向量
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
# 计算y轴方向的向量,通过z和x的叉积获得
y = np.cross(z, x)
# 将计算得到的原点和轴向向量添加到对应的列表中
origins.append(origin)
xs.append(x)
ys.append(y)
zs.append(z)
# 返回一个DifferentiableProjectiveCamera对象,包含原点、轴向向量和其它参数
return DifferentiableProjectiveCamera(
# 将原点列表转换为PyTorch的张量
origin=torch.from_numpy(np.stack(origins, axis=0)).float(),
# 将x轴列表转换为PyTorch的张量
x=torch.from_numpy(np.stack(xs, axis=0)).float(),
# 将y轴列表转换为PyTorch的张量
y=torch.from_numpy(np.stack(ys, axis=0)).float(),
# 将z轴列表转换为PyTorch的张量
z=torch.from_numpy(np.stack(zs, axis=0)).float(),
# 设置摄像机的宽度
width=size,
# 设置摄像机的高度
height=size,
# 设置x方向的视场角
x_fov=0.7,
# 设置y方向的视场角
y_fov=0.7,
# 设置形状参数,表示1个摄像机和其数量
shape=(1, len(xs)),
)
.\diffusers\pipelines\shap_e\pipeline_shap_e.py
# 版权信息,标明该文件的版权归属
# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved.
#
# 按照 Apache License, Version 2.0 进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件只能在遵守许可的情况下使用
# you may not use this file except in compliance with the License.
# 可以通过以下链接获取许可副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,软件按“原样”提供,不提供任何明示或暗示的保证或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可文件以了解特定权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入数学模块
import math
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 导入类型提示相关的类型
from typing import List, Optional, Union
# 导入 numpy 库并命名为 np
import numpy as np
# 导入图像处理库 PIL 的 Image 模块
import PIL.Image
# 导入 PyTorch 库
import torch
# 从 transformers 库导入 CLIPTextModelWithProjection 和 CLIPTokenizer
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
# 从本地模型模块导入 PriorTransformer 类
from ...models import PriorTransformer
# 从调度器模块导入 HeunDiscreteScheduler 类
from ...schedulers import HeunDiscreteScheduler
# 从工具模块导入多个工具类和函数
from ...utils import (
BaseOutput,
logging,
replace_example_docstring,
)
# 从工具的 torch_utils 模块导入 randn_tensor 函数
from ...utils.torch_utils import randn_tensor
# 从管道工具模块导入 DiffusionPipeline 类
from ..pipeline_utils import DiffusionPipeline
# 从渲染器模块导入 ShapERenderer 类
from .renderer import ShapERenderer
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 示例文档字符串,展示如何使用该管道
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from diffusers.utils import export_to_gif
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> repo = "openai/shap-e"
>>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
>>> pipe = pipe.to(device)
>>> guidance_scale = 15.0
>>> prompt = "a shark"
>>> images = pipe(
... prompt,
... guidance_scale=guidance_scale,
... num_inference_steps=64,
... frame_size=256,
... ).images
>>> gif_path = export_to_gif(images[0], "shark_3d.gif")
```py
"""
# 定义 ShapEPipelineOutput 数据类,继承自 BaseOutput
@dataclass
class ShapEPipelineOutput(BaseOutput):
"""
ShapEPipeline 和 ShapEImg2ImgPipeline 的输出类。
参数:
images (`torch.Tensor`)
生成的 3D 渲染图像列表。
"""
# 声明一个属性,表示图像可以是多种格式的列表
images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]]
# 定义 ShapEPipeline 类,继承自 DiffusionPipeline
class ShapEPipeline(DiffusionPipeline):
"""
用于生成 3D 资产的潜在表示并使用 NeRF 方法进行渲染的管道。
该模型继承自 DiffusionPipeline。请查看超类文档以获取所有管道实现的通用方法
(下载、保存、在特定设备上运行等)。
# 文档字符串,描述构造函数的参数及其类型
Args:
prior ([`PriorTransformer`]):
用于近似文本嵌入生成图像嵌入的标准 unCLIP 先验。
text_encoder ([`~transformers.CLIPTextModelWithProjection`]):
冻结的文本编码器。
tokenizer ([`~transformers.CLIPTokenizer`]):
用于对文本进行分词的 `CLIPTokenizer`。
scheduler ([`HeunDiscreteScheduler`]):
用于与 `prior` 模型结合生成图像嵌入的调度器。
shap_e_renderer ([`ShapERenderer`]):
Shap-E 渲染器将生成的潜在向量投影到 MLP 的参数中,以使用 NeRF 渲染方法创建 3D 对象。
"""
# 定义 CPU 卸载顺序,指定先卸载 text_encoder 后卸载 prior
model_cpu_offload_seq = "text_encoder->prior"
# 指定不进行 CPU 卸载的模块列表
_exclude_from_cpu_offload = ["shap_e_renderer"]
# 初始化方法,接收多个参数用于设置对象状态
def __init__(
self,
prior: PriorTransformer, # 先验模型
text_encoder: CLIPTextModelWithProjection, # 文本编码器
tokenizer: CLIPTokenizer, # 文本分词器
scheduler: HeunDiscreteScheduler, # 调度器
shap_e_renderer: ShapERenderer, # Shap-E 渲染器
):
super().__init__() # 调用父类的初始化方法
# 注册各个模块,将其绑定到当前实例
self.register_modules(
prior=prior, # 注册先验模型
text_encoder=text_encoder, # 注册文本编码器
tokenizer=tokenizer, # 注册文本分词器
scheduler=scheduler, # 注册调度器
shap_e_renderer=shap_e_renderer, # 注册 Shap-E 渲染器
)
# 从 diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline 复制的方法,准备潜在向量
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
# 如果未提供潜在向量,则生成随机的潜在向量
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 如果提供的潜在向量形状不符合预期,则抛出异常
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
# 将潜在向量移动到指定设备
latents = latents.to(device)
# 将潜在向量乘以调度器的初始噪声标准差
latents = latents * scheduler.init_noise_sigma
# 返回处理后的潜在向量
return latents
# 编码提示文本的方法,接收多个参数用于配置
def _encode_prompt(
self,
prompt, # 提示文本
device, # 指定设备
num_images_per_prompt, # 每个提示生成的图像数量
do_classifier_free_guidance, # 是否进行无分类器引导
# 定义一个方法,处理输入的提示文本
):
# 判断 prompt 是否为列表,如果是,返回其长度,否则返回 1
len(prompt) if isinstance(prompt, list) else 1
# YiYi 注释: 将 pad_token_id 设置为 0,不确定为何无法在配置文件中设置
self.tokenizer.pad_token_id = 0
# 获取提示文本的嵌入表示
text_inputs = self.tokenizer(
prompt,
# 在处理时填充到最大长度
padding="max_length",
# 设置最大长度为 tokenizer 的最大模型长度
max_length=self.tokenizer.model_max_length,
# 如果文本超长,进行截断
truncation=True,
# 返回张量格式,类型为 PyTorch
return_tensors="pt",
)
# 获取文本输入的 ID
text_input_ids = text_inputs.input_ids
# 获取未截断的 ID,填充方式为最长
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
# 检查未截断 ID 的形状是否大于等于文本输入 ID 的形状,并确保它们不相等
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
# 解码并记录被截断的文本部分
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"以下输入部分被截断,因为 CLIP 只能处理最长"
f" {self.tokenizer.model_max_length} 个 tokens: {removed_text}"
)
# 将文本输入 ID 转移到设备上并获取编码输出
text_encoder_output = self.text_encoder(text_input_ids.to(device))
# 获取文本嵌入
prompt_embeds = text_encoder_output.text_embeds
# 根据每个提示文本的图像数量重复嵌入
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
# 在 Shap-E 中,先对 prompt_embeds 进行归一化,然后再重新缩放
prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
# 如果需要分类自由引导
if do_classifier_free_guidance:
# 创建与 prompt_embeds 形状相同的零张量
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
# 为了分类自由引导,需要进行两次前向传递
# 将无条件嵌入和文本嵌入连接到一个批次中,以避免进行两次前向传递
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 将特征重新缩放为单位方差
prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
# 返回最终的提示嵌入
return prompt_embeds
# 装饰器,禁用梯度计算
@torch.no_grad()
# 替换示例文档字符串
@replace_example_docstring(EXAMPLE_DOC_STRING)
# 定义调用方法,处理输入参数
def __call__(
# 提示文本
prompt: str,
# 每个提示生成的图像数量,默认为 1
num_images_per_prompt: int = 1,
# 推理步骤数量,默认为 25
num_inference_steps: int = 25,
# 随机数生成器,可选
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# 潜在张量,可选
latents: Optional[torch.Tensor] = None,
# 指导比例,默认为 4.0
guidance_scale: float = 4.0,
# 帧大小,默认为 64
frame_size: int = 64,
# 输出类型,可选,默认为 'pil'
output_type: Optional[str] = "pil", # pil, np, latent, mesh
# 是否返回字典格式,默认为 True
return_dict: bool = True,