Lucidrains-系列项目源码解析-七-

Lucidrains 系列项目源码解析(七)

.\lucidrains\compressive-transformer-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'compressive-transformer-pytorch', # 包的名称
  packages = find_packages(exclude=['examples']), # 查找并包含除了 examples 之外的所有包
  version = '0.4.0', # 版本号
  license='MIT', # 许可证
  description = 'Implementation of Compressive Transformer in Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/compressive-transformer-pytorch', # 项目链接
  keywords = [ # 关键词列表
      'attention',
      'artificial intelligence',
      'transformer',
      'deep learning'
  ],
  install_requires=[ # 安装依赖
      'torch',
      'mogrifier'
  ],
  classifiers=[ # 分类器
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\conformer\conformer\conformer.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

# 导入所需的库

# helper functions

# 定义辅助函数

def exists(val):
    return val is not None

# 检查值是否存在的函数

def default(val, d):
    return val if exists(val) else d

# 如果值存在则返回该值,否则返回默认值的函数

def calc_same_padding(kernel_size):
    pad = kernel_size // 2
    return (pad, pad - (kernel_size + 1) % 2)

# 计算卷积核大小的 padding 值的函数

# helper classes

# 定义辅助类

class Swish(nn.Module):
    def forward(self, x):
        return x * x.sigmoid()

# Swish 激活函数类的定义

class GLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        out, gate = x.chunk(2, dim=self.dim)
        return out * gate.sigmoid()

# GLU 激活函数类的定义

class DepthWiseConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, padding):
        super().__init__()
        self.padding = padding
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)

    def forward(self, x):
        x = F.pad(x, self.padding)
        return self.conv(x)

# 深度卷积类的定义

# attention, feedforward, and conv module

# 注意力、前馈和卷积模块的定义

class Scale(nn.Module):
    def __init__(self, scale, fn):
        super().__init__()
        self.fn = fn
        self.scale = scale

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

# 缩放类的定义

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 预归一化类的定义

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        max_pos_emb = 512
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads= heads
        self.scale = dim_head ** -0.5
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.max_pos_emb = max_pos_emb
        self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
        context = None,
        mask = None,
        context_mask = None
    ):
        n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context)
        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # shaw's relative positional embedding

        seq = torch.arange(n, device = device)
        dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j')
        dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
        rel_pos_emb = self.rel_pos_emb(dist).to(q)
        pos_attn = einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale
        dots = dots + pos_attn

        if exists(mask) or exists(context_mask):
            mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
            context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))
            mask_value = -torch.finfo(dots.dtype).max
            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
            dots.masked_fill_(~mask, mask_value)

        attn = dots.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return self.dropout(out)

# 注意力机制类的定义

class FeedForward(nn.Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.
    ):  # 定义神经网络模型的初始化方法
        super().__init__()  # 调用父类的初始化方法
        self.net = nn.Sequential(  # 创建一个包含多个神经网络层的序列容器
            nn.Linear(dim, dim * mult),  # 添加线性层,输入维度为dim,输出维度为dim * mult
            Swish(),  # 使用Swish激活函数
            nn.Dropout(dropout),  # 添加Dropout层,以减少过拟合
            nn.Linear(dim * mult, dim),  # 添加线性层,输入维度为dim * mult,输出维度为dim
            nn.Dropout(dropout)  # 再次添加Dropout层
        )

    def forward(self, x):  # 定义神经网络模型的前向传播方法
        return self.net(x)  # 返回神经网络模型对输入x的输出结果
# 定义一个 ConformerConvModule 类,继承自 nn.Module
class ConformerConvModule(nn.Module):
    # 初始化方法,接受多个参数
    def __init__(
        self,
        dim,
        causal = False,
        expansion_factor = 2,
        kernel_size = 31,
        dropout = 0.
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 计算内部维度
        inner_dim = dim * expansion_factor
        # 计算填充大小
        padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)

        # 定义网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # LayerNorm 层
            Rearrange('b n c -> b c n'),  # 重新排列维度
            nn.Conv1d(dim, inner_dim * 2, 1),  # 一维卷积层
            GLU(dim=1),  # GLU 激活函数
            DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),  # 深度卷积层
            nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),  # BatchNorm1d 层或 Identity 层
            Swish(),  # Swish 激活函数
            nn.Conv1d(inner_dim, dim, 1),  # 一维卷积层
            Rearrange('b c n -> b n c'),  # 重新排列维度
            nn.Dropout(dropout)  # Dropout 层
        )

    # 前向传播方法
    def forward(self, x):
        return self.net(x)

# 定义一个 ConformerBlock 类,继承自 nn.Module
class ConformerBlock(nn.Module):
    # 初始化方法,接受多个参数
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 定义网络结构
        self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)  # FeedForward 层
        self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)  # Attention 层
        self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)  # ConformerConvModule 层
        self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)  # FeedForward 层

        self.attn = PreNorm(dim, self.attn)  # PreNorm 层
        self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))  # Scale 层
        self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))  # Scale 层

        self.post_norm = nn.LayerNorm(dim)  # LayerNorm 层

    # 前向传播方法
    def forward(self, x, mask = None):
        x = self.ff1(x) + x
        x = self.attn(x, mask = mask) + x
        x = self.conv(x) + x
        x = self.ff2(x) + x
        x = self.post_norm(x)
        return x

# 定义一个 Conformer 类,继承自 nn.Module
class Conformer(nn.Module):
    # 初始化方法,接受多个参数
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        # 调用父类的初始化方法
        super().__init__()
        self.dim = dim
        self.layers = nn.ModuleList([])

        # 循环创建 ConformerBlock 层,并添加到 layers 中
        for _ in range(depth):
            self.layers.append(ConformerBlock(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                conv_expansion_factor = conv_expansion_factor,
                conv_kernel_size = conv_kernel_size,
                conv_causal = conv_causal
            ))

    # 前向传播方法
    def forward(self, x):
        # 遍历 layers 中的每个 ConformerBlock 层,并进行前向传播
        for block in self.layers:
            x = block(x)

        return x

.\lucidrains\conformer\conformer\__init__.py

# 从conformer.conformer模块中导入ConformerConvModule, ConformerBlock, Conformer类
from conformer.conformer import ConformerConvModule, ConformerBlock, Conformer

Conformer

Implementation of the convolutional module from the Conformer paper, for improving the local inductive bias in Transformers.

Install

$ pip install conformer

Usage

The Conformer convolutional module, the main novelty of the paper

import torch
from conformer import ConformerConvModule

layer = ConformerConvModule(
    dim = 512,
    causal = False,             # auto-regressive or not - 1d conv will be made causal with padding if so
    expansion_factor = 2,       # what multiple of the dimension to expand for the depthwise convolution
    kernel_size = 31,           # kernel size, 17 - 31 was said to be optimal
    dropout = 0.                # dropout at the very end
)

x = torch.randn(1, 1024, 512)
x = layer(x) + x

1 Conformer Block

import torch
from conformer import ConformerBlock

block = ConformerBlock(
    dim = 512,
    dim_head = 64,
    heads = 8,
    ff_mult = 4,
    conv_expansion_factor = 2,
    conv_kernel_size = 31,
    attn_dropout = 0.,
    ff_dropout = 0.,
    conv_dropout = 0.
)

x = torch.randn(1, 1024, 512)

block(x) # (1, 1024, 512)

Conformer - just multiple ConformerBlock from above

import torch
from conformer import Conformer

conformer = Conformer(
    dim = 512,
    depth = 12,          # 12 blocks
    dim_head = 64,
    heads = 8,
    ff_mult = 4,
    conv_expansion_factor = 2,
    conv_kernel_size = 31,
    attn_dropout = 0.,
    ff_dropout = 0.,
    conv_dropout = 0.
)

x = torch.randn(1, 1024, 512)

conformer(x) # (1, 1024, 512)

Todo

Citations

@misc{gulati2020conformer,
    title   = {Conformer: Convolution-augmented Transformer for Speech Recognition},
    author  = {Anmol Gulati and James Qin and Chung-Cheng Chiu and Niki Parmar and Yu Zhang and Jiahui Yu and Wei Han and Shibo Wang and Zhengdong Zhang and Yonghui Wu and Ruoming Pang},
    year    = {2020},
    eprint  = {2005.08100},
    archivePrefix = {arXiv},
    primaryClass = {eess.AS}
}

.\lucidrains\conformer\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'conformer',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.3.2',
  # 许可证信息
  license='MIT',
  # 描述信息
  description = 'The convolutional module from the Conformer paper',
  # 作者信息
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/conformer',
  # 关键词列表
  keywords = [
      'artificial intelligence',
      'deep learning',
      'transformers',
      'audio'
  ],
  # 安装依赖列表
  install_requires=[
      'einops>=0.6.1',
      'torch'
  ],
  # 分类信息列表
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\contrastive-learner\contrastive_learner\contrastive_learner.py

# 导入必要的库
import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F

from torchvision.models import resnet50
from kornia import augmentation as augs
from kornia import filters

# 辅助函数

# 定义一个返回输入的函数
def identity(x): return x

# 如果输入值为None,则返回默认值
def default(val, def_val):
    return def_val if val is None else val

# 将输入张量展平
def flatten(t):
    return t.reshape(t.shape[0], -1)

# 安全地在指定维度上连接张量
def safe_concat(arr, el, dim=0):
    if arr is None:
        return el
    return torch.cat((arr, el), dim=dim)

# 单例装饰器,用于缓存结果
def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# 损失函数

# 对比损失函数
def contrastive_loss(queries, keys, temperature = 0.1):
    b, device = queries.shape[0], queries.device
    logits = queries @ keys.t()
    logits = logits - logits.max(dim=-1, keepdim=True).values
    logits /= temperature
    return F.cross_entropy(logits, torch.arange(b, device=device))

# NT-Xent损失函数
def nt_xent_loss(queries, keys, temperature = 0.1):
    b, device = queries.shape[0], queries.device

    n = b * 2
    projs = torch.cat((queries, keys))
    logits = projs @ projs.t()

    mask = torch.eye(n, device=device).bool()
    logits = logits[~mask].reshape(n, n - 1)
    logits /= temperature

    labels = torch.cat(((torch.arange(b, device=device) + b - 1), torch.arange(b, device=device)), dim=0)
    loss = F.cross_entropy(logits, labels, reduction='sum')
    loss /= n
    return loss

# 数据增强工具

# 随机应用数据增强函数
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# 指数移动平均

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# 隐藏层提取器类

