Lucidrains-系列项目源码解析-九-

Lucidrains 系列项目源码解析(九)

Dataloaders

In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.

Decoder: Image Embedding Dataset

When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a webdataset that contains .jpg and .npy files in the .tars that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain .npy files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the .jpg and the index of the embedding in the .npy. So, for example, 0001.tar from the webdataset with image 00010509.jpg (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a img_emb_0001.npy which contains a NumPy array with the embedding at index 509.

Generating a dataset of this type:

  1. Use img2dataset to generate a webdataset.
  2. Use clip-retrieval to convert the images to embeddings.
  3. Use embedding-dataset-reordering to reorder the embeddings into the expected format.

Usage:

from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader

# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
    tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
    embeddings_url="path/or/url/to/embeddings/folder",     # Included if .npy files are not in webdataset. Left out or set to None otherwise
    num_workers=4,
    batch_size=32,
    shard_width=4,                                         # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
    shuffle_num=200,                                       # Does a shuffle of the data with a buffer size of 200
    shuffle_shards=True,                                   # Shuffle the order the shards are read in
    resample_shards=False,                                 # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
    print(img.shape)  # torch.Size([32, 3, 256, 256])
    print(emb.shape)  # torch.Size([32, 512])
    # Train decoder only as shown above

# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
    urls="/path/or/url/to/webdataset/{0000..9999}.tar",
    embedding_folder_url="path/or/url/to/embeddings/folder",
    shard_width=4,
    shuffle_shards=True,
    resample=False
)

Diffusion Prior: Prior Embedding Dataset

When training the prior it is much more efficient to work with pre-computed embeddings. The PriorEmbeddingDataset class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.

To utilize the PriorEmbeddingDataset, all you need to do is make a single call to get_reader() which will create EmbeddingReader object(s) for you. Afterwards, you can utilize make_splits() to cleanly create DataLoader objects from for your training run.

If you are training in a distributed manner, make_splits() accepts rank and world_size arguments to properly distribute to each process. The defaults for these values are rank=0 and world_size=1, so single-process training can safely ignore these parameters.

Usage:

from dalle2_pytorch.dataloaders import get_reader, make_splits

# grab embeddings from some specified location
IMG_URL = "data/img_emb/"
META_URL = "data/meta/"

reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)

# some config for training
TRAIN_ARGS = {
    "world_size": 3,
    "text_conditioned": True,
    "start": 0,
    "num_data_points": 10000,
    "batch_size": 2,
    "train_split": 0.5,
    "eval_split": 0.25,
    "image_reader": reader,
}

# specifying a rank will handle allocation internally
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dataloaders\simple_image_only_dataloader.py

# 导入所需的库
from pathlib import Path
import torch
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image

# 定义一个循环生成器函数,用于无限循环遍历数据集
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 定义数据集类
class Dataset(data.Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        # 获取指定文件夹下所有指定扩展名的文件路径
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 定义数据预处理的操作
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.CenterCrop(image_size),
            transforms.ToTensor()
        ])

    # 返回数据集的长度
    def __len__(self):
        return len(self.paths)

    # 根据索引获取数据
    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 获取图像数据的数据加载器
def get_images_dataloader(
    folder,
    *,
    batch_size,
    image_size,
    shuffle = True,
    cycle_dl = True,
    pin_memory = True
):
    # 创建数据集对象
    ds = Dataset(folder, image_size)
    # 创建数据加载器对象
    dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)

    # 如果需要循环遍历数据加载器,则将数据加载器设置为循环生成器
    if cycle_dl:
        dl = cycle(dl)
    return dl

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dataloaders\__init__.py

# 从dalle2_pytorch.dataloaders.decoder_loader模块中导入ImageEmbeddingDataset和create_image_embedding_dataloader函数
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
# 从dalle2_pytorch.dataloaders.prior_loader模块中导入make_splits、get_reader和PriorEmbeddingDataset函数
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    **kwargs
):
    # 根据是否需要梯度过滤参数
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果权重衰减为0,则使用 Adam 优化器
    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要对参数进行分组权重衰减
    if group_wd_params:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 将参数分为需要权重衰减和不需要权重衰减的两组
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 使用 AdamW 优化器,设置学习率、权重衰减、动量参数和 epsilon
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\trackers.py

# 导入所需的库
import urllib.request
import os
import json
from pathlib import Path
import shutil
from itertools import zip_longest
from typing import Any, Optional, List, Union
from pydantic import BaseModel

import torch
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.version import __version__
from packaging import version

# 常量定义
DEFAULT_DATA_PATH = './.tracker-data'

# 辅助函数
def exists(val):
    return val is not None

# 定义基础日志类
class BaseLogger:
    """
    An abstract class representing an object that can log data.
    Parameters:
        data_path (str): A file path for storing temporary data.
        verbose (bool): Whether of not to always print logs to the console.
    """
    def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
        self.data_path = Path(data_path)
        self.resume = resume
        self.auto_resume = auto_resume
        self.verbose = verbose

    def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
        """
        Initializes the logger.
        Errors if the logger is invalid.
        full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
        """
        raise NotImplementedError

    def log(self, log, **kwargs) -> None:
        raise NotImplementedError

    def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
        raise NotImplementedError

    def log_file(self, file_path, **kwargs) -> None:
        raise NotImplementedError

    def log_error(self, error_string, **kwargs) -> None:
        raise NotImplementedError

    def get_resume_data(self, **kwargs) -> dict:
        """
        Sets tracker attributes that along with { "resume": True } will be used to resume training.
        It is assumed that after init is called this data will be complete.
        If the logger does not have any resume functionality, it should return an empty dict.
        """
        raise NotImplementedError

# 定义控制台日志类
class ConsoleLogger(BaseLogger):
    def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
        print("Logging to console")

    def log(self, log, **kwargs) -> None:
        print(log)

    def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
        pass

    def log_file(self, file_path, **kwargs) -> None:
        pass

    def log_error(self, error_string, **kwargs) -> None:
        print(error_string)

    def get_resume_data(self, **kwargs) -> dict:
        return {}

# 定义Wandb日志类
class WandbLogger(BaseLogger):
    """
    Logs to a wandb run.
    Parameters:
        data_path (str): A file path for storing temporary data.
        wandb_entity (str): The wandb entity to log to.
        wandb_project (str): The wandb project to log to.
        wandb_run_id (str): The wandb run id to resume.
        wandb_run_name (str): The wandb run name to use.
    """
    def __init__(self,
        data_path: str,
        wandb_entity: str,
        wandb_project: str,
        wandb_run_id: Optional[str] = None,
        wandb_run_name: Optional[str] = None,
        **kwargs
    ):
        super().__init__(data_path, **kwargs)
        self.entity = wandb_entity
        self.project = wandb_project
        self.run_id = wandb_run_id
        self.run_name = wandb_run_name
    # 初始化函数,接受完整配置、额外配置和其他参数,不返回任何内容
    def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
        # 断言 wandb_entity 必须被指定以使用 wandb 记录器
        assert self.entity is not None, "wandb_entity must be specified for wandb logger"
        # 断言 wandb_project 必须被指定以使用 wandb 记录器
        assert self.project is not None, "wandb_project must be specified for wandb logger"
        # 导入 wandb 模块或打印错误信息
        self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
        # 设置环境变量 WANDB_SILENT 为 true
        os.environ["WANDB_SILENT"] = "true"
        # 初始化 wandb 运行对象
        init_object = {
            "entity": self.entity,
            "project": self.project,
            "config": {**full_config.dict(), **extra_config}
        }
        # 如果指定了运行名称,则设置到初始化对象中
        if self.run_name is not None:
            init_object['name'] = self.run_name
        # 如果要恢复运行,则设置相应参数
        if self.resume:
            assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
            if self.run_name is not None:
                print("You are renaming a run. I hope that is what you intended.")
            init_object['resume'] = 'must'
            init_object['id'] = self.run_id

        # 初始化 wandb 运行
        self.wandb.init(**init_object)
        print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")

    # 记录日志函数
    def log(self, log, **kwargs) -> None:
        # 如果设置了 verbose,则打印日志
        if self.verbose:
            print(log)
        # 记录日志到 wandb
        self.wandb.log(log, **kwargs)

    # 记录图片函数
    def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
        """
        Takes a tensor of images and a list of captions and logs them to wandb.
        """
        # 创建 wandb 图像对象列表
        wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
        # 记录图像到 wandb
        self.wandb.log({ image_section: wandb_images }, **kwargs)

    # 记录文件函数
    def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
        # 如果未指定基本路径,则将文件路径的父路径作为基本路径
        if base_path is None:
            base_path = Path(file_path).parent
        # 保存文件到 wandb
        self.wandb.save(str(file_path), base_path = str(base_path))

    # 记录错误函数
    def log_error(self, error_string, step=None, **kwargs) -> None:
        # 如果设置了 verbose,则打印错误信息
        if self.verbose:
            print(error_string)
        # 记录错误信息到 wandb
        self.wandb.log({"error": error_string, **kwargs}, step=step)

    # 获取恢复数据函数
    def get_resume_data(self, **kwargs) -> dict:
        # 为了恢复运行,需要 wandb_entity、wandb_project 和 wandb_run_id
        return {
            "entity": self.entity,
            "project": self.project,
            "run_id": self.wandb.run.id
        }
# 定义一个字典,将不同的日志类型映射到对应的日志类
logger_type_map = {
    'console': ConsoleLogger,
    'wandb': WandbLogger,
}

# 创建日志记录器的函数,根据日志类型选择对应的日志类进行实例化
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
    # 如果日志类型为'custom',则抛出未实现错误
    if logger_type == 'custom':
        raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
    try:
        # 根据日志类型从映射字典中获取对应的日志类
        logger_class = logger_type_map[logger_type]
    except KeyError:
        # 如果日志类型未知,则抛出数值错误
        raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
    # 返回实例化的日志类对象
    return logger_class(data_path, **kwargs)

# 定义一个抽象基类,表示可以加载模型检查点的对象
class BaseLoader:
    """
    An abstract class representing an object that can load a model checkpoint.
    Parameters:
        data_path (str): A file path for storing temporary data.
    """
    def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
        self.data_path = Path(data_path)
        self.only_auto_resume = only_auto_resume

    def init(self, logger: BaseLogger, **kwargs) -> None:
        raise NotImplementedError

    def recall() -> dict:
        raise NotImplementedError

# 定义一个从 URL 下载文件并加载的加载器类
class UrlLoader(BaseLoader):
    """
    A loader that downloads the file from a url and loads it
    Parameters:
        data_path (str): A file path for storing temporary data.
        url (str): The url to download the file from.
    """
    def __init__(self, data_path: str, url: str, **kwargs):
        super().__init__(data_path, **kwargs)
        self.url = url

    def init(self, logger: BaseLogger, **kwargs) -> None:
        # 确保要下载的文件存在
        pass  # TODO: Actually implement that

    def recall(self) -> dict:
        # 下载文件
        save_path = self.data_path / 'loaded_checkpoint.pth'
        urllib.request.urlretrieve(self.url, str(save_path))
        # 加载文件
        return torch.load(str(save_path), map_location='cpu')

# 定义一个从本地路径加载文件的加载器类
class LocalLoader(BaseLoader):
    """
    A loader that loads a file from a local path
    Parameters:
        data_path (str): A file path for storing temporary data.
        file_path (str): The path to the file to load.
    """
    def __init__(self, data_path: str, file_path: str, **kwargs):
        super().__init__(data_path, **kwargs)
        self.file_path = Path(file_path)

    def init(self, logger: BaseLogger, **kwargs) -> None:
        # 确保要加载的文件存在
        if not self.file_path.exists() and not self.only_auto_resume:
            raise FileNotFoundError(f'Model not found at {self.file_path}')

    def recall(self) -> dict:
        # 加载文件
        return torch.load(str(self.file_path), map_location='cpu')

# 定义一个从 wandb 运行中加载模型的加载器类
class WandbLoader(BaseLoader):
    """
    A loader that loads a model from an existing wandb run
    """
    def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
        super().__init__(data_path, **kwargs)
        self.run_path = wandb_run_path
        self.file_path = wandb_file_path

    def init(self, logger: BaseLogger, **kwargs) -> None:
        self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
        # 确保文件可以被下载
        if self.wandb.run is not None and self.run_path is None:
            self.run_path = self.wandb.run.path
            assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
        assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
        assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
        
        os.environ["WANDB_SILENT"] = "true"
        pass  # TODO: Actually implement that

    def recall(self) -> dict:
        file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
        return torch.load(file_reference.name, map_location='cpu')

# 定义一个字典,将不同的加载器类型映射到对应的加载器类
loader_type_map = {
    'url': UrlLoader,
    'local': LocalLoader,
    # 键为'wandb',值为WandbLoader的键值对
    'wandb': WandbLoader,
# 结束当前代码块
}

# 创建数据加载器的函数,根据给定的加载器类型和数据路径返回相应的加载器对象
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
    # 如果加载器类型为'custom',则抛出未实现错误
    if loader_type == 'custom':
        raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
    # 尝试获取对应加载器类型的加载器类
    try:
        loader_class = loader_type_map[loader_type]
    except KeyError:
        # 如果加载器类型未知,则抛出数值错误
        raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
    # 返回使用给定数据路径和参数初始化的加载器对象
    return loader_class(data_path, **kwargs)

# 基础保存器类
class BaseSaver:
    # 初始化函数
    def __init__(self,
        data_path: str,
        save_latest_to: Optional[Union[str, bool]] = None,
        save_best_to: Optional[Union[str, bool]] = None,
        save_meta_to: Optional[str] = None,
        save_type: str = 'checkpoint',
        **kwargs
    ):
        # 初始化保存器属性
        self.data_path = Path(data_path)
        self.save_latest_to = save_latest_to
        self.saving_latest = save_latest_to is not None and save_latest_to is not False
        self.save_best_to = save_best_to
        self.saving_best = save_best_to is not None and save_best_to is not False
        self.save_meta_to = save_meta_to
        self.saving_meta = save_meta_to is not None
        self.save_type = save_type
        # 断言保存类型为'checkpoint'或'model'
        assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
        # 断言至少有一个保存选项被指定
        assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'

    # 初始化函数,抛出未实现错误
    def init(self, logger: BaseLogger, **kwargs) -> None:
        raise NotImplementedError

    # 保存文件函数,抛出未实现错误
    def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
        """
        Save a general file under save_meta_to
        """
        raise NotImplementedError

