Lucidrains-系列项目源码解析-五-

Lucidrains 系列项目源码解析(五)

.\lucidrains\byol-pytorch\byol_pytorch\trainer.py

# 导入必要的库
from pathlib import Path
import torch
import torch.distributed as dist
from torch.nn import Module
from torch.nn import SyncBatchNorm
from torch.optim import Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from byol_pytorch.byol_pytorch import BYOL
from beartype import beartype
from beartype.typing import Optional
from accelerate import Accelerator

# 定义函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 生成数据集循环迭代器
def cycle(dl):
    while True:
        for batch in dl:
            yield batch

# 定义数据集类

class MockDataset(Dataset):
    def __init__(self, image_size, length):
        self.length = length
        self.image_size = image_size

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return torch.randn(3, self.image_size, self.image_size)

# 主训练器类

class BYOLTrainer(Module):
    @beartype
    def __init__(
        self,
        net: Module,
        *,
        image_size: int,
        hidden_layer: str,
        learning_rate: float,
        dataset: Dataset,
        num_train_steps: int,
        batch_size: int = 16,
        optimizer_klass = Adam,
        checkpoint_every: int = 1000,
        checkpoint_folder: str = './checkpoints',
        byol_kwargs: dict = dict(),
        optimizer_kwargs: dict = dict(),
        accelerator_kwargs: dict = dict(),
    ):
        super().__init__()
        # 初始化加速器
        self.accelerator = Accelerator(**accelerator_kwargs)

        # 如果分布式训练已初始化且世界大小大于1,则转换网络为同步批量归一化
        if dist.is_initialized() and dist.get_world_size() > 1:
            net = SyncBatchNorm.convert_sync_batchnorm(net)

        self.net = net

        # 初始化BYOL模型
        self.byol = BYOL(net, image_size=image_size, hidden_layer=hidden_layer, **byol_kwargs)

        # 初始化优化器
        self.optimizer = optimizer_klass(self.byol.parameters(), lr=learning_rate, **optimizer_kwargs)

        # 初始化数据加载器
        self.dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)

        self.num_train_steps = num_train_steps

        self.checkpoint_every = checkpoint_every
        self.checkpoint_folder = Path(checkpoint_folder)
        self.checkpoint_folder.mkdir(exist_ok=True, parents=True)
        assert self.checkpoint_folder.is_dir()

        # 使用加速器准备模型、优化器和数据加载器
        (
            self.byol,
            self.optimizer,
            self.dataloader
        ) = self.accelerator.prepare(
            self.byol,
            self.optimizer,
            self.dataloader
        )

        # 注册缓冲区
        self.register_buffer('step', torch.tensor(0))

    # 等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 打印消息
    def print(self, msg):
        return self.accelerator.print(msg)

    # 前向传播
    def forward(self):
        step = self.step.item()
        data_it = cycle(self.dataloader)

        for _ in range(self.num_train_steps):
            images = next(data_it)

            with self.accelerator.autocast():
                loss = self.byol(images)
                self.accelerator.backward(loss)

            self.print(f'loss {loss.item():.3f}')

            self.optimizer.zero_grad()
            self.optimizer.step()

            self.wait()

            self.byol.update_moving_average()

            self.wait()

            # 每隔一定步数保存检查点
            if not (step % self.checkpoint_every) and self.accelerator.is_main_process:
                checkpoint_num = step // self.checkpoint_every
                checkpoint_path = self.checkpoint_folder / f'checkpoint.{checkpoint_num}.pt'
                torch.save(self.net.state_dict(), str(checkpoint_path))

            self.wait()

            step += 1

        self.print('training complete')

.\lucidrains\byol-pytorch\byol_pytorch\__init__.py

# 从 byol_pytorch 模块中导入 BYOL 类
# 该类用于实现 BYOL(Bootstrap Your Own Latent)算法
from byol_pytorch.byol_pytorch import BYOL
# 从 byol_pytorch 模块中导入 BYOLTrainer 和 MockDataset 类
# BYOLTrainer 类用于训练 BYOL 模型,MockDataset 类用于创建模拟数据集
from byol_pytorch.trainer import BYOLTrainer, MockDataset

Pytorch-lightning example script

Requirements

$ pip install pytorch-lightning
$ pip install pillow

Run

$ python train.py --image_folder /path/to/your/images

.\lucidrains\byol-pytorch\examples\lightning\train.py

# 导入所需的库
import os
import argparse
import multiprocessing
from pathlib import Path
from PIL import Image

import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

# 导入自定义的 BYOL 模块和 pytorch lightning 模块
from byol_pytorch import BYOL
import pytorch_lightning as pl

# 加载预训练的 resnet 50 模型
resnet = models.resnet50(pretrained=True)

# 解析命令行参数
parser = argparse.ArgumentParser(description='byol-lightning-test')
parser.add_argument('--image_folder', type=str, required=True,
                    help='path to your folder of images for self-supervised learning')
args = parser.parse_args()

# 定义常量
BATCH_SIZE = 32
EPOCHS = 1000
LR = 3e-4
NUM_GPUS = 2
IMAGE_SIZE = 256
IMAGE_EXTS = ['.jpg', '.png', '.jpeg']
NUM_WORKERS = multiprocessing.cpu_count()

# 定义 pytorch lightning 模块
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()

# 定义处理灰度图像的函数
def expand_greyscale(t):
    return t.expand(3, -1, -1)

# 定义图像数据集类
class ImagesDataset(Dataset):
    def __init__(self, folder, image_size):
        super().__init__()
        self.folder = folder
        self.paths = []

        for path in Path(f'{folder}').glob('**/*'):
            _, ext = os.path.splitext(path)
            if ext.lower() in IMAGE_EXTS:
                self.paths.append(path)

        print(f'{len(self.paths)} images found')

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Lambda(expand_greyscale)
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        img = img.convert('RGB')
        return self.transform(img)

# 主程序入口
if __name__ == '__main__':
    # 创建图像数据集对象
    ds = ImagesDataset(args.image_folder, IMAGE_SIZE)
    # 创建数据加载器
    train_loader = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

    # 创建自监督学习模型
    model = SelfSupervisedLearner(
        resnet,
        image_size=IMAGE_SIZE,
        hidden_layer='avgpool',
        projection_size=256,
        projection_hidden_size=4096,
        moving_average_decay=0.99
    )

    # 创建训练器
    trainer = pl.Trainer(
        gpus=NUM_GPUS,
        max_epochs=EPOCHS,
        accumulate_grad_batches=1,
        sync_batchnorm=True
    )

    # 训练模型
    trainer.fit(model, train_loader)

Bootstrap Your Own Latent (BYOL), in Pytorch

PyPI version

Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.

This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.

Update 1: There is now new evidence that batch normalization is key to making this technique work well

Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work

Update 3: Finally, we have some analysis for why this works

Yannic Kilcher's excellent explanation

Now go save your organization from having to pay for labels 😃

Install

$ pip install byol-pytorch

Usage

Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.

BYOL → SimSiam

A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the use_momentum flag to False. You will no longer need to invoke update_moving_average if you go this route as shown in the example below.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = False       # turn off momentum in the target encoder
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

Advanced

While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    projection_size = 256,           # the projection size
    projection_hidden_size = 4096,   # the hidden dimension of the MLP for both the projection and prediction
    moving_average_decay = 0.99      # the moving average decay factor for the target encoder, already set at what paper recommends
)

By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the augment_fn keyword.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn
)

In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

augment_fn2 = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip(),
    kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn,
    augment_fn2 = augment_fn2,
)

To fetch the embeddings or the projections, you simply have to pass in a return_embeddings = True flag to the BYOL learner instance

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

imgs = torch.randn(2, 3, 256, 256)
projection, embedding = learner(imgs, return_embedding = True)

Distributed Training

The repository now offers distributed training with 🤗 Huggingface Accelerate. You just have to pass in your own Dataset into the imported BYOLTrainer

First setup the configuration for distributed training by invoking the accelerate CLI

$ accelerate config

Then craft your training script as shown below, say in ./train.py

from torchvision import models

from byol_pytorch import (
    BYOL,
    BYOLTrainer,
    MockDataset
)

resnet = models.resnet50(pretrained = True)

dataset = MockDataset(256, 10000)

trainer = BYOLTrainer(
    resnet,
    dataset = dataset,
    image_size = 256,
    hidden_layer = 'avgpool',
    learning_rate = 3e-4,
    num_train_steps = 100_000,
    batch_size = 16,
    checkpoint_every = 1000     # improved model will be saved periodically to ./checkpoints folder 
)

trainer()

Then use the accelerate CLI again to launch the script

$ accelerate launch ./train.py

Alternatives

If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.

https://github.com/lucidrains/pixel-level-contrastive-learning

Citation