class OutputHiddenLayer(nn.Module):
    def __init__(self, net, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.hidden = None
        self._register_hook()

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _register_hook(self):
        def hook(_, __, output):
            self.hidden = output

        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(hook)

    def forward(self, x):
        if self.layer == -1:
            return self.net(x)

        _ = self.net(x)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

# 主类

class ContrastiveLearner(nn.Module):
    # 初始化函数,设置模型参数和属性
    def __init__(self, net, image_size, hidden_layer = -2, project_hidden = True, project_dim=128, augment_both=True, use_nt_xent_loss=False, augment_fn = None, use_bilinear = False, use_momentum = False, momentum_value = 0.999, key_encoder = None, temperature = 0.1):
        # 调用父类的初始化函数
        super().__init__()
        # 创建输出隐藏层对象
        self.net = OutputHiddenLayer(net, layer=hidden_layer)

        # 默认数据增强操作
        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size))
        )

        # 设置数据增强操作
        self.augment = default(augment_fn, DEFAULT_AUG)

        # 是否同时对两个数据进行增强
        self.augment_both = augment_both

        # 设置温度参数和是否使用 NT-Xent 损失函数
        self.temperature = temperature
        self.use_nt_xent_loss = use_nt_xent_loss

        # 是否对隐藏层进行投影
        self.project_hidden = project_hidden
        self.projection = None
        self.project_dim = project_dim

        # 是否使用双线性插值
        self.use_bilinear = use_bilinear
        self.bilinear_w = None

        # 是否使用动量方法
        self.use_momentum = use_momentum
        self.ema_updater = EMA(momentum_value)
        self.key_encoder = key_encoder

        # 用于累积查询和键
        self.queries = None
        self.keys = None

        # 发送一个模拟图像张量以实例化参数
        self.forward(torch.randn(1, 3, image_size, image_size))

    # 获取键编码器对象
    @singleton('key_encoder')
    def _get_key_encoder(self):
        key_encoder = copy.deepcopy(self.net)
        key_encoder._register_hook()
        return key_encoder

    # 获取双线性插值矩阵
    @singleton('bilinear_w')
    def _get_bilinear(self, hidden):
        _, dim = hidden.shape
        return nn.Parameter(torch.eye(dim, device=device, dtype=dtype)).to(hidden)

    # 获取投影函数
    @singleton('projection')
    def _get_projection_fn(self, hidden):
        _, dim = hidden.shape

        return nn.Sequential(
            nn.Linear(dim, dim, bias = False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(dim, self.project_dim, bias = False)
        ).to(hidden)

    # 重置移动平均值
    def reset_moving_average(self):
        assert self.use_momentum, 'must be using momentum method for key encoder'
        del self.key_encoder
        self.key_encoder = None

    # 更新移动平均值
    def update_moving_average(self):
        assert self.key_encoder is not None, 'key encoder has not been created yet'
        self.key_encoder = update_moving_average(self.ema_updater, self.key_encoder, self.net)

    # 计算损失函数
    def calculate_loss(self):
        assert self.queries is not None and self.keys is not None, 'no queries or keys accumulated'
        loss_fn = nt_xent_loss if self.use_nt_xent_loss else contrastive_loss
        loss = loss_fn(self.queries, self.keys, temperature = self.temperature)
        self.queries = self.keys = None
        return loss

    # 前向传播函数
    def forward(self, x, accumulate = False):
        # 获取输入张量的形状和设备信息
        b, c, h, w, device = *x.shape, x.device
        transform_fn = self.augment if self.augment_both else noop

        # 获取查询编码器
        query_encoder = self.net
        queries = query_encoder(transform_fn(x))

        # 获取键编码器
        key_encoder = self.net if not self.use_momentum else self._get_key_encoder()
        keys = key_encoder(self.augment(x))

        if self.use_momentum:
            keys = keys.detach()

        queries, keys = map(flatten, (queries, keys))

        if self.use_bilinear:
            W = self._get_bilinear(keys)
            keys = (W @ keys.t()).t()

        project_fn = self._get_projection_fn(queries) if self.project_hidden else identity
        queries, keys = map(project_fn, (queries, keys))

        self.queries = safe_concat(self.queries, queries)
        self.keys = safe_concat(self.keys, keys)

        return self.calculate_loss() if not accumulate else None

.\lucidrains\contrastive-learner\contrastive_learner\__init__.py

# 从contrastive_learner.contrastive_learner模块中导入ContrastiveLearner类
from contrastive_learner.contrastive_learner import ContrastiveLearner

Contrastive learning in Pytorch, made simple

PyPI version

It seems we have lift-off for self-supervised learning on images.

This is a simple to use Pytorch wrapper to enable contrastive self-supervised learning on any visual neural network. At the moment, it contains enough settings for one to train on either of the schemes used in SimCLR or CURL.

You can wrap any neural network that accepts a visual input, be it a resnet, policy network, or the discriminator of a GAN. The rest is taken care of.

Issues

It has surfaced that the results of CURL are not reproducible. It is recommended that you go with the SimCLR settings until further notice.

Install

$ pip install contrastive-learner

Usage

SimCLR (projection head with normalized temperature-scaled cross-entropy loss)

import torch
from contrastive_learner import ContrastiveLearner
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = ContrastiveLearner(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',  # layer name where output is hidden dimension. this can also be an integer specifying the index of the child
    project_hidden = True,     # use projection head
    project_dim = 128,         # projection head dimensions, 128 from paper
    use_nt_xent_loss = True,   # the above mentioned loss, abbreviated
    temperature = 0.1,         # temperature
    augment_both = True        # augment both query and key
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_batch_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_batch_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

CURL (with momentum averaged key encoder)

import torch
from contrastive_learner import ContrastiveLearner
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = ContrastiveLearner(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = True,         # use momentum for key encoder
    momentum_value = 0.999,
    project_hidden = False,      # no projection heads
    use_bilinear = True,         # in paper, logits is bilinear product of query / key
    use_nt_xent_loss = False,    # use regular contrastive loss
    augment_both = False         # in curl, only the key is augmented
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_batch_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_batch_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of key encoder

Advanced

If you want to accumulate queries and keys to do contrastive loss on a bigger batch, use the accumulate keyword on the forward pass.

for _ in range(100):
    for _ in range(5):
        images = sample_batch_images()
        _ = learner(images, accumulate=True)  # accumulate queries and keys
    loss = learner.calculate_loss()           # calculate similarity on all accumulated
    opt.zero_grad()
    loss.backward()
    opt.step()

By default, this will use the augmentations recommended in the SimCLR paper, mainly color jitter, gaussian blur, and random resize crop. However, if you would like to specify your own augmentations, you can simply pass in a augment_fn in the constructor. Augmentations must work in the tensor space. If you decide to use torchvision augmentations, make sure the function converts first to PIL .toPILImage() and then back to tensors .ToTensor()

custom_augment_fn = nn.Sequential(
    kornia.augmentations.RandomHorizontalFlip()
)

learner = ContrastiveLearner(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    project_hidden = True,
    project_dim = 128,
    use_nt_xent_loss = True,
    augment_fn = custom_augment_fn
)

Citations

@misc{chen2020simple,
    title   = {A Simple Framework for Contrastive Learning of Visual Representations},
    author  = {Ting Chen and Simon Kornblith and Mohammad Norouzi and Geoffrey Hinton},
    year    = {2020}
}
@misc{srinivas2020curl,
    title   = {CURL: Contrastive Unsupervised Representations for Reinforcement Learning},
    author  = {Aravind Srinivas and Michael Laskin and Pieter Abbeel},
    year    = {2020}
}

.\lucidrains\contrastive-learner\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'contrastive_learner',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.1.1',
  # 许可证
  license='MIT',
  # 描述
  description = 'Self-supervised contrastive learning made simple',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/contrastive-learner',
  # 关键词
  keywords = ['self-supervised learning', 'artificial intelligence'],
  # 安装依赖
  install_requires=[
      'torch',
      'kornia'
  ],
  # 分类
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\coordinate-descent-attention\coordinate_descent_attention\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        # 保存模型当前是否为训练状态
        was_training = model.training
        # 将模型设置为评估状态
        model.eval()
        # 调用传入的函数,并传入模型、参数和关键字参数
        out = fn(model, *args, **kwargs)
        # 恢复模型之前的训练状态
        model.train(was_training)
        return out
    return inner

# top k 过滤

# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
    # 计算需要保留的 top k 值的数量
    k = int((1 - thres) * logits.shape[-1])
    # 获取 top k 值和对应的索引
    val, ind = torch.topk(logits, k)
    # 创建一个与 logits 相同形状的张量,填充为负的最大值
    probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
    # 根据索引将 top k 值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net,        
        pad_value = 0
    ):
        super().__init__()
        # 初始化属性
        self.seq_len = net.seq_len
        self.pad_value = pad_value
        self.net = net

    # 生成函数装饰器,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prompt,
        seq_len,
        temperature=1.0,
        filter_thres=0.9,
        **kwargs
    ):
        # 获取 prompt 的形状和设备信息
        b, t, device = *prompt.shape, prompt.device

        out = prompt

        # 生成序列
        for _ in range(seq_len):
            # 获取最后 self.seq_len 长度的序列,并传入网络获取 logits
            logits = self.net(out[:, -self.seq_len:], **kwargs)[:, -1]

            # 对 logits 进行 top k 过滤
            filtered_logits = top_k(logits, thres = filter_thres)
            # 计算概率分布
            probs = F.softmax(filtered_logits / temperature, dim = -1)

            # 从概率分布中采样一个值
            sample = torch.multinomial(probs, 1)
            # 将采样值拼接到输出序列中
            out = torch.cat((out, sample), dim = -1)

        # 去除前面的 prompt 部分,返回生成的序列
        out = out[:, t:]
        return out

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 获取输入 x 和标签 labels
        x, labels = x[:, :-1], x[:, 1:]
        # 将输入传入网络获取 logits
        logits = self.net(x, **kwargs)
        # 重新排列 logits 的维度
        logits = rearrange(logits, "b c n -> b n c")
        # 计算交叉熵损失
        return F.cross_entropy(logits, labels)

.\lucidrains\coordinate-descent-attention\coordinate_descent_attention\coordinate_descent_attention.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 库中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 colt5_attention 模块中导入 coor_descent 和 triton_coor_descent 函数

from colt5_attention import coor_descent
from colt5_attention.triton_coor_descent import triton_coor_descent

# helpers

# 定义函数 exists,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数 default,如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# classes

# 定义类 FeedForward,继承自 nn.Module 类
class FeedForward(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        mult = 4,
        use_coor_descent = False,
        coor_descent_iters = 20,
        coor_descent_sparsity_k = None,
        coor_descent_eps = 1e-1,
        coor_descent_eps_init = 4.,
        coor_descent_eps_decay = 0.7,
    ):
        super().__init__()

        dim_hidden = int(dim * mult)

        self.use_coor_descent = use_coor_descent

        self.coor_descent_iters = coor_descent_iters
        self.coor_descent_sparsity_k = default(coor_descent_sparsity_k, dim_hidden // 10)
        self.coor_descent_eps = coor_descent_eps
        self.coor_descent_eps_init = coor_descent_eps_init
        self.coor_descent_eps_decay = coor_descent_eps_decay

        self.proj_in = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim_hidden),
        )

        self.proj_out = nn.Linear(dim_hidden, dim)

    # 前向传播函数
    def forward(self, x):
        x = self.proj_in(x)

        if self.use_coor_descent:
            x = triton_coor_descent(
                x,
                n_iters = self.coor_descent_iters,
                k = self.coor_descent_sparsity_k,
                eps = self.coor_descent_eps,
                eps_init = self.coor_descent_eps_init,
                eps_decay = eslf.coor_descent_eps_decay,
                checkpoint_segments = self.coor_descent_iters // 5
            )
        else:
            x = F.gelu(x)

        return self.proj_out(x)

# 定义类 Attention,继承自 nn.Module 类
class Attention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        use_coor_descent = False,
        coor_descent_iters = 20,
        coor_descent_sparsity_k = 1,
        coor_descent_eps = 1e-1,
        coor_descent_eps_init = 4.,
        coor_descent_eps_decay = 0.7,
        attn_null_kv = 0,
        learned_sparsity_k = False
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        dim_inner = dim_head * heads

        self.use_coor_descent = use_coor_descent

        self.coor_descent_iters = coor_descent_iters
        self.coor_descent_sparsity_k = coor_descent_sparsity_k

        self.coor_descent_eps = coor_descent_eps
        self.coor_descent_eps_init = coor_descent_eps_init
        self.coor_descent_eps_decay = coor_descent_eps_decay

        self.to_learned_k = None
        if learned_sparsity_k:
            self.to_learned_k = nn.Linear(dim, heads)
            nn.init.constant_(self.to_learned_k.bias, -10)

        self.norm = nn.LayerNorm(dim)

        self.null_kv = nn.Parameter(torch.randn(2, heads, attn_null_kv, dim_head))

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
        self.to_out = nn.Linear(dim_inner, dim, bias = False)
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        # 解构 x 的形状,获取批大小 b,序列长度 n,头数 h,设备信息 device,数据类型 dtype
        b, n, h, device, dtype = *x.shape[:2], self.heads, x.device, x.dtype
        # 对输入 x 进行归一化处理
        x = self.norm(x)

        # 获取查询(q)、键(k)、值(v),并将它们按头数拆分

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 如果需要添加空键值对

        if self.null_kv.numel() > 0:
            nk, nv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.null_kv)
            k = torch.cat((nk, k), dim = -2)
            v = torch.cat((nv, v), dim = -2)

        # 计算相似度

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        i, j = sim.shape[-2:]
        causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)

        # 是否使用坐标下降

        if self.use_coor_descent:

            if exists(self.to_learned_k):
                sparsity_k = self.to_learned_k(x).sigmoid() * (self.coor_descent_sparsity_k - 1) + 1
                sparsity_k = rearrange(sparsity_k, 'b i h -> (b h i)')
            else:
                sparsity_k = torch.ones(i, device = device, dtype = dtype) * self.coor_descent_sparsity_k

            causal_mask = repeat(causal_mask, 'i j -> b h i j', b = sim.shape[0], h = sim.shape[1])

            attn = triton_coor_descent(
                sim,
                n_iters = self.coor_descent_iters,
                k = sparsity_k,
                eps = self.coor_descent_eps,
                eps_decay = self.coor_descent_eps_decay,
                eps_init = self.coor_descent_eps_init,
                mask = ~causal_mask,
                checkpoint_segments = self.coor_descent_iters // 5
            )

        else:
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
            attn = sim.softmax(dim = -1)

        # 聚合

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头部

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        num_tokens,  # 标记的数量
        dim,  # 向量维度
        seq_len,  # 序列长度
        depth,  # 层数
        dim_head = 64,  # 注意力头的维度
        heads = 8,  # 注意力头的数量
        ff_mult = 4,  # FeedForward 层的倍数
        attn_use_coor_descent = False,  # 是否使用坐标下降优化注意力
        ff_use_coor_descent = False,  # 是否使用坐标下降优化 FeedForward
        attn_coor_descent_sparsity_k = 2,  # 注意力坐标下降的稀疏度参数
        ff_coor_descent_sparsity_k = 2,  # FeedForward 坐标下降的稀疏度参数
        coor_descent_iters = 15,  # 坐标下降的迭代次数
        coor_descent_eps = 1e-1,  # 坐标下降的收敛阈值
        attn_null_kv = 0,  # 注意力的 null key 和 value
        learned_sparsity_k = False  # 是否学习稀疏度参数
    ):
        super().__init__()
        self.seq_len = seq_len  # 保存序列长度

        # 创建标记嵌入层和位置嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(seq_len, dim)

        self.layers = nn.ModuleList([])  # 初始化层列表

        # 定义坐标下降参数字典
        coor_kwargs = dict(
            coor_descent_iters = coor_descent_iters,
            coor_descent_eps = coor_descent_eps,
        )

        # 根据层数循环创建多个注意力和 FeedForward 层
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(
                    dim,
                    dim_head = dim_head,
                    heads = heads,
                    use_coor_descent = attn_use_coor_descent,
                    coor_descent_sparsity_k = attn_coor_descent_sparsity_k,
                    attn_null_kv = attn_null_kv,
                    learned_sparsity_k = learned_sparsity_k,
                    **coor_kwargs
                ),
                FeedForward(
                    dim,
                    ff_mult,
                    use_coor_descent = ff_use_coor_descent,
                    coor_descent_sparsity_k = ff_coor_descent_sparsity_k,
                    **coor_kwargs
                )
            ]))

        # 定义输出层,包括 LayerNorm 和线性层
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    # 前向传播函数
    def forward(self, x):
        n, device = x.shape[-1], x.device
        assert n <= self.seq_len  # 断言序列长度不超过设定的最大长度

        x = self.token_emb(x)  # 对输入进行标记嵌入
        x = x + self.pos_emb(torch.arange(n, device = device))  # 加上位置嵌入

        # 遍历每个注意力和 FeedForward 层,进行前向传播
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.to_logits(x)  # 返回最终输出

.\lucidrains\coordinate-descent-attention\coordinate_descent_attention\__init__.py

# 从 coordinate_descent_attention 包中导入 Transformer 和 Attention 类
from coordinate_descent_attention.coordinate_descent_attention import Transformer, Attention
# 从 coordinate_descent_attention 包中导入 AutoregressiveWrapper 类
from coordinate_descent_attention.autoregressive_wrapper import AutoregressiveWrapper

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Coordinate Descent Attention (wip)

Implementation of an Attention layer where each head can attend to more than just one token, using coordinate descent to pick topk. Perhaps the number of tokens to attend to can even be learned.

In the case that experiments above fail, will use the repo for a few other ideas, among them getting coordinate descent routing working for autoregressive transformers.

Ongoing experiments

Update: I don't think the improvements are worth it. The memory usage becomes impractical as the number of iterations goes up as well. I'll keep playing around with topk attention though, because it bothers me that softmax becomes a bottleneck for the tokens far in the future, especially as sequence lengths go above 8k

Update: Using a kernel written in Triton, it is a bit more viable, but still too much if number of iterations is high

Update: by doing recomputes in segments of iterations, now feasible, if it were to actually yields any improvements

Appreciation

  • StabilityAI for the sponsorship to carry out independent research

Install

$ pip install coordinate-descent-attention

Usage

import torch
from coordinate_descent_attention import Transformer

model = Transformer(
    num_tokens = 256,
    dim = 512,
    depth = 2,
    seq_len = 2048,
    dim_head = 64,
    heads = 8,
    attn_use_coor_descent = True   # set to True to switch from softmax to coordinate descent on qk similarity matrix
).cuda()

x = torch.randint(0, 256, (1, 2048)).cuda()

logits = model(x)

Todo

Citations

@article{Wright2015CoordinateDA,
    title   = {Coordinate descent algorithms},
    author  = {Stephen J. Wright},
    journal = {Mathematical Programming},
    year    = {2015},
    volume  = {151},
    pages   = {3-34}
}
@inproceedings{Gupta2021MemoryefficientTV,
    title   = {Memory-efficient Transformers via Top-k Attention},
    author  = {Ankit Gupta and Guy Dar and Shaya Goodman and David Ciprut and Jonathan Berant},
    booktitle = {SUSTAINLP},
    year    = {2021}
}
@article{Zhao2019ExplicitST,
    title   = {Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection},
    author  = {Guangxiang Zhao and Junyang Lin and Zhiyuan Zhang and Xuancheng Ren and Qi Su and Xu Sun},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1912.11637}
}
@article{Schmitzer2016StabilizedSS,
    title   = {Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems},
    author  = {Bernhard Schmitzer},
    journal = {ArXiv},
    year    = {2016},
    volume  = {abs/1610.06519}
}

.\lucidrains\coordinate-descent-attention\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'coordinate-descent-attention',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.11',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Coordinate Descent Attention - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/coodinate-descent-attention',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'attention mechanism'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.6.1',
    'torch>=1.6',
    'colt5-attention>=0.9.0'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\coordinate-descent-attention\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义的模块
from coordinate_descent_attention import Transformer, AutoregressiveWrapper

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 实例化 Transformer 模型
model = Transformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    seq_len = SEQ_LEN,
    attn_use_coor_descent = True,
    ff_use_coor_descent = True,
    attn_coor_descent_sparsity_k = 2,
    ff_coor_descent_sparsity_k = 128,
    coor_descent_iters = 25
)

# 将模型包装为自回归模型,并移至 GPU
model = AutoregressiveWrapper(model).cuda()

# 准备 enwik8 数据

# 读取并处理数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str, "\n")

.\lucidrains\cross-transformers-pytorch\cross_transformers_pytorch\cross_transformers_pytorch.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