# 本地保存器类,继承自基础保存器类
class LocalSaver(BaseSaver):
    # 初始化函数
    def __init__(self,
        data_path: str,
        **kwargs
    ):
        # 调用父类初始化函数
        super().__init__(data_path, **kwargs)

    # 初始化函数,确保要保存的目录存在
    def init(self, logger: BaseLogger, **kwargs) -> None:
        print(f"Saving {self.save_type} locally")
        # 如果数据路径不存在,则创建目录
        if not self.data_path.exists():
            self.data_path.mkdir(parents=True)

    # 保存文件函数,复制文件到指定路径
    def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
        # 获取保存路径文件名
        save_path_file_name = Path(save_path).name
        # 确保父目录存在
        save_path_parent = Path(save_path).parent
        if not save_path_parent.exists():
            save_path_parent.mkdir(parents=True)
        print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
        # 复制文件到保存路径
        shutil.copy(local_path, save_path)

# Wandb保存器类,继承自基础保存器类
class WandbSaver(BaseSaver):
    # 初始化函数
    def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
        # 调用父类初始化函数
        super().__init__(data_path, **kwargs)
        self.run_path = wandb_run_path

    # 初始化函数,初始化wandb并确保用户可以上传到此运行
    def init(self, logger: BaseLogger, **kwargs) -> None:
        self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
        os.environ["WANDB_SILENT"] = "true"
        # 确保用户可以上传到此运行
        if self.run_path is not None:
            entity, project, run_id = self.run_path.split("/")
            self.run = self.wandb.init(entity=entity, project=project, id=run_id)
        else:
            assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
            self.run = self.wandb.run
        # TODO: 现在实际检查上传是否可行
        print(f"Saving to wandb run {self.run.path}-{self.run.name}")
    # 保存文件到指定路径,并在wandb中记录相同的文件结构
    def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
        # 获取保存路径中的文件名
        save_path_file_name = Path(save_path).name
        # 打印保存文件的信息,包括文件名、保存类型和wandb运行的路径和名称
        print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
        # 将保存路径设置为数据路径加上保存路径
        save_path = Path(self.data_path) / save_path
        # 创建保存路径的父目录,如果不存在则创建
        save_path.parent.mkdir(parents=True, exist_ok=True)
        # 复制本地文件到保存路径
        shutil.copy(local_path, save_path)
        # 在wandb中保存文件,设置基本路径为数据路径,保存策略为立即保存
        self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
class HuggingfaceSaver(BaseSaver):
    # HuggingfaceSaver 类继承自 BaseSaver 类
    def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
        # 初始化方法,接受数据路径、Huggingface 仓库、token 路径等参数
        super().__init__(data_path, **kwargs)
        # 调用父类的初始化方法
        self.huggingface_repo = huggingface_repo
        # 设置 Huggingface 仓库
        self.token_path = token_path
        # 设置 token 路径

    def init(self, logger: BaseLogger, **kwargs):
        # 初始化方法,接受 logger 和其他参数
        # 确保用户可以上传到仓库
        self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
        # 导入 huggingface_hub 模块
        try:
            identity = self.hub.whoami()  # Errors if not logged in
            # 获取当前用户信息,如果未登录则报错
            # 然后表示已登录
        except:
            # 如果未登录,使用 token_path 设置 token
            if not os.path.exists(self.token_path):
                raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
            with open(self.token_path, "r") as f:
                token = f.read().strip()
            self.hub.HfApi.set_access_token(token)
            identity = self.hub.whoami()
        print(f"Saving to huggingface repo {self.huggingface_repo}")
        # 打印保存到 Huggingface 仓库的信息

    def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
        # 保存文件到 Huggingface 很简单,只需要上传文件并指定正确的名称
        save_path_file_name = Path(save_path).name
        # 获取保存路径的文件名
        print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
        # 打印保存文件的信息
        self.hub.upload_file(
            path_or_fileobj=str(local_path),
            path_in_repo=str(save_path),
            repo_id=self.huggingface_repo
        )
        # 上传文件到 Huggingface 仓库

saver_type_map = {
    'local': LocalSaver,
    'wandb': WandbSaver,
    'huggingface': HuggingfaceSaver
}
# 不同的保存类型映射到不同的 Saver 类

def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
    # 创建 Saver 对象的方法,接受保存类型、数据路径和其他参数
    if saver_type == 'custom':
        raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
    # 如果是自定义类型,则抛出未实现错误
    try:
        saver_class = saver_type_map[saver_type]
    except KeyError:
        raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
    # 获取对应保存类型的 Saver 类
    return saver_class(data_path, **kwargs)
    # 返回创建的 Saver 对象

class Tracker:
    # Tracker 类
    def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
        # 初始化方法,接受数据路径、是否覆盖数据路径和是否为虚拟模式等参数
        self.data_path = Path(data_path)
        # 设置数据路径为给定的路径
        if not dummy_mode:
            # 如果不是虚拟模式
            if not overwrite_data_path:
                assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
                # 断言数据路径不存在,如果存在则报错
                if not self.data_path.exists():
                    self.data_path.mkdir(parents=True)
        # 如果数据路径不存在,则创建该路径
        self.logger: BaseLogger = None
        # 初始化 logger 为 None
        self.loader: Optional[BaseLoader] = None
        # 初始化 loader 为 None
        self.savers: List[BaseSaver]= []
        # 初始化 savers 为空列表
        self.dummy_mode = dummy_mode
        # 设置虚拟模式标志
    def _load_auto_resume(self) -> bool:
        # 加载自动恢复数据
        # 如果文件不存在,则返回 False。如果自动恢复已启用,则打印警告,以便用户知道这是第一次运行。
        if not self.auto_resume_path.exists():
            if self.logger.auto_resume:
                print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
            return False

        # 现在我们知道自动恢复文件存在,但如果我们不自动恢复,我们应该删除它,以免下次意外加载它
        if not self.logger.auto_resume:
            print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
            self.auto_resume_path.unlink()
            return False

        # 否则,我们将将 JSON 读入字典,将覆盖 logger.__dict__ 的部分
        with open(self.auto_resume_path, 'r') as f:
            auto_resume_dict = json.load(f)
        # 检查记录器是否与自动恢复保存的类型相同
        if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
            raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
        # 然后我们准备用自动恢复保存覆盖记录器
        self.logger.__dict__["resume"] = True
        print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
        self.logger.__dict__.update(auto_resume_dict)
        return True

    def _save_auto_resume(self):
        # 从记录器获取自动恢复字典,并将 "logger_type" 添加到其中,然后将其保存到 auto_resume 文件
        auto_resume_dict = self.logger.get_resume_data()
        auto_resume_dict['logger_type'] = self.logger.__class__.__name__
        with open(self.auto_resume_path, 'w') as f:
            json.dump(auto_resume_dict, f)

    def init(self, full_config: BaseModel, extra_config: dict):
        self.auto_resume_path = self.data_path / 'auto_resume.json'
        # 检查是否恢复运行
        self.did_auto_resume = self._load_auto_resume()
        if self.did_auto_resume:
            print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
            print(f"New logger config: {self.logger.__dict__}")
        
        self.save_metadata = dict(
            version = version.parse(__version__)
        )  # 将保存在检查点或模型旁��的数据
        self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps']  # 如果尝试将它们保存为元数据,这些键将导致我们出错

        assert self.logger is not None, '`logger` must be set before `init` is called'
        if self.dummy_mode:
            # 我们唯一需要的是一个加载器
            if self.loader is not None:
                self.loader.init(self.logger)
            return
        assert len(self.savers) > 0, '`savers` must be set before `init` is called'

        self.logger.init(full_config, extra_config)
        if self.loader is not None:
            self.loader.init(self.logger)
        for saver in self.savers:
            saver.init(self.logger)

        if self.logger.auto_resume:
            # 然后我们需要保存自动恢复文件。假定在调用 logger.init 后,记录器已准备好保存。
            self._save_auto_resume()

    def add_logger(self, logger: BaseLogger):
        self.logger = logger

    def add_loader(self, loader: BaseLoader):
        self.loader = loader

    def add_saver(self, saver: BaseSaver):
        self.savers.append(saver)
    # 记录日志,如果处于虚拟模式,则直接返回
    def log(self, *args, **kwargs):
        if self.dummy_mode:
            return
        # 调用logger对象的log方法记录日志
        self.logger.log(*args, **kwargs)
    
    # 记录图片日志,如果处于虚拟模式,则直接返回
    def log_images(self, *args, **kwargs):
        if self.dummy_mode:
            return
        # 调用logger对象的log_images方法记录图片日志
        self.logger.log_images(*args, **kwargs)

    # 记录文件日志,如果处于虚拟模式,则直接返回
    def log_file(self, *args, **kwargs):
        if self.dummy_mode:
            return
        # 调用logger对象的log_file方法记录文件日志
        self.logger.log_file(*args, **kwargs)

    # 保存配置文件,如果处于虚拟模式,则直接返回
    def save_config(self, current_config_path: str, config_name = 'config.json'):
        if self.dummy_mode:
            return
        # 将当前配置文件复制到data_path根目录下的config_name文件中
        shutil.copy(current_config_path, self.data_path / config_name)
        # 遍历所有savers,如果saver正在保存元数据,则将当前配置文件保存到指定路径下
        for saver in self.savers:
            if saver.saving_meta:
                remote_path = Path(saver.save_meta_to) / config_name
                saver.save_file(current_config_path, str(remote_path))

    # 添加保存元数据,用于与模型或解码器一起保存
    def add_save_metadata(self, state_dict_key: str, metadata: Any):
        """
        Adds a new piece of metadata that will be saved along with the model or decoder.
        """
        # 将元数据添加到save_metadata字典中
        self.save_metadata[state_dict_key] = metadata

    # 保存状态字典,根据保存类型和文件路径保存状态字典
    def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
        """
        Gets the state dict to be saved and writes it to file_path.
        If save_type is 'checkpoint', we save the entire trainer state dict.
        If save_type is 'model', we save only the model state dict.
        """
        assert save_type in ['checkpoint', 'model']
        if save_type == 'checkpoint':
            # 创建不包含黑名单键的元数据字典,以便在创建状态字典时不出错
            metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
            # 保存整个trainer状态字典
            trainer.save(file_path, overwrite=True, **kwargs, **metadata)
        elif save_type == 'model':
            if isinstance(trainer, DiffusionPriorTrainer):
                prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
                prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
                # 如果模型中包含CLIP,则移除CLIP
                original_clip = prior.clip
                prior.clip = None
                model_state_dict = prior.state_dict()
                prior.clip = original_clip
            elif isinstance(trainer, DecoderTrainer):
                decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
                # 如果模型中包含CLIP,则移除CLIP
                original_clip = decoder.clip
                decoder.clip = None
                if trainer.use_ema:
                    trainable_unets = decoder.unets
                    decoder.unets = trainer.unets  # 交换EMA unets
                    model_state_dict = decoder.state_dict()
                    decoder.unets = trainable_unets  # 恢复原始unets
                else:
                    model_state_dict = decoder.state_dict()
                decoder.clip = original_clip
            else:
                raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
            # 构建状态字典,包含save_metadata和模型的state_dict
            state_dict = {
                **self.save_metadata,
                'model': model_state_dict
            }
            # 将状态字典保存到文件路径中
            torch.save(state_dict, file_path)
        return Path(file_path)
    # 保存训练器的状态和模型到指定路径
    def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
        # 如果处于虚拟模式,则直接返回
        if self.dummy_mode:
            return
        # 如果既不是最佳模型也不是最新模型,则无需保存
        if not is_best and not is_latest:
            # 无需执行任何操作
            return
        # 保存检查点和模型到指定路径
        checkpoint_path = self.data_path / 'checkpoint.pth'
        self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
        model_path = self.data_path / 'model.pth'
        self._save_state_dict(trainer, 'model', model_path, **kwargs)
        print("Saved cached models")
        # 调用保存器的保存方法
        for saver in self.savers:
            local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
            # 如果需要保存最新模型且当前为最新模型,则保存最新模型
            if saver.saving_latest and is_latest:
                latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
                try:
                    saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
                except Exception as e:
                    self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
                    print(f'Error saving checkpoint: {e}')
            # 如果需要保存最佳模型且当前为最佳模型,则保存最佳模型
            if saver.saving_best and is_best:
                best_checkpoint_path = saver.save_best_to.format(**kwargs)
                try:
                    saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
                except Exception as e:
                    self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
                    print(f'Error saving checkpoint: {e}')
    
    @property
    # 定义是否可以执行回溯操作
    def can_recall(self):
        return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
    
    # 执行回溯操作
    def recall(self):
        if self.can_recall:
            return self.loader.recall()
        else:
            raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\trainer.py

# 导入必要的库
import time
import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from contextlib import nullcontext
from collections.abc import Iterable

import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler

# 导入自定义模块
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version

# 导入第三方库
import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator, DistributedType
import numpy as np

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 从字典中选择指定键的值并弹出这些键
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return str.startswith(prefix)

# 根据键的前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀将字典分组并修剪键
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 将数字分成若干组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# 装饰器

# 将函数参数转换为 torch 张量
def cast_torch_tensor(fn):
    @wraps(fn)
    def inner(model, *args, **kwargs):
        device = kwargs.pop('_device', next(model.parameters()).device)
        cast_device = kwargs.pop('_cast_device', True)
        cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)

        kwargs_keys = kwargs.keys()
        all_args = (*args, *kwargs.values())
        split_kwargs_index = len(all_args) - len(kwargs_keys)
        all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))

        if cast_device:
            all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))

        if cast_deepspeed_precision:
            try:
                accelerator = model.accelerator
                if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
                    cast_type_map = {
                        "fp16": torch.half,
                        "bf16": torch.bfloat16,
                        "no": torch.float
                    }
                    precision_type = cast_type_map[accelerator.mixed_precision]
                    all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
            except AttributeError:
                # Then this model doesn't have an accelerator
                pass

        args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
        kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))

        out = fn(model, *args, **kwargs)
        return out
    return inner