@misc{grill2020bootstrap,
    title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
    author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
    year = {2020},
    eprint = {2006.07733},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{chen2020exploring,
    title={Exploring Simple Siamese Representation Learning}, 
    author={Xinlei Chen and Kaiming He},
    year={2020},
    eprint={2011.10566},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

.\lucidrains\byol-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'byol-pytorch',
  # 查找并包含除了'examples'之外的所有包
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '0.8.0',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Self-supervised contrastive learning made simple',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/byol-pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 关键词列表
  keywords = [
      'self-supervised learning',
      'artificial intelligence'
  ],
  # 安装依赖
  install_requires=[
      'accelerate',
      'beartype',
      'torch>=1.6',
      'torchvision>=0.8'
  ],
  # 分类标签
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\CALM-pytorch\CALM_pytorch\CALM.py

# 从 math 模块中导入 ceil 函数
from math import ceil
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 contextlib 模块中导入 nullcontext 和 contextmanager 函数
from contextlib import nullcontext, contextmanager

# 从 dataclasses 模块中导入 dataclass 装饰器
from dataclasses import dataclass

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 torch.nn 模块中导入 Module 和 ModuleList 类
from torch.nn import Module, ModuleList
# 从 torch.utils.data 模块中导入 Dataset 和 DataLoader 类
from torch.utils.data import Dataset, DataLoader
# 从 torch.optim.lr_scheduler 模块中导入 _LRScheduler 类
from torch.optim.lr_scheduler import _LRScheduler
# 从 torch 模块中导入 nn、einsum 和 Tensor 类
from torch import nn, einsum, Tensor

# 导入 beartype 库
from beartype import beartype
from beartype.door import is_bearable
# 从 beartype.typing 模块中导入 List、Optional、Callable、Type、Tuple、Union、Literal 类型
from beartype.typing import List, Optional, Callable, Type, Tuple, Union, Literal

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 从 x_transformers.x_transformers 模块中导入 RMSNorm、Attention 和 TransformerWrapper 类
from x_transformers.x_transformers import (
    RMSNorm,
    Attention,
    TransformerWrapper,
)

# 导入 accelerate 库
from accelerate import Accelerator

# 从 pytorch_custom_utils 模块中导入 OptimizerWithWarmupSchedule、get_adam_optimizer 和 auto_unwrap_model 函数
from pytorch_custom_utils import (
    OptimizerWithWarmupSchedule,
    get_adam_optimizer,
    auto_unwrap_model
)

# 从 pytorch_custom_utils.accelerate_utils 模块中导入 model_forward_contexts 函数
from pytorch_custom_utils.accelerate_utils import (
    model_forward_contexts
)

# 从 CALM_pytorch.sampling_utils 模块中导入 sample、top_p 和 top_k 函数

# types

# 定义 Sequence 类型为 Tuple 或 List
Sequence = Union[Tuple, List]

# 定义 HiddenPosition 类型为 'input' 或 'output'
HiddenPosition = Union[Literal['input'], Literal['output']]

# 定义 SequenceOf 函数,接受类型参数 t,返回 Tuple[t, ...] 或 List[t]
def SequenceOf(t):
    return Union[Tuple[t, ...], List[t]]

# 定义 SingularOrMany 函数,接受类型参数 t,返回 t 或 SequenceOf(t)
def SingularOrMany(t):
    return Union[t, SequenceOf(t)]

# helpers

# 定义 exists 函数,判断变量是否存在
def exists(v):
  return v is not None

# 定义 default 函数,返回第一个参���或默认值
def default(v, d):
    return v if exists(v) else d

# 定义 xnor 函数,实现逻辑异或操作
def xnor(x, y):
    return not (x ^ y)

# 定义 cast_tuple 函数,将参数转换为元组
def cast_tuple(t, length = 1):
    return t if is_bearable(t, Sequence) else ((t,) * length)

# 定义 get_block_output_from_hook_outputs 函数,从钩子输出中获取模块输出
def get_block_output_from_hook_outputs(
    hidden_position: HiddenPosition,
    _, inp, out
):
    maybe_tensor = out if hidden_position == 'output' else inp

    if isinstance(maybe_tensor, tuple):
        maybe_tensor = maybe_tensor[0]

    assert torch.is_tensor(maybe_tensor)
    return maybe_tensor

# freezing llms

# 定义 set_module_requires_grad_ 函数,设置模块参数是否需要梯度
@beartype
def set_module_requires_grad_(
    module: Module,
    requires_grad: bool
):
    for param in module.parameters():
        param.requires_grad = requires_grad

# 定义 freeze_all_layers_ 函数,冻结所有层的参数
def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

# function for returning an ordered list of modules, where the output of the module is the output of that transformer block layer
# ex. for x-transformers TransformerWrapper

# 定义 x_transformer_blocks 函数,返回 TransformerWrapper 中每个 transformer block 的模块列表
@beartype
def x_transformer_blocks(transformer: TransformerWrapper) -> List[Module]:
    blocks = []
    for layer in transformer.attn_layers.layers:
        blocks.append(layer[-1])
    return blocks[1::2]

# helper classes

# 定义 Recorder 类
class Recorder:
    # Recorder 类的构造函数
    @beartype
    def __init__(
        self,
        outputs: Optional[List] = None,
        forward_hook_get_hidden: HiddenPosition = 'output',
        modules: Optional[List] = None,
    ):
        self.output = default(outputs, [])
        self.modules = modules
        self.get_output_fn = partial(get_block_output_from_hook_outputs, forward_hook_get_hidden)

    # Recorder 类的调用函数
    def __call__(self, *args):

        if exists(self.modules):
            self.modules.append(args[0])

        hidden = self.get_output_fn(*args)
        self.output.append(hidden.detach())

# 定义 ExtractHiddensWrapper 类
class ExtractHiddensWrapper(Module):
    # ExtractHiddensWrapper 类的构造函数
    @beartype
    def __init__(
        self,
        model: Module,
        blocks: List[Module],
        hidden_positions: SingularOrMany(HiddenPosition) = 'output'
    ):
        super().__init__()
        hidden_positions = cast_tuple(hidden_positions, len(blocks))
        assert len(hidden_positions) == len(blocks)

        self.model = model

        self.outputs = []
        self.modules = []
        self.recorders = []

        for block, hidden_position in zip(blocks, hidden_positions):
            recorder = Recorder(self.outputs, hidden_position, self.modules)
            self.recorders.append(recorder)
            block.register_forward_hook(recorder)
    # 定义一个方法用于前向传播,接受任意参数和关键字参数,可以选择是否返回被挂钩的模块
    def forward(self, *args, return_hooked_modules = False, **kwargs):
        # 调用模型的前向传播方法,传入参数和关键字参数
        self.model(*args, **kwargs)

        # 复制输出和模块字典
        outputs = self.outputs.copy()
        modules = self.modules.copy()

        # 清空输出和模块字典
        self.outputs.clear()
        self.modules.clear()

        # 如果不需要返回被挂钩的模块,则返回输出字典
        if not return_hooked_modules:
            return outputs

        # 如果需要返回被挂钩的模块,则同时返回输出字典和模块字典
        return outputs, modules
# 定义交叉注意力块类
class CrossAttentionBlock(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,
        dim_context,
        linear_project_context = True,  # 在论文中,他们对增强隐藏状态进行了投影。不确定是否需要,但最好先准确
        pre_rmsnorm = False,
        forward_hook_get_hidden: Union[
            Literal['output'],
            Literal['input']
        ] = 'output',
        **kwargs
    ):
        super().__init__()
        # 如果需要预先进行 RMS 归一化,则创建 RMSNorm 对象
        self.pre_rmsnorm = RMSNorm(dim) if pre_rmsnorm else nn.Identity()

        self.context_proj = None

        self.dim = dim
        self.dim_context = dim_context

        # 如果需要线性投影上下文,则创建线性层对象
        if linear_project_context:
            self.context_proj = nn.Linear(dim_context, dim)
            dim_context = dim

        # 创建注意力对象
        self.attn = Attention(
            dim = dim,
            dim_context = dim_context,
            zero_init_output = True,
            gate_value_heads = True,
            **kwargs
        )

        self.context = None
        self.context_mask = None
        self.forward_hook_get_hidden = forward_hook_get_hidden

    # 设置掩码
    def set_mask(self, mask: Tensor):
        self.context_mask = mask

    # 取消掩码
    def unset_mask(self):
        self.context_mask = None

    # 前向传播函数
    def forward(self, *hook_args):
        x = get_block_output_from_hook_outputs(self.forward_hook_get_hidden, *hook_args)

        context = self.context
        assert exists(context)

        maybe_enable_grad = torch.enable_grad if self.training else nullcontext

        with maybe_enable_grad():
            res = x
            x = self.pre_rmsnorm(x)

            if exists(self.context_proj):
                context = self.context_proj(context)

            out = self.attn(x, context, context_mask = self.context_mask) + res

        return out

# 主类
@dataclass
class AugmentParams:
    model: Module
    hidden_position: SingularOrMany(HiddenPosition) = 'output'
    transformer_blocks: Optional[List[Module]] = None
    extract_blocks_fn: Optional[Callable[[Module], List[Module]]] = None
    model_return_hiddens: bool = False
    input_shape: Optional[Tuple[int, ...]] = None
    connections: Optional[Tuple[Tuple[int, int], ...]] = None
    connect_every_num_layers: int = 4 # 在论文中,他们做了 4 层
    mask_kwarg: Optional[str] = None

# CALM 类
class CALM(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        anchor_llm: Module,
        augment_llms: SingularOrMany(AugmentParams),
        *,
        attn_kwargs: dict = dict(
            linear_project_context = True,
            pre_rmsnorm = True,
            flash = True
        ),
        anchor_extract_blocks_fn: Callable[[Module], List[Module]] = None,
        anchor_transformer_blocks: Optional[List[Module]] = None,
        anchor_hidden_position: SingularOrMany(HiddenPosition) = 'output',
        pad_id: int = -1
    def state_dict(self):
        return self.cross_attns.state_dict()

    def load_state_dict(self, pkg, strict = False):
        self.cross_attns.load_state_dict(pkg, strict = strict)

    def parameters(self):
        return self.cross_attns.parameters()

    def release_cross_attn_contexts(self):
        for one_augment_cross_attns in self.cross_attns:
            for cross_attn in one_augment_cross_attns:
                cross_attn.context = None

    def forward_augments(
        self,
        prompt: Tensor,
        prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None
    ):
        # 如果只提供一个提示并且有多个增强LLM,则将该提示输入到所有增强LLM中

        num_augment_llms = len(self.augment_llms)

        prompts = cast_tuple(prompt, num_augment_llms)

        assert len(prompts) == num_augment_llms

        # 提示掩码

        if not exists(prompt_mask):
            prompt_mask = tuple((p != self.pad_id if not torch.is_floating_point(p) else None) for p in prompts)

        prompt_mask = cast_tuple(prompt_mask, num_augment_llms)

        prompt_masks = prompt_mask # 在这一点上,应该是复数

        assert len(prompt_masks) == num_augment_llms

        # 调用增强LLM,使用前向钩子收集隐藏状态

        augments_hiddens = []

        with torch.no_grad():

            self.augment_llms.eval()

            for augment_llm, params, prompt, prompt_mask in zip(self.augment_llms, self.augment_llms_params, prompts, prompt_masks):
                augment_llm_kwarg = dict()

                if exists(params.mask_kwarg):
                    augment_llm_kwarg = {params.mask_kwarg: prompt_mask}

                one_augment_hiddens = augment_llm(prompt, **augment_llm_kwarg)

                augments_hiddens.append(one_augment_hiddens)

        # 为锚点前向设置每个交叉注意力块的上下文

        for one_augment_hiddens, one_augment_cross_attns, one_augment_connections in zip(augments_hiddens, self.cross_attns, self.connections):

            for (augment_layer_index, _), cross_attn in zip(one_augment_connections, one_augment_cross_attns):
            
                cross_attn.context = one_augment_hiddens[augment_layer_index - 1]

        return prompts, prompt_masks

    @contextmanager
    def set_cross_attn_masks(self, masks):
        # 为交叉注意力设置上下文掩码

        for one_cross_attn, mask in zip(self.cross_attns, masks):
            for cross_attn in one_cross_attn:
                cross_attn.set_mask(mask)

        yield

        # 取消设置上下文掩码

        for one_cross_attn in self.cross_attns:
            for cross_attn in one_cross_attn:
                cross_attn.unset_mask()


    @torch.no_grad()
    def generate(
        self,
        prompt: Tensor,
        seq_len: int,
        prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None,
        filter_fn: Callable = top_p,
        filter_kwargs: dict = dict(
            thres = 0.9
        )
    ):
        batch, device = prompt.shape[0], next(self.cross_attns.parameters()).device

        self.eval()

        # 在所有增强模型上运行前向并收集隐藏状态

        prompts, prompt_masks = self.forward_augments(prompt = prompt, prompt_mask = prompt_mask)

        with self.set_cross_attn_masks(prompt_masks):

            # 采样

            generated =  sample(
                self.anchor_llm,
                prompt,
                seq_len = seq_len,
                filter_fn = filter_fn,
                filter_kwargs = filter_kwargs
            )

            self.release_cross_attn_contexts()

        return generated

    @beartype
    def forward(
        self,
        seq: Tensor,
        *,
        prompt: SingularOrMany(Tensor),
        prompt_mask: Optional[SingularOrMany(Tensor)] = None,
        mask: Optional[Tensor] = None,
        return_loss = True,
        anchor_llm_in_train_mode = True  # 对此不确定
        ):
        # 如果需要返回损失值,则将交叉注意力模型设置为训练模式
        if return_loss:
            self.cross_attns.train()

            # 如果锚定语言模型需要在训练模式下,则设置为训练模式,否则设置为评估模式
            if anchor_llm_in_train_mode:
                self.anchor_llm.train()
            else:
                self.anchor_llm.eval()

            # 将序列截断,去掉最后一个字符,用于输入和标签
            seq, labels = seq[:, :-1], seq[:, 1:]

        # 在所有数据增强模型上运行前向传播,并收集隐藏状态

        prompts, prompt_masks = self.forward_augments(prompt=prompt, prompt_mask=prompt_mask)

        # 设置交叉注意力模型的掩码
        with self.set_cross_attn_masks(prompt_masks):
            # 调用锚定语言模型,该模型应该处理与增强语言模型隐藏状态的交叉注意力

            logits = self.anchor_llm(seq)

            # 释放交叉注意力上下文
            self.release_cross_attn_contexts()

            # 断言锚定语言模型返回的 logits 维度应为 (batch, seq, num tokens)
            assert logits.ndim == 3, 'anchor llm should return logits in the shape (batch, seq, num tokens)'

        # 返回用于解码的 logits

        if not return_loss:
            return logits

        # 考虑提示掩码

        if exists(mask):
            # 如果存在掩码,则使用掩码填充标签
            labels = labels.masked_fill(~mask[:, 1:], self.pad_id)

        # 用于微调

        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels,
            ignore_index=self.pad_id
        )

        return loss
# 定义一个循环生成器,用于循环遍历数据加载器中的批次数据
def cycle(dl):
    while True:
        for batch in dl:
            yield batch

# 使用装饰器自动解包模型
@auto_unwrap_model()
class FineTuner:

    # 初始化方法,接收多个参数
    @beartype
    def __init__(
        self,
        calm: CALM,
        *,
        num_train_steps: int,
        learning_rate: float,
        weight_decay: float,
        batch_size: int,
        dataset: Dataset,
        data_kwarg_names: Tuple[str, ...] = ('seq', 'mask', 'prompt'),
        accelerate_kwargs: dict = dict(),
        checkpoint_every: int = 1000,
        checkpoint_path: str = './checkpoints',
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        warmup_steps: int = 1000,
        max_grad_norm = 0.5,
        grad_accum_steps = 1
    ):
        # 初始化加速器
        self.accelerator = Accelerator(**accelerate_kwargs)

        # 创建数据加载器
        self.dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
        self.data_kwarg_names = data_kwarg_names

        # 设置模型
        self.model = calm

        # 创建 Adam 优化器
        adam = get_adam_optimizer(
            calm.parameters(),
            lr = learning_rate,
            wd = weight_decay
        )

        # 初始化优化器和学习率调度器
        self.optimizer = OptimizerWithWarmupSchedule(
            accelerator = self.accelerator,
            optimizer = adam,
            scheduler = scheduler,
            scheduler_kwargs = scheduler_kwargs,
            warmup_steps = warmup_steps,
            max_grad_norm = max_grad_norm
        )

        self.step = 0
        self.num_train_steps = num_train_steps
        self.grad_accum_steps = grad_accum_steps

        self.checkpoint_every = checkpoint_every
        self.checkpoint_path = Path(checkpoint_path)
        self.checkpoint_path.mkdir(exist_ok = True, parents = True)

    # 判断当前进程是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 打印信息
    def print(self, msg):
        self.accelerator.print(msg)

    # 保存模型和优化器状态
    def save(self, filename: str, overwrite: bool = True):
        path = self.checkpoint_path / filename
        assert overwrite or not path.exists()

        pkg = dict(
            model = self.model.state_dict(),
            optimizer = self.optimizer.state_dict(),
            step = self.step
        )

        torch.save(pkg, str(path))

    # 加载模型和优化器状态
    def load(self, filename: str):
        path = self.checkpoint_path / filename
        assert path.exists()

        pkg = torch.load(str(path))

        self.model.load_state_dict(pkg['model'])
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.step = pkg['step']

    # 定义 FineTuner 类的调用方法
    def __call__(self, forward_kwargs: dict = dict()):
        dl_iter = cycle(self.dl)
        self.model.train()

        for step in range(self.step, self.num_train_steps):

            for context in model_forward_contexts(
                model = self.model,
                accelerator = self.accelerator,
                grad_accum_steps = self.grad_accum_steps
            ):
                with context():
                    data = next(dl_iter)

                    if not isinstance(data, dict):
                        data = dict(zip(self.data_kwarg_names, data))

                    loss = self.model(**data, **forward_kwargs)

                    self.accelerator.backward(loss / self.grad_accum_steps)

            self.print(f'{step + 1}: {loss.item():.3f}')

            self.optimizer.step()
            self.optimizer.zero_grad()

            self.step += 1

            self.accelerator.wait_for_everyone()

            if self.is_main and not (self.step % self.checkpoint_every):
                num = self.step // self.checkpoint_every
                self.save(f'checkpoint.{num}.pt')

            self.accelerator.wait_for_everyone()

        self.print('training complete')
        self.save('checkpoint.-1.pt')

.\lucidrains\CALM-pytorch\CALM_pytorch\sampling_utils.py

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的函数模块
from torch import Tensor  # 导入 PyTorch 中的张量
from torch.nn import Module  # 导入 PyTorch 中的神经网络模块
from torch.nn.utils.rnn import pad_sequence  # 导入 PyTorch 中的序列填充函数

from beartype import beartype  # 导入 beartype 库中的类型检查装饰器
from beartype.typing import Optional, Callable, List, Tuple  # 导入 beartype 库中的类型注解

from einops import rearrange  # 导入 einops 库中的重排函数

from tqdm import tqdm  # 导入 tqdm 库中的进度条显示函数

def exists(v):  # 定义函数,判断变量是否存在
    return v is not None  # 返回变量是否不为 None

def default(v, d):  # 定义函数,返回变量或默认值
    return v if exists(v) else d  # 如果变量存在则返回变量,否则返回默认值

# 采样辅助函数

def log(t, eps = 1e-20):  # 定义函数,计算张量的对数
    return torch.log(t.clamp(min = eps))  # 返回张量的对数,避免小于 eps 的值

def gumbel_noise(t):  # 定义函数,生成 Gumbel 噪声
    noise = torch.zeros_like(t).uniform_(0, 1)  # 生成与输入张量相同大小的均匀分布噪声
    return -log(-log(noise))  # 返回 Gumbel 噪声

def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True, eps = 1e-10):  # 定义函数,使用 Gumbel 分布进行采样
    return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)  # 返回 Gumbel 采样结果

# nucleus

def top_p(logits, thres = 0.9):  # 定义函数,根据 top-p 策略进行筛选
    sorted_logits, sorted_indices = torch.sort(logits, descending = True)  # 对 logits 进行降序排序
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)  # 计算累积概率

    sorted_indices_to_remove = cum_probs > thres  # 根据阈值筛选需要移除的索引
    sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)  # 对需要移除的索引进行填充

    sorted_logits[sorted_indices_to_remove] = float('-inf')  # 将需要移除的 logits 置为负无穷
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)  # 返回根据 top-p 策略筛选后的 logits

# topk

def top_k(logits, frac_num_tokens = 0.1, k: Optional[int] = None):  # 定义函数,根据 top-k 策略进行筛选
    num_tokens = logits.shape[-1]  # 获取 logits 的最后一个维度大小

    k = default(k, ceil(frac_num_tokens * num_tokens))  # 计算 k 值
    k = min(k, num_tokens)  # 取 k 和 num_tokens 中的较小值

    val, ind = torch.topk(logits, k)  # 获取 top-k 的值和索引
    probs = torch.full_like(logits, float('-inf'))  # 创建与 logits 相同大小的全为负无穷的张量
    probs.scatter_(1, ind, val)  # 根据 top-k 的索引和值填充 probs
    return probs  # 返回根据 top-k 策略筛选后的 logits

# 解码

