Lucidrains-系列项目源码解析-八-

Lucidrains 系列项目源码解析(八)

.\lucidrains\DALLE-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('dalle_pytorch/version.py').read())

# 设置包的元信息
setup(
  # 包名
  name = 'dalle-pytorch',
  # 查找所有包
  packages = find_packages(),
  # 包含所有数据文件
  include_package_data = True,
  # 版本号
  version = __version__,
  # 许可证
  license='MIT',
  # 描述
  description = 'DALL-E - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/dalle-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'text-to-image'
  ],
  # 安装依赖
  install_requires=[
    'axial_positional_embedding',
    'DALL-E',
    'einops>=0.3.2',
    'ftfy',
    'packaging',
    'pillow',
    'regex',
    'rotary-embedding-torch',
    'taming-transformers-rom1504',
    'tokenizers',
    'torch>=1.6',
    'torchvision',
    'transformers',
    'tqdm',
    'youtokentome',
    'WebDataset'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\DALLE-pytorch\train_dalle.py

# 导入必要的库
import argparse
from pathlib import Path
import time
from glob import glob
import os
import shutil

import torch
import wandb  # 如果用户没有安装 wandb,则提前退出
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

# 导入 DALL-E 相关模块
from dalle_pytorch import __version__
from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE
from dalle_pytorch import distributed_utils
from dalle_pytorch.loader import TextImageDataset
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer

# 导入用于支持 webdataset 的库
import webdataset as wds
from torchvision import transforms as T
from PIL import Image
from io import BytesIO

# 参数解析
parser = argparse.ArgumentParser()

group = parser.add_mutually_exclusive_group(required=False)

# 添加参数:离散 VAE 的路径
group.add_argument('--vae_path', type=str,
                   help='path to your trained discrete VAE')

# 添加参数:部分训练的 DALL-E 的路径
group.add_argument('--dalle_path', type=str,
                   help='path to your partially trained DALL-E')

# 添加参数:训练好的 VQGAN 权重路径
parser.add_argument('--vqgan_model_path', type=str, default=None,
                   help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

# 添加参数:训练好的 VQGAN 配置路径
parser.add_argument('--vqgan_config_path', type=str, default=None,
                   help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')

# 添加参数:包含图像和文本用于学习 DALL-E 的文件夹路径
parser.add_argument('--image_text_folder', type=str, required=True,
                    help='path to your folder of images and text for learning the DALL-E')

# 添加参数:WebDataset 的列名,用于图像和文本
parser.add_argument('--wds', type=str, default='',
                    help='Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.')

# 添加参数:是否截断超过最大标记长度的标题
parser.add_argument('--truncate_captions', dest='truncate_captions', action='store_true',
                    help='Captions passed in which exceed the max token length will be truncated if this is set.')

# 添加参数:随机调整裁剪的较低比率
parser.add_argument('--random_resize_crop_lower_ratio', dest='resize_ratio', type=float, default=0.75,
                    help='Random resized crop lower ratio')

# 添加参数:是否使用中文
parser.add_argument('--chinese', dest='chinese', action='store_true')

# 添加参数:是否启用 taming 模式
parser.add_argument('--taming', dest='taming', action='store_true')

# 添加参数:是否使用 Hugging Face Tokenizer
parser.add_argument('--hug', dest='hug', action='store_true')

# 添加参数:BPE json 文件路径
parser.add_argument('--bpe_path', type=str,
                    help='path to your BPE json file')

# 添加参数:DALL-E 输出文件名
parser.add_argument('--dalle_output_file_name', type=str, default="dalle",
                    help='output_file_name')

# 添加参数:启用 DeepSpeed 16 位精度
parser.add_argument('--fp16', action='store_true',
                    help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.')

# 添加参数:启用 Apex "O1" 自动混合精度
parser.add_argument('--amp', action='store_true',
                   help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.')

# 添加参数:W&B 保存结果时使用的名称
parser.add_argument('--wandb_name', default='dalle_train_transformer',
                    help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')

# 添加参数:W&B 日志记录的团队/实体名称
parser.add_argument('--wandb_entity', default=None,
                    help='(optional) Name of W&B team/entity to log to.')

# 添加参数:稳定 softmax,防止在 softmax 过程中值变得过大
parser.add_argument('--stable_softmax', dest='stable_softmax', action='store_true',
                    help='Prevent values from becoming too large during softmax. Helps with stability in fp16 and Mixture of Quantization training.')

# 分布式训练参数
parser = distributed_utils.wrap_arg_parser(parser)

# 训练设置参数
train_group = parser.add_argument_group('Training settings')

# 添加参数:是否启用 FLOPS 分析
train_group.add_argument('--flops_profiler', dest='flops_profiler', action='store_true', help='Exits after printing detailed flops/runtime analysis of forward/backward')

# 添加参数:训练轮数
train_group.add_argument('--epochs', default=20, type=int, help='Number of epochs')
# 添加一个参数到训练组,保存每n步一个检查点
train_group.add_argument('--save_every_n_steps', default=1000, type=int, help='Save a checkpoint every n steps')

# 添加一个参数到训练组,保留n个检查点,如果检查点数量超过n则删除旧的deepspeed检查点(谨慎操作)
train_group.add_argument('--keep_n_checkpoints', default=None, type=int, help='(Careful) Deletes old deepspeed checkpoints if there are more than n')

# 添加一个参数到训练组,批量大小
train_group.add_argument('--batch_size', default=4, type=int, help='Batch size')

# 添加一个参数到训练组,GA步数,每次迭代中跨步累积梯度的步数。仅适用于DeepSpeed。
train_group.add_argument('--ga_steps', default=1, type=int, help='Number of steps to accumulate gradients across per each iteration. DeepSpeed only.')

# 添加一个参数到训练组,学习率
train_group.add_argument('--learning_rate', default=3e-4, type=float, help='Learning rate')

# 添加一个参数到训练组,梯度规范化裁剪
train_group.add_argument('--clip_grad_norm', default=0.5, type=float, help='Clip gradient norm')

# 添加一个参数到训练组,学习率衰减
train_group.add_argument('--lr_decay', dest='lr_decay', action='store_true')

# 创建模型设置参数组
model_group = parser.add_argument_group('Model settings')

# 添加一个参数到模型设置组,模型维度
model_group.add_argument('--dim', default=512, type=int, help='Model dimension')

# 添加一个参数到模型设置组,文本序列长度
model_group.add_argument('--text_seq_len', default=256, type=int, help='Text sequence length')

# 添加一个参数到模型设置组,模型深度
model_group.add_argument('--depth', default=2, type=int, help='Model depth')

# 添加一个参数到模型设置组,模型头数
model_group.add_argument('--heads', default=8, type=int, help='Model number of heads')

# 添加一个参数到模型设置组,模型头维度
model_group.add_argument('--dim_head', default=64, type=int, help='Model head dimension')

# 添加一个参数到训练组,前馈层dropout
train_group.add_argument('--ff_dropout', default=0.0, type=float, help='Feed forward dropout.')

# 添加一个参数到训练组,注意力dropout
train_group.add_argument('--attn_dropout', default=0.0, type=float, help='Feed forward dropout.')

# 添加一个参数到模型设置组,可逆性
model_group.add_argument('--reversible', dest='reversible', action='store_true')

# 添加一个参数到模型设置组,图像损失权重
model_group.add_argument('--loss_img_weight', default=7, type=int, help='Image loss weight')

# 添加一个参数到模型设置组,注意力类型
model_group.add_argument('--attn_types', default='full', type=str, help='comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.')

# 添加一个参数到模型设置组,使用移位标记特性
model_group.add_argument('--shift_tokens', help='Use the shift tokens feature', action='store_true')

# 添加一个参数到模型设置组,使用旋转嵌入
model_group.add_argument('--rotary_emb', help='Use rotary embeddings', action='store_true')

# 添加一个参数到模型设置组,共享注意力层ID
model_group.add_argument('--shared_attn_ids', default=None, type=str, help='Comma separated list of shared attention layer ids. Default: sharing is disabled')

# 添加一个参数到模型设置组,共享前馈层ID
model_group.add_argument('--shared_ff_ids', default=None, type=str, help='Comma separated list of shared feed forward layer ids. Default: sharing is disabled')

# 添加一个参数到模型设置组,共享输入和输出嵌入
model_group.add_argument('--share_input_output_emb', help='Share input and output embeddings', action='store_true')

# 解析命令行参数
args = parser.parse_args()

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 获取可训练参数
def get_trainable_params(model):
    return [params for params in model.parameters() if params.requires_grad]

# 将检查点路径转换为带有插入标签的目录
def cp_path_to_dir(cp_path, tag):
    """Convert a checkpoint path to a directory with `tag` inserted.
    If `cp_path` is already a directory, return it unchanged.
    """
    if not isinstance(cp_path, Path):
        cp_path = Path(cp_path)
    if cp_path.is_dir():
        return cp_path
    path_sans_extension = cp_path.parent / cp_path.stem
    cp_dir = Path(f'{path_sans_extension}-{tag}-cp')
    return cp_dir

# 常量

# 图像文本列
WEBDATASET_IMAGE_TEXT_COLUMNS = tuple(args.wds.split(','))
ENABLE_WEBDATASET = True if len(WEBDATASET_IMAGE_TEXT_COLUMNS) == 2 else False

# DALLE输出文件名
DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name + ".pt"

# VAE路径
VAE_PATH = args.vae_path
VQGAN_MODEL_PATH = args.vqgan_model_path
VQGAN_CONFIG_PATH = args.vqgan_config_path
DALLE_PATH = args.dalle_path
RESUME = exists(DALLE_PATH)

# 训练周期
EPOCHS = args.epochs
BATCH_SIZE = args.batch_size

# 学习率
LEARNING_RATE = args.learning_rate
GRAD_CLIP_NORM = args.clip_grad_norm
LR_DECAY = args.lr_decay
SAVE_EVERY_N_STEPS = args.save_every_n_steps
KEEP_N_CHECKPOINTS = args.keep_n_checkpoints

# 模型维度
MODEL_DIM = args.dim
TEXT_SEQ_LEN = args.text_seq_len
DEPTH = args.depth
HEADS = args.heads
DIM_HEAD = args.dim_head
REVERSIBLE = args.reversible
# 从参数中获取损失图像权重
LOSS_IMG_WEIGHT = args.loss_img_weight
# 从参数中获取前馈神经网络的丢弃率
FF_DROPOUT = args.ff_dropout
# 从参数中获取注意力机制的丢弃率
ATTN_DROPOUT = args.attn_dropout
# 从参数中获取是否使用稳定的 softmax 函数
STABLE = args.stable_softmax
# 从参数中获取是否移动标记
SHIFT_TOKENS = args.shift_tokens
# 从参数中获取是否使用旋转嵌入
ROTARY_EMB = args.rotary_emb

# 从参数中获取注意力类型并转换为元组
ATTN_TYPES = tuple(args.attn_types.split(','))
# 如果存在共享的注意力 ID,则从参数中获取并转换为元组,否则为 None
SHARED_ATTN_IDS = tuple(args.shared_attn_ids.split(',')) if exists(args.shared_attn_ids) else None
# 如果存在共享的前馈神经网络 ID,则从参数中获取并转换为元组,否则为 None
SHARED_FF_IDS = tuple(args.shared_ff_ids.split(',')) if exists(args.shared_ff_ids) else None
# 从参数中获取是否共享输入输出嵌入
SHARE_INPUT_OUTPUT_EMB = args.share_input_output_emb

# 定义 DeepSpeed 检查点辅助文件名
DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'

# 如果未启用 WebDataset
if not ENABLE_WEBDATASET:
    # 如果指定的图像文本文件夹不存在,则抛出异常
    assert Path(args.image_text_folder).exists(), f'The path {args.image_text_folder} was not found.'
# 如果启用了 WebDataset
else:
    # 如果图像文本文件夹是一个目录
    if Path(args.image_text_folder).is_dir():
        # 获取目录下所有的 .tar 文件路径
        DATASET = [str(p) for p in Path(args.image_text_folder).glob("**/*") if ".tar" in str(p).lower()] # .name
        # 如果找到的 .tar 文件数量为 0,则抛出异常
        assert len(DATASET) > 0, 'The directory ({}) does not contain any WebDataset/.tar files.'.format(args.image_text_folder)
        print('Found {} WebDataset .tar(.gz) file(s) under given path {}!'.format(len(DATASET), args.image_text_folder))
    # 如果图像文本文件夹是一个 http(s) 链接
    elif ('http://' in args.image_text_folder.lower()) | ('https://' in args.image_text_folder.lower()):
        # 设置 DATASET 为 http(s) 链接
        DATASET = f"pipe:curl -L -s {args.image_text_folder} || true"
        print('Found {} http(s) link under given path!'.format(len(DATASET), args.image_text_folder))
    # 如果图像文本文件夹是一个 Google Cloud Storage (GCS) 链接
    elif 'gs://' in args.image_text_folder.lower():
        # 设置 DATASET 为 GCS 链接
        DATASET = f"pipe:gsutil cat {args.image_text_folder} || true"
        print('Found {} GCS link under given path!'.format(len(DATASET), args.image_text_folder))
    # 如果图像文本文件夹包含 .tar 文件
    elif '.tar' in args.image_text_folder:
        # 设置 DATASET 为图像文本文件夹路径
        DATASET = args.image_text_folder
        print('Found WebDataset .tar(.gz) file under given path {}!'.format(args.image_text_folder))
    else:
        # 抛出异常,未提供文件夹、.tar(.gz) 文件或指向 .tar 文件的 URL
        raise Exception('No folder, no .tar(.gz) and no url pointing to tar files provided under {}.'.format(args.image_text_folder))

# 初始化分布式后端
distr_backend = distributed_utils.set_backend_from_args(args)
distr_backend.initialize()

# 检查是否使用 DeepSpeed
using_deepspeed = distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)
# 检查当前进程是否为根进程
is_root = distr_backend.is_root_worker()

# 分词器
if exists(args.bpe_path):
    # 根据 BPE 路径选择分词器类
    klass = HugTokenizer if args.hug else YttmTokenizer
    tokenizer = klass(args.bpe_path)
elif args.chinese:
    # 如果是中文文本,则使用中文分词器
    tokenizer = ChineseTokenizer()

# 重建 VAE
if RESUME:
    # 获取 DALL-E 模型路径
    dalle_path = Path(DALLE_PATH)
    # 如果使用 DeepSpeed,则获取 DeepSpeed 检查点目录
    if using_deepspeed:
        cp_dir = cp_path_to_dir(dalle_path, 'ds')
        # 检查 DeepSpeed 检查点目录是否存在
        assert cp_dir.is_dir(), f'DeepSpeed checkpoint directory {cp_dir} not found'
        dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME
    else:
        # 检查 DALL-E 模型文件是否存在
        assert dalle_path.exists(), 'DALL-E model file does not exist'
    # 加载模型参数、VAE 参数、权重等信息
    loaded_obj = torch.load(str(dalle_path), map_location='cpu')

    dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
    opt_state = loaded_obj.get('opt_state')
    scheduler_state = loaded_obj.get('scheduler_state')

    # 根据 VAE 参数初始化 VAE 模型
    if vae_params is not None:
        vae = DiscreteVAE(**vae_params)
    elif args.taming:
        vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
    else:
        vae = OpenAIDiscreteVAE()

    # 获取恢复的训练轮数
    resume_epoch = loaded_obj.get('epoch', 0)
else:
    # 如果存在 VAE 模型路径
    if exists(VAE_PATH):
        # 获取 VAE 模型路径
        vae_path = Path(VAE_PATH)
        # 检查 VAE 模型文件是否存在
        assert vae_path.exists(), 'VAE model file does not exist'
        assert not vae_path.is_dir(), \
            ('Cannot load VAE model from directory; please use a '
             'standard *.pt checkpoint. '
             'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
             'model is not supported.')

        # 加载 VAE 模型参数和权重
        loaded_obj = torch.load(str(vae_path))

        vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

        # 根据 VAE 参数初始化 VAE 模型,并加载权重
        vae = DiscreteVAE(**vae_params)
        vae.load_state_dict(weights)
    else:
        # 如果不是预训练模型,则打印提示信息
        if is_root:
            print('using pretrained VAE for encoding images to tokens')
        # 初始化 VAE 参数为 None
        vae_params = None

        # 如果使用 Taming 模型
        if args.taming:
            # 使用 VQGanVAE 模型
            vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
        else:
            # 使用 OpenAIDiscreteVAE 模型
            vae = OpenAIDiscreteVAE()

    # 初始化 DALL-E 参数字典
    dalle_params = dict(
        num_text_tokens=tokenizer.vocab_size,
        text_seq_len=TEXT_SEQ_LEN,
        dim=MODEL_DIM,
        depth=DEPTH,
        heads=HEADS,
        dim_head=DIM_HEAD,
        reversible=REVERSIBLE,
        loss_img_weight=LOSS_IMG_WEIGHT,
        attn_types=ATTN_TYPES,
        ff_dropout=FF_DROPOUT,
        attn_dropout=ATTN_DROPOUT,
        stable=STABLE,
        shift_tokens=SHIFT_TOKENS,
        rotary_emb=ROTARY_EMB,
        shared_attn_ids=SHARED_ATTN_IDS,
        shared_ff_ids=SHARED_FF_IDS,
        share_input_output_emb=SHARE_INPUT_OUTPUT_EMB,
    )
    # 初始化恢复训练的轮次为 0
    resume_epoch = 0
# 设置图像大小为VAE的图像大小
IMAGE_SIZE = vae.image_size
# 设置通道数为VAE的通道数
CHANNELS = vae.channels
# 判断是否为透明通道
TRANSPARENT = CHANNELS == 4
# 设置图像模式为RGBA或RGB
IMAGE_MODE = 'RGBA' if CHANNELS == 4 else 'RGB'

# 配置OpenAI VAE为float16s
if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
    # 如果是OpenAI离散VAE并且启用了fp16,设置编码器的输出卷积为float16
    vae.enc.blocks.output.conv.use_float16 = True

# 辅助函数

# 对模型的参数进行分组
def group_weight(model):
    group_decay, group_no_decay = [], []
    for params in model.named_parameters():
        if 'transformer' in params[0]:
            if 'bias' in params[0] or 'norm' in params[0]:
                group_no_decay.append(params[1])
                continue
        group_decay.append(params[1])

    assert len(list(model.parameters())) == len(group_decay) + len(group_no_decay)
    groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
    return groups

# 创建数据集和数据加载器

# 是否打乱数据集
is_shuffle = not distributed_utils.using_backend(distributed_utils.HorovodBackend)

# 图像预处理
imagepreproc = T.Compose([
    T.Lambda(lambda img: img.convert(IMAGE_MODE) if img.mode != IMAGE_MODE else img),
    T.RandomResizedCrop(IMAGE_SIZE, scale=(args.resize_ratio, 1.), ratio=(1., 1.)),
    T.ToTensor(),
])

# 图像转换函数
def imagetransform(b):
    return Image.open(BytesIO(b))

# 分词函数
def tokenize(s):
    return tokenizer.tokenize(s.decode('utf-8'), TEXT_SEQ_LEN, truncate_text=args.truncate_captions).squeeze(0)

if ENABLE_WEBDATASET:
    # 设置数据集大小
    DATASET_SIZE = int(1e9) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader

    myimg, mycap = WEBDATASET_IMAGE_TEXT_COLUMNS
    # 图像文本映射
    image_text_mapping = {
        myimg: imagetransform,
        mycap: tokenize
    }
    # 图像映射
    image_mapping = {
        myimg: imagepreproc
    }

    # 数据集过滤函数
    def filter_dataset(item):
        if mycap not in item:
            return False
        if myimg not in item:
            return False
        return True

    # 创建WebDataset
    w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue)
    filtered_dataset = w_dataset.select(filter_dataset)
    ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE / distr_backend.get_world_size(), partial=True)
else:
    # 创建TextImageDataset
    ds = TextImageDataset(
        args.image_text_folder,
        text_len=TEXT_SEQ_LEN,
        image_size=IMAGE_SIZE,
        transparent=TRANSPARENT,
        resize_ratio=args.resize_ratio,
        truncate_captions=args.truncate_captions,
        tokenizer=tokenizer,
        shuffle=is_shuffle,
    )
    assert len(ds) > 0, 'dataset is empty'

if is_root:
    if not ENABLE_WEBDATASET:
        print(f'{len(ds)} image-text pairs found for training')

# 数据采样器

data_sampler = None

if not is_shuffle:
    data_sampler = torch.utils.data.distributed.DistributedSampler(
        ds,
        num_replicas=distr_backend.get_world_size(),
        rank=distr_backend.get_rank()
    )

# WebLoader用于WebDataset和DeepSpeed兼容性

if ENABLE_WEBDATASET:
    dl = wds.WebLoader(ds, batch_size=None, shuffle=False, num_workers=4) # optionally add num_workers=2 (n) argument
    number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size())
    dl = dl.slice(number_of_batches)
    dl.length = number_of_batches
else:
    # 用于图像文本文件夹数据集的常规DataLoader
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)

# 初始化DALL-E

dalle = DALLE(vae=vae, **dalle_params)

if not using_deepspeed:
    if args.fp16:
        # 如果启用fp16,将DALL-E设置为半精度
        dalle = dalle.half()
    # 将DALL-E移动到GPU
    dalle = dalle.cuda()

if RESUME and not using_deepspeed:
    # 如果恢复训练并且不使用DeepSpeed,加载权重
    dalle.load_state_dict(weights)

# 优化器

# 创建Adam优化器
opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)