# 梯度累积函数

# 将可迭代对象分割成指定大小的子集
def split_iterable(it, split_size):
    accum = []
    for ind in range(ceil(len(it) / split_size)):
        start_index = ind * split_size
        accum.append(it[start_index: (start_index + split_size)])
    return accum

# 如果未提供分割大小,则返回原始对象
def split(t, split_size = None):
    if not exists(split_size):
        return t
    # 检查输入是否为 torch.Tensor 类型
    if isinstance(t, torch.Tensor):
        # 如果是,则按照指定维度和大小拆分张量
        return t.split(split_size, dim=0)

    # 检查输入是否为可迭代对象
    if isinstance(t, Iterable):
        # 如果是,则调用自定义函数 split_iterable() 拆分可迭代对象
        return split_iterable(t, split_size)

    # 如果输入既不是 torch.Tensor 也不是可迭代对象,则返回类型错误
    return TypeError
# 在给定条件下,查找数组中第一个满足条件的元素并返回
def find_first(cond, arr):
    for el in arr:
        if cond(el):
            return el
    return None

# 将位置参数和关键字参数拆分成一个包含所有参数值的元组,并计算参数的长度
def split_args_and_kwargs(*args, split_size = None, **kwargs):
    # 将所有参数值组合成一个元组
    all_args = (*args, *kwargs.values())
    len_all_args = len(all_args)
    # 查找第一个是 torch.Tensor 类型的参数
    first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
    # 断言第一个参数存在
    assert exists(first_tensor)

    # 获取第一个参数的长度作为批量大小
    batch_size = len(first_tensor)
    # 如果未指定拆分大小,则默认为批量大小
    split_size = default(split_size, batch_size)
    # 计算拆分后的块数
    num_chunks = ceil(batch_size / split_size)

    # 计算关键字参数的长度和键名
    dict_len = len(kwargs)
    dict_keys = kwargs.keys()
    # 计算关键字参数在拆分后的参数中的索引位置
    split_kwargs_index = len_all_args - dict_len

    # 对所有参数进行拆分,如果参数是 torch.Tensor 或可迭代对象,则按拆分大小进行拆分,否则复制参数值
    split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
    # 计算每个块的大小
    chunk_sizes = tuple(map(len, split_all_args[0]))

    # 遍历每个块,将参数和关键字参数拆分成块,并生成块大小的比例和拆分后的参数
    for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
        chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
        chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
        chunk_size_frac = chunk_size / batch_size
        yield chunk_size_frac, (chunked_args, chunked_kwargs)

# 扩散先验训练器

# 将函数分块处理
def prior_sample_in_chunks(fn):
    @wraps(fn)
    def inner(self, *args, max_batch_size = None, **kwargs):
        # 如果未指定最大批量大小,则直接调用函数
        if not exists(max_batch_size):
            return fn(self, *args, **kwargs)

        # 拆分参数并调用函数,将结果拼接在一起
        outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
        return torch.cat(outputs, dim = 0)
    return inner

# 扩散先验训练器类
class DiffusionPriorTrainer(nn.Module):
    def __init__(
        self,
        diffusion_prior,
        accelerator = None,
        use_ema = True,
        lr = 3e-4,
        wd = 1e-2,
        eps = 1e-6,
        max_grad_norm = None,
        group_wd_params = True,
        warmup_steps = None,
        cosine_decay_max_steps = None,
        **kwargs
    # 初始化函数,设置一些成员变量和参数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言确保传入的参数是 DiffusionPrior 类型的对象
        assert isinstance(diffusion_prior, DiffusionPrior)

        # 将参数按照前缀 'ema_' 分组并去除前缀,返回未分组的参数和 ema 参数
        ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
        # 将参数按照前缀 'accelerator_' 分组并去除前缀,返回未分组的参数和 accelerator 参数
        accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)

        # 如果 accelerator 不存在,则根据参数创建一个 Accelerator 对象
        if not exists(accelerator):
            accelerator = Accelerator(**accelerator_kwargs)

        # 设置一些有用的成员变量

        self.accelerator = accelerator
        self.text_conditioned = diffusion_prior.condition_on_text_encodings

        # 设置设备

        self.device = accelerator.device
        diffusion_prior.to(self.device)

        # 保存模型

        self.diffusion_prior = diffusion_prior

        # 混合精度检查

        if (
            exists(self.accelerator) 
            and self.accelerator.distributed_type == DistributedType.DEEPSPEED 
            and self.diffusion_prior.clip is not None
            ):
            # 确保 clip 使用正确的精度,否则 deepspeed 会报错
            cast_type_map = {
                "fp16": torch.half,
                "bf16": torch.bfloat16,
                "no": torch.float
            }
            precision_type = cast_type_map[accelerator.mixed_precision]
            assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
            self.diffusion_prior.clip.to(precision_type)

        # 优化器设置

        self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)

        # 根据参数创建优化器
        self.optimizer = get_optimizer(
            self.diffusion_prior.parameters(),
            **self.optim_kwargs,
            **kwargs
        )

        # 如果存在 cosine_decay_max_steps,则使用 CosineAnnealingLR 调度器,否则使用 LambdaLR 调度器
        if exists(cosine_decay_max_steps):
            self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
        else:
            self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
        
        # 如果存在 warmup_steps,则使用 LinearWarmup 调度器
        self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None

        # 如果使用 HFA,则分发模型
        self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)

        # 指数移动平均设置

        self.use_ema = use_ema

        if self.use_ema:
            self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)

        # 如果需要梯度裁剪

        self.max_grad_norm = max_grad_norm

        # 内部跟踪步数

        self.register_buffer('step', torch.tensor([0], device = self.device))

    # 实用函数

    def save(self, path, overwrite = True, **kwargs):

        # 只在主进程上保存
        if self.accelerator.is_main_process:
            print(f"Saving checkpoint at step: {self.step.item()}")
            path = Path(path)
            assert not (path.exists() and not overwrite)
            path.parent.mkdir(parents = True, exist_ok = True)

            # FIXME: LambdaLR 由于 pickling 问题无法保存
            save_obj = dict(
                optimizer = self.optimizer.state_dict(),
                scheduler = self.scheduler.state_dict(),
                warmup_scheduler = self.warmup_scheduler,
                model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
                version = version.parse(__version__),
                step = self.step,
                **kwargs
            )

            # 如果使用指数移动平均,则保存相关参数
            if self.use_ema:
                save_obj = {
                    **save_obj,
                    'ema': self.ema_diffusion_prior.state_dict(),
                    'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # 为了方便只保存 ema 模型
                }

            # 保存模型
            torch.save(save_obj, str(path))
    def load(self, path_or_state, overwrite_lr = True, strict = True):
        """
        Load a checkpoint of a diffusion prior trainer.

        Will load the entire trainer, including the optimizer and EMA.

        Params:
            - path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
            - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
            - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match

        Returns:
            loaded_obj (dict): The loaded checkpoint dictionary
        """

        # all processes need to load checkpoint. no restriction here
        if isinstance(path_or_state, str):
            path = Path(path_or_state)
            assert path.exists()
            loaded_obj = torch.load(str(path), map_location=self.device)

        elif isinstance(path_or_state, dict):
            loaded_obj = path_or_state

        if version.parse(__version__) != loaded_obj['version']:
            print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')

        # unwrap the model when loading from checkpoint
        self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
        self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))

        self.optimizer.load_state_dict(loaded_obj['optimizer'])
        self.scheduler.load_state_dict(loaded_obj['scheduler'])

        # set warmupstep
        if exists(self.warmup_scheduler):
            self.warmup_scheduler.last_step = self.step.item()

        # ensure new lr is used if different from old one
        if overwrite_lr:
            new_lr = self.optim_kwargs["lr"]

            for group in self.optimizer.param_groups:
                group["lr"] = new_lr if group["lr"] > 0.0 else 0.0

        if self.use_ema:
            assert 'ema' in loaded_obj
            self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
            # below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
            self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])

        return loaded_obj

    # model functionality

    def update(self):

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
        
        self.optimizer.step()
        self.optimizer.zero_grad()

        # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
        if not self.accelerator.optimizer_step_was_skipped:
            sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
            with sched_context():
                self.scheduler.step()

        if self.use_ema:
            self.ema_diffusion_prior.update()

        self.step += 1

    @torch.no_grad()
    @cast_torch_tensor
    @prior_sample_in_chunks
    def p_sample_loop(self, *args, **kwargs):
        model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
        return model.p_sample_loop(*args, **kwargs)

    @torch.no_grad()
    @cast_torch_tensor
    @prior_sample_in_chunks
    def sample(self, *args, **kwargs):
        model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
        return model.sample(*args, **kwargs)

    @torch.no_grad()
    def sample_batch_size(self, *args, **kwargs):
        model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
        return model.sample_batch_size(*args, **kwargs)

    @torch.no_grad()
    @cast_torch_tensor
    @prior_sample_in_chunks
    # 调用加速器对象的unwrap_model方法,将扩散先验解包后调用clip对象的embed_text方法,返回结果
    def embed_text(self, *args, **kwargs):
        return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)

    # 使用装饰器将函数参数转换为torch张量
    def forward(
        self,
        *args,
        max_batch_size = None,
        **kwargs
    ):
        # 初始化总损失为0
        total_loss = 0.

        # 将参数和关键字参数按照指定大小分块,遍历每个分块
        for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
            # 使用加速器对象的autocast方法进行自动混合精度计算
            with self.accelerator.autocast():
                # 调用扩散先验函数,传入分块参数和关键字参数,计算损失
                loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
                # 将损失乘以分块大小比例
                loss = loss * chunk_size_frac

            # 将损失值加到总损失中
            total_loss += loss.item()

            # 如果处于训练状态,使用加速器对象的backward方法进行反向传播
            if self.training:
                self.accelerator.backward(loss)

        # 返回总损失值
        return total_loss
# 解码器训练器

# 定义一个装饰器函数,用于将输入数据分成多个批次进行处理
def decoder_sample_in_chunks(fn):
    @wraps(fn)
    def inner(self, *args, max_batch_size = None, **kwargs):
        # 如果未指定最大批次大小,则直接调用原始函数
        if not exists(max_batch_size):
            return fn(self, *args, **kwargs)

        # 如果解码器是无条件的,则将批次大小分组成多个子批次进行处理
        if self.decoder.unconditional:
            batch_size = kwargs.get('batch_size')
            batch_sizes = num_to_groups(batch_size, max_batch_size)
            outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
        else:
            # 如果解码器是有条件的,则将输入数据分成多个子块进行处理
            outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]

        # 将所有子批次或子块的输出拼接在一起
        return torch.cat(outputs, dim = 0)
    return inner

