diffusers 源码解析(十九)
.\diffusers\pipelines\audioldm2\pipeline_audioldm2.py
import inspect
from typing import Any , Callable , Dict , List , Optional , Union
import numpy as np
import torch
from transformers import (
ClapFeatureExtractor,
ClapModel,
GPT2Model,
RobertaTokenizer,
RobertaTokenizerFast,
SpeechT5HifiGan,
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
VitsModel,
VitsTokenizer,
)
from ...models import AutoencoderKL
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_librosa_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
if is_librosa_available():
import librosa
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """ # 示例文档字符串的开始
``` # 示例文档的内容
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
```py # 示例文档字符串的结束
``` # 示例文档的分隔
Examples:
```py
>>> import scipy # 导入 scipy 库,用于处理音频文件
>>> import torch # 导入 PyTorch 库,用于深度学习模型的计算
>>> from diffusers import AudioLDM2Pipeline # 从 diffusers 库导入 AudioLDM2Pipeline 类,用于音频生成
>>> repo_id = "cvssp/audioldm2" # 定义模型的仓库 ID
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) # 从预训练模型加载管道,并指定数据类型为 float16
>>> pipe = pipe.to("cuda") # 将管道移动到 GPU 上以加速计算
>>> # define the prompts
>>> prompt = "The sound of a hammer hitting a wooden surface." # 定义正向提示语,描述想要生成的音频内容
>>> negative_prompt = "Low quality." # 定义负向提示语,表明不希望生成的音频质量
>>> # set the seed for generator
>>> generator = torch.Generator("cuda").manual_seed(0) # 创建一个 GPU 上的随机数生成器并设置种子
>>> # run the generation
>>> audio = pipe( # 调用生成管道生成音频
... prompt, # 使用正向提示语
... negative_prompt=negative_prompt, # 使用负向提示语
... num_inference_steps=200, # 设置推理步骤数为 200
... audio_length_in_s=10.0, # 设置生成音频的时长为 10 秒
... num_waveforms_per_prompt=3, # 为每个提示生成 3 个波形
... generator=generator, # 使用之前创建的随机数生成器
... ).audios # 获取生成的音频数据
>>> # save the best audio sample (index 0) as a .wav file
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0]) # 将最佳音频样本(索引 0)保存为 .wav 文件,采样率为 16000
```
```py
#Using AudioLDM2 for Text To Speech
>>> import scipy # 导入 scipy 库,用于处理音频文件
>>> import torch # 导入 PyTorch 库,用于深度学习模型的计算
>>> from diffusers import AudioLDM2Pipeline # 从 diffusers 库导入 AudioLDM2Pipeline 类,用于音频生成
>>> repo_id = "anhnct/audioldm2_gigaspeech" # 定义 TTS 模型的仓库 ID
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) # 从预训练模型加载管道,并指定数据类型为 float16
>>> pipe = pipe.to("cuda") # 将管道移动到 GPU 上以加速计算
>>> # define the prompts
>>> prompt = "A female reporter is speaking" # 定义正向提示语,描述想要生成的语音内容
>>> transcript = "wish you have a good day" # 定义要生成的语音的转录文本
>>> # set the seed for generator
>>> generator = torch.Generator("cuda").manual_seed(0) # 创建一个 GPU 上的随机数生成器并设置种子
>>> # run the generation
>>> audio = pipe( # 调用生成管道生成音频
... prompt, # 使用正向提示语
... transcription=transcript, # 使用转录文本
... num_inference_steps=200, # 设置推理步骤数为 200
... audio_length_in_s=10.0, # 设置生成音频的时长为 10 秒
... num_waveforms_per_prompt=2, # 为每个提示生成 2 个波形
... generator=generator, # 使用之前创建的随机数生成器
... max_new_tokens=512, #必须将 max_new_tokens 设置为 512 以用于 TTS
... ).audios # 获取生成的音频数据
>>> # save the best audio sample (index 0) as a .wav file
>>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0]) # 将最佳音频样本(索引 0)保存为 .wav 文件,采样率为 16000
```
# 文档字符串,用于描述函数或类的功能
"""
def prepare_inputs_for_generation (
inputs_embeds,
attention_mask=None ,
past_key_values=None ,
**kwargs,
):
if past_key_values is not None :
inputs_embeds = inputs_embeds[:, -1 :]
return {
"inputs_embeds" : inputs_embeds,
"attention_mask" : attention_mask,
"past_key_values" : past_key_values,
"use_cache" : kwargs.get("use_cache" ),
}
class AudioLDM2Pipeline (DiffusionPipeline ):
r"""
用于基于文本生成音频的管道,使用 AudioLDM2 模型。
该模型继承自 [`DiffusionPipeline`]。请查看超类文档以了解所有管道的通用方法
(下载、保存、在特定设备上运行等)。
# 参数说明部分,描述各个参数的用途
Args:
vae ([`AutoencoderKL`]):
# 变分自编码器 (VAE) 模型,用于将图像编码和解码为潜在表示
text_encoder ([`~transformers.ClapModel`]):
# 第一个被冻结的文本编码器。AudioLDM2 使用联合音频-文本嵌入模型
# [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
# 特别是 [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) 变体。
# 文本分支用于将文本提示编码为提示嵌入。完整的音频-文本模型用于
# 通过计算相似度分数来对生成的波形进行排名。
text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]):
# 第二个被冻结的文本编码器。AudioLDM2 使用
# [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel) 的编码器,
# 特别是 [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) 变体。第二个被冻结的文本编码器
# 用于文本转语音(TTS)。AudioLDM2 使用
# [Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel) 的编码器。
projection_model ([`AudioLDM2ProjectionModel`]):
# 一个训练过的模型,用于线性投影第一个和第二个文本编码器模型的隐藏状态,并插入学习到的 SOS 和 EOS 令牌嵌入。
# 来自两个文本编码器的投影隐藏状态被连接,作为语言模型的输入。
# 为 Vits 隐藏状态提供学习的位置嵌入。
language_model ([`~transformers.GPT2Model`]):
# 自回归语言模型,用于生成一系列基于两个文本编码器的投影输出的隐藏状态。
tokenizer ([`~transformers.RobertaTokenizer`]):
# 用于对第一个被冻结的文本编码器进行文本标记化的标记器。
tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]):
# 用于对第二个被冻结的文本编码器进行文本标记化的标记器。
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
# 特征提取器,用于将生成的音频波形预处理为对数-梅尔谱图,以便进行自动评分。
unet ([`UNet2DConditionModel`]):
# 一个 `UNet2DConditionModel`,用于对编码的音频潜在变量进行去噪。
scheduler ([`SchedulerMixin`]):
# 调度器,与 `unet` 一起用于去噪编码的音频潜在变量。可以是
# [`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`] 之一。
vocoder ([`~transformers.SpeechT5HifiGan`]):
# 类 `SpeechT5HifiGan` 的声码器,用于将梅尔谱图潜在变量转换为最终音频波形。
"""
def __init__ (
self,
vae: AutoencoderKL,
text_encoder: ClapModel,
text_encoder_2: Union [T5EncoderModel, VitsModel],
projection_model: AudioLDM2ProjectionModel,
language_model: GPT2Model,
tokenizer: Union [RobertaTokenizer, RobertaTokenizerFast],
tokenizer_2: Union [T5Tokenizer, T5TokenizerFast, VitsTokenizer],
feature_extractor: ClapFeatureExtractor,
unet: AudioLDM2UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
vocoder: SpeechT5HifiGan,
):
super ().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
projection_model=projection_model,
language_model=language_model,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
feature_extractor=feature_extractor,
unet=unet,
scheduler=scheduler,
vocoder=vocoder,
)
self.vae_scale_factor = 2 ** (len (self.vae.config.block_out_channels) - 1 )
def enable_vae_slicing (self ):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing (self ):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_model_cpu_offload (self, gpu_id=0 ):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=" , "0.17.0.dev0" ):
from accelerate import cpu_offload_with_hook
else :
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." )
device = torch.device(f"cuda:{gpu_id} " )
if self.device.type != "cpu" :
self.to("cpu" , silence_dtype_warnings=True )
torch.cuda.empty_cache()
model_sequence = [
self.text_encoder.text_model,
self.text_encoder.text_projection,
self.text_encoder_2,
self.projection_model,
self.language_model,
self.unet,
self.vae,
self.vocoder,
self.text_encoder,
]
hook = None
for cpu_offloaded_model in model_sequence:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
self.final_offload_hook = hook
def generate_language_model (
self,
inputs_embeds: torch.Tensor = None ,
max_new_tokens: int = 8 ,
**model_kwargs,
):
"""
生成一系列隐藏状态,基于语言模型和嵌入输入进行条件生成。
参数:
inputs_embeds (`torch.Tensor` 形状为 `(batch_size, sequence_length, hidden_size)`):
作为生成提示的序列。
max_new_tokens (`int`):
生成的新标记数量。
model_kwargs (`Dict[str, Any]`, *可选*):
额外模型特定参数的临时参数化,将传递给模型的 `forward` 函数。
返回:
`inputs_embeds (`torch.Tensor` 形状为 `(batch_size, sequence_length, hidden_size)`):
生成的隐藏状态序列。
"""
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
for _ in range (max_new_tokens):
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
output = self.language_model(**model_inputs, return_dict=True )
next_hidden_states = output.last_hidden_state
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1 :, :]], dim=1 )
model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
return inputs_embeds[:, -max_new_tokens:, :]
def encode_prompt (
self,
prompt,
device,
num_waveforms_per_prompt,
do_classifier_free_guidance,
transcription=None ,
negative_prompt=None ,
prompt_embeds: Optional [torch.Tensor] = None ,
negative_prompt_embeds: Optional [torch.Tensor] = None ,
generated_prompt_embeds: Optional [torch.Tensor] = None ,
negative_generated_prompt_embeds: Optional [torch.Tensor] = None ,
attention_mask: Optional [torch.LongTensor] = None ,
negative_attention_mask: Optional [torch.LongTensor] = None ,
max_new_tokens: Optional [int ] = None ,
def mel_spectrogram_to_waveform(self, mel_spectrogram ):
if mel_spectrogram.dim( ) == 4 :
mel_spectrogram = mel_spectrogram.squeeze(1 )
waveform = self.vocoder(mel_spectrogram )
waveform = waveform.cpu( ).float ( )
return waveform
def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype ):
if not is_librosa_available( ):
logger.info(
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
)
return audio
inputs = self.tokenizer(text, return_tensors="pt" , padding=True )
resampled_audio = librosa.resample(
audio.numpy( ), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
)
inputs["input_features" ] = self.feature_extractor(
list (resampled_audio ), return_tensors="pt" , sampling_rate=self.feature_extractor.sampling_rate
).input_features.type (dtype )
inputs = inputs.to(device )
logits_per_text = self.text_encoder(**inputs ).logits_per_text
indices = torch.argsort(logits_per_text, dim=1 , descending=True )[:, :num_waveforms_per_prompt]
audio = torch.index_select(audio, 0 , indices.reshape(-1 ).cpu( ) )
return audio
def prepare_extra_step_kwargs(self, generator, eta ):
accepts_eta = "eta" in set (inspect.signature(self.scheduler.step ).parameters.keys( ) )
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta" ] = eta
accepts_generator = "generator" in set (inspect.signature(self.scheduler.step ).parameters.keys( ) )
if accepts_generator:
extra_step_kwargs["generator" ] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
audio_length_in_s,
vocoder_upsample_factor,
callback_steps,
transcription=None ,
negative_prompt=None ,
prompt_embeds=None ,
negative_prompt_embeds=None ,
generated_prompt_embeds=None ,
negative_generated_prompt_embeds=None ,
attention_mask=None ,
negative_attention_mask=None ,
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None ):
shape = (
batch_size,
num_channels_latents,
int (height ) // self.vae_scale_factor,
int (self.vocoder.config.model_in_dim ) // self.vae_scale_factor,
)
if isinstance (generator, list ) and len (generator ) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len (generator)} , but requested an effective batch"
f" size of {batch_size} . Make sure the batch size matches the length of the generators."
)
if latents is None :
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype )
else :
latents = latents.to(device )
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad( )
@replace_example_docstring(EXAMPLE_DOC_STRING )
def __call__(
prompt: Union [str , List [str ]] = None ,
transcription: Union [str , List [str ]] = None ,
audio_length_in_s: Optional [float ] = None ,
num_inference_steps: int = 200 ,
guidance_scale: float = 3.5 ,
negative_prompt: Optional [Union [str , List [str ]]] = None ,
num_waveforms_per_prompt: Optional [int ] = 1 ,
eta: float = 0.0 ,
generator: Optional [Union [torch.Generator, List [torch.Generator]]] = None ,
latents: Optional [torch.Tensor] = None ,
prompt_embeds: Optional [torch.Tensor] = None ,
negative_prompt_embeds: Optional [torch.Tensor] = None ,
generated_prompt_embeds: Optional [torch.Tensor] = None ,
negative_generated_prompt_embeds: Optional [torch.Tensor] = None ,
attention_mask: Optional [torch.LongTensor] = None ,
negative_attention_mask: Optional [torch.LongTensor] = None ,
max_new_tokens: Optional [int ] = None ,
return_dict: bool = True ,
callback: Optional [Callable [[int , int , torch.Tensor], None ]] = None ,
callback_steps: Optional [int ] = 1 ,
cross_attention_kwargs: Optional [Dict [str , Any ]] = None ,
output_type: Optional [str ] = "np" ,
.\diffusers\pipelines\audioldm2\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_dummy_objects = {}
_import_structure = {}
try :
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=" , "4.27.0" )):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else :
_import_structure["modeling_audioldm2" ] = ["AudioLDM2ProjectionModel" , "AudioLDM2UNet2DConditionModel" ]
_import_structure["pipeline_audioldm2" ] = ["AudioLDM2Pipeline" ]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try :
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=" , "4.27.0" )):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else :
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .pipeline_audioldm2 import AudioLDM2Pipeline
else :
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals ()["__file__" ],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr (sys.modules[__name__], name, value)
.\diffusers\pipelines\aura_flow\pipeline_aura_flow.py
import inspect
from typing import List , Optional , Tuple , Union
import torch
from transformers import T5Tokenizer, UMT5EncoderModel
from ...image_processor import VaeImageProcessor
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
示例:
```py
>>> import torch
>>> from diffusers import AuraFlowPipeline
>>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) # 从预训练模型创建管道
>>> pipe = pipe.to("cuda") # 将管道移动到GPU设备
>>> prompt = "A cat holding a sign that says hello world" # 定义输入提示
>>> image = pipe(prompt).images[0] # 生成图像
>>> image.save("aura_flow.png") # 保存生成的图像
```py
"""
def retrieve_timesteps (
scheduler,
num_inference_steps: Optional [int ] = None ,
device: Optional [Union [str , torch.device]] = None ,
timesteps: Optional [List [int ]] = None ,
sigmas: Optional [List [float ]] = None ,
**kwargs,
):
"""
调用调度器的`set_timesteps`方法并在调用后从调度器检索时间步。处理
自定义时间步。任何kwargs将被传递给`scheduler.set_timesteps`。
# 定义参数说明
Args:
scheduler (`SchedulerMixin`): # 调度器,用于获取时间步
The scheduler to get timesteps from.
num_inference_steps (`int`): # 用于生成样本的扩散步数
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*): # 指定时间步移动的设备
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): # 自定义时间步以覆盖调度器的时间步间隔策略
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*): # 自定义 sigma 以覆盖调度器的时间步间隔策略
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: # 返回一个元组,第一个元素是调度器的时间步计划,第二个元素是推理步骤数量
A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None :
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
if timesteps is not None :
accepts_timesteps = "timesteps" in set (inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__} 's `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len (timesteps)
elif sigmas is not None :
accept_sigmas = "sigmas" in set (inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__} 's `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len (timesteps)
else :
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class AuraFlowPipeline (DiffusionPipeline ):
r"""
参数:
tokenizer (`T5TokenizerFast`):
T5Tokenizer 类的分词器
text_encoder ([`T5EncoderModel`]):
冻结的文本编码器。AuraFlow 使用 T5,具体是
[EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) 变体
vae ([`AutoencoderKL`]):
用于将图像编码和解码为潜在表示的变分自编码器模型
transformer ([`AuraFlowTransformer2DModel`]):
条件 Transformer 架构 (MMDiT 和 DiT) 用于去噪编码的图像潜在表示
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
用于与 `transformer` 结合使用的调度器,以去噪编码的图像潜在表示
"""
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__ (
self,
tokenizer: T5Tokenizer,
text_encoder: UMT5EncoderModel,
vae: AutoencoderKL,
transformer: AuraFlowTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super ().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = (
2 ** (len (self.vae.config.block_out_channels) - 1 ) if hasattr (self, "vae" ) and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def check_inputs (
self,
prompt,
height,
width,
negative_prompt,
prompt_embeds=None ,
negative_prompt_embeds=None ,
prompt_attention_mask=None ,
negative_prompt_attention_mask=None ,
):
if height % 8 != 0 or width % 8 != 0 :
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width} ." )
if prompt is not None and prompt_embeds is not None :
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds} . Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None :
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance (prompt, str ) and not isinstance (prompt, list )):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type (prompt)} " )
if prompt is not None and negative_prompt_embeds is not None :
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds} . Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None :
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds} . Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None :
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`." )
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None :
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`." )
if prompt_embeds is not None and negative_prompt_embeds is not None :
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape} ."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape} ."
)
def encode_prompt (
self,
prompt: Union [str , List [str ]],
negative_prompt: Union [str , List [str ]] = None ,
do_classifier_free_guidance: bool = True ,
num_images_per_prompt: int = 1 ,
device: Optional [torch.device] = None ,
prompt_embeds: Optional [torch.Tensor] = None ,
negative_prompt_embeds: Optional [torch.Tensor] = None ,
prompt_attention_mask: Optional [torch.Tensor] = None ,
negative_prompt_attention_mask: Optional [torch.Tensor] = None ,
max_sequence_length: int = 256 ,
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None ,
):
if latents is not None :
return latents.to(device=device, dtype=dtype )
shape = (
batch_size,
num_channels_latents,
int (height ) // self.vae_scale_factor,
int (width ) // self.vae_scale_factor,
)
if isinstance (generator, list ) and len (generator ) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len (generator)} , but requested an effective batch"
f" size of {batch_size} . Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype )
return latents
def upcast_vae(self ):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32 )
use_torch_2_0_or_xformers = isinstance (
self.vae.decoder.mid_block.attentions[0 ].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
FusedAttnProcessor2_0,
),
)
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype )
self.vae.decoder.conv_in.to(dtype )
self.vae.decoder.mid_block.to(dtype )
@torch.no_grad( )
@replace_example_docstring(EXAMPLE_DOC_STRING )
def __call__(
self,
prompt: Union [str , List [str ]] = None ,
negative_prompt: Union [str , List [str ]] = None ,
num_inference_steps: int = 50 ,
timesteps: List [int ] = None ,
sigmas: List [float ] = None ,
guidance_scale: float = 3.5 ,
num_images_per_prompt: Optional [int ] = 1 ,
height: Optional [int ] = 1024 ,
width: Optional [int ] = 1024 ,
generator: Optional [Union [torch.Generator, List [torch.Generator]]] = None ,
latents: Optional [torch.Tensor] = None ,
prompt_embeds: Optional [torch.Tensor] = None ,
prompt_attention_mask: Optional [torch.Tensor] = None ,
negative_prompt_embeds: Optional [torch.Tensor] = None ,
negative_prompt_attention_mask: Optional [torch.Tensor] = None ,
max_sequence_length: int = 256 ,
output_type: Optional [str ] = "pil" ,
return_dict: bool = True ,
.\diffusers\pipelines\aura_flow\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try :
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else :
_import_structure["pipeline_aura_flow" ] = ["AuraFlowPipeline" ]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try :
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else :
from .pipeline_aura_flow import AuraFlowPipeline
else :
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals ()["__file__" ],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr (sys.modules[__name__], name, value)
.\diffusers\pipelines\auto_pipeline.py
from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
KandinskyInpaintCombinedPipeline,
KandinskyInpaintPipeline,
KandinskyPipeline,
)
from .kandinsky2_2 import (
KandinskyV22CombinedPipeline,
KandinskyV22Img2ImgCombinedPipeline,
KandinskyV22Img2ImgPipeline,
KandinskyV22InpaintCombinedPipeline,
KandinskyV22InpaintPipeline,
KandinskyV22Pipeline,
)
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
)
from .stable_diffusion_3 import (
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3Pipeline,
)
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion" , StableDiffusionPipeline),
("stable-diffusion-xl" , StableDiffusionXLPipeline),
("stable-diffusion-3" , StableDiffusion3Pipeline),
("stable-diffusion-3-pag" , StableDiffusion3PAGPipeline),
("if" , IFPipeline),
("hunyuan" , HunyuanDiTPipeline),
("hunyuan-pag" , HunyuanDiTPAGPipeline),
("kandinsky" , KandinskyCombinedPipeline),
("kandinsky22" , KandinskyV22CombinedPipeline),
("kandinsky3" , Kandinsky3Pipeline),
("stable-diffusion-controlnet" , StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet" , StableDiffusionXLControlNetPipeline),
("wuerstchen" , WuerstchenCombinedPipeline),
("cascade" , StableCascadeCombinedPipeline),
("lcm" , LatentConsistencyModelPipeline),
("pixart-alpha" , PixArtAlphaPipeline),
("pixart-sigma" , PixArtSigmaPipeline),
("stable-diffusion-pag" , StableDiffusionPAGPipeline),
("stable-diffusion-controlnet-pag" , StableDiffusionControlNetPAGPipeline),
("stable-diffusion-xl-pag" , StableDiffusionXLPAGPipeline),
("stable-diffusion-xl-controlnet-pag" , StableDiffusionXLControlNetPAGPipeline),
("pixart-sigma-pag" , PixArtSigmaPAGPipeline),
("auraflow" , AuraFlowPipeline),
("flux" , FluxPipeline),
]
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion" , StableDiffusionImg2ImgPipeline),
("stable-diffusion-xl" , StableDiffusionXLImg2ImgPipeline),
("stable-diffusion-3" , StableDiffusion3Img2ImgPipeline),
("if" , IFImg2ImgPipeline),
("kandinsky" , KandinskyImg2ImgCombinedPipeline),
("kandinsky22" , KandinskyV22Img2ImgCombinedPipeline),
("kandinsky3" , Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet" , StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet" , StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-pag" , StableDiffusionXLPAGImg2ImgPipeline),
("lcm" , LatentConsistencyModelImg2ImgPipeline),
]
)
AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion" , StableDiffusionInpaintPipeline),
("stable-diffusion-xl" , StableDiffusionXLInpaintPipeline),
("stable-diffusion-3" , StableDiffusion3InpaintPipeline),
("if" , IFInpaintingPipeline),
("kandinsky" , KandinskyInpaintCombinedPipeline),
("kandinsky22" , KandinskyV22InpaintCombinedPipeline),
("stable-diffusion-controlnet" , StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-xl-controlnet" , StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag" , StableDiffusionXLPAGInpaintPipeline),
]
)
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
[
("kandinsky" , KandinskyPipeline),
("kandinsky22" , KandinskyV22Pipeline),
("wuerstchen" , WuerstchenDecoderPipeline),
("cascade" , StableCascadeDecoderPipeline),
]
)
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
[
("kandinsky" , KandinskyImg2ImgPipeline),
("kandinsky22" , KandinskyV22Img2ImgPipeline),
]
)
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict(
[
("kandinsky" , KandinskyInpaintPipeline),
("kandinsky22" , KandinskyV22InpaintPipeline),
]
)
if is_sentencepiece_available():
from .kolors import KolorsPipeline
from .pag import KolorsPAGPipeline
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors" ] = KolorsPipeline
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag" ] = KolorsPAGPipeline
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors" ] = KolorsPipeline
SUPPORTED_TASKS_MAPPINGS = [
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING,
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING,
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING,
]
def _get_connected_pipeline (pipeline_cls ):
if pipeline_cls in _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING.values():
return _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False
)
if pipeline_cls in _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING.values():
return _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False
)
if pipeline_cls in _AUTO_INPAINT_DECODER_PIPELINES_MAPPING.values():
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False )
def _get_task_class (mapping, pipeline_class_name, throw_error_if_not_exist: bool = True ):
def get_model (pipeline_class_name ):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
model_name = get_model(pipeline_class_name)
if model_name is not None :
task_class = mapping.get(model_name, None )
if task_class is not None :
return task_class
if throw_error_if_not_exist:
raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name} " )
class AutoPipelineForText2Image (ConfigMixin ):
r"""
[`AutoPipelineForText2Image`] 是一个通用管道类,用于实例化文本到图像的管道类。
特定的基础管道类将通过 [`~AutoPipelineForText2Image.from_pretrained`] 或
[`~AutoPipelineForText2Image.from_pipe`] 方法自动选择。
此类不能通过 `__init__()` 实例化(会抛出错误)。
类属性:
- **config_name** (`str`) -- 存储所有扩散管道组件的类和模块名称的配置文件名。
"""
config_name = "model_index.json"
def __init__ (self, *args, **kwargs ):
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__} .from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__} .from_pipe(pipeline)` methods."
)
@classmethod
@validate_hf_hub_args
@classmethod
class AutoPipelineForImage2Image (ConfigMixin ):
r"""
[`AutoPipelineForImage2Image`] 是一个通用管道类,用于实例化图像到图像的管道类。
特定的基础管道类将通过 [`~AutoPipelineForImage2Image.from_pretrained`] 或
[`~AutoPipelineForImage2Image.from_pipe`] 方法自动选择。
此类不能通过 `__init__()` 实例化(会抛出错误)。
类属性:
- **config_name** (`str`) -- 存储所有扩散管道组件的类和模块名称的配置文件名。
"""
config_name = "model_index.json"
def __init__ (self, *args, **kwargs ):
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__} .from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__} .from_pipe(pipeline)` methods."
)
@classmethod
@validate_hf_hub_args
@classmethod
class AutoPipelineForInpainting (ConfigMixin ):
r"""
[`AutoPipelineForInpainting`] 是一个通用管道类,用于实例化图像修复的管道类。该
# 自动选择特定的基础管道类,可以通过 `from_pretrained` 或 `from_pipe` 方法实现
specific underlying pipeline class is automatically selected from either the
# 无法通过 `__init__()` 方法实例化该类(会抛出错误)
[`~AutoPipelineForInpainting.from_pretrained`] or [`~AutoPipelineForInpainting.from_pipe`] methods.
# 类属性:
# - **config_name** (`str`) -- 存储所有扩散管道组件类和模块名称的配置文件名
This class cannot be instantiated using `__init__()` (throws an error).
# 配置文件名,指向模型索引的 JSON 文件
config_name = "model_index.json"
# 初始化方法,接受任意数量的位置和关键字参数
def __init__(self, *args, **kwargs):
# 抛出环境错误,指示使用特定的方法实例化该类
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
)
# 类方法装饰器,表明该方法是属于类而不是实例的
@classmethod
@validate_hf_hub_args
# 再次标记该方法为类方法
@classmethod
.\diffusers\pipelines\blip_diffusion\blip_image_processing.py
"""Image processor class for BLIP."""
from typing import Dict , List , Optional , Union
import numpy as np
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, is_vision_available, logging
from diffusers.utils import numpy_to_pil
if is_vision_available():
import PIL.Image
logger = logging.get_logger(__name__)
class BlipImageProcessor (BaseImageProcessor ):
r""" # 开始文档字符串,描述该类的用途
Constructs a BLIP image processor. # 构造一个 BLIP 图像处理器
# 参数说明文档
Args:
# 是否调整图像的(高度,宽度)尺寸到指定的 `size`,可通过 `preprocess` 方法中的 `do_resize` 参数覆盖
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
# 输出图像调整大小后的尺寸,默认为 {"height": 384, "width": 384},可通过 `preprocess` 方法中的 `size` 参数覆盖
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
# 如果调整图像大小,使用的重采样滤波器,仅在 `do_resize` 设置为 True 时有效,且可通过 `preprocess` 方法中的 `resample` 参数覆盖
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
# 是否通过指定的缩放因子 `rescale_factor` 对图像进行重新缩放,默认为 True,可通过 `preprocess` 方法中的 `do_rescale` 参数覆盖
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
# 如果对图像进行重新缩放时使用的缩放因子,仅在 `do_rescale` 设置为 True 时有效,且可通过 `preprocess` 方法中的 `rescale_factor` 参数覆盖
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
# 是否对图像进行归一化处理,默认为 True,可通过 `preprocess` 方法中的 `do_normalize` 参数覆盖
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
# 归一化图像时使用的均值,可以是一个浮点数或浮点数列表,其长度与图像通道数相等,可通过 `preprocess` 方法中的 `image_mean` 参数覆盖
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
# 归一化图像时使用的标准差,可以是一个浮点数或浮点数列表,其长度与图像通道数相等,可通过 `preprocess` 方法中的 `image_std` 参数覆盖
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
# 是否将图像转换为 RGB 格式
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values" ]
def __init__ (
self,
do_resize: bool = True ,
size: Dict [str , int ] = None ,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True ,
rescale_factor: Union [int , float ] = 1 / 255 ,
do_normalize: bool = True ,
image_mean: Optional [Union [float , List [float ]]] = None ,
image_std: Optional [Union [float , List [float ]]] = None ,
do_convert_rgb: bool = True ,
do_center_crop: bool = True ,
**kwargs,
) -> None :
super ().__init__(**kwargs)
size = size if size is not None else {"height" : 224 , "width" : 224 }
size = get_size_dict(size, default_to_square=True )
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.do_center_crop = do_center_crop
def resize (
self,
image: np.ndarray,
size: Dict [str , int ],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional [Union [str , ChannelDimension]] = None ,
input_data_format: Optional [Union [str , ChannelDimension]] = None ,
**kwargs,
) -> np.ndarray:
""" # 开始函数文档字符串
Resize an image to `(size["height"], size["width"])`. # 描述函数功能:调整图像大小
Args: # 参数说明部分
image (`np.ndarray`): # 输入参数:待调整大小的图像,类型为 numpy 数组
Image to resize. # 图像的说明
size (`Dict[str, int]`): # 输入参数:字典,包含目标图像的高度和宽度
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. # 字典格式的描述
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): # 可选参数:指定重采样的方法
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. # 重采样过滤器的说明
data_format (`ChannelDimension` or `str`, *optional*): # 可选参数:输出图像的通道维度格式
The channel dimension format for the output image. If unset, the channel dimension format of the input # 描述输入图像通道格式的使用
image is used. Can be one of: # 可能的通道格式选项
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. # 第一种格式的说明
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. # 第二种格式的说明
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. # 第三种格式的说明
input_data_format (`ChannelDimension` or `str`, *optional*): # 可选参数:输入图像的通道维度格式
The channel dimension format for the input image. If unset, the channel dimension format is inferred # 描述输入图像通道格式的推断
from the input image. Can be one of: # 可能的输入格式选项
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. # 第一种格式的说明
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. # 第二种格式的说明
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. # 第三种格式的说明
Returns: # 返回值说明部分
`np.ndarray`: The resized image. # 返回一个调整大小后的 numpy 数组图像
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()} " )
output_size = (size["height" ], size["width" ])
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess (
self,
images: ImageInput,
do_resize: Optional [bool ] = None ,
size: Optional [Dict [str , int ]] = None ,
resample: PILImageResampling = None ,
do_rescale: Optional [bool ] = None ,
do_center_crop: Optional [bool ] = None ,
rescale_factor: Optional [float ] = None ,
do_normalize: Optional [bool ] = None ,
image_mean: Optional [Union [float , List [float ]]] = None ,
image_std: Optional [Union [float , List [float ]]] = None ,
return_tensors: Optional [Union [str , TensorType]] = None ,
do_convert_rgb: bool = None ,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional [Union [str , ChannelDimension]] = None ,
**kwargs,
def postprocess(self, sample: torch.Tensor, output_type: str = "pil" ):
if output_type not in ["pt" , "np" , "pil" ]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
sample = (sample / 2 + 0.5 ).clamp(0 , 1 )
if output_type == "pt" :
return sample
sample = sample.cpu( ).permute(0 , 2 , 3 , 1 ).numpy( )
if output_type == "np" :
return sample
sample = numpy_to_pil(sample )
return sample
.\diffusers\pipelines\blip_diffusion\modeling_blip2.py
from typing import Optional , Tuple , Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import BertTokenizer
from transformers.activations import QuickGELUActivation as QuickGELU
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Encoder,
Blip2PreTrainedModel,
Blip2QFormerAttention,
Blip2QFormerIntermediate,
Blip2QFormerOutput,
)
from transformers.pytorch_utils import apply_chunking_to_forward
from transformers.utils import (
logging,
replace_return_docstrings,
)
logger = logging.get_logger(__name__)
class Blip2TextEmbeddings (nn.Module):
"""从词和位置嵌入构建嵌入。"""
def __init__ (self, config ):
super ().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids" , torch.arange(config.max_position_embeddings).expand((1 , -1 )))
self.position_embedding_type = getattr (config, "position_embedding_type" , "absolute" )
self.config = config
def forward (
self,
input_ids=None ,
position_ids=None ,
query_embeds=None ,
past_key_values_length=0 ,
):
if input_ids is not None :
seq_length = input_ids.size()[1 ]
else :
seq_length = 0
if position_ids is None :
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
if input_ids is not None :
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute" :
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if query_embeds is not None :
batch_size = embeddings.shape[0 ]
query_embeds = query_embeds.repeat(batch_size, 1 , 1 )
embeddings = torch.cat((query_embeds, embeddings), dim=1 )
else :
embeddings = query_embeds
embeddings = embeddings.to(query_embeds.dtype)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class Blip2VisionEmbeddings (nn.Module):
def __init__ (self, config: Blip2VisionConfig ):
super ().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1 , 1 , self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3 , out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(torch.randn(1 , self.num_positions, self.embed_dim))
def forward (self, pixel_values: torch.Tensor ) -> torch.Tensor:
batch_size = pixel_values.shape[0 ]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.flatten(2 ).transpose(1 , 2 )
class_embeds = self.class_embedding.expand(batch_size, 1 , -1 ).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1 )
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1 ), :].to(target_dtype)
return embeddings
class Blip2QFormerEncoder (nn.Module):
def __init__ (self, config ):
super ().__init__()
self.config = config
self.layer = nn.ModuleList(
[Blip2QFormerLayer(config, layer_idx) for layer_idx in range (config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward (
self,
hidden_states,
attention_mask=None ,
head_mask=None ,
encoder_hidden_states=None ,
encoder_attention_mask=None ,
past_key_values=None ,
use_cache=None ,
output_attentions=False ,
output_hidden_states=False ,
return_dict=True ,
query_length=0 ,
class Blip2QFormerLayer(nn.Module ):
def __init__(self, config, layer_idx ):
super ( ).__init__( )
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config )
self.layer_idx = layer_idx
if layer_idx % config.cross_attention_frequency == 0 :
self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True )
self.has_cross_attention = True
else :
self.has_cross_attention = False
self.intermediate = Blip2QFormerIntermediate(config )
self.intermediate_query = Blip2QFormerIntermediate(config )
self.output_query = Blip2QFormerOutput(config )
self.output = Blip2QFormerOutput(config )
def forward(
self,
hidden_states,
attention_mask=None ,
head_mask=None ,
encoder_hidden_states=None ,
encoder_attention_mask=None ,
past_key_value=None ,
output_attentions=False ,
query_length=0 ,
):
self_attn_past_key_value = past_key_value[:2 ] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0 ]
outputs = self_attention_outputs[1 :-1 ]
present_key_value = self_attention_outputs[-1 ]
if query_length > 0 :
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
if encoder_hidden_states is None :
raise ValueError("encoder_hidden_states must be given for cross-attention layers" )
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0 ]
outputs = outputs + cross_attention_outputs[1 :-1 ]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1 ] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1 )
else :
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output, ) + outputs
outputs = outputs + (present_key_value, )
return outputs
def feed_forward_chunk(self, attention_output ):
intermediate_output = self.intermediate(attention_output )
layer_output = self.output(intermediate_output, attention_output )
return layer_output
def feed_forward_chunk_query(self, attention_output ):
intermediate_output = self.intermediate_query(attention_output )
layer_output = self.output_query(intermediate_output, attention_output )
return layer_output
class ProjLayer(nn.Module ):
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1 , eps=1e-12 ):
super ( ).__init__( )
self.dense1 = nn.Linear(in_dim, hidden_dim )
self.act_fn = QuickGELU( )
self.dense2 = nn.Linear(hidden_dim, out_dim )
self.dropout = nn.Dropout(drop_p )
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps )
def forward(self, x ):
x_in = x
x = self.LayerNorm(x )
x = self.dropout(self.dense2(self.act_fn(self.dense1(x ) ) ) ) + x_in
return x
class Blip2VisionModel(Blip2PreTrainedModel ):
main_input_name = "pixel_values"
config_class = Blip2VisionConfig
def __init__(self, config: Blip2VisionConfig ):
super ( ).__init__(config )
self.config = config
embed_dim = config.hidden_size
self.embeddings = Blip2VisionEmbeddings(config )
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps )
self.encoder = Blip2Encoder(config )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps )
self.post_init( )
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig )
def forward(
self,
pixel_values: Optional [torch.Tensor] = None ,
output_attentions: Optional [bool ] = None ,
output_hidden_states: Optional [bool ] = None ,
return_dict: Optional [bool ] = None ,
) -> Union [Tuple , BaseModelOutputWithPooling]:
r""" # 文档字符串的开始,通常用于描述函数的用途
Returns: # 返回部分的说明
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None :
raise ValueError("You have to specify pixel_values" )
hidden_states = self.embeddings(pixel_values )
hidden_states = self.pre_layernorm(hidden_states )
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0 ]
last_hidden_state = self.post_layernorm(last_hidden_state )
pooled_output = last_hidden_state[:, 0 , :]
pooled_output = self.post_layernorm(pooled_output )
if not return_dict:
return (last_hidden_state, pooled_output ) + encoder_outputs[1 :]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self ):
return self.embeddings
class Blip2QFormerModel(Blip2PreTrainedModel ):
"""
Querying Transformer (Q-Former), used in BLIP-2.
"""
def __init__(self, config: Blip2Config ):
super ( ).__init__(config )
self.config = config
self.embeddings = Blip2TextEmbeddings(config.qformer_config )
self.visual_encoder = Blip2VisionModel(config.vision_config )
self.query_tokens = nn.Parameter(torch.zeros(1 , config.num_query_tokens, config.qformer_config.hidden_size ) )
if not hasattr (config, "tokenizer" ) or config.tokenizer is None :
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased" , truncation_side="right" )
else :
self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right" )
self.tokenizer.add_special_tokens({"bos_token" : "[DEC]" } )
self.proj_layer = ProjLayer(
in_dim=config.qformer_config.hidden_size,
out_dim=config.qformer_config.hidden_size,
hidden_dim=config.qformer_config.hidden_size * 4 ,
drop_p=0.1 ,
eps=1e-12 ,
)
self.encoder = Blip2QFormerEncoder(config.qformer_config )
self.post_init( )
def get_input_embeddings(self ):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value ):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune ):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items( ):
self.encoder.layer[layer].attention.prune_heads(heads )
def get_extended_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple [int ],
device: torch.device,
has_query: bool = False ,
) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored. # 准备可广播的注意力和因果掩码,以忽略未来和被掩盖的标记。
Arguments: # 参数说明
attention_mask (`torch.Tensor`): # 注意力掩码,类型为 torch.Tensor
Mask with ones indicating tokens to attend to, zeros for tokens to ignore. # 掩码中,1表示要关注的标记,0表示要忽略的标记。
input_shape (`Tuple[int]`): # 输入的形状,类型为整数元组
The shape of the input to the model. # 模型输入的形状。
device (`torch.device`): # 输入的设备类型
The device of the input to the model. # 模型输入的设备。
Returns: # 返回值说明
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. # 返回扩展的注意力掩码,其数据类型与 attention_mask 的数据类型相同。
"""
if attention_mask.dim( ) == 3 :
extended_attention_mask = attention_mask[:, None , :, :]
elif attention_mask.dim( ) == 2 :
extended_attention_mask = attention_mask[:, None , None , :]
else :
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})" .format (
input_shape, attention_mask.shape
)
)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype )
extended_attention_mask = (1.0 - extended_attention_mask ) * -10000.0
return extended_attention_mask
def forward(
self,
text_input=None ,
image_input=None ,
head_mask=None ,
encoder_hidden_states=None ,
encoder_attention_mask=None ,
past_key_values=None ,
use_cache=None ,
output_attentions=None ,
output_hidden_states=None ,
return_dict=None ,
.\diffusers\pipelines\blip_diffusion\modeling_ctx_clip.py
from typing import Optional , Tuple , Union
import torch
from torch import nn
from transformers import CLIPPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.configuration_clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPEncoder
def _expand_mask (mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional [int ] = None ):
"""
扩展 attention_mask 从 `[bsz, seq_len]` 到 `[bsz, 1, tgt_seq_len, src_seq_len]`。
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None , None , :].expand(bsz, 1 , tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool ), torch.finfo(dtype).min )
class ContextCLIPTextModel (CLIPPreTrainedModel ):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPEncoderLayer" ]
def __init__ (self, config: CLIPTextConfig ):
super ().__init__(config)
self.text_model = ContextCLIPTextTransformer(config)
self.post_init()
def forward (
self,
ctx_embeddings: torch.Tensor = None ,
ctx_begin_pos: list = None ,
input_ids: Optional [torch.Tensor] = None ,
attention_mask: Optional [torch.Tensor] = None ,
position_ids: Optional [torch.Tensor] = None ,
output_attentions: Optional [bool ] = None ,
output_hidden_states: Optional [bool ] = None ,
return_dict: Optional [bool ] = None ,
) -> Union [Tuple , BaseModelOutputWithPooling]:
return self.text_model(
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class ContextCLIPTextTransformer (nn.Module):
def __init__ (self, config: CLIPTextConfig ):
super ().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = ContextCLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward (
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list ,
input_ids: Optional [torch.Tensor] = None ,
attention_mask: Optional [torch.Tensor] = None ,
position_ids: Optional [torch.Tensor] = None ,
output_attentions: Optional [bool ] = None ,
output_hidden_states: Optional [bool ] = None ,
return_dict: Optional [bool ] = None ,
) -> Union [Tuple , BaseModelOutputWithPooling]:
r"""
# 文档字符串,说明返回值类型
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None :
raise ValueError("You have to specify either input_ids" )
input_shape = input_ids.size()
input_ids = input_ids.view(-1 , input_shape[-1 ])
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
)
bsz, seq_len = input_shape
if ctx_embeddings is not None :
seq_len += ctx_embeddings.size(1 )
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
if attention_mask is not None :
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0 ]
last_hidden_state = self.final_layer_norm(last_hidden_state)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0 ], device=input_ids.device),
input_ids.to(torch.int ).argmax(dim=-1 ),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1 :]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask (self, bsz, seq_len, dtype ):
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min ))
mask.triu_(1 )
mask = mask.unsqueeze(1 )
return mask
class ContextCLIPTextEmbeddings (nn.Module):
def __init__ (self, config: CLIPTextConfig ):
super ().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
self.register_buffer("position_ids" , torch.arange(config.max_position_embeddings).expand((1 , -1 )))
def forward (
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list ,
input_ids: Optional [torch.LongTensor] = None ,
position_ids: Optional [torch.LongTensor] = None ,
inputs_embeds: Optional [torch.Tensor] = None ,
) -> torch.Tensor:
if ctx_embeddings is None :
ctx_len = 0
else :
ctx_len = ctx_embeddings.shape[1 ]
seq_length = (input_ids.shape[-1 ] if input_ids is not None else inputs_embeds.shape[-2 ]) + ctx_len
if position_ids is None :
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None :
inputs_embeds = self.token_embedding(input_ids)
input_embeds_ctx = []
bsz = inputs_embeds.shape[0 ]
if ctx_embeddings is not None :
for i in range (bsz):
cbp = ctx_begin_pos[i]
prefix = inputs_embeds[i, :cbp]
suffix = inputs_embeds[i, cbp:]
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0 ))
inputs_embeds = torch.stack(input_embeds_ctx, dim=0 )
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 如何使用 Uni-app 实现视频聊天(源码,支持安卓、iOS)
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)