if RESUME and opt_state:
    # 如果恢复训练并且有优化器状态,加载优化器状态
    opt.load_state_dict(opt_state)

# 调度器

scheduler = None

if LR_DECAY:
    # 创建一个学习率调度器 ReduceLROnPlateau
    scheduler = ReduceLROnPlateau(
        opt,  # 传入优化器
        mode="min",  # 设置模式为最小化
        factor=0.5,  # 学习率调整因子
        patience=10,  # 忍耐次数
        cooldown=10,  # 冷却时间
        min_lr=1e-6,  # 最小学习率
        verbose=True,  # 是否打印信息
    )
    # 如果 RESUME 为真且存在学习率调度器状态
    if RESUME and scheduler_state:
        # 加载学习率调度器状态
        scheduler.load_state_dict(scheduler_state)
# 实验跟踪器

# 如果是根节点
if is_root:

    # 定义模型配置字典
    model_config = dict(
        depth=DEPTH,
        heads=HEADS,
        dim_head=DIM_HEAD
    )

    # 初始化 wandb 实验
    run = wandb.init(
        project=args.wandb_name,
        entity=args.wandb_entity,
        resume=False,
        config=model_config,
    )

# 分发

# 检查批量大小是否符合要求
distr_backend.check_batch_size(BATCH_SIZE)
# 配置 DeepSpeed
deepspeed_config = {
    'train_batch_size': BATCH_SIZE,
    'gradient_accumulation_steps': args.ga_steps,
    'gradient_clipping': GRAD_CLIP_NORM,
    'fp16': {
        'enabled': args.fp16,
    },
    'amp': {
        'enabled': args.amp,
        'opt_level': 'O1',
    },
    "flops_profiler": {
        "enabled": args.flops_profiler,
        "profile_step": 200,
        "module_depth": -1,
        "top_modules": 1,
        "detailed": True,
        "output_file": None # TODO 无法使其工作。
    },
}

# 如果 DeepSpeed 配置中的零优化阶段大于等于 2
if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2:
    print(f"Checkpoints made with DeepSpeed ZeRO Stages 2 and 3 will be stored in deepspeed checkpoint folder")
    print(f"As such, they will require DeepSpeed as a dependency in order to resume from or generate with.")
    print("See the deespeed conversion script for details on how to convert your ZeRO stage 2/3 checkpoint to a single file.")
    print("If using a single GPU, consider running with apex automatic mixed precision instead for a similar speedup to ZeRO.")
    time.sleep(2)

# 分发模型、优化器、数据加载器和调度器
(distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute(
    args=args,
    model=dalle,
    optimizer=opt,
    model_parameters=get_trainable_params(dalle),
    training_data=(
        (None if ENABLE_WEBDATASET else ds)
        if using_deepspeed
        else dl
    ),
    # 不将 LR 调度器传递给 DeepSpeed,以便手动推进
    lr_scheduler=scheduler if LR_DECAY and not using_deepspeed else None,
    config_params=deepspeed_config,
)
# 优先使用 `deepspeed_config` 中的调度器。

# 如果启用了 LR 衰减且分发调度器为 None,则使用全局调度器
if LR_DECAY and distr_scheduler is None:
    distr_scheduler = scheduler

# 如果正在使用 DeepSpeed 并且启用了 fp16
avoid_model_calls = using_deepspeed and args.fp16

# 如果恢复训练并且正在使用 DeepSpeed
if RESUME and using_deepspeed:
    distr_dalle.load_checkpoint(str(cp_dir))

# 保存模型
def save_model(path, epoch=0):
    save_obj = {
        'hparams': dalle_params,
        'vae_params': vae_params,
        'epoch': epoch,
        'version': __version__,
        'vae_class_name': vae.__class__.__name__
    }

    # 如果使用 DeepSpeed
    if using_deepspeed:
        cp_dir = cp_path_to_dir(path, 'ds')

        # 如果保留的检查点数量不为 None 且为根节点
        if KEEP_N_CHECKPOINTS is not None and is_root:
            checkpoints = sorted(glob(str(cp_dir / "global*")), key=os.path.getmtime, reverse=True)
            for checkpoint in checkpoints[KEEP_N_CHECKPOINTS:]:
                shutil.rmtree(checkpoint)

        # 保存 DeepSpeed 检查点
        distr_dalle.save_checkpoint(cp_dir, client_state=save_obj)

        if not is_root:
            return

        # 保存辅助值以便重用标准加载程序
        save_obj = {
            **save_obj,
            # 保存一个无意义的值,指导用户进一步帮助
            'weights': (
                'To get a working standard checkpoint, '
                'look into consolidating DeepSpeed checkpoints.'
            ),
        }
        torch.save(save_obj, str(cp_dir / DEEPSPEED_CP_AUX_FILENAME))
        if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2: # 参见 https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
            return

    if not is_root:
        return

    save_obj = {
        **save_obj,
        'weights': dalle.state_dict(),
        'opt_state': opt.state_dict(),
        'scheduler_state': (scheduler.state_dict() if scheduler else None)
    }

    torch.save(save_obj, path)

# 保存模型配置和路径为 artifact
def save_artifact(model_config, model_path, name = 'trained-dalle'):
    model_artifact = wandb.Artifact(name, type='model', metadata=dict(model_config))
    model_artifact.add_file(model_path)
    run.log_artifact(model_artifact)

# 训练
# 在训练开始之前保存一个检查点,以便在配置错误时提前失败
# 参考 https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints

# 保存模型
save_model(DALLE_OUTPUT_FILE_NAME, epoch=resume_epoch)

# 循环每个 epoch
for epoch in range(resume_epoch, EPOCHS):
    # 如果有数据采样器,则设置当前 epoch
    if data_sampler:
        data_sampler.set_epoch(epoch)

    # 遍历数据加载器
    for i, (text, images) in enumerate((dl if ENABLE_WEBDATASET else distr_dl)):
        # 每隔 10 步打印时间
        if i % 10 == 0 and is_root:
            t = time.time()

        # 如果启用了 fp16,将图像转换为半精度
        if args.fp16:
            images = images.half()

        # 将文本和图像移动到 GPU
        text, images = map(lambda t: t.cuda(), (text, images))

        # 计算损失
        loss = distr_dalle(text, images, return_loss=True)

        # 如果使用了 DeepSpeed
        if using_deepspeed:
            distr_dalle.backward(loss)
            distr_dalle.step()
            # 梯度在步骤后会自动清零
        else:
            loss.backward()
            clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM)
            distr_opt.step()
            distr_opt.zero_grad()

        # 计算集体损失,取平均值
        avg_loss = distr_backend.average_all(loss)

        log = {}

        # 每隔 10 步打印损失
        if i % 10 == 0 and is_root:
            print(epoch, i, f'loss - {avg_loss.item()}')

            log = {
                **log,
                'epoch': epoch,
                'iter': i,
                'loss': avg_loss.item()
            }

        # 每隔 SAVE_EVERY_N_STEPS 步保存模型
        if i % SAVE_EVERY_N_STEPS == 0:
            save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)

        # 每隔 100 步处理图像和日志
        if i % 100 == 0 and is_root:
            sample_text = text[:1]
            token_list = sample_text.masked_select(sample_text != 0).tolist()
            decoded_text = tokenizer.decode(token_list)

            if not avoid_model_calls:
                # 避免 CUDA 索引错误
                image = dalle.generate_images(text[:1], filter_thres=0.9)  # 使用 0.9 的 topk 抽样

            if not avoid_model_calls:
                log['image'] = wandb.Image(image, caption=decoded_text)

        # 每隔 10 步打印每秒样本数
        if i % 10 == 9 and is_root:
            sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
            log["sample_per_sec"] = sample_per_sec
            print(epoch, i, f'sample_per_sec - {sample_per_sec}')

        # 如果达到指定步数并启用了 FLOPS ���析器,则停止训练
        if i == 201 and args.flops_profiler:
            raise StopIteration("Profiler has finished running. Stopping training early.")

        # 如果是根节点,记录日志
        if is_root:
            wandb.log(log)

    # 如果启用了学习率衰减,根据平均损失调整学习率
    if LR_DECAY:
        distr_scheduler.step(avg_loss)

    # 每个 epoch 结束时保存模型
    save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)

    if is_root:
        # 每个 epoch 结束时将训练好的模型保存到 wandb 作为 artifact
        save_artifact(model_config, DALLE_OUTPUT_FILE_NAME)

# 最后保存模型
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)

if is_root:
    # 保存训练好的模型到 wandb,并完成 wandb 日志
    wandb.save(DALLE_OUTPUT_FILE_NAME)
    save_artifact(model_config, DALLE_OUTPUT_FILE_NAME)
    wandb.finish()

.\lucidrains\DALLE-pytorch\train_vae.py

# 导入数学库
import math
# 从数学库中导入平方根函数
from math import sqrt
# 导入参数解析库
import argparse
# 从路径库中导入路径类
from pathlib import Path

# 导入 torch 库
import torch
# 从 torch 优化模块中导入 Adam 优化器
from torch.optim import Adam
# 从 torch 优化学习率调度模块中导入指数衰减学习率调度器
from torch.optim.lr_scheduler import ExponentialLR

# 导入视觉库
from torchvision import transforms as T
# 从 torch 工具数据模块中导入数据加载器
from torch.utils.data import DataLoader
# 从 torchvision 数据集模块中导入图像文件夹数据集类
from torchvision.datasets import ImageFolder
# 从 torchvision 工具模块中导入制作网格、保存图像的函数
from torchvision.utils import make_grid, save_image

# 导入 dalle_pytorch 类和工具
from dalle_pytorch import distributed_utils
from dalle_pytorch import DiscreteVAE

# 参数解析
parser = argparse.ArgumentParser()

# 添加图像文件夹路径参数
parser.add_argument('--image_folder', type=str, required=True,
                    help='path to your folder of images for learning the discrete VAE and its codebook')
# 添加图像大小参数
parser.add_argument('--image_size', type=int, required=False, default=128,
                    help='image size')

# 将参数解析器包装为分布式工具的参数解析器
parser = distributed_utils.wrap_arg_parser(parser)

# 训练参数组
train_group = parser.add_argument_group('Training settings')

# 添加训练轮数参数
train_group.add_argument('--epochs', type=int, default=20, help='number of epochs')
# 添加批量大小参数
train_group.add_argument('--batch_size', type=int, default=8, help='batch size')
# 添加学习率参数
train_group.add_argument('--learning_rate', type=float, default=1e-3, help='learning rate')
# 添加学习率衰减率参数
train_group.add_argument('--lr_decay_rate', type=float, default=0.98, help='learning rate decay')
# 添加初始温度参数
train_group.add_argument('--starting_temp', type=float, default=1., help='starting temperature')
# 添加最小温度参数
train_group.add_argument('--temp_min', type=float, default=0.5, help='minimum temperature to anneal to')
# 添加退火率参数
train_group.add_argument('--anneal_rate', type=float, default=1e-6, help='temperature annealing rate')
# 添加保存图像数量参数
train_group.add_argument('--num_images_save', type=int, default=4, help='number of images to save')

# 模型参数组
model_group = parser.add_argument_group('Model settings')

# 添加图��令牌数量参数
model_group.add_argument('--num_tokens', type=int, default=8192, help='number of image tokens')
# 添加层数参数
model_group.add_argument('--num_layers', type=int, default=3, help='number of layers (should be 3 or above)')
# 添加残差网络块数量参数
model_group.add_argument('--num_resnet_blocks', type=int, default=2, help='number of residual net blocks')
# 添加平滑 L1 损失参数
model_group.add_argument('--smooth_l1_loss', dest='smooth_l1_loss', action='store_true')
# 添加嵌入维度参数
model_group.add_argument('--emb_dim', type=int, default=512, help='embedding dimension')
# 添加隐藏维度参数
model_group.add_argument('--hidden_dim', type=int, default=256, help='hidden dimension')
# 添加 KL 损失权重参数
model_group.add_argument('--kl_loss_weight', type=float, default=0., help='KL loss weight')
# 添加透明度参数
model_group.add_argument('--transparent', dest='transparent', action='store_true')

# 解析参数
args = parser.parse_args()

# 常量

# 图像大小
IMAGE_SIZE = args.image_size
# 图像文件夹路径
IMAGE_PATH = args.image_folder

# 训练轮数
EPOCHS = args.epochs
# 批量大小
BATCH_SIZE = args.batch_size
# 学习率
LEARNING_RATE = args.learning_rate
# 学习率衰减率
LR_DECAY_RATE = args.lr_decay_rate

# 图像令牌数量
NUM_TOKENS = args.num_tokens
# 层数
NUM_LAYERS = args.num_layers
# 残差网络块数量
NUM_RESNET_BLOCKS = args.num_resnet_blocks
# 平滑 L1 损失
SMOOTH_L1_LOSS = args.smooth_l1_loss
# 嵌入维度
EMB_DIM = args.emb_dim
# 隐藏维度
HIDDEN_DIM = args.hidden_dim
# KL 损失权重
KL_LOSS_WEIGHT = args.kl_loss_weight

# 透明度
TRANSPARENT = args.transparent
# 通道数
CHANNELS = 4 if TRANSPARENT else 3
# 图像模式
IMAGE_MODE = 'RGBA' if TRANSPARENT else 'RGB'

# 初始温度
STARTING_TEMP = args.starting_temp
# 最小温度
TEMP_MIN = args.temp_min
# 退火率
ANNEAL_RATE = args.anneal_rate

# 保存图像数量
NUM_IMAGES_SAVE = args.num_images_save

# 初始化分布式后端
distr_backend = distributed_utils.set_backend_from_args(args)
distr_backend.initialize()

# 是否使用 DeepSpeed
using_deepspeed = distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)

# 数据

# 创建图像文件夹数据集
ds = ImageFolder(
    IMAGE_PATH,
    T.Compose([
        # 将图像转换为指定模式
        T.Lambda(lambda img: img.convert(IMAGE_MODE) if img.mode != IMAGE_MODE else img),
        # 调整大小
        T.Resize(IMAGE_SIZE),
        # 中心裁剪
        T.CenterCrop(IMAGE_SIZE),
        # 转换为张量
        T.ToTensor()
    ])
)

if distributed_utils.using_backend(distributed_utils.HorovodBackend):
    # 创建一个用于分布式训练的数据采样器,用于在不同进程之间分配数据
    data_sampler = torch.utils.data.distributed.DistributedSampler(
        ds, num_replicas=distr_backend.get_world_size(),
        rank=distr_backend.get_rank())
# 如果条件不成立,将数据采样器设置为 None
else:
    data_sampler = None

# 创建数据加载器,设置批量大小、是否打乱数据、数据采样器
dl = DataLoader(ds, BATCH_SIZE, shuffle = not data_sampler, sampler=data_sampler)

# 定义 VAE 的参数
vae_params = dict(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    channels = CHANNELS,
    codebook_dim = EMB_DIM,
    hidden_dim   = HIDDEN_DIM,
    num_resnet_blocks = NUM_RESNET_BLOCKS
)