# 定义解码器训练器类
class DecoderTrainer(nn.Module):
    def __init__(
        self,
        decoder,
        accelerator = None,
        dataloaders = None,
        use_ema = True,
        lr = 1e-4,
        wd = 1e-2,
        eps = 1e-8,
        warmup_steps = None,
        cosine_decay_max_steps = None,
        max_grad_norm = 0.5,
        amp = False,
        group_wd_params = True,
        **kwargs
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言确保decoder是Decoder类型的实例
        assert isinstance(decoder, Decoder)
        # 将参数中以'ema_'开头的参数分组并去除前缀,返回两个字典
        ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)

        # 设置加速器,默认为Accelerator
        self.accelerator = default(accelerator, Accelerator)

        # 获取decoder中包含的unet数量
        self.num_unets = len(decoder.unets)

        # 设置是否使用指数移动平均
        self.use_ema = use_ema
        # 初始化ema_unets为一个空的ModuleList
        self.ema_unets = nn.ModuleList([])

        # 设置是否使用混合精度训练
        self.amp = amp

        # 可以对每个unet进行学习率、权重衰减等参数的细致定制

        # 将lr, wd, eps, warmup_steps, cosine_decay_max_steps映射为长度为num_unets的元组
        lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))

        # 断言确保所有unet的学习率都不超过1e-2
        assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'

        # 初始化优化器、调度器和预热调度器列表
        optimizers = []
        schedulers = []
        warmup_schedulers = []

        # 遍历decoder中的unets以及对应的lr, wd, eps, warmup_steps, cosine_decay_max_steps
        for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
            # 如果unet是nn.Identity类型,则添加None到列表中
            if isinstance(unet, nn.Identity):
                optimizers.append(None)
                schedulers.append(None)
                warmup_schedulers.append(None)
            else:
                # 获取unet的参数,初始化优化器
                optimizer = get_optimizer(
                    unet.parameters(),
                    lr = unet_lr,
                    wd = unet_wd,
                    eps = unet_eps,
                    group_wd_params = group_wd_params,
                    **kwargs
                )

                optimizers.append(optimizer)

                # 初始化调度器和预热调度器
                if exists(unet_cosine_decay_max_steps):
                    scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
                else:
                    scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)

                warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
                warmup_schedulers.append(warmup_scheduler)

                schedulers.append(scheduler)

            # 如果使用指数移动平均,则将unet添加到ema_unets中
            if self.use_ema:
                self.ema_unets.append(EMA(unet, **ema_kwargs))

        # 如果需要梯度裁剪
        self.max_grad_norm = max_grad_norm

        # 注册一个名为steps的缓冲区,值为长度为num_unets的全零张量
        self.register_buffer('steps', torch.tensor([0] * self.num_unets))

        # 如果使用的分布式类型是DEEPSPEED且decoder中有clip参数
        if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
            # 确保clip使用正确的精度,否则会出错
            cast_type_map = {
                "fp16": torch.half,
                "bf16": torch.bfloat16,
                "no": torch.float
            }
            precision_type = cast_type_map[accelerator.mixed_precision]
            assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
            clip = decoder.clip
            clip.to(precision_type)

        # 准备decoder和optimizers
        decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))

        self.decoder = decoder

        # 准备数据加载器

        train_loader = val_loader = None
        if exists(dataloaders):
            train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])

        self.train_loader = train_loader
        self.val_loader = val_loader

        # 存储优化器

        for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
            setattr(self, f'optim{opt_ind}', optimizer)

        # 存储调度器

        for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
            setattr(self, f'sched{sched_ind}', scheduler)

        # 存储预热调度器

        self.warmup_schedulers = warmup_schedulers

    # 验证并返回unet的编号
    def validate_and_return_unet_number(self, unet_number = None):
        # 如果只有一个unet,则默认unet_number为1
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        # 断言确保unet_number存在且在1到num_unets之间
        assert exists(unet_number) and 1 <= unet_number <= self.num_unets
        return unet_number
    # 返回指定 UNet 编号已经执行的步数
    def num_steps_taken(self, unet_number = None):
        # 验证并返回 UNet 编号
        unet_number = self.validate_and_return_unet_number(unet_number)
        # 返回指定 UNet 编号已经执行的步数
        return self.steps[unet_number - 1].item()

    # 保存模型状态到指定路径
    def save(self, path, overwrite = True, **kwargs):
        # 转换路径为 Path 对象
        path = Path(path)
        # 断言路径不存在或者可以覆盖
        assert not (path.exists() and not overwrite)
        # 创建父目录
        path.parent.mkdir(parents = True, exist_ok = True)

        # 构建保存对象字典
        save_obj = dict(
            model = self.accelerator.unwrap_model(self.decoder).state_dict(),
            version = __version__,
            steps = self.steps.cpu(),
            **kwargs
        )

        # 遍历 UNet 数量
        for ind in range(0, self.num_unets):
            optimizer_key = f'optim{ind}'
            scheduler_key = f'sched{ind}'

            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)

            optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
            scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None

            # 更新保存对象字典
            save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}

        # 如果使用 EMA,更新保存对象字典
        if self.use_ema:
            save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}

        # 保存模型状态到指定路径
        self.accelerator.save(save_obj, str(path))

    # 加载模型状态
    def load_state_dict(self, loaded_obj, only_model = False, strict = True):
        # 检查版本是否匹配
        if version.parse(__version__) != version.parse(loaded_obj['version']):
            self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')

        # 加载模型状态
        self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
        self.steps.copy_(loaded_obj['steps'])

        # 如果只加载模型状态,直接返回加载的对象
        if only_model:
            return loaded_obj

        # 遍历 UNet 数量,加载优化器和调度器状态
        for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):

            optimizer_key = f'optim{ind}'
            optimizer = getattr(self, optimizer_key)

            scheduler_key = f'sched{ind}'
            scheduler = getattr(self, scheduler_key)

            warmup_scheduler = self.warmup_schedulers[ind]

            if exists(optimizer):
                optimizer.load_state_dict(loaded_obj[optimizer_key])

            if exists(scheduler):
                scheduler.load_state_dict(loaded_obj[scheduler_key])

            if exists(warmup_scheduler):
                warmup_scheduler.last_step = last_step

        # 如果使用 EMA,加载 EMA 模型状态
        if self.use_ema:
            assert 'ema' in loaded_obj
            self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)

    # 加载模型状态
    def load(self, path, only_model = False, strict = True):
        # 转换路径为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()

        # 加载模型状态
        loaded_obj = torch.load(str(path), map_location = 'cpu')

        # 调用 load_state_dict 方法加载模型状态
        self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)

        return loaded_obj

    # 返回 EMA 模型列表
    @property
    def unets(self):
        return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

    # 增加步数
    def increment_step(self, unet_number):
        # 断言 UNet 编号在有效范围内
        assert 1 <= unet_number <= self.num_unets

        # 转换 UNet 编号为张量
        unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
        # 增加步数
        self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
    # 更新模型参数
    def update(self, unet_number = None):
        # 验证并返回UNET编号
        unet_number = self.validate_and_return_unet_number(unet_number)
        index = unet_number - 1

        # 获取对应的优化器和调度器
        optimizer = getattr(self, f'optim{index}')
        scheduler = getattr(self, f'sched{index}')

        # 如果存在最大梯度范数,则对解码器参数进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm)  # Automatically unscales gradients

        # 执行优化器的步骤和梯度清零操作
        optimizer.step()
        optimizer.zero_grad()

        # 获取热身调度器,并根据是否存在进行相应操作
        warmup_scheduler = self.warmup_schedulers[index]
        scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext

        # 在上下文中执行调度器的步骤
        with scheduler_context():
            scheduler.step()

        # 如果使用指数移动平均模型,则更新模型
        if self.use_ema:
            ema_unet = self.ema_unets[index]
            ema_unet.update()

        # 增加步数
        self.increment_step(unet_number)

    # 生成样本
    @torch.no_grad()
    @cast_torch_tensor
    @decoder_sample_in_chunks
    def sample(self, *args, **kwargs):
        distributed = self.accelerator.num_processes > 1
        base_decoder = self.accelerator.unwrap_model(self.decoder)

        was_training = base_decoder.training
        base_decoder.eval()

        # 根据是否使用EMA模型进行采样
        if kwargs.pop('use_non_ema', False) or not self.use_ema:
            out = base_decoder.sample(*args, **kwargs, distributed = distributed)
            base_decoder.train(was_training)
            return out

        # 切换为指数移动平均UNET进行采样
        trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
        base_decoder.unets = self.unets                  # swap in exponential moving averaged unets for sampling

        output = base_decoder.sample(*args, **kwargs, distributed = distributed)

        base_decoder.unets = trainable_unets             # restore original training unets

        # 将EMA模型UNET转回原始设备
        for ema in self.ema_unets:
            ema.restore_ema_model_device()

        base_decoder.train(was_training)
        return output

    # 嵌入文本
    @torch.no_grad()
    @cast_torch_tensor
    @prior_sample_in_chunks
    def embed_text(self, *args, **kwargs):
        return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)

    # 嵌入图像
    @torch.no_grad()
    @cast_torch_tensor
    @prior_sample_in_chunks
    def embed_image(self, *args, **kwargs):
        return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)

    # 前向传播
    @cast_torch_tensor
    def forward(
        self,
        *args,
        unet_number = None,
        max_batch_size = None,
        return_lowres_cond_image=False,
        **kwargs
    ):
        # 验证并返回UNET编号
        unet_number = self.validate_and_return_unet_number(unet_number)

        total_loss = 0.
        cond_images = []
        # 将参数拆分为指定大小的块,并进行处理
        for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
            with self.accelerator.autocast():
                # 调用解码器进行前向传播,计算损失
                loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
                # 如果需要返回低分辨率条件图像,则提取出来
                if return_lowres_cond_image:
                    loss, cond_image = loss_obj
                else:
                    loss = loss_obj
                    cond_image = None
                loss = loss * chunk_size_frac
                if cond_image is not None:
                    cond_images.append(cond_image)

            total_loss += loss.item()

            # 如果处于训练状态,则进行反向传播
            if self.training:
                self.accelerator.backward(loss)

        # 如果需要返回低分辨率条件图像,则返回总损失和条件图像的张量
        if return_lowres_cond_image:
            return total_loss, torch.stack(cond_images)
        else:
            return total_loss

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\train_configs.py

# 导入所需的库
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, model_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar

# 导入自定义的模块
from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa

# 导入自定义的模块中的类
from dalle2_pytorch.dalle2_pytorch import (
    CoCaAdapter,
    OpenAIClipAdapter,
    OpenClipAdapter,
    Unet,
    Decoder,
    DiffusionPrior,
    DiffusionPriorNetwork,
    XClipAdapter
)
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义类型变量
InnerType = TypeVar('InnerType')
ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
SingularOrIterable = Union[InnerType, ListOrTuple]

# 通用的 pydantic 类

# 训练集划分配置类
class TrainSplitConfig(BaseModel):
    train: float = 0.75
    val: float = 0.15
    test: float = 0.1

    # 验证所有参数的和是否为1
    @model_validator(mode = 'after')
    def validate_all(self, m):
        actual_sum = sum([*dict(self).values()])
        if actual_sum != 1.:
            raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}')
        return self

# 日志追踪配置类
class TrackerLogConfig(BaseModel):
    log_type: str = 'console'
    resume: bool = False  # For logs that are saved to unique locations, resume a previous run
    auto_resume: bool = False  # If the process crashes and restarts, resume from the run that crashed
    verbose: bool = False

    class Config:
        # 每个日志类型都有自己的参数,将通过配置传递
        extra = "allow"

    # 创建日志记录器
    def create(self, data_path: str):
        kwargs = self.dict()
        return create_logger(self.log_type, data_path, **kwargs)

# 加载追踪配置类
class TrackerLoadConfig(BaseModel):
    load_from: Optional[str] = None
    only_auto_resume: bool = False  # Only attempt to load if the logger is auto-resuming

    class Config:
        extra = "allow"

    # 创建加载器
    def create(self, data_path: str):
        kwargs = self.dict()
        if self.load_from is None:
            return None
        return create_loader(self.load_from, data_path, **kwargs)

# 保存追踪配置类
class TrackerSaveConfig(BaseModel):
    save_to: str = 'local'
    save_all: bool = False
    save_latest: bool = True
    save_best: bool = True

    class Config:
        extra = "allow"

    # 创建保存器
    def create(self, data_path: str):
        kwargs = self.dict()
        return create_saver(self.save_to, data_path, **kwargs)

# 追踪配置类
class TrackerConfig(BaseModel):
    data_path: str = '.tracker_data'
    overwrite_data_path: bool = False
    log: TrackerLogConfig
    load: Optional[TrackerLoadConfig] = None
    save: Union[List[TrackerSaveConfig], TrackerSaveConfig]

    # 创建追踪器
    def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
        tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
        # 添加日志记录器
        tracker.add_logger(self.log.create(self.data_path))
        # 添加加载器
        if self.load is not None:
            tracker.add_loader(self.load.create(self.data_path))
        # 添加保存器或保存器
        if isinstance(self.save, list):
            for save_config in self.save:
                tracker.add_saver(save_config.create(self.data_path))
        else:
            tracker.add_saver(self.save.create(self.data_path))
        # 初始化所有组件并验证所有数据是否有效
        tracker.init(full_config, extra_config)
        return tracker

# 扩散先验配置类

# 适配器配置类
class AdapterConfig(BaseModel):
    make: str = "openai"
    model: str = "ViT-L/14"
    base_model_kwargs: Optional[Dict[str, Any]] = None
    # 创建适配器对象的方法
    def create(self):
        # 如果适配器类型是 "openai",则返回 OpenAIClipAdapter 对象
        if self.make == "openai":
            return OpenAIClipAdapter(self.model)
        # 如果适配器类型是 "open_clip",则返回 OpenClipAdapter 对象
        elif self.make == "open_clip":
            # 获取预训练模型列表,并选择对应模型的检查点
            pretrained = dict(list_pretrained())
            checkpoint = pretrained[self.model]
            return OpenClipAdapter(name=self.model, pretrained=checkpoint)
        # 如果适配器类型是 "x-clip",则返回 XClipAdapter 对象
        elif self.make == "x-clip":
            return XClipAdapter(XCLIP(**self.base_model_kwargs))
        # 如果适配器类型是 "coca",则返回 CoCaAdapter 对象
        elif self.make == "coca":
            return CoCaAdapter(CoCa(**self.base_model_kwargs))
        # 如果适配器类型不匹配任何已知类型,则抛出属性错误异常
        else:
            raise AttributeError("No adapter with that name is available.")
# 定义 DiffusionPriorNetworkConfig 类,包含了模型的各种配置参数
class DiffusionPriorNetworkConfig(BaseModel):
    dim: int  # 模型维度
    depth: int  # 模型深度
    max_text_len: Optional[int] = None  # 最大文本长度
    num_timesteps: Optional[int] = None  # 时间步数
    num_time_embeds: int = 1  # 时间嵌入数量
    num_image_embeds: int = 1  # 图像嵌入数量
    num_text_embeds: int = 1  # 文本嵌入数量
    dim_head: int = 64  # 头部维度
    heads: int = 8  # 头部数量
    ff_mult: int = 4  # FeedForward 层倍数
    norm_in: bool = False  # 输入层是否进行归一化
    norm_out: bool = True  # 输出层是否进行归一化
    attn_dropout: float = 0.  # 注意力机制的 dropout 概率
    ff_dropout: float = 0.  # FeedForward 层的 dropout 概率
    final_proj: bool = True  # 是否进行最终投影
    normformer: bool = False  # 是否使用 Normformer
    rotary_emb: bool = True  # 是否使用旋转嵌入

    class Config:
        extra = "allow"

    # 创建 DiffusionPriorNetwork 对象
    def create(self):
        kwargs = self.dict()
        return DiffusionPriorNetwork(**kwargs)

# 定义 DiffusionPriorConfig 类,包含了模型的配置参数
class DiffusionPriorConfig(BaseModel):
    clip: Optional[AdapterConfig] = None  # 适配器配置
    net: DiffusionPriorNetworkConfig  # DiffusionPriorNetworkConfig 对象
    image_embed_dim: int  # 图像嵌入维度
    image_size: int  # 图像尺寸
    image_channels: int = 3  # 图像通道数
    timesteps: int = 1000  # 时间步数
    sample_timesteps: Optional[int] = None  # 采样时间步数
    cond_drop_prob: float = 0.  # 条件丢弃概率
    loss_type: str = 'l2'  # 损失类型
    predict_x_start: bool = True  # 是否预测 x 起始点
    beta_schedule: str = 'cosine'  # beta 调度
    condition_on_text_encodings: bool = True  # 是否在文本编码上进行条件

    class Config:
        extra = "allow"

    # 创建 DiffusionPrior 对象
    def create(self):
        kwargs = self.dict()

        has_clip = exists(kwargs.pop('clip'))
        kwargs.pop('net')

        clip = None
        if has_clip:
            clip = self.clip.create()

        diffusion_prior_network = self.net.create()
        return DiffusionPrior(net=diffusion_prior_network, clip=clip, **kwargs)

# 定义 DiffusionPriorTrainConfig 类,包含了训练配置参数
class DiffusionPriorTrainConfig(BaseModel):
    epochs: int = 1  # 训练轮数
    lr: float = 1.1e-4  # 学习率
    wd: float = 6.02e-2  # 权重衰减
    max_grad_norm: float = 0.5  # 最大梯度范数
    use_ema: bool = True  # 是否使用指数移动平均
    ema_beta: float = 0.99  # 指数移动平均的 beta
    amp: bool = False  # 是否使用混合精度训练
    warmup_steps: Optional[int] = None  # 热身步数
    save_every_seconds: int = 3600  # 保存模型的时间间隔
    eval_timesteps: List[int] = [64]  # 评估时间步数
    best_validation_loss: float = 1e9  # 最佳验证损失
    current_epoch: int = 0  # 当前轮数
    num_samples_seen: int = 0  # 当前样本数
    random_seed: int = 0  # 随机种子