@torch.no_grad()  # 禁用梯度计算
@beartype  # 使用 beartype 类型检查装饰器
def sample(  # 定义函数,生成序列样本
    net: Module,  # 神经网络模型
    prompts,  # 输入的提示序列
    seq_len: int,  # 生成序列的长度
    temperature = 1.,  # 温度参数
    filter_fn: Callable = top_p,  # 筛选函数,默认为 top-p
    filter_kwargs: dict = dict(),  # 筛选函数的参数
    pad_id: int = -1,  # 填充标识符
    eos_id: Optional[int] = None,  # 结束标识符
    output_keep_prompt = False  # 是否保留提示序列
):
    device = next(net.parameters()).device  # 获取模型参数所在的设备
    net.eval()  # 设置模型为评估模式

    if isinstance(prompts, (tuple, list)):  # 如果提示序列是元组或列表
        prompts = pad_sequence(prompts, batch_first = True, padding_value = pad_id)  # 对提示序列进行填充

    batch, prompts_tensor_len = prompts.shape  # 获取提示序列的形状信息

    batch_arange = torch.arange(batch, device = device)[..., None]  # 创建批次索引张量

    prompt_lens = (prompts != pad_id).sum(dim = -1)  # 计算提示序列的长度
    curr_seq_indices = prompt_lens[..., None]  # 当前序列索引

    out = prompts.clone()  # 克隆提示序列作为输出序列

    pbar = tqdm(  # 创建进度条
        initial = out.shape[-1],  # 初始值
        total = seq_len,  # 总步数
        desc = 'sampling'  # 描述
    )

    while (curr_seq_indices < seq_len).any():  # 当当前序列索引小于生成序列长度时循环
        out = F.pad(out, (0, 1), value = pad_id)  # 对输出序列进行填充

        net_input = out.masked_fill(out == pad_id, 0)  # 将填充值替换为 0

        logits = net(net_input)  # 输入网络获取 logits

        logits = logits[batch_arange, curr_seq_indices]  # 根据当前序列索引获取 logits
        logits = rearrange(logits, 'b 1 d -> b d')  # 重排 logits 的维度

        logits = filter_fn(logits, **filter_kwargs)  # 根据筛选函数筛选 logits
        sampled_tokens = gumbel_sample(logits, temperature = temperature, dim = -1)  # 使用 Gumbel 采样获取 tokens

        out[batch_arange, curr_seq_indices] = sampled_tokens  # 更新输出序列

        curr_seq_indices += 1  # 当前序列索引加一
        curr_seq_indices.clamp_(max = seq_len)  # 限制当前序列索引的最大值为生成序列长度
        pbar.update(1)  # 更新进度条

        if not exists(eos_id):  # 如果结束标识符不存在
            continue  # 继续下一次循环

        is_eos_mask = out == eos_id  # 获取结束标识符的掩码
        all_eos = is_eos_mask.any(dim = -1).all()  # 判断是否所有序列都包含结束标识符

        if all_eos:  # 如果所有序列都包含结束标识符
            break  # 跳出循环

    pbar.close()  # 关闭进度条

    out = out[:, :seq_len]  # 截取生成序列的长度

    if exists(eos_id):  # 如果结束标识符存在
        is_eos_mask = out == eos_id  # 获取结束标识符的掩码
        after_eos_mask = F.pad(is_eos_mask.cumsum(dim = -1) > 0, (1, -1), value = False)  # 获取结束标识符后的掩码
        out = out.masked_fill_(after_eos_mask, pad_id)  # 将结束标识符后的位置填充为填充标识符

    if output_keep_prompt:  # 如果需要保留提示序列
        return out  # 返回输出序列

    prompt_mask = torch.arange(out.shape[-1], device = device) < prompt_lens[..., None]  # 创建提示序列的掩码

    generated_seq_mask = out != pad_id & ~prompt_mask  # 生成序列的掩码
    seq_lens = generated_seq_mask.sum(dim = -1).tolist()  # 计算生成序列的长度

    return out[generated_seq_mask].split(seq_lens)  # 返回根据生成序列掩码拆分后的结果

.\lucidrains\CALM-pytorch\CALM_pytorch\__init__.py

# 从 CALM_pytorch.CALM 模块中导入以下类和函数
from CALM_pytorch.CALM import (
    AugmentParams,  # 导入 AugmentParams 类
    ExtractHiddensWrapper,  # 导入 ExtractHiddensWrapper 类
    CALM,  # 导入 CALM 类
    FineTuner  # 导入 FineTuner 类
)

CALM - Pytorch

Implementation of CALM from the paper LLM Augmented LLMs: Expanding Capabilities through Composition, out of Google Deepmind

Can support any number of augmentation LLMs

Install

$ pip install CALM-pytorch

Appreciation

Usage

ex. with x-transformers

import torch
from x_transformers import TransformerWrapper, Decoder

augment_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 2,
        heads = 8
    )
)

# import CALM wrapper

from CALM_pytorch import CALM, AugmentParams

calm = CALM(
    anchor_llm,
    augment_llms = AugmentParams(
        model = augment_llm,
        connect_every_num_layers = 4
    )
)

# mock input

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()
prompt = torch.randint(0, 20000, (1, 256))

# forward for finetuning loss

loss = calm(
    seq,
    mask = mask,
    prompt = prompt
)

loss.backward()

# after much training, prompt the composed model

generated = calm.generate(
    prompt = seq[:, :1],
    seq_len = 1024
)

To use a handy trainer class using 🤗 Accelerate, just import FineTuner and use as follows

trainer = FineTuner(
    calm = calm,
    dataset = dataset,   # returns a dictionary of input kwargs to calm - dict(seq: Tensor, mask: Tensor, prompt: Tensor). it can also return a Tuple, in which data_kwargs needs to be set to the correct ordered value of kwarg names
    batch_size = 16,
    num_train_steps = 10000,
    learning_rate = 3e-4,
    weight_decay = 1e-2,
    warmup_steps = 1000,
    checkpoint_every = 1000
)

trainer()

# checkpoints of the cross attention parameters will be saved to ./checkpoints every 1000 steps

To explore multiple augmentation LLMs, simply pass in a list for augment_llm

ex.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llm = [AugmentParams(augment_llm1), AugmentParams(augment_llm2)] # pass in a list of AugmentParams wrapping model and other hparams specific to that transformer
)

Say you want to explore different types of connectivity between anchor and augmentation model(s), just pass in the connections as a tuple of tuple integer pairs, specifying the anchor to augment layer number.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
        AugmentParams(
            model = augment_llm1,
            connections = (
                (1, 12),  # 1st layer of augment llm1 attended to by 12th layer of anchor llm
                (2, 12),
                (3, 12),
                (4, 12),
            ),
        ),
        AugmentParams(
            model = augment_llm2,
            connections = (
                (6, 1),   # 6th layer of augment llm2 attended to by 1st layer of anchor llm
                (6, 2),
                (12, 12),
            )
        )
    )
)

CALM setup with 2 specialized augmentation LLMs + a vision transformer

import torch

# pip install vit-pytorch x-transformers

from vit_pytorch.vit import ViT, Attention
from x_transformers import TransformerWrapper, Encoder, Decoder

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

augment_llm1 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

augment_llm2 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 256,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# calm

from CALM_pytorch import CALM, AugmentParams, FineTuner

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
        AugmentParams(
            model = augment_llm1,
            mask_kwarg = 'mask'
        ),
        AugmentParams(
            model = augment_llm2,
            mask_kwarg = 'mask'
        ),
        AugmentParams(
            model = vit,
            input_shape = (3, 256, 256),
            hidden_position = 'input',
            extract_blocks_fn = lambda vit: [m for m in vit.modules() if isinstance(m, Attention)]
        )
    ),
    attn_kwargs = dict(
        linear_project_context = True,
        pre_rmsnorm = True,
        flash = True
    )
)

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()

prompt = (
    torch.randint(0, 20000, (1, 256)),
    torch.randint(0, 20000, (1, 256)),
    torch.randn(1, 3, 256, 256)
)

loss = calm(
    seq,
    mask = mask,
    prompt = prompt
)

loss.backward()

Todo

Citations

@inproceedings{Bansal2024LLMAL,
  title   = {LLM Augmented LLMs: Expanding Capabilities through Composition},
  author  = {Rachit Bansal and Bidisha Samanta and Siddharth Dalmia and Nitish Gupta and Shikhar Vashishth and Sriram Ganapathy and Abhishek Bapna and Prateek Jain and Partha Pratim Talukdar},
  year    = {2024},
  url     = {https://api.semanticscholar.org/CorpusID:266755751}
}

.\lucidrains\CALM-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'CALM-Pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.2.1',  # 版本号
  license='MIT',  # 许可证
  description = 'CALM - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/CALM-pytorch',  # URL
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'composing LLMs'  # 关键词
  ],
  install_requires = [  # 安装依赖
    'accelerate',  # 加速库
    'beartype',  # 类型检查库
    'einops>=0.7.0',  # 数据重塑库
    'pytorch-custom-utils>=0.0.11',  # PyTorch自定义工具库
    'torch>=2.0',  # PyTorch库
    'tqdm',  # 进度条库
    'x-transformers>=1.27.3'  # 自定义Transformer库
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',  # 开发状态
    'Intended Audience :: Developers',  # 目标受众
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
    'License :: OSI Approved :: MIT License',  # 许可证
    'Programming Language :: Python :: 3.6',  # 编程语言
  ],
)

.\lucidrains\charformer-pytorch\charformer_pytorch\charformer_pytorch.py

# 导入 math 模块
import math
# 从 math 模块中导入 gcd 函数
from math import gcd
# 导入 functools 模块
import functools
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn, F, einsum
import torch.nn.functional as F
from torch import nn, einsum
# 从 einops 模块中导入 rearrange, reduce, repeat
from einops import rearrange, reduce, repeat
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 计算多个数的最小公倍数
def lcm(*numbers):
    return int(functools.reduce(lambda x, y: int((x * y) / gcd(x, y)), numbers, 1))

# 计算带有掩码的张量的均值
def masked_mean(tensor, mask, dim = -1):
    diff_len = len(tensor.shape) - len(mask.shape)
    mask = mask[(..., *((None,) * diff_len))]
    tensor.masked_fill_(~mask, 0.)

    total_el = mask.sum(dim = dim)
    mean = tensor.sum(dim = dim) / total_el.clamp(min = 1.)
    mean.masked_fill_(total_el == 0, 0.)
    return mean

# 计算下一个可被整除的长度
def next_divisible_length(seqlen, multiple):
    return math.ceil(seqlen / multiple) * multiple

# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, *, seq_dim, dim = -1, value = 0.):
    seqlen = tensor.shape[seq_dim]
    length = next_divisible_length(seqlen, multiple)
    if length == seqlen:
        return tensor
    remainder = length - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return F.pad(tensor, (*pad_offset, 0, remainder), value = value)

# 辅助类

# 填充层
class Pad(nn.Module):
    def __init__(self, padding, value = 0.):
        super().__init__()
        self.padding = padding
        self.value = value

    def forward(self, x):
        return F.pad(x, self.padding, value = self.value)

# 深度卷积层
class DepthwiseConv1d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size):
        super().__init__()
        self.conv = nn.Conv1d(dim_in, dim_out, kernel_size, groups = dim_in)
        self.proj_out = nn.Conv1d(dim_out, dim_out, 1)

    def forward(self, x):
        x = self.conv(x)
        return self.proj_out(x)

# 主类

class GBST(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        max_block_size = None,
        blocks = None,
        downsample_factor = 4,
        score_consensus_attn = True
    ):
        super().__init__()
        assert exists(max_block_size) ^ exists(blocks), 'either max_block_size or blocks are given on initialization'
        self.token_emb = nn.Embedding(num_tokens, dim)

        if exists(blocks):
            assert isinstance(blocks, tuple), 'blocks must be a tuple of block sizes'
            self.blocks = tuple(map(lambda el: el if isinstance(el, tuple) else (el, 0), blocks))
            assert all([(offset < block_size) for block_size, offset in self.blocks]), 'offset must be always smaller than the block size'

            max_block_size = max(list(map(lambda t: t[0], self.blocks)))
        else:
            self.blocks = tuple(map(lambda el: (el, 0), range(1, max_block_size + 1)))

        self.pos_conv = nn.Sequential(
            Pad((0, 0, 0, max_block_size - 1)),
            Rearrange('b n d -> b d n'),
            DepthwiseConv1d(dim, dim, kernel_size = max_block_size),
            Rearrange('b d n -> b n d')
        )

        self.score_fn = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... () -> ...')
        )

        self.score_consensus_attn = score_consensus_attn

        assert downsample_factor <= max_block_size, 'final downsample factor should be less than the maximum block size'

        self.block_pad_multiple = lcm(*[block_size for block_size, _ in self.blocks])
        self.downsample_factor = downsample_factor

.\lucidrains\charformer-pytorch\charformer_pytorch\__init__.py

# 从 charformer_pytorch.charformer_pytorch 模块中导入 GBST 类
from charformer_pytorch.charformer_pytorch import GBST

Charformer - Pytorch

Implementation of the GBST (gradient-based subword tokenization) module from the Charformer paper, in Pytorch. The paper proposes a module that automatically learns subword representations, obviating the need for tokenizers in the encoder setting.

AI Coffee Break with Letitia video

Install

$ pip install charformer-pytorch

Usage

import torch
from charformer_pytorch import GBST

tokenizer = GBST(
    num_tokens = 257,             # number of tokens, should be 256 for byte encoding (+ 1 special token for padding in this example)
    dim = 512,                    # dimension of token and intra-block positional embedding
    max_block_size = 4,           # maximum block size
    downsample_factor = 4,        # the final downsample factor by which the sequence length will decrease by
    score_consensus_attn = True   # whether to do the cheap score consensus (aka attention) as in eq. 5 in the paper
)

tokens = torch.randint(0, 257, (1, 1023)) # uneven number of tokens (1023)
mask   = torch.ones(1, 1023).bool()

# both tokens and mask will be appropriately downsampled

tokens, mask = tokenizer(tokens, mask = mask) # (1, 256, 512), (1, 256)

# now pass this on to your transformer

Deviating from the paper, you can also specify block size(s) with different offsets. This is to cover a potential use-case for genomics pre-training, where the tokenizer should be able to learn the correct frame. Simply omit the max_block_size, and pass in blocks as a list of tuples of tuples, each tuple with the format (block size, offset). Offsets must be less than the block size

import torch
from charformer_pytorch import GBST

tokenizer = GBST(
    num_tokens = 4 + 1,
    dim = 512,
    blocks = ((3, 0), (3, 1), (3, 2)),  # block size of 3, with offsets of 0, 1, 2
    downsample_factor = 3,
    score_consensus_attn = True
).cuda()

basepairs = torch.randint(0, 4, (1, 1023)).cuda()
mask      = torch.ones(1, 1023).bool().cuda()

# both basepairs and mask will be appropriately downsampled

basepairs, mask = tokenizer(basepairs, mask = mask)

Citations

@misc{tay2021charformer,
    title   = {Charformer: Fast Character Transformers via Gradient-based Subword Tokenization}, 
    author  = {Yi Tay and Vinh Q. Tran and Sebastian Ruder and Jai Gupta and Hyung Won Chung and Dara Bahri and Zhen Qin and Simon Baumgartner and Cong Yu and Donald Metzler},
    year    = {2021},
    eprint  = {2106.12672},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\charformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'charformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.4',  # 版本号
  license='MIT',  # 许可证
  description = 'Charformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/charformer-pytorch',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'learned tokenization'  # 关键词
  ],
  install_requires=[
    'einops>=0.3',  # 安装所需的依赖包
    'torch>=1.6'  # 安装所需的依赖包
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\chroma-pytorch\chroma_pytorch\chroma_pytorch.py

import torch  # 导入 PyTorch 库
from torch import nn, einsum  # 从 PyTorch 库中导入 nn 模块和 einsum 函数

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange 和 repeat 函数

import math  # 导入 math 库
from pathlib import Path  # 从 pathlib 库中导入 Path 类
from random import random  # 从 random 库中导入 random 函数
from functools import partial  # 从 functools 库中导入 partial 函数
from multiprocessing import cpu_count  # 从 multiprocessing 库中导入 cpu_count 函数

import torch  # 重新导入 PyTorch 库
from torch import nn, einsum  # 从 PyTorch 库中重新导入 nn 模块和 einsum 函数
from torch.special import expm1  # 从 PyTorch 库中导入 expm1 函数
import torch.nn.functional as F  # 从 PyTorch 库中导入 F 模块
from torch.utils.data import Dataset, DataLoader  # 从 PyTorch 库中导入 Dataset 和 DataLoader 类

from torch.optim import Adam  # 从 PyTorch 库中导入 Adam 优化器
from torchvision import transforms as T, utils  # 从 torchvision 库中导入 transforms 模块和 utils 模块

from einops import rearrange, reduce, repeat  # 从 einops 库中重新导入 rearrange、reduce 和 repeat 函数
from einops.layers.torch import Rearrange  # 从 einops 库中导入 Rearrange 类

from tqdm.auto import tqdm  # 从 tqdm 库中导入 tqdm 函数
from ema_pytorch import EMA  # 从 ema_pytorch 库中导入 EMA 类

from accelerate import Accelerator  # 从 accelerate 库中导入 Accelerator 类

# helpers functions

def exists(x):  # 定义 exists 函数,判断变量 x 是否存在
    return x is not None

def default(val, d):  # 定义 default 函数,如果 val 存在则返回 val,否则返回 d()
    if exists(val):
        return val
    return d() if callable(d) else d

def cycle(dl):  # 定义 cycle 函数,循环生成数据集 dl 中的数据
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):  # 定义 has_int_squareroot 函数,判断 num 是否有整数平方根
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):  # 定义 num_to_groups 函数,将 num 分成 divisor 组
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def convert_image_to(img_type, image):  # 定义 convert_image_to 函数,将图像转换为指定类型
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# small helper modules