# 创建离散 VAE 模型
vae = DiscreteVAE(
    **vae_params,
    smooth_l1_loss = SMOOTH_L1_LOSS,
    kl_div_loss_weight = KL_LOSS_WEIGHT
)

# 如果不使用 DeepSpeed,则将 VAE 模型移到 GPU 上
if not using_deepspeed:
    vae = vae.cuda()

# 断言数据集中有数据
assert len(ds) > 0, 'folder does not contain any images'
if distr_backend.is_root_worker():
    # 打印找到的图片数量
    print(f'{len(ds)} images found for training')

# 优化器
opt = Adam(vae.parameters(), lr = LEARNING_RATE)
sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE)

if distr_backend.is_root_worker():
    # weights & biases 实验跟踪
    import wandb

    model_config = dict(
        num_tokens = NUM_TOKENS,
        smooth_l1_loss = SMOOTH_L1_LOSS,
        num_resnet_blocks = NUM_RESNET_BLOCKS,
        kl_loss_weight = KL_LOSS_WEIGHT
    )

    # 初始化 weights & biases 实验
    run = wandb.init(
        project = 'dalle_train_vae',
        job_type = 'train_model',
        config = model_config
    )

# 分布式
distr_backend.check_batch_size(BATCH_SIZE)
deepspeed_config = {'train_batch_size': BATCH_SIZE}

# 分布式训练
(distr_vae, distr_opt, distr_dl, distr_sched) = distr_backend.distribute(
    args=args,
    model=vae,
    optimizer=opt,
    model_parameters=vae.parameters(),
    training_data=ds if using_deepspeed else dl,
    lr_scheduler=sched if not using_deepspeed else None,
    config_params=deepspeed_config,
)

using_deepspeed_sched = False
# 如果没有使用 DeepSpeed 调度器,则使用 sched
if distr_sched is None:
    distr_sched = sched
elif using_deepspeed:
    # 使用 DeepSpeed LR 调度器,并让 DeepSpeed 处理调度
    using_deepspeed_sched = True

# 保存模型
def save_model(path):
    save_obj = {
        'hparams': vae_params,
    }
    if using_deepspeed:
        cp_path = Path(path)
        path_sans_extension = cp_path.parent / cp_path.stem
        cp_dir = str(path_sans_extension) + '-ds-cp'

        # 保存 DeepSpeed 检查点
        distr_vae.save_checkpoint(cp_dir, client_state=save_obj)
        # 不返回以获取一个“正常”的检查点来参考

    if not distr_backend.is_root_worker():
        return

    save_obj = {
        **save_obj,
        'weights': vae.state_dict()
    }

    # 保存模型权重
    torch.save(save_obj, path)

# 设置初始温度
global_step = 0
temp = STARTING_TEMP

# 训练循环
for epoch in range(EPOCHS):
    # 遍历数据加载器中的图像数据和标签,使用enumerate获取索引和数据
    for i, (images, _) in enumerate(distr_dl):
        # 将图像数据移动到GPU上进行加速处理
        images = images.cuda()

        # 使用分布式VAE模型计算损失和重构图像
        loss, recons = distr_vae(
            images,
            return_loss = True,
            return_recons = True,
            temp = temp
        )

        # 如果使用DeepSpeed,则自动将梯度清零并执行优化步骤
        if using_deepspeed:
            # 梯度在步骤后自动清零
            distr_vae.backward(loss)
            distr_vae.step()
        else:
            # 否则手动将优化器梯度清零,计算梯度并执行优化步骤
            distr_opt.zero_grad()
            loss.backward()
            distr_opt.step()

        # 初始化日志字典
        logs = {}

        # 每100个迭代打印日志
        if i % 100 == 0:
            # 如果是根节点工作进程
            if distr_backend.is_root_worker():
                k = NUM_IMAGES_SAVE

                # 使用无梯度计算获取编码和硬重构图像
                with torch.no_grad():
                    codes = vae.get_codebook_indices(images[:k])
                    hard_recons = vae.decode(codes)

                # 截取部分图像和重构图像
                images, recons = map(lambda t: t[:k], (images, recons))
                # 将图像、重构图像、硬重构图像、编码转移到CPU并去除梯度信息
                images, recons, hard_recons, codes = map(lambda t: t.detach().cpu(), (images, recons, hard_recons, codes))
                # 将图像、重构图像、硬重构图像转换为图像网格
                images, recons, hard_recons = map(lambda t: make_grid(t.float(), nrow = int(sqrt(k)), normalize = True, range = (-1, 1)), (images, recons, hard_recons))

                # 更新日志字典
                logs = {
                    **logs,
                    'sample images':        wandb.Image(images, caption = 'original images'),
                    'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
                    'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
                    'codebook_indices':     wandb.Histogram(codes),
                    'temperature':          temp
                }

                # 保存模型
                wandb.save('./vae.pt')
            save_model(f'./vae.pt')

            # 温度退火

            temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN)

            # 学习率衰减

            # 不要从`deepspeed_config`中提前调整调度器
            if not using_deepspeed_sched:
                distr_sched.step()

        # 计算集合损失,取平均值
        avg_loss = distr_backend.average_all(loss)

        # 如果是根节点工作进程
        if distr_backend.is_root_worker():
            # 每10个迭代打印学习率和损失
            if i % 10 == 0:
                lr = distr_sched.get_last_lr()[0]
                print(epoch, i, f'lr - {lr:6f} loss - {avg_loss.item()}')

                # 更新日志字典
                logs = {
                    **logs,
                    'epoch': epoch,
                    'iter': i,
                    'loss': avg_loss.item(),
                    'lr': lr
                }

            # 记录日志
            wandb.log(logs)
        global_step += 1

    # 如果是根节点工作进程
    if distr_backend.is_root_worker():
        # 在每个epoch结束时将训练好的模型保存到wandb作为artifact

        model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
        model_artifact.add_file('vae.pt')
        run.log_artifact(model_artifact)
# 如果当前进程是根节点工作进程
if distr_backend.is_root_worker():
    # 保存最终的 VAE 模型并清理工作

    # 保存模型到文件 './vae-final.pt'
    save_model('./vae-final.pt')
    # 将模型文件上传到 wandb 服务器
    wandb.save('./vae-final.pt')

    # 创建一个 wandb Artifact 对象,用于存储训练好的 VAE 模型
    model_artifact = wandb.Artifact('trained-vae', type='model', metadata=dict(model_config))
    # 将 'vae-final.pt' 文件添加到 Artifact 对象中
    model_artifact.add_file('vae-final.pt')
    # 记录 Artifact 对象到当前运行日志中
    run.log_artifact(model_artifact)

    # 结束当前 wandb 运行
    wandb.finish()

DALLE2 Training Configurations

For more complex configuration, we provide the option of using a configuration file instead of command line arguments.

Decoder Trainer

The decoder trainer has 7 main configuration options. A full example of their use can be found in the example decoder configuration.

Unet:

This is a single unet config, which belongs as an array nested under the decoder config as a list of unets

Option Required Default Description
dim Yes N/A The starting channels of the unet.
image_embed_dim Yes N/A The dimension of the image embeddings.
dim_mults No (1, 2, 4, 8) The growth factors of the channels.

Any parameter from the Unet constructor can also be given here.

Decoder:

Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.

Option Required Default Description
unets Yes N/A A list of unets, using the configuration above
image_sizes Yes N/A The resolution of the image after each upsampling step. The length of this array should be the number of unets defined.
image_size Yes N/A Not used. Can be any number.
timesteps No 1000 The number of diffusion timesteps used for generation.
loss_type No l2 The loss function. Options are l1, huber, or l2.
beta_schedule No cosine The noising schedule. Options are cosine, linear, quadratic, jsd, or sigmoid.
learned_variance No True Whether to learn the variance.
clip No None The clip model to use if embeddings are being generated on the fly. Takes keys make and model with defaults openai and ViT-L/14.

Any parameter from the Decoder constructor can also be given here.

Data:

Settings for creation of the dataloaders.

Option Required Default Description
webdataset_base_url Yes N/A The url of a shard in the webdataset with the shard replaced with {}[1].
img_embeddings_url No None The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used.
text_embeddings_url No None The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used.
num_workers No 4 The number of workers used in the dataloader.
batch_size No 64 The batch size.
start_shard No 0 Defines the start of the shard range the dataset will recall.
end_shard No 9999999 Defines the end of the shard range the dataset will recall.
shard_width No 6 Defines the width of one webdataset shard number[2].
index_width No 4 Defines the width of the index of a file inside a shard[3].
splits No { "train": 0.75, "val": 0.15, "test": 0.1 } Defines the proportion of shards that will be allocated to the training, validation, and testing datasets.
shuffle_train No True Whether to shuffle the shards of the training dataset.
resample_train No False If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if shuffle_train is enabled.
preprocessing No { "ToTensor": True } Defines preprocessing applied to images from the datasets.

Train:

Settings for controlling the training hyperparameters.

Option Required Default Description
epochs No 20 The number of epochs in the training run.
lr No 1e-4 The learning rate.
wd No 0.01 The weight decay.
max_grad_norm No 0.5 The grad norm clipping.
save_every_n_samples No 100000 Samples will be generated and a checkpoint will be saved every save_every_n_samples samples.
cond_scale No 1.0 Conditioning scale to use for sampling. Can also be an array of values, one for each unet.
device No cuda:0 The device to train on.
epoch_samples No None Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit.
validation_samples No None The number of samples to use for validation. None mean the entire validation set.
use_ema No True Whether to use exponential moving average models for sampling.
ema_beta No 0.99 The ema coefficient.
unet_training_mask No None A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of None trains all unets.

Evaluate:

Defines which evaluation metrics will be used to test the model.
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.

Option Required Default Description
n_evaluation_samples No 1000 The number of samples to generate to test the model.
FID No None Setting to an object enables the Frechet Inception Distance metric.
IS No None Setting to an object enables the Inception Score metric.
KID No None Setting to an object enables the Kernel Inception Distance metric.
LPIPS No None Setting to an object enables the Learned Perceptual Image Patch Similarity metric.

Tracker:

Selects how the experiment will be tracked.

Option Required Default Description
data_path No ./.tracker-data The path to the folder where temporary tracker data will be saved.
overwrite_data_path No False If true, the data path will be overwritten. Otherwise, you need to delete it yourself.
log Yes N/A Logging configuration.
load No None Checkpoint loading configuration.
save Yes N/A Checkpoint/Model saving configuration.
Tracking is split up into three sections:
  • Log: Where to save run metadata and image output. Options are console or wandb.
  • Load: Where to load a checkpoint from. Options are local, url, or wandb.
  • Save: Where to save a checkpoint to. Options are local, huggingface, or wandb.

Logging:

All loggers have the following keys:

Option Required Default Description
log_type Yes N/A The type of logger class to use.
resume No False For loggers that have the option to resume an old run, resume it using maually input parameters.
auto_resume No False If true, the logger will attempt to resume an old run using parameters from that previous run.

If using console there is no further configuration than setting log_type to console.

Option Required Default Description
log_type Yes N/A Must be console.

If using wandb

Option Required Default Description
log_type Yes N/A Must be wandb.
wandb_entity Yes N/A The wandb entity to log to.
wandb_project Yes N/A The wandb project save the run to.
wandb_run_name No None The wandb run name.
wandb_run_id No None The wandb run id. Used if resuming an old run.

Loading:

All loaders have the following keys:

Option Required Default Description
load_from Yes N/A The type of loader class to use.
only_auto_resume No False If true, the loader will only load the model if the run is being auto resumed.

If using local

Option Required Default Description
load_from Yes N/A Must be local.
file_path Yes N/A The path to the checkpoint file.

If using url

Option Required Default Description
load_from Yes N/A Must be url.
url Yes N/A The url of the checkpoint file.

If using wandb

Option Required Default Description
load_from Yes N/A Must be wandb.
wandb_run_path No None The wandb run path. If None, uses the run that is being resumed.
wandb_file_path Yes N/A The path to the checkpoint file in the W&B file system.

Saving:
Unlike log and load, save may be an array of options so that you can save to different locations in a run.

All save locations have these configuration options

Option Required Default Description
save_to Yes N/A Must be local, huggingface, or wandb.
save_latest_to No None Sets the relative path to save the latest model to.
save_best_to No None Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models.
save_meta_to No None The path to save metadata files in. This includes the config files used to start the training.
save_type No checkpoint The type of save. checkpoint saves a checkpoint, model saves a model without any fluff (Saves with ema if ema is enabled).

If using local

Option Required Default Description
save_to Yes N/A Must be local.

If using huggingface

Option Required Default Description
save_to Yes N/A Must be huggingface.
huggingface_repo Yes N/A The huggingface repository to save to.
token_path No None If logging in with the huggingface cli is not possible, point to a token file instead.

If using wandb

Option Required Default Description
save_to Yes N/A Must be wandb.
wandb_run_path No None The wandb run path. If None, uses the current run. You will almost always want this to be None.

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\cli.py

# 导入需要的库
import click
import torch
import torchvision.transforms as T
from functools import reduce
from pathlib import Path

# 导入自定义模块
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior

# 定义函数,根据键路径获取字典中的值
def safeget(dictionary, keys, default = None):
    return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)

# 简单的文本转换函数,将特殊字符替换为下划线
def simple_slugify(text, max_length = 255):
    return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]

# 获取包的版本号
def get_pkg_version():
    from pkg_resources import get_distribution
    return get_distribution('dalle2_pytorch').version

# 主函数
def main():
    pass

# 命令行参数设置
@click.command()
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
@click.argument('text')
def dream(
    model,
    cond_scale,
    text
):
    # 获取模型路径
    model_path = Path(model)
    full_model_path = str(model_path.resolve())
    # 检查模型是否存在
    assert model_path.exists(), f'model not found at {full_model_path}'
    # 加载模型
    loaded = torch.load(str(model_path))

    # 获取模型版本号
    version = safeget(loaded, 'version')
    print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')

    # 获取初始化参数
    prior_init_params = safeget(loaded, 'init_params.prior')
    decoder_init_params = safeget(loaded, 'init_params.decoder')
    model_params = safeget(loaded, 'model_params')

    # 初始化 DiffusionPrior 和 Decoder
    prior = DiffusionPrior(**prior_init_params)
    decoder = Decoder(**decoder_init_params)

    # 初始化 DALLE2 模型
    dalle2 = DALLE2(prior, decoder)
    dalle2.load_state_dict(model_params)

    # 生成图像
    image = dalle2(text, cond_scale = cond_scale)

    # 转换为 PIL 图像并保存
    pil_image = T.ToPILImage()(image)
    return pil_image.save(f'./{simple_slugify(text)}.png')

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dalle2_pytorch.py

# 导入数学库
import math
# 导入随机数库
import random
# 导入进度条库
from tqdm.auto import tqdm
# 导入偏函数库
from functools import partial, wraps
# 导入上下文管理库
from contextlib import contextmanager
# 导入命名元组库
from collections import namedtuple
# 导入路径库
from pathlib import Path

# 导入 PyTorch 库
import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch import nn, einsum
import torchvision.transforms as T

# 导入 einops 库
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

# 导入 kornia 库
from kornia.filters import gaussian_blur2d
import kornia.augmentation as K

# 导入 dalle2_pytorch 库
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE

# 导入 resize_right 库
from resize_right import resize

# 导入旋转嵌入库
from rotary_embedding_torch import RotaryEmbedding

# 导入 x-clip 库
from x_clip import CLIP
from coca_pytorch import CoCa

# 常量定义
NAT = 1. / math.log(2.)

# 定义命名元组 UnetOutput
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 返回列表的第一个元素
def first(arr, d = None):
    if len(arr) == 0:
        return d
    return arr[0]

# 可选函数装饰器
def maybe(fn):
    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)
    return inner

# 默认值函数
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将值转换为元组
def cast_tuple(val, length = None, validate = True):
    if isinstance(val, list):
        val = tuple(val)

    out = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length) and validate:
        assert len(out) == length

    return out

# 获取模块的设备
def module_device(module):
    if isinstance(module, nn.Identity):
        return 'cpu' # 无关紧要
    return next(module.parameters()).device

# 初始化权重为零
def zero_init_(m):
    nn.init.zeros_(m.weight)
    if exists(m.bias):
        nn.init.zeros_(m.bias)

# 空上下文管理器
@contextmanager
def null_context(*args, **kwargs):
    yield

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 判断是否为浮点数类型
def is_float_dtype(dtype):
    return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])

# 判断是否为字符串列表
def is_list_str(x):
    if not isinstance(x, (list, tuple)):
        return False
    return all([type(el) == str for el in x])

# 将元组填充到指定长度
def pad_tuple_to_length(t, length, fillvalue = None):
    remain_length = length - len(t)
    if remain_length <= 0:
        return t
    return (*t, *((fillvalue,) * remain_length))

# 检查点辅助函数

def make_checkpointable(fn, **kwargs):
    if isinstance(fn, nn.ModuleList):
        return [maybe(make_checkpointable)(el, **kwargs) for el in fn]

    condition = kwargs.pop('condition', None)

    if exists(condition) and not condition(fn):
        return fn

    @wraps(fn)
    def inner(*args):
        input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])

        if not input_needs_grad:
            return fn(*args)

        return checkpoint(fn, *args)

    return inner

# 控制 CLIP 冻结的函数

def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

def unfreeze_all_layers_(module):
    set_module_requires_grad_(module, True)

def freeze_model_and_make_eval_(model):
    model.eval()
    freeze_all_layers_(model)

# 张量辅助函数

# 对数函数
def log(t, eps = 1e-12):
    return torch.log(t.clamp(min = eps))

# L2 归一化函数
def l2norm(t):
    return F.normalize(t, dim = -1)

# 调整图像大小函数
def resize_image_to(
    image,
    target_image_size,
    clamp_range = None,
    nearest = False,
    **kwargs
):
    orig_image_size = image.shape[-1]
    # 如果原始图像大小与目标图像大小相同,则直接返回原始图像
    if orig_image_size == target_image_size:
        return image

    # 如果不使用最近邻插值,则计算缩放因子并调整图像大小
    if not nearest:
        scale_factors = target_image_size / orig_image_size
        out = resize(image, scale_factors=scale_factors, **kwargs)
    # 如果使用最近邻插值,则使用最近邻插值方法调整图像大小
    else:
        out = F.interpolate(image, target_image_size, mode='nearest')

    # 如果指定了范围限制,则对输出图像进行范围限制
    if exists(clamp_range):
        out = out.clamp(*clamp_range)

    # 返回调整后的图像
    return out