# 定义 DiffusionPriorDataConfig 类,包含了数据配置参数
class DiffusionPriorDataConfig(BaseModel):
    image_url: str  # 嵌入文件夹路径
    meta_url: str  # 图像元数据(标题)路径
    splits: TrainSplitConfig  # 数据集的训练、验证、测试拆分
    batch_size: int  # 每个 GPU 的批量大小
    num_data_points: int = 25e7  # 训练数据点总数
    eval_every_seconds: int = 3600  # 多久进行一次验证统计

# 定义 TrainDiffusionPriorConfig 类,包含了训练配置参数
class TrainDiffusionPriorConfig(BaseModel):
    prior: DiffusionPriorConfig  # DiffusionPriorConfig 对象
    data: DiffusionPriorDataConfig  # DiffusionPriorDataConfig 对象
    train: DiffusionPriorTrainConfig  # DiffusionPriorTrainConfig 对象
    tracker: TrackerConfig  # 跟踪器配置

    # 从 JSON 路径加载配置
    @classmethod
    def from_json_path(cls, json_path):
        with open(json_path) as f:
            config = json.load(f)
        return cls(**config)

# 解码器 Pydantic 类

# 定义 UnetConfig 类,包含了 Unet 模型的配置参数
class UnetConfig(BaseModel):
    dim: int  # 维度
    dim_mults: ListOrTuple[int]  # 维度倍增列表
    image_embed_dim: Optional[int] = None  # 图像嵌入维度
    text_embed_dim: Optional[int] = None  # 文本嵌入维度
    cond_on_text_encodings: Optional[bool] = None  # 是否在文本编码上进行条件
    cond_dim: Optional[int] = None  # 条件维度
    channels: int = 3  # 通道数
    self_attn: SingularOrIterable[bool] = False  # 自注意力机制
    attn_dim_head: int = 32  # 注意力头部维度
    attn_heads: int = 16  # 注意力头部数量
    init_cross_embed: bool = True  # 是否初始化交叉嵌入

    class Config:
        extra = "allow"

# 定义 DecoderConfig 类,包含了解码器的配置参数
class DecoderConfig(BaseModel):
    unets: ListOrTuple[UnetConfig]  # UnetConfig 列表
    image_size: Optional[int] = None  # 图像尺寸
    image_sizes: ListOrTuple[int] = None  # 图像尺寸列表
    clip: Optional[AdapterConfig] = None  # 适配器配置(如果未提供嵌入,则使用 clip 模型)
    channels: int = 3  # 通道数
    timesteps: int = 1000  # 时间步数
    sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None  # 采样时间步数
    loss_type: str = 'l2'  # 损失类型
    beta_schedule: Optional[ListOrTuple[str]] = None  # beta 调度(None 表示所有余弦)
    # 定义学习方差的参数,默认为 True
    learned_variance: SingularOrIterable[bool] = True
    # 定义图像条件下的丢弃概率,默认为 0.1
    image_cond_drop_prob: float = 0.1
    # 定义文本条件下的丢弃概率,默认为 0.5

    def create(self):
        # 从参数中提取解码器的参数
        decoder_kwargs = self.dict()

        # 从解码器参数中提取 UNet 的配置
        unet_configs = decoder_kwargs.pop('unets')
        # 根据 UNet 的配置创建 UNet 对象列表
        unets = [Unet(**config) for config in unet_configs]

        # 检查是否存在剪辑参数
        has_clip = exists(decoder_kwargs.pop('clip'))
        clip = None
        # 如果存在剪辑参数,则创建剪辑对象
        if has_clip:
            clip = self.clip.create()

        # 返回解码器对象,传入 UNet 对象列表和剪辑对象
        return Decoder(unets, clip=clip, **decoder_kwargs)

    # 验证器,用于检查图像大小参数
    @validator('image_sizes')
    def check_image_sizes(cls, image_sizes, values):
        # 如果 image_size 和 image_sizes 中只有一个存在,则抛出异常
        if exists(values.get('image_size')) ^ exists(image_sizes):
            return image_sizes
        raise ValueError('either image_size or image_sizes is required, but not both')

    # 类配置,允许额外参数
    class Config:
        extra = "allow"
# 定义一个配置类,用于存储解码器的训练配置信息
class DecoderDataConfig(BaseModel):
    webdataset_base_url: str                     # 存储包含jpg图像的webdataset的路径
    img_embeddings_url: Optional[str] = None     # 存储包含嵌入向量的.npy文件的路径
    text_embeddings_url: Optional[str] = None    # 存储包含嵌入向量的.npy文件的路径
    num_workers: int = 4                         # 工作进程数
    batch_size: int = 64                         # 批量大小
    start_shard: int = 0                         # 起始分片
    end_shard: int = 9999999                     # 结束分片
    shard_width: int = 6                         # 分片宽度
    index_width: int = 4                         # 索引宽度
    splits: TrainSplitConfig                     # 训练数据集拆分配置
    shuffle_train: bool = True                    # 是否对训练数据进行洗牌
    resample_train: bool = False                  # 是否重新采样训练数据
    preprocessing: Dict[str, Any] = {'ToTensor': True}  # 预处理步骤配置

    @property
    def img_preproc(self):
        # 获取图像预处理转换函数
        def _get_transformation(transformation_name, **kwargs):
            if transformation_name == "RandomResizedCrop":
                return T.RandomResizedCrop(**kwargs)
            elif transformation_name == "RandomHorizontalFlip":
                return T.RandomHorizontalFlip()
            elif transformation_name == "ToTensor":
                return T.ToTensor()

        transforms = []
        # 遍历预处理配置,生成转换函数列表
        for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
            transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
            transforms.append(_get_transformation(transform_name, **transform_kwargs))
        return T.Compose(transforms)

# 定义一个配置类,用于存储解码器的训练配置信息
class DecoderTrainConfig(BaseModel):
    epochs: int = 20                             # 训练轮数
    lr: SingularOrIterable[float] = 1e-4         # 学习率
    wd: SingularOrIterable[float] = 0.01         # 权重衰减
    warmup_steps: Optional[SingularOrIterable[int]] = None  # 预热步数
    find_unused_parameters: bool = True          # 是否查找未使用的参数
    static_graph: bool = True                    # 是否使用静态图
    max_grad_norm: SingularOrIterable[float] = 0.5  # 最大梯度范数
    save_every_n_samples: int = 100000           # 每隔多少样本保存一次模型
    n_sample_images: int = 6                     # 在采样训练和测试数���集时生成的示例图像数量
    cond_scale: Union[float, List[float]] = 1.0  # 条件缩放
    device: str = 'cuda:0'                       # 设备
    epoch_samples: Optional[int] = None          # 每轮样本数限制
    validation_samples: Optional[int] = None     # 验证集样本数限制
    save_immediately: bool = False                # 是否立即保存
    use_ema: bool = True                         # 是否使用指数移动平均
    ema_beta: float = 0.999                      # 指数移动平均的beta值
    amp: bool = False                            # 是否使用混合精度训练
    unet_training_mask: Optional[ListOrTuple[bool]] = None  # UNet训练掩码

# 定义一个配置类,用于存储解码器的评估配置信息
class DecoderEvaluateConfig(BaseModel):
    n_evaluation_samples: int = 1000             # 评估样本数
    FID: Optional[Dict[str, Any]] = None         # FID评估配置
    IS: Optional[Dict[str, Any]] = None          # IS评估配置
    KID: Optional[Dict[str, Any]] = None         # KID评估配置
    LPIPS: Optional[Dict[str, Any]] = None       # LPIPS评估配置

# 定义一个配置类,用于存储训练解码器的完整配置信息
class TrainDecoderConfig(BaseModel):
    decoder: DecoderConfig                      # 解码器配置
    data: DecoderDataConfig                      # 数据配置
    train: DecoderTrainConfig                    # 训练配置
    evaluate: DecoderEvaluateConfig              # 评估配置
    tracker: TrackerConfig                      # 追踪器配置
    seed: int = 0                                # 随机种子

    @classmethod
    def from_json_path(cls, json_path):
        with open(json_path) as f:
            config = json.load(f)                 # 从JSON文件中加载配置
            print(config)
        return cls(**config)

    @model_validator(mode = 'after')             # 模型验证器
    # 检查是否提供了足够的信息来获取指定用于训练的嵌入
    def check_has_embeddings(self, m):
        # 将self转换为字典
        values = dict(self)

        # 获取data和decoder配置
        data_config, decoder_config = values.get('data'), values.get('decoder')

        # 如果data_config或decoder_config不存在
        if not exists(data_config) or not exists(decoder_config):
            # 则发生了其他错误,应该直接返回values
            return values

        # 检查decoder是否使用文本嵌入
        using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
        # 检查是否使用了clip
        using_clip = exists(decoder_config.clip)
        # 获取图片嵌入和文本嵌入的URL
        img_emb_url = data_config.img_embeddings_url
        text_emb_url = data_config.text_embeddings_url

        # 如果使用了文本嵌入
        if using_text_embeddings:
            # 需要一种方法来获取嵌入
            assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'

        # 如果使用了clip
        if using_clip:
            # 如果同时使用了文本嵌入和图片嵌入的URL
            if using_text_embeddings:
                assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
            else:
                assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'

        # 如果存在文本嵌入的URL
        if text_emb_url:
            assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."

        # 返回m
        return m

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\utils.py

# 导入时间模块
import time
# 导入 importlib 模块
import importlib

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 时间辅助函数

# 计时器类
class Timer:
    def __init__(self):
        self.reset()

    # 重置计时器
    def reset(self):
        self.last_time = time.time()

    # 返回经过的时间
    def elapsed(self):
        return time.time() - self.last_time

# 打印辅助函数

# 打印带边框的字符串
def print_ribbon(s, symbol='=', repeat=40):
    flank = symbol * repeat
    return f'{flank} {s} {flank}'

# 导入辅助函数

# 尝试导入指定模块,如果失败则打印错误信息并退出程序
def import_or_print_error(pkg_name, err_str=None):
    try:
        return importlib.import_module(pkg_name)
    except ModuleNotFoundError as e:
        if exists(err_str):
            print(err_str)
        exit()

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\version.py

# 定义变量 __version__,赋值为字符串 '1.15.6'
__version__ = '1.15.6'

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\vqgan_vae.py

# 导入必要的库
import copy
import math
from math import sqrt
from functools import partial, wraps

# 导入自定义模块
from vector_quantize_pytorch import VectorQuantize as VQ

# 导入 PyTorch 库
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision

# 导入 einops 库
from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 定义常量
MList = nn.ModuleList

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 装饰器

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 移除 VGG 模型装饰器
def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, 'vgg')
        if has_vgg:
            vgg = self.vgg
            delattr(self, 'vgg')

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self.vgg = vgg

        return out
    return inner

# 关键字参数辅助函数

# 从字典中选择指定键的值并弹出这些键
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, string_input):
    return string_input.startswith(prefix)

# 根据前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀将字典分组并去除前缀
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 张量辅助函数

# 对数函数
def log(t, eps = 1e-10):
    return torch.log(t + eps)

# 计算梯度惩罚
def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs = output, inputs = images,
                           grad_outputs = torch.ones(output.size(), device = images.device),
                           create_graph = True, retain_graph = True, only_inputs = True)[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

# 稳定的 Softmax 函数
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
    t = t / alpha
    t = t - torch.amax(t, dim = dim, keepdim = True).detach()
    return (t * alpha).softmax(dim = dim)

# 安全除法
def safe_div(numer, denom, eps = 1e-8):
    return numer / (denom + eps)

# GAN 损失函数

# Hinge 判别器损失
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 生成器损失
def hinge_gen_loss(fake):
    return -fake.mean()

# 二元交叉熵判别器损失
def bce_discr_loss(fake, real):
    return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()

# 二元交叉熵生成器损失
def bce_gen_loss(fake):
    return -log(torch.sigmoid(fake)).mean()

# 计算损失对层的梯度
def grad_layer_wrt_loss(loss, layer):
    return torch_grad(
        outputs = loss,
        inputs = layer,
        grad_outputs = torch.ones_like(loss),
        retain_graph = True
    )[0].detach()

# VQGAN VAE

# 通道层归一化
class LayerNormChan(nn.Module):
    def __init__(
        self,
        dim,
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.gamma

# 判别器

class Discriminator(nn.Module):
    def __init__(
        self,
        dims,
        channels = 3,
        groups = 16,
        init_kernel_size = 5
    # 定义一个继承自 nn.Module 的类,用于构建一个简单的卷积神经网络
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 将输入维度按照前后两两配对,形成一个维度对的列表
        dim_pairs = zip(dims[:-1], dims[1:])

        # 初始化网络的第一层,包括一个卷积层和激活函数
        self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])

        # 遍历维度对列表,构建网络的中间层,每层包括卷积层、归一化层和激活函数
        for dim_in, dim_out in dim_pairs:
            self.layers.append(nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
                nn.GroupNorm(groups, dim_out),
                leaky_relu()
            ))

        # 获取最后一个维度
        dim = dims[-1]
        # 构建输出层,包括两个卷积层和激活函数,用于生成输出结果
        self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
            nn.Conv2d(dim, dim, 1),
            leaky_relu(),
            nn.Conv2d(dim, 1, 4)
        )

    # 定义前向传播方法,将输入数据通过网络层进行处理,得到输出结果
    def forward(self, x):
        # 遍历网络的每一层,将输入数据依次传递给每一层
        for net in self.layers:
            x = net(x)

        # 返回经过所有网络层处理后的输出结果
        return self.to_logits(x)
# positional encoding