class Residual(nn.Module):  # 定义 Residual 类��实现残差连接
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim, dim_out = None):  # 定义 Upsample 函数,上采样操作
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None):  # 定义 Downsample 函数,下采样操作
    return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)

class LayerNorm(nn.Module):  # 定义 LayerNorm 类,实现层归一化
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * (var + eps).rsqrt() * self.g

class PreNorm(nn.Module):  # 定义 PreNorm 类,实现预归一化
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

# positional embeds

class LearnedSinusoidalPosEmb(nn.Module):  # 定义 LearnedSinusoidalPosEmb 类,实现学习的正弦位置嵌入
    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

# building block modules

class Block(nn.Module):  # 定义 Block 类,实现基本块
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):  # 定义 ResnetBlock 类,实现残差块
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    # 定义前向传播函数,接受输入 x 和时间嵌入 time_emb
    def forward(self, x, time_emb = None):

        # 初始化 scale_shift 为 None
        scale_shift = None
        # 如果 self.mlp 和 time_emb 都存在
        if exists(self.mlp) and exists(time_emb):
            # 将 time_emb 输入到 self.mlp 中进行处理
            time_emb = self.mlp(time_emb)
            # 重新排列 time_emb 的维度,增加两个维度
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            # 将 time_emb 拆分成两部分,分别赋值给 scale 和 shift
            scale_shift = time_emb.chunk(2, dim = 1)

        # 将输入 x 传入第一个块中进行处理
        h = self.block1(x, scale_shift = scale_shift)

        # 将处理后的结果传入第二个块中进行处理
        h = self.block2(h)

        # 返回处理后的结果与输入 x 经过残差卷积的结果之和
        return h + self.res_conv(x)
class LinearAttention(nn.Module):
    # 定义线性注意力机制模块
    def __init__(self, dim, heads = 4, dim_head = 32):
        # 初始化函数
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        # 将输入转换为查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            # 输出转换为指定维度
            nn.Conv2d(hidden_dim, dim, 1),
            # 对输出进行 LayerNorm 处理
            LayerNorm(dim)
        )

    def forward(self, x):
        # 前向传播函数
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale
        v = v / (h * w)

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

class Attention(nn.Module):
    # 定义注意力机制模块
    def __init__(self, dim, heads = 4, dim_head = 32):
        # 初始化函数
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        # 将输入转换为查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        # 前向传播函数
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q * self.scale

        sim = einsum('b h d i, b h d j -> b h i j', q, k)
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

# model

class Unet(nn.Module):
    # 定义 Unet 模型
    def __init__(
        self,
        dim,
        init_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        resnet_block_groups = 8,
        learned_sinusoidal_dim = 16
    ):
        # 调用父类的构造函数
        super().__init__()

        # 确定维度
        self.channels = channels
        input_channels = channels * 2
        init_dim = default(init_dim, dim)
        # 初始化卷积层,输入通道数为input_channels,输出通道数为init_dim,卷积核大小为7,填充为3
        self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)

        # 计算不同层次的维度
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # 定义ResnetBlock类的部分参数
        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # 时间嵌入
        time_dim = dim * 4
        sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
        fourier_dim = learned_sinusoidal_dim + 1

        # 时间嵌入的多层感知机
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # 层次
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        # 遍历不同层次的维度
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            # 添加不同层次的模块到downs列表中
            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
            ]))

        mid_dim = dims[-1]
        # 中间块
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)

        # 反向遍历不同层次的维度
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            # 添加不同层次的模块到ups列表中
            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv2d(dim_out, dim_in, 3, padding = 1)
            ]))

        # 最终的残差块
        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv2d(dim, channels, 1)

    def forward(self, x, time, x_self_cond = None):

        # 默认x_self_cond为与x相同形状的全零张量
        x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
        x = torch.cat((x_self_cond, x), dim = 1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        # 遍历downs列表中的模块
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # 遍历ups列表中的模块
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim = 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)
# 定义一个名为 Chroma 的类
class Chroma(nn.Module):
    # 初始化方法
    def __init__(
        self,
        model,
        *,
        image_size,
        timesteps = 1000,
        use_ddim = False,
        noise_schedule = 'cosine',
        time_difference = 0.
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置模型和通道数
        self.model = model
        self.channels = self.model.channels

        # 设置图像大小和噪声调度
        self.image_size = image_size

        # 根据噪声调度选择不同的 log_snr 函数
        if noise_schedule == "linear":
            self.log_snr = beta_linear_log_snr
        elif noise_schedule == "cosine":
            self.log_snr = alpha_cosine_log_snr
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        # 设置采样时间步数和是否使用 ddim
        self.timesteps = timesteps
        self.use_ddim = use_ddim

        # 设置时间差异
        self.time_difference = time_difference

    # 定义 device 属性
    @property
    def device(self):
        return next(self.model.parameters()).device

    # 获取采样时间步数
    def get_sampling_timesteps(self, batch, *, device):
        # 生成时间序列
        times = torch.linspace(1., 0., self.timesteps + 1, device = device)
        times = repeat(times, 't -> b t', b = batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
        times = times.unbind(dim = -1)
        return times

    # 生成样本
    @torch.no_grad()
    def ddpm_sample(self, shape, time_difference = None):
        # 获取 batch 大小和设备
        batch, device = shape[0], self.device

        # 设置时间差异
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步数
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成随机噪声图像
        img = torch.randn(shape, device=device)

        x_start = None

        # 循环采样时间步数
        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.timesteps):

            # 添加时间延迟
            time_next = (time_next - self.time_difference).clamp(min = 0.)

            # 获取噪声条件
            noise_cond = self.log_snr(time)

            # 获取预测的 x0
            x_start = self.model(img, noise_cond, x_start)

            # 限制 x0 的范围
            x_start.clamp_(-1., 1.)

            # 获取 log(snr)
            log_snr = self.log_snr(time)
            log_snr_next = self.log_snr(time_next)
            log_snr, log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))

            # 获取时间和下一个时间的 alpha 和 sigma
            alpha, sigma = log_snr_to_alpha_sigma(log_snr)
            alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

            # 推导后验均值和方差
            c = -expm1(log_snr - log_snr_next)
            mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
            variance = (sigma_next ** 2) * c
            log_variance = log(variance)

            # 生成噪声
            noise = torch.where(
                rearrange(time_next > 0, 'b -> b 1 1 1'),
                torch.randn_like(img),
                torch.zeros_like(img)
            )

            # 更新图像
            img = mean + (0.5 * log_variance).exp() * noise

        return img

    @torch.no_grad()
    # 从给定形状中采样数据,可以指定时间差
    def ddim_sample(self, shape, time_difference = None):
        # 获取批次大小和设备
        batch, device = shape[0], self.device

        # 设置时间差,默认为self.time_difference
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成符合正态分布的随机数据
        img = torch.randn(shape, device = device)

        x_start = None

        # 遍历时间对
        for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):

            # 获取时间和噪声水平
            log_snr = self.log_snr(times)
            log_snr_next = self.log_snr(times_next)

            # 将噪声水平填充到与img相同的维度
            padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))

            # 将噪声水平转换为alpha和sigma
            alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr)
            alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next)

            # 添加时间延迟
            times_next = (times_next - time_difference).clamp(min = 0.)

            # 预测x0
            x_start = self.model(img, log_snr, x_start)

            # 限制x0的取值范围
            x_start.clamp_(-1., 1.)

            # 获取预测的噪声
            pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)

            # 计算下一个x
            img = x_start * alpha_next + pred_noise * sigma_next

        return img

    # 无梯度计算
    @torch.no_grad()
    def sample(self, batch_size = 16):
        image_size, channels = self.image_size, self.channels
        # 根据是否使用DDIM选择采样函数
        sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
        return sample_fn((batch_size, channels, image_size, image_size))

    # 前向传播函数
    def forward(self, img, *args, **kwargs):
        batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 断言图像的高度和宽度必须为img_size
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        # 生成随机时间
        times = torch.zeros((batch,), device = device).float().uniform_(0, 1.)

        # 生成噪声
        noise = torch.randn_like(img)

        # 获取噪声水平并填充到与img相同的维度
        noise_level = self.log_snr(times)
        padded_noise_level = right_pad_dims_to(img, noise_level)
        alpha, sigma =  log_snr_to_alpha_sigma(padded_noise_level)

        # 添加噪声到图像
        noised_img = alpha * img + sigma * noise

        # 如果进行自条件训练,50%的概率从当前时间预测x_start,并用unet进行条件
        # 这种技术会使训练速度减慢25%,但似乎显著降低FID
        self_cond = None
        if random() < 0.5:
            with torch.no_grad():
                self_cond = self.model(noised_img, noise_level).detach_()

        # 预测并进行梯度下降
        pred = self.model(noised_img, noise_level, self_cond)

        return F.mse_loss(pred, img)
# trainer 类
class Trainer(object):
    # 初始化方法
    def __init__(
        self,
        diffusion_model,
        folder,
        *,
        train_batch_size = 16,
        gradient_accumulate_every = 1,
        augment_horizontal_flip = True,
        train_lr = 1e-4,
        train_num_steps = 100000,
        ema_update_every = 10,
        ema_decay = 0.995,
        adam_betas = (0.9, 0.99),
        save_and_sample_every = 1000,
        num_samples = 25,
        results_folder = './results',
        amp = False,
        fp16 = False,
        split_batches = True,
        convert_image_to = None
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = 'fp16' if fp16 else 'no'
        )

        # 设置是否使用 amp
        self.accelerator.native_amp = amp

        # 设置扩散模型
        self.model = diffusion_model

        # 检查 num_samples 是否有整数平方根
        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every

        # 设置训练批次大小和梯度累积频率
        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        # 设置训练步数和图像大小
        self.train_num_steps = train_num_steps
        self.image_size = diffusion_model.image_size

        # 数据集和数据加载器
        self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
        dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

        # 准备数据加载器
        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # 优化器
        self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)

        # 定期记录结果到文件夹
        if self.accelerator.is_main_process:
            self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)

            self.results_folder = Path(results_folder)
            self.results_folder.mkdir(exist_ok = True)

        # 步数计数器状态
        self.step = 0

        # 使用加速器准备模型、数据加载器和优化器
        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

    # 保存模��
    def save(self, milestone):
        if not self.accelerator.is_local_main_process:
            return

        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict(),
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
        }

        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

    # 加载模型
    def load(self, milestone):
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(data['model'])

        self.step = data['step']
        self.opt.load_state_dict(data['opt'])
        self.ema.load_state_dict(data['ema'])

        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])
    # 定义训练方法
    def train(self):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = accelerator.device

        # 使用 tqdm 显示训练进度条,设置初始值、总步数和是否禁用
        with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:

            # 在未达到总步数前循环
            while self.step < self.train_num_steps:

                # 初始化总损失
                total_loss = 0.

                # 根据梯度累积次数循环
                for _ in range(self.gradient_accumulate_every):
                    # 获取下一个数据批次并发送到设备
                    data = next(self.dl).to(device)

                    # 使用加速器自动混合精度
                    with self.accelerator.autocast():
                        # 计算模型损失
                        loss = self.model(data)
                        loss = loss / self.gradient_accumulate_every
                        total_loss += loss.item()

                    # 反向传播
                    self.accelerator.backward(loss)

                # 更新进度条显示损失值
                pbar.set_description(f'loss: {total_loss:.4f}')

                # 等待所有进程完成
                accelerator.wait_for_everyone()

                # 更新优化器参数
                self.opt.step()
                self.opt.zero_grad()

                # 等待所有进程完成
                accelerator.wait_for_everyone()

                # 如果是主进程
                if accelerator.is_main_process:
                    # 将指数移动平均模型发送到设备并更新
                    self.ema.to(device)
                    self.ema.update()

                    # 如果步数不为0且可以保存和采样
                    if self.step != 0 and self.step % self.save_and_sample_every == 0:
                        # 将指数移动平均模型设置为评估模式
                        self.ema.ema_model.eval()

                        # 使用无梯度计算
                        with torch.no_grad():
                            # 计算里程碑和批次数
                            milestone = self.step // self.save_and_sample_every
                            batches = num_to_groups(self.num_samples, self.batch_size)
                            all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))

                        # 拼接所有图像并保存
                        all_images = torch.cat(all_images_list, dim=0)
                        utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow=int(math.sqrt(self.num_samples)))
                        self.save(milestone)

                # 更新步数并进度条
                self.step += 1
                pbar.update(1)

        # 打印训练完成信息
        accelerator.print('training complete')

.\lucidrains\chroma-pytorch\chroma_pytorch\semantic_conditioner.py

# 导入所需的库
import torch
import os
import logging
from transformers import AutoTokenizer, AutoModelForMaskedLM, logging
from tf_bind_transformer.cache_utils import cache_fn, run_once

# 设置日志级别为错误
logging.set_verbosity_error()

# 检查值是否存在
def exists(val):
    return val is not None

# 对字典中的值应用函数
def map_values(fn, dictionary):
    return {k: fn(v) for k, v in dictionary.items()}

# 检查是否在环境变量中设置了使用 CPU 进行上下文嵌入
CONTEXT_EMBED_USE_CPU = os.getenv('CONTEXT_EMBED_USE_CPU', None) is not None

# 如果设置了使用 CPU 进行上下文嵌入,则打印提示信息
if CONTEXT_EMBED_USE_CPU:
    print('calculating context embed only on cpu')

# 预定义模型的维度和路径
MODELS = dict(
    pubmed = dict(
        dim = 768,
        path = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
    )
)

# 全局变量,用于存储模型和分词器
GLOBAL_VARIABLES = dict(model = None, tokenizer = None)

# 获取指定模型的上下文维度
def get_contextual_dim(model_name):
    assert model_name in MODELS
    return MODELS[model_name]['dim']

# 初始化模型和分词器,只运行一次
@run_once('init_transformer')
def init_transformer(model_name):
    path = MODELS[model_name]['path']
    GLOBAL_VARIABLES['tokenizer'] = AutoTokenizer.from_pretrained(path)

    model = AutoModelForMaskedLM.from_pretrained(path)

    # 如果未设置使用 CPU 进行上下文嵌入,则将模型移至 GPU
    if not CONTEXT_EMBED_USE_CPU:
        model = model.cuda()

    GLOBAL_VARIABLES['model'] = model

# 对文本进行分词和编码
@torch.no_grad()
def tokenize_text(
    text,
    max_length = 256,
    model_name = 'pubmed',
    hidden_state_index = -1,
    return_cls_token = True
):
    init_transformer(model_name)

    model = GLOBAL_VARIABLES['model']
    tokenizer = GLOBAL_VARIABLES['tokenizer']

    encoding = tokenizer.batch_encode_plus(
        [text],
        add_special_tokens = True,
        padding = True,
        truncation = True,
        max_length = max_length,
        return_attention_mask = True,
        return_tensors = 'pt'
    )

    # 如果未设置使用 CPU 进行上下文嵌入,则将编码移至 GPU
    if not CONTEXT_EMBED_USE_CPU:
        encoding = map_values(lambda t: t.cuda(), encoding)

    model.eval()
    with torch.no_grad():
        outputs = model(**encoding, output_hidden_states = True)

    hidden_state = outputs.hidden_states[hidden_state_index][0]

    if return_cls_token:
        return hidden_state[0]

    return hidden_state.mean(dim = 0)