# 图像归一化函数
# DDPMS 期望图像在 -1 到 1 的范围内
# 但 CLIP 可能不同

def normalize_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_zero_to_one(normed_img):
    return (normed_img + 1) * 0.5

# CLIP 相关适配器

EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])

class BaseClipAdapter(nn.Module):
    def __init__(self, clip, **kwargs):
        super().__init__()
        self.clip = clip
        self.overrides = kwargs

    def validate_and_resize_image(self, image):
        image_size = image.shape[-1]
        assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
        return resize_image_to(image, self.image_size)

    @property
    def dim_latent(self):
        raise NotImplementedError

    @property
    def image_size(self):
        raise NotImplementedError

    @property
    def image_channels(self):
        raise NotImplementedError

    @property
    def max_text_len(self):
        raise NotImplementedError

    def embed_text(self, text):
        raise NotImplementedError

    def embed_image(self, image):
        raise NotImplementedError

class XClipAdapter(BaseClipAdapter):
    @property
    def dim_latent(self):
        return self.clip.dim_latent

    @property
    def image_size(self):
        return self.clip.image_size

    @property
    def image_channels(self):
        return self.clip.image_channels

    @property
    def max_text_len(self):
        return self.clip.text_seq_len

    @torch.no_grad()
    def embed_text(self, text):
        text = text[..., :self.max_text_len]
        text_mask = text != 0
        encoder_output = self.clip.text_transformer(text)

        encoder_output_is_cls = encoder_output.ndim == 3

        text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)
        text_embed = self.clip.to_text_latent(text_cls)

        if exists(text_encodings):
            text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)

        return EmbeddedText(l2norm(text_embed), text_encodings)

    @torch.no_grad()
    def embed_image(self, image):
        image = self.validate_and_resize_image(image)
        encoder_output = self.clip.visual_transformer(image)
        image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
        image_embed = self.clip.to_visual_latent(image_cls)
        return EmbeddedImage(l2norm(image_embed), image_encodings)

class CoCaAdapter(BaseClipAdapter):
    @property
    def dim_latent(self):
        return self.clip.dim

    @property
    def image_size(self):
        assert 'image_size' in self.overrides
        return self.overrides['image_size']

    @property
    def image_channels(self):
        assert 'image_channels' in self.overrides
        return self.overrides['image_channels']

    @property
    def max_text_len(self):
        assert 'max_text_len' in self.overrides
        return self.overrides['max_text_len']

    @torch.no_grad()
    def embed_text(self, text):
        text = text[..., :self.max_text_len]
        text_mask = text != 0
        text_embed, text_encodings = self.clip.embed_text(text)
        text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
        return EmbeddedText(text_embed, text_encodings)

    @torch.no_grad()
    def embed_image(self, image):
        image = self.validate_and_resize_image(image)
        image_embed, image_encodings = self.clip.embed_image(image)
        return EmbeddedImage(image_embed, image_encodings)

class OpenAIClipAdapter(BaseClipAdapter):
    def __init__(
        self,
        name = 'ViT-B/32'
    ): 
        # 导入 clip 模块
        import clip
        # 加载 OpenAI 的 CLIP 模型和预处理函数
        openai_clip, preprocess = clip.load(name)
        # 调用父类的构造函数,初始化 CLIP 模型
        super().__init__(openai_clip)
        # 设置结束符号的 ID,用于处理 0 也是 '!' 的情况
        self.eos_id = 49407 

        # 获取文本注意力最终层
        text_attention_final = self.find_layer('ln_final')

        # 设置潜在维度
        self.dim_latent_ = text_attention_final.weight.shape[0]
        # 注册前向钩子
        self.handle = text_attention_final.register_forward_hook(self._hook)

        # 获取 CLIP 模型的归一化函数
        self.clip_normalize = preprocess.transforms[-1]
        # 标记是否已清除
        self.cleared = False

    # 查找指定层
    def find_layer(self,  layer):
        modules = dict([*self.clip.named_modules()])
        return modules.get(layer, None)

    # 清除钩子
    def clear(self):
        if self.cleared:
            return

        self.handle()

    # 钩子函数
    def _hook(self, _, inputs, outputs):
        self.text_encodings = outputs

    # 获取潜在维度
    @property
    def dim_latent(self):
        return self.dim_latent_

    # 获取图像大小
    @property
    def image_size(self):
        return self.clip.visual.input_resolution

    # 获取图像通道数
    @property
    def image_channels(self):
        return 3

    # 获取最大文本长度
    @property
    def max_text_len(self):
        return self.clip.context_length

    # 嵌入文本
    @torch.no_grad()
    def embed_text(self, text):
        text = text[..., :self.max_text_len]

        # 判断是否为结束符号
        is_eos_id = (text == self.eos_id)
        text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
        text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
        text_mask = text_mask & (text != 0)
        assert not self.cleared

        # 编码文本
        text_embed = self.clip.encode_text(text)
        text_encodings = self.text_encodings
        text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
        del self.text_encodings
        return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())

    # 嵌入图像
    @torch.no_grad()
    def embed_image(self, image):
        assert not self.cleared
        # 验证和调整图像大小
        image = self.validate_and_resize_image(image)
        image = self.clip_normalize(image)
        # 编码图像
        image_embed = self.clip.encode_image(image)
        return EmbeddedImage(l2norm(image_embed.float()), None)
class OpenClipAdapter(BaseClipAdapter):
    # OpenClipAdapter 类继承自 BaseClipAdapter 类
    def __init__(
        self,
        name = 'ViT-B/32',
        pretrained = 'laion400m_e32'
    ):
        # 导入 open_clip 模块
        import open_clip
        # 创建 OpenCLIP 模型和预处理方法
        clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)

        # 调用父类的构造函数,传入 clip 模型
        super().__init__(clip)
        # 设置结束符 ID
        self.eos_id = 49407

        # 查找文本注意力最终层
        text_attention_final = self.find_layer('ln_final')
        # 获取潜在维度
        self._dim_latent = text_attention_final.weight.shape[0]

        # 注册 forward hook
        self.handle = text_attention_final.register_forward_hook(self._hook)
        # 获取 CLIP 模型的归一化方法
        self.clip_normalize = preprocess.transforms[-1]
        # 标记是否已清除
        self.cleared = False

    # 查找指定层
    def find_layer(self,  layer):
        modules = dict([*self.clip.named_modules()])
        return modules.get(layer, None)

    # 清除方法
    def clear(self):
        if self.cleared:
            return

        self.handle()

    # 钩子方法
    def _hook(self, _, inputs, outputs):
        self.text_encodings = outputs

    @property
    def dim_latent(self):
        return self._dim_latent

    @property
    def image_size(self):
        # 获取图像尺寸
        image_size = self.clip.visual.image_size
        if isinstance(image_size, tuple):
            return max(image_size)
        return image_size

    @property
    def image_channels(self):
        return 3

    @property
    def max_text_len(self):
        return self.clip.context_length

    @torch.no_grad()
    def embed_text(self, text):
        # 截取文本长度
        text = text[..., :self.max_text_len]

        # 创建文本掩码
        is_eos_id = (text == self.eos_id)
        text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
        text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
        text_mask = text_mask & (text != 0)
        assert not self.cleared

        # 编码文本
        text_embed = self.clip.encode_text(text)
        text_encodings = self.text_encodings
        text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
        del self.text_encodings
        return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())

    @torch.no_grad()
    def embed_image(self, image):
        assert not self.cleared
        # 验证并调整图像大小
        image = self.validate_and_resize_image(image)
        image = self.clip_normalize(image)
        image_embed = self.clip.encode_image(image)
        return EmbeddedImage(l2norm(image_embed.float()), None)

# 分类器自由指导函数

# 创建概率掩码
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 高斯扩散辅助函数

# 提取函数
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

# 平均扁平函数
def meanflat(x):
    return x.mean(dim = tuple(range(1, len(x.shape))))

# 正态 KL 散度
def normal_kl(mean1, logvar1, mean2, logvar2):
    return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))

# 近��标准正态 CDF
def approx_standard_normal_cdf(x):
    return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3)))

# 离散化高斯对数似然
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
    assert x.shape == means.shape == log_scales.shape

    # 修正 nan 梯度
    eps = 1e-12 if x.dtype == torch.float32 else 1e-3

    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1. / 255.)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    min_in = inv_stdv * (centered_x - 1. / 255.)
    cdf_min = approx_standard_normal_cdf(min_in)
    log_cdf_plus = log(cdf_plus, eps = eps)
    log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
    cdf_delta = cdf_plus - cdf_min
    # 使用 torch.where 函数根据条件选择不同的操作
    # 如果 x 小于 -thres,则返回 log_cdf_plus
    # 如果 x 大于 thres,则返回 log_one_minus_cdf_min
    # 否则返回 log(cdf_delta, eps = eps)
    log_probs = torch.where(x < -thres,
        log_cdf_plus,
        torch.where(x > thres,
            log_one_minus_cdf_min,
            log(cdf_delta, eps = eps)))

    # 返回计算得到的 log_probs
    return log_probs
# 定义一个余弦调度函数,根据给定的时间步数和参数s生成一组beta值
def cosine_beta_schedule(timesteps, s = 0.008):
    # 计算总步数
    steps = timesteps + 1
    # 在0到timesteps之间生成均匀间隔的值,作为x
    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
    # 根据余弦函数计算alpha的累积乘积
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    # 将alpha的累积乘积除以第一个元素,得到归一化后的值
    alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
    # 计算beta值
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    # 将beta值限制在0到0.999之间
    return torch.clip(betas, 0, 0.999)


# 定义一个线性调度函数,根据给定的时间步数生成一组beta值
def linear_beta_schedule(timesteps):
    # 计算比例尺
    scale = 1000 / timesteps
    # 计算起始beta值
    beta_start = scale * 0.0001
    # 计算结束beta值
    beta_end = scale * 0.02
    # 在起始和结束之间生成均匀间隔的值,作为beta值
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)


# 定义一个二次调度函数,根据给定的时间步数生成一组beta值
def quadratic_beta_schedule(timesteps):
    # 计算比例尺
    scale = 1000 / timesteps
    # 计算起始beta值
    beta_start = scale * 0.0001
    # 计算结束beta值
    beta_end = scale * 0.02
    # 在起始和结束之间生成均匀间隔的值,然后取平方,作为beta值
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2


# 定义一个sigmoid调度函数,根据给定的时间步数生成一组beta值
def sigmoid_beta_schedule(timesteps):
    # 计算比例尺
    scale = 1000 / timesteps
    # 计算起始beta值
    beta_start = scale * 0.0001
    # 计算结束beta值
    beta_end = scale * 0.02
    # 在-6到6之间生成均匀间隔的值,作为betas
    betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
    # 对betas应用sigmoid函数,然后乘以结束和起始之间的差值,再加上起始值,得到最终的beta值
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start


# 定义一个噪声调度器类
class NoiseScheduler(nn.Module):
    # 初始化函数,设置参数和计算beta值
    def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
        # 调用父类的初始化函数
        super().__init__()

        # 根据不同的beta调度方式计算beta值
        if beta_schedule == "cosine":
            betas = cosine_beta_schedule(timesteps)
        elif beta_schedule == "linear":
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == "quadratic":
            betas = quadratic_beta_schedule(timesteps)
        elif beta_schedule == "jsd":
            betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
        elif beta_schedule == "sigmoid":
            betas = sigmoid_beta_schedule(timesteps)
        else:
            raise NotImplementedError()

        # 计算alphas值
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis = 0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

        # 获取时间步数并设置为类属性
        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)

        # 根据损失类型选择损失函数
        if loss_type == 'l1':
            loss_fn = F.l1_loss
        elif loss_type == 'l2':
            loss_fn = F.mse_loss
        elif loss_type == 'huber':
            loss_fn = F.smooth_l1_loss
        else:
            raise NotImplementedError()

        # 设置损失类型和损失函数为类属性
        self.loss_type = loss_type
        self.loss_fn = loss_fn

        # 注册缓冲区辅助函数,将double类型转换为float类型
        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))

        # 注册各种缓冲区
        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # 计算后验分布的方差
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        register_buffer('posterior_variance', posterior_variance)
        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

        # 设置是否进行p2损失重新加权的标志和p2损失权重
        self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
        register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)

    # 生成随机时间步
    def sample_random_times(self, batch):
        return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)

    # 计算后验分布
    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    # 从q分布中采样
    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
    # 计算给定时间点 t 的速度 v
    def calculate_v(self, x_start, t, noise = None):
        # 使用累积平方根 alpha 乘以噪声,减去累积平方根 1-alpha 乘以起始位置 x_start
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
        )

    # 从起始位置 x_from 到目标时间 to_t 的采样
    def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
        shape = x_from.shape
        noise = default(noise, lambda: torch.randn_like(x_from))

        # 提取累积平方根 alpha 和 1-alpha
        alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
        sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
        alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
        sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)

        # 计算采样结果
        return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha

    # 根据速度 v 预测起始位置
    def predict_start_from_v(self, x_t, t, v):
        # 使用累积平方根 alpha 乘以当前位置 x_t,减去累积平方根 1-alpha 乘以速度 v
        return (
            extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )

    # 根据噪声预测起始位置
    def predict_start_from_noise(self, x_t, t, noise):
        # 使用倒数累积平方根 alpha 乘以当前位置 x_t,减去倒数累积平方根 alpha-1 乘以噪声
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    # 根据起始位置和当前位置预测噪声
    def predict_noise_from_start(self, x_t, t, x0):
        # 使用倒数累积平方根 alpha 乘以当前位置 x_t 减去起始位置 x0,再除以倒数累积平方根 alpha-1
        return (
            (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
        )

    # 对损失进行 P2 重加权
    def p2_reweigh_loss(self, loss, times):
        # 如果没有 P2 损失重加权,则直接返回原始损失
        if not self.has_p2_loss_reweighting:
            return loss
        # 返回损失乘以 P2 损失权重
        return loss * extract(self.p2_loss_weight, times, loss.shape)
# 重新排列图像为序列

class RearrangeToSequence(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        x = rearrange(x, 'b c ... -> b ... c')  # 重新排列输入张量的维度
        x, ps = pack([x], 'b * c')  # 打包张量

        x = self.fn(x)  # 使用给定的函数处理张量

        x, = unpack(x, ps, 'b * c')  # 解包张量
        x = rearrange(x, 'b ... c -> b c ...')  # 重新排列输出张量的维度
        return x

# 扩散先验

class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
        super().__init__()
        self.eps = eps
        self.fp16_eps = fp16_eps
        self.stable = stable
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        eps = self.eps if x.dtype == torch.float32 else self.fp16_eps

        if self.stable:
            x = x / x.amax(dim = -1, keepdim = True).detach()

        var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = -1, keepdim = True)
        return (x - mean) * (var + eps).rsqrt() * self.g

class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
        super().__init__()
        self.eps = eps
        self.fp16_eps = fp16_eps
        self.stable = stable
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        eps = self.eps if x.dtype == torch.float32 else self.fp16_eps

        if self.stable:
            x = x / x.amax(dim = 1, keepdim = True).detach()

        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * (var + eps).rsqrt() * self.g

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 多层感知机

class MLP(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        *,
        expansion_factor = 2.,
        depth = 2,
        norm = False,
    ):
        super().__init__()
        hidden_dim = int(expansion_factor * dim_out)
        norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()

        layers = [nn.Sequential(
            nn.Linear(dim_in, hidden_dim),
            nn.SiLU(),
            norm_fn()
        )]

        for _ in range(depth - 1):
            layers.append(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.SiLU(),
                norm_fn()
            ))

        layers.append(nn.Linear(hidden_dim, dim_out))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x.float())

# 因果变换器的相对位置偏差

class RelPosBias(nn.Module):
    def __init__(
        self,
        heads = 8,
        num_buckets = 32,
        max_distance = 128,
    ):
        super().__init__()
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(
        relative_position,
        num_buckets = 32,
        max_distance = 128
    ):
        n = -relative_position
        n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
        return torch.where(is_small, n, val_if_large)
    # 前向传播函数,接受输入参数 i, j 和 device
    def forward(self, i, j, *, device):
        # 生成一个从 0 到 i-1 的长整型张量,使用指定设备
        q_pos = torch.arange(i, dtype = torch.long, device = device)
        # 生成一个从 0 到 j-1 的长整型张量,使用指定设备
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        # 计算相对位置矩阵,即 k_pos 和 q_pos 的差值
        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
        # 将相对位置矩阵映射到指定的桶中,使用 self._relative_position_bucket 方法
        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
        # 计算相对位置注意力偏置,使用 self.relative_attention_bias 方法
        values = self.relative_attention_bias(rp_bucket)
        # 重新排列结果张量的维度,将 'i j h' 转换为 'h i j'
        return rearrange(values, 'i j h -> h i j')
# 定义一个 SwiGLU 类,用于前向传播
class SwiGLU(nn.Module):
    """ 在 https://arxiv.org/abs/2204.0231 中成功使用 """
    def forward(self, x):
        # 将输入张量 x 按照最后一个维度分成两部分
        x, gate = x.chunk(2, dim = -1)
        # 返回经过门控线性单元激活函数处理后的结果
        return x * F.silu(gate)

# 定义一个 FeedForward 函数,用于创建前馈神经网络
def FeedForward(
    dim,
    mult = 4,
    dropout = 0.,
    post_activation_norm = False
):
    """ 后激活归一化 https://arxiv.org/abs/2110.09456 """

    # 计算内部维度
    inner_dim = int(mult * dim)
    # 返回一个包含多个层的神经网络模型
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        SwiGLU(),
        LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, dim, bias = False)
    )