class ContinuousPositionBias(nn.Module):
    """ from https://arxiv.org/abs/2111.09883 """

    def __init__(self, *, dim, heads, layers = 2):
        super().__init__()
        self.net = MList([])
        self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))

        for _ in range(layers - 1):
            self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))

        self.net.append(nn.Linear(dim, heads)
        # 初始化一个空的相对位置矩阵
        self.register_buffer('rel_pos', None, persistent = False)

    def forward(self, x):
        n, device = x.shape[-1], x.device
        fmap_size = int(sqrt(n))

        if not exists(self.rel_pos):
            # 生成位置信息
            pos = torch.arange(fmap_size, device = device)
            grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
            grid = rearrange(grid, 'c i j -> (i j) c')
            rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
            rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
            # 将生成的位置信息存储在缓冲区中
            self.register_buffer('rel_pos', rel_pos, persistent = False)

        rel_pos = self.rel_pos.float()

        for layer in self.net:
            rel_pos = layer(rel_pos)

        bias = rearrange(rel_pos, 'i j h -> h i j')
        return x + bias

# resnet encoder / decoder

class ResnetEncDec(nn.Module):
    def __init__(
        self,
        dim,
        *,
        channels = 3,
        layers = 4,
        layer_mults = None,
        num_resnet_blocks = 1,
        resnet_groups = 16,
        first_conv_kernel_size = 5,
        use_attn = True,
        attn_dim_head = 64,
        attn_heads = 8,
        attn_dropout = 0.,
    ):
        super().__init__()
        assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'

        self.layers = layers

        self.encoders = MList([])
        self.decoders = MList([])

        layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
        assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'

        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        self.encoded_dim = dims[-1]

        dim_pairs = zip(dims[:-1], dims[1:])

        append = lambda arr, t: arr.append(t)
        prepend = lambda arr, t: arr.insert(0, t)

        if not isinstance(num_resnet_blocks, tuple):
            num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)

        if not isinstance(use_attn, tuple):
            use_attn = (*((False,) * (layers - 1)), use_attn)

        assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
        assert len(use_attn) == layers

        for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
            append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
            prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))

            if layer_use_attn:
                prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))

            for _ in range(layer_num_resnet_blocks):
                append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
                prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))

            if layer_use_attn:
                append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))

        prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
        append(self.decoders, nn.Conv2d(dim, channels, 1))

    def get_encoded_fmap_size(self, image_size):
        return image_size // (2 ** self.layers)
    # 定义一个属性,返回最后一个解码器的权重
    @property
    def last_dec_layer(self):
        return self.decoders[-1].weight

    # 编码函数,对输入数据进行编码
    def encode(self, x):
        # 遍历所有编码器,对输入数据进行编码
        for enc in self.encoders:
            x = enc(x)
        # 返回编码后的数据
        return x

    # 解码函数,对输入数据进行解码
    def decode(self, x):
        # 遍历所有解码器,对输入数据进行解码
        for dec in self.decoders:
            x = dec(x)
        # 返回解码后的数据
        return x
# 定义 GLUResBlock 类,继承自 nn.Module
class GLUResBlock(nn.Module):
    # 初始化函数,接受通道数和组数作为参数
    def __init__(self, chan, groups = 16):
        super().__init__()
        # 定义网络结构为一个序列
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan * 2, 3, padding = 1),  # 3x3 卷积层
            nn.GLU(dim = 1),  # GLU 激活函数
            nn.GroupNorm(groups, chan),  # 分组归一化
            nn.Conv2d(chan, chan * 2, 3, padding = 1),  # 3x3 卷积层
            nn.GLU(dim = 1),  # GLU 激活函数
            nn.GroupNorm(groups, chan),  # 分组归一化
            nn.Conv2d(chan, chan, 1)  # 1x1 卷积层
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x) + x  # 返回网络输出与输入的和

# 定义 ResBlock 类,继承自 nn.Module
class ResBlock(nn.Module):
    # 初始化函数,接受通道数和组数作为参数
    def __init__(self, chan, groups = 16):
        super().__init__()
        # 定义网络结构为一个序列
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding = 1),  # 3x3 卷积层
            nn.GroupNorm(groups, chan),  # 分组归一化
            leaky_relu(),  # leaky_relu 激活函数
            nn.Conv2d(chan, chan, 3, padding = 1),  # 3x3 卷积层
            nn.GroupNorm(groups, chan),  # 分组归一化
            leaky_relu(),  # leaky_relu 激活函数
            nn.Conv2d(chan, chan, 1)  # 1x1 卷积层
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x) + x  # 返回网络输出与输入的和

# 定义 VQGanAttention 类,继承自 nn.Module
class VQGanAttention(nn.Module):
    # 初始化函数,接受维度、头数、头维度和 dropout 等参数
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head

        self.dropout = nn.Dropout(dropout)
        self.pre_norm = LayerNormChan(dim)

        self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)

    # 前向传播函数
    def forward(self, x):
        h = self.heads
        height, width, residual = *x.shape[-2:], x.clone()

        x = self.pre_norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))

        sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale

        sim = self.cpb(sim)

        attn = stable_softmax(sim, dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h c j -> b h c i', attn, v)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
        out = self.to_out(out)

        return out + residual

# 定义 RearrangeImage 类,继承自 nn.Module
class RearrangeImage(nn.Module):
    # 前向传播函数
    def forward(self, x):
        n = x.shape[1]
        w = h = int(sqrt(n))
        return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)

# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
    # 初始化函数,接受维度、头数和头维度等参数
    def __init__(
        self,
        dim,
        *,
        heads = 8,
        dim_head = 32
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    # 前向传播函数
    def forward(self, x):
        h = self.heads

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 FeedForward 函数,返回一个包含层归一化、线性层、GELU 激活函数和线性层的序列
def FeedForward(dim, mult = 4):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult, bias = False),
        nn.GELU(),
        nn.Linear(dim * mult, dim, bias = False)
    )

# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接受维度、层数、头维度、头数和前馈网络倍数等参数
    def __init__(
        self,
        dim,
        *,
        layers,
        dim_head = 32,
        heads = 8,
        ff_mult = 4
    ):  
        # 调用父类的构造函数
        super().__init__()
        # 初始化一个空的神经网络模块列表
        self.layers = nn.ModuleList([])
        # 循环创建指定数量的层
        for _ in range(layers):
            # 向神经网络模块列表中添加一个包含注意力和前馈神经网络的模块列表
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        # 初始化一个 LayerNorm 层
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        # 遍历每一层的注意力和前馈神经网络
        for attn, ff in self.layers:
            # 对输入进行注意力操作并加上原始输入
            x = attn(x) + x
            # 对输入进行前馈神经网络操作并加上原始输入
            x = ff(x) + x

        # 对最终结果进行 LayerNorm 操作
        return self.norm(x)
# 定义 ViTEncDec 类,继承自 nn.Module
class ViTEncDec(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        channels = 3,
        layers = 4,
        patch_size = 8,
        dim_head = 32,
        heads = 8,
        ff_mult = 4
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置编码后的维度
        self.encoded_dim = dim
        # 设置补丁大小
        self.patch_size = patch_size

        # 计算输入维度
        input_dim = channels * (patch_size ** 2)

        # 定义编码器部分
        self.encoder = nn.Sequential(
            # 重排输入数据形状
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            # 线性层
            nn.Linear(input_dim, dim),
            # Transformer 模块
            Transformer(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                layers = layers
            ),
            # 重排图像数据形状
            RearrangeImage(),
            # 重排输出数据形状
            Rearrange('b h w c -> b c h w')
        )

        # 定义解码器部分
        self.decoder = nn.Sequential(
            # 重排输入数据形状
            Rearrange('b c h w -> b (h w) c'),
            # Transformer 模块
            Transformer(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                layers = layers
            ),
            # 线性层和激活函数
            nn.Sequential(
                nn.Linear(dim, dim * 4, bias = False),
                nn.Tanh(),
                nn.Linear(dim * 4, input_dim, bias = False),
            ),
            # 重排图像数据形状
            RearrangeImage(),
            # 重排输出数据形状
            Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
        )

    # 获取编码后特征图的大小
    def get_encoded_fmap_size(self, image_size):
        return image_size // self.patch_size

    # 返回解码器的最后一层
    @property
    def last_dec_layer(self):
        return self.decoder[-3][-1].weight

    # 编码函数
    def encode(self, x):
        return self.encoder(x)

    # 解码函数
    def decode(self, x):
        return self.decoder(x)

# 定义 NullVQGanVAE 类,继承自 nn.Module
class NullVQGanVAE(nn.Module):
    # 初始化函数��接受 channels 参数
    def __init__(
        self,
        *,
        channels
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置编码后的维度为 channels
        self.encoded_dim = channels
        # 设置层数为 0
        self.layers = 0

    # 获取编码后特征图的大小
    def get_encoded_fmap_size(self, size):
        return size

    # 复制模型用于评估
    def copy_for_eval(self):
        return self

    # 编码函数
    def encode(self, x):
        return x

    # 解码函数
    def decode(self, x):
        return x

# 定义 VQGanVAE 类,继承自 nn.Module
class VQGanVAE(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim,
        image_size,
        channels = 3,
        layers = 4,
        l2_recon_loss = False,
        use_hinge_loss = True,
        vgg = None,
        vq_codebook_dim = 256,
        vq_codebook_size = 512,
        vq_decay = 0.8,
        vq_commitment_weight = 1.,
        vq_kmeans_init = True,
        vq_use_cosine_sim = True,
        use_vgg_and_gan = True,
        vae_type = 'resnet',
        discr_layers = 4,
        **kwargs
    # 初始化函数,设置各种参数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数按照前缀分组,提取出以'vq_'开头的参数
        vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
        # 将参数按照前缀分组,提取出以'encdec_'开头的参数
        encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)

        # 设置图像大小、通道数、VQ 编码簇大小
        self.image_size = image_size
        self.channels = channels
        self.codebook_size = vq_codebook_size

        # 根据 VAE 类型选择编码器解码器类
        if vae_type == 'resnet':
            enc_dec_klass = ResnetEncDec
        elif vae_type == 'vit':
            enc_dec_klass = ViTEncDec
        else:
            raise ValueError(f'{vae_type} not valid')

        # 初始化编码器解码器
        self.enc_dec = enc_dec_klass(
            dim = dim,
            channels = channels,
            layers = layers,
            **encdec_kwargs
        )

        # 初始化 VQ 模块
        self.vq = VQ(
            dim = self.enc_dec.encoded_dim,
            codebook_dim = vq_codebook_dim,
            codebook_size = vq_codebook_size,
            decay = vq_decay,
            commitment_weight = vq_commitment_weight,
            accept_image_fmap = True,
            kmeans_init = vq_kmeans_init,
            use_cosine_sim = vq_use_cosine_sim,
            **vq_kwargs
        )

        # 设置重构损失函数
        self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss

        # 如果是灰度图像,则关闭 GAN 和感知损失
        self.vgg = None
        self.discr = None
        self.use_vgg_and_gan = use_vgg_and_gan

        if not use_vgg_and_gan:
            return

        # 初始化感知损失
        if exists(vgg):
            self.vgg = vgg
        else:
            self.vgg = torchvision.models.vgg16(pretrained = True)
            self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])

        # 初始化 GAN 相关损失
        layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        self.discr = Discriminator(dims = dims, channels = channels)

        self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
        self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss

    # 获取编码后的维度
    @property
    def encoded_dim(self):
        return self.enc_dec.encoded_dim

    # 获取编码后特征图的大小
    def get_encoded_fmap_size(self, image_size):
        return self.enc_dec.get_encoded_fmap_size(image_size)

    # 复制模型用于评估
    def copy_for_eval(self):
        device = next(self.parameters()).device
        vae_copy = copy.deepcopy(self.cpu())

        if vae_copy.use_vgg_and_gan:
            del vae_copy.discr
            del vae_copy.vgg

        vae_copy.eval()
        return vae_copy.to(device)

    # 获取模型状态字典
    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    # 加载模型状态字典
    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    # 获取编码簇
    @property
    def codebook(self):
        return self.vq.codebook

    # 编码
    def encode(self, fmap):
        fmap = self.enc_dec.encode(fmap)
        return fmap

    # 解码
    def decode(self, fmap, return_indices_and_loss = False):
        fmap, indices, commit_loss = self.vq(fmap)

        fmap = self.enc_dec.decode(fmap)

        if not return_indices_and_loss:
            return fmap

        return fmap, indices, commit_loss

    # 前向传播
    def forward(
        self,
        img,
        return_loss = False,
        return_discr_loss = False,
        return_recons = False,
        add_gradient_penalty = True
        ):
            # 解构赋值,获取图像的批次、通道数、高度、宽度、设备信息
            batch, channels, height, width, device = *img.shape, img.device
            # 断言输入图像的高度和宽度与设定的图像大小相等
            assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
            # 断言输入图像的通道数与 VQGanVAE 中设定的通道数相等
            assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'

            # 编码输入图像
            fmap = self.encode(img)

            # 解码编码后的特征图,并返回索引和损失
            fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)

            if not return_loss and not return_discr_loss:
                return fmap

            # 断言只能返回自编码器损失或鉴别器损失,不能同时返回
            assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'

            # 是否返回鉴别器损失
            if return_discr_loss:
                # 断言鉴别器存在
                assert exists(self.discr), 'discriminator must exist to train it'

                # 分离编码后的特征图,设置输入图像为需要梯度
                fmap.detach_()
                img.requires_grad_()

                # 获取编码后特征图和输入图像的鉴别器输出
                fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))

                # 计算鉴别器损失
                discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

                if add_gradient_penalty:
                    # 添加梯度惩罚项
                    gp = gradient_penalty(img, img_discr_logits)
                    loss = discr_loss + gp

                if return_recons:
                    return loss, fmap

                return loss

            # 重构损失
            recon_loss = self.recon_loss_fn(fmap, img)

            # 若不使用 VGG 和 GAN
            if not self.use_vgg_and_gan:
                if return_recons:
                    return recon_loss, fmap

                return recon_loss

            # 感知损失
            img_vgg_input = img
            fmap_vgg_input = fmap

            if img.shape[1] == 1:
                # 处理灰度图像用于 VGG
                img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))

            # 获取输入图像和重构图像的 VGG 特征
            img_vgg_feats = self.vgg(img_vgg_input)
            recon_vgg_feats = self.vgg(fmap_vgg_input)
            perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)

            # 生成器损失
            gen_loss = self.gen_loss(self.discr(fmap))

            # 计算自适应权重
            last_dec_layer = self.enc_dec.last_dec_layer

            norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
            norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

            adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
            adaptive_weight.clamp_(max = 1e4)

            # 组合损失
            loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss

            if return_recons:
                return loss, fmap

            return loss

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\vqgan_vae_trainer.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 copy 模块中导入 copy 函数
import copy
# 从 random 模块中导入 choice 函数
from random import choice
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree
# 从 PIL 模块中导入 Image 类
from PIL import Image