# 获取文本表示
def get_text_repr(
    texts,
    *,
    device,
    max_length = 256,
    model_name = 'pubmed',
    hidden_state_index = -1,
    return_cls_token = True,
):
    assert model_name in MODELS, f'{model_name} not found in available text transformers to use'

    # 如果输入为字符串,则转换为列表
    if isinstance(texts, str):
        texts = [texts]

    # 缓存文本表示函数
    get_context_repr_fn = cache_fn(tokenize_text, path = f'contexts/{model_name}')

    # 获取文本的表示
    representations = [get_context_repr_fn(text, max_length = max_length, model_name = model_name, hidden_state_index = hidden_state_index, return_cls_token = return_cls_token) for text in texts]

    return torch.stack(representations).to(device)

.\lucidrains\chroma-pytorch\chroma_pytorch\__init__.py

# 从 chroma_pytorch 包中导入 Chroma 类
from chroma_pytorch.chroma_pytorch import Chroma

figure 1 in paper

generating a protein that binds to spike protein of coronavirus - Baker lab's concurrent RFDiffusion work

Chroma - Pytorch (wip)

Implementation of Chroma, generative model of proteins using DDPM and GNNs, in Pytorch. Concurrent work seems to suggest we have a slight lift-off applying denoising diffusion probabilistic models to protein design. Will also incorporate self-conditioning, applied successfully by Baker lab in RFDiffusion.

Explanation by Stephan Heijl

If you are interested in open sourcing works like these out in the wild, please consider joining OpenBioML

Todo

Citations

@misc{
    title   = {Illuminating protein space with a programmable generative model},
    author  = {John Ingraham, Max Baranov, Zak Costello, Vincent Frappier, Ahmed Ismail, Shan Tie, Wujie Wang, Vincent Xue, Fritz Obermeyer, Andrew Beam, Gevorg Grigoryan},    
    year    = {2022},
    url     = {https://cdn.generatebiomedicines.com/assets/ingraham2022.pdf}
}

.\lucidrains\chroma-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'chroma-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Chroma - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/chroma-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'denoising diffusion',
    'protein design'
  ],
  install_requires=[  # 依赖的包列表
    'einops>=0.6',
    'invariant-point-attention',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\classifier-free-guidance-pytorch\classifier_free_guidance_pytorch\attend.py

# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat

# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数exists,用于检查值是否存在
def exists(val):
    return val is not None

# 定义装饰器once,确保函数只被调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用once装饰print函数,确保只打印一次
print_once = once(print)

# 主要类Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.register_buffer("mask", None, persistent=False)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定cuda和cpu的高效注意力配置
        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(False, True, True)

    # 获取掩码
    def get_mask(self, n, device):
        if exists(self.mask) and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    # Flash Attention函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        # 推荐使用多查询单键值注意力的Tri Dao
        if k.ndim == 3:
            k = repeat(k, 'b ... -> b h ...', h = heads)

        if v.ndim == 3:
            v = repeat(v, 'b ... -> b h ...', h = heads)

        # 检查掩码是否存在并扩展到兼容的形状
        if exists(mask):
            if mask.ndim == 2:
                mask = rearrange(mask, 'b j -> b 1 1 j')

            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用torch.backends.cuda.sdp_kernel(**config._asdict())来调用pytorch 2.0的flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = self.causal
            )

        return out
    # 定义一个前向传播函数,接受查询(q), 键(k), 值(v)和掩码(mask)作为输入参数
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取序列长度n和设备信息
        n, device = q.shape[-2], q.device

        # 缩放因子,根据特征维度的倒数开根号
        scale = q.shape[-1] ** -0.5

        # 如果启用了flash注意力机制,则调用flash_attn函数
        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 根据键的维度确定键值对的einsum等式
        kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

        # 计算相似度
        sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale

        # 键的填充掩码
        if exists(mask):
            if mask.ndim == 2:
                mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 因果掩码
        if self.causal:
            causal_mask = self.get_mask(n, device)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力权重
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)

        return out

.\lucidrains\classifier-free-guidance-pytorch\classifier_free_guidance_pytorch\bge.py

# 导入所需的模块和函数
from typing import List
from beartype import beartype

import torch
import transformers 
from transformers import AutoTokenizer, AutoModel, AutoConfig
transformers.logging.set_verbosity_error()

# 创建 BGEAdapter 类
class BGEAdapter():
    def __init__(
        self,
        name
    ):
        # 设置模型名称
        name = 'BAAI/bge-base-en-v1.5'
        # 根据模型名称加载对应的 tokenizer
        tokenizer = AutoTokenizer.from_pretrained(name)
        # 根据模型名称加载对应的 model
        model = AutoModel.from_pretrained(name)
        # 根据模型名称加载对应的配置
        self.Config = AutoConfig.from_pretrained(name)
        
        # 如果有可用的 CUDA 设备,则将模型移动到 CUDA 上
        if torch.cuda.is_available():
            model = model.to("cuda")  
            
        # 设置对象的名称、模型和 tokenizer
        self.name =  name
        self.model = model
        self.tokenizer = tokenizer

    # 定义 dim_latent 属性,返回隐藏层的大小
    @property
    def dim_latent(self):
        return self.Config.hidden_size

    # 定义 max_text_len 属性,返回文本的最大长度
    @property
    def max_text_len(self):
        return 512

    # 定义 embed_text 方法,用于文本嵌入
    @torch.no_grad()
    @beartype
    def embed_text(
        self,
        texts: List[str],
        return_text_encodings = False,
        output_device = None
    ):
        # 使用 tokenizer 对文本进行编码
        encoded_input  = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to("cuda")
 
        # 将模型设置为评估模式
        self.model.eval()
         
        # 使用模型对编码后的输入进行推理
        with torch.no_grad():
            model_output = self.model(**encoded_input)  
            
        # 如果不需要返回文本编码,则返回规范化后的 CLS 嵌入
        if not return_text_encodings: 
            sentence_embeddings = model_output[0][:, 0]
            sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
            return sentence_embeddings  # 返回规范化后的 CLS 嵌入

        # 如果需要返回文本编码,则返回最后一个隐藏状态,并根据输出设备进行转换
        return model_output.last_hidden_state.to(output_device)

.\lucidrains\classifier-free-guidance-pytorch\classifier_free_guidance_pytorch\classifier_free_guidance_pytorch.py

# 导入必要的模块
from collections import namedtuple
from functools import wraps, partial, cache

import torch
import torch.nn.functional as F
from torch import nn, einsum, Tensor

from einops import rearrange, repeat, pack, unpack

from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Callable, Tuple, Optional, List, Literal, Union, Dict, Any

from inspect import signature

from classifier_free_guidance_pytorch.t5 import T5Adapter
from classifier_free_guidance_pytorch.open_clip import OpenClipAdapter
from classifier_free_guidance_pytorch.attend import Attend
from classifier_free_guidance_pytorch.bge import BGEAdapter

# 常量定义

COND_DROP_KEY_NAME = 'cond_drop_prob'

TEXTS_KEY_NAME = 'texts'
TEXT_EMBEDS_KEY_NAME = 'text_embeds'
TEXT_CONDITIONER_NAME = 'text_conditioner'
CONDITION_FUNCTION_KEY_NAME = 'cond_fns'

# 定义命名元组
TextCondReturn = namedtuple('TextCondReturn', [
    'embed',
    'mask'
])

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 判断列表是否为空
def is_empty(l):
    return len(l) == 0

# 返回第一个存在的值
def default(*values):
    for value in values:
        if exists(value):
            return value
    return None

# 将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 将单个值打包成元组
def pack_one(x, pattern):
    return pack([x], pattern)

# 从元组中解包单个值
def unpack_one(x, ps, pattern):
    return unpack(x, ps, pattern)[0]

# 张量辅助函数

# 根据概率生成掩码张量
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 使用自动文本条件的分类器自由引导

# 装饰器函数,用于处理函数的参数和自动文本条件
@beartype
def classifier_free_guidance(
    fn: Callable,
    cond_drop_prob_keyname = COND_DROP_KEY_NAME,
    texts_key_name = TEXTS_KEY_NAME,
    text_embeds_key_name = TEXT_EMBEDS_KEY_NAME,
    cond_fns_keyname = CONDITION_FUNCTION_KEY_NAME,
    text_conditioner_name = TEXT_CONDITIONER_NAME
):
    # 获取函数的参数信息
    fn_params = signature(fn).parameters

    # 判断是否需要自动处理文本条件
    auto_handle_text_condition = texts_key_name not in fn_params and text_embeds_key_name not in fn_params

    # 内部函数,用于实际执行分类器自由引导
    @wraps(fn)
    def inner(
        self,
        *args,
        cond_scale: float = 1.,
        rescale_phi: float = 0.,
        cfg_routed_kwargs: Dict[str, Tuple[Any, Any]] = dict(),   # 用于传递参数到前向和无效前向调用的字典(用于处理在使用 CFG 进行变换解码时的缓存)
        **kwargs
        @wraps(fn)
        # 定义一个装饰器函数,用于包装原始函数
        def fn_maybe_with_text(self, *args, **kwargs):
            # 在可能包含文本的情况下,对原始函数进行包装
            if auto_handle_text_condition:
                # 如果自动处理文本条件为真
                texts = kwargs.pop('texts', None)
                text_embeds = kwargs.pop('text_embeds', None)

                assert not (exists(texts) and exists(text_embeds))
                # 断言不存在同时有texts和text_embeds

                raw_text_cond = cond_fns = None

                text_conditioner = getattr(self, text_conditioner_name, None)
                # 获取文本条件器对象

                cond_drop_prob = kwargs.pop(cond_drop_prob_keyname, None)

                assert not exists(cond_drop_prob) or 0. <= cond_drop_prob <= 1.
                # 断言不存在cond_drop_prob或者其值在0到1之间

                # 自动将文本转换为条件函数
                if exists(texts) ^ exists(text_embeds):

                    assert is_bearable(texts, Optional[List[str]]), f'keyword `{texts_key_name}` must be a list of strings'
                    # 断言texts是可接受的类型,必须是字符串列表

                    assert exists(text_conditioner) and is_bearable(text_conditioner, Conditioner), 'text_conditioner must be set on your network with the correct hidden dimensions to be conditioned on'
                    # 断言存在text_conditioner并且其类型是Conditioner

                    text_condition_input = dict(texts = texts) if exists(texts) else dict(text_embeds = text_embeds)

                    cond_fns, raw_text_cond = text_conditioner(**text_condition_input, cond_drop_prob = cond_drop_prob)
                    # 调用文本条件器生成条件函数和原始文本条件

                elif isinstance(text_conditioner, NullConditioner):
                    assert cond_drop_prob == 0., 'null conditioner has nothing to dropout'
                    # 断言cond_drop_prob为0,空条件器没有需要丢弃的内容

                    cond_fns, raw_text_cond = text_conditioner()
                    # 调用空条件器

                if 'cond_fns' in fn_params:
                    kwargs.update(cond_fns = cond_fns)

                if 'raw_text_cond' in fn_params:
                    kwargs.update(raw_text_cond = raw_text_cond)

            return fn(self, *args, **kwargs)
            # 返回原始函数的结果

        # 主分类器自由引导逻辑

        if self.training:
            assert cond_scale == 1, 'you cannot do condition scaling when in training mode'
            # 断言在训练模式下不能进行条件缩放

            return fn_maybe_with_text(self, *args, **kwargs)
            # 返回可能包含文本的函数结果

        assert cond_scale >= 1, 'invalid conditioning scale, must be greater or equal to 1'
        # 断言条件缩放必须大于等于1

        kwargs_without_cond_dropout = {**kwargs, cond_drop_prob_keyname: 0.}
        kwargs_with_cond_dropout = {**kwargs, cond_drop_prob_keyname: 1.}
        # 创建不带条件丢弃和带条件丢弃的参数字典

        # 处理要路由到前向和空前向的参数,以便处理两次调用的缓存
        fn_kwargs = {k: v[0] for k, v in cfg_routed_kwargs.items()}
        null_fn_kwargs = {k: v[1] for k, v in cfg_routed_kwargs.items()}
        # 创建非空前向和空前向的参数字典

        # 非空前向
        outputs = fn_maybe_with_text(self, *args, **fn_kwargs, **kwargs_without_cond_dropout)
        # 调用可能包含文本的函数

        if cond_scale == 1:
            return outputs
            # 如果条件缩放为1,则直接返回结果

        logits, *rest = cast_tuple(outputs)
        # 将输出结果拆分为logits和其余部分

        # 空前向
        null_outputs = fn_maybe_with_text(self, *args, **null_fn_kwargs, **kwargs_with_cond_dropout)
        # 调用可能包含文本的函数

        null_logits, *null_rest = cast_tuple(null_outputs)
        # 将空前向的输出结果拆分为null_logits和其余部分

        zipped_rest = tuple(zip(rest, null_rest))
        # 将非空前向和空前向的其余部分进行压缩

        scaled_logits = null_logits + (logits - null_logits) * cond_scale
        # 计算缩放后的logits

        if rescale_phi <= 0:
            logit_output = scaled_logits
        else:
            # 提议的方法,用于防止分类器自由引导过度饱和
            # 与imagen的解决方案不同,适用于像素空间和潜在空间

            dims = tuple(range(1, logits.ndim - 1))
            rescaled_logits = scaled_logits * (logits.std(dim = dims, keepdim = True) / scaled_logits.std(dim = dims, keepdim= True))
            logit_output = rescaled_logits * rescale_phi + scaled_logits * (1. - rescale_phi)
            # 计算最终输出logits

        if is_empty(zipped_rest):
            return logit_output
            # 如果压缩后的结果为空,则直接返回logit_output

        return (logit_output, *zipped_rest)
        # 返回最终结果
    return inner
# class decorator

# 装饰器函数,用于添加分类器自由引导的类装饰器
@beartype
def classifier_free_guidance_class_decorator(
    orig_class,
    cond_drop_prob_keyname = COND_DROP_KEY_NAME,
    texts_key_name = TEXTS_KEY_NAME,
    text_embeds_key_name = TEXT_EMBEDS_KEY_NAME,
    cond_fns_keyname = CONDITION_FUNCTION_KEY_NAME,
    text_conditioner_name = TEXT_CONDITIONER_NAME
):
    assert issubclass(orig_class, nn.Module)

    # decorate init

    # 保存原始类的初始化方法
    orig_init = orig_class.__init__

    # 装饰原始类的初始化方法
    @wraps(orig_init)
    @beartype
    def __init__(
        self,
        *args,
        text_condition_type: Union[
            Literal['film'],
            Literal['attention'],
            Literal['null'],
            Literal['raw'],
        ] = 'film',
        text_condition_model_types: Tuple[str, ...] = ('t5',),
        text_condition_hidden_dims: Tuple[int, ...],
        text_condition_cond_drop_prob: float,
        **kwargs
    ):
        # 调用原始类的初始化方法
        orig_init(self, *args, **kwargs)

        # 根据文本条件类型选择相应的条件器类
        if text_condition_type == 'film':
            condition_klass = TextConditioner
        elif text_condition_type == 'attention':
            condition_klass = AttentionTextConditioner
        elif text_condition_type == 'raw':
            condition_klass = TextEmbeddingReturner
        else:
            condition_klass = NullConditioner

        # 初始化文本条件器
        self.text_conditioner = condition_klass(
            model_types = text_condition_model_types,
            hidden_dims = text_condition_hidden_dims,
            cond_drop_prob = text_condition_cond_drop_prob
        )

    orig_class.__init__ = __init__

    # decorate forward

    # 装饰原始类的前向传播方法
    decorated_forward = classifier_free_guidance(
        orig_class.forward,
        cond_drop_prob_keyname = cond_drop_prob_keyname,
        texts_key_name = texts_key_name,
        text_embeds_key_name = text_embeds_key_name,
        cond_fns_keyname = cond_fns_keyname,
        text_conditioner_name = text_conditioner_name
    )

    orig_class.forward = decorated_forward

    # forward `embed_texts` to the `text_conditioner.embed_texts`

    # 定义嵌入文本的方法,将其转发到文本条件器的嵌入文本方法
    @beartype
    def embed_texts(self, texts: List[str]):
        return self.text_conditioner.embed_texts(texts)

    # 定义属性,缓存最大条件文本长度
    @property
    @cache
    def max_cond_text_len(self):
        total_cond_text_len = sum([text_model.max_text_len for text_model in self.text_conditioner.text_models])
        return total_cond_text_len

    # 如果原始类没有最大条件文本长度属性,则添加
    if not hasattr(orig_class, 'max_cond_text_len'):
        orig_class.max_cond_text_len = max_cond_text_len

    # 如果原始类没有嵌入文本方法,则添加
    if not hasattr(orig_class, 'embed_texts'):
        orig_class.embed_texts = embed_texts

    # 标记类已被装饰
    orig_class.__decorated_with_cfg = True
    return orig_class

# attention

# 定义注意力模块类
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dim_context = None,
        norm_context = False,
        num_null_kv = 0,
        flash = False
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        dim_context = default(dim_context, dim)

        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()

        self.attend = Attend(flash = flash)        

        self.num_null_kv = num_null_kv
        self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head))

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context = None,
        mask = None
        ):
        # 获取输入张量 x 的第一个维度大小
        b = x.shape[0]

        # 如果上下文存在,则对上下文进行归一化处理
        if exists(context):
            context = self.context_norm(context)

        # 如果上下文不存在,则使用默认的 x 作为上下文输入
        kv_input = default(context, x)

        # 对输入张量 x 进行归一化处理
        x = self.norm(x)

        # 将输入张量 x 分别转换为查询 q,键 k,值 v
        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        # 如果存在空键值对数量大于 0
        if self.num_null_kv > 0:
            # 重复空键值对,使其与输入张量 x 的第一个维度大小相匹配
            null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
            # 将空键值对与原始键 k 和值 v 进行拼接
            k = torch.cat((null_k, k), dim = -2)
            v = torch.cat((null_v, v), dim = -2)

        # 如果存在掩码 mask
        if exists(mask):
            # 在掩码 mask 上添加指定数量的填充值
            mask = F.pad(mask, (self.num_null_kv, 0), value = True)
            # 重新排列掩码 mask 的维度
            mask = rearrange(mask, 'b j -> b 1 1 j')

        # 重新排列查询 q 的维度
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # 进行注意力计算
        out = self.attend(q, k, v, mask = mask)

        # 重新排列输出 out 的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回最终输出
        return self.to_out(out)