class CrossTransformer(nn.Module):
    def __init__(
        self,
        dim = 512,
        dim_key = 128,
        dim_value = 128
    ):
        # 初始化 CrossTransformer 类
        super().__init__()
        # 设置缩放因子为维度关键字的负平方根
        self.scale = dim_key ** -0.5
        # 将输入转换为查询和键的卷积层
        self.to_qk = nn.Conv2d(dim, dim_key, 1, bias = False)
        # 将输入转换为值的卷积层
        self.to_v = nn.Conv2d(dim, dim_value, 1, bias = False)

    def forward(self, model, img_query, img_supports):
        """
        dimensions names:
        
        b - batch
        k - num classes
        n - num images in a support class
        c - channels
        h, i - height
        w, j - width
        """

        # 获取支持集图像的形状
        b, k, *_ = img_supports.shape

        # 对查询图像进行模型处理
        query_repr = model(img_query)
        *_, h, w = query_repr.shape

        # 重排支持集图像的维度
        img_supports = rearrange(img_supports, 'b k n c h w -> (b k n) c h w', b = b)
        # 对支持集图像进行模型处理
        supports_repr = model(img_supports)

        # 将查询图像转换为查询和值
        query_q, query_v = self.to_qk(query_repr), self.to_v(query_repr)

        # 将支持集图像转换为键和值
        supports_k, supports_v = self.to_qk(supports_repr), self.to_v(supports_repr)
        # 重排支持集图像的维度
        supports_k, supports_v = map(lambda t: rearrange(t, '(b k n) c h w -> b k n c h w', b = b, k = k), (supports_k, supports_v))

        # 计算查询图像和支持集图像之间的相似度
        sim = einsum('b c h w, b k n c i j -> b k h w n i j', query_q, supports_k) * self.scale
        sim = rearrange(sim, 'b k h w n i j -> b k h w (n i j)')

        # 对相似度进行 softmax 操作
        attn = sim.softmax(dim = -1)
        attn = rearrange(attn, 'b k h w (n i j) -> b k h w n i j', i = h, j = w)

        # 计算输出
        out = einsum('b k h w n i j, b k n c i j -> b k c h w', attn, supports_v)

        # 重排输出的维度
        out = rearrange(out, 'b k c h w -> b k (c h w)')
        query_v = rearrange(query_v, 'b c h w -> b () (c h w)')

        # 计算欧氏距离
        euclidean_dist = ((query_v - out) ** 2).sum(dim = -1) / (h * w)
        return -euclidean_dist

.\lucidrains\cross-transformers-pytorch\cross_transformers_pytorch\__init__.py

# 从 cross_transformers_pytorch 包中导入 CrossTransformer 类
from cross_transformers_pytorch.cross_transformers_pytorch import CrossTransformer

Cross Transformers - Pytorch (wip)

Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch

Install

$ pip install cross-transformers-pytorch

Usage

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models
from cross_transformers_pytorch import CrossTransformer

resnet = models.resnet34(pretrained = True)
model = nn.Sequential(*[*resnet.children()][:-2])

cross_transformer = CrossTransformer(
    dim = 512,
    dim_key = 128,
    dim_value = 128
)

# (batch, channels, height, width)
img_query = torch.randn(1, 3, 224, 224)

# (batch, classes, num supports, channels, height, width)
img_supports = torch.randn(1, 2, 4, 3, 224, 224)

labels = torch.randint(0, 2, (1,))

dists = cross_transformer(model, img_query, img_supports) # (1, 2)

loss = F.cross_entropy(dists, labels)
loss.backward()

Citations

@misc{doersch2020crosstransformers,
    title={CrossTransformers: spatially-aware few-shot transfer}, 
    author={Carl Doersch and Ankush Gupta and Andrew Zisserman},
    year={2020},
    eprint={2007.11498},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

.\lucidrains\cross-transformers-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'cross-transformers-pytorch',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '0.0.2',  # 版本号
  license='MIT',  # 许可证
  description = 'Cross Transformers - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/cross-transformers-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'cross attention',
    'few shot learning'
  ],
  install_requires=[  # 安装依赖
    'torch>=1.6',
    'einops>=0.3'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\DALLE-pytorch\dalle_pytorch\attention.py

# 从 inspect 模块中导入 isfunction 函数
# 从 math 模块中导入 ceil 函数
# 导入 torch 库
# 从 torch 模块中导入 nn、einsum
# 从 torch.nn 模块中导入 functional 模块,并重命名为 F
# 从 einops 库中导入 rearrange、repeat 函数
# 导入 rotary_embedding_torch 库中的 apply_rotary_emb 函数

def exists(val):
    # 判断值是否存在
    return val is not None

def uniq(arr):
    # 返回数组中唯一的元素
    return{el: True for el in arr}.keys()

def default(val, d):
    # 如果值存在,则返回该值;否则返回默认值
    if exists(val):
        return val
    return d() if isfunction(d) else d

def max_neg_value(t):
    # 返回给定张量的最大负值
    return -torch.finfo(t.dtype).max

def stable_softmax(t, dim = -1, alpha = 32 ** 2):
    # 计算稳定的 softmax 函数
    t = t / alpha
    t = t - torch.amax(t, dim = dim, keepdim = True).detach()
    return (t * alpha).softmax(dim = dim)

def apply_pos_emb(pos_emb, qkv):
    # 应用位置编码到查询、键、值张量中
    n = qkv[0].shape[-2]
    pos_emb = pos_emb[..., :n, :]
    return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False,
                 static_mask = None):
        # 初始化 Attention 类
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.seq_len = seq_len
        self.scale = dim_head ** -0.5

        self.stable = stable
        self.causal = causal
        self.register_buffer('static_mask', static_mask, persistent=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
        # 前向传播函数
        b, n, _, h, device = *x.shape, self.heads, x.device
        softmax = torch.softmax if not self.stable else stable_softmax
        offset = cache.get('offset', 0) if exists(cache) else 0

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        if exists(rotary_pos_emb):
            q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v))

        q = q * self.scale

        if offset > 0:
            k_top, v_top = cache[cache_key]
            k = torch.cat([k_top, k], dim=-2)
            v = torch.cat([v_top, v], dim=-2)
        if exists(cache):
            cache[cache_key] = k, v

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
        mask_value = max_neg_value(dots)

        if exists(mask):
            mask = rearrange(mask, 'b j -> b () () j')
            dots.masked_fill_(~mask, mask_value)
            del mask

        if self.causal and offset == 0:  # causality is naturally enforced for the cached inference
            i, j = dots.shape[-2:]
            mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
            dots.masked_fill_(mask, mask_value)

        if exists(self.static_mask):
            dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value)

        attn = softmax(dots, dim=-1)

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

# 定义 SparseConvCausalAttention 类,实现稀疏注意力机制
class SparseConvCausalAttention(nn.Module):
    # 初始化函数,设置模型参数和超参数
    def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
        # 调用父类的初始化函数
        super().__init__()
        # 断言核大小必须为奇数
        assert kernel_size % 2 == 1, 'kernel size must be odd'

        # 计算内部维度
        inner_dim = dim_head *  heads
        # 设置序列长度
        self.seq_len = seq_len
        # 设置头数
        self.heads = heads
        # 设置缩放因子
        self.scale = dim_head ** -0.5
        # 设置图像大小
        self.image_size = image_size
        # 设置核大小
        self.kernel_size = kernel_size
        # 设置膨胀率
        self.dilation = dilation

        # 设置是否稳定
        self.stable = stable

        # 创建线性层,用于将输入转换为查询、键和值
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 创建输出层,包含线性层和dropout层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    # 定义前向传播函数,接受输入 x,mask 和旋转位置嵌入 rotary_pos_emb
    def forward(self, x, mask = None, rotary_pos_emb = None):
        # 解包 x 的形状信息,包括 batch 大小 b,序列长度 n,头数 h,图像大小 img_size,卷积核大小 kernel_size,膨胀率 dilation,序列长度 seq_len,设备信息 device
        b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
        # 根据是否稳定计算 softmax 函数
        softmax = torch.softmax if not self.stable else stable_softmax

        # 计算图像序列长度
        img_seq_len = img_size ** 2
        # 计算文本长度
        text_len = seq_len + 1 - img_seq_len

        # 填充

        # 计算填充长度
        padding = seq_len - n + 1
        # 如果 mask 为 None,则创建全为 True 的 mask 张量
        mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())

        # 对输入 x 进行填充
        x = F.pad(x, (0, 0, 0, padding), value = 0)
        # 裁剪 mask 的长度
        mask = mask[:, :text_len]

        # 求解查询 / 键 / 值

        # 将输入 x 转换为查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # 重排查询、键、值的维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

        # 如果存在旋转位置嵌入,则应用到查询、键、值上
        if exists(rotary_pos_emb):
            q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))

        # 缩放查询
        q *= self.scale

        # 分离文本查询、图像查询、文本键、图像键、文本值、图像值
        ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))

        # 文本注意力

        # 计算点积注意力得分
        dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
        # 计算 mask 的值
        mask_value = max_neg_value(dots_text)

        i, j = dots_text.shape[-2:]
        # 创建文本因果 mask
        text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
        dots_text.masked_fill_(text_causal_mask, mask_value)

        # 计算文本注意力权重
        attn_text = softmax(dots_text, dim = -1)
        out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

        # 图像注意力

        # 计算有效卷积核大小
        effective_kernel_size = (kernel_size - 1) * dilation + 1
        same_padding = effective_kernel_size // 2
        causal_padding = (same_padding * 2, 0, same_padding * 2, 0)

        # 重排图像键、值的维度
        k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img))
        # 对图像键、值进行填充
        k_img, v_img = map(lambda t: F.pad(t, causal_padding), (k_img, v_img))
        k_img, v_img = map(lambda t: F.unfold(t, kernel_size, dilation = dilation), (k_img, v_img))
        k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img))

        # 让图像关注所有文本

        dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
        dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)

        # 使用填充 mask 对张量进行填充和展开
        i, j = dots_image.shape[-2:]
        ones = torch.ones((img_seq_len,), device = device)
        ones = rearrange(ones, '(h w) -> () () h w', h = img_size)
        ones = F.pad(ones, causal_padding, value = 0.)
        ones = F.unfold(ones, kernel_size, dilation = dilation)
        ones = rearrange(ones, 'b j i -> b i j')

        # 对图像注意力进行 mask
        padding_mask = ones == 0.

        # 将文本 mask 与图像因果 mask 连接起来
        padding_mask = repeat(padding_mask, '() i j -> b i j', b = b * h)
        mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h)
        mask = torch.cat((~mask, padding_mask), dim = -1)

        # 图像可以关注所有文本

        dots = torch.cat((dots_image_to_text, dots_image), dim = -1)
        dots.masked_fill_(mask, mask_value)

        attn = softmax(dots, dim = -1)

        # 聚合

        attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]

        out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img)
        out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text)

        out_image = out_image_to_image + out_image_to_text

        # 合并文本和图像的注意力值

        out = torch.cat((out_text, out_image), dim = 1)

        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        out =  self.to_out(out)
        return out[:, :n]
# 稀疏轴向因果注意力机制

class SparseAxialCausalAttention(nn.Module):
    # 初始化函数,定义稀疏轴向因果注意力机制的参数
    def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
        super().__init__()
        # 断言轴向参数只能是0(沿高度)或1(沿宽度)
        assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
        self.axis = axis

        # 计算内部维度
        inner_dim = dim_head *  heads
        self.seq_len = seq_len
        self.heads = heads
        # 缩放因子
        self.scale = dim_head ** -0.5
        self.image_size = image_size

        # 是否稳定
        self.stable = stable

        # 线性变换,将输入维度映射到内部维度的3倍(用于查询、键、值)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 输出层,包含线性变换和dropout
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    # 定义前向传播函数,接受输入 x,mask 和旋转位置嵌入 rotary_pos_emb
    def forward(self, x, mask = None, rotary_pos_emb = None):
        # 解包 x 的形状信息,包括 batch 大小 b,序列长度 n,头数 h,图像大小 img_size,轴 axis,序列长度 seq_len,设备 device
        b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
        # 根据是否稳定计算 softmax 函数
        softmax = torch.softmax if not self.stable else stable_softmax

        # 计算图像序列长度和文本序列长度
        img_seq_len = img_size ** 2
        text_len = seq_len + 1 - img_seq_len

        # 填充

        # 计算需要填充的长度
        padding = seq_len - n + 1
        # 如果 mask 为 None,则创建全为 True 的 mask 张量
        mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())

        # 对输入 x 进行填充
        x = F.pad(x, (0, 0, 0, padding), value = 0)
        mask = mask[:, :text_len]

        # 求解查询 / 键 / 值

        # 将输入 x 转换为查询、键、值,并按维度 -1 切分
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

        # 如果存在旋转位置嵌入,则应用到查询、键、值上
        if exists(rotary_pos_emb):
            q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))

        # 缩放查询
        q *= self.scale

        # 拆分文本查询、图像查询、文本键、图像键、文本值、图像值
        ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))

        # 文本注意力

        # 计算文本查询和文本键的点积
        dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
        mask_value = max_neg_value(dots_text)

        i, j = dots_text.shape[-2:]
        # 创建文本因果 mask
        text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
        dots_text.masked_fill_(text_causal_mask, mask_value)

        # 计算文本注意力权重
        attn_text = softmax(dots_text, dim = -1)
        out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

        # 图像注意力

        # 根据轴 axis 拆分图像查询、图像键、图像值
        split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
        merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'

        # 拆分轴

        q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))

        # 相似度

        dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
        dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)

        dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)

        # mask 以使图像对文本有完全注意力,但沿轴是因果的

        bh, x, i, j = dots.shape
        causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
        causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)

        mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
        mask = torch.cat((~mask, causal_mask), dim = -1)

        dots.masked_fill_(mask, mask_value)

        # 注意力

        attn = softmax(dots, dim = -1)

        # 聚合

        attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]

        out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
        out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)

        out_image = out_image_to_image + out_image_to_text

        # 合并轴

        out_image = rearrange(out_image, merge_axis_einops, x = img_size)

        # 合并文本和图像的注意力值

        out = torch.cat((out_text, out_image), dim = 1)

        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        out =  self.to_out(out)
        return out[:, :n]
# 定义 SparseAttention 类,继承自 Attention 类
class SparseAttention(Attention):
    # 初始化函数
    def __init__(
        self,
        *args,
        block_size = 16,  # 定义块大小,默认为16
        text_seq_len = 256,  # 定义文本序列长度,默认为256
        num_random_blocks = None,  # 定义随机块数,默认为None
        **kwargs
    ):
        super().__init__(*args, **kwargs)  # 调用父类的初始化函数
        from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig  # 导入相关模块
        self.block_size = block_size  # 设置块大小

        num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4)  # 计算随机块数
        global_block_indices = list(range(ceil(text_seq_len / block_size)))  # 计算全局块索引

        # 初始化稀疏自注意力机制
        self.attn_fn = SparseSelfAttention(
            sparsity_config = VariableSparsityConfig(
                num_heads = self.heads,
                block = self.block_size,
                num_random_blocks = num_random_blocks,
                global_block_indices = global_block_indices,
                attention = 'unidirectional' if self.causal else 'bidirectional'
            ),
            max_seq_length = self.seq_len,
            attn_mask_mode = 'add'
        )

    # 前向传播函数
    def forward(self, x, mask = None, rotary_pos_emb = None):
        b, n, _, h, device = *x.shape, self.heads, x.device  # 获取输入张量的形状和设备信息
        remainder = n % self.block_size  # 计算余数
        mask = default(mask, lambda: torch.ones(b, n, device = device).bool())  # 设置默认掩码

        if remainder > 0:
            padding = self.block_size - remainder  # 计算填充大小
            x = F.pad(x, (0, 0, 0, padding), value = 0)  # 对输入张量进行填充
            mask = F.pad(mask, (0, padding), value = False)  # 对掩码进行填充

        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 将输入张量转换为查询、键、值
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)  # 重排查询、键、值的维度

        if exists(rotary_pos_emb):  # 如果存在旋转位置编码
            q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))  # 应用位置编码

        key_pad_mask = None  # 初始化键掩码
        if exists(mask):  # 如果存在掩码
            key_pad_mask = ~mask  # 生成键掩码

        attn_mask = None  # 初始化注意力掩码
        if self.causal:  # 如果是因果注意力
            i, j = q.shape[-2], k.shape[-2]  # 获取查询和键的长度
            mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()  # 生成上三角掩码
            attn_mask = torch.zeros(i, j, device = device).to(q)  # 初始化注意力掩码
            mask_value = max_neg_value(q) / 2  # 计算掩码值
            attn_mask.masked_fill_(mask, mask_value)  # 填充注意力掩码

        # 使用稀疏自注意力机制进行计算
        out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排输出维度
        out = self.to_out(out)  # 输出层处理
        return out[:, :n]  # 返回结果

.\lucidrains\DALLE-pytorch\dalle_pytorch\dalle_pytorch.py

# 从 math 模块中导入 log2 和 sqrt 函数
from math import log2, sqrt
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 导入 numpy 库
import numpy as np