# 定义一个 Attention 类,用于实现注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        causal = False,
        rotary_emb = None,
        cosine_sim = True,
        cosine_sim_scale = 16
    ):
        super().__init__()
        # 初始化注意力机制的参数
        self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
        self.cosine_sim = cosine_sim

        self.heads = heads
        inner_dim = dim_head * heads

        self.causal = causal
        self.norm = LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

        self.rotary_emb = rotary_emb

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

    def forward(self, x, mask = None, attn_bias = None):
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
        q = q * self.scale

        # 旋转嵌入

        if exists(self.rotary_emb):
            q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))

        # 添加空键/值以用于先验网络中的无分类器引导

        nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 是否使用余弦相似度

        if self.cosine_sim:
            q, k = map(l2norm, (q, k))

        q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))

        # 计算查询/键的相似性

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # 相对位置编码(T5 风格)

        if exists(attn_bias):
            sim = sim + attn_bias

        # 掩码

        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, max_neg_value)

        # 注意力

        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.type(sim.dtype)

        attn = self.dropout(attn)

        # 聚合值

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义一个 CausalTransformer 类,用于实现因果变换器
class CausalTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        norm_in = False,
        norm_out = True,
        attn_dropout = 0.,
        ff_dropout = 0.,
        final_proj = True,
        normformer = False,
        rotary_emb = True
    ): 
        # 调用父类的构造函数
        super().__init__()
        # 如果需要进行输入层归一化,则初始化 LayerNorm 对象,否则使用 nn.Identity()
        self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM

        # 初始化相对位置偏置对象
        self.rel_pos_bias = RelPosBias(heads = heads)

        # 如果需要旋转嵌入,则初始化 RotaryEmbedding 对象,否则为 None
        rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None

        # 初始化多层 Transformer 模块
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            # 每层包含注意力机制和前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
            ]))

        # 如果需要输出层归一化,则初始化 LayerNorm 对象,否则使用 nn.Identity()
        self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity()  # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
        # 如果需要最终投影,则初始化线性层,否则使用 nn.Identity()
        self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()

    def forward(self, x):
        # 获取输入张量的长度和设备信息
        n, device = x.shape[1], x.device

        # 对输入张量进行初始归一化处理
        x = self.init_norm(x)

        # 计算注意力偏置
        attn_bias = self.rel_pos_bias(n, n + 1, device = device)

        # 遍历每一层 Transformer 模块
        for attn, ff in self.layers:
            # 执行注意力机制和前馈神经网络操作
            x = attn(x, attn_bias = attn_bias) + x
            x = ff(x) + x

        # 对输出结果进行归一化处理
        out = self.norm(x)
        # 返回最终输出结果
        return self.project_out(out)
# 定义一个名为 DiffusionPriorNetwork 的神经网络模块
class DiffusionPriorNetwork(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        num_timesteps = None,
        num_time_embeds = 1,
        num_image_embeds = 1,
        num_text_embeds = 1,
        max_text_len = 256,
        self_cond = False,
        **kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置维度属性
        self.dim = dim

        # 设置时间嵌入、图像嵌入和文本嵌入的数量
        self.num_time_embeds = num_time_embeds
        self.num_image_embeds = num_image_embeds
        self.num_text_embeds = num_text_embeds

        # 将输入转换为文本嵌入
        self.to_text_embeds = nn.Sequential(
            nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),
            Rearrange('b (n d) -> b n d', n = num_text_embeds)
        )

        # 检查是否存在时间步长
        self.continuous_embedded_time = not exists(num_timesteps)

        # 将输入转换为时间嵌入
        self.to_time_embeds = nn.Sequential(
            nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
            Rearrange('b (n d) -> b n d', n = num_time_embeds)
        )

        # 将输入转换为图像嵌入
        self.to_image_embeds = nn.Sequential(
            nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),
            Rearrange('b (n d) -> b n d', n = num_image_embeds)
        )

        # 学习查询向量
        self.learned_query = nn.Parameter(torch.randn(dim))
        # 创建因果变换器
        self.causal_transformer = CausalTransformer(dim = dim, **kwargs)

        # dalle1 学习的填充策略

        # 设置最大文本长度
        self.max_text_len = max_text_len

        # 创建空文本编码和空文本嵌入
        self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))
        self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))
        self.null_image_embed = nn.Parameter(torch.randn(1, dim))

        # 是否使用自我条件,Hinton 的团队的新 ddpm 技术

        self.self_cond = self_cond

    # 带有条件缩放的前向传播函数
    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        # 调用前向传播函数
        logits = self.forward(*args, **kwargs)

        # 如果条件缩放为1���则直接返回logits
        if cond_scale == 1:
            return logits

        # 计算空logits
        null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
        # 返回经过条件缩放后的logits
        return null_logits + (logits - null_logits) * cond_scale

    # 前向传播函数
    def forward(
        self,
        image_embed,
        diffusion_timesteps,
        *,
        text_embed,
        text_encodings = None,
        self_cond = None,
        text_cond_drop_prob = 0.,
        image_cond_drop_prob = 0.
        ):
            # 解包图像嵌入的批次大小、维度、设备和数据类型
            batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype

            # 获取时间嵌入、图像嵌入和文本嵌入的数量
            num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds

            # 设置自身条件

            if self.self_cond:
                # 如果存在自身条件,则创建一个全零张量
                self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
                self_cond = rearrange(self_cond, 'b d -> b 1 d')

            # 在第2.2节,最后一段
            # "... 包括编码文本、CLIP文本嵌入、扩散时间步嵌入、噪声CLIP图像嵌入、用于预测的最终嵌入"

            # 将文本嵌入转换为所需格式
            text_embed = self.to_text_embeds(text_embed)
            # 将图像嵌入转换为所需格式
            image_embed = self.to_image_embeds(image_embed)

            # 分类器自由引导掩码

            # 创建文本保留掩码
            text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)
            text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')

            # 创建图像保留掩码
            image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
            image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')

            # 使文本编码变为可选
            # 尽管论文似乎暗示它是存在的 <--

            if not exists(text_encodings):
                text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
        
            # 创建一个掩码,用于检测文本编码中的填充
            mask = torch.any(text_encodings != 0., dim = -1)

            # 用学习填充令牌替换文本编码中的任何填充
            text_encodings = text_encodings[:, :self.max_text_len]
            mask = mask[:, :self.max_text_len]

            text_len = text_encodings.shape[-2]
            remainder = self.max_text_len - text_len

            if remainder > 0:
                text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
                mask = F.pad(mask, (0, remainder), value = False)

            # 使用空编码屏蔽文本编码
            null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)

            text_encodings = torch.where(
                rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,
                text_encodings,
                null_text_encodings
            )

            # 使用空文本嵌入屏蔽文本嵌入
            null_text_embeds = self.null_text_embeds.to(text_embed.dtype)

            text_embed = torch.where(
                text_keep_mask,
                text_embed,
                null_text_embeds
            )

            # 使用空图像嵌入屏蔽图像嵌入
            null_image_embed = self.null_image_embed.to(image_embed.dtype)

            image_embed = torch.where(
                image_keep_mask,
                image_embed,
                null_image_embed
            )

            # 文本嵌入是否用于条件取决于是否文本编码可用于注意力(对于分类器自由引导,尽管从论文中看出先前的ddpm未使用,因为目标不同)
            # 但让我们做正确的事情

            if self.continuous_embedded_time:
                diffusion_timesteps = diffusion_timesteps.type(dtype)

            # 将时间嵌入转换为所需格式
            time_embed = self.to_time_embeds(diffusion_timesteps)

            # 重复学习的查询,以预测图像嵌入(每个DDPM时间步)
            learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)

            if self.self_cond:
                learned_queries = torch.cat((self_cond, learned_queries), dim = -2)

            # 将各种嵌入拼接在一起
            tokens = torch.cat((
                text_encodings,
                text_embed,
                time_embed,
                image_embed,
                learned_queries
            ), dim = -2)

            # 注意力机制
            tokens = self.causal_transformer(tokens)

            # 获取学习的查询,应该预测图像嵌入(每个DDPM时间步)
            pred_image_embed = tokens[..., -1, :]

            return pred_image_embed
# 定义一个 DiffusionPrior 类,继承自 nn.Module
class DiffusionPrior(nn.Module):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        net,
        *,
        clip = None,  # 用于裁剪梯度的阈值
        image_embed_dim = None,  # 图像嵌入维度
        image_size = None,  # 图像尺寸
        image_channels = 3,  # 图像通道数,默认为3
        timesteps = 1000,  # 时间步数
        sample_timesteps = None,  # 采样时间步数
        cond_drop_prob = 0.,  # 条件丢弃概率
        text_cond_drop_prob = None,  # 文本条件丢弃概率
        image_cond_drop_prob = None,  # 图像条件丢弃概率
        loss_type = "l2",  # 损失类型,默认为 l2
        predict_x_start = True,  # 是否预测 x 的起始值
        predict_v = False,  # 是否预测速度
        beta_schedule = "cosine",  # beta 调度方式
        condition_on_text_encodings = True,  # 是否在文本编码上进行条件化,论文建议开启,但可以在 CLIP 预处理文本嵌入到图像嵌入训练中关闭
        sampling_clamp_l2norm = False,  # 是否在每个去噪迭代中对图像嵌入进行 l2 范数裁剪(类似于通常 DDPMs 的 -1 到 1 裁剪)
        sampling_final_clamp_l2norm = False,  # 是否对最终图像嵌入输出进行 l2 范数裁剪(这也适用于 DDPM 中的图像)
        training_clamp_l2norm = False,  # 是否在训练时对 l2 范数进行裁剪
        init_image_embed_l2norm = False,  # 是否初始化图像嵌入的 l2 范数
        image_embed_scale = None,  # 用于缩放 l2 范数的图像嵌入,使其更适合高斯扩散,由 Katherine (@crowsonkb) 在 https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 中提出
        clip_adapter_overrides = dict()  # 用于覆盖 clip 适配器的字典
    ):
        # 调用父类的构造函数
        super().__init__()

        # 设置样本时间步数
        self.sample_timesteps = sample_timesteps

        # 创建噪声调度器对象
        self.noise_scheduler = NoiseScheduler(
            beta_schedule = beta_schedule,
            timesteps = timesteps,
            loss_type = loss_type
        )

        # 如果指定了 clip 参数
        if exists(clip):
            # 检查图像通道数是否与 clip 接受的通道数相同
            assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'

            # 根据 clip 的类型进行适配
            if isinstance(clip, CLIP):
                clip = XClipAdapter(clip, **clip_adapter_overrides)
            elif isinstance(clip, CoCa):
                clip = CoCaAdapter(clip, **clip_adapter_overrides)

            # 断言 clip 是 BaseClipAdapter 类型
            assert isinstance(clip, BaseClipAdapter)
            # 冻结模型并设置为评估模式
            freeze_model_and_make_eval_(clip)
            self.clip = clip
        else:
            # 如果未指定 clip 参数,则需要指定图像嵌入维度
            assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
            self.clip = None

        # 设置网络和图像嵌入维度
        self.net = net
        self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)

        # 断言网络维度与图像嵌入维度相同
        assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'
        # 断言 clip 的潜在维度与图像嵌入维度相同
        assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'

        # 设置通道数
        self.channels = default(image_channels, lambda: clip.image_channels)

        # 设置文本条件丢弃概率和图像条件丢弃概率
        self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
        self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)

        # 是否使用分类器指导
        self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.
        self.condition_on_text_encodings = condition_on_text_encodings

        # 在论文中,他们不预测噪声,而是直接为图像嵌入预测 x0,声称实验结果更好。我将提供两者。

        self.predict_x_start = predict_x_start
        self.predict_v = predict_v # 优先于 predict_x_start

        # @crowsonkb 的建议 - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132

        # 设置图像嵌入缩放因子
        self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)

        # 是否在采样时强制进行 l2norm,类似于裁剪去噪时的操作

        self.sampling_clamp_l2norm = sampling_clamp_l2norm
        self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm

        self.training_clamp_l2norm = training_clamp_l2norm
        self.init_image_embed_l2norm = init_image_embed_l2norm

        # 设备跟踪器

        self.register_buffer('_dummy', torch.tensor([True]), persistent = False)

    @property
    def device(self):
        # 返回设备信息
        return self._dummy.device

    # 对图像嵌入进行 l2norm 裁剪
    def l2norm_clamp_embed(self, image_embed):
        return l2norm(image_embed) * self.image_embed_scale
    # 计算预测的均值、后验方差和后验对数方差,以及起始值
    def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
        # 断言条件,如果条件不成立则抛出异常
        assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
        
        # 使用网络进行预测,根据条件缩放和文本条件
        pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)

        # 根据预测值选择起始值
        if self.predict_v:
            x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
        elif self.predict_x_start:
            x_start = pred
        else:
            x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)

        # 如果需要剪裁去噪后的值,并且不是预测 x 的起始值
        if clip_denoised and not self.predict_x_start:
            x_start.clamp_(-1., 1.)

        # 如果预测 x 的起始值并且采样剪裁 L2 范数
        if self.predict_x_start and self.sampling_clamp_l2norm:
            x_start = l2norm(x_start) * self.image_embed_scale

        # 获取模型均值、后验方差和后验对数方差
        model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

    # 生成样本
    @torch.no_grad()
    def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
        # 获取输入 x 的形状和设备信息
        b, *_, device = *x.shape, x.device
        # 计算模型均值、模型方差和模型对数方差,以及起始值
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
        # 生成噪声
        noise = torch.randn_like(x)
        # 当 t == 0 时不添加噪声
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        # 根据模型均值、模型对数方差和噪声生成预测值
        pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        return pred, x_start

    # 循环生成样本
    @torch.no_grad()
    def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
        # 获取批量大小和设备信息
        batch, device = shape[0], self.device

        # 生成随机图像嵌入
        image_embed = torch.randn(shape, device = device)
        x_start = None # 用于自我条件

        # 如果初始化图像嵌入的 L2 范数
        if self.init_image_embed_l2norm:
            image_embed = l2norm(image_embed) * self.image_embed_scale

        # 遍历时间步骤,生成样本
        for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
            times = torch.full((batch,), i, device = device, dtype = torch.long)

            self_cond = x_start if self.net.self_cond else None
            image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)

        # 如果采样最终剪裁 L2 范数并且预测 x 的起始值
        if self.sampling_final_clamp_l2norm and self.predict_x_start:
            image_embed = self.l2norm_clamp_embed(image_embed)

        return image_embed

    # 无梯度计算
    @torch.no_grad()
    # 定义一个函数,用于在动态图像生成中循环采样,支持不同维度的输入
    def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
        # 获取输入形状的相关信息
        batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps

        # 在指定时间范围内生成时间序列
        times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]

        # 将时间序列反转并转换为整数列表
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:]))

        # 生成随机的图像嵌入向量
        image_embed = torch.randn(shape, device = device)

        x_start = None # 用于自条件生成

        # 如果需要对初始图像嵌入向量进行 L2 范数归一化
        if self.init_image_embed_l2norm:
            image_embed = l2norm(image_embed) * self.image_embed_scale

        # 在时间序列上进行循环采样
        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            alpha = alphas[time]
            alpha_next = alphas[time_next]

            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)

            self_cond = x_start if self.net.self_cond else None

            # 使用条件信息生成预测结果
            pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)

            # 推导 x0

            if self.predict_v:
                x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
            elif self.predict_x_start:
                x_start = pred
            else:
                x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)

            # 在可能预测噪声之前对 x0 进行裁剪

            if not self.predict_x_start:
                x_start.clamp_(-1., 1.)

            if self.predict_x_start and self.sampling_clamp_l2norm:
                x_start = self.l2norm_clamp_embed(x_start)

            # 预测噪声

            pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)

            if time_next < 0:
                image_embed = x_start
                continue

            c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
            noise = torch.randn_like(image_embed) if time_next > 0 else 0.

            image_embed = x_start * alpha_next.sqrt() + \
                          c1 * noise + \
                          c2 * pred_noise

        # 如果需要对最终的图像嵌入向量进行 L2 范数归一化
        if self.predict_x_start and self.sampling_final_clamp_l2norm:
            image_embed = self.l2norm_clamp_embed(image_embed)

        return image_embed

    # 用于在动态图像生成中循环采样的函数,支持不同维度的输入
    @torch.no_grad()
    def p_sample_loop(self, *args, timesteps = None, **kwargs):
        # 如果未指定时间步长,则使用默认值
        timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
        assert timesteps <= self.noise_scheduler.num_timesteps
        is_ddim = timesteps < self.noise_scheduler.num_timesteps

        # 根据是否为低维输入选择不同的采样函数
        if not is_ddim:
            normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
        else:
            normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)

        # 对图像嵌入向量进行缩放处理并返回
        image_embed = normalized_image_embed / self.image_embed_scale
        return image_embed
    # 定义一个函数,计算损失值
    def p_losses(self, image_embed, times, text_cond, noise = None):
        # 如果没有提供噪声,则生成一个默认的噪声
        noise = default(noise, lambda: torch.randn_like(image_embed))

        # 使用噪声调度器生成噪声图像嵌入
        image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)

        self_cond = None
        # 如果网络支持自身条件,并且随机数小于0.5
        if self.net.self_cond and random.random() < 0.5:
            # 使用网络生成自身条件
            with torch.no_grad():
                self_cond = self.net(image_embed_noisy, times, **text_cond).detach()

        # 使用网络进行预测
        pred = self.net(
            image_embed_noisy,
            times,
            self_cond = self_cond,
            text_cond_drop_prob = self.text_cond_drop_prob,
            image_cond_drop_prob = self.image_cond_drop_prob,
            **text_cond
        )

        # 如果需要预测起始图像并且训练时使用L2范数约束
        if self.predict_x_start and self.training_clamp_l2norm:
            # 对预测结果进行L2范数约束
            pred = self.l2norm_clamp_embed(pred)

        # 如果需要预测速度
        if self.predict_v:
            # 计算目标速度
            target = self.noise_scheduler.calculate_v(image_embed, times, noise)
        # 如果需要预测起始图像
        elif self.predict_x_start:
            target = image_embed
        else:
            target = noise

        # 计算损失值
        loss = self.noise_scheduler.loss_fn(pred, target)
        return loss

    # 生成一个批次的图像
    @torch.no_grad()
    @eval_decorator
    def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
        # 获取设备信息
        device = self.betas.device
        shape = (batch_size, self.image_embed_dim)

        # 生成随机噪声图像
        img = torch.randn(shape, device = device)

        # 对于每个时间步长,生成图像
        for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):
            img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
        return img

    # 生成样本
    @torch.no_grad()
    @eval_decorator
    def sample(
        self,
        text,
        num_samples_per_batch = 2,
        cond_scale = 1.,
        timesteps = None
    ):
        timesteps = default(timesteps, self.sample_timesteps)

        # 重复文本以匹配样本数
        text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)

        batch_size = text.shape[0]
        image_embed_dim = self.image_embed_dim

        # 嵌入文本
        text_embed, text_encodings = self.clip.embed_text(text)

        text_cond = dict(text_embed = text_embed)

        if self.condition_on_text_encodings:
            text_cond = {**text_cond, 'text_encodings': text_encodings}

        # 生成图像嵌入
        image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)

        # 计算文本和图像之间的相似度
        text_embeds = text_cond['text_embed']
        text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
        image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
        text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds)
        top_sim_indices = text_image_sims.topk(k = 1).indices
        top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
        top_image_embeds = image_embeds.gather(1, top_sim_indices)
        return rearrange(top_image_embeds, 'b 1 d -> b d')

    # 前向传播函数
    def forward(
        self,
        text = None,
        image = None,
        text_embed = None,      # 允许在预处理的CLIP文本和图像嵌入上进行训练
        image_embed = None,
        text_encodings = None,  # 以及CLIP文本编码
        *args,
        **kwargs
        # 检查是否提供了文本或文本嵌入,二者必须有一个
        assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
        # 检查是否提供了图像或图像嵌入,二者必须有一个
        assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
        # 如果在初始化时指定了要在文本编码上进行条件化,则文本编码必须存在
        assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'

        # 如果提供了图像,则使用CLIP模型嵌入图像
        if exists(image):
            image_embed, _ = self.clip.embed_image(image)

        # 根据传入的内容计算文本条件
        if exists(text):
            text_embed, text_encodings = self.clip.embed_text(text)

        # 创建文本条件字典
        text_cond = dict(text_embed = text_embed)

        # 如果在文本编码上进行条件化,则文本编码必须存在
        if self.condition_on_text_encodings:
            assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
            text_cond = {**text_cond, 'text_encodings': text_encodings}

        # 从ddpm中获取时间步条件
        batch, device = image_embed.shape[0], image_embed.device
        times = self.noise_scheduler.sample_random_times(batch)

        # 缩放图像嵌入
        image_embed *= self.image_embed_scale

        # 计算前向损失
        return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
# 定义一个最近邻上采样模块,将输入维度提升为指定的输出维度
def NearestUpsample(dim, dim_out = None):
    # 如果未指定输出维度,则默认与输入维度相同
    dim_out = default(dim_out, dim)

    return nn.Sequential(
        # 使用最近邻插值方式上采样,比例为2
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        # 使用3x3卷积核进行卷积,将输入维度转换为输出维度
        nn.Conv2d(dim, dim_out, 3, padding = 1)
    )

# 定义一个像素混洗上采样模块,用于解决棋盘伪影问题
class PixelShuffleUpsample(nn.Module):
    """
    code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
    https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
    """
    def __init__(self, dim, dim_out = None):
        super().__init__()
        # 如果未指定输出维度,则默认与输入维度相同
        dim_out = default(dim_out, dim)
        # 使用1x1卷积核将输入维度转换为输出维度的4倍
        conv = nn.Conv2d(dim, dim_out * 4, 1)

        self.net = nn.Sequential(
            # 进行卷积操作
            conv,
            # 使用SiLU激活函数
            nn.SiLU(),
            # 像素混洗操作,将通道数减少为原来的四分之一
            nn.PixelShuffle(2)
        )

        # 初始化卷积层的权重
        self.init_conv_(conv)

    # 初始化卷积层的权重
    def init_conv_(self, conv):
        o, i, h, w = conv.weight.shape
        conv_weight = torch.empty(o // 4, i, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义一个下采样模块,采用最优的像素解开操作
def Downsample(dim, dim_out = None):
    # https://arxiv.org/abs/2208.03641 显示这是最优的下采样方式
    # 在论文中被称为SP-conv,实际上是像素解开操作
    dim_out = default(dim_out, dim)
    return nn.Sequential(
        # 像素解开操作,将每个像素分成4个像素
        Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
        # 使用1x1卷积核将输入维度转换为输出维度
        nn.Conv2d(dim * 4, dim_out, 1)
    )

# 定义一个权重标准化的卷积层
class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        flattened_weights = rearrange(weight, 'o ... -> o (...)')

        mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')

        var = torch.var(flattened_weights, dim = -1, unbiased = False)
        var = rearrange(var, 'o -> o 1 1 1')

        weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

# 定义一个正弦位置编码模块
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        dtype, device = x.dtype, x.device
        assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'

        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
        emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
        return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)

# 定义一个块模块
class Block(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        weight_standardization = False
    ):
        super().__init__()
        conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d

        # 使用3x3卷积核进行卷积,将输入维度转换为输出维度
        self.project = conv_klass(dim, dim_out, 3, padding = 1)
        # 使用组归一化进行归一化
        self.norm = nn.GroupNorm(groups, dim_out)
        # 使用SiLU激活函数
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.project(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        *,
        cond_dim = None,
        time_cond_dim = None,
        groups = 8,
        weight_standardization = False,
        cosine_sim_cross_attn = False
    # 初始化函数,继承父类的初始化方法
    def __init__(
        super().__init__()

        # 初始化时间多层感知器为 None
        self.time_mlp = None

        # 如果时间条件维度存在
        if exists(time_cond_dim):
            # 创建时间多层感知器模型
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim_out * 2)
            )

        # 初始化交叉注意力为 None
        self.cross_attn = None

        # 如果条件维度存在
        if exists(cond_dim):
            # 创建交叉注意力模型
            self.cross_attn = CrossAttention(
                dim = dim_out,
                context_dim = cond_dim,
                cosine_sim = cosine_sim_cross_attn
            )

        # 创建第一个块
        self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
        # 创建第二个块
        self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
        # 如果输入维度不等于输出维度,创建卷积层;否则创建恒等映射
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    # 前向传播函数
    def forward(self, x, time_emb = None, cond = None):

        # 初始化缩放和平移为 None
        scale_shift = None
        # 如果时间多层感知器和时间嵌入都存在
        if exists(self.time_mlp) and exists(time_emb):
            # 通过时间多层感知器处理时间嵌入
            time_emb = self.time_mlp(time_emb)
            # 重新排列时间嵌入的维度
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            # 将处理后的时间嵌入分成两部分,分别表示缩放和平移
            scale_shift = time_emb.chunk(2, dim = 1)

        # 使用第一个块处理输入数据
        h = self.block1(x, scale_shift = scale_shift)

        # 如果交叉注意力存在
        if exists(self.cross_attn):
            # 确保条件存在
            assert exists(cond)

            # 重新排列隐藏状态的维度
            h = rearrange(h, 'b c ... -> b ... c')
            # 打包隐藏状态
            h, ps = pack([h], 'b * c')

            # 使用交叉注意力处理隐藏状态
            h = self.cross_attn(h, context = cond) + h

            # 解包隐藏状态
            h, = unpack(h, ps, 'b * c')
            # 重新排列隐藏状态的维度
            h = rearrange(h, 'b ... c -> b c ...')

        # 使用第二个块处理隐藏状态
        h = self.block2(h)
        # 返回最终结果,加上残差连接
        return h + self.res_conv(x)
# 定义交叉注意力模块
class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        norm_context = False,
        cosine_sim = False,
        cosine_sim_scale = 16
    ):
        super().__init__()
        self.cosine_sim = cosine_sim
        self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
        self.heads = heads
        inner_dim = dim_head * heads

        context_dim = default(context_dim, dim)

        self.norm = LayerNorm(dim)
        self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
        self.dropout = nn.Dropout(dropout)

        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

    def forward(self, x, context, mask = None):
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        context = self.norm_context(context)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        # add null key / value for classifier free guidance in prior net

        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        if self.cosine_sim:
            q, k = map(l2norm, (q, k))

        q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))

        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.type(sim.dtype)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义线性注意力模块
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 8,
        **kwargs
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads
        self.norm = ChanLayerNorm(dim)

        self.nonlin = nn.GELU()
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1, bias = False),
            ChanLayerNorm(dim)
        )

    def forward(self, fmap):
        h, x, y = self.heads, *fmap.shape[-2:]
        seq_len = x * y

        fmap = self.norm(fmap)
        q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale
        v = l2norm(v)

        k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

        out = self.nonlin(out)
        return self.to_out(out)