# dimension adapters

# 重新排列通道为最后一个维度的函数装饰器
def rearrange_channel_last(fn):
    @wraps(fn)
    def inner(hiddens):
        hiddens, ps = pack_one(hiddens, 'b * d')
        conditioned = fn(hiddens)
        return unpack_one(conditioned, ps, 'b * d')
    return inner

# 重新排列通道为第一个维度的函数装饰器
def rearrange_channel_first(fn):
    """ will adapt shape of (batch, feature, ...) for conditioning """

    @wraps(fn)
    def inner(hiddens):
        hiddens, ps = pack_one(hiddens, 'b d *')
        hiddens = rearrange(hiddens, 'b d n -> b n d')
        conditioned =  fn(hiddens)
        conditioned = rearrange(conditioned, 'b n d -> b d n')
        return unpack_one(conditioned, ps, 'b d *')

    return inner

# conditioning modules

# FiLM 模块
class FiLM(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim * 4),
            nn.SiLU(),
            nn.Linear(hidden_dim * 4, hidden_dim * 2)
        )

        nn.init.zeros_(self.net[-1].weight)
        nn.init.zeros_(self.net[-1].bias)

    def forward(self, conditions, hiddens):
        scale, shift = self.net(conditions).chunk(2, dim = -1)
        assert scale.shape[-1] == hiddens.shape[-1], f'unexpected hidden dimesion {hiddens.shape[-1]} used for conditioning'
        scale, shift = map(lambda t: rearrange(t, 'b d -> b 1 d'), (scale, shift))
        return hiddens * (scale + 1) + shift

# 交叉注意力模块
class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        heads = 8,
        dim_head = 64,
        flash = False
    ):
        super().__init__()
        self.attn = Attention(
            dim = hidden_dim,
            dim_context = dim,
            norm_context = True,
            num_null_kv = 1,
            dim_head = dim_head,
            heads = heads,
            flash = flash
        )

    def forward(
        self,
        condition,
        hiddens,
        mask = None
    ):
        return self.attn(hiddens, condition, mask = mask) + hiddens

# film text conditioning

# 条件配置字典
CONDITION_CONFIG = dict(
    t5 = T5Adapter,
    clip = OpenClipAdapter,
    bge = BGEAdapter
)

# 模型类型列表
MODEL_TYPES = CONDITION_CONFIG.keys()

# 条件器基类
class Conditioner(nn.Module):
    pass

# 空条件器
class Identity(nn.Module):
    def forward(self, t, *args, **kwargs):
        return t

# 空条件器类,继承自 Conditioner
@beartype
class NullConditioner(Conditioner):
    def __init__(
        self,
        *,
        hidden_dims: Tuple[int, ...],
        **kwargs
    ):
        super().__init__()
        num_null_conditioners = len(hidden_dims)
        self.cond_fns = tuple(Identity() for _ in range(num_null_conditioners))

        self.register_buffer('_device_param', torch.tensor(0), persistent = False)

    @property
    def device(self):
        return next(self.buffers()).device

    def embed_texts(self, texts: List[str]):
        assert False, 'null conditioner cannot embed text'

    def forward(self, *args, **kwarg):
        return self.cond_fns, None

# 带有 FiLM 的文本条件器
@beartype
class TextConditioner(Conditioner):
    def __init__(
        self,
        *,
        hidden_dims: Tuple[int, ...],
        model_types = 't5',
        model_names = None,
        cond_drop_prob = 0.,
        hiddens_channel_first = True,
        text_embed_stem_dim_mult = 2
    ):
        # 调用父类的构造函数
        super().__init__()
        # 将 model_types 转换为元组
        model_types = cast_tuple(model_types)
        # 将 model_names 转换为元组,并确保其长度与 model_types 相同
        model_names = cast_tuple(model_names, length = len(model_types))

        # 断言 model_types 和 model_names 的长度相同
        assert len(model_types) == len(model_names)
        # 断言 model_types 中的每个元素都在 MODEL_TYPES 中
        assert all([model_type in MODEL_TYPES for model_type in model_types])

        # 初始化一个空列表 text_models
        text_models = []

        # 遍历 model_types 和 model_names,根据 model_type 创建对应的模型,并添加到 text_models 中
        for model_type, model_name in zip(model_types, model_names):
            klass = CONDITION_CONFIG.get(model_type)
            model = klass(model_name)
            text_models.append(model)

        # 将 text_models 赋值给 self.text_models
        self.text_models = text_models
        # 获取每个模型的潜在维度,存储在 latent_dims 中
        self.latent_dims = [model.dim_latent for model in text_models]

        # 初始化一个空的 nn.ModuleList,用于存储条件器
        self.conditioners = nn.ModuleList([])

        # 将 hidden_dims、num_condition_fns、hiddens_channel_first、cond_drop_prob 等属性赋值
        self.hidden_dims = hidden_dims
        self.num_condition_fns = len(hidden_dims)
        self.hiddens_channel_first = cast_tuple(hiddens_channel_first, self.num_condition_fns) # 是否将待条件化的隐藏层放在通道维度的第一位

        # 断言 hiddens_channel_first 的长度与 num_condition_fns 相同
        assert len(self.hiddens_channel_first) == self.num_condition_fns

        # 将 cond_drop_prob 赋值给 self.cond_drop_prob

        # 计算总的潜在维度
        total_latent_dim = sum(self.latent_dims)

        # 计算 MLP 的输出维度
        mlp_stem_output_dim = total_latent_dim * text_embed_stem_dim_mult

        # 定义文本嵌入的 MLP 结构
        self.text_embed_stem_mlp = nn.Sequential(
            nn.Linear(total_latent_dim, mlp_stem_output_dim),
            nn.SiLU()
        )

        # 根据 hidden_dims 创建条件器,并添加到 self.conditioners 中
        for hidden_dim in hidden_dims:
            self.conditioners.append(FiLM(mlp_stem_output_dim, hidden_dim))

        # 初始化一个随机参数 null_text_embed
        self.null_text_embed = nn.Parameter(torch.randn(total_latent_dim))

        # 注册一个缓冲区 _device_param
        self.register_buffer('_device_param', torch.tensor(0.), persistent = False)

    @property
    def device(self):
        # 返回第一个缓冲区的设备
        return next(self.buffers()).device

    def embed_texts(self, texts: List[str]):
        # 获取设备信息
        device = self.device

        # 初始化一个空列表 text_embeds,用于存储文本嵌入结果
        text_embeds = []
        # 遍历每个文本模型,将文本嵌入结果添加到 text_embeds 中
        for text_model in self.text_models:
            text_embed = text_model.embed_text(texts)
            text_embeds.append(text_embed.to(device))

        # 沿着最后一个维度拼接文本嵌入结果
        return torch.cat(text_embeds, dim = -1)

    def forward(
        self,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        cond_drop_prob = None,
        repeat_batch = 1,               # 用于机器人变压器边缘情况
    ) -> Tuple[
        Tuple[Callable, ...],
        TextCondReturn
    ]:

        # 断言 texts 和 text_embeds 只有一个存在
        assert exists(texts) ^ exists(text_embeds)

        # 如果处于训练状态,则使用默认的 cond_drop_prob,否则需要显式设置
        if self.training:
            cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
        else:
            assert exists(cond_drop_prob), '当不处于训练状态时,必须显式设置 cond_drop_prob'

        # 根据 texts 或 text_embeds 的存在情况确定 batch 大小
        if exists(texts):
            batch = len(texts)
        elif exists(text_embeds):
            batch = text_embeds.shape[0]

        # 如果 text_embeds 不存在,则调用 embed_texts 方法生成
        if not exists(text_embeds):
            text_embeds = self.embed_texts(texts)

        # 如果 cond_drop_prob 大于 0,则生成一个掩码,用于对文本嵌入进行条件化
        if cond_drop_prob > 0.:
            prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device = self.device)
            null_text_embeds = rearrange(self.null_text_embed, 'd -> 1 d')

            text_embeds = torch.where(
                prob_keep_mask,
                text_embeds,
                null_text_embeds
            )

        # 对文本嵌入进行 MLP 处理
        text_embeds = self.text_embed_stem_mlp(text_embeds)

        # 准备条件函数
        repeat_batch = cast_tuple(repeat_batch, self.num_condition_fns)

        cond_fns = []

        # 遍历条件器,生成条件函数
        for cond, cond_hiddens_channel_first, cond_repeat_batch in zip(self.conditioners, self.hiddens_channel_first, repeat_batch):
            cond_text_embeds = repeat(text_embeds, 'b ... -> (b r) ...', r = cond_repeat_batch)
            cond_fn = partial(cond, cond_text_embeds)

            wrapper_fn = rearrange_channel_first if cond_hiddens_channel_first else rearrange_channel_last

            cond_fns.append(wrapper_fn(cond_fn))

        # 返回条件函数和文本条件返回值
        return tuple(cond_fns), TextCondReturn(text_embeds, None)
# 定义一个名为 AttentionTextConditioner 的类,继承自 Conditioner 类
@beartype
class AttentionTextConditioner(Conditioner):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        *,
        hidden_dims: Tuple[int, ...],  # 隐藏层维度的元组
        model_types = 't5',  # 模型类型,默认为 't5'
        model_names = None,  # 模型名称,默认为 None
        cond_drop_prob = 0.,  # 条件丢弃概率,默认为 0
        hiddens_channel_first = True,  # 是否隐藏层优先,默认为 True
        dim_latent = None,  # 潜在维度,默认为 None
        attn_dim_head = 64,  # 注意力头维度,默认为 64
        attn_heads = 8,  # 注意力头数,默认为 8
        flash = True  # 是否闪烁,默认为 True
    ):
        super().__init__()  # 调用父类的初始化函数
        model_types = cast_tuple(model_types)  # 将模型类型转换为元组
        model_names = cast_tuple(model_names, length = len(model_types))  # 将模型名称转换为元组,长度与模型类型相同

        assert len(model_types) == len(model_names)  # 断言模型类型和模型名称长度相同
        assert all([model_type in MODEL_TYPES for model_type in model_types])  # 断言所有模型类型在 MODEL_TYPES 中

        text_models = []  # 初始化文本模型列表

        # 遍历模型类型和模型名称,创建文本模型并添加到列表中
        for model_type, model_name in zip(model_types, model_names):
            klass = CONDITION_CONFIG.get(model_type)
            model = klass(model_name)
            text_models.append(model)

        self.text_models = text_models  # 将文本模型列表赋值给类属性

        self.to_latent_dims = nn.ModuleList([])  # 初始化线性层列表

        dim_latent = default(dim_latent, max([model.dim_latent for model in text_models]))  # 计算潜在维度

        self.dim_latent = dim_latent  # 将潜在维度赋值给类属性

        # 遍历文本模型,为每个模型添加线性层
        for model in text_models:
            self.to_latent_dims.append(nn.Linear(model.dim_latent, dim_latent))

        self.conditioners = nn.ModuleList([])  # 初始化条件器列表

        self.hidden_dims = hidden_dims  # 隐藏层维度赋值给类属性
        self.num_condition_fns = len(hidden_dims)  # 隐藏层维度数量赋值给类属性
        self.hiddens_channel_first = cast_tuple(hiddens_channel_first, self.num_condition_fns)  # 是否隐藏层优先赋值给类属性

        assert len(self.hiddens_channel_first) == self.num_condition_fns  # 断言隐藏层优先长度与隐藏层维度数量相同

        self.cond_drop_prob = cond_drop_prob  # 条件丢弃概率赋值给类属性

        # 遍历隐藏层维度,为每个维度添加交叉注意力模块
        for hidden_dim in hidden_dims:
            self.conditioners.append(CrossAttention(dim_latent, hidden_dim, flash = flash))

        self.register_buffer('_device_param', torch.tensor(0), persistent = False)  # 注册缓冲区

    @property
    def device(self):
        return next(self.buffers()).device  # 返回设备信息

    # 嵌入文本函数,接受文本列表,返回文本嵌入向量
    def embed_texts(self, texts: List[str]):
        device = self.device  # 获取设备信息

        text_embeds = []  # 初始化文本嵌入列表

        # 遍历文本模型和线性层,为每个文本嵌入向量添加嵌入
        for text_model, to_latent in zip(self.text_models, self.to_latent_dims):
            text_embed = text_model.embed_text(texts, return_text_encodings = True)  # 嵌入文本并返回文本编码

            text_embed = text_embed.to(device)  # 将文本嵌入向量移动到设备

            mask = (text_embed != 0).any(dim = -1)  # 创建掩码

            text_embed = to_latent(text_embed)  # 使用线性层转换文本嵌入向量
            text_embed = text_embed.masked_fill(~mask[..., None], 0.)  # 根据掩码填充文本嵌入向量

            text_embeds.append(text_embed)  # 将处理后的文本嵌入向量添加到列表中

        return torch.cat(text_embeds, dim = -2)  # 沿指定维度连接文本嵌入向量

    # 前向传播函数,接受文本列表、文本嵌入向量等参数,返回元组
    def forward(
        self,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        cond_drop_prob = None,
        repeat_batch = 1,  # 用于机器人变压器边缘情况
    ) -> Tuple[
        Tuple[Callable, ...],
        TextCondReturn
        # 检查是否存在文本或文本嵌入
        assert exists(texts) or exists(text_embeds)

        # 如果存在文本嵌入和文本,则文本嵌入优先
        if exists(text_embeds) and exists(texts):
            texts = None

        # 如果处于训练状态,则使用默认的条件丢弃概率,否则需要显式设置条件丢弃概率
        if self.training:
            cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
        else:
            assert exists(cond_drop_prob), 'when not training, cond_drop_prob must be explicitly set'

        # 根据文本或文本嵌入的存在情况确定批次大小
        if exists(texts):
            batch = len(texts)
        elif exists(text_embeds):
            batch = text_embeds.shape[0]

        # 如果不存在文本嵌入,则使用模型的 embed_texts 方法生成文本嵌入
        if not exists(text_embeds):
            text_embeds = self.embed_texts(texts)

        # 创建一个掩码,标记非零元素的位置
        mask = (text_embeds != 0).any(dim=-1)

        # 如果条件丢弃概率大于0,则生成一个概率保留掩码
        if cond_drop_prob > 0.:
            prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device=self.device)
            mask = mask & prob_keep_mask

        # 准备条件函数
        repeat_batch = cast_tuple(repeat_batch, self.num_condition_fns)
        cond_fns = []

        # 遍历条件器,生成条件函数列表
        for cond, cond_hiddens_channel_first, cond_repeat_batch in zip(self.conditioners, self.hiddens_channel_first, repeat_batch):
            cond_text_embeds = repeat(text_embeds, 'b ... -> (b r) ...', r=cond_repeat_batch)
            cond_mask = repeat(mask, 'b ... -> (b r) ...', r=cond_repeat_batch) if exists(mask) else None

            cond_fn = partial(cond, cond_text_embeds, mask=cond_mask)

            wrapper_fn = rearrange_channel_first if cond_hiddens_channel_first else rearrange_channel_last

            cond_fns.append(wrapper_fn(cond_fn))

        # 返回条件函数列表和文本条件返回对象
        return tuple(cond_fns), TextCondReturn(text_embeds, mask)
