diffusers 源码解析(二)
.\diffusers\loaders\lora_conversion_utils.py
import re
from ..utils import is_peft_version, logging
logger = logging.get_logger(__name__)
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
all_keys = list(state_dict.keys())
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
is_in_sgm_format = False
for key in all_keys:
if any(p in key for p in sgm_patterns):
is_in_sgm_format = True
break
if not is_in_sgm_format:
return state_dict
new_state_dict = {}
inner_block_map = ["resnets", "attentions", "upsamplers"]
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
for layer in all_keys:
if "text" in layer:
new_state_dict[layer] = state_dict.pop(layer)
else:
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
if sgm_patterns[0] in layer:
input_block_ids.add(layer_id)
elif sgm_patterns[1] in layer:
middle_block_ids.add(layer_id)
elif sgm_patterns[2] in layer:
output_block_ids.add(layer_id)
else:
raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
for layer_id in input_block_ids
}
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
for layer_id in middle_block_ids
}
output_blocks = {
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
for layer_id in output_block_ids
}
for i in input_block_ids:
block_id = (i - 1) // (unet_config.layers_per_block + 1)
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
for key in input_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in middle_block_ids:
key_part = None
if i == 0:
key_part = [inner_block_map[0], "0"]
elif i == 1:
key_part = [inner_block_map[1], "0"]
elif i == 2:
key_part = [inner_block_map[0], "1"]
else:
raise ValueError(f"Invalid middle block id {i}.")
for key in middle_blocks[i]:
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in output_block_ids:
block_id = i // (unet_config.layers_per_block + 1)
layer_in_block_id = i % (unet_config.layers_per_block + 1)
for key in output_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id]
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
if len(state_dict) > 0:
raise ValueError("At this point all state dict entries have to be converted.")
return new_state_dict
def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
unet_state_dict = {}
te_state_dict = {}
te2_state_dict = {}
network_alphas = {}
dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
all_lora_keys = list(state_dict.keys())
for key in all_lora_keys:
if not key.endswith("lora_down.weight"):
continue
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name.startswith("lora_unet_"):
diffusers_name = _convert_unet_lora_key(key)
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if dora_present_in_te or dora_present_in_te2:
dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith("lora_te2_"):
te2_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item()
network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Non-diffusers checkpoint detected.")
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
if len(te2_state_dict) > 0
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alphas
def _convert_unet_lora_key(key):
"""
转换 U-Net LoRA 键为 Diffusers 兼容的键。
"""
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
elif "ff" in diffusers_name:
pass
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
pass
else:
pass
return diffusers_name
def _convert_text_encoder_lora_key(key, lora_name):
"""
转换文本编码器 LoRA 键为 Diffusers 兼容的键。
"""
if lora_name.startswith(("lora_te_", "lora_te1_")):
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
else:
key_to_replace = "lora_te2_"
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
pass
elif "mlp" in diffusers_name:
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
return diffusers_name
def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
"""
Gets the correct alpha name for the Diffusers model.
"""
if lora_name_alpha.startswith("lora_unet_"):
prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
prefix = "text_encoder."
else:
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
return {new_name: alpha}
.\diffusers\loaders\lora_pipeline.py
import os
from typing import Callable, Dict, List, Optional, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
USE_PEFT_BACKEND,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_peft_version,
is_transformers_available,
logging,
scale_lora_layers,
)
from .lora_base import LoraBaseMixin
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
if is_transformers_available():
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
将 LoRA 层加载到稳定扩散模型 [`UNet2DConditionModel`] 和
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 中。
"""
_lora_loadable_modules = ["unet", "text_encoder"]
unet_name = UNET_NAME
text_encoder_name = TEXT_ENCODER_NAME
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
加载指定的 LoRA 权重到 `self.unet` 和 `self.text_encoder` 中。
所有关键字参数将转发给 `self.lora_state_dict`。
详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`],了解如何加载状态字典。
详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`],了解如何将状态字典加载到 `self.unet` 中。
详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`],了解如何将状态字典加载到 `self.text_encoder` 中。
参数:
pretrained_model_name_or_path_or_dict (`str` 或 `os.PathLike` 或 `dict`):
详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]。
kwargs (`dict`, *可选*):
详情请参阅 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]。
adapter_name (`str`, *可选*):
用于引用加载的适配器模型的适配器名称。如果未指定,将使用 `default_{i}`,其中 i 是加载的适配器总数。
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name,
_pipeline=self,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=getattr(self, self.text_encoder_name)
if not hasattr(self, "text_encoder")
else self.text_encoder,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
)
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
"""
将 `state_dict` 中指定的 LoRA 层加载到 `unet` 中。
参数:
state_dict (`dict`):
包含 LoRA 层参数的标准状态字典。键可以直接索引到 unet,或者以额外的 `unet` 前缀标识,以区分文本编码器的 LoRA 层。
network_alphas (`Dict[str, float]`):
用于稳定学习和防止下溢的网络 alpha 值。此值与 kohya-ss 训练脚本中的 `--network_alpha` 选项含义相同。参考[此链接](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning)。
unet (`UNet2DConditionModel`):
用于加载 LoRA 层的 UNet 模型。
adapter_name (`str`, *可选*):
用于引用加载的适配器模型的适配器名称。如果未指定,将使用 `default_{i}`,其中 i 是加载的适配器总数。
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if not only_text_encoder:
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
)
@classmethod
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
):
pass
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
pass
):
r""" # 文档字符串,描述函数的作用和参数
Save the LoRA parameters corresponding to the UNet and text encoder. # 保存与 UNet 和文本编码器相对应的 LoRA 参数
Arguments: # 参数说明
save_directory (`str` or `os.PathLike`): # 保存目录的类型说明
Directory to save LoRA parameters to. Will be created if it doesn't exist. # 保存 LoRA 参数的目录,如果不存在则创建
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): # UNet 的 LoRA 层状态字典
State dict of the LoRA layers corresponding to the `unet`. # 与 `unet` 相对应的 LoRA 层的状态字典
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): # 文本编码器的 LoRA 层状态字典
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text # 与 `text_encoder` 相对应的 LoRA 层状态字典,必须显式传递
encoder LoRA state dict because it comes from 🤗 Transformers. # 因为它来自 🤗 Transformers
is_main_process (`bool`, *optional*, defaults to `True`): # 主要进程的布尔值,可选,默认值为 True
Whether the process calling this is the main process or not. Useful during distributed training and you # 调用此函数的进程是否为主进程,在分布式训练中很有用
need to call this function on all processes. In this case, set `is_main_process=True` only on the main # 在这种情况下,只在主进程上设置 `is_main_process=True` 以避免竞争条件
process to avoid race conditions. # 避免竞争条件
save_function (`Callable`): # 保存函数的类型说明
The function to use to save the state dictionary. Useful during distributed training when you need to # 用于保存状态字典的函数,在分布式训练中很有用
replace `torch.save` with another method. Can be configured with the environment variable # 可以通过环境变量配置
`DIFFUSERS_SAVE_MODE`. # `DIFFUSERS_SAVE_MODE`
safe_serialization (`bool`, *optional*, defaults to `True`): # 安全序列化的布尔值,可选,默认值为 True
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. # 是否使用 `safetensors` 或传统的 PyTorch 方法 `pickle` 保存模型
"""
state_dict = {}
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["unet", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r""" # 开始文档字符串,描述该方法的功能和用法
Fuses the LoRA parameters into the original parameters of the corresponding blocks. # 将 LoRA 参数融合到对应块的原始参数中
<Tip warning={true}> # 开始警告提示框
This is an experimental API. # 说明这是一个实验性 API
</Tip> # 结束警告提示框
Args: # 开始参数说明
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. # 可注入 LoRA 的组件列表
lora_scale (`float`, defaults to 1.0): # LoRA 参数对输出影响的比例
Controls how much to influence the outputs with the LoRA parameters. # 控制 LoRA 参数对输出的影响程度
safe_fusing (`bool`, defaults to `False`): # 是否在融合前检查权重是否为 NaN
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. # 如果值为 NaN 则不进行融合
adapter_names (`List[str]`, *optional*): # 可选的适配器名称
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. # 如果未传入,默认融合所有活动适配器
Example: # 示例部分的开始
```py # Python 代码块开始
from diffusers import DiffusionPipeline # 导入 DiffusionPipeline 模块
import torch # 导入 PyTorch 库
pipeline = DiffusionPipeline.from_pretrained( # 从预训练模型创建管道
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 # 使用 float16 类型的模型
).to("cuda") # 将管道移动到 GPU
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") # 加载 LoRA 权重
pipeline.fuse_lora(lora_scale=0.7) # 融合 LoRA,影响比例为 0.7
``` # Python 代码块结束
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
r""" # 开始文档字符串,描述该方法的功能和用法
Reverses the effect of # 反转 fuse_lora 方法的效果
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). # 提供 fuse_lora 的链接
<Tip warning={true}> # 开始警告提示框
This is an experimental API. # 说明这是一个实验性 API
</Tip> # 结束警告提示框
Args: # 开始参数说明
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. # 可注入 LoRA 的组件列表,用于反融合
unfuse_unet (`bool`, defaults to `True`): # 是否反融合 UNet 的 LoRA 参数
Whether to unfuse the UNet LoRA parameters. # 反融合 UNet LoRA 参数的选项
unfuse_text_encoder (`bool`, defaults to `True`): # 是否反融合文本编码器的 LoRA 参数
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the # 反融合文本编码器的 LoRA 参数的选项
LoRA parameters then it won't have any effect. # 如果文本编码器未被修改,则不会有任何效果
"""
super().unfuse_lora(components=components)
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
r"""
将 LoRA 层加载到 Stable Diffusion XL 的 [`UNet2DConditionModel`]、
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 和
[`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection) 中。
"""
_lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
unet_name = UNET_NAME
text_encoder_name = TEXT_ENCODER_NAME
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
**kwargs,
):
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if not only_text_encoder:
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
)
@classmethod
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
# 文档字符串,描述保存 UNet 和文本编码器对应的 LoRA 参数的功能
Arguments:
# 保存 LoRA 参数的目录,若不存在则创建
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
# UNet 对应的 LoRA 层的状态字典
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `unet`.
# 文本编码器对应的 LoRA 层的状态字典,必须显式传入
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
# 第二个文本编码器对应的 LoRA 层的状态字典,必须显式传入
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
# 表示调用此函数的进程是否为主进程,主要用于分布式训练
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
# 保存状态字典的函数,分布式训练时可替换 `torch.save`
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
# 是否使用 safetensors 保存模型,默认为 True
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
将 LoRA 参数融合到相应模块的原始参数中。
<Tip warning={true}>
这是一个实验性 API。
</Tip>
Args:
components: (`List[str]`): 需要融合 LoRA 的组件列表。
lora_scale (`float`, defaults to 1.0):
控制 LoRA 参数对输出的影响程度。
safe_fusing (`bool`, defaults to `False`):
在融合前检查权重是否为 NaN 的开关。
adapter_names (`List[str]`, *optional*):
用于融合的适配器名称。如果未传入,则将融合所有活动适配器。
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
# 加载 LoRA 权重
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
# 融合 LoRA 参数,影响程度为 0.7
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
r"""
逆转
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora) 的效果。
<Tip warning={true}>
这是一个实验性 API。
</Tip>
Args:
components (`List[str]`): 需要从中解融合 LoRA 的组件列表。
unfuse_unet (`bool`, defaults to `True`): 是否解融合 UNet 的 LoRA 参数。
unfuse_text_encoder (`bool`, defaults to `True`):
是否解融合文本编码器的 LoRA 参数。如果文本编码器没有被 LoRA 参数修补,则不会有任何效果。
"""
super().unfuse_lora(components=components)
class SD3LoraLoaderMixin(LoraBaseMixin):
r"""
加载 LoRA 层到 [`SD3Transformer2DModel`]、
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 和
[`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection)。
特定于 [`StableDiffusion3Pipeline`]。
"""
_lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
@classmethod
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
):
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
):
r"""
保存与 UNet 和文本编码器对应的 LoRA 参数。
参数:
save_directory (`str` 或 `os.PathLike`):
保存 LoRA 参数的目录。如果不存在,将创建该目录。
transformer_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
与 `transformer` 相关的 LoRA 层的状态字典。
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
与 `text_encoder` 相关的 LoRA 层的状态字典。必须显式传递文本编码器的 LoRA 状态字典,因为它来自 🤗 Transformers。
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
与 `text_encoder_2` 相关的 LoRA 层的状态字典。必须显式传递文本编码器的 LoRA 状态字典,因为它来自 🤗 Transformers。
is_main_process (`bool`, *可选*, 默认值为 `True`):
调用此函数的进程是否为主进程。在分布式训练期间非常有用,您需要在所有进程上调用此函数。在这种情况下,只有在主进程上设置 `is_main_process=True` 以避免竞争条件。
save_function (`Callable`):
用于保存状态字典的函数。在分布式训练时,当您需要将 `torch.save` 替换为其他方法时非常有用。可以通过环境变量 `DIFFUSERS_SAVE_MODE` 进行配置。
safe_serialization (`bool`, *可选*, 默认值为 `True`):
是否使用 `safetensors` 保存模型,或使用传统的 PyTorch 方法 `pickle`。
"""
state_dict = {}
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"必须至少传递一个 `transformer_lora_layers`、`text_encoder_lora_layers` 或 `text_encoder_2_lora_layers`。"
)
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
class FluxLoraLoaderMixin(LoraBaseMixin):
r"""
加载 LoRA 层到 [`FluxTransformer2DModel`] 和 [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)。
特定于 [`StableDiffusion3Pipeline`]。
"""
_lora_loadable_modules = ["transformer", "text_encoder"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
**kwargs,
):
pass
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
**kwargs
):
pass
):
"""
加载指定的 LoRA 权重到 `self.transformer` 和 `self.text_encoder`。
所有关键字参数会转发给 `self.lora_state_dict`。
详见 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] 如何加载状态字典。
详见 [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] 如何将状态字典加载到 `self.transformer`。
参数:
pretrained_model_name_or_path_or_dict (`str` 或 `os.PathLike` 或 `dict`):
详见 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]。
kwargs (`dict`, *可选*):
详见 [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]。
adapter_name (`str`, *可选*):
用于引用加载的适配器模型的名称。如果未指定,将使用
`default_{i}`,其中 i 是加载的适配器总数。
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
)
@classmethod
@classmethod
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r""" # 定义文档字符串,描述此函数的功能及参数
Save the LoRA parameters corresponding to the UNet and text encoder. # 描述保存LoRA参数的功能
Arguments: # 开始列出函数的参数
save_directory (`str` or `os.PathLike`): # 参数:保存LoRA参数的目录,类型为字符串或路径类
Directory to save LoRA parameters to. Will be created if it doesn't exist. # 描述:如果目录不存在,将创建该目录
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): # 参数:与transformer对应的LoRA层的状态字典
State dict of the LoRA layers corresponding to the `transformer`. # 描述:说明参数的作用
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): # 参数:与text_encoder对应的LoRA层的状态字典
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text # 描述:说明此参数必须提供,来自🤗 Transformers
encoder LoRA state dict because it comes from 🤗 Transformers. # 继续描述参数的来源
is_main_process (`bool`, *optional*, defaults to `True`): # 参数:指示当前进程是否为主进程,类型为布尔值
Whether the process calling this is the main process or not. Useful during distributed training and you # 描述:用于分布式训练时判断主进程
need to call this function on all processes. In this case, set `is_main_process=True` only on the main # 进一步说明如何使用此参数
process to avoid race conditions. # 描述:避免竞争条件
save_function (`Callable`): # 参数:用于保存状态字典的函数,类型为可调用对象
The function to use to save the state dictionary. Useful during distributed training when you need to # 描述:在分布式训练中,可能需要替换默认的保存方法
replace `torch.save` with another method. Can be configured with the environment variable # 说明如何配置此参数
`DIFFUSERS_SAVE_MODE`. # 提供环境变量名称
safe_serialization (`bool`, *optional*, defaults to `True`): # 参数:指示是否使用安全序列化保存模型,类型为布尔值
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. # 描述:选择保存模型的方式
"""
state_dict = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
# 文档字符串,说明此函数的作用和用法
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
# 将 LoRA 参数融合到对应块的原始参数中
<Tip warning={true}>
# 警告提示,说明这是一个实验性 API
This is an experimental API.
# 这是一项实验性 API
</Tip>
Args:
components: (`List[str]`):
# 参数说明,接受一个字符串列表,表示要融合 LoRA 的组件
lora_scale (`float`, defaults to 1.0):
# 参数说明,控制 LoRA 参数对输出的影响程度
Controls how much to influence the outputs with the LoRA parameters.
# 控制 LoRA 参数对输出的影响程度
safe_fusing (`bool`, defaults to `False`):
# 参数说明,是否在融合之前检查权重中是否有 NaN 值
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
# 是否在融合之前检查权重的 NaN 值,如果存在则不进行融合
adapter_names (`List[str]`, *optional*):
# 参数说明,可选的适配器名称列表,用于融合
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
# 用于融合的适配器名称列表,如果未传入,则将融合所有活动适配器
Example:
# 示例代码,展示如何使用该 API
```py
from diffusers import DiffusionPipeline
# 导入 DiffusionPipeline 类
import torch
# 导入 PyTorch 库
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
# 从预训练模型创建管道,并将其移动到 CUDA 设备上
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
# 加载 LoRA 权重到管道中
pipeline.fuse_lora(lora_scale=0.7)
# 融合 LoRA 参数,设置影响程度为 0.7
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
# 方法文档字符串,说明此方法的作用和用法
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
# 反转 fuse_lora 方法的效果
<Tip warning={true}>
# 警告提示,说明这是一个实验性 API
This is an experimental API.
# 这是一项实验性 API
</Tip>
Args:
components (`List[str]`):
# 参数说明,接受一个字符串列表,表示要从中解除 LoRA 的组件
List of LoRA-injectable components to unfuse LoRA from.
# 要从中解除 LoRA 的组件列表
"""
super().unfuse_lora(components=components)
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
_lora_loadable_modules = ["transformer", "text_encoder"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
@classmethod
@classmethod
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
保存与 UNet 和文本编码器对应的 LoRA 参数。
参数:
save_directory (`str` 或 `os.PathLike`):
保存 LoRA 参数的目录。如果目录不存在,将被创建。
unet_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
与 `unet` 相关的 LoRA 层的状态字典。
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` 或 `Dict[str, torch.Tensor]`):
与 `text_encoder` 相关的 LoRA 层的状态字典。必须明确传递文本编码器的 LoRA 状态字典,因为它来自 🤗 Transformers。
is_main_process (`bool`, *可选*, 默认值为 `True`):
调用此函数的过程是否为主过程。在分布式训练期间,您需要在所有进程上调用此函数。在这种情况下,只有在主过程中将 `is_main_process=True`,以避免竞争条件。
save_function (`Callable`):
用于保存状态字典的函数。在分布式训练时,需要用其他方法替换 `torch.save`。可以通过环境变量 `DIFFUSERS_SAVE_MODE` 进行配置。
safe_serialization (`bool`, *可选*, 默认值为 `True`):
是否使用 `safetensors` 或传统的 PyTorch 方式通过 `pickle` 保存模型。
"""
state_dict = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
.\diffusers\loaders\peft.py
import inspect
from functools import partial
from typing import Dict, List, Optional, Union
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
delete_adapter_layers,
is_peft_available,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .unet_loader_utils import _maybe_expand_lora_scales
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
}
class PeftAdapterMixin:
"""
包含用于加载和使用适配器权重的所有函数,该函数在 PEFT 库中受支持。有关适配器的更多详细信息以及如何将其注入基础模型,请查阅 PEFT
[文档](https://huggingface.co/docs/peft/index)。
安装最新版本的 PEFT,并使用此混入以:
- 在模型中附加新适配器。
- 附加多个适配器并逐步激活/停用它们。
- 激活/停用模型中的所有适配器。
- 获取活动适配器的列表。
"""
_hf_peft_config_loaded = False
def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
设置当前活跃的适配器,以便在 UNet 中使用。
参数:
adapter_names (`List[str]` 或 `str`):
要使用的适配器名称。
adapter_weights (`Union[List[float], float]`, *可选*):
与 UNet 一起使用的适配器权重。如果为 `None`,则所有适配器的权重设置为 `1.0`。
示例:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
```py
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `set_adapters()`.")
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)
weights = [w if w is not None else 1.0 for w in weights]
scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
weights = scale_expansion_fn(self, weights)
set_weights_and_activate_adapters(self, adapter_names, weights)
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
r"""
向当前模型添加一个新的适配器用于训练。如果未传递适配器名称,将为适配器分配默认名称,以遵循 PEFT 库的约定。
如果您不熟悉适配器和 PEFT 方法,建议您查看 PEFT 的
[文档](https://huggingface.co/docs/peft)。
参数:
adapter_config (`[~peft.PeftConfig]`):
要添加的适配器的配置;支持的适配器包括非前缀调整和适应提示方法。
adapter_name (`str`, *可选*, 默认为 `"default"`):
要添加的适配器名称。如果未传递名称,将为适配器分配默认名称。
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not is_peft_available():
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
from peft import PeftConfig, inject_adapter_in_model
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
if not isinstance(adapter_config, PeftConfig):
raise ValueError(
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
)
adapter_config.base_model_name_or_path = None
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""
设置特定适配器,强制模型仅使用该适配器并禁用其他适配器。
如果您不熟悉适配器和 PEFT 方法,我们邀请您阅读 PEFT 的更多信息
[文档](https://huggingface.co/docs/peft)。
参数:
adapter_name (Union[str, List[str]])):
要设置的适配器名称或适配器名称列表(如果是单个适配器)。
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
if isinstance(adapter_name, str):
adapter_name = [adapter_name]
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
from peft.tuners.tuners_utils import BaseTunerLayer
_adapters_has_been_set = False
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
raise ValueError(
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True
if not _adapters_has_been_set:
raise ValueError(
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
)
def disable_adapters(self) -> None:
r"""
禁用所有附加到模型的适配器,并回退到仅使用基础模型进行推理。
如果您对适配器和 PEFT 方法不熟悉,我们邀请您在 PEFT
[文档](https://huggingface.co/docs/peft) 中了解更多信息。
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True
def enable_adapters(self) -> None:
"""
启用附加到模型的适配器。模型使用 `self.active_adapters()` 检索要启用的适配器列表。
如果您对适配器和 PEFT 方法不熟悉,我们邀请您在 PEFT
[文档](https://huggingface.co/docs/peft) 中了解更多信息。
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False
def active_adapters(self) -> List[str]:
"""
获取模型当前活动的适配器列表。
如果您对适配器和 PEFT 方法不熟悉,我们邀请您在 PEFT
[文档](https://huggingface.co/docs/peft) 中了解更多信息。
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not is_peft_available():
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `fuse_lora()`.")
self.lora_scale = lora_scale
self._safe_fusing = safe_fusing
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
def _fuse_lora_apply(self, module, adapter_names=None):
from peft.tuners.tuners_utils import BaseTunerLayer
merge_kwargs = {"safe_merge": self._safe_fusing}
if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
" to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs)
def unfuse_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
self.apply(self._unfuse_lora_apply)
def _unfuse_lora_apply(self, module):
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
module.unmerge()
def unload_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unload_lora()`.")
from ..utils import recurse_remove_peft_layers
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
def disable_lora(self):
"""
禁用底层模型的活动 LoRA 层。
示例:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.disable_lora()
```py
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=False)
def enable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=True)
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Delete an adapter's LoRA layers from the underlying model.
Args:
adapter_names (`Union[List[str], str]`):
The names (single string or list of strings) of the adapter to delete.
Example:
...
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if isinstance(adapter_names, str):
adapter_names = [adapter_names]
for adapter_name in adapter_names:
delete_adapter_layers(self, adapter_name)
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
import importlib
import inspect
import os
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
from ..utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
SingleFileComponentError,
_is_legacy_scheduler_kwargs,
_is_model_weights_in_cached_folder,
_legacy_load_clip_tokenizer,
_legacy_load_safety_checker,
_legacy_load_scheduler,
create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
fetch_diffusers_config,
fetch_original_config,
is_clip_model_in_single_file,
is_t5_in_single_file,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
def load_single_file_sub_model(
library_name,
class_name,
name,
checkpoint,
pipelines,
is_pipeline_module,
cached_model_config_path,
original_config=None,
local_files_only=False,
torch_dtype=None,
is_legacy_loading=False,
**kwargs,
):
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
else:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_tokenizer = (
is_transformers_available()
and issubclass(class_obj, PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
diffusers_module = importlib.import_module(__name__.split(".")[0])
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
if is_diffusers_single_file_model:
load_method = getattr(class_obj, "from_single_file")
if original_config:
cached_model_config_path = None
loaded_sub_model = load_method(
pretrained_model_link_or_path_or_dict=checkpoint,
original_config=original_config,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
**kwargs,
)
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
loaded_sub_model = create_diffusers_clip_model_from_ldm(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
)
elif is_transformers_model and is_t5_in_single_file(checkpoint):
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)
elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
local_files_only=local_files_only
)
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
loaded_sub_model = _legacy_load_scheduler(
class_obj,
checkpoint=checkpoint,
component_name=name,
original_config=original_config,
**kwargs
)
else:
if not hasattr(class_obj, "from_pretrained"):
raise ValueError(
(
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
" a supported loading method."
)
)
loading_kwargs = {}
loading_kwargs.update(
{
"pretrained_model_name_or_path": cached_model_config_path,
"subfolder": name,
"local_files_only": local_files_only,
}
)
if issubclass(class_obj, torch.nn.Module):
loading_kwargs.update({"torch_dtype": torch_dtype})
if is_diffusers_model or is_transformers_model:
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
)
load_method = getattr(class_obj, "from_pretrained")
loaded_sub_model = load_method(**loading_kwargs)
return loaded_sub_model
def _map_component_types_to_config_dict(component_types):
diffusers_module = importlib.import_module(__name__.split(".")[0])
config_dict = {}
component_types.pop("self", None)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
for component_name, component_value in component_types.items():
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
is_transformers_model = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_transformers_tokenizer = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif is_scheduler_enum or is_scheduler:
if is_scheduler_enum:
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
elif is_scheduler:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif (
is_transformers_model or is_transformers_tokenizer
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["transformers", component_value[0].__name__]
else:
config_dict[component_name] = [None, None]
return config_dict
def _infer_pipeline_config_dict(pipeline_class):
parameters = inspect.signature(pipeline_class.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
component_types = pipeline_class._get_signature_types()
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
config_dict = _map_component_types_to_config_dict(component_types)
return config_dict
def _download_diffusers_model_config_from_hub(
pretrained_model_name_or_path,
cache_dir,
revision,
proxies,
force_download=None,
local_files_only=None,
token=None,
):
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
cached_model_path = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
local_files_only=local_files_only,
token=token,
allow_patterns=allow_patterns,
)
return cached_model_path
class FromSingleFileMixin:
"""
加载以 `.ckpt` 格式保存的模型权重到 [`DiffusionPipeline`] 中。
"""
@classmethod
@validate_hf_hub_args
import importlib
import inspect
import re
from contextlib import nullcontext
from typing import Optional
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
fetch_diffusers_config,
fetch_original_config,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
},
"UNet2DConditionModel": {
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
"default_subfolder": "unet",
"legacy_kwargs": {
"num_in_channels": "in_channels",
},
},
"AutoencoderKL": {
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
"default_subfolder": "vae",
},
"ControlNetModel": {
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
},
"SD3Transformer2DModel": {
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"MotionAdapter": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"SparseControlNetModel": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
}
"FluxTransformer2DModel": {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
loadable_class = getattr(diffusers_module, loadable_class_str)
if issubclass(cls, loadable_class):
return loadable_class_str
return None
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters
mapping_kwargs = {}
for parameter in parameters:
if parameter in kwargs:
mapping_kwargs[parameter] = kwargs[parameter]
return mapping_kwargs
class FromOriginalModelMixin:
"""
加载保存为 `.ckpt` 或 `.safetensors` 格式的预训练权重到模型中。
"""
@classmethod
@validate_hf_hub_args
"""用于 Stable Diffusion 检查点的转换脚本。"""
import os
import re
from contextlib import nullcontext
from io import BytesIO
from urllib.parse import urlparse
import requests
import torch
import yaml
from ..models.modeling_utils import load_state_dict
from ..schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EDMDPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ..utils import (
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
deprecate,
is_accelerate_available,
is_transformers_available,
logging,
)
from ..utils.hub_utils import _get_model_file
if is_transformers_available():
from transformers import AutoImageProcessor
if is_accelerate_available():
from accelerate import init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
logger = logging.get_logger(__name__)
CHECKPOINT_KEY_NAMES = {
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
"upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
"controlnet": "control_model.time_embed.0.weight",
"playground-v2-5": "edm_mean",
"inpainting": "model.diffusion_model.input_blocks.0.0.weight",
"clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
"clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
"clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
"open_clip": "cond_stage_model.model.token_embedding.weight",
"open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
"open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
}
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
"flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
"xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
"xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
"playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
"upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
"inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"},
"inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
"controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
"v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
"v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"},
"stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
"stable_cascade_stage_b_lite": {
"pretrained_model_name_or_path": "stabilityai/stable-cascade",
"subfolder": "decoder_lite",
},
"stable_cascade_stage_c": {
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
"subfolder": "prior",
},
"stable_cascade_stage_c_lite": {
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
"subfolder": "prior_lite",
},
"sd3": {
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
},
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
}
DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
"xl_base": 1024,
"xl_refiner": 1024,
"xl_inpaint": 1024,
"playground-v2-5": 1024,
"upscale": 512,
"inpainting": 512,
"inpainting_v2": 512,
"controlnet": 512,
"v2": 768,
"v1": 512,
}
DIFFUSERS_TO_LDM_MAPPING = {
"unet": {
"layers": {
"time_embedding.linear_1.weight": "time_embed.0.weight",
"time_embedding.linear_1.bias": "time_embed.0.bias",
"time_embedding.linear_2.weight": "time_embed.2.weight",
"time_embedding.linear_2.bias": "time_embed.2.bias",
"conv_in.weight": "input_blocks.0.0.weight",
"conv_in.bias": "input_blocks.0.0.bias",
"conv_norm_out.weight": "out.0.weight",
"conv_norm_out.bias": "out.0.bias",
"conv_out.weight": "out.2.weight",
"conv_out.bias": "out.2.bias",
},
"class_embed_type": {
"class_embedding.linear_1.weight": "label_emb.0.0.weight",
"class_embedding.linear_1.bias": "label_emb.0.0.bias",
"class_embedding.linear_2.weight": "label_emb.0.2.weight",
"class_embedding.linear_2.bias": "label_emb.0.2.bias",
},
"addition_embed_type": {
"add_embedding.linear_1.weight": "label_emb.0.0.weight",
"add_embedding.linear_1.bias": "label_emb.0.0.bias",
"add_embedding.linear_2.weight": "label_emb.0.2.weight",
"add_embedding.linear_2.bias": "label_emb.0.2.bias",
},
},
"controlnet": {
"layers": {
"time_embedding.linear_1.weight": "time_embed.0.weight",
"time_embedding.linear_1.bias": "time_embed.0.bias",
"time_embedding.linear_2.weight": "time_embed.2.weight",
"time_embedding.linear_2.bias": "time_embed.2.bias",
"conv_in.weight": "input_blocks.0.0.weight",
"conv_in.bias": "input_blocks.0.0.bias",
"controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
"controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias",
"controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight",
"controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias",
},
"class_embed_type": {
"class_embedding.linear_1.weight": "label_emb.0.0.weight",
"class_embedding.linear_1.bias": "label_emb.0.0.bias",
"class_embedding.linear_2.weight": "label_emb.0.2.weight",
"class_embedding.linear_2.bias": "label_emb.0.2.bias",
},
"addition_embed_type": {
"add_embedding.linear_1.weight": "label_emb.0.0.weight",
"add_embedding.linear_1.bias": "label_emb.0.0.bias",
"add_embedding.linear_2.weight": "label_emb.0.2.weight",
"add_embedding.linear_2.bias": "label_emb.0.2.bias",
},
},
"vae": {
"encoder.conv_in.weight": "encoder.conv_in.weight",
"encoder.conv_in.bias": "encoder.conv_in.bias",
"encoder.conv_out.weight": "encoder.conv_out.weight",
"encoder.conv_out.bias": "encoder.conv_out.bias",
"encoder.conv_norm_out.weight": "encoder.norm_out.weight",
"encoder.conv_norm_out.bias": "encoder.norm_out.bias",
"decoder.conv_in.weight": "decoder.conv_in.weight",
"decoder.conv_in.bias": "decoder.conv_in.bias",
"decoder.conv_out.weight": "decoder.conv_out.weight",
"decoder.conv_out.bias": "decoder.conv_out.bias",
"decoder.conv_norm_out.weight": "decoder.norm_out.weight",
"decoder.conv_norm_out.bias": "decoder.norm_out.bias",
"quant_conv.weight": "quant_conv.weight",
"quant_conv.bias": "quant_conv.bias",
"post_quant_conv.weight": "post_quant_conv.weight",
"post_quant_conv.bias": "post_quant_conv.bias",
},
"openclip": {
"layers": {
"text_model.embeddings.position_embedding.weight": "positional_embedding",
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.final_layer_norm.weight": "ln_final.weight",
"text_model.final_layer_norm.bias": "ln_final.bias",
"text_projection.weight": "text_projection",
},
"transformer": {
"text_model.encoder.layers.": "resblocks.",
"layer_norm1": "ln_1",
"layer_norm2": "ln_2",
".fc1.": ".c_fc.",
".fc2.": ".c_proj.",
"transformer.text_model.final_layer_norm.": "ln_final.",
"transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"transformer.text_model.embeddings.position_embedding.weight": "positional_embedding",
},
},
SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
"cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.23.ln_1.bias",
"cond_stage_model.model.transformer.resblocks.23.ln_1.weight",
"cond_stage_model.model.transformer.resblocks.23.ln_2.bias",
"cond_stage_model.model.transformer.resblocks.23.ln_2.weight",
"cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias",
"cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight",
"cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias",
"cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight",
"cond_stage_model.model.text_projection",
]
SCHEDULER_DEFAULT_CONFIG = {
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
"interpolation_type": "linear",
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"timestep_spacing": "leading",
}
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.",
]
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
class SingleFileComponentError(Exception):
def __init__(self, message=None):
self.message = message
super().__init__(self.message)
def is_valid_url(url):
result = urlparse(url)
if result.scheme and result.netloc:
return True
return False
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
if not is_valid_url(pretrained_model_name_or_path):
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
weights_name = None
repo_id = (None,)
for prefix in VALID_URL_PREFIXES:
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
match = re.match(pattern, pretrained_model_name_or_path)
if not match:
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
return repo_id, weights_name
repo_id = f"{match.group(1)}/{match.group(2)}"
weights_name = match.group(3)
return repo_id, weights_name
def _is_model_weights_in_cached_folder(cached_folder, name):
pretrained_model_name_or_path = os.path.join(cached_folder, name)
weights_exist = False
for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
weights_exist = True
return weights_exist
def _is_legacy_scheduler_kwargs(kwargs):
return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
def load_single_file_checkpoint(
pretrained_model_link_or_path,
force_download=False,
proxies=None,
token=None,
cache_dir=None,
local_files_only=None,
revision=None,
):
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
pretrained_model_link_or_path = _get_model_file(
repo_id,
weights_name=weights_name,
force_download=force_download,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
)
checkpoint = load_state_dict(pretrained_model_link_or_path)
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
return checkpoint
def fetch_original_config(original_config_file, local_files_only=False):
if os.path.isfile(original_config_file):
with open(original_config_file, "r") as fp:
original_config_file = fp.read()
elif is_valid_url(original_config_file):
if local_files_only:
raise ValueError(
"`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
"Please provide a valid local file path."
)
original_config_file = BytesIO(requests.get(original_config_file).content)
else:
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
original_config = yaml.safe_load(original_config_file)
return original_config
def is_clip_model(checkpoint):
if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
return True
return False
def is_clip_sdxl_model(checkpoint):
if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
return True
return False
def is_clip_sd3_model(checkpoint):
if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
return True
return False
def is_open_clip_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
return True
return False
def is_open_clip_sdxl_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
return True
return False
def is_open_clip_sd3_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
return True
return False
def is_open_clip_sdxl_refiner_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
return True
return False
def is_clip_model_in_single_file(class_obj, checkpoint):
is_clip_in_checkpoint = any(
[
is_clip_model(checkpoint),
is_clip_sd3_model(checkpoint),
is_open_clip_model(checkpoint),
is_open_clip_sdxl_model(checkpoint),
is_open_clip_sdxl_refiner_model(checkpoint),
is_open_clip_sd3_model(checkpoint),
]
)
if (
class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
) and is_clip_in_checkpoint:
return True
return False
def infer_diffusers_model_type(checkpoint):
if (
CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
):
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
model_type = "inpainting_v2"
else:
model_type = "inpainting"
elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
model_type = "v2"
elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
model_type = "playground-v2-5"
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
model_type = "xl_base"
elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
model_type = "xl_refiner"
elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
model_type = "upscale"
elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
model_type = "controlnet"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
):
model_type = "stable_cascade_stage_c_lite"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
):
model_type = "stable_cascade_stage_c"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
):
model_type = "stable_cascade_stage_b_lite"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
):
model_type = "stable_cascade_stage_b"
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
model_type = "sd3"
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
model_type = "animatediff_scribble"
elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
model_type = "animatediff_rgb"
elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
model_type = "animatediff_v2"
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
model_type = "animatediff_sdxl_beta"
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
model_type = "animatediff_v1"
else:
model_type = "animatediff_v3"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
model_type = "flux-dev"
else:
model_type = "flux-schnell"
else:
model_type = "v1"
return model_type
def fetch_diffusers_config(checkpoint):
model_type = infer_diffusers_model_type(checkpoint)
model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
return model_path
def set_image_size(checkpoint, image_size=None):
if image_size:
return image_size
model_type = infer_diffusers_model_type(checkpoint)
image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]
return image_size
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config_from_ldm(
original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
):
"""
基于 LDM 模型配置创建 diffuser 配置。
"""
if image_size is not None:
deprecation_message = (
"Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
image_size = set_image_size(checkpoint, image_size=image_size)
if (
"unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
):
unet_params = original_config["model"]["params"]["unet_config"]["params"]
else:
unet_params = original_config["model"]["params"]["network_config"]["params"]
if num_in_channels is not None:
deprecation_message = (
"Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
in_channels = num_in_channels
else:
in_channels = unet_params["in_channels"]
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
if unet_params["transformer_depth"] is not None:
transformer_layers_per_block = (
unet_params["transformer_depth"]
if isinstance(unet_params["transformer_depth"], int)
else list(unet_params["transformer_depth"])
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = (
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
if head_dim is None:
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
class_embed_type = None
addition_embed_type = None
addition_time_embed_dim = None
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params["context_dim"] is not None:
context_dim = (
unet_params["context_dim"]
if isinstance(unet_params["context_dim"], int)
else unet_params["context_dim"][0]
)
if "num_classes" in unet_params:
if unet_params["num_classes"] == "sequential":
if context_dim in [2048, 1280]:
addition_embed_type = "text_time"
addition_time_embed_dim = 256
else:
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": in_channels,
"down_block_types": down_block_types,
"block_out_channels": block_out_channels,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"addition_embed_type": addition_embed_type,
"addition_time_embed_dim": addition_time_embed_dim,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"transformer_layers_per_block": transformer_layers_per_block,
}
if upcast_attention is not None:
deprecation_message = (
"Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
config["upcast_attention"] = upcast_attention
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params["disable_self_attentions"]
if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
config["num_class_embeds"] = unet_params["num_classes"]
config["out_channels"] = unet_params["out_channels"]
config["up_block_types"] = up_block_types
return config
def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
if image_size is not None:
deprecation_message = (
"Configuring ControlNetModel with the `image_size` argument"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
image_size = set_image_size(checkpoint, image_size=image_size)
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)
controlnet_config = {
"conditioning_channels": unet_params["hint_channels"],
"in_channels": diffusers_unet_config["in_channels"],
"down_block_types": diffusers_unet_config["down_block_types"],
"block_out_channels": diffusers_unet_config["block_out_channels"],
"layers_per_block": diffusers_unet_config["layers_per_block"],
"cross_attention_dim": diffusers_unet_config["cross_attention_dim"],
"attention_head_dim": diffusers_unet_config["attention_head_dim"],
"use_linear_projection": diffusers_unet_config["use_linear_projection"],
"class_embed_type": diffusers_unet_config["class_embed_type"],
"addition_embed_type": diffusers_unet_config["addition_embed_type"],
"addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"],
"projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"],
"transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"],
}
return controlnet_config
def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
"""
根据 LDM 模型的配置创建 Diffusers 配置。
"""
if image_size is not None:
deprecation_message = (
"Configuring AutoencoderKL with the `image_size` argument"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
image_size = set_image_size(checkpoint, image_size=image_size)
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
latents_mean = checkpoint["edm_mean"]
latents_std = checkpoint["edm_std"]
else:
latents_mean = None
latents_std = None
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
scaling_factor = original_config["model"]["params"]["scale_factor"]
elif scaling_factor is None:
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = {
"sample_size": image_size,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": down_block_types,
"up_block_types": up_block_types,
"block_out_channels": block_out_channels,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": scaling_factor,
}
if latents_mean is not None and latents_std is not None:
config.update({"latents_mean": latents_mean, "latents_std": latents_std})
return config
def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None):
for ldm_key in ldm_keys:
diffusers_key = (
ldm_key.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
if mapping:
diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
for ldm_key in ldm_keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = (
ldm_key.replace(mapping["old"], mapping["new"])
.replace("norm.weight", "group_norm.weight")
.replace("norm.bias", "group_norm.bias")
.replace("q.weight", "to_q.weight")
.replace("q.bias", "to_q.bias")
.replace("k.weight", "to_k.weight")
.replace("k.bias", "to_k.bias")
.replace("v.weight", "to_v.weight")
.replace("v.bias", "to_v.bias")
.replace("proj_out.weight", "to_out.0.weight")
.replace("proj_out.bias", "to_out.0.bias")
)
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
shape = new_checkpoint[diffusers_key].shape
if len(shape) == 3:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
elif len(shape) == 4:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
is_stage_c = "clip_txt_mapper.weight" in checkpoint
if is_stage_c:
state_dict = {}
for key in checkpoint.keys():
if key.endswith("in_proj_weight"):
weights = checkpoint[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = checkpoint[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = checkpoint[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = checkpoint[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = checkpoint[key]
else:
state_dict = {}
for key in checkpoint.keys():
if key.endswith("in_proj_weight"):
weights = checkpoint[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = checkpoint[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = checkpoint[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = checkpoint[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
elif key.endswith("clip_mapper.weight"):
weights = checkpoint[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = checkpoint[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = checkpoint[key]
return state_dict
def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
"""
接受状态字典和配置,并返回转换后的检查点。
"""
unet_state_dict = {}
keys = list(checkpoint.keys())
unet_key = LDM_UNET_KEY
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning("Checkpoint has both EMA and non-EMA weights.")
logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)
new_checkpoint = {}
ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
for diffusers_key, ldm_key in ldm_unet_keys.items():
if ldm_key not in unet_state_dict:
continue
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]):
class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"]
for diffusers_key, ldm_key in class_embed_keys.items():
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"):
addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"]
for diffusers_key, ldm_key in addition_embed_keys.items():
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
if "num_class_embeds" in config:
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
update_unet_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
unet_state_dict,
{"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
)
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
f"input_blocks.{i}.0.op.bias"
)
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if attentions:
update_unet_attention_ldm_to_diffusers(
attentions,
new_checkpoint,
unet_state_dict,
{"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
)
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
update_unet_resnet_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
unet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
)
else:
update_unet_attention_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
unet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
resnets = [
key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key
]
update_unet_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
unet_state_dict,
{"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"},
)
attentions = [
key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key
]
if attentions:
update_unet_attention_ldm_to_diffusers(
attentions,
new_checkpoint,
unet_state_dict,
{"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"},
)
if f"output_blocks.{i}.1.conv.weight" in unet_state_dict:
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.1.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.1.conv.bias"
]
if f"output_blocks.{i}.2.conv.weight" in unet_state_dict:
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.2.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.2.conv.bias"
]
return new_checkpoint
def convert_controlnet_checkpoint(
checkpoint,
config,
**kwargs,
):
if "time_embed.0.weight" in checkpoint:
controlnet_state_dict = checkpoint
else:
controlnet_state_dict = {}
keys = list(checkpoint.keys())
controlnet_key = LDM_CONTROLNET_KEY
for key in keys:
if key.startswith(controlnet_key):
controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
new_checkpoint = {}
ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
for diffusers_key, ldm_key in ldm_controlnet_keys.items():
if ldm_key not in controlnet_state_dict:
continue
new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key]
num_input_blocks = len(
{".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer}
)
input_blocks = {
layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
update_unet_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
controlnet_state_dict,
{"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
)
if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
f"input_blocks.{i}.0.op.bias"
)
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if attentions:
update_unet_attention_ldm_to_diffusers(
attentions,
new_checkpoint,
controlnet_state_dict,
{"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
)
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
num_middle_blocks = len(
{".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer}
)
middle_blocks = {
layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
update_unet_resnet_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
controlnet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
)
else:
update_unet_attention_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
controlnet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
)
new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
cond_embedding_blocks = {
".".join(layer.split(".")[:2])
for layer in controlnet_state_dict
if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
}
num_cond_embedding_blocks = len(cond_embedding_blocks)
for idx in range(1, num_cond_embedding_blocks + 1):
diffusers_idx = idx - 1
cond_block_id = 2 * idx
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
f"input_hint_block.{cond_block_id}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
f"input_hint_block.{cond_block_id}.bias"
)
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config):
vae_state_dict = {}
keys = list(checkpoint.keys())
vae_key = ""
for ldm_vae_key in LDM_VAE_KEYS:
if any(k.startswith(ldm_vae_key) for k in keys):
vae_key = ldm_vae_key
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"]
for diffusers_key, ldm_key in vae_diffusers_ldm_map.items():
if ldm_key not in vae_state_dict:
continue
new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
num_down_blocks = len(config["down_block_types"])
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
)
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.bias"
)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
update_vae_attentions_ldm_to_diffusers(
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
)
num_up_blocks = len(config["up_block_types"])
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
)
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
update_vae_attentions_ldm_to_diffusers(
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
keys = list(checkpoint.keys())
text_model_dict = {}
remove_prefixes = []
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
if remove_prefix:
remove_prefixes.append(remove_prefix)
for key in keys:
for prefix in remove_prefixes:
if key.startswith(prefix):
diffusers_key = key.replace(prefix, "")
text_model_dict[diffusers_key] = checkpoint.get(key)
return text_model_dict
def convert_open_clip_checkpoint(
text_model,
checkpoint,
prefix="cond_stage_model.model.",
):
text_model_dict = {}
text_proj_key = prefix + "text_projection"
if text_proj_key in checkpoint:
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
elif hasattr(text_model.config, "projection_dim"):
text_proj_dim = text_model.config.projection_dim
else:
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
keys = list(checkpoint.keys())
keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"]
for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items():
ldm_key = prefix + ldm_key
if ldm_key not in checkpoint:
continue
if ldm_key in keys_to_ignore:
continue
if ldm_key.endswith("text_projection"):
text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous()
else:
text_model_dict[diffusers_key] = checkpoint[ldm_key]
for key in keys:
if key in keys_to_ignore:
continue
if not key.startswith(prefix + "transformer."):
continue
diffusers_key = key.replace(prefix + "transformer.", "")
transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"]
for new_key, old_key in transformer_diffusers_to_ldm_map.items():
diffusers_key = (
diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "")
)
if key.endswith(".in_proj_weight"):
weight_value = checkpoint.get(key)
text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
text_model_dict[diffusers_key + ".k_proj.weight"] = (
weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
)
text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
elif key.endswith(".in_proj_bias"):
weight_value = checkpoint.get(key)
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
text_model_dict[diffusers_key + ".k_proj.bias"] = (
weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
)
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
else:
text_model_dict[diffusers_key] = checkpoint.get(key)
return text_model_dict
def create_diffusers_clip_model_from_ldm(
cls,
checkpoint,
subfolder="",
config=None,
torch_dtype=None,
local_files_only=None,
is_legacy_loading=False,
):
if config:
config = {"pretrained_model_name_or_path": config}
else:
config = fetch_diffusers_config(checkpoint)
if is_legacy_loading:
logger.warning(
(
"Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
"the local cache directory with the necessary CLIP model config files. "
"Attempting to load CLIP model from legacy cache directory."
)
)
if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
clip_config = "openai/clip-vit-large-patch14"
config["pretrained_model_name_or_path"] = clip_config
subfolder = ""
elif is_open_clip_model(checkpoint):
clip_config = "stabilityai/stable-diffusion-2"
config["pretrained_model_name_or_path"] = clip_config
subfolder = "text_encoder"
else:
clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config["pretrained_model_name_or_path"] = clip_config
subfolder = ""
model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls(model_config)
position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
if is_clip_model(checkpoint):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
elif (
is_clip_sdxl_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
elif (
is_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
elif is_open_clip_model(checkpoint):
prefix = "cond_stage_model.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif (
is_open_clip_sdxl_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
):
):
prefix = "conditioner.embedders.1.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif is_open_clip_sdxl_refiner_model(checkpoint):
prefix = "conditioner.embedders.0.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif (
is_open_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
else:
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
if torch_dtype is not None:
model.to(torch_dtype)
model.eval()
return model
def _legacy_load_scheduler(
cls,
checkpoint,
component_name,
original_config=None,
**kwargs,
):
scheduler_type = kwargs.get("scheduler_type", None)
prediction_type = kwargs.get("prediction_type", None)
if scheduler_type is not None:
deprecation_message = (
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
"Example:\n\n"
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
"scheduler = DDIMScheduler()\n"
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
)
deprecate("scheduler_type", "1.0.0", deprecation_message)
if prediction_type is not None:
deprecation_message = (
"Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
"pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
"Example:\n\n"
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
)
deprecate("prediction_type", "1.0.0", deprecation_message)
scheduler_config = SCHEDULER_DEFAULT_CONFIG
model_type = infer_diffusers_model_type(checkpoint=checkpoint)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
if original_config:
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
else:
num_train_timesteps = 1000
scheduler_config["num_train_timesteps"] = num_train_timesteps
if model_type == "v2":
if prediction_type is None:
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
else:
prediction_type = prediction_type or "epsilon"
scheduler_config["prediction_type"] = prediction_type
if model_type in ["xl_base", "xl_refiner"]:
scheduler_type = "euler"
elif model_type == "playground":
scheduler_type = "edm_dpm_solver_multistep"
else:
if original_config:
beta_start = original_config["model"]["params"].get("linear_start")
beta_end = original_config["model"]["params"].get("linear_end")
else:
beta_start = 0.02
beta_end = 0.085
scheduler_config["beta_start"] = beta_start
scheduler_config["beta_end"] = beta_end
scheduler_config["beta_schedule"] = "scaled_linear"
scheduler_config["clip_sample"] = False
scheduler_config["set_alpha_to_one"] = False
if component_name == "low_res_scheduler":
return cls.from_config(
{
"beta_end": 0.02,
"beta_schedule": "scaled_linear",
"beta_start": 0.0001,
"clip_sample": True,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"trained_betas": None,
"variance_type": "fixed_small",
}
)
if scheduler_type is None:
return cls.from_config(scheduler_config)
elif scheduler_type == "pndm":
scheduler_config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(scheduler_config)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
elif scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler_config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
elif scheduler_type == "ddim":
scheduler = DDIMScheduler.from_config(scheduler_config)
elif scheduler_type == "edm_dpm_solver_multistep":
scheduler_config = {
"algorithm_type": "dpmsolver++",
"dynamic_thresholding_ratio": 0.995,
"euler_at_final": False,
"final_sigmas_type": "zero",
"lower_order_final": True,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"rho": 7.0,
"sample_max_value": 1.0,
"sigma_data": 0.5,
"sigma_max": 80.0,
"sigma_min": 0.002,
"solver_order": 2,
"solver_type": "midpoint",
"thresholding": False,
}
scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
return scheduler
def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
if config:
config = {"pretrained_model_name_or_path": config}
else:
config = fetch_diffusers_config(checkpoint)
if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
clip_config = "openai/clip-vit-large-patch14"
config["pretrained_model_name_or_path"] = clip_config
subfolder = ""
elif is_open_clip_model(checkpoint):
clip_config = "stabilityai/stable-diffusion-2"
config["pretrained_model_name_or_path"] = clip_config
subfolder = "tokenizer"
else:
clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config["pretrained_model_name_or_path"] = clip_config
subfolder = ""
tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
return tokenizer
def _legacy_load_safety_checker(local_files_only, torch_dtype):
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
feature_extractor = AutoImageProcessor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
)
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1
caption_projection_dim = 1536
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
)
return converted_state_dict
def is_t5_in_single_file(checkpoint):
if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
return True
return False
def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}
remove_prefixes = ["text_encoders.t5xxl.transformer."]
for key in keys:
for prefix in remove_prefixes:
if key.startswith(prefix):
diffusers_key = key.replace(prefix, "")
text_model_dict[diffusers_key] = checkpoint.get(key)
return text_model_dict
def create_diffusers_t5_model_from_checkpoint(
cls,
checkpoint,
subfolder="",
config=None,
torch_dtype=None,
local_files_only=None,
):
if config:
config = {"pretrained_model_name_or_path": config}
else:
config = fetch_diffusers_config(checkpoint)
model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls(model_config)
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = model._keep_in_fp32_modules
else:
keep_in_fp32_modules = []
if keep_in_fp32_modules is not None:
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
param.data = param.data.to(torch.float32)
return model
def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
for k, v in checkpoint.items():
if "pos_encoder" in k:
continue
else:
converted_state_dict[
k.replace(".norms.0", ".norm1")
.replace(".norms.1", ".norm2")
.replace(".ff_norm", ".norm3")
.replace(".attention_blocks.0", ".attn1")
.replace(".attention_blocks.1", ".attn2")
.replace(".temporal_transformer", "")
] = v
return converted_state_dict
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1
mlp_ratio = 4.0
inner_dim = 3072
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"time_in.in_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"time_in.out_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
"vector_in.out_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
has_guidance = any("guidance" in k for k in checkpoint)
if has_guidance:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
"guidance_in.in_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
"guidance_in.in_layer.bias"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
"guidance_in.out_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
"guidance_in.out_layer.bias"
)
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.bias"
)
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
)
return converted_state_dict
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律