# 定义交叉嵌入层模块
class CrossEmbedLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        kernel_sizes,
        dim_out = None,
        stride = 2
    # 初始化函数,继承父类的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 断言所有卷积核大小与步长的奇偶性相同
        assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
        # 如果未指定输出维度,则与输入维度相同
        dim_out = default(dim_out, dim_in)

        # 对卷积核大小进行排序
        kernel_sizes = sorted(kernel_sizes)
        # 计算总共有多少个尺度
        num_scales = len(kernel_sizes)

        # 计算每个尺度的维度
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        # 最后一个尺度的维度为总维度减去前面各尺度的维度之和
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        # 创建卷积层列表
        self.convs = nn.ModuleList([])
        # 遍历卷积核大小和对应的尺度维度
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            # 将每个尺度的卷积层添加到列表中
            self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    # 前向传播函数
    def forward(self, x):
        # 对输入数据进行多尺度卷积操作,得到特征图元组
        fmaps = tuple(map(lambda conv: conv(x), self.convs))
        # 在通道维度上拼接特征图
        return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
    # 定义一个 UpsampleCombiner 类,继承自 nn.Module
    def __init__(
        self,
        dim,
        *,
        enabled = False,
        dim_ins = tuple(),
        dim_outs = tuple()
    ):
        # 初始化函数,接受维度 dim 和一些可选参数
        super().__init__()
        # 调用父类的初始化函数
        assert len(dim_ins) == len(dim_outs)
        # 断言输入维度和输出维度的长度相等
        self.enabled = enabled
        # 设置是否启用的标志

        if not self.enabled:
            # 如果未启用
            self.dim_out = dim
            # 设置输出维度为输入维度
            return

        self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
        # 使用输入维度和输出维度创建 Block 对象列表
        self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
        # 设置输出维度为输入维度加上所有输出维度之和

    def forward(self, x, fmaps = None):
        # 前向传播函数,接受输入 x 和特征图列表 fmaps,默认为 None
        target_size = x.shape[-1]
        # 获取输入 x 的最后一个维度大小

        fmaps = default(fmaps, tuple())
        # 如果 fmaps 为 None,则设置为空元组

        if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
            # 如果未启用或者 fmaps 为空或者 fmap_convs 为空
            return x
            # 返回输入 x

        fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
        # 调整特征图大小为目标大小
        outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
        # 对每个特征图应用对应的卷积操作
        return torch.cat((x, *outs), dim = 1)
        # 沿着指定维度拼接输入 x 和处理后的特征图列表

class Unet(nn.Module):
    # 定义一个 Unet 类,继承自 nn.Module
    def __init__(
        self,
        dim,
        *,
        image_embed_dim = None,
        text_embed_dim = None,
        cond_dim = None,
        num_image_tokens = 4,
        num_time_tokens = 2,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        channels_out = None,
        self_attn = False,
        attn_dim_head = 32,
        attn_heads = 16,
        lowres_cond = False,             # for cascading diffusion - https://cascaded-diffusion.github.io/
        lowres_noise_cond = False,       # for conditioning on low resolution noising, based on Imagen
        self_cond = False,               # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
        sparse_attn = False,
        cosine_sim_cross_attn = False,
        cosine_sim_self_attn = False,
        attend_at_middle = True,         # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
        cond_on_text_encodings = False,
        max_text_len = 256,
        cond_on_image_embeds = False,
        add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
        init_dim = None,
        init_conv_kernel_size = 7,
        resnet_groups = 8,
        resnet_weight_standardization = False,
        num_resnet_blocks = 2,
        init_cross_embed = True,
        init_cross_embed_kernel_sizes = (3, 7, 15),
        cross_embed_downsample = False,
        cross_embed_downsample_kernel_sizes = (2, 4),
        memory_efficient = False,
        scale_skip_connection = False,
        pixel_shuffle_upsample = True,
        final_conv_kernel_size = 1,
        combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
        checkpoint_during_training = False,
        **kwargs
    # 定义初始化函数,接受一系列参数

    def cast_model_parameters(
        self,
        *,
        lowres_cond,
        lowres_noise_cond,
        channels,
        channels_out,
        cond_on_image_embeds,
        cond_on_text_encodings,
    # 如果当前模型参数与输入参数相同,则返回当前模型
    ):
        if lowres_cond == self.lowres_cond and \
            channels == self.channels and \
            cond_on_image_embeds == self.cond_on_image_embeds and \
            cond_on_text_encodings == self.cond_on_text_encodings and \
            lowres_noise_cond == self.lowres_noise_cond and \
            channels_out == self.channels_out:
            return self

        # 更新参数字典
        updated_kwargs = dict(
            lowres_cond = lowres_cond,
            channels = channels,
            channels_out = channels_out,
            cond_on_image_embeds = cond_on_image_embeds,
            cond_on_text_encodings = cond_on_text_encodings,
            lowres_noise_cond = lowres_noise_cond
        )

        # 返回一个新的类实例,使用当前模型的局部变量和更新后的参数
        return self.__class__(**{**self._locals, **updated_kwargs})

    # 带有条件缩放的前向传播函数
    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        # 调用前向传播函数获取 logits
        logits = self.forward(*args, **kwargs)

        # 如果条件缩放因子为1,则直接返回 logits
        if cond_scale == 1:
            return logits

        # 计算无条件 logits
        null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
        # 返回加权后的 logits
        return null_logits + (logits - null_logits) * cond_scale

    # 前向传播函数
    def forward(
        self,
        x,
        time,
        *,
        image_embed,
        lowres_cond_img = None,
        lowres_noise_level = None,
        text_encodings = None,
        image_cond_drop_prob = 0.,
        text_cond_drop_prob = 0.,
        blur_sigma = None,
        blur_kernel_size = None,
        disable_checkpoint = False,
        self_cond = None
# 定义一个低分辨率条件器的类,继承自 nn.Module
class LowresConditioner(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        downsample_first = True,  # 是否先降采样
        use_blur = True,  # 是否使用模糊
        blur_prob = 0.5,  # 模糊概率
        blur_sigma = 0.6,  # 模糊标准差
        blur_kernel_size = 3,  # 模糊核大小
        use_noise = False,  # 是否使用噪声
        input_image_range = None,  # 输入图像范围
        normalize_img_fn = identity,  # 图像归一化函数
        unnormalize_img_fn = identity  # 图像反归一化函数
    ):
        super().__init__()  # 调用父类的初始化函数
        self.downsample_first = downsample_first  # 是否先降采样
        self.input_image_range = input_image_range  # 输入图像范围

        self.use_blur = use_blur  # 是否使用模糊
        self.blur_prob = blur_prob  # 模糊概率
        self.blur_sigma = blur_sigma  # 模糊标准差
        self.blur_kernel_size = blur_kernel_size  # 模糊核大小

        self.use_noise = use_noise  # 是否使用噪声
        self.normalize_img = normalize_img_fn  # 图像归一化函数
        self.unnormalize_img = unnormalize_img_fn  # 图像反归一化函数
        self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None  # 噪声调度器

    # 添加噪声到图像
    def noise_image(self, cond_fmap, noise_levels = None):
        assert exists(self.noise_scheduler)  # 断言噪声调度器存在

        batch = cond_fmap.shape[0]  # 批次大小
        cond_fmap = self.normalize_img(cond_fmap)  # 归一化图像

        random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch))  # 随机噪声级别
        cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap))  # 添加噪声

        cond_fmap = self.unnormalize_img(cond_fmap)  # 反归一化图像
        return cond_fmap, random_noise_levels  # 返回添加噪声后的图像和随机噪声级别

    # 前向传播函数
    def forward(
        self,
        cond_fmap,
        *,
        target_image_size,  # 目标图像大小
        downsample_image_size = None,  # 降采样图像大小
        should_blur = True,  # 是否应该模糊
        blur_sigma = None,  # 模糊标准差
        blur_kernel_size = None  # 模糊核大小
    ):
        if self.downsample_first and exists(downsample_image_size):  # 如果先降采样且降采样图像大小存在
            cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True)  # 调整图像大小

        # 模糊只有50%的概率应用
        # 参考 https://arxiv.org/abs/2106.15282 中的第3.1节

        if self.use_blur and should_blur and random.random() < self.blur_prob:  # 如果使用模糊且应该模糊且随机数小于模糊概率
            # 在训练时,模糊低分辨率条件图像

            blur_sigma = default(blur_sigma, self.blur_sigma)  # 默认模糊标准差
            blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)  # 默认模糊核大小

            # 允许在 lo 和 hi 浮点值之间绘制随机标准差

            if isinstance(blur_sigma, tuple):  # 如果模糊标准差是元组
                blur_sigma = tuple(map(float, blur_sigma))  # 转换为浮点数元组
                blur_sigma = random.uniform(*blur_sigma)  # 在范围内随机选择一个值

            # 允许在 lo 和 hi 整数值之间绘制随机核大小

            if isinstance(blur_kernel_size, tuple):  # 如果模糊核大小是元组
                blur_kernel_size = tuple(map(int, blur_kernel_size))  # 转换为整数元组
                kernel_size_lo, kernel_size_hi = blur_kernel_size  # 获取最小和最大值
                blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)  # 在范围内随机选择一个值

            cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))  # 二维高斯模糊

        # 调整到目标图像大小

        cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)  # 调整图像大小

        # 噪声调节,如在 Imagen 中所做
        # 作为 BSR 噪声的替代,并可能替换第一阶段的模糊

        random_noise_levels = None  # 随机噪声级别为空

        if self.use_noise:  # 如果使用噪声
            cond_fmap, random_noise_levels = self.noise_image(cond_fmap)  # 添加噪声

        # 返回条件特征图,以及增强噪声级别

        return cond_fmap, random_noise_levels  # 返回条件特征图和随机噪声级别