# 返回原始文本嵌入

# 定义一个文本嵌入返回器类,继承自 Conditioner 类
@beartype
class TextEmbeddingReturner(Conditioner):
    # 初始化函数
    def __init__(
        self,
        *,
        dim_latent = None,  # 潜在维度,默认为 None
        hidden_dims: Tuple[int, ...] = tuple(),  # 隐藏维度,默认为空元组
        model_types = 't5',  # 模型类型,默认为 't5'
        model_names = None,  # 模型名称,默认为 None
        cond_drop_prob = 0.,  # 条件丢弃概率,默认为 0.
    ):
        super().__init__()  # 调用父类的初始化函数
        model_types = cast_tuple(model_types)  # 将模型类型转换为元组
        model_names = cast_tuple(model_names, length = len(model_types))  # 将模型名称转换为元组,长度与模型类型相同

        assert len(model_types) == len(model_names)  # 断言模型类型和模型名称长度相同
        assert all([model_type in MODEL_TYPES for model_type in model_types])  # 断言所有模型类型在 MODEL_TYPES 中

        text_models = []  # 初始化文本模型列表

        # 遍历模型类型和模型名称,创建模型对象并添加到文本模型列表中
        for model_type, model_name in zip(model_types, model_names):
            klass = CONDITION_CONFIG.get(model_type)
            model = klass(model_name)
            text_models.append(model)

        self.text_models = text_models  # 将文本模型列表赋值给实例变量

        self.to_latent_dims = nn.ModuleList([])  # 初始化潜在维度列表

        dim_latent = default(dim_latent, max([model.dim_latent for model in text_models]))  # 获取最大的模型潜在维度作为潜在维度

        self.dim_latent = dim_latent  # 将潜在维度赋值给实例变量

        # 遍历文本模型,为每个模型创建线性层并添加到潜在维度列表中
        for model in text_models:
            self.to_latent_dims.append(nn.Linear(model.dim_latent, dim_latent))

        self.conditioners = nn.ModuleList([])  # 初始化条件器列表

        self.cond_drop_prob = cond_drop_prob  # 将条件丢弃概率赋值给实例变量

        # 遍历隐藏维度,为每个维度创建恒等映射并添加到条件器列表中
        for hidden_dim in hidden_dims:
            self.conditioners.append(nn.Identity())

        self.register_buffer('_device_param', torch.tensor(0), persistent = False)  # 注册缓冲区

    @property
    def device(self):
        return next(self.buffers()).device  # 返回缓冲区的设备

    # 嵌入文本函数
    def embed_texts(self, texts: List[str]):
        device = self.device  # 获取设备

        text_embeds = []  # 初始化文本嵌入列表

        # 遍历文本模型和潜在维度列表,为每个文本模型嵌入文本并处理
        for text_model, to_latent in zip(self.text_models, self.to_latent_dims):
            text_embed = text_model.embed_text(texts, return_text_encodings = True)  # 嵌入文本并返回文本编码

            text_embed = text_embed.to(device)  # 将文本嵌入移到设备上

            mask = (text_embed != 0).any(dim = -1)  # 创建掩码,标记非零值

            text_embed = to_latent(text_embed)  # 使用线性层进行潜在维度转换
            text_embed = text_embed.masked_fill(~mask[..., None], 0.)  # 根据掩码填充文本嵌入

            text_embeds.append(text_embed)  # 将处理后的文本嵌入添加到列表中

        return torch.cat(text_embeds, dim = -2)  # 沿指定维度拼接文本嵌入

    # 前向传播函数
    def forward(
        self,
        texts: Optional[List[str]] = None,  # 文本列表,默认为 None
        text_embeds: Optional[Tensor] = None,  # 文本嵌入张量,默认为 None
        cond_drop_prob = None  # 条件丢弃概率,默认为 None
    ) -> Tuple[
        Tuple[Callable, ...],  # 返回条件器元组
        TextCondReturn  # 返回文本条件返回对象
    ]:

        assert exists(texts) ^ exists(text_embeds)  # 断言文本列表和文本嵌入张量只能有一个存在

        if self.training:
            cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)  # 如果在训练模式下,使用默认的条件丢弃概率
        else:
            assert exists(cond_drop_prob), 'when not training, cond_drop_prob must be explicitly set'  # 如果不在训练模式下,条件丢弃概率必须显式设置

        if exists(texts):
            batch = len(texts)  # 获取文本列表的长度

        elif exists(text_embeds):
            batch = text_embeds.shape[0]  # 获取文本嵌入张量的批次大小

        if not exists(text_embeds):
            text_embeds = self.embed_texts(texts)  # 如果文本嵌入不存在,则调用嵌入文本函数

        mask = (text_embeds != 0).any(dim = -1)  # 创建掩码,标记非零值

        if cond_drop_prob > 0.:
            prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device = self.device)  # 创建概率掩码
            mask = mask & prob_keep_mask  # 更新掩码

        return tuple(self.conditioners), TextCondReturn(text_embeds, mask)  # 返回条件器元组和文本条件返回对象

.\lucidrains\classifier-free-guidance-pytorch\classifier_free_guidance_pytorch\open_clip.py

# 导入必要的库和模块
from beartype import beartype
from typing import List

import torch
from torch import nn, einsum
import torch.nn.functional as F

import open_clip
from classifier_free_guidance_pytorch.tokenizer import tokenizer

# 常量定义

DEFAULT_CLIP_NAME = 'ViT-B-32'
DEFAULT_PRETRAINED_CLIP = 'laion400m_e32'

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 适配器类

class OpenClipAdapter():
    def __init__(
        self,
        name = DEFAULT_CLIP_NAME,
        pretrained = DEFAULT_PRETRAINED_CLIP
    ):
        # 设置默认值
        name = default(name, DEFAULT_CLIP_NAME)
        pretrained = default(pretrained, DEFAULT_PRETRAINED_CLIP)

        # 创建 OpenCLIP 模型和预处理函数
        clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)

        self.clip = clip
        clip.eval()

        self.tokenizer = tokenizer

        self.eos_id = 49407

        # 获取文本注意力的最后一层
        text_attention_final = self.find_layer('ln_final')
        self._dim_latent = text_attention_final.weight.shape[0]

        # 注册前向钩子
        self.handle = text_attention_final.register_forward_hook(self._hook)
        self.clip_normalize = preprocess.transforms[-1]
        self.cleared = False

    # 查找指定层
    def find_layer(self,  layer):
        modules = dict([*self.clip.named_modules()])
        return modules.get(layer, None)

    # 清除前向钩子
    def clear(self):
        if self.cleared:
            return

        self.handle()

    # 前向钩子函数
    def _hook(self, _, inputs, outputs):
        self.text_encodings = outputs

    @property
    def dim_latent(self):
        return self._dim_latent

    @property
    def max_text_len(self):
        return 77

    # 嵌入文本
    @torch.no_grad()
    @beartype
    def embed_text(
        self,
        texts: List[str],
        return_text_encodings = False,
        output_device = None
    ):
        # 对文本进行分词
        texts, max_length = self.tokenizer.tokenize(texts)
        texts = texts[..., :self.max_text_len]

        # 编码文本
        text_embeds = self.clip.encode_text(texts)

        texts = texts[..., :max_length]

        if not return_text_encodings:
            return l2norm(text_embeds).to(output_device)

        # 处理文本编码
        is_eos_id = (texts == self.eos_id)
        text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
        text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
        text_mask = text_mask & (texts != 0)

        assert not self.cleared

        text_encodings = self.text_encodings[:, :max_length]
        text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
        del self.text_encodings

        return text_encodings.float().to(output_device)

.\lucidrains\classifier-free-guidance-pytorch\classifier_free_guidance_pytorch\t5.py

# 导入所需的模块
from typing import List
from beartype import beartype

import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config

# 设置 transformers 模块的日志级别为 error
transformers.logging.set_verbosity_error()

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 配置

# 定义最大长度
MAX_LENGTH = 256

# 默认的 T5 模型名称
DEFAULT_T5_NAME = 'google/t5-v1_1-base'

# 存储 T5 模型配置的字典
T5_CONFIGS = {}

# 全局单例变量

# 获取 tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

# 获取模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取模型和 tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()
    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)
    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        # 避免加载模型,仅获取维度
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config=config)
    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]
    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config
    else:
        assert False
    return config.d_model

# 编码文本

# 对文本进行编码
def t5_encode_text(texts, name = DEFAULT_T5_NAME, output_device = None):
    t5, tokenizer = get_model_and_tokenizer(name)

    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device

    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = attn_mask.bool()

    if not exists(output_device):
        return encoded_text, attn_mask

    encoded_text.to(output_device)
    attn_mask.to(output_device)

    return encoded_text, attn_mask

# T5 适配器类
class T5Adapter():
    def __init__(
        self,
        name
    ):
        name = default(name, DEFAULT_T5_NAME)
        t5, tokenizer = get_model_and_tokenizer(name)

        if torch.cuda.is_available():
            t5 = t5.cuda()

        self.name =  name
        self.t5 = t5
        self.tokenizer = tokenizer

    @property
    def dim_latent(self):
        return get_encoded_dim(self.name)

    @property
    def max_text_len(self):
        return MAX_LENGTH

    @torch.no_grad()
    @beartype
    def embed_text(
        self,
        texts: List[str],
        return_text_encodings = False,
        output_device = None
    ):
        device = next(self.t5.parameters()).device

        encoded = self.tokenizer.batch_encode_plus(
            texts,
            return_tensors = "pt",
            padding = 'longest',
            max_length = MAX_LENGTH,
            truncation = True
        )

        input_ids = encoded.input_ids.to(device)
        attn_mask = encoded.attention_mask.to(device)

        self.t5.eval()

        with torch.no_grad():
            output = self.t5(input_ids = input_ids, attention_mask = attn_mask)
            encoded_text = output.last_hidden_state.detach()

        attn_mask = attn_mask.bool()

        encoded_text.masked_fill_(~attn_mask[..., None], 0.)

        if not return_text_encodings:
            numer = encoded_text.sum(dim = -2)
            denom = attn_mask.sum(dim = -1)[..., None]
            numer.masked_fill_(denom == 0, 0.)
            mean_encodings = numer / denom.clamp(min = 1e-3)
            return mean_encodings

        return encoded_text.to(output_device)

.\lucidrains\classifier-free-guidance-pytorch\classifier_free_guidance_pytorch\__init__.py

# 从 classifier_free_guidance_pytorch 包中导入 NullConditioner、TextConditioner、AttentionTextConditioner、TextEmbeddingReturner 类
from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import (
    NullConditioner,
    TextConditioner,
    AttentionTextConditioner,
    TextEmbeddingReturner
)

# 从 classifier_free_guidance_pytorch 包中导入 classifier_free_guidance、classifier_free_guidance_class_decorator 函数
from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import (
    classifier_free_guidance,
    classifier_free_guidance_class_decorator
)

# 从 classifier_free_guidance_pytorch 包中导入 OpenClipAdapter 类
from classifier_free_guidance_pytorch.open_clip import OpenClipAdapter

# 从 classifier_free_guidance_pytorch 包中导入 T5Adapter 类
from classifier_free_guidance_pytorch.t5 import T5Adapter

# 从 classifier_free_guidance_pytorch 包中导入 BGEAdapter 类
from classifier_free_guidance_pytorch.bge import BGEAdapter

Classifier Free Guidance - Pytorch

Implementation of Classifier Free Guidance in Pytorch, with emphasis on text conditioning, and flexibility to include multiple text embedding models, as done in eDiff-I

It is clear now that text guidance is the ultimate interface to models. This repository will leverage some python decorator magic to make it easy to incorporate SOTA text conditioning to any model.

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

  • 🤗 Huggingface for their amazing transformers library. The text conditioning module will use T5 embeddings, as latest research recommends

  • OpenCLIP for providing SOTA open sourced CLIP models. The eDiff model sees immense improvements by combining the T5 embeddings with CLIP text embeddings

Install

$ pip install classifier-free-guidance-pytorch

Usage

import torch
from classifier_free_guidance_pytorch import TextConditioner

text_conditioner = TextConditioner(
    model_types = 't5',    
    hidden_dims = (256, 512),
    hiddens_channel_first = False,
    cond_drop_prob = 0.2  # conditional dropout 20% of the time, must be greater than 0. to unlock classifier free guidance
).cuda()

# pass in your text as a List[str], and get back a List[callable]
# each callable function receives the hiddens in the dimensions listed at init (hidden_dims)

first_condition_fn, second_condition_fn = text_conditioner(['a dog chasing after a ball'])

# these hiddens will be in the direct flow of your model, say in a unet

first_hidden = torch.randn(1, 16, 256).cuda()
second_hidden = torch.randn(1, 32, 512).cuda()

# conditioned features

first_conditioned = first_condition_fn(first_hidden)
second_conditioned = second_condition_fn(second_hidden)

If you wish to use cross attention based conditioning (each hidden feature in your network can attend to individual subword tokens), just import the AttentionTextConditioner instead. Rest is the same

from classifier_free_guidance_pytorch import AttentionTextConditioner

text_conditioner = AttentionTextConditioner(
    model_types = ('t5', 'clip'),   # something like in eDiff paper, where they used both T5 and Clip for even better results (Balaji et al.)
    hidden_dims = (256, 512),
    cond_drop_prob = 0.2
)

Magic Class Decorator

This is a work in progress to make it as easy as possible to text condition your network.

First, let's say you have a simple two layer network

import torch
from torch import nn

class MLP(nn.Module):
    def __init__(
        self,
        dim
    ):
        super().__init__()
        self.proj_in = nn.Sequential(nn.Linear(dim, dim * 2), nn.ReLU())
        self.proj_mid = nn.Sequential(nn.Linear(dim * 2, dim), nn.ReLU())
        self.proj_out = nn.Linear(dim, 1)

    def forward(
        self,
        data
    ):
        hiddens1 = self.proj_in(data)
        hiddens2 = self.proj_mid(hiddens1)
        return self.proj_out(hiddens2)