# 导入 torch 库
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.cuda.amp 模块中导入 autocast, GradScaler 函数
from torch.cuda.amp import autocast, GradScaler
# 从 torch.utils.data 模块中导入 Dataset, DataLoader, random_split 类
from torch.utils.data import Dataset, DataLoader, random_split

# 导入 torchvision.transforms 模块,并重命名为 T
import torchvision.transforms as T
# 从 torchvision.datasets 模块中导入 ImageFolder 类
from torchvision.datasets import ImageFolder
# 从 torchvision.utils 模块中导入 make_grid, save_image 函数
from torchvision.utils import make_grid, save_image

# 导入 einops 模块中的 rearrange 函数
from einops import rearrange

# 导入 dalle2_pytorch.vqgan_vae 模块中的 VQGanVAE 类
from dalle2_pytorch.vqgan_vae import VQGanVAE
# 导入 dalle2_pytorch.optimizer 模块中的 get_optimizer 函数
from dalle2_pytorch.optimizer import get_optimizer

# 导入 ema_pytorch 模块中的 EMA 类
from ema_pytorch import EMA

# helpers

# 定义函数 exists,判断值是否不为 None
def exists(val):
    return val is not None

# 定义函数 noop,空函数,不执行任何操作
def noop(*args, **kwargs):
    pass

# 定义函数 cycle,生成一个无限循环的数据生成器
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 定义函数 cast_tuple,将输入转换为元组类型
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 定义函数 yes_or_no,询问用户问题并返回 True 或 False
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 定义函数 accum_log,累积日志信息
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# classes

# 定义类 ImageDataset,继承自 Dataset 类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        # 获取指定文件夹下指定扩展名的所有文件路径
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        # 定义数据转换操作
        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    # 返回数据集的长度
    def __len__(self):
        return len(self.paths)

    # 获取指定索引处的数据
    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# main trainer class

# 定义类 VQGanVAETrainer,继承自 nn.Module 类
class VQGanVAETrainer(nn.Module):
    def __init__(
        self,
        vae,
        *,
        num_train_steps,
        lr,
        batch_size,
        folder,
        grad_accum_every,
        wd = 0.,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        ema_beta = 0.995,
        ema_update_after_step = 500,
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
        amp = False
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言 vae 是 VQGanVAE 的实例
        assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
        # 获取 VAE 的图像大小
        image_size = vae.image_size

        # 设置 VAE 和 EMA_VAE
        self.vae = vae
        self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)

        # 注册步数缓冲区
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、批量大小、梯度累积频率
        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 获取所有参数、判别器参数、VAE 参数
        all_parameters = set(vae.parameters())
        discr_parameters = set(vae.discr.parameters())
        vae_parameters = all_parameters - discr_parameters

        # 获取优化器
        self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
        self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)

        # 设置混合精度训练
        self.amp = amp
        self.scaler = GradScaler(enabled = amp)
        self.discr_scaler = GradScaler(enabled = amp)

        # 创建数据集
        self.ds = ImageDataset(folder, image_size = image_size)

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 创建数据加载器
        self.dl = cycle(DataLoader(
            self.ds,
            batch_size = batch_size,
            shuffle = True
        ))

        self.valid_dl = cycle(DataLoader(
            self.valid_ds,
            batch_size = batch_size,
            shuffle = True
        ))

        # 设置保存模型和结果的频率
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置应用梯度惩罚的频率
        self.apply_grad_penalty_every = apply_grad_penalty_every

        # 设置结果文件夹
        self.results_folder = Path(results_folder)

        # 如果结果文件夹中有文件且确认清除,则删除文件夹
        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
    # 定义训练步骤函数
    def train_step(self):
        # 获取模型参数所在设备
        device = next(self.vae.parameters()).device
        # 获取当前步数
        steps = int(self.steps.item())
        # 是否应用梯度惩罚
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

        # 设置 VAE 模型为训练模式
        self.vae.train()

        # 初始化日志字典
        logs = {}

        # 更新 VAE(生成器)

        # 多次执行梯度累积
        for _ in range(self.grad_accum_every):
            # 获取下一个数据批次
            img = next(self.dl)
            img = img.to(device)

            # 开启自动混合精度
            with autocast(enabled = self.amp):
                # 计算损失
                loss = self.vae(
                    img,
                    return_loss = True,
                    apply_grad_penalty = apply_grad_penalty
                )

                # 反向传播并缩放损失
                self.scaler.scale(loss / self.grad_accum_every).backward()

            # 累积损失到日志中
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 梯度更新
        self.scaler.step(self.optim)
        self.scaler.update()
        self.optim.zero_grad()

        # 更新鉴别器

        if exists(self.vae.discr):
            discr_loss = 0
            for _ in range(self.grad_accum_every):
                img = next(self.dl)
                img = img.to(device)

                with autocast(enabled = self.amp):
                    loss = self.vae(img, return_discr_loss = True)

                    self.discr_scaler.scale(loss / self.grad_accum_every).backward()

                accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})

            self.discr_scaler.step(self.discr_optim)
            self.discr_scaler.update()
            self.discr_optim.zero_grad()

            # 打印日志
            print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")

        # 更新指数移动平均生成器
        self.ema_vae.update()

        # 定期采样结果
        if not (steps % self.save_results_every):
            for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
                model.eval()

                imgs = next(self.dl)
                imgs = imgs.to(device)

                recons = model(imgs)
                nrows = int(sqrt(self.batch_size))

                imgs_and_recons = torch.stack((imgs, recons), dim = 0)
                imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')

                imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
                grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))

                logs['reconstructions'] = grid

                save_image(grid, str(self.results_folder / f'{filename}.png'))

            print(f'{steps}: saving to {str(self.results_folder)}')

        # 定期保存模型
        if not (steps % self.save_model_every):
            state_dict = self.vae.state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.pt')
            torch.save(state_dict, model_path)

            ema_state_dict = self.ema_vae.state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
            torch.save(ema_state_dict, model_path)

            print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数并返回日志
        self.steps += 1
        return logs

    # 训练函数
    def train(self, log_fn = noop):
        # 获取模型参数所在设备
        device = next(self.vae.parameters()).device

        # 在训练步数未达到总训练步数前循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 训练完成
        print('training complete')

.\lucidrains\DALLE2-pytorch\dalle2_pytorch\__init__.py

# 从dalle2_pytorch版本模块中导入版本号
from dalle2_pytorch.version import __version__
# 从dalle2_pytorch模块中导入DALLE2类、DiffusionPriorNetwork类、DiffusionPrior类、Unet类和Decoder类
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
# 从dalle2_pytorch模块中导入OpenAIClipAdapter类和OpenClipAdapter类
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
# 从dalle2_pytorch模块中导入DecoderTrainer类和DiffusionPriorTrainer类
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer

# 从dalle2_pytorch模块中导入VQGanVAE类
from dalle2_pytorch.vqgan_vae import VQGanVAE
# 从x_clip模块中导入CLIP类
from x_clip import CLIP

Diffusion Prior

This readme serves as an introduction to the diffusion prior.

Intro

A properly trained prior will allow you to translate between two embedding spaces. If you know a priori that two embeddings are connected some way—then ability the translate between them could extremely helpful.

Motivation

Before we dive into the model, let’s look at a quick example of where the model may be helpful.

For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.

CLIP is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are close the image and text embeddings occupy two disjoint sets.

# Load Models
clip_model = clip.load("ViT-L/14")
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings

# Retrieve prompt from user and encode with CLIP
prompt = "A corgi wearing sunglasses"
tokenized_text = tokenize(prompt)
text_embedding = clip_model.encode_text(tokenized_text)

# Now, pass the text embedding to the decoder
predicted_image = decoder.sample(text_embedding)

Question: Can you spot the issue here?

Answer: We’re trying to generate an image from a text embedding!

Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution

# Load Models
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings

# Retrieve prompt from user and encode with a prior
prompt = "A corgi wearing sunglasses"
tokenized_text = tokenize(prompt)
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!

# Now, pass the predicted image embedding to the decoder
predicted_image = decoder.sample(text_embedding)

With the prior we are able to successfully generate embeddings within CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.

You may be asking yourself the following question:

"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"

OpenAI covers this topic in their DALLE-2 paper. The TL;DR is "it doesn't work as well as decoders trained on image embeddings"...also...its just an example 😄

Usage

To utilize a pre-trained prior, it’s quite simple.

Loading Checkpoints

import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer

def load_diffusion_model(dprior_path):

    prior_network = DiffusionPriorNetwork(
        dim=768,
        depth=24,
        dim_head=64,
        heads=32,
        normformer=True,
        attn_dropout=5e-2,
        ff_dropout=5e-2,
        num_time_embeds=1,
        num_image_embeds=1,
        num_text_embeds=1,
        num_timesteps=1000,
        ff_mult=4
    )

    diffusion_prior = DiffusionPrior(
        net=prior_network,
        clip=OpenAIClipAdapter("ViT-L/14"),
        image_embed_dim=768,
        timesteps=1000,
        cond_drop_prob=0.1,
        loss_type="l2",
        condition_on_text_encodings=True,

    )

    trainer = DiffusionPriorTrainer(
        diffusion_prior=diffusion_prior,
        lr=1.1e-4,
        wd=6.02e-2,
        max_grad_norm=0.5,
        amp=False,
        group_wd_params=True,
        use_ema=True,
        device=device,
        accelerator=None,
    )

    trainer.load(dprior_path)

    return trainer

Here we instantiate a model matches the configuration it was trained with, and then load the weights (just like any other PyTorch model!)

Sampling

Once we have a pre-trained model, generating embeddings is quite simple!

# tokenize the text
tokenized_text = clip.tokenize("<your amazing prompt>")
# predict an embedding
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)

The resulting tensor returned from .sample() is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on ViT-L/14 embeddings will predict an embedding of shape (1, 768).

For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().

Some things to note:

  • It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is n=2). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
  • You may specify a higher conditioning scale than the default (1.0). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than 1.0 but ymmv.

Training

Overview

Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration

Dataset

To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage img2datset to pull images from a list of URLs and clip_retrieval for generating the actual embeddings that can be used in the prior's dataloader.

Configuration

The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.

Distributed Training

If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool more information here.

Evaluation

There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:

Metric Description Comments
Online Model Validation The validation loss associated with your online model. Ideally validation loss will be as low as possible. Using L2 loss, values as low as 0.1 and lower are possible after around 1 Billion samples seen.
EMA Validation This metric measures the validation loss associated with your EMA model. This will likely lag behind your "online" model's validation loss, but should outperform in the long-term.
Baseline Similarity Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. Generally 0.3 is considered a good cosine similarity for caption similarity.
Similarity With Original Image This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. Values around 0.75+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above 0.5/0.6 similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA)
Difference From Baseline Similarity Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. This value should float around 0.0 with some room for variation. After a billion samples seen, values are within 0.01+/- of 0.0. If this climbs to high, (~>0.02) then this may be a sign that your model is overfitting somehow.
Similarity With Text This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting.
Similarity With Unrelated Caption This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below 0.1 is probably safe.

Launching the script

Now that you’ve done all the prep it’s time for the easy part! 🚀

To actually launch the script, you will either use accelerate launch train_diffusion_prior.py --config_path <path to your config> to launch with distributed training & huggingface accelerate or python train_diffusion_prior.py if you would like to train on your gpu/cpu without huggingface accelerate.

Checkpointing

Checkpoints will be saved to the directory specified in your configuration file.

Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your save_every configuration does not overlap with the number of steps required to do a complete pass through the data.

Things To Keep In Mind

The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.

As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.

With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!

DALL-E 2 - Pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch.

Yannic Kilcher summary | AssemblyAI explainer

The main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)

This model is SOTA for text-to-image for now.

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community | Yannic Interview

As of 5/23/22, it is no longer SOTA. SOTA will be here. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.

Status

  • A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and Katherine's own experiments, validate OpenAI's finding that the extra prior increases variety of generations.

  • Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.

ongoing at 21k steps

  • Justin Pinkney successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application

  • Romain has scaled up training to 800 GPUs with the available scripts without any issues

Pre-Trained Models

Appreciation

This library would not have gotten to this working state without the help of

  • Zion for the distributed training code for the diffusion prior
  • Aidan for the distributed training code for the decoder as well as the dataloaders
  • Kumar for working on the initial diffusion training script
  • Romain for the pull request reviews and project management
  • He Cao and xiankgx for the Q&A and for identifying of critical bugs
  • Marunine for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
  • MalumaDev for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
  • Katherine for her advice
  • Stability AI for the generous sponsorship
  • 🤗 Huggingface and in particular Sylvain for the Accelerate library
  • Alex for einops, indispensable tool for tensor manipulation

... and many others. Thank you! 🙏

Install

$ pip install dalle2-pytorch

Usage

To train DALLE-2 is a 3 step process, with the training of CLIP being the most important

To train CLIP, you can either use x-clip package, or join the LAION discord, where a lot of replication efforts are already underway.

This repository will demonstrate integration with x-clip for starters

import torch
from dalle2_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 1,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 1,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
    use_all_token_embeds = True,            # whether to use fine-grained contrastive learning (FILIP)
    decoupled_contrastive_learning = True,  # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
    extra_latent_projection = True,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_visual_ssl = True,                  # whether to do self supervised learning on images
    visual_ssl_type = 'simclr',             # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
    use_mlm = False,                        # use masked language learning (MLM) on text (DeCLIP)
    text_ssl_loss_weight = 0.05,            # weight for text MLM loss
    image_ssl_loss_weight = 0.05            # weight for image self-supervised learning loss
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train

loss = clip(
    text,
    images,
    return_loss = True              # needs to be set to True to return contrastive loss
)

loss.backward()

# do the above with as many texts and images as possible in a loop

Then, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above

import torch
from dalle2_pytorch import Unet, Decoder, CLIP

# trained clip from step 1

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 1,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 1,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# unet for the decoder

unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