# 解码器类
class Decoder(nn.Module):
    # 初始化函数,设置各种参数和默认值
    def __init__(
        self,
        unet,
        *,
        clip = None,                               # 剪辑参数
        image_size = None,                         # 图像大小
        channels = 3,                              # 通道数
        vae = tuple(),                             # 变分自动编码器
        timesteps = 1000,                          # 时间步数
        sample_timesteps = None,                   # 采样时间步数
        image_cond_drop_prob = 0.1,                # 图像条件概率
        text_cond_drop_prob = 0.5,                 # 文本条件概率
        loss_type = 'l2',                          # 损失类型
        beta_schedule = None,                      # beta调度
        predict_x_start = False,                   # 预测x的起始点
        predict_v = False,                         # 预测v
        predict_x_start_for_latent_diffusion = False,  # 用于潜在扩散的预测x的起始点
        image_sizes = None,                        # 用于级联ddpm,每个阶段的图像大小
        random_crop_sizes = None,                  # 是否在级联中随机裁剪图像
        use_noise_for_lowres_cond = False,         # 是否在低分辨率条件下使用噪声
        use_blur_for_lowres_cond = True,           # 是否在低分辨率条件下使用模糊
        lowres_downsample_first = True,            # 级联ddpm - 先缩小分辨率,然后到下一个条件分辨率+模糊
        blur_prob = 0.5,                           # 训练时,高斯模糊仅应用50%的时间
        blur_sigma = 0.6,                          # 模糊sigma
        blur_kernel_size = 3,                      # 模糊核大小
        lowres_noise_sample_level = 0.2,           # 在样本时间为低分辨率条件使用0.2的噪声水平
        clip_denoised = True,                      # 剪辑去噪
        clip_x_start = True,                       # 剪辑x的起始点
        clip_adapter_overrides = dict(),           # 剪辑适配器覆盖
        learned_variance = True,                   # 学习方差
        learned_variance_constrain_frac = False,   # 学习方差约束分数
        vb_loss_weight = 0.001,                    # vb损失权重
        unconditional = False,                     # 为生成没有条件的图像设置为True
        auto_normalize_img = True,                 # 是否自动归一化图像
        use_dynamic_thres = False,                 # 是否使用动态阈值
        dynamic_thres_percentile = 0.95,           # 动态阈值百分位数
        p2_loss_weight_gamma = 0.,                 # p2损失权重
        p2_loss_weight_k = 1,                      # p2损失权重k
        ddim_sampling_eta = 0.                     # 确定性采样
    @property
    def device(self):
        return self._dummy.device

    @property
    def condition_on_text_encodings(self):
        return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])

    # 获取指定编号的unet
    def get_unet(self, unet_number):
        assert 0 < unet_number <= self.num_unets
        index = unet_number - 1
        return self.unets[index]

    # 解析unet输出
    def parse_unet_output(self, learned_variance, output):
        var_interp_frac_unnormalized = None

        if learned_variance:
            output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)

        return UnetOutput(output, var_interp_frac_unnormalized)

    # 上下文管理器,用于在GPU上处理一个unet
    @contextmanager
    def one_unet_in_gpu(self, unet_number = None, unet = None):
        assert exists(unet_number) ^ exists(unet)

        if exists(unet_number):
            unet = self.get_unet(unet_number)

        # 设备
        cuda, cpu = torch.device('cuda'), torch.device('cpu')

        self.cuda()

        devices = [module_device(unet) for unet in self.unets]

        self.unets.to(cpu)
        unet.to(cuda)

        yield

        for unet, device in zip(self.unets, devices):
            unet.to(device)
    # 定义一个动态阈值函数,用于改进分类器自由引导设置中的夹紧操作
    def dynamic_threshold(self, x):
        """ proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
        
        # s 是阈值量
        # 静态阈值设定为 s = 1
        s = 1.
        # 如果使用动态阈值
        if self.use_dynamic_thres:
            # 计算 x 的绝对值的分位数,用于确定动态阈值
            s = torch.quantile(
                rearrange(x, 'b ... -> b (...)').abs(),
                self.dynamic_thres_percentile,
                dim = -1
            )

            # 夹紧阈值,确保不小于1
            s.clamp_(min = 1.)
            s = s.view(-1, *((1,) * (x.ndim - 1)))

        # 根据阈值夹紧 x,取值范围为 [-s, s],然后归一化
        x = x.clamp(-s, s) / s
        return x

    # 计算模型的均值、后验方差和后验对数方差,用于生成样本
    def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
        # 断言条件,确保条件满足
        assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

        # 默认情况下,使用 unet 进行前向传播
        model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))

        # 解析 unet 输出,获取预测值和方差插值比例
        pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)

        # 根据预测值选择不同的处理方式
        if predict_v:
            x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
        elif predict_x_start:
            x_start = pred
        else:
            x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)

        # 如果需要对去噪后的结果进行夹紧
        if clip_denoised:
            x_start = self.dynamic_threshold(x_start)

        # 计算模型均值、后验方差和后验对数方差
        model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)

        # 如果使用了学习的方差
        if learned_variance:
            # 根据网络预测的最大和最小对数 beta 值进行插值,计算后验对数方差和后验方差
            min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
            max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
            var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)

            if self.learned_variance_constrain_frac:
                var_interp_frac = var_interp_frac.sigmoid()

            posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
            posterior_variance = posterior_log_variance.exp()

        return model_mean, posterior_variance, posterior_log_variance, x_start

    # 生成样本,使用模型均值和后验方差
    @torch.no_grad()
    def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
        b, *_, device = *x.shape, x.device
        # 计算模型均值、后验方差和后验对数方差
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
        noise = torch.randn_like(x)
        # 当 t == 0 时不添加噪声
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        return pred, x_start

    # 生成样本,使用模型均值和后验方差
    @torch.no_grad()
    # 定义一个函数,用于执行采样循环,生成图片
    def p_sample_loop_ddpm(
        self,
        unet,
        shape,
        image_embed,
        noise_scheduler,
        predict_x_start = False,
        predict_v = False,
        learned_variance = False,
        clip_denoised = True,
        lowres_cond_img = None,
        text_encodings = None,
        cond_scale = 1,
        is_latent_diffusion = False,
        lowres_noise_level = None,
        inpaint_image = None,
        inpaint_mask = None,
        inpaint_resample_times = 5
    ):
        # 获取设备信息
        device = self.device

        # 获取 batch 大小
        b = shape[0]
        # 生成随机噪声图片
        img = torch.randn(shape, device = device)

        x_start = None # for self-conditioning

        is_inpaint = exists(inpaint_image)
        resample_times = inpaint_resample_times if is_inpaint else 1

        if is_inpaint:
            # 对 inpaint_image 进行归一化处理
            inpaint_image = self.normalize_img(inpaint_image)
            # 将 inpaint_image 调整大小以匹配 shape[-1]
            inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
            # 将 inpaint_mask 调整大小以匹配 shape[-1]
            inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
            inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
            inpaint_mask = inpaint_mask.bool()

        if not is_latent_diffusion:
            # 对 lowres_cond_img 进行归一化处理
            lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)

        # 遍历时间步骤
        for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
            is_last_timestep = time == 0

            # 遍历重新采样次数
            for r in reversed(range(0, resample_times)):
                is_last_resample_step = r == 0

                # 生成时间步骤的张量
                times = torch.full((b,), time, device = device, dtype = torch.long)

                if is_inpaint:
                    # 根据 repaint 论文进行处理
                    noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
                    img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)

                self_cond = x_start if unet.self_cond else None

                # 执行采样操作
                img, x_start = self.p_sample(
                    unet,
                    img,
                    times,
                    image_embed = image_embed,
                    text_encodings = text_encodings,
                    cond_scale = cond_scale,
                    self_cond = self_cond,
                    lowres_cond_img = lowres_cond_img,
                    lowres_noise_level = lowres_noise_level,
                    predict_x_start = predict_x_start,
                    predict_v = predict_v,
                    noise_scheduler = noise_scheduler,
                    learned_variance = learned_variance,
                    clip_denoised = clip_denoised
                )

                if is_inpaint and not (is_last_timestep or is_last_resample_step):
                    # 在 repaint 中,每个步骤最多重新噪声和重新采样 10 次
                    img = noise_scheduler.q_sample_from_to(img, times - 1, times)

        if is_inpaint:
            img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)

        # 对生成的图片进行反归一化处理
        unnormalize_img = self.unnormalize_img(img)
        return unnormalize_img

    @torch.no_grad()
    def p_sample_loop_ddim(
        self,
        unet,
        shape,
        image_embed,
        noise_scheduler,
        timesteps,
        eta = 1.,
        predict_x_start = False,
        predict_v = False,
        learned_variance = False,
        clip_denoised = True,
        lowres_cond_img = None,
        text_encodings = None,
        cond_scale = 1,
        is_latent_diffusion = False,
        lowres_noise_level = None,
        inpaint_image = None,
        inpaint_mask = None,
        inpaint_resample_times = 5
        # 解构 shape 变量,获取批次大小、设备、总时间步长、alpha 值、eta 值
        batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta

        # 在 0 到总时间步长之间生成 timesteps + 2 个步长的时间点,并去除最后一个时间点
        times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]

        # 将时间点列表反转,并转换为整数列表
        times = list(reversed(times.int().tolist()))
        # 生成时间点对列表
        time_pairs = list(zip(times[:-1], times[1:]))
        # 过滤出时间点对中第一个时间点大于第二个时间点的情况
        time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))

        # 检查是否存在 inpaint_image
        is_inpaint = exists(inpaint_image)
        # 如果存在 inpaint_image,则使用 inpaint_resample_times,否则为 1
        resample_times = inpaint_resample_times if is_inpaint else 1

        # 如果存在 inpaint_image,则对其进行归一化和调整大小,并生成对应的掩码
        if is_inpaint:
            inpaint_image = self.normalize_img(inpaint_image)
            inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
            inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
            inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
            inpaint_mask = inpaint_mask.bool()

        # 生成随机噪声图像
        img = torch.randn(shape, device = device)

        # 初始化 x_start 为 None,用于自条件
        x_start = None

        # 如果不是潜在扩散,则对低分辨率条件图像进行归一化
        if not is_latent_diffusion:
            lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)

        # 遍历时间点对
        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            # 判断是否为最后一个时间步
            is_last_timestep = time_next == 0

            # 反向遍历重采样次数
            for r in reversed(range(0, resample_times)):
                # 判断是否为最后一个重采样步骤
                is_last_resample_step = r == 0

                # 获取当前时间点和下一个时间点的 alpha 值
                alpha = alphas[time]
                alpha_next = alphas[time_next]

                # 生成当前时间点的条件
                time_cond = torch.full((batch,), time, device = device, dtype = torch.long)

                # 如果存在 inpaint_image,则根据时间点和掩码生成噪声图像
                if is_inpaint:
                    noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
                    img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)

                # 根据 unet 的 self_cond 属性确定是否使用自条件
                self_cond = x_start if unet.self_cond else None

                # 使用 unet 模型生成输出
                unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)

                # 解析 unet 输出
                pred, _ = self.parse_unet_output(learned_variance, unet_output)

                # 预测 x0
                if predict_v:
                    x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
                elif predict_x_start:
                    x_start = pred
                else:
                    x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)

                # 可能对 x0 进行裁剪
                if clip_denoised:
                    x_start = self.dynamic_threshold(x_start)

                # 预测噪声
                pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)

                # 计算 c1 和 c2
                c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
                c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
                noise = torch.randn_like(img) if not is_last_timestep else 0.

                # 更新图像
                img = x_start * alpha_next.sqrt() + \
                      c1 * noise + \
                      c2 * pred_noise

                # 如果存在 inpaint_image 且不是最后一个时间步或最后一个重采样步骤,则重新噪声和重采样
                if is_inpaint and not (is_last_timestep or is_last_resample_step):
                    time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
                    img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)

        # 如果存在 inpaint_image,则将图像还原为原始图像
        if exists(inpaint_image):
            img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)

        # 将图像还原为原始图像
        img = self.unnormalize_img(img)
        # 返回生成的图像
        return img

    # 禁用梯度
    @torch.no_grad()
    # 定义一个方法 p_sample_loop,接受可变数量的参数和关键字参数
    def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
        # 获取噪声调度器的总时间步数
        num_timesteps = noise_scheduler.num_timesteps

        # 如果未指定时间步数,则使用默认值为总时间步数
        timesteps = default(timesteps, num_timesteps)
        # 断言指定的时间步数不超过总时间步数
        assert timesteps <= num_timesteps
        # 判断是否为动态维度
        is_ddim = timesteps < num_timesteps

        # 如果不是动态维度,则调用 p_sample_loop_ddpm 方法
        if not is_ddim:
            return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)

        # 如果是动态维度,则调用 p_sample_loop_ddim 方法
        return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
    # 定义一个函数,计算损失值
    def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
        # 设置默认的噪声函数
        noise = default(noise, lambda: torch.randn_like(x_start))

        # 将输入归一化到[-1, 1]范围内
        if not is_latent_diffusion:
            x_start = self.normalize_img(x_start)
            lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)

        # 获取带噪声的输入图像
        x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)

        # 设置 UNet 的参数
        unet_kwargs = dict(
            image_embed = image_embed,
            text_encodings = text_encodings,
            lowres_cond_img = lowres_cond_img,
            lowres_noise_level = lowres_noise_level,
        )

        # 自我条件
        self_cond = None

        # 如果 UNet 具有自我条件属性且随机数小于0.5
        if unet.self_cond and random.random() < 0.5:
            with torch.no_grad():
                unet_output = unet(x_noisy, times, **unet_kwargs)
                self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
                self_cond = self_cond.detach()

        # 前向传播获取模型预测
        unet_output = unet(
            x_noisy,
            times,
            **unet_kwargs,
            self_cond = self_cond,
            image_cond_drop_prob = self.image_cond_drop_prob,
            text_cond_drop_prob = self.text_cond_drop_prob,
        )

        pred, _ = self.parse_unet_output(learned_variance, unet_output)

        # 根据需求选择目标值
        if predict_v:
            target = noise_scheduler.calculate_v(x_start, times, noise)
        elif predict_x_start:
            target = x_start
        else:
            target = noise

        # 计算损失值
        loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b (...)', 'mean')

        # 对损失值进行重新加权
        loss = noise_scheduler.p2_reweigh_loss(loss, times)

        loss = loss.mean()

        if not learned_variance:
            # 如果不使用学习的方差,则返回简单的损失值
            return loss

        # 如果学习方差,还包括额外的 kl 损失
        true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
        model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)

        # KL 损失
        detached_model_mean = model_mean.detach()
        kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
        kl = meanflat(kl) * NAT

        # 解码器负对数似然
        decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
        decoder_nll = meanflat(decoder_nll) * NAT

        # 在第一个时间步返回解码器 NLL,否则返回 KL 散度
        vb_losses = torch.where(times == 0, decoder_nll, kl)

        # 对 vb 损失进行加权
        vb_loss = vb_losses.mean() * self.vb_loss_weight

        return loss + vb_loss

    # 禁止梯度计算
    @torch.no_grad()
    # 评估装饰器
    @eval_decorator
    # 定义一个名为sample的方法,用于生成样本
    def sample(
        self,
        image = None, # 图像输入,默认为None
        image_embed = None, # 图像嵌入,默认为None
        text = None, # 文本输入,默认为None
        text_encodings = None, # 文本编码,默认为None
        batch_size = 1, # 批处理大小,默认为1
        cond_scale = 1., # 条件比例,默认为1.0
        start_at_unet_number = 1, # 开始的UNET编号,默认为1
        stop_at_unet_number = None, # 结束的UNET编号,默认为None
        distributed = False, # 是否分布式,默认为False
        inpaint_image = None, # 修复图像,默认为None
        inpaint_mask = None, # 修复掩码,默认为None
        inpaint_resample_times = 5, # 修复重采样次数,默认为5
        one_unet_in_gpu_at_time = True # 是否一次在GPU上运行一个UNET,默认为True
    # 定义一个名为forward的方法,用于前向传播
    def forward(
        self,
        image, # 图像输入
        text = None, # 文本输入,默认为None
        image_embed = None, # 图像嵌入,默认为None
        text_encodings = None, # 文本编码,默认为None
        unet_number = None, # UNET编号,默认为None
        return_lowres_cond_image = False # 是否返回低分辨率的条件图像,用于调试上采样器的目的,默认为False
        ):
        # 断言语句,用于检查是否指定了要训练的 unet 编号,如果训练多个 unet,则必须指定要训练的 unet 编号
        assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)'
        # 如果未指定 unet 编号,则默认为 1
        unet_number = default(unet_number, 1)
        # 计算 unet 编号在列表中的索引
        unet_index = unet_number - 1

        # 获取指定编号的 unet 模型
        unet = self.get_unet(unet_number)

        # 获取对应 unet 编号的 VAE 模型、噪声调度器、低分辨率条件器、目标图像大小、预测 x 起始位置、预测速度、随机裁剪大小、学习的方差、图像的形状和设备
        vae                 = self.vaes[unet_index]
        noise_scheduler     = self.noise_schedulers[unet_index]
        lowres_conditioner  = self.lowres_conds[unet_index]
        target_image_size   = self.image_sizes[unet_index]
        predict_x_start     = self.predict_x_start[unet_index]
        predict_v           = self.predict_v[unet_index]
        random_crop_size    = self.random_crop_sizes[unet_index]
        learned_variance    = self.learned_variance[unet_index]
        b, c, h, w, device, = *image.shape, image.device

        # 断言语句,用于检查图像通道数是否与模型要求的通道数相同
        assert image.shape[1] == self.channels
        # 断言语句,用于检查图像的高度和宽度是否大于等于目标图像大小
        assert h >= target_image_size and w >= target_image_size

        # 生成一组随机时间步长
        times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)

        # 如果未提供图像嵌入且不是无条件生成,则使用 CLIP 模型对图像进行嵌入
        if not exists(image_embed) and not self.unconditional:
            assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
            image_embed, _ = self.clip.embed_image(image)

        # 如果提供了文本且未提供文本编码且不是无条件生成,则使用 CLIP 模型对文本进行嵌入
        if exists(text) and not exists(text_encodings) and not self.unconditional:
            assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
            _, text_encodings = self.clip.embed_text(text)

        # 断言语句,用于检查是否传入了文本编码
        assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
        # 断言语句,用于检查是否指定了不基于文本编码的解码器
        assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'

        # 如果存在低分辨率条件器,则对图像进行低分辨率处理
        lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None)
        # 调整图像大小为目标图像大小
        image = resize_image_to(image, target_image_size, nearest = True)

        # 如果存在随机裁剪大小,则对图像进行随机裁剪
        if exists(random_crop_size):
            aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)

            # 确保低分辨率条件器和图像都以相同方式进行增强
            # 详细信息请参考 https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
            image = aug(image)
            lowres_cond_img = aug(lowres_cond_img, params = aug._params)

        # 判断是否为潜在扩散模型
        is_latent_diffusion = not isinstance(vae, NullVQGanVAE)

        # 将 VAE 模型设置为评估模式,并禁用梯度计算
        vae.eval()
        with torch.no_grad():
            # 对图像进行编码
            image = vae.encode(image)
            # 对低分辨率条件图像进行编码
            lowres_cond_img = maybe(vae.encode)(lowres_cond_img)

        # 计算损失
        losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)

        # 如果不返回低分辨率条件图像,则返回损失
        if not return_lowres_cond_image:
            return losses

        # 返回损失和低分辨率条件图像
        return losses, lowres_cond_img
# 主类定义

class DALLE2(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        prior,  # 先验模型
        decoder,  # 解码器
        prior_num_samples = 2  # 先验模型采样数量,默认为2
    ):
        super().__init__()
        # 断言先验模型和解码器的类型
        assert isinstance(prior, DiffusionPrior)
        assert isinstance(decoder, Decoder)
        # 初始化先验模型和解码器
        self.prior = prior
        self.decoder = decoder

        self.prior_num_samples = prior_num_samples  # 先验模型采样数量
        self.decoder_need_text_cond = self.decoder.condition_on_text_encodings  # 解码器是否需要文本编码

        self.to_pil = T.ToPILImage()  # 转换为 PIL 图像

    @torch.no_grad()
    @eval_decorator
    # 前向传播函数
    def forward(
        self,
        text,  # 文本输入
        cond_scale = 1.,  # 条件缩放
        prior_cond_scale = 1.,  # 先验条件缩放
        return_pil_images = False  # 是否返回 PIL 图像
    ):
        device = module_device(self)  # 获取设备
        one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)  # 判断是否为单个文本

        if isinstance(text, str) or is_list_str(text):
            text = [text] if not isinstance(text, (list, tuple)) else text
            text = tokenizer.tokenize(text).to(device)  # 对文本进行标记化处理并移动到设备

        # 从先验模型中采样图像编码
        image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)

        text_cond = text if self.decoder_need_text_cond else None  # 如果解码器需要文本编码,则传入文本编码,否则为None
        # 从解码器中采样图像
        images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)

        if return_pil_images:
            images = list(map(self.to_pil, images.unbind(dim = 0)))  # 将图像转换为 PIL 图像

        if one_text:
            return first(images)  # 如果只有一个文本输入,则返回第一个图像

        return images  # 返回图像列表

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dataloaders\decoder_loader.py

import os
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import numpy as np
import fsspec
import shutil

def get_shard(filename):
    """
    Filenames with shards in them have a consistent structure that we can take advantage of
    Standard structure: path/to/file/prefix_string_00001.ext
    """
    try:
        return filename.split("_")[-1].split(".")[0]
    except ValueError:
        raise RuntimeError(f"Could not find shard for filename {filename}")

def get_example_file(fs, path, file_format):
    """
    Given a file system and a file extension, return the example file
    """
    return fs.glob(os.path.join(path, f"*.{file_format}"))[0]

def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
    """Given a datum of {"__key__": str, "__url__": str, ...} adds the corresponding embedding and yields"""
    previous_tar_url = None
    current_embeddings = None
    # Get a reference to an abstract file system where the embeddings are stored
    embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
    example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
    example_embedding_shard = get_shard(example_embedding_file)
    emb_shard_width = len(example_embedding_shard)
    # Easier to get the basename without the shard once than search through for the correct file every time
    embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"

    def load_corresponding_embeds(tar_url):
      """Finds and reads the npy files that contains embeddings for the given webdataset tar"""
      shard = int(tar_url.split("/")[-1].split(".")[0])
      embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
      with embeddings_fs.open(embedding_url) as f:
        data = np.load(f)
      return torch.from_numpy(data)

    for sample in samples:
        try:
            tar_url = sample["__url__"]
            key = sample["__key__"]
            if tar_url != previous_tar_url:
                # If the tar changed, we need to download new embeddings
                # This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
                previous_tar_url = tar_url
                current_embeddings = load_corresponding_embeds(tar_url)
                
            embedding_index = int(key[-index_width:])
            embedding = current_embeddings[embedding_index]
            # We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
            if torch.count_nonzero(embedding) == 0:
                raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
            sample[sample_key] = embedding
            yield sample
        except Exception as exn:  # From wds implementation
            if handler(exn):
                continue
            else:
                break
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)

def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
    """Finds if there is a corresponding embedding for the tarfile at { url: [URL] }"""
    embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
    embedding_files = embeddings_fs.ls(embeddings_path)
    get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
    embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files])  # Sets have O(1) check for member

    get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
    # 遍历 tarfiles 列表中的每个 tarfile
    for tarfile in tarfiles:
        try:
            # 获取 tarfile 对应的 webdataset shard
            webdataset_shard = get_tar_shard(tarfile["url"])
            # 如果该 shard 有关联的 embeddings 文件,则返回该 tarfile
            # 否则继续迭代直到找到有关联的 embeddings 文件
            if webdataset_shard in embedding_shards:
                yield tarfile
        except Exception as exn:  # 从 wds 实现中捕获异常
            # 如果 handler 函数处理了异常,则继续循环
            if handler(exn):
                continue
            # 如果 handler 函数未处理异常,则跳出循环
            else:
                break
# 创建一个过滤器,用于跳过未关联的碎片
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)

# 将样本中的img_emb和text_emb键合并为一个键"emb": { "text": text_emb, "img": img_emb }
# 如果text_emb和img_emb中的一个或两个不存在于样本中,则只添加存在的部分
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
    """
    Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
    either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
    """
    for sample in samples:
        try:
            sample['emb'] = {}
            if 'text_emb' in sample:
                sample['emb']['text'] = sample['text_emb']
            if 'img_emb' in sample:
                sample['emb']['img'] = sample['img_emb']
            yield sample
        except Exception as exn:  # From wds implementation
            if handler(exn):
                continue
            else:
                break

# 验证样本中是否存在所需的键,如果不存在则抛出异常
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
    """
    Requires that both the image and embedding are present in the sample
    This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
    """
    for sample in samples:
        try:
            for key in required_keys:
                assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
            yield sample
        except Exception as exn:  # From wds implementation
            if handler(exn):
                continue
            else:
                break

# 创建一个过滤器,用于验证样本中是否存在所需的键
key_verifier = wds.filters.pipelinefilter(verify_keys)

# ImageEmbeddingDataset类,是DataPipeline的流式接口包装器,返回图像嵌入对
# 从webdataset中读取npy文件作为嵌入,如果存在的话。如果设置了embedding_folder_url,则会从替代来源插入它们
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
    """
    A fluid interface wrapper for DataPipline that returns image embedding pairs
    Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
    """

    def __init__(
            self,
            urls,
            img_embedding_folder_url=None,
            text_embedding_folder_url=None,
            index_width=None,
            img_preproc=None,
            extra_keys=[],
            handler=wds.handlers.reraise_exception,
            resample=False,
            shuffle_shards=True
    def preproc(self, sample):
        """Applies the preprocessing for images"""
        if self.img_preproc is not None:
            sample["jpg"] = self.img_preproc(sample["jpg"])
        return sample

# 创建一个图像嵌入数据加载器的便捷函数
def create_image_embedding_dataloader(
    tar_url,
    num_workers,
    batch_size,
    img_embeddings_url=None,
    text_embeddings_url=None,
    index_width=None,
    shuffle_num = None,
    shuffle_shards = True,
    resample_shards = False, 
    img_preproc=None,
    extra_keys=[],
    handler=wds.handlers.reraise_exception#warn_and_continue
):
    """
    Convenience function to create an image embedding dataseta and dataloader in one line

    :param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
    :param num_workers: The number of workers to use for the dataloader
    :param batch_size: The batch size to use for the dataloader
    :param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
        Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
    :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
            For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
    :param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
    :param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
    :param resample_shards: 如果为True,则对webdataset分片进行有放回抽样。如果设置为True,则需要设置自己的epoch大小,因为它将无限重采样。
    :param handler: webdataset处理程序。
    """
    # 创建ImageEmbeddingDataset对象
    ds = ImageEmbeddingDataset(
        tar_url,
        img_embedding_folder_url=img_embeddings_url,
        text_embedding_folder_url=text_embeddings_url,
        index_width=index_width,
        shuffle_shards=shuffle_shards,
        resample=resample_shards,
        extra_keys=extra_keys,
        img_preproc=img_preproc,
        handler=handler
    )
    # 如果设置了shuffle_num并且大于0,则对数据集进行洗牌
    if shuffle_num is not None and shuffle_num > 0:
        ds.shuffle(1000)
    # 返回一个DataLoader对象
    return DataLoader(
        ds,
        num_workers=num_workers,
        batch_size=batch_size,
        prefetch_factor=2,  # 这可能是一个好主意,使其较高,以便预取下一个npy文件
        pin_memory=True,
        shuffle=False
    )

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dataloaders\prior_loader.py

# 从 math 模块中导入 ceil 函数
from math import ceil
# 从 clip 模块中导入 tokenize 函数
from clip import tokenize
# 从 embedding_reader 模块中导入 EmbeddingReader 类
from embedding_reader import EmbeddingReader
# 从 torch 模块中导入 from_numpy 函数和 DataLoader 类
from torch import from_numpy
from torch.utils.data import IterableDataset, DataLoader

# 定义 PriorEmbeddingDataset 类,继承自 IterableDataset 类
class PriorEmbeddingDataset(IterableDataset):
    """
    PriorEmbeddingDataset is a wrapper of EmbeddingReader.

    It enables one to simplify the logic necessary to yield samples from
    the different EmbeddingReader configurations available.
    """

    # 初始化方法
    def __init__(
        self,
        text_conditioned: bool,
        batch_size: int,
        start: int,
        stop: int,
        image_reader,
        text_reader: EmbeddingReader = None,
    ) -> None:
        # 调用父类的初始化方法
        super(PriorEmbeddingDataset).__init__()

        # 设置属性值
        self.text_conditioned = text_conditioned

        # 如果不是文本条件,则设置文本阅读器
        if not self.text_conditioned:
            self.text_reader = text_reader

        # 设置属性值
        self.image_reader = image_reader
        self.start = start
        self.stop = stop
        self.batch_size = batch_size

    # 返回数据集的长度
    def __len__(self):
        return self.stop - self.start

    # 迭代器方法
    def __iter__(self):
        # 定义 loader_args 字典
        loader_args = dict(
            batch_size=self.batch_size,
            start=self.start,
            end=self.stop,
            show_progress=False,
        )

        # 如果请求的数据是文本条件的,则只加载图像
        if self.text_conditioned:
            self.loader = self.image_reader(**loader_args)
        # 否则,包括文本嵌入并绕过元数据
        else:
            self.loader = zip(
                self.image_reader(**loader_args), self.text_reader(**loader_args)
            )

        # 返回格式化后的数据加载器
        return self

    # 获取下一个数据样本
    def __next__(self):
        try:
            return self.get_sample()
        except StopIteration:
            raise StopIteration

    # 返回对象的字符串表示形式
    def __str__(self):
        return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"

    # 设置起始点
    def set_start(self, start):
        """
        Adjust the starting point within the reader, useful for resuming an epoch
        """
        self.start = start

    # 获取起始点
    def get_start(self):
        return self.start

    # 获取样本数据
    def get_sample(self):
        """
        pre-proocess data from either reader into a common format
        """
        if self.text_conditioned:
            image_embedding, caption = next(self.loader)

            image_embedding = from_numpy(image_embedding)
            tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)

            return image_embedding, tokenized_caption

        else:
            (image_embedding, _), (text_embedding, _) = next(self.loader)

            image_embedding = from_numpy(image_embedding)
            text_embedding = from_numpy(text_embedding)

            return image_embedding, text_embedding


# 辅助函数

# 分发数据给每个排名
def distribute_to_rank(start, stop, rank, world_size):
    """
    Distribute data to each rank given the world size.

    Return:
        - New start and stop points for this rank.
    """
    num_samples = int(stop - start)

    per_rank = int(ceil((num_samples) / float(world_size)))

    assert (
        per_rank > 0
    ), f"Number of samples per rank must be larger than 0, (found: {per_rank})"

    rank_start = start + rank * per_rank

    rank_stop = min(rank_start + per_rank, stop)

    new_length = rank_stop - rank_start

    assert (
        new_length > 0
    ), "Calculated start and stop points result in a length of zero for this rank."

    return rank_start, rank_stop

# 获取阅读器对象
def get_reader(
    text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
):
    """
    Create an EmbeddingReader object from the specified URLs

    get_reader() will always expect a url to image embeddings.

    If text-conditioned, it will also expect a meta_url for the captions.
    Otherwise, it will need txt_url for the matching text embeddings.

    Returns an image_reader object if text-conditioned.
    Otherwise it returns both an image_reader and a text_reader
    """

    # 断言确保图像 URL 不为空
    assert img_url is not None, "Must supply a image url"

    # 如果需要文本条件,则断言确保元数据 URL 不为空
    if text_conditioned:
        assert meta_url is not None, "Must supply meta url if text-conditioned"

        # 创建一个 EmbeddingReader 对象用于读取图像数据
        image_reader = EmbeddingReader(
            embeddings_folder=img_url,
            file_format="parquet_npy",
            # 假设标题列存在且是唯一请求的列
            meta_columns=["caption"],
            metadata_folder=meta_url,
        )

        # 返回图像数据读取器
        return image_reader

    # 否则,需要文本嵌入,返回两个读取器
    assert (
        txt_url is not None
    ), "Must supply text embedding url if not text-conditioning"

    # 创建一个 EmbeddingReader 对象用于读取图像数据
    image_reader = EmbeddingReader(img_url, file_format="npy")
    # 创建一个 EmbeddingReader 对象用于读取文本数据
    text_reader = EmbeddingReader(txt_url, file_format="npy")

    # 返回图像数据读取器和文本数据读取器
    return image_reader, text_reader
def make_splits(
    text_conditioned: bool,
    batch_size: int,
    num_data_points: int,
    train_split: float,
    eval_split: float,
    image_reader: EmbeddingReader,
    text_reader: EmbeddingReader = None,
    start=0,
    rank=0,
    world_size=1,
):
    """
    Split an embedding reader object as needed.

    NOTE: make_splits() will infer the test set size from your train and eval.

    Input:
        - text_conditioned: whether to prepare text-conditioned training data
        - batch_size: the batch size for a single gpu
        - num_data_points: the total number of data points you wish to train on
        - train_split: the percentage of data you wish to train on
        - eval_split: the percentage of data you wish to validate on
        - image_reader: the image_reader you wish to split
        - text_reader: the text_reader you want to split (if !text_conditioned)
        - start: the starting point within your dataset
        - rank: the rank of your worker
        - world_size: the total world size of your distributed training run

    Returns:
        - PyTorch Dataloaders that yield tuples of (img, txt) data.
    """

    assert start < image_reader.count, "start position cannot exceed reader count."

    # verify that the num_data_points does not exceed the max points
    if num_data_points > (image_reader.count - start):
        print(
            "Specified count is larger than what's available...defaulting to reader's count."
        )
        num_data_points = image_reader.count

    # compute split points
    train_set_size = int(train_split * num_data_points)
    eval_set_size = int(eval_split * num_data_points)
    eval_start = train_set_size
    eval_stop = int(eval_start + eval_set_size)

    assert (
        train_split + eval_split
    ) < 1.0, "Specified train and eval split is too large to infer a test split."

    # distribute to rank
    rank_train_start, rank_train_stop = distribute_to_rank(
        start, train_set_size, rank, world_size
    )
    rank_eval_start, rank_eval_stop = distribute_to_rank(
        train_set_size, eval_stop, rank, world_size
    )
    rank_test_start, rank_test_stop = distribute_to_rank(
        eval_stop, num_data_points, rank, world_size
    )

    # wrap up splits into a dict
    train_split_args = dict(
        start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
    )
    eval_split_args = dict(
        start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
    )
    test_split_args = dict(
        start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
    )

    if text_conditioned:
        # add the text-conditioned args to a unified dict
        reader_args = dict(
            text_conditioned=text_conditioned,
            image_reader=image_reader,
        )

        train_split_args = dict(**reader_args, **train_split_args)
        eval_split_args = dict(**reader_args, **eval_split_args)
        test_split_args = dict(**reader_args, **test_split_args)

        train = PriorEmbeddingDataset(**train_split_args)
        val = PriorEmbeddingDataset(**eval_split_args)
        test = PriorEmbeddingDataset(**test_split_args)

    else:
        # add the non-conditioned args to a unified dict
        reader_args = dict(
            text_conditioned=text_conditioned,
            image_reader=image_reader,
            text_reader=text_reader,
        )

        train_split_args = dict(**reader_args, **train_split_args)
        eval_split_args = dict(**reader_args, **eval_split_args)
        test_split_args = dict(**reader_args, **test_split_args)

        train = PriorEmbeddingDataset(**train_split_args)
        val = PriorEmbeddingDataset(**eval_split_args)
        test = PriorEmbeddingDataset(**test_split_args)

    # true batch size is specifed in the PriorEmbeddingDataset
    train_loader = DataLoader(train, batch_size=None)
    eval_loader = DataLoader(val, batch_size=None)
    # 创建一个数据加载器用于加载测试数据集,batch_size设置为None表示每次加载整个数据集
    test_loader = DataLoader(test, batch_size=None)

    # 返回训练数据加载器、验证数据加载器和测试数据加载器
    return train_loader, eval_loader, test_loader

  1. If your shard files have the paths protocol://path/to/shard/00104.tar, then the base url would be protocol://path/to/shard/{}.tar. If you are using a protocol like s3, you need to pipe the tars. For example pipe:s3cmd get s3://bucket/path/{}.tar -. ↩︎

  2. This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename 00104.tar, your shard length is 5. ↩︎

  3. Inside the webdataset tar, you have files named something like 001045945.jpg. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is 001041 and index is 5945). The index_width in this case is 4. ↩︎

posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(15)  评论(0编辑  收藏  举报