# instantiate model and pass in some data, get (in this case) a binary prediction

model = MLP(dim = 256)

data = torch.randn(2, 256)

pred = model(data)

You would like to condition the hidden layers (hiddens1 and hiddens2) with text. Each batch element here would get its own free text conditioning

This has been whittled down to ~3 step using this repository.

import torch
from torch import nn

from classifier_free_guidance_pytorch import classifier_free_guidance_class_decorator

@classifier_free_guidance_class_decorator
class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.proj_in = nn.Sequential(nn.Linear(dim, dim * 2), nn.ReLU())
        self.proj_mid = nn.Sequential(nn.Linear(dim * 2, dim), nn.ReLU())
        self.proj_out = nn.Linear(dim, 1)

    def forward(
        self,
        inp,
        cond_fns # List[Callable] - (1) your forward function now receives a list of conditioning functions, which you invoke on your hidden tensors
    ):
        cond_hidden1, cond_hidden2 = cond_fns # conditioning functions are given back in the order of the `hidden_dims` set on the text conditioner

        hiddens1 = self.proj_in(inp)
        hiddens1 = cond_hidden1(hiddens1) # (2) condition the first hidden layer with FiLM

        hiddens2 = self.proj_mid(hiddens1)
        hiddens2 = cond_hidden2(hiddens2) # condition the second hidden layer with FiLM

        return self.proj_out(hiddens2)

# instantiate your model - extra keyword arguments will need to be defined, prepended by `text_condition_`

model = MLP(
    dim = 256,
    text_condition_type = 'film',                 # can be film, attention, or null (none)
    text_condition_model_types = ('t5', 'clip'),  # in this example, conditioning on both T5 and OpenCLIP
    text_condition_hidden_dims = (512, 256),      # and pass in the hidden dimensions you would like to condition on. in this case there are two hidden dimensions (dim * 2 and dim, after the first and second projections)
    text_condition_cond_drop_prob = 0.25          # conditional dropout probability for classifier free guidance. can be set to 0. if you do not need it and just want the text conditioning
)

# now you have your input data as well as corresponding free text as List[str]

data = torch.randn(2, 256)
texts = ['a description', 'another description']

# (3) train your model, passing in your list of strings as 'texts'

pred  = model(data, texts = texts)

# after much training, you can now do classifier free guidance by passing in a condition scale of > 1. !

model.eval()
guided_pred = model(data, texts = texts, cond_scale = 3.)  # cond_scale stands for conditioning scale from classifier free guidance paper

Todo

Citations

@article{Ho2022ClassifierFreeDG,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2207.12598}
}
@article{Balaji2022eDiffITD,
    title   = {eDiff-I: Text-to-Image Diffusion Models with an Ensemble of Expert Denoisers},
    author  = {Yogesh Balaji and Seungjun Nah and Xun Huang and Arash Vahdat and Jiaming Song and Karsten Kreis and Miika Aittala and Timo Aila and Samuli Laine and Bryan Catanzaro and Tero Karras and Ming-Yu Liu},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2211.01324}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Lin2023CommonDN,
    title   = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
    author  = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
    year    = {2023}
}

.\lucidrains\classifier-free-guidance-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'classifier-free-guidance-pytorch',  # 包名
  packages = find_packages(exclude=[]),  # 查找包
  include_package_data = True,  # 包含数据文件
  version = '0.5.3',  # 版本号
  license='MIT',  # 许可证
  description = 'Classifier Free Guidance - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/classifier-free-guidance-pytorch',  # URL
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'classifier free guidance',
    'text conditioning and guidance'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.7',
    'ftfy',
    'open-clip-torch>=2.8.0',
    'torch>=2.0',
    'transformers[torch]'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\CoCa-pytorch\coca_pytorch\coca_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 einsum, nn 模块
from torch import einsum, nn
# 从 torch 库中导入 F 模块
import torch.nn.functional as F
# 从 torch.autograd 库中导入 Function 模块
from torch.autograd import Function
# 从 torch.distributed 库中导入 dist 模块
import torch.distributed as dist

# 从 einops 库中导入 rearrange, repeat 函数

# helper functions

# 定义函数 exists,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数 default,如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# distributed

# 定义函数 pad_dim_to,将张量在指定维度上填充到指定长度
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)

# 定义函数 all_gather_variable_batch,用于在分布式环境中收集所有张量的批次
def all_gather_variable_batch(t):
    device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()

    size = torch.tensor(t.shape[0], device = device, dtype = torch.long)
    sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
    dist.all_gather(sizes, size)

    sizes = torch.stack(sizes)
    max_size = sizes.amax().item()

    padded_t = pad_dim_to(t, max_size, dim = 0)
    gathered_tensors = [torch.empty_like(padded_t, device = device, dtype = padded_t.dtype) for i in range(world_size)]
    dist.all_gather(gathered_tensors, padded_t)

    gathered_tensor = torch.cat(gathered_tensors)
    seq = torch.arange(max_size, device = device)

    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')

    gathered_tensor = gathered_tensor[mask]
    sizes = sizes.tolist()

    return gathered_tensor, sizes

# 定义类 AllGather,用于在分布式环境中收集所有张量
class AllGather(Function):
    @staticmethod
    def forward(ctx, x):
        assert dist.is_initialized() and dist.get_world_size() > 1
        x, batch_sizes = all_gather_variable_batch(x)
        ctx.batch_sizes = batch_sizes
        return x

    @staticmethod
    def backward(ctx, grads):
        batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
        grads_by_rank = grads.split(batch_sizes, dim = 0)
        return grads_by_rank[rank]

# 将 AllGather 类应用到张量上
all_gather = AllGather.apply

# normalization
# they use layernorm without bias, something that pytorch does not offer

# 定义类 LayerNorm,用于实现 Layer Normalization
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# residual

# 定义类 Residual,用于实现残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# to latents

# 定义类 EmbedToLatents,用于将输入转换为潜在空间
class EmbedToLatents(nn.Module):
    def __init__(self, dim, dim_latents):
        super().__init__()
        self.to_latents = nn.Linear(dim, dim_latents, bias=False)

    def forward(self, x):
        latents = self.to_latents(x)
        return F.normalize(latents, dim=-1)

# rotary positional embedding
# https://arxiv.org/abs/2104.09864

# 定义类 RotaryEmbedding,用于实现旋转位置嵌���
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)

# 定义函数 rotate_half,用于旋转张量的一半
def rotate_half(x):
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)

# 定义函数 apply_rotary_pos_emb,应用旋转位置嵌入到张量上
def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())

# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202

# 定义类 SwiGLU,用于实现 SwiGLU 激活函数
class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame

# 定义类 ParallelTransformerBlock,用于实现并行的注意力和前馈网络块
class ParallelTransformerBlock(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        # 调用父类的初始化函数
        super().__init__()
        # 对输入进行归一化处理
        self.norm = LayerNorm(dim)

        # 计算注意力机制和前馈网络的内部维度
        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        # 设置头数和缩放因子
        self.heads = heads
        self.scale = dim_head**-0.5
        # 初始化旋转嵌入
        self.rotary_emb = RotaryEmbedding(dim_head)

        # 定义融合的注意力机制和前馈网络的投影层
        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        # 前馈网络输出层
        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # 用于缓存因果掩码和旋转嵌入
        self.mask = None
        self.pos_emb = None

    # 获取因果掩码
    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n].to(device)

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.mask = mask
        return mask

    # 获取旋转嵌入
    def get_rotary_embedding(self, n, device):
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n].to(device)

        pos_emb = self.rotary_emb(n, device=device)
        self.pos_emb = pos_emb
        return pos_emb

    # 前向传播函数
    def forward(self, x, attn_mask=None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device, h = x.shape[1], x.device, self.heads

        # 预先归一化处理
        x = self.norm(x)

        # 获取注意力机制的查询、键、值和前馈网络的内部表示
        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # 分割头部
        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # 旋转嵌入
        positions = self.get_rotary_embedding(n, device)
        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

        # 缩放
        q = q * self.scale

        # 相似度计算
        sim = einsum("b h i d, b j d -> b h i j", q, k)

        # 因果掩码
        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 额外的注意力掩码
        if exists(attn_mask):
            attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        # 注意力计算
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        # 聚合值
        out = einsum("b h i j, b j d -> b h i d", attn, v)

        # 合并头部
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.attn_out(out) + self.ff_out(ff)
# 定义交叉注意力模块,使用多查询 + 单头键/值,类似于 PaLM,可选择并行前馈
class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim=None,
        dim_head=64,
        heads=8,
        parallel_ff=False,
        ff_mult=4,
        norm_context=False
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head
        context_dim = default(context_dim, dim)

        self.norm = LayerNorm(dim)
        self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

        # 是否使用并行前馈
        ff_inner_dim = ff_mult * dim

        self.ff = nn.Sequential(
            nn.Linear(dim, ff_inner_dim * 2, bias=False),
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        ) if parallel_ff else None

    def forward(self, x, context):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 预层归一化,用于查询和上下文
        x = self.norm(x)
        context = self.context_norm(context)

        # 获取查询
        q = self.to_q(x)
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # 缩放
        q = q * self.scale

        # 获取键/值
        k, v = self.to_kv(context).chunk(2, dim=-1)

        # 查询/键相似度
        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # 注意力
        sim = sim - sim.amax(dim=-1, keepdim=True)
        attn = sim.softmax(dim=-1)

        # 聚合
        out = einsum('b h i j, b j d -> b h i d', attn, v)

        # 合并和组合头
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        # 添加并行前馈(用于多模态层)
        if exists(self.ff):
            out = out + self.ff(x)

        return out

# transformer
class CoCa(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        unimodal_depth,
        multimodal_depth,
        dim_latents = None,
        image_dim = None,
        num_img_queries=256,
        dim_head=64,
        heads=8,
        ff_mult=4,
        img_encoder=None,
        caption_loss_weight=1.,
        contrastive_loss_weight=1.,
        pad_id=0
    # 初始化函数,设置模型的参数
    def __init__(
        self,
        dim,
        num_tokens,
        pad_id,
        caption_loss_weight,
        contrastive_loss_weight,
        img_encoder,
        num_img_queries,
        image_dim,
        dim_head,
        heads,
        dim_latents,
        unimodal_depth,
        multimodal_depth,
        ff_mult
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置模型的维度
        self.dim = dim

        # 设置填充标识符和损失权重
        self.pad_id = pad_id
        self.caption_loss_weight = caption_loss_weight
        self.contrastive_loss_weight = contrastive_loss_weight

        # token embeddings

        # 创建 token embeddings 层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建文本分类标记
        self.text_cls_token = nn.Parameter(torch.randn(dim))

        # image encoder

        # 设置图像编码器
        self.img_encoder = img_encoder

        # attention pooling for image tokens

        # 创建图像查询参数
        self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, dim)) # num image queries for multimodal, but 1 extra CLS for contrastive learning
        # 创建图像注意力池化层
        self.img_attn_pool = CrossAttention(dim=dim, context_dim=image_dim, dim_head=dim_head, heads=heads, norm_context=True)

        # 图像注意力池化层的归一化
        self.img_attn_pool_norm = LayerNorm(dim)
        # 文本分类标记的归一化
        self.text_cls_norm = LayerNorm(dim)

        # to latents

        # 设置潜变量的维度
        dim_latents = default(dim_latents, dim)
        # 图像到潜变量的映射
        self.img_to_latents = EmbedToLatents(dim, dim_latents)
        # 文本到潜变量的映射
        self.text_to_latents = EmbedToLatents(dim, dim_latents)

        # 对比学习的温度参数
        self.temperature = nn.Parameter(torch.Tensor([1.]))

        # unimodal layers

        # 创建单模态层
        self.unimodal_layers = nn.ModuleList([])
        for ind in range(unimodal_depth):
            self.unimodal_layers.append(
                Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
            )

        # multimodal layers

        # 创建多模态层
        self.multimodal_layers = nn.ModuleList([])
        for ind in range(multimodal_depth):
            self.multimodal_layers.append(nn.ModuleList([
                Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
                Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult))
            ]))

        # to logits

        # 创建输出层
        self.to_logits = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, num_tokens, bias=False)
        )

        # 将嵌入权重与投影层权重绑定
        self.to_logits[-1].weight = self.token_emb.weight
        # 初始化嵌入权重
        nn.init.normal_(self.token_emb.weight, std=0.02)

        # 是否处于数据并行设置中
        self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1

    # 嵌入文本
    def embed_text(self, text):
        # 获取批次大小和设备
        batch, device = text.shape[0], text.device

        # 获取序列长度
        seq = text.shape[1]

        # 获取文本的 token embeddings
        text_tokens = self.token_emb(text)

        # 添加文本分类标记
        text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch)
        text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2)

        # 创建文本分类标记的特定掩码,防止其与填充部分进行注意力
        cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')
        attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

        # 经过单模态层
        for attn_ff in self.unimodal_layers:
            text_tokens = attn_ff(text_tokens, attn_mask=attn_mask)

        # 获取文本分类标记
        text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1]
        text_embeds = self.text_cls_norm(text_cls_tokens)
        return text_embeds, text_tokens
    # 将图像嵌入到嵌入向量中
    def embed_image(self, images=None, image_tokens=None):
        # 将图像编码为嵌入向量
        # 使用在初始化时传入的 img_encoder
        # 也可以接受预先计算的图像标记

        # 确保图像和图像标记不同时存在
        assert not (exists(images) and exists(image_tokens))

        if exists(images):
            # 确保存在 self.img_encoder,用于自动图像编码
            assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
            image_tokens = self.img_encoder(images)

        # 注意力池化图像标记

        img_queries = repeat(self.img_queries, 'n d -> b n d', b=image_tokens.shape[0])
        img_queries = self.img_attn_pool(img_queries, image_tokens)
        img_queries = self.img_attn_pool_norm(img_queries)

        return img_queries[:, 0], img_queries[:, 1:]

    def forward(
        self,
        text,
        images=None,
        image_tokens=None,
        labels=None,
        return_loss=False,
        return_embeddings=False
    ):
        batch, device = text.shape[0], text.device

        if return_loss and not exists(labels):
            text, labels = text[:, :-1], text[:, 1:]

        text_embeds, text_tokens = self.embed_text(text)

        image_embeds, image_tokens = self.embed_image(images=images, image_tokens=image_tokens)

        # 如果研究人员需要返回嵌入向量,则返回嵌入向量

        if return_embeddings:
            return text_embeds, image_embeds

        # 经过多模态层

        for attn_ff, cross_attn in self.multimodal_layers:
            text_tokens = attn_ff(text_tokens)
            text_tokens = cross_attn(text_tokens, image_tokens)

        logits = self.to_logits(text_tokens)

        if not return_loss:
            return logits

        # 缩写

        ce = F.cross_entropy

        # 计算标题损失(交叉熵损失)

        logits = rearrange(logits, 'b n c -> b c n')
        caption_loss = ce(logits, labels, ignore_index=self.pad_id)
        caption_loss = caption_loss * self.caption_loss_weight

        # 嵌入到潜变量

        text_latents = self.text_to_latents(text_embeds)
        image_latents = self.img_to_latents(image_embeds)

        # 可能进行分布式全收集

        if self.is_distributed:
            latents = torch.stack((text_latents, image_latents), dim=1)
            latents = all_gather(latents)
            text_latents, image_latents = latents.unbind(dim=1)

        # 计算对比损失

        sim = einsum('i d, j d -> i j', text_latents, image_latents)
        sim = sim * self.temperature.exp()
        contrastive_labels = torch.arange(batch, device=device)

        contrastive_loss = (ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels)) * 0.5
        contrastive_loss = contrastive_loss * self.contrastive_loss_weight

        return caption_loss + contrastive_loss
posted @ 2024-06-28 14:14  绝不原创的飞龙  阅读(11)  评论(0编辑  收藏  举报