# decoder, which contains the unet and clip

decoder = Decoder(
    unet = unet,
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)

images = torch.randn(4, 3, 256, 256).cuda()

# feed images into decoder

loss = decoder(images)
loss.backward()

# do the above for many many many many steps
# then it will learn to generate images based on the CLIP image embeddings

Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step

import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP

# get trained CLIP from step one

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
).cuda()

# setup prior network, which contains an autoregressive transformer

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

# diffusion prior network, which contains the CLIP and network (with transformer) above

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed text and images into diffusion prior network

loss = diffusion_prior(text, images)
loss.backward()

# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings

In the paper, they actually used a recently discovered technique, from Jonathan Ho himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.

This can easily be used within this framework as so

import torch
from dalle2_pytorch import Unet, Decoder, CLIP

# trained clip from step 1

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# 2 unets for the decoder (a la cascading DDPM)

unet1 = Unet(
    dim = 32,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 32,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

# decoder, which contains the unet(s) and clip

decoder = Decoder(
    clip = clip,
    unet = (unet1, unet2),            # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
    image_sizes = (256, 512),         # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
    timesteps = 1000,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)

images = torch.randn(4, 3, 512, 512).cuda()

# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme

loss = decoder(images, unet_number = 1)
loss.backward()

loss = decoder(images, unet_number = 2)
loss.backward()

# do the above for many steps for both unets

Finally, to generate the DALL-E2 images from text. Insert the trained DiffusionPrior as well as the Decoder (which wraps CLIP, the causal transformer, and unet(s))

from dalle2_pytorch import DALLE2

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

# send the text as a string if you want to use the simple tokenizer from DALLE v1
# or you can do it as token ids, if you have your own tokenizer

texts = ['glistening morning dew on a flower petal']
images = dalle2(texts) # (1, 3, 256, 256)

That's it!

Let's see the whole script below

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train

loss = clip(
    text,
    images,
    return_loss = True
)

loss.backward()

# do above for many steps ...

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 1000,
    sample_timesteps = 64,
    cond_drop_prob = 0.2
).cuda()

loss = diffusion_prior(text, images)
loss.backward()

# do above for many steps ...

# decoder (with unet)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    text_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings = True    # set to True for any unets that need to be conditioned on text encodings
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

for unet_number in (1, 2):
    loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
    loss.backward()

# do above for many steps

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['cute puppy chasing after a squirrel'],
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

# save your image (in this example, of size 256x256)

Everything in this readme should run without error

You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings

For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.

Training on Preprocessed CLIP Embeddings

It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in image_embed, text_embed, and optionally text_encodings

Working example below

import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP

# get trained CLIP from step one

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
).cuda()

# setup prior network, which contains an autoregressive transformer

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

# diffusion prior network, which contains the CLIP and network (with transformer) above

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2,
    condition_on_text_encodings = False  # this probably should be true, but just to get Laion started
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone

clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed

# feed text and images into diffusion prior network

loss = diffusion_prior(
    text_embed = clip_text_embeds,
    image_embed = clip_image_embeds
)

loss.backward()

# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings

You can also completely go CLIP-less, in which case you will need to pass in the image_embed_dim into the DiffusionPrior on initialization

import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior

# setup prior network, which contains an autoregressive transformer

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

# diffusion prior network, which contains the CLIP and network (with transformer) above

diffusion_prior = DiffusionPrior(
    net = prior_network,
    image_embed_dim = 512,               # this needs to be set
    timesteps = 100,
    cond_drop_prob = 0.2,
    condition_on_text_encodings = False  # this probably should be true, but just to get Laion started
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone

clip_image_embeds = torch.randn(4, 512).cuda()
clip_text_embeds = torch.randn(4, 512).cuda()

# feed text and images into diffusion prior network

loss = diffusion_prior(
    text_embed = clip_text_embeds,
    image_embed = clip_image_embeds
)

loss.backward()

# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings

OpenAI CLIP

Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.

To use a pretrained OpenAI CLIP, simply import OpenAIClipAdapter and pass it into the DiffusionPrior or Decoder like so

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter

# openai pretrained clip - defaults to ViT-B/32

clip = OpenAIClipAdapter()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

loss = diffusion_prior(text, images)
loss.backward()

# do above for many steps ...

# decoder (with unet)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    text_embed_dim = 512,
    cond_on_text_encodings = True  # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 1000,
    sample_timesteps = (250, 27),
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

for unet_number in (1, 2):
    loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
    loss.backward()

# do above for many steps

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['a butterfly trying to escape a tornado'],
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

# save your image (in this example, of size 256x256)

Alternatively, you can also use Open Clip

$ pip install open-clip-torch

Ex. using the SOTA Open Clip model trained by Romain

from dalle2_pytorch import OpenClipAdapter

clip = OpenClipAdapter('ViT-H/14')

Now you'll just have to worry about training the Prior and the Decoder!

Inpainting

Inpainting is also built into the Decoder. You simply have to pass in the inpaint_image and inpaint_mask (boolean tensor where True indicates which regions of the inpaint image to keep)

This repository uses the formulation put forth by Lugmayr et al. in Repaint

import torch
from dalle2_pytorch import Unet, Decoder, CLIP

# trained clip from step 1

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# 2 unets for the decoder (a la cascading DDPM)

unet = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 1, 1, 1)
).cuda()


# decoder, which contains the unet(s) and clip

decoder = Decoder(
    clip = clip,
    unet = (unet,),               # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
    image_sizes = (256,),         # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
    timesteps = 1000,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)

images = torch.randn(4, 3, 256, 256).cuda()

# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme

loss = decoder(images, unet_number = 1)
loss.backward()

# do the above for many steps for both unets

mock_image_embed = torch.randn(1, 512).cuda()

# then to do inpainting

inpaint_image = torch.randn(1, 3, 256, 256).cuda()      # (batch, channels, height, width)
inpaint_mask = torch.ones(1, 256, 256).bool().cuda()    # (batch, height, width)

inpainted_images = decoder.sample(
    image_embed = mock_image_embed,
    inpaint_image = inpaint_image,    # just pass in the inpaint image
    inpaint_mask = inpaint_mask       # and the mask
)

inpainted_images.shape # (1, 3, 256, 256)

Experimental

DALL-E2 with Latent Diffusion

This repository decides to take the next step and offer DALL-E v2 combined with latent diffusion, from Rombach et al.

You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.

The repository also comes equipped with all the necessary settings to recreate ViT-VQGan from the Improved VQGans paper. Furthermore, the vector quantization library also comes equipped to do residual or multi-headed quantization, which I believe will give an even further boost in performance to the autoencoder.

import torch
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE

# trained clip from step 1

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 1,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 1,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

# 3 unets for the decoder (a la cascading DDPM)

# first two unets are doing latent diffusion
# vqgan-vae must be trained beforehand

vae1 = VQGanVAE(
    dim = 32,
    image_size = 256,
    layers = 3,
    layer_mults = (1, 2, 4)
)

vae2 = VQGanVAE(
    dim = 32,
    image_size = 512,
    layers = 3,
    layer_mults = (1, 2, 4)
)

unet1 = Unet(
    dim = 32,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    sparse_attn = True,
    sparse_attn_window = 2,
    dim_mults = (1, 2, 4, 8)
)

unet2 = Unet(
    dim = 32,
    image_embed_dim = 512,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16),
    cond_on_image_embeds = True,
    cond_on_text_encodings = False
)

unet3 = Unet(
    dim = 32,
    image_embed_dim = 512,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16),
    cond_on_image_embeds = True,
    cond_on_text_encodings = False,
    attend_at_middle = False
)

# decoder, which contains the unet(s) and clip

decoder = Decoder(
    clip = clip,
    vae = (vae1, vae2),                # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
    unet = (unet1, unet2, unet3),      # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
    image_sizes = (256, 512, 1024),    # resolutions, 256 for first unet, 512 for second, 1024 for third
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)

images = torch.randn(1, 3, 1024, 1024).cuda()

# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme

with decoder.one_unet_in_gpu(1):
    loss = decoder(images, unet_number = 1)
    loss.backward()

with decoder.one_unet_in_gpu(2):
    loss = decoder(images, unet_number = 2)
    loss.backward()

with decoder.one_unet_in_gpu(3):
    loss = decoder(images, unet_number = 3)
    loss.backward()

# do the above for many steps for both unets

# then it will learn to generate images based on the CLIP image embeddings

# chaining the unets from lowest resolution to highest resolution (thus cascading)

mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)

Training wrapper

Decoder Training

Training the Decoder may be confusing, as one needs to keep track of an optimizer for each of the Unet(s) separately. Each Unet will also need its own corresponding exponential moving average. The DecoderTrainer hopes to make this simple, as shown below

import torch
from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# mock data

text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()

# decoder (with unet)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    text_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings = True,
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16),
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 1000
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr = 3e-4,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)

for unet_number in (1, 2):
    loss = decoder_trainer(
        images,
        text = text,
        unet_number = unet_number, # which unet to train on
        max_batch_size = 4         # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
    )

    decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average

# after much training
# you can sample from the exponentially moving averaged unets as so

mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)

Diffusion Prior Training

Similarly, one can use the DiffusionPriorTrainer to automatically instantiate and keep track of an exponential moving averaged prior.

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# mock data

text = torch.randint(0, 49408, (512, 256)).cuda()
images = torch.randn(512, 3, 256, 256).cuda()

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

diffusion_prior_trainer = DiffusionPriorTrainer(
    diffusion_prior,
    lr = 3e-4,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)

loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
diffusion_prior_trainer.update()  # this will update the optimizer as well as the exponential moving averaged diffusion prior

# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior

image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings

Bonus

Unconditional Training

The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set unconditional = True in the Decoder

ex.

import torch
from dalle2_pytorch import Unet, Decoder, DecoderTrainer

# unet for the cascading ddpm

unet1 = Unet(
    dim = 128,
    dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 32,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

# decoder, which contains the unets

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (256, 512),  # first unet up to 256px, then second to 512px
    timesteps = 1000,
    unconditional = True
).cuda()

# decoder trainer

decoder_trainer = DecoderTrainer(decoder)

# images (get a lot of this)

images = torch.randn(1, 3, 512, 512).cuda()

# feed images into decoder

for i in (1, 2):
    loss = decoder_trainer(images, unet_number = i)
    decoder_trainer.update(unet_number = i)

# do the above for many many many many images
# then it will learn to generate images

images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)

Dataloaders

Decoder Dataloaders

In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.

Decoder: Image Embedding Dataset

When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a webdataset that contains .jpg and .npy files in the .tars that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain .npy files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the .jpg and the index of the embedding in the .npy. So, for example, 0001.tar from the webdataset with image 00010509.jpg (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a img_emb_0001.npy which contains a NumPy array with the embedding at index 509.

Generating a dataset of this type:

  1. Use img2dataset to generate a webdataset.
  2. Use clip-retrieval to convert the images to embeddings.
  3. Use embedding-dataset-reordering to reorder the embeddings into the expected format.

Usage:

from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader

# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
    tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
    embeddings_url="path/or/url/to/embeddings/folder",     # Included if .npy files are not in webdataset. Left out or set to None otherwise
    num_workers=4,
    batch_size=32,
    shard_width=4,                                         # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
    shuffle_num=200,                                       # Does a shuffle of the data with a buffer size of 200
    shuffle_shards=True,                                   # Shuffle the order the shards are read in
    resample_shards=False,                                 # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
    print(img.shape)  # torch.Size([32, 3, 256, 256])
    print(emb["img"].shape)  # torch.Size([32, 512])
    # Train decoder only as shown above

# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
    urls="/path/or/url/to/webdataset/{0000..9999}.tar",
    embedding_folder_url="path/or/url/to/embeddings/folder",
    shard_width=4,
    shuffle_shards=True,
    resample=False
)

Scripts

train_diffusion_prior.py

For detailed information on training the diffusion prior, please refer to the dedicated readme

Todo

Citations

@misc{ramesh2022,
    title   = {Hierarchical Text-Conditional Image Generation with CLIP Latents}, 
    author  = {Aditya Ramesh et al},
    year    = {2022}
}
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/rivershavewings}
}
@misc{rombach2021highresolution,
    title   = {High-Resolution Image Synthesis with Latent Diffusion Models}, 
    author  = {Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
    year    = {2021},
    eprint  = {2112.10752},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{shen2019efficient,
    author  = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
    title   = {Efficient Attention: Attention with Linear Complexities},
    journal = {CoRR},
    year    = {2018},
    url     = {http://arxiv.org/abs/1812.01243},
}
@article{Yu2021VectorquantizedIM,
    title   = {Vector-quantized Image Modeling with Improved VQGAN},
    author  = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.04627}
}
@article{Shleifer2021NormFormerIT,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Sam Shleifer and Jason Weston and Myle Ott},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.09456}
}
@article{Yu2022CoCaCC,
    title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
    author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.01917}
}
@misc{wang2021crossformer,
    title   = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
    author  = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
    year    = {2021},
    eprint  = {2108.00154},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{ho2021cascaded,
    title   = {Cascaded Diffusion Models for High Fidelity Image Generation},
    author  = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
    journal = {arXiv preprint arXiv:2106.15282},
    year    = {2021}
}
@misc{Saharia2022,
    title   = {Imagen: unprecedented photorealism × deep level of language understanding},
    author  = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
    year    = {2022}
}
@article{Choi2022PerceptionPT,
    title   = {Perception Prioritized Training of Diffusion Models},
    author  = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2204.00227}
}
@article{Saharia2021PaletteID,
    title   = {Palette: Image-to-Image Diffusion Models},
    author  = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2111.05826}
}
@article{Lugmayr2022RePaintIU,
    title   = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
    author  = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2201.09865}
}
@misc{chen2022analog,
    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
    author  = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
    year    = {2022},
    eprint  = {2208.04202},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Qiao2019WeightS,
    title   = {Weight Standardization},
    author  = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1903.10520}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}
@article{Salimans2022ProgressiveDF,
    title   = {Progressive Distillation for Fast Sampling of Diffusion Models},
    author  = {Tim Salimans and Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.00512}
}

Creating noise from data is easy; creating data from noise is generative modeling. - Yang Song's paper

posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(6)  评论(0编辑  收藏  举报