diffusers 源码解析(九)
.\diffusers\models\embeddings_flax.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 该文件的使用需要遵循 Apache 2.0 许可证
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# 查看许可证以了解特定权限和限制
import math # 导入数学库以进行数学运算
import flax.linen as nn # 导入Flax库中的神经网络模块
import jax.numpy as jnp # 导入JAX的numpy模块以进行数值计算
def get_sinusoidal_embeddings(
timesteps: jnp.ndarray, # 定义输入参数 timesteps 为一维 JAX 数组
embedding_dim: int, # 定义输出嵌入的维度
freq_shift: float = 1, # 频率偏移的默认值为1
min_timescale: float = 1, # 最小时间尺度的默认值
max_timescale: float = 1.0e4, # 最大时间尺度的默认值
flip_sin_to_cos: bool = False, # 是否翻转正弦和余弦
scale: float = 1.0, # 缩放因子的默认值
) -> jnp.ndarray: # 函数返回一个 JAX 数组
"""Returns the positional encoding (same as Tensor2Tensor).
返回位置编码,类似于Tensor2Tensor
Args:
timesteps: a 1-D Tensor of N indices, one per batch element.
输入为一维张量,N个索引,每个批次元素一个
These may be fractional.
embedding_dim: The number of output channels.
嵌入的通道数
min_timescale: The smallest time unit (should probably be 0.0).
最小时间单位
max_timescale: The largest time unit.
最大时间单位
Returns:
a Tensor of timing signals [N, num_channels]
返回时间信号的张量 [N, num_channels]
"""
assert timesteps.ndim == 1, "Timesteps should be a 1d-array" # 检查 timesteps 是否为一维数组
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" # 检查嵌入维度是否为偶数
num_timescales = float(embedding_dim // 2) # 计算时间尺度的数量
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) # 计算对数时间尺度增量
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) # 计算反时间尺度
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) # 计算嵌入
# scale embeddings
scaled_time = scale * emb # 对嵌入进行缩放
if flip_sin_to_cos: # 如果需要翻转正弦和余弦
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) # 拼接余弦和正弦信号
else: # 否则
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) # 拼接正弦和余弦信号
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) # 重塑信号的形状
return signal # 返回信号
class FlaxTimestepEmbedding(nn.Module): # 定义时间步嵌入模块
r"""
Time step Embedding Module. Learns embeddings for input time steps.
时间步嵌入模块。学习输入时间步的嵌入
Args:
time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
时间步嵌入维度
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
参数的数据类型
"""
time_embed_dim: int = 32 # 设置时间嵌入维度的默认值为32
dtype: jnp.dtype = jnp.float32 # 设置参数的数据类型的默认值为jnp.float32
@nn.compact # 指示该方法为紧凑的神经网络模块
def __call__(self, temb): # 定义模块的调用方法,接收输入参数 temb
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) # 第一个全连接层
temb = nn.silu(temb) # 应用Silu激活函数
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) # 第二个全连接层
return temb # 返回处理后的temb
class FlaxTimesteps(nn.Module): # 定义时间步模块
r"""
# 包装类,用于生成正弦时间步嵌入,详细说明见 https://arxiv.org/abs/2006.11239
# 参数:
# dim (`int`, *可选*, 默认为 `32`):
# 时间步嵌入的维度
dim: int = 32 # 定义时间步嵌入的维度,默认值为 32
flip_sin_to_cos: bool = False # 定义是否将正弦值转换为余弦值,默认为 False
freq_shift: float = 1 # 定义频率偏移量,默认为 1
@nn.compact # 表示这是一个紧凑模式的神经网络层,适合 JAX 使用
def __call__(self, timesteps): # 定义调用方法,接受时间步作为输入
return get_sinusoidal_embeddings( # 调用函数生成正弦嵌入
timesteps, # 输入的时间步
embedding_dim=self.dim, # 嵌入维度设置为实例属性 dim
flip_sin_to_cos=self.flip_sin_to_cos, # 设置是否翻转正弦到余弦
freq_shift=self.freq_shift # 设置频率偏移量
) # 返回生成的正弦嵌入
.\diffusers\models\lora.py
# 版权信息,指明文件的版权所有者和保留权利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵循许可证,否则不得使用本文件。
# 可在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,
# 否则根据许可证分发的软件在“按原样”基础上提供,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参阅许可证,以获取有关权限和
# 限制的具体条款。
# 重要提示: #
###################################################################
# ----------------------------------------------------------------#
# 此文件已被弃用,将很快删除 #
# (一旦 PEFT 成为 LoRA 的必需依赖项) #
# ----------------------------------------------------------------#
###################################################################
from typing import Optional, Tuple, Union # 导入可选类型、元组和联合类型以用于类型注解
import torch # 导入 PyTorch 库
import torch.nn.functional as F # 导入 PyTorch 的功能性神经网络模块
from torch import nn # 从 PyTorch 导入神经网络模块
from ..utils import deprecate, logging # 从上级目录导入工具函数 deprecate 和 logging
from ..utils.import_utils import is_transformers_available # 导入检查 transformers 库可用性的函数
# 如果 transformers 库可用,则导入相关模型
if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection # 导入 CLIP 文本模型及其变体
logger = logging.get_logger(__name__) # 创建一个记录器实例,用于日志记录,禁用 pylint 的名称检查
def text_encoder_attn_modules(text_encoder):
attn_modules = [] # 初始化一个空列表,用于存储注意力模块
# 检查文本编码器是否为 CLIPTextModel 或 CLIPTextModelWithProjection 的实例
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
# 遍历编码器层,收集每一层的自注意力模块
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn" # 构造注意力模块的名称
mod = layer.self_attn # 获取当前层的自注意力模块
attn_modules.append((name, mod)) # 将名称和模块元组添加到列表中
else:
# 如果文本编码器不是预期的类型,抛出值错误
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
return attn_modules # 返回注意力模块的列表
def text_encoder_mlp_modules(text_encoder):
mlp_modules = [] # 初始化一个空列表,用于存储 MLP 模块
# 检查文本编码器是否为 CLIPTextModel 或 CLIPTextModelWithProjection 的实例
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
# 遍历编码器层,收集每一层的 MLP 模块
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
mlp_mod = layer.mlp # 获取当前层的 MLP 模块
name = f"text_model.encoder.layers.{i}.mlp" # 构造 MLP 模块的名称
mlp_modules.append((name, mlp_mod)) # 将名称和模块元组添加到列表中
else:
# 如果文本编码器不是预期的类型,抛出值错误
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
return mlp_modules # 返回 MLP 模块的列表
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
# 遍历文本编码器中的注意力模块
for _, attn_module in text_encoder_attn_modules(text_encoder):
# 检查当前注意力模块的查询投影是否为 PatchedLoraProjection 实例
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale # 调整查询投影的 Lora 缩放因子
attn_module.k_proj.lora_scale = lora_scale # 调整键投影的 Lora 缩放因子
attn_module.v_proj.lora_scale = lora_scale # 调整值投影的 Lora 缩放因子
attn_module.out_proj.lora_scale = lora_scale # 调整输出投影的 Lora 缩放因子
# 遍历文本编码器中的 MLP 模块
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
# 检查当前模块的 fc1 层是否为 PatchedLoraProjection 类型
if isinstance(mlp_module.fc1, PatchedLoraProjection):
# 设置 fc1 层的 lora_scale 属性
mlp_module.fc1.lora_scale = lora_scale
# 设置 fc2 层的 lora_scale 属性
mlp_module.fc2.lora_scale = lora_scale
# 定义一个名为 PatchedLoraProjection 的类,继承自 PyTorch 的 nn.Module
class PatchedLoraProjection(torch.nn.Module):
# 初始化方法,接受多个参数以设置 LoraProjection
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
# 设置弃用警告信息
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
# 调用 deprecate 函数记录弃用信息
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)
# 调用父类的初始化方法
super().__init__()
# 从 lora 模块导入 LoRALinearLayer 类
from ..models.lora import LoRALinearLayer
# 保存传入的常规线性层
self.regular_linear_layer = regular_linear_layer
# 获取常规线性层的设备信息
device = self.regular_linear_layer.weight.device
# 如果未指定数据类型,则使用常规线性层的权重数据类型
if dtype is None:
dtype = self.regular_linear_layer.weight.dtype
# 创建 LoRALinearLayer 实例
self.lora_linear_layer = LoRALinearLayer(
self.regular_linear_layer.in_features,
self.regular_linear_layer.out_features,
network_alpha=network_alpha,
device=device,
dtype=dtype,
rank=rank,
)
# 保存 LoRA 的缩放因子
self.lora_scale = lora_scale
# 重写 PyTorch 的 state_dict 方法以确保仅保存 'regular_linear_layer' 权重
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
# 如果没有 LoRA 线性层,返回常规线性层的状态字典
if self.lora_linear_layer is None:
return self.regular_linear_layer.state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)
# 否则调用父类的 state_dict 方法
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
# 定义一个融合 LoRA 权重的方法
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
# 如果没有 LoRA 线性层,则直接返回
if self.lora_linear_layer is None:
return
# 获取常规线性层的权重数据类型和设备
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
# 将常规线性层的权重转换为浮点类型
w_orig = self.regular_linear_layer.weight.data.float()
# 将 LoRA 层的上权重转换为浮点类型
w_up = self.lora_linear_layer.up.weight.data.float()
# 将 LoRA 层的下权重转换为浮点类型
w_down = self.lora_linear_layer.down.weight.data.float()
# 如果 network_alpha 不为 None,则调整上权重
if self.lora_linear_layer.network_alpha is not None:
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
# 计算融合后的权重
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
# 如果安全融合并且融合权重中包含 NaN,抛出异常
if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
# 更新常规线性层的权重数据
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
# 将 LoRA 线性层设为 None,表示已经融合
self.lora_linear_layer = None
# 将上、下权重矩阵转移到 CPU 以节省内存
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
# 更新 LoRA 的缩放因子
self.lora_scale = lora_scale
# 定义解融合 Lora 的私有方法
def _unfuse_lora(self):
# 检查 w_up 和 w_down 属性是否存在且不为 None
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
# 如果任一属性为 None,则直接返回
return
# 获取常规线性层的权重数据
fused_weight = self.regular_linear_layer.weight.data
# 保存权重的数据类型和设备信息
dtype, device = fused_weight.dtype, fused_weight.device
# 将 w_up 转换为目标设备并转为浮点类型
w_up = self.w_up.to(device=device).float()
# 将 w_down 转换为目标设备并转为浮点类型
w_down = self.w_down.to(device).float()
# 计算未融合的权重,通过从融合权重中减去 Lora 的贡献
unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
# 将未融合的权重赋值回常规线性层
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
# 清空 w_up 和 w_down 属性
self.w_up = None
self.w_down = None
# 定义前向传播方法
def forward(self, input):
# 如果 lora_scale 为 None,则设置为 1.0
if self.lora_scale is None:
self.lora_scale = 1.0
# 如果 lora_linear_layer 为 None,则直接返回常规线性层的输出
if self.lora_linear_layer is None:
return self.regular_linear_layer(input)
# 返回常规线性层的输出加上 Lora 的贡献
return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
# 定义一个用于 LoRA 的线性层,继承自 nn.Module
class LoRALinearLayer(nn.Module):
r"""
A linear layer that is used with LoRA.
Parameters:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
rank (`int`, `optional`, defaults to 4):
The rank of the LoRA layer.
network_alpha (`float`, `optional`, defaults to `None`):
The value of the network alpha used for stable learning and preventing underflow. This value has the same
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
device (`torch.device`, `optional`, defaults to `None`):
The device to use for the layer's weights.
dtype (`torch.dtype`, `optional`, defaults to `None`):
The dtype to use for the layer's weights.
"""
# 初始化方法,定义输入输出特征和其他参数
def __init__(
self,
in_features: int, # 输入特征数量
out_features: int, # 输出特征数量
rank: int = 4, # LoRA 层的秩,默认为 4
network_alpha: Optional[float] = None, # 用于稳定学习的网络 alpha,默认为 None
device: Optional[Union[torch.device, str]] = None, # 权重使用的设备,默认为 None
dtype: Optional[torch.dtype] = None, # 权重使用的数据类型,默认为 None
):
super().__init__() # 调用父类的初始化方法
# 弃用提示消息,提醒用户切换到 PEFT 后端
deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRALinearLayer", "1.0.0", deprecation_message) # 记录弃用信息
# 定义向下线性层,不使用偏置
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
# 定义向上线性层,不使用偏置
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# 将网络 alpha 值赋给实例变量
self.network_alpha = network_alpha
self.rank = rank # 保存秩
self.out_features = out_features # 保存输出特征数量
self.in_features = in_features # 保存输入特征数量
# 使用正态分布初始化向下权重
nn.init.normal_(self.down.weight, std=1 / rank)
# 将向上权重初始化为零
nn.init.zeros_(self.up.weight)
# 前向传播方法,接受隐藏状态并返回处理后的结果
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype # 保存输入数据类型
dtype = self.down.weight.dtype # 获取向下层权重的数据类型
# 通过向下层处理隐藏状态
down_hidden_states = self.down(hidden_states.to(dtype))
# 通过向上层处理向下层输出
up_hidden_states = self.up(down_hidden_states)
# 如果网络 alpha 不为 None,则调整向上层输出
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
# 返回与原始数据类型相同的输出
return up_hidden_states.to(orig_dtype)
# 定义一个用于 LoRA 的卷积层,继承自 nn.Module
class LoRAConv2dLayer(nn.Module):
r"""
A convolutional layer that is used with LoRA.
# 参数说明
Parameters:
in_features (`int`): # 输入特征的数量
Number of input features. # 输入特征的数量
out_features (`int`): # 输出特征的数量
Number of output features. # 输出特征的数量
rank (`int`, `optional`, defaults to 4): # LoRA 层的秩,默认为 4
The rank of the LoRA layer. # LoRA 层的秩
kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1): # 卷积核的大小,默认为 (1, 1)
The kernel size of the convolution. # 卷积核的大小
stride (`int` or `tuple` of two `int`, `optional`, defaults to 1): # 卷积的步幅,默认为 (1, 1)
The stride of the convolution. # 卷积的步幅
padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0): # 卷积的填充方式,默认为 0
The padding of the convolution. # 卷积的填充方式
network_alpha (`float`, `optional`, defaults to `None`): # 网络 alpha 的值,用于稳定学习,防止下溢
The value of the network alpha used for stable learning and preventing underflow. This value has the same
meaning as the `--network_alpha` option in the kohya-ss trainer script. See # 与 kohya-ss 训练脚本中的 `--network_alpha` 选项含义相同
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # 参考链接
# 初始化方法
def __init__(
self,
in_features: int, # 输入特征数量
out_features: int, # 输出特征数量
rank: int = 4, # LoRA 层的秩,默认为 4
kernel_size: Union[int, Tuple[int, int]] = (1, 1), # 卷积核大小,默认为 (1, 1)
stride: Union[int, Tuple[int, int]] = (1, 1), # 卷积步幅,默认为 (1, 1)
padding: Union[int, Tuple[int, int], str] = 0, # 卷积填充,默认为 0
network_alpha: Optional[float] = None, # 网络 alpha 的值,默认为 None
):
super().__init__() # 调用父类的初始化方法
# 弃用警告信息,提示用户切换到 PEFT 后端
deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message) # 发出弃用警告
# 定义下卷积层,输入为 in_features,输出为 rank,使用指定的卷积参数
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
# 根据官方 kohya_ss 训练器,向上卷积层的卷积核大小始终固定
# # 参考链接: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
# 定义上卷积层,输入为 rank,输出为 out_features,使用固定的卷积核大小 (1, 1)
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
# 保存网络 alpha 值,与训练脚本中的相同含义
# 参考链接: https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha # 设置网络 alpha 值
self.rank = rank # 设置秩
# 初始化下卷积层的权重为均值为 0,标准差为 1/rank 的正态分布
nn.init.normal_(self.down.weight, std=1 / rank)
# 初始化上卷积层的权重为 0
nn.init.zeros_(self.up.weight)
# 前向传播方法
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # 定义前向传播函数
orig_dtype = hidden_states.dtype # 保存输入张量的原始数据类型
dtype = self.down.weight.dtype # 获取下卷积层权重的数据类型
# 将输入的隐状态张量通过下卷积层
down_hidden_states = self.down(hidden_states.to(dtype))
# 将下卷积层的输出通过上卷积层
up_hidden_states = self.up(down_hidden_states)
# 如果 network_alpha 不为 None,则进行缩放
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank # 根据 network_alpha 进行缩放
# 返回转换回原始数据类型的输出张量
return up_hidden_states.to(orig_dtype) # 返回最终输出
# 定义一个可以与 LoRA 兼容的卷积层,继承自 nn.Conv2d
class LoRACompatibleConv(nn.Conv2d):
"""
A convolutional layer that can be used with LoRA.
"""
# 初始化方法,接受可变数量的参数,lora_layer 为可选参数,其他参数通过 kwargs 接收
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
# 设置弃用消息,提示用户切换到 PEFT 后端
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
# 调用弃用函数,记录此类的弃用信息
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 将 lora_layer 赋值给实例变量
self.lora_layer = lora_layer
# 设置 lora_layer 的方法
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
# 设置弃用消息,提示用户切换到 PEFT 后端
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
# 调用弃用函数,记录此方法的弃用信息
deprecate("set_lora_layer", "1.0.0", deprecation_message)
# 将传入的 lora_layer 赋值给实例变量
self.lora_layer = lora_layer
# 融合 LoRA 权重的方法
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
# 如果 lora_layer 为 None,直接返回
if self.lora_layer is None:
return
# 获取当前权重的数据类型和设备
dtype, device = self.weight.data.dtype, self.weight.data.device
# 将权重转换为浮点型
w_orig = self.weight.data.float()
# 获取 lora_layer 的上升和下降权重,并转换为浮点型
w_up = self.lora_layer.up.weight.data.float()
w_down = self.lora_layer.down.weight.data.float()
# 如果 network_alpha 不为 None,调整上升权重
if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
# 进行矩阵乘法,融合上升和下降权重
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
# 将融合的结果调整为原始权重的形状
fusion = fusion.reshape((w_orig.shape))
# 计算最终融合权重
fused_weight = w_orig + (lora_scale * fusion)
# 如果安全融合为 True,检查融合权重中是否有 NaN 值
if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
# 将融合后的权重赋值回实例的权重,保持设备和数据类型
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# 融合后可以删除 lora_layer
self.lora_layer = None
# 将上升和下降矩阵转移到 CPU,以减少内存占用
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
# 存储 lora_scale
self._lora_scale = lora_scale
# 解融合 LoRA 权重的方法
def _unfuse_lora(self):
# 检查 w_up 和 w_down 是否存在
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
# 获取当前融合权重
fused_weight = self.weight.data
# 获取当前权重的数据类型和设备
dtype, device = fused_weight.data.dtype, fused_weight.data.device
# 将 w_up 和 w_down 转移到正确的设备并转换为浮点型
self.w_up = self.w_up.to(device=device).float()
self.w_down = self.w_down.to(device).float()
# 进行矩阵乘法,重新计算未融合权重
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
# 将融合结果调整为融合权重的形状
fusion = fusion.reshape((fused_weight.shape))
# 计算最终的未融合权重
unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
# 更新实例的权重
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
# 清空 w_up 和 w_down
self.w_up = None
self.w_down = None
# 定义前向传播函数,接收隐藏状态和缩放因子,返回张量
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
# 检查填充模式是否不是“零”,若是则进行相应填充
if self.padding_mode != "zeros":
# 对隐藏状态进行填充,使用反向填充参数和指定的填充模式
hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode)
# 设置填充为 (0, 0)
padding = (0, 0)
else:
# 使用类中的填充属性
padding = self.padding
# 进行二维卷积操作,返回卷积结果
original_outputs = F.conv2d(
hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups
)
# 如果 Lora 层不存在,则返回卷积结果
if self.lora_layer is None:
return original_outputs
else:
# 否则,将卷积结果与 Lora 层的结果按比例相加并返回
return original_outputs + (scale * self.lora_layer(hidden_states))
# 定义一个兼容 LoRA 的线性层,继承自 nn.Linear
class LoRACompatibleLinear(nn.Linear):
"""
A Linear layer that can be used with LoRA.
"""
# 初始化方法,接收参数并可选传入 LoRA 层
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
# 定义弃用提示信息,建议用户切换到 PEFT 后端
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
# 调用弃用函数提示用户
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 设置 LoRA 层
self.lora_layer = lora_layer
# 设置 LoRA 层的方法
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
# 定义弃用提示信息,建议用户切换到 PEFT 后端
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
# 调用弃用函数提示用户
deprecate("set_lora_layer", "1.0.0", deprecation_message)
# 设置 LoRA 层
self.lora_layer = lora_layer
# 融合 LoRA 权重的方法
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
# 如果没有 LoRA 层,直接返回
if self.lora_layer is None:
return
# 获取权重的数据类型和设备
dtype, device = self.weight.data.dtype, self.weight.data.device
# 将原始权重转换为浮点型
w_orig = self.weight.data.float()
# 获取 LoRA 层的上权重并转换为浮点型
w_up = self.lora_layer.up.weight.data.float()
# 获取 LoRA 层的下权重并转换为浮点型
w_down = self.lora_layer.down.weight.data.float()
# 如果网络 alpha 不为 None,则调整上权重
if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
# 融合权重的计算
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
# 如果进行安全融合且融合权重存在 NaN,则抛出错误
if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
# 更新当前权重为融合后的权重
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# 将 LoRA 层设为 None,表示已融合
self.lora_layer = None
# 将上权重和下权重移到 CPU,防止内存溢出
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
# 保存 LoRA 融合的缩放因子
self._lora_scale = lora_scale
# 反融合 LoRA 权重的方法
def _unfuse_lora(self):
# 如果上权重和下权重不存在,直接返回
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
# 获取当前融合权重
fused_weight = self.weight.data
# 获取当前权重的数据类型和设备
dtype, device = fused_weight.dtype, fused_weight.device
# 将上权重和下权重移到对应设备并转换为浮点型
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
# 计算未融合的权重
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
# 更新当前权重为未融合的权重
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
# 将上权重和下权重设为 None
self.w_up = None
self.w_down = None
# 前向传播方法
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
# 如果没有 LoRA 层,直接使用父类的前向传播
if self.lora_layer is None:
out = super().forward(hidden_states)
return out
else:
# 使用父类的前向传播加上 LoRA 层的输出
out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
return out
.\diffusers\models\modeling_flax_pytorch_utils.py
# coding=utf-8 # 指定文件编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team. # 版权信息,表明版权所有者
# Licensed under the Apache License, Version 2.0 (the "License"); # 说明该文件根据 Apache 2.0 许可证发布
# you may not use this file except in compliance with the License. # 说明只能在遵守许可证的情况下使用此文件
# You may obtain a copy of the License at # 提供获取许可证的地址
#
# http://www.apache.org/licenses/LICENSE-2.0 # 许可证的具体链接
#
# Unless required by applicable law or agreed to in writing, software # 免责声明,除非另有规定或书面同意
# distributed under the License is distributed on an "AS IS" BASIS, # 说明软件是按“现状”提供的
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # 没有任何形式的明示或暗示的保证
# See the License for the specific language governing permissions and # 指向许可证以获取具体条款
# limitations under the License. # 以及使用限制的说明
"""PyTorch - Flax general utilities.""" # 文档字符串,描述该模块的功能
import re # 导入正则表达式模块
import jax.numpy as jnp # 导入 JAX 的 NumPy 库,并重命名为 jnp
from flax.traverse_util import flatten_dict, unflatten_dict # 从 flax 导入字典扁平化和还原的工具
from jax.random import PRNGKey # 从 jax 导入伪随机数生成器的键
from ..utils import logging # 从父目录导入 logging 模块
logger = logging.get_logger(__name__) # 创建一个日志记录器,记录当前模块的信息
def rename_key(key): # 定义一个函数,用于重命名键
regex = r"\w+[.]\d+" # 定义一个正则表达式,匹配包含点号和数字的字符串
pats = re.findall(regex, key) # 使用正则表达式查找所有匹配的字符串
for pat in pats: # 遍历所有找到的匹配
key = key.replace(pat, "_".join(pat.split("."))) # 将匹配的字符串中的点替换为下划线
return key # 返回修改后的键
#####################
# PyTorch => Flax #
##################### # 注释区分 PyTorch 到 Flax 的转换部分
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # 说明该函数的来源链接
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py # 说明该函数的另一来源链接
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): # 定义函数,重命名权重并在必要时改变张量形状
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" # 文档字符串,说明函数功能
# conv norm or layer norm # 注释,说明即将处理的内容是卷积归一化或层归一化
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) # 将原键的最后一个元素替换为 "scale"
# rename attention layers # 注释,说明将重命名注意力层
if len(pt_tuple_key) > 1: # 如果元组键的长度大于 1
for rename_from, rename_to in ( # 遍历重命名映射的元组
("to_out_0", "proj_attn"), # 旧名称到新名称的映射
("to_k", "key"), # 旧名称到新名称的映射
("to_v", "value"), # 旧名称到新名称的映射
("to_q", "query"), # 旧名称到新名称的映射
):
if pt_tuple_key[-2] == rename_from: # 如果倒数第二个元素匹配旧名称
weight_name = pt_tuple_key[-1] # 获取最后一个元素作为权重名称
weight_name = "kernel" if weight_name == "weight" else weight_name # 如果权重名称是 "weight",则改为 "kernel"
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) # 生成新的键
if renamed_pt_tuple_key in random_flax_state_dict: # 如果新键存在于状态字典中
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape # 断言新键的形状与转置的张量形状相同
return renamed_pt_tuple_key, pt_tensor.T # 返回新的键和转置的张量
if ( # 检查是否满足以下条件
any("norm" in str_ for str_ in pt_tuple_key) # 如果键中任何部分包含 "norm"
and (pt_tuple_key[-1] == "bias") # 并且最后一个元素是 "bias"
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) # 并且去掉最后一个元素后加 "bias" 的键不在状态字典中
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) # 并且去掉最后一个元素后加 "scale" 的键在状态字典中
):
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) # 将键的最后一个元素替换为 "scale"
return renamed_pt_tuple_key, pt_tensor # 返回新的键和原张量
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: # 如果最后一个元素是 "weight" 或 "gamma" 并且去掉最后一个元素后加 "scale" 的键在状态字典中
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) # 将键的最后一个元素替换为 "scale"
return renamed_pt_tuple_key, pt_tensor # 返回新的键和原张量
# embedding # 注释,表明此处将处理嵌入相关的内容
# 检查元组的最后一个元素是否为 "weight",并且在字典中查找相应的 "embedding" 键
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
# 将元组的最后一个元素替换为 "embedding"
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
# 返回更新后的元组键和张量
return renamed_pt_tuple_key, pt_tensor
# 卷积层处理
# 更新元组的最后一个元素为 "kernel"
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
# 检查元组的最后一个元素是否为 "weight",并且张量的维度是否为 4
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
# 转置张量的维度顺序
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
# 返回更新后的元组键和张量
return renamed_pt_tuple_key, pt_tensor
# 线性层处理
# 更新元组的最后一个元素为 "kernel"
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
# 检查元组的最后一个元素是否为 "weight"
if pt_tuple_key[-1] == "weight":
# 转置张量
pt_tensor = pt_tensor.T
# 返回更新后的元组键和张量
return renamed_pt_tuple_key, pt_tensor
# 旧版 PyTorch 层归一化权重处理
# 更新元组的最后一个元素为 "weight"
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
# 检查元组的最后一个元素是否为 "gamma"
if pt_tuple_key[-1] == "gamma":
# 返回更新后的元组键和张量
return renamed_pt_tuple_key, pt_tensor
# 旧版 PyTorch 层归一化偏置处理
# 更新元组的最后一个元素为 "bias"
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
# 检查元组的最后一个元素是否为 "beta"
if pt_tuple_key[-1] == "beta":
# 返回更新后的元组键和张量
return renamed_pt_tuple_key, pt_tensor
# 如果没有匹配的条件,则返回原始元组键和张量
return pt_tuple_key, pt_tensor
# 将 PyTorch 的状态字典转换为 Flax 模型的参数字典
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
# 步骤 1:将 PyTorch 张量转换为 NumPy 数组
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# 步骤 2:由于模型是无状态的,使用随机种子初始化 Flax 参数
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
# 将随机生成的 Flax 参数展平为字典形式
random_flax_state_dict = flatten_dict(random_flax_params)
# 初始化一个空的 Flax 状态字典
flax_state_dict = {}
# 需要修改一些参数名称以匹配 Flax 的命名
for pt_key, pt_tensor in pt_state_dict.items():
# 重命名 PyTorch 的键
renamed_pt_key = rename_key(pt_key)
# 将重命名后的键分割成元组形式
pt_tuple_key = tuple(renamed_pt_key.split("."))
# 正确重命名权重参数并调整形状
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
# 检查 Flax 键是否在随机生成的状态字典中
if flax_key in random_flax_state_dict:
# 如果形状不匹配,抛出错误
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
# 也将意外的权重添加到字典中,以便引发警告
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
# 返回解压缩后的 Flax 状态字典
return unflatten_dict(flax_state_dict)
.\diffusers\models\modeling_flax_utils.py
# 指定文件编码为 UTF-8
# coding=utf-8
# 版权声明,表示文件由 HuggingFace Inc. 团队拥有
# Copyright 2024 The HuggingFace Inc. team.
#
# 根据 Apache 2.0 许可证许可本文件,使用时需遵循该许可证
# Licensed under the Apache License, Version 2.0 (the "License");
# 只能在遵循许可证的前提下使用此文件
# you may not use this file except in compliance with the License.
# 可以在此网址获取许可证副本
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有约定,软件按“原样”提供
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的明示或暗示的保证或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取特定语言管理权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入操作系统模块
import os
# 导入反序列化错误类
from pickle import UnpicklingError
# 导入类型提示所需的 Any, Dict, Union 类型
from typing import Any, Dict, Union
# 导入 JAX 库及其 NumPy 子模块
import jax
import jax.numpy as jnp
# 导入 msgpack 异常
import msgpack.exceptions
# 从 flax 库导入冻结字典及其解冻方法
from flax.core.frozen_dict import FrozenDict, unfreeze
# 从 flax 库导入字节序列化与反序列化方法
from flax.serialization import from_bytes, to_bytes
# 从 flax 库导入字典扁平化与解扁平化方法
from flax.traverse_util import flatten_dict, unflatten_dict
# 从 huggingface_hub 导入创建仓库和下载方法
from huggingface_hub import create_repo, hf_hub_download
# 导入 huggingface_hub 的一些异常类
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
# 导入请求库中的 HTTP 错误类
from requests import HTTPError
# 导入当前包的版本和 PyTorch 可用性检查
from .. import __version__, is_torch_available
# 导入工具函数和常量
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
PushToHubMixin,
logging,
)
# 从模型转换工具中导入 PyTorch 状态字典转换为 Flax 的方法
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)
# 定义 FlaxModelMixin 类,继承自 PushToHubMixin
class FlaxModelMixin(PushToHubMixin):
r"""
所有 Flax 模型的基类。
[`FlaxModelMixin`] 负责存储模型配置,并提供加载、下载和保存模型的方法。
- **config_name** ([`str`]) -- 调用 [`~FlaxModelMixin.save_pretrained`] 时保存模型的文件名。
"""
# 配置文件名常量,指定模型配置文件名
config_name = CONFIG_NAME
# 自动保存的参数列表
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
# Flax 内部参数列表
_flax_internal_args = ["name", "parent", "dtype"]
# 类方法,用于根据配置创建模型实例
@classmethod
def _from_config(cls, config, **kwargs):
"""
模型初始化所需的上下文管理器在这里定义。
"""
# 返回类的实例,传入配置和其他参数
return cls(config, **kwargs)
# 定义一个方法,将给定参数的浮点值转换为指定的数据类型
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
# 帮助方法,用于将给定 PyTree 中的浮点值转换为给定的数据类型
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# 条件转换函数,判断参数类型并执行转换
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
# 检查参数是否为浮点类型的数组
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
# 将数组转换为指定的数据类型
param = param.astype(dtype)
# 返回转换后的参数
return param
# 如果没有提供掩码,则对所有参数应用条件转换
if mask is None:
# 使用 jax.tree_map 对参数树中的每个元素应用条件转换
return jax.tree_map(conditional_cast, params)
# 扁平化参数字典以便处理
flat_params = flatten_dict(params)
# 扁平化掩码,并丢弃结构信息
flat_mask, _ = jax.tree_flatten(mask)
# 遍历掩码和参数的扁平化键
for masked, key in zip(flat_mask, flat_params.keys()):
# 如果掩码为真,则执行转换
if masked:
param = flat_params[key]
# 将转换后的参数重新存储回扁平化参数字典中
flat_params[key] = conditional_cast(param)
# 将扁平化的参数字典转换回原始结构
return unflatten_dict(flat_params)
# 定义一个方法,将参数转换为 bfloat16 类型
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
# 将浮点参数转换为 jax.numpy.bfloat16,返回新的参数树
r"""
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
the `params` in place.
This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
for params you want to cast, and `False` for those you want to skip.
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> params = model.to_bf16(params)
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_bf16(params, mask)
```py"""
# 调用内部方法,将参数转换为 bfloat16 类型
return self._cast_floating_to(params, jnp.bfloat16, mask)
# 将模型参数转换为浮点32位格式的方法
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
将浮点数 `params` 转换为 `jax.numpy.float32`。此方法可用于显式将模型参数转换为 fp32 精度。
返回一个新的 `params` 树,而不在原地转换 `params`。
参数:
params (`Union[Dict, FrozenDict]`):
模型参数的 `PyTree`。
mask (`Union[Dict, FrozenDict]`):
与 `params` 树具有相同结构的 `PyTree`。叶子应为布尔值。应为要转换的参数设置为 `True`,为要跳过的参数设置为 `False`。
示例:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # 从 huggingface.co 下载模型和配置
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # 默认情况下,模型参数将是 fp32,为了说明此方法的用法,
>>> # 我们将首先转换为 fp16,然后再转换回 fp32
>>> params = model.to_f16(params)
>>> # 现在转换回 fp32
>>> params = model.to_fp32(params)
```py"""
# 调用私有方法,将参数转换为浮点32格式,传入参数、目标类型和掩码
return self._cast_floating_to(params, jnp.float32, mask)
# 定义一个将浮点数参数转换为 float16 的方法,接受参数字典和可选的掩码
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
将浮点数 `params` 转换为 `jax.numpy.float16`。该方法返回一个新的 `params` 树,不会在原地转换 `params`。
此方法可在 GPU 上使用,显式地将模型参数转换为 float16 精度,以进行全半精度训练,或将权重保存为 float16 以便推理,从而节省内存并提高速度。
参数:
params (`Union[Dict, FrozenDict]`):
一个模型参数的 `PyTree`。
mask (`Union[Dict, FrozenDict]`):
具有与 `params` 树相同结构的 `PyTree`。叶子节点应为布尔值。对于要转换的参数,应为 `True`,而要跳过的参数应为 `False`。
示例:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # 加载模型
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # 默认情况下,模型参数将为 fp32,转换为 float16
>>> params = model.to_fp16(params)
>>> # 如果你不想转换某些参数(例如层归一化的偏差和尺度)
>>> # 则可以按如下方式传递掩码
>>> from flax import traverse_util
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_fp16(params, mask)
```py"""
# 调用内部方法将参数转换为 float16 类型,传入可选的掩码
return self._cast_floating_to(params, jnp.float16, mask)
# 定义一个初始化权重的方法,接受随机数生成器作为参数,返回字典
def init_weights(self, rng: jax.Array) -> Dict:
# 抛出未实现的错误,提示此方法需要被实现
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
# 定义一个类方法用于从预训练模型加载参数,接受模型名称或路径等参数
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
dtype: jnp.dtype = jnp.float32,
*model_args,
**kwargs,
# 定义一个保存预训练模型的方法,接受保存目录和参数等
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
params: Union[Dict, FrozenDict],
is_main_process: bool = True,
push_to_hub: bool = False,
**kwargs,
):
"""
保存模型及其配置文件到指定目录,以便使用
[`~FlaxModelMixin.from_pretrained`] 类方法重新加载。
参数:
save_directory (`str` 或 `os.PathLike`):
保存模型及其配置文件的目录。如果目录不存在,将会被创建。
params (`Union[Dict, FrozenDict]`):
模型参数的 `PyTree`。
is_main_process (`bool`, *可选*, 默认为 `True`):
调用此函数的进程是否为主进程。在分布式训练中非常有用,
需要在所有进程上调用此函数。此时,仅在主进程上将 `is_main_process=True`
以避免竞争条件。
push_to_hub (`bool`, *可选*, 默认为 `False`):
保存模型后是否将其推送到 Hugging Face 模型库。可以使用 `repo_id`
指定要推送到的库(默认为 `save_directory` 中的名称)。
kwargs (`Dict[str, Any]`, *可选*):
额外的关键字参数,将传递给 [`~utils.PushToHubMixin.push_to_hub`] 方法。
"""
# 检查提供的路径是否为文件,如果是则记录错误并返回
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
# 如果目录不存在则创建该目录
os.makedirs(save_directory, exist_ok=True)
# 如果需要推送到模型库
if push_to_hub:
# 从关键字参数中弹出提交信息,如果没有则为 None
commit_message = kwargs.pop("commit_message", None)
# 从关键字参数中弹出隐私设置,默认为 False
private = kwargs.pop("private", False)
# 从关键字参数中弹出创建 PR 的设置,默认为 False
create_pr = kwargs.pop("create_pr", False)
# 从关键字参数中弹出 token,默认为 None
token = kwargs.pop("token", None)
# 从关键字参数中弹出 repo_id,默认为 save_directory 的最后一部分
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
# 创建库并获取 repo_id
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
# 将当前对象赋值给 model_to_save
model_to_save = self
# 将模型架构附加到配置中
# 保存配置
if is_main_process:
# 如果是主进程,保存模型配置到指定目录
model_to_save.save_config(save_directory)
# 保存模型的输出文件路径
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
# 以二进制写入模式打开模型文件
with open(output_model_file, "wb") as f:
# 将模型参数转换为字节
model_bytes = to_bytes(params)
# 将字节数据写入文件
f.write(model_bytes)
# 记录模型权重保存的路径信息
logger.info(f"Model weights saved in {output_model_file}")
# 如果需要推送到模型库
if push_to_hub:
# 调用上传文件夹的方法,将模型文件夹推送到模型库
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
.\diffusers\models\modeling_outputs.py
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从上级目录的 utils 模块导入 BaseOutput 类
from ..utils import BaseOutput
# 定义 AutoencoderKLOutput 类,继承自 BaseOutput
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
AutoencoderKL 编码方法的输出。
参数:
latent_dist (`DiagonalGaussianDistribution`):
编码器的输出,以 `DiagonalGaussianDistribution` 的均值和对数方差表示。
`DiagonalGaussianDistribution` 允许从分布中采样潜在变量。
"""
# 定义 latent_dist 属性,类型为 DiagonalGaussianDistribution
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
# 定义 Transformer2DModelOutput 类,继承自 BaseOutput
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
[`Transformer2DModel`] 的输出。
参数:
sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)` 或 `(batch size, num_vector_embeds - 1, num_latent_pixels)` 如果 [`Transformer2DModel`] 是离散的):
基于 `encoder_hidden_states` 输入的隐藏状态输出。如果是离散的,则返回无噪声潜在像素的概率分布。
"""
# 定义 sample 属性,类型为 torch.Tensor
sample: "torch.Tensor" # noqa: F821
.\diffusers\models\modeling_pytorch_flax_utils.py
# 指定文件编码为 UTF-8
# coding=utf-8
# 版权所有 2024 The HuggingFace Inc. 团队。
#
# 根据 Apache 许可证版本 2.0("许可证")许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件在许可证下以“原样”方式分发,
# 不提供任何形式的保证或条件,无论是明示或暗示的。
# 有关许可证下的特定权限和限制,请参见许可证。
"""PyTorch - Flax 一般实用工具。"""
# 从 pickle 模块导入 UnpicklingError 异常
from pickle import UnpicklingError
# 导入 jax 库及其 numpy 模块
import jax
import jax.numpy as jnp
# 导入 numpy 库
import numpy as np
# 从 flax.serialization 导入 from_bytes 函数
from flax.serialization import from_bytes
# 从 flax.traverse_util 导入 flatten_dict 函数
from flax.traverse_util import flatten_dict
# 从 utils 模块导入 logging
from ..utils import logging
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)
#####################
# Flax => PyTorch #
#####################
# 从指定模型文件加载 Flax 检查点到 PyTorch 模型
# 来源:https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
# 尝试打开模型文件以读取 Flax 状态
try:
with open(model_file, "rb") as flax_state_f:
# 从字节流中反序列化 Flax 状态
flax_state = from_bytes(None, flax_state_f.read())
# 捕获反序列化错误
except UnpicklingError as e:
try:
# 以文本模式打开模型文件
with open(model_file) as f:
# 检查文件内容是否以 "version" 开头
if f.read().startswith("version"):
# 如果是,抛出 OSError,提示缺少 git-lfs
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
# 否则,抛出 ValueError
raise ValueError from e
# 捕获 Unicode 解码错误和其他值错误
except (UnicodeDecodeError, ValueError):
# 抛出环境错误,提示无法转换文件
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# 返回加载的 Flax 权重到 PyTorch 模型
return load_flax_weights_in_pytorch_model(pt_model, flax_state)
# 从 Flax 状态加载权重到 PyTorch 模型
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
"""将 Flax 检查点加载到 PyTorch 模型中"""
# 尝试导入 PyTorch
try:
import torch # noqa: F401
# 捕获导入错误
except ImportError:
# 记录错误信息,提示需要安装 PyTorch 和 Flax
logger.error(
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
# 抛出异常
raise
# 检查是否存在 bf16 权重
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
# 如果存在 bf16 类型的权重
if any(is_type_bf16):
# 如果权重是 bf16 类型,转换为 fp32,因为 torch.from_numpy 无法处理 bf16
# 而且 bf16 在 PyTorch 中尚未完全支持。
logger.warning(
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
"before loading those in PyTorch model."
)
# 使用 tree_map 遍历 flax_state,将 bf16 权重转换为 float32
flax_state = jax.tree_util.tree_map(
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
)
# 将基础模型前缀设为空
pt_model.base_model_prefix = ""
# 将 flax_state 字典扁平化,使用 "." 作为分隔符
flax_state_dict = flatten_dict(flax_state, sep=".")
# 获取 PyTorch 模型的状态字典
pt_model_dict = pt_model.state_dict()
# 记录意外和缺失的键
unexpected_keys = [] # 存储意外键
missing_keys = set(pt_model_dict.keys()) # 存储缺失键的集合
# 遍历 flax_state_dict 中的每个键值对
for flax_key_tuple, flax_tensor in flax_state_dict.items():
# 将键元组转换为数组形式
flax_key_tuple_array = flax_key_tuple.split(".")
# 如果键的最后一个元素是 "kernel" 且张量维度为 4
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
# 将最后一个元素替换为 "weight",并调整张量的维度顺序
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
# 如果键的最后一个元素是 "kernel"
elif flax_key_tuple_array[-1] == "kernel":
# 将最后一个元素替换为 "weight",并转置张量
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
flax_tensor = flax_tensor.T
# 如果键的最后一个元素是 "scale"
elif flax_key_tuple_array[-1] == "scale":
# 将最后一个元素替换为 "weight"
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
# 如果键数组中不包含 "time_embedding"
if "time_embedding" not in flax_key_tuple_array:
# 遍历键数组,替换下划线为点
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
flax_key_tuple_array[i] = (
flax_key_tuple_string.replace("_0", ".0")
.replace("_1", ".1")
.replace("_2", ".2")
.replace("_3", ".3")
.replace("_4", ".4")
.replace("_5", ".5")
.replace("_6", ".6")
.replace("_7", ".7")
.replace("_8", ".8")
.replace("_9", ".9")
)
# 将键数组重新连接为字符串
flax_key = ".".join(flax_key_tuple_array)
# 如果当前键在 PyTorch 模型的字典中
if flax_key in pt_model_dict:
# 如果权重形状不匹配,抛出错误
if flax_tensor.shape != pt_model_dict[flax_key].shape:
raise ValueError(
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
else:
# 将权重添加到 PyTorch 字典中
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
# 从缺失键中移除当前键
missing_keys.remove(flax_key)
else:
# 权重不是 PyTorch 模型所期望的
unexpected_keys.append(flax_key)
# 将状态字典加载到 PyTorch 模型中
pt_model.load_state_dict(pt_model_dict)
# 将缺失键重新转换为列表
# 将 missing_keys 转换为列表,以便后续处理
missing_keys = list(missing_keys)
# 检查 unexpected_keys 的长度,如果大于 0,表示有未使用的权重
if len(unexpected_keys) > 0:
# 记录警告信息,提示某些权重未被使用
logger.warning(
"Some weights of the Flax model were not used when initializing the PyTorch model"
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
" FlaxBertForSequenceClassification model)."
)
# 检查 missing_keys 的长度,如果大于 0,表示有权重未被初始化
if len(missing_keys) > 0:
# 记录警告信息,提示某些权重是新初始化的
logger.warning(
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
" use it for predictions and inference."
)
# 返回初始化后的 PyTorch 模型
return pt_model
.\diffusers\models\modeling_utils.py
# coding=utf-8 # 指定文件编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team. # HuggingFace Inc. 团队的版权声明
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # NVIDIA 的版权声明
#
# Licensed under the Apache License, Version 2.0 (the "License"); # 指定此文件使用 Apache 2.0 许可证
# you may not use this file except in compliance with the License. # 使用此文件需要遵循许可证的规定
# You may obtain a copy of the License at # 可以在以下网址获取许可证
#
# http://www.apache.org/licenses/LICENSE-2.0 # 许可证的具体链接
#
# Unless required by applicable law or agreed to in writing, software # 除非法律要求或书面同意
# distributed under the License is distributed on an "AS IS" BASIS, # 否则按 "现状" 基础分发软件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # 不提供任何形式的担保或条件
# See the License for the specific language governing permissions and # 参见许可证了解特定权限和限制
# limitations under the License. # 以及许可证下的限制
import inspect # 导入 inspect 模块,用于获取对象的信息
import itertools # 导入 itertools 模块,提供高效的迭代器
import json # 导入 json 模块,用于 JSON 数据的解析和生成
import os # 导入 os 模块,提供与操作系统交互的功能
import re # 导入 re 模块,提供正则表达式操作
from collections import OrderedDict # 从 collections 导入有序字典
from functools import partial # 从 functools 导入部分函数应用工具
from pathlib import Path # 从 pathlib 导入路径处理工具
from typing import Any, Callable, List, Optional, Tuple, Union # 导入类型注解支持
import safetensors # 导入 safetensors 库,处理安全的张量
import torch # 导入 PyTorch 库
from huggingface_hub import create_repo, split_torch_state_dict_into_shards # 从 huggingface_hub 导入相关功能
from huggingface_hub.utils import validate_hf_hub_args # 导入验证 Hugging Face Hub 参数的工具
from torch import Tensor, nn # 从 torch 导入 Tensor 和神经网络模块
from .. import __version__ # 从父级模块导入当前版本
from ..utils import ( # 从父级模块的 utils 导入多个工具
CONFIG_NAME, # 配置文件名常量
FLAX_WEIGHTS_NAME, # Flax 权重文件名常量
SAFE_WEIGHTS_INDEX_NAME, # 安全权重索引文件名常量
SAFETENSORS_WEIGHTS_NAME, # Safetensors 权重文件名常量
WEIGHTS_INDEX_NAME, # 权重索引文件名常量
WEIGHTS_NAME, # 权重文件名常量
_add_variant, # 导入添加变体的工具
_get_checkpoint_shard_files, # 导入获取检查点分片文件的工具
_get_model_file, # 导入获取模型文件的工具
deprecate, # 导入弃用标记的工具
is_accelerate_available, # 导入检测加速库可用性的工具
is_torch_version, # 导入检测 PyTorch 版本的工具
logging, # 导入日志记录工具
)
from ..utils.hub_utils import ( # 从父级模块的 hub_utils 导入多个工具
PushToHubMixin, # 导入用于推送到 Hub 的混合类
load_or_create_model_card, # 导入加载或创建模型卡的工具
populate_model_card, # 导入填充模型卡的工具
)
from .model_loading_utils import ( # 从当前包的 model_loading_utils 导入多个工具
_determine_device_map, # 导入确定设备映射的工具
_fetch_index_file, # 导入获取索引文件的工具
_load_state_dict_into_model, # 导入将状态字典加载到模型中的工具
load_model_dict_into_meta, # 导入将模型字典加载到元数据中的工具
load_state_dict, # 导入加载状态字典的工具
)
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") # 编译正则表达式,用于匹配分片文件名
if is_torch_version(">=", "1.9.0"): # 检查当前 PyTorch 版本是否大于等于 1.9.0
_LOW_CPU_MEM_USAGE_DEFAULT = True # 设置低 CPU 内存使用默认值为 True
else: # 如果 PyTorch 版本小于 1.9.0
_LOW_CPU_MEM_USAGE_DEFAULT = False # 设置低 CPU 内存使用默认值为 False
if is_accelerate_available(): # 检查加速库是否可用
import accelerate # 如果可用,则导入 accelerate 库
def get_parameter_device(parameter: torch.nn.Module) -> torch.device: # 定义获取模型参数设备的函数
try: # 尝试执行以下代码
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) # 合并模型参数和缓冲区
return next(parameters_and_buffers).device # 返回第一个参数或缓冲区的设备
except StopIteration: # 如果没有参数和缓冲区
# For torch.nn.DataParallel compatibility in PyTorch 1.5 # 为兼容 PyTorch 1.5 的 DataParallel
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: # 定义查找张量属性的内部函数
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] # 获取模块中所有张量属性
return tuples # 返回张量属性的列表
gen = parameter._named_members(get_members_fn=find_tensor_attributes) # 获取模型的命名成员生成器
first_tuple = next(gen) # 获取生成器中的第一个元组
return first_tuple[1].device # 返回第一个张量的设备
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: # 定义获取模型参数数据类型的函数
try: # 尝试执行以下代码
params = tuple(parameter.parameters()) # 将模型参数转换为元组
if len(params) > 0: # 如果参数数量大于零
return params[0].dtype # 返回第一个参数的数据类型
buffers = tuple(parameter.buffers()) # 将缓冲区转换为元组
if len(buffers) > 0: # 如果缓冲区数量大于零
return buffers[0].dtype # 返回第一个缓冲区的数据类型
# 捕获 StopIteration 异常,处理迭代器停止的情况
except StopIteration:
# 为了兼容 PyTorch 1.5 中的 torch.nn.DataParallel
# 定义一个函数,用于查找模块中所有的张量属性,返回属性名和张量的元组列表
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
# 生成一个元组列表,包含模块中所有张量属性的名称和对应的张量
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
# 返回元组列表
return tuples
# 使用指定的函数获取模块的命名成员生成器
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
# 获取生成器中的第一个元组
first_tuple = next(gen)
# 返回第一个张量的 dtype(数据类型)
return first_tuple[1].dtype
# 定义一个模型混合类,继承自 PyTorch 的 nn.Module 和 PushToHubMixin
class ModelMixin(torch.nn.Module, PushToHubMixin):
r"""
所有模型的基类。
[`ModelMixin`] 负责存储模型配置,并提供加载、下载和保存模型的方法。
- **config_name** ([`str`]) -- 保存模型时的文件名,调用 [`~models.ModelMixin.save_pretrained`]。
"""
# 配置名称,作为模型保存时的文件名
config_name = CONFIG_NAME
# 自动保存的参数列表
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
# 是否支持梯度检查点
_supports_gradient_checkpointing = False
# 加载时忽略的意外键
_keys_to_ignore_on_load_unexpected = None
# 不分割的模块
_no_split_modules = None
# 初始化方法
def __init__(self):
# 调用父类的初始化方法
super().__init__()
# 重写 getattr 方法以优雅地弃用直接访问配置属性
def __getattr__(self, name: str) -> Any:
"""重写 `getattr` 的唯一原因是优雅地弃用直接访问配置属性。
参见 https://github.com/huggingface/diffusers/pull/3129 需要在这里重写
__getattr__,以免触发 `torch.nn.Module` 的 __getattr__:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
"""
# 检查属性是否在内部字典中,并且是否存在于内部字典的属性中
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
# 检查属性是否在当前实例的字典中
is_attribute = name in self.__dict__
# 如果属性在配置中且不在实例字典中,显示弃用警告
if is_in_config and not is_attribute:
# 构建弃用消息
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
# 调用弃用函数显示警告
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
# 返回内部字典中的属性值
return self._internal_dict[name]
# 调用 PyTorch 的原始 __getattr__ 方法
return super().__getattr__(name)
# 定义一个只读属性,检查是否启用了梯度检查点
@property
def is_gradient_checkpointing(self) -> bool:
"""
检查该模型是否启用了梯度检查点。
"""
# 遍历模型中的所有模块,检查是否有启用梯度检查点的模块
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
# 启用梯度检查点的方法
def enable_gradient_checkpointing(self) -> None:
"""
启用当前模型的梯度检查点(在其他框架中可能称为 *激活检查点* 或
*检查点激活*)。
"""
# 检查当前模型是否支持梯度检查点
if not self._supports_gradient_checkpointing:
# 如果不支持,抛出值错误
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
# 应用设置,启用梯度检查点
self.apply(partial(self._set_gradient_checkpointing, value=True))
# 禁用梯度检查点的方法
def disable_gradient_checkpointing(self) -> None:
"""
禁用当前模型的梯度检查点(在其他框架中可能称为 *激活检查点* 或
*检查点激活*)。
"""
# 检查当前模型是否支持梯度检查点
if self._supports_gradient_checkpointing:
# 应用设置,禁用梯度检查点
self.apply(partial(self._set_gradient_checkpointing, value=False))
# 定义一个设置 npu flash attention 开关的方法,接收布尔值 valid
def set_use_npu_flash_attention(self, valid: bool) -> None:
r"""
设置 npu flash attention 的开关。
"""
# 定义一个递归设置 npu flash attention 的内部方法,接收一个模块
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
# 如果模块有设置 npu flash attention 的方法,则调用它
if hasattr(module, "set_use_npu_flash_attention"):
module.set_use_npu_flash_attention(valid)
# 递归遍历模块的所有子模块
for child in module.children():
fn_recursive_set_npu_flash_attention(child)
# 遍历当前对象的所有子模块
for module in self.children():
# 如果子模块是一个 torch.nn.Module 类型,则调用递归方法
if isinstance(module, torch.nn.Module):
fn_recursive_set_npu_flash_attention(module)
# 定义一个启用 npu flash attention 的方法
def enable_npu_flash_attention(self) -> None:
r"""
启用来自 torch_npu 的 npu flash attention。
"""
# 调用设置方法,将开关置为 True
self.set_use_npu_flash_attention(True)
# 定义一个禁用 npu flash attention 的方法
def disable_npu_flash_attention(self) -> None:
r"""
禁用来自 torch_npu 的 npu flash attention。
"""
# 调用设置方法,将开关置为 False
self.set_use_npu_flash_attention(False)
# 定义一个设置内存高效注意力的 xformers 方法,接收布尔值 valid 和可选的注意力操作
def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# 递归遍历所有子模块。
# 任何暴露 set_use_memory_efficient_attention_xformers 方法的子模块都会接收到消息
def fn_recursive_set_mem_eff(module: torch.nn.Module):
# 如果模块有设置内存高效注意力的方法,则调用它
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
# 递归遍历模块的所有子模块
for child in module.children():
fn_recursive_set_mem_eff(child)
# 遍历当前对象的所有子模块
for module in self.children():
# 如果子模块是一个 torch.nn.Module 类型,则调用递归方法
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
# 启用来自 xFormers 的内存高效注意力
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
# 文档字符串,描述该方法的功能和使用示例
r"""
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
inference. Speed up during training is not guaranteed.
<Tip warning={true}>
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
precedent.
</Tip>
Parameters:
attention_op (`Callable`, *optional*):
Override the default `None` operator for use as `op` argument to the
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
function of xFormers.
Examples:
```py
>>> import torch
>>> from diffusers import UNet2DConditionModel
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
>>> model = UNet2DConditionModel.from_pretrained(
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
... )
>>> model = model.to("cuda")
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
```py
"""
# 设置使用 xFormers 的内存高效注意力,传入可选的注意力操作
self.set_use_memory_efficient_attention_xformers(True, attention_op)
# 禁用来自 xFormers 的内存高效注意力
def disable_xformers_memory_efficient_attention(self) -> None:
# 文档字符串,描述该方法的功能
r"""
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
"""
# 设置不使用 xFormers 的内存高效注意力
self.set_use_memory_efficient_attention_xformers(False)
# 保存预训练模型的方法
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Optional[Callable] = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False,
**kwargs,
@classmethod
# 类方法,加载预训练模型
@validate_hf_hub_args
@classmethod
def _load_pretrained_model(
cls,
model,
state_dict: OrderedDict,
resolved_archive_file,
pretrained_model_name_or_path: Union[str, os.PathLike],
ignore_mismatched_sizes: bool = False,
@classmethod
# 获取对象的构造函数签名参数
def _get_signature_keys(cls, obj):
# 获取构造函数的参数字典
parameters = inspect.signature(obj.__init__).parameters
# 提取必需的参数
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
# 提取可选参数
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
# 计算期望的模块,排除 'self'
expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters
# 从 transformers 的 modeling_utils.py 修改而来
# 定义一个私有方法,用于获取在使用 device_map 时不应拆分的模块
def _get_no_split_modules(self, device_map: str):
"""
获取模型中在使用 device_map 时不应拆分的模块。我们遍历模块以获取底层的 `_no_split_modules`。
参数:
device_map (`str`):
设备映射值。选项包括 ["auto", "balanced", "balanced_low_0", "sequential"]
返回:
`List[str]`: 不应拆分的模块列表
"""
# 初始化一个集合,用于存储不应拆分的模块
_no_split_modules = set()
# 将当前对象添加到待检查的模块列表中
modules_to_check = [self]
# 当待检查模块列表不为空时继续循环
while len(modules_to_check) > 0:
# 从待检查列表中弹出最后一个模块
module = modules_to_check.pop(-1)
# 如果模块不在不应拆分的模块集合中,检查其子模块
if module.__class__.__name__ not in _no_split_modules:
# 如果模块是 ModelMixin 的实例
if isinstance(module, ModelMixin):
# 如果模块的 `_no_split_modules` 属性为 None,抛出异常
if module._no_split_modules is None:
raise ValueError(
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
# 否则,将模块的不应拆分模块添加到集合中
else:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
# 将当前模块的所有子模块添加到待检查列表中
modules_to_check += list(module.children())
# 返回不应拆分模块的列表
return list(_no_split_modules)
# 定义一个属性,用于获取模块所在的设备
@property
def device(self) -> torch.device:
"""
`torch.device`: 模块所在的设备(假设所有模块参数在同一设备上)。
"""
# 调用函数获取当前对象的参数设备
return get_parameter_device(self)
# 定义一个属性,用于获取模块的数据类型
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: 模块的数据类型(假设所有模块参数具有相同的数据类型)。
"""
# 调用函数获取当前对象的参数数据类型
return get_parameter_dtype(self)
# 定义一个方法,用于获取模块中的参数数量
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
获取模块中(可训练或非嵌入)参数的数量。
参数:
only_trainable (`bool`, *可选*, 默认为 `False`):
是否仅返回可训练参数的数量。
exclude_embeddings (`bool`, *可选*, 默认为 `False`):
是否仅返回非嵌入参数的数量。
返回:
`int`: 参数的数量。
示例:
```py
from diffusers import UNet2DConditionModel
model_id = "runwayml/stable-diffusion-v1-5"
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
unet.num_parameters(only_trainable=True)
859520964
```py
"""
# 如果排除嵌入参数
if exclude_embeddings:
# 获取所有嵌入层的参数名
embedding_param_names = [
f"{name}.weight"
for name, module_type in self.named_modules()
if isinstance(module_type, torch.nn.Embedding)
]
# 筛选出非嵌入参数
non_embedding_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
]
# 返回所有非嵌入参数的数量(可训练或非可训练)
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
# 返回所有参数的数量(可训练或非可训练)
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
# 定义一个方法,用于转换过时的注意力块
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
# 初始化一个列表,用于存储过时注意力块的路径
deprecated_attention_block_paths = []
# 定义一个递归函数,用于查找过时的注意力块
def recursive_find_attn_block(name, module):
# 检查当前模块是否是过时的注意力块
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
# 将找到的模块名称添加到路径列表中
deprecated_attention_block_paths.append(name)
# 遍历模块的子模块
for sub_name, sub_module in module.named_children():
# 形成完整的子模块名称
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
# 递归查找子模块
recursive_find_attn_block(sub_name, sub_module)
# 从当前对象开始递归查找过时的注意力块
recursive_find_attn_block("", self)
# 注意:需要检查过时参数是否在状态字典中
# 因为可能加载的是已经转换过的状态字典
# 遍历所有找到的过时注意力块路径
for path in deprecated_attention_block_paths:
# group_norm 路径保持不变
# 将 query 参数转换为 to_q
if f"{path}.query.weight" in state_dict:
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
if f"{path}.query.bias" in state_dict:
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
# 将 key 参数转换为 to_k
if f"{path}.key.weight" in state_dict:
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
if f"{path}.key.bias" in state_dict:
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
# 将 value 参数转换为 to_v
if f"{path}.value.weight" in state_dict:
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
if f"{path}.value.bias" in state_dict:
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
# 将 proj_attn 参数转换为 to_out.0
if f"{path}.proj_attn.weight" in state_dict:
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
# 将当前对象的注意力模块转换为已弃用的注意力块
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
# 初始化一个列表,用于存储已弃用的注意力块模块
deprecated_attention_block_modules = []
# 定义递归函数以查找注意力块模块
def recursive_find_attn_block(module):
# 检查模块是否为已弃用的注意力块
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
# 将找到的模块添加到列表中
deprecated_attention_block_modules.append(module)
# 遍历子模块并递归调用
for sub_module in module.children():
recursive_find_attn_block(sub_module)
# 从当前对象开始递归查找
recursive_find_attn_block(self)
# 遍历所有已弃用的注意力块模块
for module in deprecated_attention_block_modules:
# 将新属性赋值给相应的旧属性
module.query = module.to_q
module.key = module.to_k
module.value = module.to_v
module.proj_attn = module.to_out[0]
# 删除旧属性以确保所有权重都加载到新属性中
del module.to_q
del module.to_k
del module.to_v
del module.to_out
# 将已弃用的注意力块模块恢复为当前对象的注意力模块
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
# 初始化一个列表,用于存储已弃用的注意力块模块
deprecated_attention_block_modules = []
# 定义递归函数以查找注意力块模块
def recursive_find_attn_block(module) -> None:
# 检查模块是否为已弃用的注意力块
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
# 将找到的模块添加到列表中
deprecated_attention_block_modules.append(module)
# 遍历子模块并递归调用
for sub_module in module.children():
recursive_find_attn_block(sub_module)
# 从当前对象开始递归查找
recursive_find_attn_block(self)
# 遍历所有已弃用的注意力块模块
for module in deprecated_attention_block_modules:
# 将旧属性赋值给相应的新属性
module.to_q = module.query
module.to_k = module.key
module.to_v = module.value
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
# 删除新属性以恢复旧的模块结构
del module.query
del module.key
del module.value
del module.proj_attn
# 定义一个继承自 ModelMixin 的类,用于处理从旧类到特定管道类的映射
class LegacyModelMixin(ModelMixin):
r"""
一个 `ModelMixin` 的子类,用于从旧类(如 `Transformer2DModel`)解析到更具体的管道类(如 `DiTTransformer2DModel`)的类映射。
"""
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# 为了避免依赖导入问题
from .model_loading_utils import _fetch_remapped_cls_from_config
# 创建 kwargs 的副本,以避免对后续调用中的关键字参数造成影响
kwargs_copy = kwargs.copy()
# 从 kwargs 中提取 cache_dir 参数,若未提供则为 None
cache_dir = kwargs.pop("cache_dir", None)
# 从 kwargs 中提取 force_download 参数,默认为 False
force_download = kwargs.pop("force_download", False)
# 从 kwargs 中提取 proxies 参数,默认为 None
proxies = kwargs.pop("proxies", None)
# 从 kwargs 中提取 local_files_only 参数,默认为 None
local_files_only = kwargs.pop("local_files_only", None)
# 从 kwargs 中提取 token 参数,默认为 None
token = kwargs.pop("token", None)
# 从 kwargs 中提取 revision 参数,默认为 None
revision = kwargs.pop("revision", None)
# 从 kwargs 中提取 subfolder 参数,默认为 None
subfolder = kwargs.pop("subfolder", None)
# 如果未提供配置,则将配置路径设置为预训练模型名称或路径
config_path = pretrained_model_name_or_path
# 设置用户代理信息
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
# 加载配置
config, _, _ = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
**kwargs,
)
# 解析类的映射
remapped_class = _fetch_remapped_cls_from_config(config, cls)
# 返回映射后的类的 from_pretrained 方法调用
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
.\diffusers\models\model_loading_utils.py
# 指定编码为 UTF-8
# coding=utf-8
# 版权声明,表明此文件的版权归 HuggingFace Inc. 团队所有
# Copyright 2024 The HuggingFace Inc. team.
# 版权声明,表明此文件的版权归 NVIDIA CORPORATION 所有
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 使用此文件必须遵守许可证
# you may not use this file except in compliance with the License.
# 可以在此处获取许可证的副本
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件在 "AS IS" 基础上分发
# Unless required by applicable law or agreed to in writing, software
# 不提供任何明示或暗示的担保或条件
# distributed under the License is distributed on an "AS IS" BASIS,
# 查看许可证以获取特定权限和限制的详细信息
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入标准库中的 importlib 模块
import importlib
# 导入 inspect 模块,用于检查对象
import inspect
# 导入操作系统模块
import os
# 从 collections 导入 OrderedDict,用于保持字典的顺序
from collections import OrderedDict
# 从 pathlib 导入 Path,处理文件路径
from pathlib import Path
# 导入 List、Optional 和 Union 类型提示
from typing import List, Optional, Union
# 导入 safetensors 模块
import safetensors
# 导入 PyTorch 库
import torch
# 从 huggingface_hub.utils 导入 EntryNotFoundError 异常
from huggingface_hub.utils import EntryNotFoundError
# 从 utils 模块中导入常量和函数
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
is_accelerate_available,
is_torch_version,
logging,
)
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)
# 定义类重映射字典,将旧类名映射到新类名
_CLASS_REMAPPING_DICT = {
"Transformer2DModel": {
"ada_norm_zero": "DiTTransformer2DModel",
"ada_norm_single": "PixArtTransformer2DModel",
}
}
# 如果可用,导入加速库的相关功能
if is_accelerate_available():
from accelerate import infer_auto_device_map
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
# 根据模型和设备映射确定设备映射
# Adapted from `transformers` (see modeling_utils.py)
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
# 如果 device_map 是字符串,获取不拆分模块
if isinstance(device_map, str):
no_split_modules = model._get_no_split_modules(device_map)
device_map_kwargs = {"no_split_module_classes": no_split_modules}
# 如果 device_map 不是 "sequential",计算平衡内存
if device_map != "sequential":
max_memory = get_balanced_memory(
model,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory,
**device_map_kwargs,
)
# 否则获取最大内存
else:
max_memory = get_max_memory(max_memory)
# 更新 device_map 参数并推断设备映射
device_map_kwargs["max_memory"] = max_memory
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
# 返回最终的设备映射
return device_map
# 从配置中获取重映射的类
def _fetch_remapped_cls_from_config(config, old_class):
# 获取旧类的名称
previous_class_name = old_class.__name__
# 根据配置中的 norm_type 查找重映射的类名
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)
# 详细信息:
# https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
# 如果 remapped_class_name 存在
if remapped_class_name:
# 加载 diffusers 库以导入兼容的原始调度器
diffusers_library = importlib.import_module(__name__.split(".")[0])
# 从 diffusers 库中获取 remapped_class_name 指定的类
remapped_class = getattr(diffusers_library, remapped_class_name)
# 记录日志,说明类对象正在更改,因之前的类将在未来版本中弃用
logger.info(
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
" DOESN'T affect the final results."
)
# 返回映射后的类
return remapped_class
else:
# 如果没有 remapped_class_name,返回旧类
return old_class
# 定义一个函数,用于加载检查点文件,返回格式化的错误信息(如有)
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
读取检查点文件,如果出现错误,则返回正确格式的错误信息。
"""
try:
# 获取检查点文件名的扩展名
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
# 如果文件扩展名是 SAFETENSORS_FILE_EXTENSION,则使用 safetensors 加载文件
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
else:
# 检查 PyTorch 版本,如果大于等于 1.13,则设置 weights_only 参数
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
# 加载检查点文件,并将模型权重映射到 CPU
return torch.load(
checkpoint_file,
map_location="cpu",
**weights_only_kwarg,
)
except Exception as e:
try:
# 尝试打开检查点文件
with open(checkpoint_file) as f:
# 检查文件是否以 "version" 开头,以确定是否缺少 git-lfs
if f.read().startswith("version"):
raise OSError(
"您似乎克隆了一个没有安装 git-lfs 的库。请安装 "
"git-lfs 并在克隆的文件夹中运行 `git lfs install` 以及 `git lfs pull`。"
)
else:
# 如果文件不存在,抛出 ValueError
raise ValueError(
f"无法找到加载此预训练模型所需的文件 {checkpoint_file}。请确保已正确保存模型。"
) from e
except (UnicodeDecodeError, ValueError):
# 如果读取文件时出现错误,抛出 OSError
raise OSError(
f"无法从检查点文件加载权重 '{checkpoint_file}' " f"在 '{checkpoint_file}'。"
)
# 定义一个函数,将模型状态字典加载到元数据中
def load_model_dict_into_meta(
model,
state_dict: OrderedDict,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
) -> List[str]:
# 如果未提供设备,则默认使用 CPU
device = device or torch.device("cpu")
# 如果未提供数据类型,则默认使用 float32
dtype = dtype or torch.float32
# 检查 set_module_tensor_to_device 函数是否接受 dtype 参数
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
# 初始化一个列表以存储意外的键
unexpected_keys = []
# 获取模型的空状态字典
empty_state_dict = model.state_dict()
# 遍历状态字典中的每个参数名称和对应的参数值
for param_name, param in state_dict.items():
# 如果参数名称不在空状态字典中,则记录为意外的键
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue # 跳过本次循环,继续下一个参数
# 检查空状态字典中对应参数的形状是否与当前参数的形状匹配
if empty_state_dict[param_name].shape != param.shape:
# 如果模型路径存在,则格式化字符串以包含模型路径
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
# 抛出值错误,提示参数形状不匹配,并给出解决方案和参考链接
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
# 如果接受数据类型,则将参数设置到模型的指定设备上,并指定数据类型
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
else:
# 如果不接受数据类型,则仅将参数设置到模型的指定设备上
set_module_tensor_to_device(model, param_name, device, value=param)
# 返回意外的键列表
return unexpected_keys
# 定义一个函数,将状态字典加载到模型中,并返回错误信息列表
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
# 如果需要,从 PyTorch 的 state_dict 转换旧格式到新格式
# 复制 state_dict,以便 _load_from_state_dict 可以对其进行修改
state_dict = state_dict.copy()
# 用于存储加载过程中的错误信息
error_msgs = []
# PyTorch 的 `_load_from_state_dict` 不会复制模块子孙中的参数
# 所以我们需要递归地应用这个函数
def load(module: torch.nn.Module, prefix: str = ""):
# 准备参数,调用模块的 `_load_from_state_dict` 方法
args = (state_dict, prefix, {}, True, [], [], error_msgs)
module._load_from_state_dict(*args)
# 遍历模块的所有子模块
for name, child in module._modules.items():
# 如果子模块存在,递归加载
if child is not None:
load(child, prefix + name + ".")
# 初始调用加载模型
load(model_to_load)
# 返回所有错误信息
return error_msgs
# 定义一个函数,获取索引文件的路径
def _fetch_index_file(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
# 如果是本地文件
if is_local:
# 构造索引文件的路径
index_file = Path(
pretrained_model_name_or_path,
subfolder or "", # 如果子文件夹为空,则使用空字符串
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
)
else:
# 构造索引文件在远程仓库中的路径
index_file_in_repo = Path(
subfolder or "", # 如果子文件夹为空,则使用空字符串
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
).as_posix() # 转换为 POSIX 路径格式
try:
# 获取模型文件的路径
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo, # 指定权重文件名
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=None, # 子文件夹为 None
user_agent=user_agent,
commit_hash=commit_hash,
)
# 将返回的路径转换为 Path 对象
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
# 如果找不到文件或发生环境错误,将索引文件设置为 None
index_file = None
# 返回索引文件的路径
return index_file
.\diffusers\models\normalization.py
# 指定文件编码为 UTF-8
# copyright 信息,标识版权所有者及年份
# 许可证声明,指明使用的许可证类型及条件
# 提供许可证的获取链接
# 声明在适用情况下,软件是以“原样”方式分发的,且不提供任何形式的担保或条件
# 引用许可证中关于权限和限制的具体条款
# 导入 numbers 模块,用于处理数值相关的操作
from typing import Dict, Optional, Tuple # 导入类型提示所需的类型
# 导入 PyTorch 相关模块和功能
import torch
import torch.nn as nn # 导入神经网络模块
import torch.nn.functional as F # 导入功能性神经网络操作模块
# 导入工具函数以检查 PyTorch 版本
from ..utils import is_torch_version
# 导入激活函数获取方法
from .activations import get_activation
# 导入嵌入层相关类
from .embeddings import (
CombinedTimestepLabelEmbeddings,
PixArtAlphaCombinedTimestepSizeEmbeddings,
)
class AdaLayerNorm(nn.Module): # 定义自定义的层归一化类,继承自 nn.Module
r""" # 文档字符串,描述此类的功能和参数
Norm layer modified to incorporate timestep embeddings. # 说明此层归一化是为了支持时间步嵌入
Parameters:
embedding_dim (`int`): The size of each embedding vector. # 嵌入向量的维度
num_embeddings (`int`, *optional*): The size of the embeddings dictionary. # 嵌入字典的大小(可选)
output_dim (`int`, *optional*): # 输出维度(可选)
norm_elementwise_affine (`bool`, defaults to `False): # 是否应用元素级仿射变换(默认 False)
norm_eps (`bool`, defaults to `False`): # 归一化时的小常数(默认 1e-5)
chunk_dim (`int`, defaults to `0`): # 分块维度(默认 0)
"""
def __init__( # 初始化方法,定义类的构造函数
self,
embedding_dim: int, # 嵌入维度
num_embeddings: Optional[int] = None, # 嵌入字典的大小(可选)
output_dim: Optional[int] = None, # 输出维度(可选)
norm_elementwise_affine: bool = False, # 是否应用元素级仿射变换
norm_eps: float = 1e-5, # 归一化时的小常数
chunk_dim: int = 0, # 分块维度
):
super().__init__() # 调用父类构造函数
self.chunk_dim = chunk_dim # 保存分块维度
output_dim = output_dim or embedding_dim * 2 # 如果未指定输出维度,则计算输出维度
if num_embeddings is not None: # 如果指定了嵌入字典大小
self.emb = nn.Embedding(num_embeddings, embedding_dim) # 初始化嵌入层
else:
self.emb = None # 嵌入层为 None
self.silu = nn.SiLU() # 初始化 SiLU 激活函数
self.linear = nn.Linear(embedding_dim, output_dim) # 初始化线性层
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) # 初始化层归一化
def forward( # 定义前向传播方法
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None # 输入张量及可选时间步和嵌入
) -> torch.Tensor: # 返回类型为张量
if self.emb is not None: # 如果嵌入层存在
temb = self.emb(timestep) # 通过嵌入层计算时间步的嵌入
temb = self.linear(self.silu(temb)) # 应用激活函数并通过线性层处理嵌入
if self.chunk_dim == 1: # 如果分块维度为 1
# 对于 CogVideoX 的特殊情况,分割嵌入为偏移量和缩放量
shift, scale = temb.chunk(2, dim=1) # 按照维度 1 分块
shift = shift[:, None, :] # 扩展偏移量维度
scale = scale[:, None, :] # 扩展缩放量维度
else: # 如果分块维度不是 1
scale, shift = temb.chunk(2, dim=0) # 按照维度 0 分块
x = self.norm(x) * (1 + scale) + shift # 进行层归一化,并应用缩放和偏移
return x # 返回结果
class FP32LayerNorm(nn.LayerNorm): # 定义 FP32 层归一化类,继承自 nn.LayerNorm
# 定义前向传播方法,接受输入张量并返回输出张量
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# 保存输入张量的数据类型
origin_dtype = inputs.dtype
# 进行层归一化处理,并将结果转换回原始数据类型
return F.layer_norm(
# 将输入张量转换为浮点型进行归一化
inputs.float(),
# 归一化的形状
self.normalized_shape,
# 如果权重存在,将其转换为浮点型;否则为 None
self.weight.float() if self.weight is not None else None,
# 如果偏置存在,将其转换为浮点型;否则为 None
self.bias.float() if self.bias is not None else None,
# 设置一个小的数值以避免除零
self.eps,
).to(origin_dtype) # 将归一化后的结果转换回原始数据类型
# 定义自适应层归一化零层的类
class AdaLayerNormZero(nn.Module):
r"""
自适应层归一化零层 (adaLN-Zero)。
参数:
embedding_dim (`int`): 每个嵌入向量的大小。
num_embeddings (`int`): 嵌入字典的大小。
"""
# 初始化方法,接收嵌入维度和可选的嵌入数量及归一化类型
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
# 调用父类初始化方法
super().__init__()
# 如果提供了嵌入数量,初始化嵌入层
if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
else:
# 否则,嵌入层设置为 None
self.emb = None
# 初始化 SiLU 激活函数
self.silu = nn.SiLU()
# 初始化线性变换层,输出维度为 6 倍的嵌入维度
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
# 根据提供的归一化类型,初始化归一化层
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
# 如果提供了不支持的归一化类型,抛出错误
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
# 定义前向传播方法
def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 如果嵌入层不为 None,则计算嵌入
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
# 先经过 SiLU 激活函数再经过线性变换
emb = self.linear(self.silu(emb))
# 将嵌入切分为 6 个部分
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
# 对输入 x 应用归一化,并结合缩放和偏移
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
# 返回处理后的 x 及其他信息
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# 定义自适应层归一化零层单一版本的类
class AdaLayerNormZeroSingle(nn.Module):
r"""
自适应层归一化零层 (adaLN-Zero)。
参数:
embedding_dim (`int`): 每个嵌入向量的大小。
num_embeddings (`int`): 嵌入字典的大小。
"""
# 初始化方法,接收嵌入维度和归一化类型
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
# 调用父类初始化方法
super().__init__()
# 初始化 SiLU 激活函数
self.silu = nn.SiLU()
# 初始化线性变换层,输出维度为 3 倍的嵌入维度
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
# 根据提供的归一化类型,初始化归一化层
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
# 如果提供了不支持的归一化类型,抛出错误
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
# 定义前向传播方法
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
# 定义一个函数的返回类型为五个张量的元组
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 通过线性层和Silu激活函数处理嵌入向量
emb = self.linear(self.silu(emb))
# 将处理后的嵌入向量分割成三个部分:shift_msa, scale_msa 和 gate_msa
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
# 对输入x进行归一化,并结合scale和shift进行变换
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
# 返回变换后的x和gate_msa
return x, gate_msa
# 定义 LuminaRMSNormZero 类,继承自 nn.Module
class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
# 初始化方法,设置嵌入维度、正则化参数和元素级偏置
def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
# 调用父类构造函数
super().__init__()
# 初始化 SiLU 激活函数
self.silu = nn.SiLU()
# 初始化线性变换层,输入为 embedding_dim 或 1024 中的较小值,输出为 4 倍的 embedding_dim
self.linear = nn.Linear(
min(embedding_dim, 1024),
4 * embedding_dim,
bias=True,
)
# 初始化 RMSNorm 层
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
# 前向传播方法
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 对 emb 应用线性变换和 SiLU 激活
emb = self.linear(self.silu(emb))
# 将嵌入分块为四个部分
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
# 对输入 x 应用 RMSNorm 并与 scale_msa 相乘
x = self.norm(x) * (1 + scale_msa[:, None])
# 返回处理后的 x 以及门控和缩放值
return x, gate_msa, scale_mlp, gate_mlp
# 定义 AdaLayerNormSingle 类,继承自 nn.Module
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
# 初始化方法,设置嵌入维度和是否使用额外条件
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
# 调用父类构造函数
super().__init__()
# 初始化 PixArtAlphaCombinedTimestepSizeEmbeddings,用于时间步嵌入
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)
# 初始化 SiLU 激活函数
self.silu = nn.SiLU()
# 初始化线性变换层,输出为 6 倍的嵌入维度
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
# 前向传播方法
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 嵌入时间步,可能使用额外的条件
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
# 返回线性变换后的嵌入和嵌入结果
return self.linear(self.silu(embedded_timestep)), embedded_timestep
# 定义 AdaGroupNorm 类,继承自 nn.Module
class AdaGroupNorm(nn.Module):
r"""
GroupNorm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
num_groups (`int`): The number of groups to separate the channels into.
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
"""
# 初始化方法,用于设置类的基本属性
def __init__(
# 嵌入向量的维度
self, embedding_dim: int,
# 输出向量的维度
out_dim: int,
# 组的数量
num_groups: int,
# 激活函数名称(可选)
act_fn: Optional[str] = None,
# 防止除零错误的微小值
eps: float = 1e-5
):
# 调用父类初始化方法
super().__init__()
# 设置组的数量
self.num_groups = num_groups
# 设置用于数值稳定性的微小值
self.eps = eps
# 如果没有提供激活函数,则设置为 None
if act_fn is None:
self.act = None
else:
# 根据激活函数名称获取激活函数
self.act = get_activation(act_fn)
# 创建一个线性层,将嵌入维度映射到输出维度的两倍
self.linear = nn.Linear(embedding_dim, out_dim * 2)
# 前向传播方法,定义输入数据的处理方式
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
# 如果存在激活函数,则对嵌入进行激活
if self.act:
emb = self.act(emb)
# 将嵌入传递通过线性层
emb = self.linear(emb)
# 扩展嵌入的维度,以适配后续操作
emb = emb[:, :, None, None]
# 将嵌入分割为缩放因子和偏移量
scale, shift = emb.chunk(2, dim=1)
# 对输入数据进行分组归一化
x = F.group_norm(x, self.num_groups, eps=self.eps)
# 使用缩放因子和偏移量调整归一化后的数据
x = x * (1 + scale) + shift
# 返回处理后的数据
return x
# 定义一个自定义的神经网络模块,继承自 nn.Module
class AdaLayerNormContinuous(nn.Module):
# 初始化方法,接受多个参数以配置层的特性
def __init__(
self,
embedding_dim: int, # 嵌入维度
conditioning_embedding_dim: int, # 条件嵌入维度
# 注释:规范层可以配置缩放和偏移参数有点奇怪,因为输出会被投影的条件嵌入立即缩放和偏移。
# 注意,AdaLayerNorm 不允许规范层有缩放和偏移参数。
# 但是这是原始代码中的实现,您应该将 `elementwise_affine` 设置为 False。
elementwise_affine=True, # 是否允许元素级的仿射变换
eps=1e-5, # 防止除零错误的小值
bias=True, # 是否在全连接层中使用偏置
norm_type="layer_norm", # 规范化类型
):
super().__init__() # 调用父类构造函数
self.silu = nn.SiLU() # 定义 SiLU 激活函数
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) # 全连接层,输出两倍嵌入维度
# 根据指定的规范类型初始化规范层
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) # 层规范化
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) # RMS 规范化
else:
raise ValueError(f"unknown norm_type {norm_type}") # 抛出错误,若规范类型未知
# 前向传播方法,定义如何计算输出
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# 将条件嵌入转换为与输入 x 相同的数据类型
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) # 应用激活函数和全连接层
scale, shift = torch.chunk(emb, 2, dim=1) # 将输出拆分为缩放和偏移
# 规范化输入 x,并进行缩放和偏移操作
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] # 返回处理后的输出
return x # 返回最终结果
# 定义另一个自定义的神经网络模块,继承自 nn.Module
class LuminaLayerNormContinuous(nn.Module):
# 初始化方法,接受多个参数以配置层的特性
def __init__(
self,
embedding_dim: int, # 嵌入维度
conditioning_embedding_dim: int, # 条件嵌入维度
# 注释:规范层可以配置缩放和偏移参数有点奇怪,因为输出会被投影的条件嵌入立即缩放和偏移。
# 注意,AdaLayerNorm 不允许规范层有缩放和偏移参数。
# 但是这是原始代码中的实现,您应该将 `elementwise_affine` 设置为 False。
elementwise_affine=True, # 是否允许元素级的仿射变换
eps=1e-5, # 防止除零错误的小值
bias=True, # 是否在全连接层中使用偏置
norm_type="layer_norm", # 规范化类型
out_dim: Optional[int] = None, # 可选的输出维度
):
super().__init__() # 调用父类构造函数
# AdaLN
self.silu = nn.SiLU() # 定义 SiLU 激活函数
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) # 全连接层,将条件嵌入映射到嵌入维度
# 根据指定的规范类型初始化规范层
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) # 层规范化
else:
raise ValueError(f"unknown norm_type {norm_type}") # 抛出错误,若规范类型未知
# 如果指定了输出维度,则创建第二个全连接层
if out_dim is not None:
self.linear_2 = nn.Linear(
embedding_dim, # 输入维度为嵌入维度
out_dim, # 输出维度
bias=bias, # 是否使用偏置
)
# 前向传播方法,定义如何计算输出
def forward(
self,
x: torch.Tensor, # 输入张量
conditioning_embedding: torch.Tensor, # 条件嵌入张量
# 返回一个张量,类型为 torch.Tensor
) -> torch.Tensor:
# 将条件嵌入转换回原始数据类型,以防止其被提升为 float32(用于 hunyuanDiT)
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
# 将嵌入值赋给 scale
scale = emb
# 对输入 x 进行规范化,并乘以(1 + scale),同时在新维度上扩展
x = self.norm(x) * (1 + scale)[:, None, :]
# 如果 linear_2 存在,则对 x 应用 linear_2
if self.linear_2 is not None:
x = self.linear_2(x)
# 返回处理后的张量 x
return x
# 定义一个自定义的层,继承自 nn.Module
class CogVideoXLayerNormZero(nn.Module):
# 初始化方法,定义该层的参数
def __init__(
self,
conditioning_dim: int, # 输入的条件维度
embedding_dim: int, # 嵌入的维度
elementwise_affine: bool = True, # 是否启用逐元素仿射变换
eps: float = 1e-5, # 防止除零的一个小常数
bias: bool = True, # 是否添加偏置
) -> None:
# 调用父类的初始化方法
super().__init__()
# 使用 SiLU 激活函数
self.silu = nn.SiLU()
# 线性变换,将条件维度映射到 6 倍的嵌入维度
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
# 归一化层,使用层归一化
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
# 前向传播方法,定义输入和输出
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 通过线性层处理 temb,并分成 6 个部分
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
# 对隐藏状态进行归一化并应用缩放和平移
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
# 对编码器隐藏状态进行相同处理
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
# 返回处理后的隐藏状态和编码器隐藏状态,以及门控信号
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
# 根据 PyTorch 版本决定是否使用标准 LayerNorm
if is_torch_version(">=", "2.1.0"):
# 使用标准的 LayerNorm
LayerNorm = nn.LayerNorm
else:
# 定义自定义的 LayerNorm 类,兼容旧版本 PyTorch
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
# 初始化方法
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
# 调用父类的初始化方法
super().__init__()
# 设置小常数以避免除零
self.eps = eps
# 如果维度是整数,则转为元组
if isinstance(dim, numbers.Integral):
dim = (dim,)
# 保存维度信息
self.dim = torch.Size(dim)
# 如果启用逐元素仿射,则初始化权重和偏置
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
# 前向传播方法
def forward(self, input):
# 应用层归一化
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
# 定义 RMSNorm 类,继承自 nn.Module
class RMSNorm(nn.Module):
# 初始化方法
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
# 调用父类的初始化方法
super().__init__()
# 设置小常数以避免除零
self.eps = eps
# 如果维度是整数,则转为元组
if isinstance(dim, numbers.Integral):
dim = (dim,)
# 保存维度信息
self.dim = torch.Size(dim)
# 如果启用逐元素仿射,则初始化权重
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
# 前向传播方法
def forward(self, hidden_states):
# 保存输入数据类型
input_dtype = hidden_states.dtype
# 计算输入的方差
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
# 对隐藏状态进行缩放
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
# 如果有权重,则进行进一步处理
if self.weight is not None:
# 如果需要,将隐藏状态转换为半精度
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
# 应用权重
hidden_states = hidden_states * self.weight
else:
# 将隐藏状态转换回原数据类型
hidden_states = hidden_states.to(input_dtype)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个全局响应归一化的类,继承自 nn.Module
class GlobalResponseNorm(nn.Module):
# 初始化方法,接受一个维度参数 dim
def __init__(self, dim):
# 调用父类构造函数
super().__init__()
# 初始化可学习参数 gamma,形状为 (1, 1, 1, dim)
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
# 初始化可学习参数 beta,形状为 (1, 1, 1, dim)
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
# 定义前向传播方法,接受输入 x
def forward(self, x):
# 计算输入 x 在 (1, 2) 维度上的 L2 范数,保持维度
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
# 归一化 gx,计算每个样本的均值并防止除以零
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
# 返回归一化后的结果,加上可学习的 gamma 和 beta
return self.gamma * (x * nx) + self.beta + x