# 导入自定义模块
from axial_positional_embedding import AxialPositionalEmbedding
from einops import rearrange

# 从 dalle_pytorch 库中导入 distributed_utils 模块
from dalle_pytorch import distributed_utils
# 从 dalle_pytorch.vae 模块中导入 OpenAIDiscreteVAE 和 VQGanVAE 类
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
# 从 dalle_pytorch.transformer 模块中导入 Transformer 和 DivideMax 类

# helpers

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义类,始终返回指定值
class always():
    def __init__(self, val):
        self.val = val
    def __call__(self, x, *args, **kwargs):
        return self.val

# 判断张量是否为空
def is_empty(t):
    return t.nelement() == 0

# 计算带掩码的平均值
def masked_mean(t, mask, dim = 1):
    t = t.masked_fill(~mask[:, :, None], 0.)
    return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]

# 生成与给定形状相同的概率掩码
def prob_mask_like(shape, prob, device):
    return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 设置模型参数是否需要梯度
def set_requires_grad(model, value):
    for param in model.parameters():
        param.requires_grad = value

# 评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# sampling helpers

# 计算对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# Gumbel 采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# Top-k 采样
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 共享嵌入层
class SharedEmbedding(nn.Embedding):
    def __init__(self, linear, start_index, end_index, **kwargs):
        super().__init__(end_index - start_index, linear.weight.shape[1], **kwargs)
        del self.weight

        self.linear = linear
        self.start_index = start_index
        self.end_index = end_index

    def forward(self, input):
        return F.embedding(
            input, self.linear.weight[self.start_index:self.end_index], self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)

# 离散 VAE 类

# ResNet 块
class ResBlock(nn.Module):
    def __init__(self, chan):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(chan, chan, 1)
        )

    def forward(self, x):
        return self.net(x) + x

# 离散 VAE 类
class DiscreteVAE(nn.Module):
    def __init__(
        self,
        image_size = 256,
        num_tokens = 512,
        codebook_dim = 512,
        num_layers = 3,
        num_resnet_blocks = 0,
        hidden_dim = 64,
        channels = 3,
        smooth_l1_loss = False,
        temperature = 0.9,
        straight_through = False,
        reinmax = False,
        kl_div_loss_weight = 0.,
        normalization = ((*((0.5,) * 3), 0), (*((0.5,) * 3), 1))
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言图片大小必须是2的幂次方
        assert log2(image_size).is_integer(), 'image size must be a power of 2'
        # 断言层数必须大于等于1
        assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
        # 判断是否有残差块
        has_resblocks = num_resnet_blocks > 0

        # 初始化各种参数
        self.channels = channels
        self.image_size = image_size
        self.num_tokens = num_tokens
        self.num_layers = num_layers
        self.temperature = temperature
        self.straight_through = straight_through
        self.reinmax = reinmax

        # 创建编码簿
        self.codebook = nn.Embedding(num_tokens, codebook_dim)

        hdim = hidden_dim

        # 初始化编码器和解码器通道数
        enc_chans = [hidden_dim] * num_layers
        dec_chans = list(reversed(enc_chans))

        enc_chans = [channels, *enc_chans]

        dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
        dec_chans = [dec_init_chan, *dec_chans]

        enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))

        enc_layers = []
        dec_layers = []

        # 创建编码器和解码器的层
        for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
            enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
            dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))

        # 添加残差块
        for _ in range(num_resnet_blocks):
            dec_layers.insert(0, ResBlock(dec_chans[1]))
            enc_layers.append(ResBlock(enc_chans[-1]))

        if num_resnet_blocks > 0:
            dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))

        enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
        dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))

        # 创建编码器和解码器
        self.encoder = nn.Sequential(*enc_layers)
        self.decoder = nn.Sequential(*dec_layers)

        # 设置损失函数和 KL 散度损失权重
        self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
        self.kl_div_loss_weight = kl_div_loss_weight

        # 处理类内的归一化
        self.normalization = tuple(map(lambda t: t[:channels], normalization))

        # 注册外部参数
        self._register_external_parameters()

    def _register_external_parameters(self):
        """Register external parameters for DeepSpeed partitioning."""
        if (
                not distributed_utils.is_distributed
                or not distributed_utils.using_backend(
                    distributed_utils.DeepSpeedBackend)
        ):
            return

        deepspeed = distributed_utils.backend.backend_module
        deepspeed.zero.register_external_parameter(self, self.codebook.weight)

    def norm(self, images):
        if not exists(self.normalization):
            return images

        means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
        means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
        images = images.clone()
        images.sub_(means).div_(stds)
        return images

    @torch.no_grad()
    @eval_decorator
    def get_codebook_indices(self, images):
        logits = self(images, return_logits = True)
        codebook_indices = logits.argmax(dim = 1).flatten(1)
        return codebook_indices

    def decode(
        self,
        img_seq
    ):
        image_embeds = self.codebook(img_seq)
        b, n, d = image_embeds.shape
        h = w = int(sqrt(n))

        image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
        images = self.decoder(image_embeds)
        return images

    def forward(
        self,
        img,
        return_loss = False,
        return_recons = False,
        return_logits = False,
        temp = None
        ):
        # 从输入参数中获取图像、标记数量、图像大小和 KL 散度损失权重
        device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
        # 断言输入图像的形状符合要求
        assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'

        # 对输入图像进行归一化处理
        img = self.norm(img)

        # 将归一化后的图像输入编码器获取 logits
        logits = self.encoder(img)

        # 如果需要返回 logits,则直接返回,用于 DALL-E 训练中获取硬图像索引
        if return_logits:
            return logits

        # 获取温度参数,默认为 self.temperature
        temp = default(temp, self.temperature)

        # 使用 Gumbel Softmax 采样生成 one-hot 编码
        one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=self.straight_through)

        # 如果使用 straight-through 和 reinmax
        if self.straight_through and self.reinmax:
            # 使用 reinmax 提高二阶精度 - https://arxiv.org/abs/2304.08612
            # 算法 2
            one_hot = one_hot.detach()
            π0 = logits.softmax(dim=1)
            π1 = (one_hot + (logits / temp).softmax(dim=1)) / 2
            π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1)
            π2 = 2 * π1 - 0.5 * π0
            one_hot = π2 - π2.detach() + one_hot

        # 使用 one-hot 编码和 codebook 权重进行采样
        sampled = einsum('b n h w, n d -> b d h w', one_hot, self.codebook.weight)
        # 将采样结果输入解码器获取输出
        out = self.decoder(sampled)

        # 如果不需要返回损失,则直接返回输出
        if not return_loss:
            return out

        # 重构损失
        recon_loss = self.loss_fn(img, out)

        # KL 散度
        logits = rearrange(logits, 'b n h w -> b (h w) n')
        log_qy = F.log_softmax(logits, dim=-1)
        log_uniform = torch.log(torch.tensor([1. / num_tokens], device=device))
        kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target=True)

        # 计算总损失
        loss = recon_loss + (kl_div * kl_div_loss_weight)

        # 如果不需要返回重构图像,则直接返回总损失
        if not return_recons:
            return loss

        # 返回总损失和输出图像
        return loss, out
# 主要的 CLIP 类
class CLIP(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim_text = 512,  # 文本维度
        dim_image = 512,  # 图像维度
        dim_latent = 512,  # 潜在维度
        num_text_tokens = 10000,  # 文本标记数量
        text_enc_depth = 6,  # 文本编码器深度
        text_seq_len = 256,  # 文本序列长度
        text_heads = 8,  # 文本注意力头数
        num_visual_tokens = 512,  # 视觉标记数量
        visual_enc_depth = 6,  # 视觉编码器深度
        visual_heads = 8,  # 视觉注意力头数
        visual_image_size = 256,  # 视觉图像大小
        visual_patch_size = 32,  # 视觉图像块大小
        channels = 3  # 通道数
    ):
        super().__init__()
        # 创建文本嵌入层
        self.text_emb = nn.Embedding(num_text_tokens, dim_text)
        # 创建文本位置嵌入层
        self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
        # 创建文本变换器
        self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads, rotary_emb = False)
        # 创建文本到潜在空间的线性层
        self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)

        # 确保图像尺寸能够被图像块大小整除
        assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (visual_image_size // visual_patch_size) ** 2
        patch_dim = channels * visual_patch_size ** 2

        self.visual_patch_size = visual_patch_size
        # 创建图像块到嵌入空间的线性层
        self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
        # 创建图像位置嵌入层
        self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
        # 创建视觉变换器
        self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads, rotary_emb = False)
        # 创建图像到潜在空间的线性层
        self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)

        # 温度参数
        self.temperature = nn.Parameter(torch.tensor(1.))

    # 前向传播函数
    def forward(
        self,
        text,
        image,
        text_mask = None,
        return_loss = False
    ):
        b, device, p = text.shape[0], text.device, self.visual_patch_size

        # 文本嵌入
        text_emb = self.text_emb(text)
        text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))

        # 图像块提取
        image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        image_emb = self.to_visual_embedding(image_patches)
        image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))

        # 文本编码
        enc_text = self.text_transformer(text_emb, mask = text_mask)
        # 图像编码
        enc_image = self.visual_transformer(image_emb)

        # 计算文本潜在空间表示
        if exists(text_mask):
            text_latents = masked_mean(enc_text, text_mask, dim = 1)
        else:
            text_latents = enc_text.mean(dim = 1)

        # 计算图像潜在空间表示
        image_latents = enc_image.mean(dim = 1)

        # 线性变换
        text_latents = self.to_text_latent(text_latents)
        image_latents = self.to_visual_latent(image_latents)

        # 归一化
        text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))

        temp = self.temperature.exp()

        # 如果不需要计算损失,则返回相似度
        if not return_loss:
            sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
            return sim

        # 计算损失
        sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
        labels = torch.arange(b, device = device)
        loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
        return loss

# 主要的 DALL-E 类
class DALLE(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,
        vae,
        num_text_tokens = 10000,
        text_seq_len = 256,
        depth,
        heads = 8,
        dim_head = 64,
        reversible = False,
        attn_dropout = 0.,
        ff_dropout = 0,
        sparse_attn = False,
        attn_types = None,
        loss_img_weight = 7,
        stable = False,
        sandwich_norm = False,
        shift_tokens = True,
        rotary_emb = True,
        shared_attn_ids = None,
        shared_ff_ids = None,
        share_input_output_emb = False,
        optimize_for_inference = False,
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言确保 vae 是 DiscreteVAE、OpenAIDiscreteVAE 或 VQGanVAE 的实例
        assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'

        # 获取图像大小、图像标记数量、图像特征图大小和图像序列长度
        image_size = vae.image_size
        num_image_tokens = vae.num_tokens
        image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
        image_seq_len = image_fmap_size ** 2

        # 为每个位置(文本序列长度)保留唯一的填充标记
        num_text_tokens = num_text_tokens + text_seq_len
        # 创建文本位置嵌入和图像位置嵌入
        self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for <bos>
        self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)

        # 设置文本标记数量和图像标记数量
        self.num_text_tokens = num_text_tokens
        self.num_image_tokens = num_image_tokens

        # 设置文本序列长度和图像序列长度
        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        # 计算总序列长度和总标记数量
        seq_len = text_seq_len + image_seq_len
        total_tokens = num_text_tokens + num_image_tokens
        self.total_tokens = total_tokens
        self.total_seq_len = seq_len

        # 冻结 VAE 不参与训练
        self.vae = vae
        set_requires_grad(self.vae, False)

        # 创建 Transformer 模型
        self.transformer = Transformer(
            dim = dim,
            causal = True,
            seq_len = seq_len,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            reversible = reversible,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            attn_types = attn_types,
            image_fmap_size = image_fmap_size,
            sparse_attn = sparse_attn,
            stable = stable,
            sandwich_norm = sandwich_norm,
            shift_tokens = shift_tokens,
            rotary_emb = rotary_emb,
            shared_attn_ids = shared_attn_ids,
            shared_ff_ids = shared_ff_ids,
            optimize_for_inference = optimize_for_inference,
        )

        # 设置稳定性参数
        self.stable = stable

        # 如果稳定性为真,使用 DivideMax 进行归一化
        if stable:
            self.norm_by_max = DivideMax(dim = -1)

        # 转换为 logits
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        # 如果共享输入输出嵌入,创建共享嵌入层,否则创建独立嵌入层
        if share_input_output_emb:
            self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
            self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
        else:
            self.text_emb = nn.Embedding(num_text_tokens, dim)
            self.image_emb = nn.Embedding(num_image_tokens, dim)

        # 创建序列范围和 logits 范围
        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        # 创建 logits 掩码
        logits_mask = (
            ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
            ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
        )

        # 注册 logits 掩码为缓冲区
        self.register_buffer('logits_mask', logits_mask, persistent=False)
        self.loss_img_weight = loss_img_weight


    @torch.no_grad()
    @eval_decorator
    def generate_texts(
        self,
        tokenizer,
        text = None,
        *,
        filter_thres = 0.5,
        temperature = 1.
        ):
        # 获取文本序列长度
        text_seq_len = self.text_seq_len
        # 如果文本为空或者为None,则将文本tokens设置为0,并移至GPU
        if text is None or text == "":
            text_tokens = torch.tensor([[0]]).cuda()
        else:
            # 将文本编码为tokens,并移至GPU
            text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)

        # 循环直到文本tokens长度达到指定长度
        for _ in range(text_tokens.shape[1], text_seq_len):
            # 获取当前设备
            device = text_tokens.device

            # 获取文本tokens的嵌入
            tokens = self.text_emb(text_tokens)
            # 添加文本位置嵌入
            tokens += self.text_pos_emb(torch.arange(text_tokens.shape[1], device=device))

            # 获取tokens序列长度
            seq_len = tokens.shape[1]

            # 使用transformer处理tokens
            output_transf = self.transformer(tokens)

            # 如果启用了稳定性,对输出进行归一化
            if self.stable:
                output_transf = self.norm_by_max(output_transf)

            # 获取logits
            logits = self.to_logits(output_transf)

            # 对logits进行掩码,确保文本预测文本(除了最后一个token),图像预测图像
            logits_mask = self.logits_mask[:, :seq_len]
            max_neg_value = -torch.finfo(logits.dtype).max
            logits.masked_fill_(logits_mask, max_neg_value)
            logits = logits[:, -1, :]

            # 从logits中筛选出top k的token
            filtered_logits = top_k(logits, thres=filter_thres)
            # 使用Gumbel采样获取样本
            sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)

            # 将新样本添加到文本tokens中
            text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1)

        # 创建填充tokens集合
        padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
        # 解码文本tokens,获取文本列表
        texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
        return text_tokens, texts

    @torch.no_grad()
    @eval_decorator
    def generate_images(
        self,
        text,
        *,
        clip=None,
        filter_thres=0.5,
        temperature=1.,
        img=None,
        num_init_img_tokens=None,
        cond_scale=1.,
        use_cache=False,
    ):
        # 获取VAE模型、文��序列长度、图像序列长度、文本tokens数量
        vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
        # 计算总长度
        total_len = text_seq_len + image_seq_len

        # 确保文本在指定范围内
        text = text[:, :text_seq_len]
        out = text

        # 如果存在图像输入
        if exists(img):
            # 获取图像大小
            image_size = vae.image_size
            assert img.shape[1] == 3 and img.shape[2] == image_size and img.shape[3] == image_size, f'input image must have the correct image size {image_size}'

            # 获取图像的codebook索引
            indices = vae.get_codebook_indices(img)
            # 设置初始图像tokens数量
            num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len))  # OpenAI used 14 * 32 initial tokens to prime
            assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'

            indices = indices[:, :num_img_tokens]
            out = torch.cat((out, indices), dim=-1)

        prev_cache = None
        cache = {} if use_cache else None
        # 循环直到out的长度达到总长度
        for cur_len in range(out.shape[1], total_len):
            is_image = cur_len >= text_seq_len

            text, image = out[:, :text_seq_len], out[:, text_seq_len:]

            # 使用条件缩放处理文本和图像
            logits = self.forward_with_cond_scale(text, image, cond_scale=cond_scale, cache=cache)
            logits = logits[:, -1, :]

            # 从logits中筛选出top k的token
            filtered_logits = top_k(logits, thres=filter_thres)
            # 使用Gumbel采样获取样本
            sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)

            # 如果是图像token,减去num_text_tokens的偏移量
            sample -= (num_text_tokens if is_image else 0)
            out = torch.cat((out, sample[:, None]), dim=-1)

        # 获取文本序列和图像序列
        text_seq = out[:, :text_seq_len]
        img_seq = out[:, -image_seq_len:]
        # 解码图像序列
        images = vae.decode(img_seq)

        # 如果存在clip模型
        if exists(clip):
            # 使用clip模型评分
            scores = clip(text_seq, images, return_loss=False)
            return images, scores

        return images
    # 定义一个带有条件缩放参数的前向传播函数
    def forward_with_cond_scale(self, *args, cond_scale = 1, cache = None, **kwargs):
        # 如果条件缩放参数为1,则直接调用原始的前向传播函数
        if cond_scale == 1:
            return self(*args, **kwargs)

        # 如果缓存存在,则复制缓存,否则设为None
        prev_cache = cache.copy() if exists(cache) else None
        # 调用原始的前向传播函数,传入缓存参数
        logits = self(*args, cache = cache, **kwargs)

        # Katherine Crowson的发现
        # https://twitter.com/RiversHaveWings/status/1478093658716966912
        # 传入空条件概率为1的参数,调用原始的前向传播函数
        null_cond_logits = self(*args, null_cond_prob = 1., cache = prev_cache, **kwargs)
        # 返回空条件logits加上(原始logits减去空条件logits)乘以条件缩放参数的结果
        return null_cond_logits + (logits - null_cond_logits) * cond_scale

    # 定义一个前向传播函数,接受文本、图像、是否返回损失、空条件概率和缓存等参数
    def forward(
        self,
        text,
        image = None,
        return_loss = False,
        null_cond_prob = 0.,
        cache = None,
    ):
        # 检查传入的文本张量是否与指定的文本序列长度相匹配
        assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
        # 获取文本张量的批次大小、设备信息和总序列长度
        batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len

        # 以 <null_cond_prob> 的概率随机移除文本条件

        if null_cond_prob > 0:
            # 创建一个与文本张量形状相同的概率掩码,用于随机移除文本条件
            null_mask = prob_mask_like((batch,), null_cond_prob, device=device)
            # 将文本张量中的部分内容根据概率掩码置零
            text *= rearrange(~null_mask, 'b -> b 1')

        # 确保文本标记中的填充获得唯一的填充标记ID

        # 生成文本范围,用于替换文本张量中的填充标记
        text_range = torch.arange(self.text_seq_len, device=device) + (self.num_text_tokens - self.text_seq_len)
        text = torch.where(text == 0, text_range, text)

        # 添加 <bos> 标记

        # 在文本张量的开头添加一个零值填充
        text = F.pad(text, (1, 0), value=0)

        # 对文本进行嵌入处理
        tokens = self.text_emb(text)
        # 添加文本位置编码
        tokens += self.text_pos_emb(torch.arange(text.shape[1], device=device))

        seq_len = tokens.shape[1]

        # 如果存在图像且图像不为空
        if exists(image) and not is_empty(image):
            is_raw_image = len(image.shape) == 4

            if is_raw_image:
                # 获取图像的代码簿索引
                image_size = self.vae.image_size
                channels = self.vae.channels
                assert tuple(image.shape[1:]) == (channels, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'

                image = self.vae.get_codebook_indices(image)

            image_len = image.shape[1]
            image_emb = self.image_emb(image)

            # 添加图像位置编码
            image_emb += self.image_pos_emb(image_emb)

            # 将文本和图像嵌入连接起来
            tokens = torch.cat((tokens, image_emb), dim=1)

            seq_len += image_len

        # 在训练时,如果长度超过总文本+图像长度,则移除最后一个标记,因为不需要对其进行训练

        if tokens.shape[1] > total_seq_len:
            seq_len -= 1
            tokens = tokens[:, :-1]

        # ���果启用稳定性训练
        if self.stable:
            alpha = 0.1
            # 对 tokens 进行稳定性训练
            tokens = tokens * alpha + tokens.detach() * (1 - alpha)

        # 如果存在缓存且缓存中有 'offset' 键
        if exists(cache) and cache.get('offset'):
            # 仅保留 tokens 的最后一个标记
            tokens = tokens[:, -1:]
        # 使用 transformer 进行处理,传入缓存信息
        out = self.transformer(tokens, cache=cache)

        # 如果启用稳定性训练
        if self.stable:
            # 对输出进行最大归一化
            out = self.norm_by_max(out)

        # 将输出转换为 logits
        logits = self.to_logits(out)

        # 对 logits 进行掩码处理,确保文本预测文本(除最后一个标记),图像预测图像

        logits_mask = self.logits_mask[:, :seq_len]
        if exists(cache) and cache.get('offset'):
            logits_mask = logits_mask[:, -1:]
        max_neg_value = -torch.finfo(logits.dtype).max
        logits.masked_fill_(logits_mask, max_neg_value)

        # 如果存在缓存
        if exists(cache):
            # 更新缓存中的 'offset' 键
            cache['offset'] = cache.get('offset', 0) + logits.shape[1]

        # 如果不需要返回损失值
        if not return_loss:
            return logits

        # 断言在训练时必须提供图像
        assert exists(image), 'when training, image must be supplied'

        # 对图像进行偏移处理
        offsetted_image = image + self.num_text_tokens
        # 创建标签,用于计算损失
        labels = torch.cat((text[:, 1:], offsetted_image), dim=1)

        # 重新排列 logits 的维度
        logits = rearrange(logits, 'b n c -> b c n')

        # 计算文本损失和图像损失
        loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
        loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])

        # 计算总损失
        loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
        return loss

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\deepspeed_backend.py

import json
import os

import torch

from .distributed_backend import DistributedBackend


class DeepSpeedBackend(DistributedBackend):
    """使用 DeepSpeed 引擎的分布式后端。"""

    BACKEND_MODULE_NAME = 'deepspeed'
    BACKEND_NAME = 'DeepSpeed'

    def wrap_arg_parser(self, parser):
        if not self.has_backend():
            parser.add_argument(
                '--deepspeed',
                type=lambda _: False,
                help=(
                    '是否使用 DeepSpeed '
                    "(由于不可用,此选项被忽略)"
                ),
            )
        else:
            parser = self.backend_module.add_config_arguments(parser)

        parser.add_argument(
            '--local_rank',
            type=int,
            default=-1,
            help='从分布式启动器传递的本地排名',
        )
        return parser

    def _initialize(self):
        self.backend_module.init_distributed()
        if torch.cuda.is_available():
            torch.cuda.set_device(self._get_local_rank())

    @staticmethod
    def _require_torch_distributed_init():
        """当 `torch.distributed` 尚未初始化时引发错误。"""
        assert torch.distributed.is_initialized(), \
            ('`torch.distributed` 未初始化;请在脚本开头调用 '
             '`DeepSpeedBackend.initialize`')

    def _get_world_size(self):
        self._require_torch_distributed_init()
        return torch.distributed.get_world_size()

    def _get_rank(self):
        self._require_torch_distributed_init()
        return torch.distributed.get_rank()

    def _get_local_rank(self):
        self._require_torch_distributed_init()
        return int(os.environ['LOCAL_RANK'])

    def _local_barrier(self):
        self._require_torch_distributed_init()
        torch.distributed.barrier()

    def _check_args(self, args, optimizer, lr_scheduler, kwargs):
        """在检查传递给 `distribute` 的值后,返回适当的优化器和学习率调度器。"""
        self._check_argvs(args, optimizer, lr_scheduler, kwargs)
        (optimizer, lr_scheduler) = self._check_config(
            args, optimizer, lr_scheduler, kwargs)
        return (optimizer, lr_scheduler)

    def _check_argvs(self, args, optimizer, lr_scheduler, kwargs):
        """对给定的命令行参数应用几个合理性检查。"""
        has_json_config = (hasattr(args, 'deepspeed_config')
                           and args.deepspeed_config is not None)
        has_dict_config = 'config_params' in kwargs
        if (
                # 没有给定配置
                (not has_json_config and not has_dict_config)
                # JSON 配置文件不存在
                or (not has_dict_config
                    and not os.path.isfile(args.deepspeed_config))
        ):
            # 让 DeepSpeed 处理这些参数错误。
            return

        if not args.deepspeed:
            print(
                '警告:已选择 DeepSpeed 后端;设置 `args.deepspeed = True`'
            )
            args.deepspeed = True

        if has_json_config and has_dict_config:
            print(
                '警告:DeepSpeed 配置同时以 JSON 文件和 Python 字典形式给出。Python 字典优先。'
            )
    def _check_config(self, args, optimizer, lr_scheduler, kwargs):
        """Return an appropriate optimizer and learning rate scheduler
        for the DeepSpeed configuration.
        """
        # 检查 DeepSpeed 配置,根据情况返回优化器和学习率调度器
        if 'config_params' in kwargs:
            config = kwargs['config_params']
        else:
            with open(args.deepspeed_config, 'r') as json_config_file:
                config = json.load(json_config_file)

        if 'optimizer' in config and optimizer is not None:
            print(
                'WARNING: Optimizer encountered in both DeepSpeed config and '
                'keyword arguments. Optimizer in DeepSpeed config '
                'takes precedence.'
            )
            optimizer = None

        if 'scheduler' in config and lr_scheduler is not None:
            print(
                'WARNING: Learning rate scheduler encountered in both '
                'DeepSpeed config and keyword arguments. Learning rate '
                'scheduler in DeepSpeed config takes precedence.'
            )
            # 对于 LR 调度器,JSON 配置已经具有优先权。我们这样做是为了向前兼容。
            lr_scheduler = None

        return (optimizer, lr_scheduler)

    def _distribute(
            self,
            args=None,
            model=None,
            optimizer=None,
            model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **kwargs,
    ):
        """Return a distributed model engine, optimizer, dataloader, and
        learning rate scheduler. These are obtained by wrapping the
        given values with the backend.

        For the other or other possible arguments,
        see `deepspeed.initialize`.
        """
        (optimizer, lr_scheduler) = self._check_args(
            args, optimizer, lr_scheduler, kwargs)

        return self.backend_module.initialize(
            args=args,
            model=model,
            optimizer=optimizer,
            model_parameters=model_parameters,
            training_data=training_data,
            lr_scheduler=lr_scheduler,
            **kwargs,
        )

    def _average_all(self, tensor):
        self._require_torch_distributed_init()
        # We copy because modification happens in-place
        averaged = tensor.detach().clone()
        # We use `all_reduce` because it is better supported than `reduce`
        torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM)
        return averaged / self.get_world_size()

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\distributed_backend.py

"""
An abstract backend for distributed deep learning.

Provides several standard utility methods under a common API.
Please check the documentation of the class `DistributedBackend` for
details to implement a new backend.
"""

from importlib import import_module


class DistributedBackend:
    """An abstract backend class for distributed deep learning.

    Provides several standard utility methods under a common API.
    Variables that must be overridden:
    - BACKEND_MODULE_NAME
    - BACKEND_NAME
    Methods that must be overridden:
    - wrap_arg_parser
    - _initialize
    - _get_world_size
    - _get_rank
    - _get_local_rank
    - _local_barrier
    - _distribute
    - _average_all
    """

    BACKEND_MODULE_NAME = None
    """Name of the module to import for the backend."""
    BACKEND_NAME = None
    """Name of the backend for printing."""

    ROOT_RANK = 0

    backend_module = None
    """The module to access the backend."""
    is_initialized = False
    """Whether the backend is initialized."""

    def __init__(self):
        if self.BACKEND_MODULE_NAME is None:
            raise NotImplementedError('BACKEND_MODULE_NAME is not set')
        if self.BACKEND_NAME is None:
            raise NotImplementedError('BACKEND_NAME is not set')

    def has_backend(self):
        """Return whether the backend module is now imported."""
        try:
            self.backend_module = import_module(self.BACKEND_MODULE_NAME)
        except ModuleNotFoundError:
            return False
        return True

    def check_batch_size(self, batch_size):
        """Check whether the batch size makes sense for distribution."""
        assert batch_size >= self.get_world_size(), \
            (f"batch size can't be smaller than number of processes "
             f'({batch_size} < {self.get_world_size()})')

    def wrap_arg_parser(self, parser):
        """Add arguments to support optional distributed backend usage."""
        raise NotImplementedError

    def initialize(self):
        """Initialize the distributed backend."""
        self._initialize()
        self.is_initialized = True

    def _initialize(self):
        """Initialize the distributed backend."""
        raise NotImplementedError

    def require_init(self):
        """Raise an error when the backend has not been initialized yet."""
        assert self.is_initialized, \
            (f'{BACKEND_NAME} backend has not been initialized; please call '
             f'`distributed_utils.initialize` at the start of your script to '
             f'allow optional distributed usage')

    def get_world_size(self):
        """Return the amount of distributed processes."""
        self.require_init()
        return self._get_world_size()

    def _get_world_size(self):
        """Return the amount of distributed processes."""
        raise NotImplementedError

    def get_rank(self):
        """Return the global rank of the calling worker process."""
        self.require_init()
        return self._get_rank()

    def _get_rank(self):
        """Return the global rank of the calling worker process."""
        raise NotImplementedError

    def get_local_rank(self):
        """Return the local rank of the calling worker process.
        The local rank is the rank based on a single node's processes.
        """
        self.require_init()
        return self._get_local_rank()

    def _get_local_rank(self):
        """Return the local rank of the calling worker process.
        The local rank is the rank based on a single node's processes.
        """
        raise NotImplementedError

    def is_root_worker(self):
        """Return whether the calling worker has the root rank."""
        return self.get_rank() == self.ROOT_RANK

    def is_local_root_worker(self):
        """Return whether the calling worker has the root rank on this node."""
        return self.get_local_rank() == self.ROOT_RANK
    def local_barrier(self):
        """Wait until all processes on this node have called this function."""
        # 确保初始化已完成
        self.require_init()
        # 调用本地屏障函数
        self._local_barrier()

    def _local_barrier(self):
        """Wait until all processes on this node have called this function."""
        # 抛出未实现错误
        raise NotImplementedError

    def distribute(
            self,
            args=None,
            model=None,
            optimizer=None,
            model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **kwargs,
    ):
        """Return a distributed model engine, optimizer, dataloader, and
        learning rate scheduler. These are obtained by wrapping the
        given values with the backend.
        """
        # 确保初始化已完成
        self.require_init()
        # 调用分发函数
        return self._distribute(
            args,
            model,
            optimizer,
            model_parameters,
            training_data,
            lr_scheduler,
            **kwargs,
        )

    def _distribute(
            self,
            args=None,
            model=None,
            optimizer=None,
            model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **kwargs,
    ):
        """Return a distributed model engine, optimizer, dataloader, and
        learning rate scheduler. These are obtained by wrapping the
        given values with the backend.
        """
        # 抛出未实现错误
        raise NotImplementedError

    def average_all(self, tensor):
        """Return the average of `tensor` over all workers."""
        # 确保初始化已完成
        self.require_init()
        # 返回所有工作进程上张量的平均值
        return self._average_all(tensor)

    def _average_all(self, tensor):
        """Return the average of `tensor` over all workers."""
        # 抛出未实现错误
        raise NotImplementedError

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\dummy_backend.py

# 导入分布式后端类 DistributedBackend
from .distributed_backend import DistributedBackend

# 定义一个虚拟的分布式后端类 DummyBackend,继承自 DistributedBackend
class DummyBackend(DistributedBackend):
    """Acts like a distributed backend.

    Used as a stand-in replacement to obtain a non-distributed program.
    """

    # 定义一个常量 BACKEND_MODULE_NAME 为 'NO MODULE'
    BACKEND_MODULE_NAME = 'NO MODULE'
    # 定义一个常量 BACKEND_NAME 为 'Dummy'
    BACKEND_NAME = 'Dummy'

    # 检查是否存在后端
    def has_backend(self):
        return True

    # 包装参数解析器,返回原参数解析器
    def wrap_arg_parser(self, parser):
        return parser

    # 初始化方法,不做任何操作
    def _initialize(self):
        pass

    # 获取世界大小,返回 1
    def _get_world_size(self):
        return 1

    # 获取当前进程的排名,返回 ROOT_RANK
    def _get_rank(self):
        return self.ROOT_RANK

    # 获取本地排名,返回 ROOT_RANK
    def _get_local_rank(self):
        return self.ROOT_RANK

    # 本地屏障,不做任何操作
    def _local_barrier(self):
        pass

    # 分发方法,返回模型、优化器、数据加载器和学习率调度器
    def _distribute(
            self,
            _args=None,
            model=None,
            optimizer=None,
            _model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **_kwargs,
    ):
        """Return the model, optimizer, dataloader, and learning rate scheduler
        as is.
        """
        return (model, optimizer, training_data, lr_scheduler)

    # 对所有张量进行平均操作,返回原张量
    def _average_all(self, tensor):
        return tensor

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\horovod_backend.py

import torch
# 导入 torch 库

from .distributed_backend import DistributedBackend
# 从当前目录下的 distributed_backend 模块中导入 DistributedBackend 类

class HorovodBackend(DistributedBackend):
    """Distributed backend using Horovod."""
    # 使用 Horovod 的分布式后端

    BACKEND_MODULE_NAME = 'horovod.torch'
    BACKEND_NAME = 'Horovod'
    # 定义后端模块名和后端名称

    def wrap_arg_parser(self, parser):
        return parser
    # 包装参数解析器

    def check_batch_size(self, batch_size):
        # Horovod 使用本地批大小来确定有效批大小
        pass
    # 检查批大小

    def _initialize(self):
        self.backend_module.init()
        # 初始化后端模块
        if torch.cuda.is_available():
            torch.cuda.set_device(self._get_local_rank())
        # 如果 CUDA 可用,则设置当前设备为本地排名对应的设备

    def _get_world_size(self):
        return self.backend_module.size()
    # 获取世界大小

    def _get_rank(self):
        return self.backend_module.rank()
    # 获取排名

    def _get_local_rank(self):
        return self.backend_module.local_rank()
    # 获取本地排名

    def _local_barrier(self):
        # 实际上是全局屏障,但对我们的目的有效
        self.backend_module.join()
    # 本地屏障

    def _distribute(
            self,
            _args=None,
            model=None,
            optimizer=None,
            _model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **_kwargs,
    ):
        optimizer = self.backend_module.DistributedOptimizer(optimizer)
        # 使用后端模块的 DistributedOptimizer 对象对优化器进行分布式处理
        self.backend_module.broadcast_parameters(
            model.state_dict(), root_rank=self.ROOT_RANK)
        # 广播模型参数
        self.backend_module.broadcast_optimizer_state(
            optimizer, root_rank=self.ROOT_RANK)
        # 广播优化器状态
        return (model, optimizer, training_data, lr_scheduler)
    # 分发模型、优化器、训练数据和学习率调度器

    def _average_all(self, tensor):
        # 默认情况下,减少操作是平均值
        averaged = self.backend_module.allreduce(tensor)
        # 对张量进行全局平均值操作
        return averaged
    # 对所有张量进行平均值操作

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\__init__.py

# 从当前目录中导入 DeepSpeedBackend 模块
from .deepspeed_backend import DeepSpeedBackend
# 从当前目录中导入 DistributedBackend 模块
from .distributed_backend import DistributedBackend
# 从当前目录中导入 DummyBackend 模块
from .dummy_backend import DummyBackend
# 从当前目录中导入 HorovodBackend 模块
from .horovod_backend import HorovodBackend

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_utils.py

"""
Utility functions for optional distributed execution.

To use,
1. set the `BACKENDS` to the ones you want to make available,
2. in the script, wrap the argument parser with `wrap_arg_parser`,
3. in the script, set and use the backend by calling
   `set_backend_from_args`.

You can check whether a backend is in use with the `using_backend`
function.
"""

from dalle_pytorch.distributed_backends import \
    DeepSpeedBackend, \
    DummyBackend, \
    HorovodBackend

_DEFAULT_BACKEND = DummyBackend()
"""Which backend to use by default. Assumed to be _not_ distributed."""

BACKENDS = [
    _DEFAULT_BACKEND,
    DeepSpeedBackend(),
    HorovodBackend(),
]

is_distributed = None
"""Whether we are distributed."""
backend = None
"""Backend in usage."""


def wrap_arg_parser(parser):
    """Add arguments to support optional distributed backend usage."""
    parser.add_argument(
        '--distributed_backend',
        '--distr_backend',
        type=str,
        default=None,
        help='which distributed backend to use. Do not distribute by default',
    )
    for distr_backend in BACKENDS:
        parser = distr_backend.wrap_arg_parser(parser)
    return parser


def set_backend_from_args(args):
    """Set and return the backend based on the given `args`."""
    global is_distributed, backend

    # Handle this specially for backwards compatibility.
    if args.deepspeed:
        args.distributed_backend = DeepSpeedBackend.BACKEND_NAME

    if not args.distributed_backend:
        is_distributed = False
        backend = _DEFAULT_BACKEND
        return backend

    backend_name = args.distributed_backend.lower()
    for distr_backend in BACKENDS:
        if distr_backend.BACKEND_NAME.lower() == backend_name:
            backend = distr_backend
            if not backend.has_backend():
                raise ModuleNotFoundError(
                    f'{backend.BACKEND_NAME} backend selected but '
                    'module not available'
                )

            print(f'Using {backend.BACKEND_NAME} for distributed execution')
            is_distributed = True
            return backend

    raise ValueError(
        'unknown backend; please check `distributed_utils.BACKENDS`')


def require_set_backend():
    """Raise an `AssertionError` when the backend has not been set."""
    assert backend is not None, (
        'distributed backend is not set. Please call '
        '`distributed_utils.set_backend_from_args` at the start of your script'
    )


def using_backend(test_backend):
    """Return whether the backend is set to `test_backend`.

    `test_backend` may be a string of the name of the backend or
    its class.
    """
    require_set_backend()
    if isinstance(test_backend, str):
        return backend.BACKEND_NAME == test_backend
    return isinstance(backend, test_backend)

.\lucidrains\DALLE-pytorch\dalle_pytorch\loader.py

from pathlib import Path
from random import randint, choice
import PIL
from torch.utils.data import Dataset
from torchvision import transforms as T

class TextImageDataset(Dataset):
    def __init__(self,
                 folder,
                 text_len=256,
                 image_size=128,
                 truncate_captions=False,
                 resize_ratio=0.75,
                 transparent=False,
                 tokenizer=None,
                 shuffle=False
                 ):
        """
        @param folder: 包含图像和文本文件的文件夹,它们通过其路径的相应“stem”匹配
        @param truncate_captions: 如果标题太长,将截断标题而不是抛出异常
        """
        super().__init__()
        self.shuffle = shuffle
        path = Path(folder)

        # 获取所有文本文件和图像文件的路径
        text_files = [*path.glob('**/*.txt')]
        image_files = [
            *path.glob('**/*.png'), *path.glob('**/*.jpg'),
            *path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
        ]

        # 将文本文件和图像文件的stem作为key,文件路径作为value存储在字典中
        text_files = {text_file.stem: text_file for text_file in text_files}
        image_files = {image_file.stem: image_file for image_file in image_files}

        # 获取文本文件和图像文件stem的交集作为keys
        keys = (image_files.keys() & text_files.keys())

        self.keys = list(keys)
        self.text_files = {k: v for k, v in text_files.items() if k in keys}
        self.image_files = {k: v for k, v in image_files.items() if k in keys}
        self.text_len = text_len
        self.truncate_captions = truncate_captions
        self.resize_ratio = resize_ratio
        self.tokenizer = tokenizer

        # 根据是否透明设置图像模式
        image_mode = 'RGBA' if transparent else 'RGB'

        # 图像转换操作
        self.image_transform = T.Compose([
            T.Lambda(lambda img: img.convert(image_mode)
            if img.mode != image_mode else img),
            T.RandomResizedCrop(image_size,
                                scale=(self.resize_ratio, 1.),
                                ratio=(1., 1.)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.keys)

    def random_sample(self):
        return self.__getitem__(randint(0, self.__len__() - 1))

    def sequential_sample(self, ind):
        if ind >= self.__len__() - 1:
            return self.__getitem__(0)
        return self.__getitem__(ind + 1)

    def skip_sample(self, ind):
        if self.shuffle:
            return self.random_sample()
        return self.sequential_sample(ind=ind)

    def __getitem__(self, ind):
        key = self.keys[ind]

        text_file = self.text_files[key]
        image_file = self.image_files[key]

        # 读取文本文件内容并按换行符分割
        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        try:
            description = choice(descriptions)
        except IndexError as zero_captions_in_file_ex:
            print(f"An exception occurred trying to load file {text_file}.")
            print(f"Skipping index {ind}")
            return self.skip_sample(ind)

        # 对文本进行标记化处理
        tokenized_text = self.tokenizer.tokenize(
            description,
            self.text_len,
            truncate_text=self.truncate_captions
        ).squeeze(0)
        try:
            image_tensor = self.image_transform(PIL.Image.open(image_file))
        except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
            print(f"An exception occurred trying to load file {image_file}.")
            print(f"Skipping index {ind}")
            return self.skip_sample(ind)

        # 成功返回标记化的文本和图像张量
        return tokenized_text, image_tensor

.\lucidrains\DALLE-pytorch\dalle_pytorch\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数

# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    # 遍历匹配的键
    for key in matched_keys:
        val = args[key]
        # 遍历路由后的参数列表和路由器中的路由
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            # 根据路由将参数添加到对应的函数参数中
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的保存和设置随机数生成器
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx

class _ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, x, blocks, args):
        ctx.args = args
        for block, kwarg in zip(blocks, args):
            x = block(x, **kwarg)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    @staticmethod
    # 定义反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, dy):
        # 获取上下文中的 y 和 args
        y = ctx.y
        args = ctx.args
        # 反向遍历上下文中的 blocks 和 args
        for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
            # 调用每个 block 的反向传播函数,更新 y 和 dy
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度
        return dy, None, None
class SequentialSequence(nn.Module):
    # 定义一个顺序执行的神经网络模块
    def __init__(self, layers, args_route = {}, layer_dropout = 0.):
        super().__init__()
        # 断言每个参数路由映射的深度与顺序层的数量相同
        assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
        self.layers = layers
        self.args_route = args_route
        self.layer_dropout = layer_dropout

    def forward(self, x, **kwargs):
        # 根据参数路由和关键字参数获取参数
        args = route_args(self.args_route, kwargs, len(self.layers))
        layers_and_args = list(zip(self.layers, args))

        for (f, g), (f_args, g_args) in layers_and_args:
            # 执行顺序层中的函数 f 和 g,并将结果与输入 x 相加
            x = x + f(x, **f_args)
            x = x + g(x, **g_args)
        return x

class ReversibleSequence(nn.Module):
    # 定义一个可逆的神经网络模块
    def __init__(self, blocks, args_route = {}):
        super().__init__()
        self.args_route = args_route
        # 创建包含可逆块的模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])

    def forward(self, x, **kwargs):
        # 在最后一个维度上将输入 x 进行拼接
        x = torch.cat([x, x], dim=-1)

        blocks = self.blocks
        # 根据参数路由和关键字参数获取参数
        args = route_args(self.args_route, kwargs, len(blocks))
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        # 调用自定义的可逆函数 _ReversibleFunction 来执行可逆操作
        out =  _ReversibleFunction.apply(x, blocks, args)
        # 在最后一个维度上将输出拆分成两部分,然后取平均值
        return torch.stack(out.chunk(2, dim=-1)).mean(dim=0)

.\lucidrains\DALLE-pytorch\dalle_pytorch\transformer.py

# 导入必要的库
from collections import deque
from collections.abc import Iterable
from functools import partial
from itertools import islice, cycle

import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange

# 导入自定义模块
from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention

# 导入旋转嵌入模块
from rotary_embedding_torch import RotaryEmbedding, broadcat

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 将变量转换为元组
def cast_tuple(val, depth = 1):
    return val if isinstance(val, Iterable) else (val,) * depth

# 类

# 最大值分割类
class DivideMax(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        maxes = x.amax(dim = self.dim, keepdim = True).detach()
        return x / maxes

# 非缓存类
class NonCached(nn.Module):
    """
    A wrapper for layers that don't support the inference cache themselves.
    Reconstructs the full sequence before the layer and
    cuts the suffix of the outputs after the layer.
    """

    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *, cache = None, cache_key = None, **kwargs):
        n = x.shape[-2]
        if exists(cache):
            if cache_key in cache:
                x = torch.cat([cache[cache_key], x], dim=-2)
            cache[cache_key] = x

        out = self.fn(x, **kwargs)

        return out[:, -n:]

# 缓存类
class CachedAs(nn.Module):
    """
    A wrapper that defines a key for the inference cache.
    """

    def __init__(self, cache_key, fn):
        super().__init__()
        self.cache_key = cache_key
        self.fn = fn

    def forward(self, x, *, cache=None, **kwargs):
        return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)

# 层缩放类
class LayerScale(nn.Module):
    def __init__(self, dim, depth, fn):
        super().__init__()
        if depth <= 18:
            init_eps = 0.1
        elif depth > 18 and depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        scale = torch.zeros(1, 1, dim).fill_(init_eps)
        self.scale = nn.Parameter(scale)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

# 层归一化类

class PreNorm(nn.Module):
    def __init__(self, dim, fn, sandwich = False):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        x = self.fn(x, **kwargs)
        return self.norm_out(x)

# 前馈类

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, dropout = 0., mult = 4.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x, cache=None, cache_key=None):
        return self.net(x)

# 标记移位类

class PreShiftToken(nn.Module):
    def __init__(self, fn, image_size, seq_len):
        super().__init__()
        self.fn = fn
        self.image_size = image_size
        self.seq_len = seq_len
        self.img_seq_len = image_size ** 2
        self.text_len = seq_len - self.img_seq_len + 1
    # 定义前向传播函数,接受输入 x,缓存 cache,缓存键 cache_key,以及其他关键字参数 kwargs
    def forward(self, x, cache=None, cache_key=None, **kwargs):
        # 获取序列长度、图像大小、文本长度
        seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len

        # 如果缓存存在且缓存键存在于缓存中
        if exists(cache) and cache_key in cache:
            # 从缓存中获取偏移量
            offset = cache['offset']
            # 断言偏移量大于等于文本长度,不支持文本的缓存推断
            assert offset >= text_len, "cached inference for text is not supported"
            # 从缓存中获取队列 q
            q = cache[cache_key]
            # 断言 q 是双端队列且长度为图像大小
            assert isinstance(q, deque) and len(q) == image_size

            # 将输入 x 按照最后一个维度分割成四部分
            x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1)

            # 将 x_top 和 x_left 添加到队列 q 中
            q.append((x_top, x_left))
            # 弹出队列 q 中的第一个元素,并更新 x_top 和 x_left
            x_top = q.popleft()[0]
            x_left = q[-2][1]
            # 如果偏移量减去文本长度对图像大小取模等于 0,则将 x_left 置零
            if (offset - text_len) % image_size == 0:
                x_left = torch.zeros_like(x_left)

            # 将 x_top、x_left 和其他部分拼接在一起
            x = torch.cat((x_top, x_left, *x_pass), dim=-1)
            # 调用 self.fn 函数,传入 x[:, None] 作为输入,同时传入缓存和其他关键字参数
            return self.fn(x[:, None], cache=cache, **kwargs)

        # 获取输入 x 的形状中的第二个维度大小
        n = x.shape[1]
        # 计算需要填充的数量
        padding = seq_len - n + 1

        # 如果序列长度小于文本长度,则没有图像令牌需要移动
        if n < text_len:
            return self.fn(x, **kwargs)

        # 获取文本和图像令牌
        x_text, x_img = x[:, :text_len], x[:, text_len:]
        # 对图像令牌进行填充
        x_img = F.pad(x_img, (0, 0, 0, padding))
        # 重新排列图像令牌的形状
        x_img = rearrange(x_img, 'b (h w) d -> b h w d', h=image_size)

        # 对文本令牌进行左移 1 位
        x_text_shift, x_text_pass = x_text.chunk(2, dim=-1)
        x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1))
        x_text = torch.cat((x_text_shift, x_text_pass), dim=-1)

        # 对图像令���进行从上和从左的移动
        x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim=-1)
        x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1))
        x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1))
        x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim=-1)

        # 将文本和图像序列合并在一起
        x_img = rearrange(x_img, 'b h w d -> b (h w) d')
        x_img = x_img[:, :-padding]
        x = torch.cat((x_text, x_img), dim=1)

        # 如果缓存存在
        if exists(cache):
            # 创建虚拟的顶部和左侧令牌
            dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1)
            dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(dummy_left)

            # 创建双端队列 q
            q = deque()
            x_img = x_img[:, -image_size:]
            # 将虚拟令牌添加到队列 q 中,直到队列大小为图像大小
            for _ in range(image_size - x_img.shape[1]):
                q.append((dummy_top, dummy_left))
            # 将图像令牌添加到队列 q 中
            for i in range(x_img.shape[1]):
                q.append(x_img[:, i].chunk(4, dim=-1)[:2])
            # 将队列 q 存入缓存中
            cache[cache_key] = q

        # 调用 self.fn 函数,传入 x 作为输入,同时传入缓存和其他关键字参数
        return self.fn(x, cache=cache, **kwargs)
# 主要的Transformer类
class Transformer(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,
        depth,
        seq_len,
        reversible = False,
        causal = True,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        attn_dropout = 0.,
        ff_dropout = 0.,
        attn_types = None,
        image_fmap_size = None,
        sparse_attn = False,
        stable = False,
        sandwich_norm = False,
        shift_tokens = False,
        rotary_emb = True,
        shared_attn_ids = None,
        shared_ff_ids = None,
        optimize_for_inference = False,  # 使用缓存友好的掩码注意力代替稀疏注意力
    # 前向传播函数
    def forward(self, x, **kwargs):
        return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)

    # 获取注意力掩码函数
    def _get_attention_mask(self, attn_type):
        # 计算图像序列长度
        img_seq_len = self.image_fmap_size ** 2
        # 计算文本长度
        text_len = self.seq_len + 1 - img_seq_len

        # 创建静态掩码
        static_mask = torch.zeros(self.seq_len, self.seq_len, dtype=torch.bool)
        static_mask[:, :text_len] = True
        # 根据不同的注意力类型生成不同的静态掩码
        if attn_type == 'axial_row':
            for row in range(self.image_fmap_size):
                begin = text_len + row * self.image_fmap_size
                end = text_len + (row + 1) * self.image_fmap_size
                static_mask[begin:end, begin:end] = True
        elif attn_type == 'axial_col':
            for col in range(self.image_fmap_size):
                begin = text_len + col
                static_mask[begin::self.image_fmap_size, begin::self.image_fmap_size] = True
        else:
            raise ValueError(f'attention type "{attn_type}" can\'t be simulated with a static mask')
        return static_mask

.\lucidrains\DALLE-pytorch\dalle_pytorch\vae.py

# 导入所需的库
import io
import sys
import os
import requests
import PIL
import warnings
import hashlib
import urllib
import yaml
from pathlib import Path
from tqdm import tqdm
from math import sqrt, log
from packaging import version

# 导入第三方库
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ
import importlib

# 导入 PyTorch 库
import torch
from torch import nn
import torch.nn.functional as F

# 导入 einops 库
from einops import rearrange

# 导入 dalle_pytorch 库中的 distributed_utils 模块
from dalle_pytorch import distributed_utils

# 常量定义

CACHE_PATH = os.path.expanduser("~/.cache/dalle")

OPENAI_VAE_ENCODER_PATH = 'https://cdn.openai.com/dall-e/encoder.pkl'
OPENAI_VAE_DECODER_PATH = 'https://cdn.openai.com/dall-e/decoder.pkl'

VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1'
VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1'

# 辅助方法

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def load_model(path):
    with open(path, 'rb') as f:
        return torch.load(f, map_location = torch.device('cpu'))

def map_pixels(x, eps = 0.1):
    return (1 - 2 * eps) * x + eps

def unmap_pixels(x, eps = 0.1):
    return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)

def download(url, filename = None, root = CACHE_PATH):
    if (
            not distributed_utils.is_distributed
            or distributed_utils.backend.is_local_root_worker()
    ):
        os.makedirs(root, exist_ok = True)
    filename = default(filename, os.path.basename(url))

    download_target = os.path.join(root, filename)
    download_target_tmp = os.path.join(root, f'tmp.{filename}')

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if (
            distributed_utils.is_distributed
            and not distributed_utils.backend.is_local_root_worker()
            and not os.path.isfile(download_target)
    ):
        # 如果文件尚不存在,则等待根工作节点下载
        distributed_utils.backend.local_barrier()

    if os.path.isfile(download_target):
        return download_target

    with urllib.request.urlopen(url) as source, open(download_target_tmp, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    os.rename(download_target_tmp, download_target)
    if (
            distributed_utils.is_distributed
            and distributed_utils.backend.is_local_root_worker()
    ):
        distributed_utils.backend.local_barrier()
    return download_target

def make_contiguous(module):
    with torch.no_grad():
        for param in module.parameters():
            param.set_(param.contiguous())

# 获取包版本信息

def get_pkg_version(pkg_name):
    from pkg_resources import get_distribution
    return get_distribution(pkg_name).version

# 预训练的 OpenAI 离散 VAE

class OpenAIDiscreteVAE(nn.Module):
    def __init__(self):
        super().__init__()
        assert version.parse(get_pkg_version('torch')) < version.parse('1.11.0'), 'torch version must be <= 1.10 in order to use OpenAI discrete vae'

        # 加载编码器和解码器模型
        self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
        self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))
        make_contiguous(self)

        self.channels = 3
        self.num_layers = 3
        self.image_size = 256
        self.num_tokens = 8192

    @torch.no_grad()
    def get_codebook_indices(self, img):
        # 映射像素值
        img = map_pixels(img)
        # 获取编码器的输出
        z_logits = self.enc.blocks(img)
        # 获取最大概率的索引
        z = torch.argmax(z_logits, dim = 1)
        return rearrange(z, 'b h w -> b (h w)')
    # 解码函数,将图像序列解码为图像
    def decode(self, img_seq):
        # 获取图像序列的形状
        b, n = img_seq.shape
        # 重新排列图像序列的形状,将其转换为二维图像
        img_seq = rearrange(img_seq, 'b (h w) -> b h w', h = int(sqrt(n)))

        # 将图像序列转换为 one-hot 编码
        z = F.one_hot(img_seq, num_classes = self.num_tokens)
        # 重新排列 one-hot 编码的形状
        z = rearrange(z, 'b h w c -> b c h w').float()
        # 使用解码器解码 one-hot 编码的数据
        x_stats = self.dec(z).float()
        # 将解码后的数据映射回像素值范围
        x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
        # 返回解码后的图像
        return x_rec

    # 前向传播函数,抛出未实现异常
    def forward(self, img):
        raise NotImplemented
# 从 Taming Transformers 论文中获取 VQGAN 模型
# https://arxiv.org/abs/2012.09841

# 从字符串中获取对象
def get_obj_from_str(string, reload=False):
    # 拆分字符串,获取模块和类名
    module, cls = string.rsplit(".", 1)
    if reload:
        # 导入模块并重新加载
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

# 根据配置实例化对象
def instantiate_from_config(config):
    if not "target" in config:
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

# VQGAN VAE 类
class VQGanVAE(nn.Module):
    def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
        super().__init__()

        if vqgan_model_path is None:
            model_filename = 'vqgan.1024.model.ckpt'
            config_filename = 'vqgan.1024.config.yml'
            download(VQGAN_VAE_CONFIG_PATH, config_filename)
            download(VQGAN_VAE_PATH, model_filename)
            config_path = str(Path(CACHE_PATH) / config_filename)
            model_path = str(Path(CACHE_PATH) / model_filename)
        else:
            model_path = vqgan_model_path
            config_path = vqgan_config_path

        config = OmegaConf.load(config_path)

        model = instantiate_from_config(config["model"])

        state = torch.load(model_path, map_location = 'cpu')['state_dict']
        model.load_state_dict(state, strict = False)

        print(f"Loaded VQGAN from {model_path} and {config_path}")

        self.model = model

        # 计算分辨率缩放因子 f
        f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]

        self.num_layers = int(log(f)/log(2))
        self.channels = 3
        self.image_size = 256
        self.num_tokens = config.model.params.n_embed
        self.is_gumbel = isinstance(self.model, GumbelVQ)

        self._register_external_parameters()

    def _register_external_parameters(self):
        """为 DeepSpeed 分区注册外部参数"""
        if (
                not distributed_utils.is_distributed
                or not distributed_utils.using_backend(
                    distributed_utils.DeepSpeedBackend)
        ):
            return

        deepspeed = distributed_utils.backend.backend_module
        deepspeed.zero.register_external_parameter(
            self, self.model.quantize.embed.weight if self.is_gumbel else self.model.quantize.embedding.weight)

    @torch.no_grad()
    def get_codebook_indices(self, img):
        b = img.shape[0]
        img = (2 * img) - 1
        _, _, [_, _, indices] = self.model.encode(img)
        if self.is_gumbel:
            return rearrange(indices, 'b h w -> b (h w)', b=b)
        return rearrange(indices, '(b n) -> b n', b = b)

    def decode(self, img_seq):
        b, n = img_seq.shape
        one_hot_indices = F.one_hot(img_seq, num_classes = self.num_tokens).float()
        z = one_hot_indices @ self.model.quantize.embed.weight if self.is_gumbel \
            else (one_hot_indices @ self.model.quantize.embedding.weight)

        z = rearrange(z, 'b (h w) c -> b c h w', h = int(sqrt(n)))
        img = self.model.decode(z)

        img = (img.clamp(-1., 1.) + 1) * 0.5
        return img

    def forward(self, img):
        raise NotImplemented

.\lucidrains\DALLE-pytorch\dalle_pytorch\version.py

# 定义变量 __version__,赋值为字符串 '1.6.6'
__version__ = '1.6.6'

.\lucidrains\DALLE-pytorch\dalle_pytorch\__init__.py

# 从dalle_pytorch包中导入DALLE, CLIP, DiscreteVAE类
# 从dalle_pytorch包中导入OpenAIDiscreteVAE, VQGanVAE类
from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE

# 从pkg_resources模块中导入get_distribution函数
from pkg_resources import get_distribution
# 从dalle_pytorch.version模块中导入__version__变量
from dalle_pytorch.version import __version__

.\lucidrains\DALLE-pytorch\generate.py

# 导入必要的库
import argparse
from pathlib import Path
from tqdm import tqdm

# 导入 torch 库
import torch

# 导入 einops 库中的 repeat 函数
from einops import repeat

# 导入 vision 相关库
from PIL import Image
from torchvision.utils import make_grid, save_image

# 导入 dalle_pytorch 库中的类和工具
from dalle_pytorch import __version__
from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer

# 参数解析
parser = argparse.ArgumentParser()

# 添加参数
parser.add_argument('--dalle_path', type = str, required = True,
                    help='path to your trained DALL-E')

parser.add_argument('--vqgan_model_path', type=str, default = None,
                   help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

parser.add_argument('--vqgan_config_path', type=str, default = None,
                   help='path to your trained VQGAN config. This should be a .yaml file.  (only valid when taming option is enabled)')

parser.add_argument('--text', type = str, required = True,
                    help='your text prompt')

parser.add_argument('--num_images', type = int, default = 128, required = False,
                    help='number of images')

parser.add_argument('--batch_size', type = int, default = 4, required = False,
                    help='batch size')

parser.add_argument('--top_k', type = float, default = 0.9, required = False,
                    help='top k filter threshold')

parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
                    help='output directory')

parser.add_argument('--bpe_path', type = str,
                    help='path to your huggingface BPE json file')

parser.add_argument('--hug', dest='hug', action = 'store_true')

parser.add_argument('--chinese', dest='chinese', action = 'store_true')

parser.add_argument('--taming', dest='taming', action='store_true')

parser.add_argument('--gentxt', dest='gentxt', action='store_true')

# 解析参数
args = parser.parse_args()

# 辅助函数
def exists(val):
    return val is not None

# 根据参数设置 tokenizer
if exists(args.bpe_path):
    klass = HugTokenizer if args.hug else YttmTokenizer
    tokenizer = klass(args.bpe_path)
elif args.chinese:
    tokenizer = ChineseTokenizer()

# 加载 DALL-E 模型
dalle_path = Path(args.dalle_path)
assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights, vae_class_name, version = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights'), load_obj.pop('vae_class_name', None), load_obj.pop('version', None)

# 友好打印
if exists(version):
    print(f'Loading a model trained with DALLE-pytorch version {version}')
else:
    print('You are loading a model trained on an older version of DALL-E pytorch - it may not be compatible with the most recent version')

# 加载 VAE 模型
if args.taming:
    vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
elif vae_params is not None:
    vae = DiscreteVAE(**vae_params)
else:
    vae = OpenAIDiscreteVAE()

assert not (exists(vae_class_name) and vae.__class__.__name__ != vae_class_name), f'you trained DALL-E using {vae_class_name} but are trying to generate with {vae.__class__.__name__} - please make sure you are passing in the correct paths and settings for the VAE to use for generation'

# 重建 DALL-E 模型
dalle = DALLE(vae = vae, **dalle_params).cuda()
dalle.load_state_dict(weights)

# 生成图片
image_size = vae.image_size
texts = args.text.split('|')

for j, text in tqdm(enumerate(texts)):
    if args.gentxt:
        text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k)
        text = gen_texts[0]
    else:
        text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda()

    text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images)

    outputs = []
    # 使用 tqdm 分块处理文本标记,每块大小为 args.batch_size,显示进度条描述为生成图像的文本
    for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'):
        # 生成图像,根据文本块和筛选阈值 args.top_k
        output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
        # 将生成的图像添加到输出列表中
        outputs.append(output)

    # 将所有输出图像拼接成一个张量
    outputs = torch.cat(outputs)

    # 保存所有图像

    # 定义文件名为文本
    file_name = text 
    # 定义输出目录为 args.outputs_dir 下的文件名替换空格为下划线后的前100个字符
    outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)]
    # 创建输出目录,如果不存在则创建,存在则忽略
    outputs_dir.mkdir(parents = True, exist_ok = True)

    # 遍历输出图像,保存为 PNG 格式
    for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
        # 保存图像为 PNG 格式,文件名为序号.png,进行归一化
        save_image(image, outputs_dir / f'{i}.png', normalize=True)
        # 将文本写入 caption.txt 文件
        with open(outputs_dir / 'caption.txt', 'w') as f:
            f.write(file_name)

    # 打印生成的图像数量和输出目录路径
    print(f'created {args.num_images} images at "{str(outputs_dir)}"')

DALL-E in Pytorch

Train DALL-E w/ DeepSpeed Join us on Discord
Released DALLE Models
Web-Hostable DALLE Checkpoints

Yannic Kilcher's video

Implementation / replication of DALL-E (paper), OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the generations.


Quick Start

Deep Daze or Big Sleep are great alternatives!

For generating video and audio, please see NÜWA

Appreciation

This library could not have been possible without the contributions of janEbert, Clay, robvanvolt, Romain Beaumont, and Alexander! 🙏

Status

  • Hannu has managed to train a small 6 layer DALL-E on a dataset of just 2000 landscape images! (2048 visual tokens)

  • Kobiso, a research engineer from Naver, has trained on the CUB200 dataset here, using full and deepspeed sparse attention

  • (3/15/21) afiaka87 has managed one epoch using a reversible DALL-E and the dVaE here

  • TheodoreGalanos has trained on 150k layouts with the following results

- Rom1504 has trained on 50k fashion images with captions with a really small DALL-E (2 layers) for just 24 hours with the following results

  • afiaka87 trained for 6 epochs on the same dataset as before thanks to the efficient 16k VQGAN with the following results

Thanks to the amazing "mega b#6696" you can generate from this checkpoint in colab -

Run inference on the Afiaka checkpoint in Colab

  • (5/2/21) First 1.3B DALL-E from 🇷🇺 has been trained and released to the public! 🎉

  • (4/8/22) Moving onwards to DALLE-2!

Install

$ pip install dalle-pytorch

Usage

Train VAE

import torch
from dalle_pytorch import DiscreteVAE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,           # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,        # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,       # codebook dimension
    hidden_dim = 64,          # hidden dimension
    num_resnet_blocks = 1,    # number of resnet blocks
    temperature = 0.9,        # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)

images = torch.randn(4, 3, 256, 256)

loss = vae(images, return_loss = True)
loss.backward()

# train with a lot of data to learn a good codebook

Train DALL-E with pretrained VAE from above

import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(text, images, return_loss = True)
loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(text)
images.shape # (4, 3, 256, 256)

To prime with a starting crop of an image, simply pass two more arguments

img_prime = torch.randn(4, 3, 256, 256)

images = dalle.generate_images(
    text,
    img = img_prime,
    num_init_img_tokens = (14 * 32)  # you can set the size of the initial crop, defaults to a little less than ~1/2 of the tokens, as done in the paper
)

images.shape # (4, 3, 256, 256)

You may also want to generate text using DALL-E. For that call this function:

text_tokens, texts = dalle.generate_texts(tokenizer, text)

OpenAI's Pretrained VAE

You can also skip the training of the VAE altogether, using the pretrained model released by OpenAI! The wrapper class should take care of downloading and caching the model for you auto-magically.

import torch
from dalle_pytorch import OpenAIDiscreteVAE, DALLE

vae = OpenAIDiscreteVAE()       # loads pretrained OpenAI VAE

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 1,                  # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(text, images, return_loss = True)
loss.backward()

Taming Transformer's Pretrained VQGAN VAE

You can also use the pretrained VAE offered by the authors of Taming Transformers! Currently only the VAE with a codebook size of 1024 is offered, with the hope that it may train a little faster than OpenAI's, which has a size of 8192.

In contrast to OpenAI's VAE, it also has an extra layer of downsampling, so the image sequence length is 256 instead of 1024 (this will lead to a 16 reduction in training costs, when you do the math). Whether it will generalize as well as the original DALL-E is up to the citizen scientists out there to discover.

Update - it works!

from dalle_pytorch import VQGanVAE

vae = VQGanVAE()

# the rest is the same as the above example

The default VQGan is the codebook size 1024 one trained on imagenet. If you wish to use a different one, you can use the vqgan_model_path and vqgan_config_path to pass the .ckpt file and the .yaml file. These options can be used both in train-dalle script or as argument of VQGanVAE class. Other pretrained VQGAN can be found in taming transformers readme. If you want to train a custom one you can follow this guide

Adjust text conditioning strength

Recently there has surfaced a new technique for guiding diffusion models without a classifier. The gist of the technique involves randomly dropping out the text condition during training, and at inference time, deriving the rough direction from unconditional to conditional distributions.

Katherine Crowson outlined in a tweet how this could work for autoregressive attention models. I have decided to include her idea in this repository for further exploration. One only has to account for two extra keyword arguments on training (null_cond_prob) and generation (cond_scale).

import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 12,
    heads = 16,
    dim_head = 64,
    attn_dropout = 0.1,
    ff_dropout = 0.1
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(
    text,
    images,
    return_loss = True,
    null_cond_prob = 0.2  # firstly, set this to the probability of dropping out the condition, 20% is recommended as a default
)

loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(
    text,
    cond_scale = 3. # secondly, set this to a value greater than 1 to increase the conditioning beyond average
)

images.shape # (4, 3, 256, 256)

That's it!

Ranking the generations

Train CLIP

import torch
from dalle_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()

To get the similarity scores from your trained Clipper, just do

images, scores = dalle.generate_images(text, mask = mask, clip = clip)

scores.shape # (2,)
images.shape # (2, 3, 256, 256)

# do your topk here, in paper they sampled 512 and chose top 32

Or you can just use the official CLIP model to rank the images from DALL-E

Scaling depth

In the blog post, they used 64 layers to achieve their results. I added reversible networks, from the Reformer paper, in order for users to attempt to scale depth at the cost of compute. Reversible networks allow you to scale to any depth at no memory cost, but a little over 2x compute cost (each layer is rerun on the backward pass).

Simply set the reversible keyword to True for the DALLE class

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True  # <-- reversible networks https://arxiv.org/abs/2001.04451
)

Sparse Attention

The blogpost alluded to a mixture of different types of sparse attention, used mainly on the image (while the text presumably had full causal attention). I have done my best to replicate these types of sparse attention, on the scant details released. Primarily, it seems as though they are doing causal axial row / column attention, combined with a causal convolution-like attention.

By default DALLE will use full attention for all layers, but you can specify the attention type per layer as follows.

  • full full attention

  • axial_row axial attention, along the rows of the image feature map

  • axial_col axial attention, along the columns of the image feature map

  • conv_like convolution-like attention, for the image feature map

The sparse attention only applies to the image. Text will always receive full attention, as said in the blogpost.

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True,
    attn_types = ('full', 'axial_row', 'axial_col', 'conv_like')  # cycles between these four types of attention
)

Deepspeed Sparse Attention

You can also train with Microsoft Deepspeed's Sparse Attention, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process.

First, you need to install Deepspeed with Sparse Attention

$ sh install_deepspeed.sh

Next, you need to install the pip package triton. It will need to be a version < 1.0 because that's what Microsoft used.

$ pip install triton==0.4.2

If both of the above succeeded, now you can train with Sparse Attention!

dalle = DALLE(
    dim = 512,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 8,
    attn_types = ('full', 'sparse')  # interleave sparse and dense attention for 64 layers
)

Training

This section will outline how to train the discrete variational autoencoder as well as the final multi-modal transformer (DALL-E). We are going to use Weights & Biases for all the experiment tracking.

(You can also do everything in this section in a Google Colab, link below)

Open In Colab Train in Colab

$ pip install wandb

Followed by

$ wandb login

VAE

To train the VAE, you just need to run

$ python train_vae.py --image_folder /path/to/your/images

If you installed everything correctly, a link to the experiments page should show up in your terminal. You can follow your link there and customize your experiment, like the example layout below.

You can of course open up the training script at ./train_vae.py, where you can modify the constants, what is passed to Weights & Biases, or any other tricks you know to make the VAE learn better.

Model will be saved periodically to ./vae.pt

In the experiment tracker, you will have to monitor the hard reconstruction, as we are essentially teaching the network to compress images into discrete visual tokens for use in the transformer as a visual vocabulary.

Weights and Biases will allow you to monitor the temperature annealing, image reconstructions (encoder and decoder working properly), as well as to watch out for codebook collapse (where the network decides to only use a few tokens out of what you provide it).

Once you have trained a decent VAE to your satisfaction, you can move on to the next step with your model weights at ./vae.pt.

DALL-E Training

Training using an Image-Text-Folder

Now you just have to invoke the ./train_dalle.py script, indicating which VAE model you would like to use, as well as the path to your folder if images and text.

The dataset I am currently working with contains a folder of images and text files, arbitraily nested in subfolders, where text file name corresponds with the image name, and where each text file contains multiple descriptions, delimited by newlines. The script will find and pair all the image and text files with the same names, and randomly select one of the textual descriptions during batch creation.

ex.

📂image-and-text-data
 ┣ 📜cat.png
 ┣ 📜cat.txt
 ┣ 📜dog.jpg
 ┣ 📜dog.txt
 ┣ 📜turtle.jpeg
 ┗ 📜turtle.txt
```py

ex. `cat.txt`

```py
A black and white cat curled up next to the fireplace
A fireplace, with a cat sleeping next to it
A black cat with a red collar napping
```py

If you have a dataset with its own directory structure for tying together image and text descriptions, do let me know in the issues, and I'll see if I can accommodate it in the script.

```py
$ python train_dalle.py --vae_path ./vae.pt --image_text_folder /path/to/data
```py

You likely will not finish DALL-E training as quickly as you did your Discrete VAE. To resume from where you left off, just run the same script, but with the path to your DALL-E checkpoints.

```py
$ python train_dalle.py --dalle_path ./dalle.pt --image_text_folder /path/to/data
```py

## Training using WebDataset

WebDataset files are regular .tar(.gz) files which can be streamed and used for DALLE-pytorch training.
You Just need to provide the image (first comma separated argument) and caption (second comma separated argument) 
column key after the --wds argument. The ---image_text_folder points to your .tar(.gz) file instead of the datafolder.

```py
$ python train_dalle.py --wds img,cap --image_text_folder /path/to/data.tar(.gz)
```py

Distributed training with deepspeed works the same way, e.g.:

```py
$ deepspeed train_dalle.py --wds img,cap --image_text_folder /path/to/data.tar(.gz) --fp16 --deepspeed
```py

If you have containing shards (dataset split into several .tar(.gz) files), this is also supported:

```py
$ deepspeed train_dalle.py --wds img,cap --image_text_folder /path/to/shardfolder --fp16 --deepspeed
```py

You can stream the data from a http server or gloogle cloud storage like this:

```py
$ deepspeed train_dalle.py --image_text_folder "http://storage.googleapis.com/nvdata-openimages/openimages-train-{000000..000554}.tar" --wds jpg,json --taming --truncate_captions --random_resize_crop_lower_ratio=0.8 --attn_types=full --epochs=2 --fp16 --deepspeed
```py

In order to convert your image-text-folder to WebDataset format, you can make use of one of several methods.
(https://www.youtube.com/watch?v=v_PacO-3OGQ here are given 4 examples, or a little helper script which also supports splitting your dataset
into shards of .tar.gz files https://github.com/robvanvolt/DALLE-datasets/blob/main/wds_create_shards.py)

### DALL-E with OpenAI's VAE

You can now also train DALL-E without having to train the Discrete VAE at all, courtesy to their open-sourcing their model. You simply have to invoke the `train_dalle.py` script without specifying the `--vae_path`

```py
$ python train_dalle.py --image_text_folder /path/to/coco/dataset
```py

### DALL-E with Taming Transformer's VQVAE

Just use the `--taming` flag. Highly recommended you use this VAE over the OpenAI one!

```py
$ python train_dalle.py --image_text_folder /path/to/coco/dataset --taming
```py

### Generation

Once you have successfully trained DALL-E, you can then use the saved model for generation!

```py
$ python generate.py --dalle_path ./dalle.pt --text 'fireflies in a field under a full moon'
```py

You should see your images saved as `./outputs/{your prompt}/{image number}.jpg`

To generate multiple images, just pass in your text with '|' character as a separator.

ex.

```py
$ python generate.py --dalle_path ./dalle.pt --text 'a dog chewing a bone|a cat chasing mice|a frog eating a fly'
```py

Note that DALL-E is a full image+text language model. As a consequence you can also generate text using a dalle model.

```py
$ python generate.py --dalle_path ./dalle.pt --text 'a dog chewing a bone' --gentext
```py

This will complete the provided text, save it in a caption.txt and generate the corresponding images.

### Docker

You can use a docker container to make sure the version of Pytorch and Cuda are correct for training DALL-E. <a href="https://docs.docker.com/get-docker/">Docker</a> and <a href='#'>Docker Container Runtime</a> should be installed.

To build:

```py
docker build -t dalle docker
```py

To run in an interactive shell:

```py
docker run --gpus all -it --mount src="$(pwd)",target=/workspace/dalle,type=bind dalle:latest bash
```py

### Distributed Training

#### DeepSpeed

Thanks to <a href="https://github.com/janEbert">janEbert</a>, the repository is now equipped so you can train DALL-E with Microsoft's <a href="https://www.deepspeed.ai/">Deepspeed</a>!

You can simply replace any `$ python <file>.py [args...]` command with

```py
$ deepspeed <file>.py [args...] --deepspeed
```py

to use the aforementioned DeepSpeed library for distributed training, speeding up your experiments.

Modify the `deepspeed_config` dictionary in `train_dalle.py` or
`train_vae.py` according to the DeepSpeed settings you'd like to use
for each one. See the [DeepSpeed configuration
docs](https://www.deepspeed.ai/docs/config-json/) for more
information.

#### DeepSpeed - 32 and 16 bit Precision
As of DeepSpeed version 0.3.16, ZeRO optimizations can be used with
single-precision floating point numbers. If you are using an older
version, you'll have to pass the `--fp16` flag to be able to enable
ZeRO optimizations.


#### DeepSpeed - Apex Automatic Mixed Precision.
Automatic mixed precision is a stable alternative to fp16 which still provides a decent speedup.
In order to run with Apex AMP (through DeepSpeed), you will need to install DeepSpeed using either the Dockerfile or the bash script.

Then you will need to install apex from source. 
This may take awhile and you may see some compilation warnings which can be ignored. 
```py
sh install_apex.sh
```py

Now, run `train_dalle.py` with `deepspeed` instead of `python` as done here:
```py
deepspeed train_dalle.py \
    --taming \
    --image_text_folder 'DatasetsDir' \
    --distr_backend 'deepspeed' \
    --amp
```py

#### Horovod

[Horovod](https://horovod.ai) offers a stable way for data parallel
training.

After [installing
Horovod](https://github.com/lucidrains/DALLE-pytorch/wiki/Horovod-Installation),
replace any `$ python <file>.py [args...]` command with

```py
$ horovodrun -np <num-gpus> <file>.py [args...] --distributed_backend horovod
```py

to use the Horovod library for distributed training, speeding up your
experiments. This will multiply your effective batch size per training
step by `<num-gpus>`, so you may need to rescale the learning rate
accordingly.

#### Custom Tokenizer

This repository supports custom tokenization with <a href="https://github.com/VKCOM/YouTokenToMe">YouTokenToMe</a>, if you wish to use it instead of the default simple tokenizer. Simply pass in an extra `--bpe_path` when invoking `train_dalle.py` and `generate.py`, with the path to your BPE model file.

The only requirement is that you use `0` as the padding during tokenization

ex.

```py
$ python train_dalle.py --image_text_folder ./path/to/data --bpe_path ./path/to/bpe.model
```py

To create a BPE model file from scratch, firstly

```py
$ pip install youtokentome
```py

Then you need to prepare a big text file that is a representative sample of the type of text you want to encode. You can then invoke the `youtokentome` command-line tools. You'll also need to specify the vocab size you wish to use, in addition to the corpus of text.

```py
$ yttm bpe --vocab_size 8000 --data ./path/to/big/text/file.txt --model ./path/to/bpe.model

That's it! The BPE model file is now saved to ./path/to/bpe.model and you can begin training!

Chinese

You can train with a pretrained chinese tokenizer offered by Huggingface 🤗 by simply passing in an extra flag --chinese

ex.

$ python train_dalle.py --chinese --image_text_folder ./path/to/data
$ python generate.py --chinese --text '追老鼠的猫'

Citations

@misc{ramesh2021zeroshot,
    title   = {Zero-Shot Text-to-Image Generation}, 
    author  = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
    year    = {2021},
    eprint  = {2102.12092},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{unpublished2021clip,
    title  = {CLIP: Connecting Text and Images},
    author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
    year   = {2021}
}
@misc{kitaev2020reformer,
    title   = {Reformer: The Efficient Transformer},
    author  = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya},
    year    = {2020},
    eprint  = {2001.04451},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{esser2021taming,
    title   = {Taming Transformers for High-Resolution Image Synthesis},
    author  = {Patrick Esser and Robin Rombach and Björn Ommer},
    year    = {2021},
    eprint  = {2012.09841},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{ho2021classifierfree,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho and Tim Salimans},
    booktitle = {NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications},
    year    = {2021},
    url     = {https://openreview.net/forum?id=qw8AKxfYbI}
}
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/RiversHaveWings/status/1478093658716966912}
}
@article{Liu2023BridgingDA,
    title   = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
    author  = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2304.08612}
}

Those who do not want to imitate anything, produce nothing. - Dali

posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(10)  评论(0编辑  收藏  举报