Lucidrains-系列项目源码解析-二十八-

Lucidrains 系列项目源码解析(二十八)

.\lucidrains\multistream-transformers\multistream_transformers\__init__.py

# 从 multistream_transformers 包中导入 MultistreamTransformer 类
from multistream_transformers.multistream_transformers import MultistreamTransformer

Multistream Transformers

Implementation of Multistream Transformers in Pytorch.

This repository deviates slightly from the paper, where instead of using the skip connection across all streams, it uses attention pooling across all tokens in the same position. This has produced the best results in my experiments with number of streams greater than 2.

Install

$ pip install multistream-transformers

Usage

import torch
from multistream_transformers import MultistreamTransformer

model = MultistreamTransformer(
    num_tokens = 256,         # number of tokens
    dim = 512,                # dimension
    depth = 4,                # depth
    causal = True,            # autoregressive or not
    max_seq_len = 1024,       # maximum sequence length
    num_streams = 2           # number of streams - 1 would make it a regular transformer
)

x = torch.randint(0, 256, (2, 1024))
mask = torch.ones((2, 1024)).bool()

logits = model(x, mask = mask) # (2, 1024, 256)

Citations

@misc{burtsev2021multistream,
    title   = {Multi-Stream Transformers}, 
    author  = {Mikhail Burtsev and Anna Rumshisky},
    year    = {2021},
    eprint  = {2107.10342},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\multistream-transformers\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'multistream-transformers',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '0.0.4',  # 版本号
  license='MIT',  # 许可证
  description = 'Multistream Transformers - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/multistream-transformers',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\multistream-transformers\train.py

# 导入所需的库
from multistream_transformers import MultistreamTransformer
from multistream_transformers.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型

model = MultistreamTransformer(
    num_tokens = 256,
    dim = 512,
    max_seq_len = SEQ_LEN,
    depth = 4,
    heads = 8,
    causal = True,
    num_streams = 2
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\attend.py

# 导入所需的模块和函数
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入自定义的 FlashAttentionFunction 函数
from memory_efficient_attention_pytorch.flash_attention import FlashAttentionFunction

# 定义一个命名元组 AttentionConfig,包含三个布尔类型的字段
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只能调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个只能打印一次的函数
print_once = once(print)

# 主要类定义
class Attend(nn.Module):
    def __init__(
        self,
        scale = 8,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.scale = scale
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        # 检查是否启用了 flash attention,且 PyTorch 版本是否大于等于 2.0
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 CUDA 和 CPU 的高效注意力配置
        self.cuda_config = None
        self.no_hardware_detected = False

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(False, True, False)

    # 定义 flash attention 函数
    def flash_attn(self, q, k, v, mask = None):
        default_scale = q.shape[-1] ** -0.5

        is_cuda = q.is_cuda

        q, k, v = map(lambda t: t.contiguous(), (q, k, v))

        # 重新缩放输入张量以适应默认缩放比例
        rescale = self.scale / default_scale
        q = q * (rescale ** 0.5)
        k = k * (rescale ** 0.5)

        # 如果没有检测到正确的硬件或不在 CUDA 上,则使用简单的实现
        use_naive = not is_cuda or not exists(self.cuda_config)

        if not is_cuda or self.no_hardware_detected:
            return FlashAttentionFunction.apply(q, k, v, mask, False, 512, 512)

        # 尝试使用 PyTorch 2.0 的 flash attention 实现
        try:
            raise Exception()
            with torch.backends.cuda.sdp_kernel(**self.cuda_config._asdict()):
                out = F.scaled_dot_product_attention(
                    q, k, v,
                    attn_mask = mask,
                    dropout_p = self.dropout if self.training else 0.
                )
        except:
            print_once('no hardware detected, falling back to naive implementation from memory-efficient-attention-pytorch library')
            self.no_hardware_detected = True

            out = FlashAttentionFunction.apply(q, k, v, mask, False, 512, 512)

        return out
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)、掩码(mask)和是否强制非闪存(force_non_flash)作为参数
    def forward(self, q, k, v, mask = None, force_non_flash = False):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 如果启用了flash且不强制使用非flash,则调用flash_attn函数
        if self.flash and not force_non_flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 计算相似度
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        # 掩码处理
        if exists(mask):
            mask_value = -torch.finfo(sim.dtype).max
            sim = sim.masked_fill(~mask, mask_value)

        # 注意力计算
        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\muse_maskgit_pytorch.py

        # 定义一个注意力机制模块
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        cross_attend = False,
        scale = 8,
        flash = True,
        dropout = 0.
    ):
        super().__init__()
        # 缩放因子
        self.scale = scale
        # 头数
        self.heads =  heads
        # 内部维度
        inner_dim = dim_head * heads

        # 是否进行跨注意力
        self.cross_attend = cross_attend
        # 归一化层
        self.norm = LayerNorm(dim)

        # 注意力机制
        self.attend = Attend(
            flash = flash,
            dropout = dropout,
            scale = scale
        )

        # 空键值对
        self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head))

        # 转换查询
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        # 转换键值对
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        # 查询缩放
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        # 键缩放
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 输出转换
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context = None,
        context_mask = None
        ):
        # 断言条件:如果存在上下文信息,则不应该使用交叉注意力,反之亦然
        assert not (exists(context) ^ self.cross_attend)

        # 获取输入张量 x 的倒数第二维度的大小
        n = x.shape[-2]
        # 获取头数 h 和是否使用交叉注意力 is_cross_attn
        h, is_cross_attn = self.heads, exists(context)

        # 对输入张量 x 进行归一化处理
        x = self.norm(x)

        # 根据是否使用交叉注意力选择键值对输入
        kv_input = context if self.cross_attend else x

        # 分别计算查询 q、键 k、值 v,并根据最后一维度拆分成三部分
        q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))

        # 将查询 q、键 k、值 v 重排维度,使得头数 h 在第二维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 获取空键值对 nk、nv,并根据头数 h 和批次大小重复扩展
        nk, nv = self.null_kv
        nk, nv = map(lambda t: repeat(t, 'h 1 d -> b h 1 d', b = x.shape[0]), (nk, nv))

        # 将键 k 和值 v 连接空键值对 nk、nv
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 对查询 q、键 k 进行 L2 归一化处理
        q, k = map(l2norm, (q, k))
        # 对查询 q、键 k 进行缩放
        q = q * self.q_scale
        k = k * self.k_scale

        # 如果存在上下文掩码,则重复扩展到匹配注意力矩阵的维度,并进行填充
        if exists(context_mask):
            context_mask = repeat(context_mask, 'b j -> b h i j', h = h, i = n)
            context_mask = F.pad(context_mask, (1, 0), value = True)

        # 进行注意力计算
        out = self.attend(q, k, v, mask = context_mask)

        # 重排输出维度,使得头数 h 在第二维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回输出结果
        return self.to_out(out)
# 定义 TransformerBlocks 类,用于堆叠多个 Transformer 模块
class TransformerBlocks(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入维度
        depth,  # 堆叠的 Transformer 模块数量
        dim_head = 64,  # 注意力头的维度
        heads = 8,  # 注意力头的数量
        ff_mult = 4,  # FeedForward 层的倍数
        flash = True  # 是否使用 Flash
    ):
        super().__init__()
        self.layers = nn.ModuleList([])  # 初始化空的模块列表

        for _ in range(depth):  # 根据 depth 循环堆叠 Transformer 模块
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash),  # 添加注意力模块
                Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attend = True, flash = flash),  # 添加交叉注意力模块
                FeedForward(dim = dim, mult = ff_mult)  # 添加 FeedForward 模块
            ]))

        self.norm = LayerNorm(dim)  # 初始化 LayerNorm 模块

    def forward(self, x, context = None, context_mask = None):  # 前向传播函数
        for attn, cross_attn, ff in self.layers:  # 遍历每个 Transformer 模块
            x = attn(x) + x  # 执行注意力模块并加上残差连接

            x = cross_attn(x, context = context, context_mask = context_mask) + x  # 执行交叉注意力模块并加上残差连接

            x = ff(x) + x  # 执行 FeedForward 模块并加上残差连接

        return self.norm(x)  # 返回 LayerNorm 后的结果

# 定义 Transformer 类,用于处理文本数据
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,  # 标记的数量
        dim,  # 输入维度
        seq_len,  # 序列长度
        dim_out = None,  # 输出维度
        t5_name = DEFAULT_T5_NAME,  # T5 模型名称
        self_cond = False,  # 是否自我条件
        add_mask_id = False,  # 是否添加 mask 标记
        **kwargs
    ):
        super().__init__()
        self.dim = dim  # 初始化输入维度
        self.mask_id = num_tokens if add_mask_id else None  # 初始化 mask 标记

        self.num_tokens = num_tokens  # 初始化标记数量
        self.token_emb = nn.Embedding(num_tokens + int(add_mask_id), dim)  # 初始化标记嵌入层
        self.pos_emb = nn.Embedding(seq_len, dim)  # 初始化位置嵌入层
        self.seq_len = seq_len  # 初始化序列长度

        self.transformer_blocks = TransformerBlocks(dim = dim, **kwargs)  # 初始化 TransformerBlocks 模块
        self.norm = LayerNorm(dim)  # 初始化 LayerNorm 模块

        self.dim_out = default(dim_out, num_tokens)  # 初始化输出维度
        self.to_logits = nn.Linear(dim, self.dim_out, bias = False)  # 初始化线性层

        # 文本条件

        self.encode_text = partial(t5_encode_text, name = t5_name)  # 编码文本

        text_embed_dim = get_encoded_dim(t5_name)  # 获取编码后的文本维度

        self.text_embed_proj = nn.Linear(text_embed_dim, dim, bias = False) if text_embed_dim != dim else nn.Identity()  # 初始化文本嵌入层

        # 可选的自我条件

        self.self_cond = self_cond  # 初始化自我条件
        self.self_cond_to_init_embed = FeedForward(dim)  # 初始化 FeedForward 模块

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3.,  # 条件缩放因子
        return_embed = False,  # 是否返回嵌入
        **kwargs
    ):
        if cond_scale == 1:  # 如果条件缩放因子为1
            return self.forward(*args, return_embed = return_embed, cond_drop_prob = 0., **kwargs)  # 执行前向传播

        logits, embed = self.forward(*args, return_embed = True, cond_drop_prob = 0., **kwargs)  # 执行前向传播并返回嵌入

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)  # 执行前向传播,使用条件丢弃

        scaled_logits = null_logits + (logits - null_logits) * cond_scale  # 计算缩放后的 logits

        if return_embed:  # 如果需要返回嵌入
            return scaled_logits, embed  # 返回缩放后的 logits 和嵌入

        return scaled_logits  # 返回缩放后的 logits

    def forward_with_neg_prompt(
        self,
        text_embed: torch.Tensor,
        neg_text_embed: torch.Tensor,
        cond_scale = 3.,  # 条件缩放因子
        return_embed = False,
        **kwargs
    ):
        neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs)  # 执行前向传播,使用负面文本嵌入
        pos_logits, embed = self.forward(*args, return_embed = True, text_embed = text_embed, cond_drop_prob = 0., **kwargs)  # 执行前向传播,使用正面文本嵌入

        logits = neg_logits + (pos_logits - neg_logits) * cond_scale  # 计算缩放后的 logits

        if return_embed:  # 如果需要返回嵌入
            return scaled_logits, embed  # 返回缩放后的 logits 和嵌入

        return scaled_logits  # 返回缩放后的 logits

    def forward(
        self,
        x,
        return_embed = False,
        return_logits = False,
        labels = None,
        ignore_index = 0,
        self_cond_embed = None,
        cond_drop_prob = 0.,
        conditioning_token_ids: Optional[torch.Tensor] = None,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[torch.Tensor] = None
        ):
        # 获取输入张量的设备、维度和长度
        device, b, n = x.device, *x.shape
        # 断言序列长度不超过self.seq_len

        # 准备文本数据

        # 断言texts和text_embeds中只有一个存在
        assert exists(texts) ^ exists(text_embeds)

        # 如果texts存在,则使用self.encode_text方法对texts进行编码得到text_embeds
        if exists(texts):
            text_embeds = self.encode_text(texts)

        # 对text_embeds进行线性变换得到context
        context = self.text_embed_proj(text_embeds)

        # 生成context_mask,用于指示哪些位置有文本数据
        context_mask = (text_embeds != 0).any(dim=-1)

        # 如果cond_drop_prob大于0,则进行条件性的dropout
        if cond_drop_prob > 0.:
            mask = prob_mask_like((b, 1), 1. - cond_drop_prob, device)
            context_mask = context_mask & mask

        # 如果conditioning_token_ids存在,则将其与context拼接起来
        if exists(conditioning_token_ids):
            conditioning_token_ids = rearrange(conditioning_token_ids, 'b ... -> b (...)')
            cond_token_emb = self.token_emb(conditioning_token_ids)
            context = torch.cat((context, cond_token_emb), dim=-2)
            context_mask = F.pad(context_mask, (0, conditioning_token_ids.shape[-1]), value=True)

        # 对输入的token进行嵌入
        x = self.token_emb(x)
        x = x + self.pos_emb(torch.arange(n, device=device))

        # 如果self.self_cond为True,则对self_cond_embed进行初始化
        if self.self_cond:
            if not exists(self_cond_embed):
                self_cond_embed = torch.zeros_like(x)
            x = x + self.self_cond_to_init_embed(self_cond_embed)

        # 使用transformer_blocks进行编码
        embed = self.transformer_blocks(x, context=context, context_mask=context_mask)

        # 将编码结果转换为logits
        logits = self.to_logits(embed)

        # 如果return_embed为True,则返回logits和embed
        if return_embed:
            return logits, embed

        # 如果labels不存在,则返回logits
        if not exists(labels):
            return logits

        # 根据self.dim_out的值计算损失
        if self.dim_out == 1:
            loss = F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels)
        else:
            loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index=ignore_index)

        # 如果return_logits为False,则返回损失
        if not return_logits:
            return loss

        # 返回损失和logits
        return loss, logits
# 定义一个自我批评的包装器类
class SelfCritic(nn.Module):
    # 初始化方法,接受一个网络对象作为参数
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.to_pred = nn.Linear(net.dim, 1)

    # 带有条件缩放的前向传播方法
    def forward_with_cond_scale(self, x, *args, **kwargs):
        _, embeds = self.net.forward_with_cond_scale(x, *args, return_embed=True, **kwargs)
        return self.to_pred(embeds)

    # 带有负面提示的前向传播方法
    def forward_with_neg_prompt(self, x, *args, **kwargs):
        _, embeds = self.net.forward_with_neg_prompt(x, *args, return_embed=True, **kwargs)
        return self.to_pred(embeds)

    # 前向传播方法
    def forward(self, x, *args, labels=None, **kwargs):
        _, embeds = self.net(x, *args, return_embed=True, **kwargs)
        logits = self.to_pred(embeds)

        # 如果没有标签,则返回logits
        if not exists(labels):
            return logits

        # 重新排列logits并计算二元交叉熵损失
        logits = rearrange(logits, '... 1 -> ...')
        return F.binary_cross_entropy_with_logits(logits, labels)

# 特殊化的transformers类

# MaskGitTransformer类继承自Transformer类
class MaskGitTransformer(Transformer):
    # 初始化方法,接受任意参数和关键字参数
    def __init__(self, *args, **kwargs):
        # 断言'add_mask_id'不在关键字参数中
        assert 'add_mask_id' not in kwargs
        super().__init__(*args, add_mask_id=True, **kwargs)

# TokenCritic类继承自Transformer类
class TokenCritic(Transformer):
    # 初始化方法,接受任意参数和关键字参数
    def __init__(self, *args, **kwargs):
        # 断言'dim_out'不在关键字参数中
        assert 'dim_out' not in kwargs
        super().__init__(*args, dim_out=1, **kwargs)

# 无分类器指导函数

# 创建一个均匀分布的张量
def uniform(shape, min=0, max=1, device=None):
    return torch.zeros(shape, device=device).float().uniform_(0, 1)

# 根据概率创建掩码张量
def prob_mask_like(shape, prob, device=None):
    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return uniform(shape, device=device) < prob

# 采样辅助函数

# 计算张量的自然对数
def log(t, eps=1e-20):
    return torch.log(t.clamp(min=eps))

# 生成Gumbel噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从Gumbel分布中采样
def gumbel_sample(t, temperature=1., dim=-1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)

# 保留top-k概率的值,其余设为负无穷
def top_k(logits, thres=0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = logits.topk(k, dim=-1)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(2, ind, val)
    return probs

# 噪声调度

# 余弦调度函数
def cosine_schedule(t):
    return torch.cos(t * math.pi * 0.5)

# 主MaskGit类

# MaskGit类继承自nn.Module类
@beartype
class MaskGit(nn.Module):
    # 初始化方法,接受多个参数和关键字参数
    def __init__(
        self,
        image_size,
        transformer: MaskGitTransformer,
        noise_schedule: Callable = cosine_schedule,
        token_critic: Optional[TokenCritic] = None,
        self_token_critic=False,
        vae: Optional[VQGanVAE] = None,
        cond_vae: Optional[VQGanVAE] = None,
        cond_image_size=None,
        cond_drop_prob=0.5,
        self_cond_prob=0.9,
        no_mask_token_prob=0.,
        critic_loss_weight=1.
        ):
        # 调用父类的构造函数
        super().__init__()
        # 如果存在 VAE 模型,则复制一个用于评估的副本,否则设为 None
        self.vae = vae.copy_for_eval() if exists(vae) else None

        # 如果存在条件 VAE 模型,则将其设为评估模式,否则设为与 VAE 模型相同
        if exists(cond_vae):
            self.cond_vae = cond_vae.eval()
        else:
            self.cond_vae = self.vae

        # 断言条件:如果存在条件 VAE 模型,则条件图像大小必须指定
        assert not (exists(cond_vae) and not exists(cond_image_size)), 'cond_image_size must be specified if conditioning'

        # 初始化图像大小和条件图像大小等属性
        self.image_size = image_size
        self.cond_image_size = cond_image_size
        self.resize_image_for_cond_image = exists(cond_image_size)

        # 设置条件丢弃概率
        self.cond_drop_prob = cond_drop_prob

        # 设置变换器和是否自我条件
        self.transformer = transformer
        self.self_cond = transformer.self_cond
        # 断言条件:VAE 和条件 VAE 的码书大小必须与变换器的标记数相等
        assert self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens, 'transformer num_tokens must be set to be equal to the vae codebook size'

        # 设置掩码 ID 和噪声计划
        self.mask_id = transformer.mask_id
        self.noise_schedule = noise_schedule

        # 断言条件:自我令牌评论和令牌评论不能同时存在
        assert not (self_token_critic and exists(token_critic))
        self.token_critic = token_critic

        # 如果存在自我令牌评论,则将其设置为 SelfCritic 类的实例
        if self_token_critic:
            self.token_critic = SelfCritic(transformer)

        # 设置评论损失权重
        self.critic_loss_weight = critic_loss_weight

        # 设置自我条件概率
        self.self_cond_prob = self_cond_prob

        # 设置不掩码令牌的概率,以保持相同令牌,以便变换器在所有令牌上产生更好的嵌入,如原始 BERT 论文中所做
        # 可能需要用于自我条件
        self.no_mask_token_prob = no_mask_token_prob

    # 保存模型参数到指定路径
    def save(self, path):
        torch.save(self.state_dict(), path)

    # 从指定路径加载模型参数
    def load(self, path):
        path = Path(path)
        assert path.exists()
        state_dict = torch.load(str(path))
        self.load_state_dict(state_dict)

    # 生成方法,用于生成文本
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        texts: List[str],
        negative_texts: Optional[List[str]] = None,
        cond_images: Optional[torch.Tensor] = None,
        fmap_size = None,
        temperature = 1.,
        topk_filter_thres = 0.9,
        can_remask_prev_masked = False,
        force_not_use_token_critic = False,
        timesteps = 18,  # 理想的步数是 18,参考 maskgit 论文
        cond_scale = 3,
        critic_noise_scale = 1
    # 前向传播方法,用于模型推理
    def forward(
        self,
        images_or_ids: torch.Tensor,
        ignore_index = -1,
        cond_images: Optional[torch.Tensor] = None,
        cond_token_ids: Optional[torch.Tensor] = None,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[torch.Tensor] = None,
        cond_drop_prob = None,
        train_only_generator = False,
        sample_temperature = None
        ):
            # 如果需要进行标记化

            if images_or_ids.dtype == torch.float:
                assert exists(self.vae), 'vqgan vae must be passed in if training from raw images'
                assert all([height_or_width == self.image_size for height_or_width in images_or_ids.shape[-2:]]), 'the image you passed in is not of the correct dimensions'

                with torch.no_grad():
                    _, ids, _ = self.vae.encode(images_or_ids)
            else:
                assert not self.resize_image_for_cond_image, 'you cannot pass in raw image token ids if you want the framework to autoresize image for conditioning super res transformer'
                ids = images_or_ids

            # 处理指定的条件图像

            if self.resize_image_for_cond_image:
                cond_images_or_ids = F.interpolate(images_or_ids, self.cond_image_size, mode='nearest')

            # 获取一些基本变量

            ids = rearrange(ids, 'b ... -> b (...)')

            batch, seq_len, device, cond_drop_prob = *ids.shape, ids.device, default(cond_drop_prob, self.cond_drop_prob)

            # 如果需要对条件图像进行标记化

            assert not (exists(cond_images) and exists(cond_token_ids)), 'if conditioning on low resolution, cannot pass in both images and token ids'

            if exists(cond_images):
                assert exists(self.cond_vae), 'cond vqgan vae must be passed in'
                assert all([height_or_width == self.cond_image_size for height_or_width in cond_images.shape[-2:]])

                with torch.no_grad():
                    _, cond_token_ids, _ = self.cond_vae.encode(cond_images)

            # 准备掩码

            rand_time = uniform((batch,), device=device)
            rand_mask_probs = self.noise_schedule(rand_time)
            num_token_masked = (seq_len * rand_mask_probs).round().clamp(min=1)

            mask_id = self.mask_id
            batch_randperm = torch.rand((batch, seq_len), device=device).argsort(dim=-1)
            mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1')

            mask_id = self.transformer.mask_id
            labels = torch.where(mask, ids, ignore_index)

            if self.no_mask_token_prob > 0.:
                no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob)
                mask &= ~no_mask_mask

            x = torch.where(mask, mask_id, ids)

            # 获取文本嵌入

            if exists(texts):
                text_embeds = self.transformer.encode_text(texts)
                texts = None

            # 自我条件

            self_cond_embed = None

            if self.transformer.self_cond and random() < self.self_cond_prob:
                with torch.no_grad():
                    _, self_cond_embed = self.transformer(
                        x,
                        text_embeds=text_embeds,
                        conditioning_token_ids=cond_token_ids,
                        cond_drop_prob=0.,
                        return_embed=True
                    )

                    self_cond_embed.detach_()

            # 获取损失

            ce_loss, logits = self.transformer(
                x,
                text_embeds=text_embeds,
                self_cond_embed=self_cond_embed,
                conditioning_token_ids=cond_token_ids,
                labels=labels,
                cond_drop_prob=cond_drop_prob,
                ignore_index=ignore_index,
                return_logits=True
            )

            if not exists(self.token_critic) or train_only_generator:
                return ce_loss

            # 令牌评论家损失

            sampled_ids = gumbel_sample(logits, temperature=default(sample_temperature, random()))

            critic_input = torch.where(mask, sampled_ids, x)
            critic_labels = (ids != critic_input).float()

            bce_loss = self.token_critic(
                critic_input,
                text_embeds=text_embeds,
                conditioning_token_ids=cond_token_ids,
                labels=critic_labels,
                cond_drop_prob=cond_drop_prob
            )

            return ce_loss + self.critic_loss_weight * bce_loss
# 定义 Muse 类,继承自 nn.Module
@beartype
class Muse(nn.Module):
    # 初始化方法
    def __init__(
        self,
        base: MaskGit,  # 接收一个 MaskGit 类型的参数作为基础模型
        superres: MaskGit  # 接收一个 MaskGit 类型的参数作为超分辨率模型
    ):
        super().__init__()  # 调用父类的初始化方法
        self.base_maskgit = base.eval()  # 将传入的基础模型设为只读模式并赋值给实例变量

        assert superres.resize_image_for_cond_image  # 断言超分辨率模型具有 resize_image_for_cond_image 属性
        self.superres_maskgit = superres.eval()  # 将传入的超分辨率模型设为只读模式并赋值给实例变量

    # 前向传播方法,使用 torch.no_grad() 装饰器
    @torch.no_grad()
    def forward(
        self,
        texts: List[str],  # 接收一个字符串列表作为输入文本
        cond_scale = 3.,  # 设置默认条件尺度为 3
        temperature = 1.,  # 设置默认温度为 1
        timesteps = 18,  # 设置默认时间步数为 18
        superres_timesteps = None,  # 超分辨率时间步数,默认为 None
        return_lowres = False,  # 是否返回低分辨率图像,默认为 False
        return_pil_images = True  # 是否返回 PIL 图像,默认为 True
    ):
        # 使用基础模型生成低分辨率图像
        lowres_image = self.base_maskgit.generate(
            texts = texts,
            cond_scale = cond_scale,
            temperature = temperature,
            timesteps = timesteps
        )

        # 使用超分辨率模型生成高分辨率图像
        superres_image = self.superres_maskgit.generate(
            texts = texts,
            cond_scale = cond_scale,
            cond_images = lowres_image,
            temperature = temperature,
            timesteps = default(superres_timesteps, timesteps)  # 使用默认的超分辨率时间步数
        )
        
        # 如果需要返回 PIL 图像
        if return_pil_images:
            # 将低分辨率图像转换为 PIL 图像列表
            lowres_image = list(map(T.ToPILImage(), lowres_image))
            # 将高分辨率图像转换为 PIL 图像列表
            superres_image = list(map(T.ToPILImage(), superres_image))            

        # 如果不需要返回低分辨率图像,则返回高分辨率图像
        if not return_lowres:
            return superres_image

        # ��回高分辨率图像和低分辨率图像
        return superres_image, lowres_image

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\t5.py

# 导入日志、torch和transformers模块
import logging
import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config

# 设置transformers日志级别为error
transformers.logging.set_verbosity_error()

# 检查值是否存在的函数
def exists(val):
    return val is not None

# 配置参数
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}

# 全局单例变量

# 获取指定模型的tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

# 获取指定模型的encoder模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定模型的encoder模型和tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()
    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)
    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        # 避免加载模型,仅获取维度
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config=config)
    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]
    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config
    else:
        assert False
    return config.d_model

# 编码文本

# 使用beartype装饰器,指定texts参数为字符串或字符串列表
@beartype
def t5_encode_text(
    texts: Union[str, List[str]],
    name = DEFAULT_T5_NAME,
    output_device = None
):
    if isinstance(texts, str):
        texts = [texts]

    # 获取指定模型的encoder模型和tokenizer
    t5, tokenizer = get_model_and_tokenizer(name)

    # 如果CUDA可用,则将模型移至CUDA
    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device

    # 对文本进行编码
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = attn_mask.bool()
    encoded_text = encoded_text.masked_fill(~attn_mask[..., None], 0.)

    # 如果output_device存在,则将编码后的文本移至指定设备
    if not exists(output_device):
        return encoded_text

    encoded_text.to(output_device)
    return encoded_text

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\trainers.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 random 模块中导入 choice 函数
from random import choice
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree
# 从 functools 模块中导入 partial 函数

# 从 beartype 模块中导入 beartype 装饰器
from beartype import beartype

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.optim 模块中导入 Adam 类
from torch.optim import Adam
# 从 torch.utils.data 模块中导入 Dataset, DataLoader, random_split 类
from torch.utils.data import Dataset, DataLoader, random_split

# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 从 torchvision.datasets 模块中导入 ImageFolder 类
from torchvision.datasets import ImageFolder
# 从 torchvision.utils 模块中导入 make_grid, save_image 函数

# 从 muse_maskgit_pytorch.vqgan_vae 模块中导入 VQGanVAE 类

# 从 einops 模块中导入 rearrange 函数

# 从 accelerate 模块中导入 Accelerator, DistributedType, DistributedDataParallelKwargs 类

# 从 ema_pytorch 模块中导入 EMA 类

# 从 PIL 模块中导入 Image, ImageFile 类
from PIL import Image, ImageFile
# 设置 ImageFile.LOAD_TRUNCATED_IMAGES 为 True

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 什么也不做
def noop(*args, **kwargs):
    pass

# 查找满足条件的元素的索引
def find_index(arr, cond):
    for ind, el in enumerate(arr):
        if cond(el):
            return ind
    return None

# 查找并弹出满足条件的元素
def find_and_pop(arr, cond, default = None):
    ind = find_index(arr, cond)

    if exists(ind):
        return arr.pop(ind)

    if callable(default):
        return default()

    return default

# 无限循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积更新日志
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 将输入转换为元组
def pair(val):
    return val if isinstance(val, tuple) else (val, val)

# 将图像转换为指定格式
def convert_image_to_fn(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 与图像相关的辅助函数和数据集

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 主训练器类

# 使用 beartype 装饰器定义 VQGanVAETrainer 类
@beartype
class VQGanVAETrainer(nn.Module):
    def __init__(
        self,
        vae: VQGanVAE,
        *,
        folder,
        num_train_steps,
        batch_size,
        image_size,
        lr = 3e-4,
        grad_accum_every = 1,
        max_grad_norm = None,
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        use_ema = True,
        ema_beta = 0.995,
        ema_update_after_step = 0,
        ema_update_every = 1,
        apply_grad_penalty_every = 4,
        accelerate_kwargs: dict = dict()
        ):
        # 调用父类的构造函数
        super().__init__()

        # 实例化加速器
        kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', [])

        # 查找并弹出 DistributedDataParallelKwargs 对象
        ddp_kwargs = find_and_pop(
            kwargs_handlers,
            lambda x: isinstance(x, DistributedDataParallelKwargs),
            partial(DistributedDataParallelKwargs, find_unused_parameters = True)
        )

        # 设置参数 find_unused_parameters 为 True
        ddp_kwargs.find_unused_parameters = True
        kwargs_handlers.append(ddp_kwargs)
        accelerate_kwargs.update(kwargs_handlers = kwargs_handlers)

        # 实例化加速器对象
        self.accelerator = Accelerator(**accelerate_kwargs)

        # 设置 VAE 模型
        self.vae = vae

        # 设置训练参数
        self.register_buffer('steps', torch.Tensor([0]))
        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 获取所有参数和判别器参数
        all_parameters = set(vae.parameters())
        discr_parameters = set(vae.discr.parameters())
        vae_parameters = all_parameters - discr_parameters
        self.vae_parameters = vae_parameters

        # 设置优化器
        self.optim = Adam(vae_parameters, lr = lr)
        self.discr_optim = Adam(discr_parameters, lr = lr)
        self.max_grad_norm = max_grad_norm
        self.discr_max_grad_norm = discr_max_grad_norm

        # 创建数据集
        self.ds = ImageDataset(folder, image_size)

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 创建数据加载器
        self.dl = DataLoader(
            self.ds,
            batch_size = batch_size,
            shuffle = True
        )

        self.valid_dl = DataLoader(
            self.valid_ds,
            batch_size = batch_size,
            shuffle = True
        )

        # 使用加速器准备模型和数据加载器
        (
            self.vae,
            self.optim,
            self.discr_optim,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.vae,
            self.optim,
            self.discr_optim,
            self.dl,
            self.valid_dl
        )

        # 设置是否使用指数移动平均
        self.use_ema = use_ema

        # 如果使用指数移动平均,创建 EMA 对象并使用加速器准备
        if use_ema:
            self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
            self.ema_vae = self.accelerator.prepare(self.ema_vae)

        # 创建数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置保存模型和结果的频率
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置应用梯度惩罚的频率
        self.apply_grad_penalty_every = apply_grad_penalty_every

        # 设置结果文件夹路径
        self.results_folder = Path(results_folder)

        # 如果结果文件夹不为空,询问是否清除之前的实验检查点和结果
        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)

    # 保存模型
    def save(self, path):
        # 如果不是本地主进程,则返回
        if not self.accelerator.is_local_main_process:
            return

        # 保存模型参数和优化器状态字典
        pkg = dict(
            model = self.accelerator.get_state_dict(self.vae),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
        )
        torch.save(pkg, path)
    # 加载模型参数和优化器状态
    def load(self, path):
        # 将路径转换为Path对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 加载模型参数
        pkg = torch.load(path)

        # 获取未封装的VAE模型
        vae = self.accelerator.unwrap_model(self.vae)
        # 加载模型参数
        vae.load_state_dict(pkg['model'])

        # 加载优化器状态
        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 返回设备
    @property
    def device(self):
        return self.accelerator.device

    # 返回是否分布式
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 返回是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 返回是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process
    # 定义训练步骤函数
    def train_step(self):
        # 获取设备信息
        device = self.device

        # 获取当前步数
        steps = int(self.steps.item())
        # 判断是否需要应用梯度惩罚
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

        # 设置 VAE 模型为训练模式
        self.vae.train()
        # 获取鉴别器模型
        discr = self.vae.module.discr if self.is_distributed else self.vae.discr
        # 如果使用指数移动平均模型,获取指数移动平均 VAE 模型
        if self.use_ema:
            ema_vae = self.ema_vae.module if self.is_distributed else self.ema_vae

        # 初始化日志字典
        logs = {}

        # 更新 VAE(生成器)

        # 根据梯度累积次数进行更新
        for _ in range(self.grad_accum_every):
            # 获取下一个数据批次
            img = next(self.dl_iter)
            img = img.to(device)

            # 使用自动混合精度计算损失
            with self.accelerator.autocast():
                # 计算 VAE 模型的损失
                loss = self.vae(
                    img,
                    add_gradient_penalty = apply_grad_penalty,
                    return_loss = True
                )

            # 反向传播
            self.accelerator.backward(loss / self.grad_accum_every)

            # 累积损失日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数,对梯度进行裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.vae.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 更新鉴别器

        if exists(discr):
            self.discr_optim.zero_grad()

            for _ in range(self.grad_accum_every):
                img = next(self.dl_iter)
                img = img.to(device)

                loss = self.vae(img, return_discr_loss = True)

                self.accelerator.backward(loss / self.grad_accum_every)

                accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})

            if exists(self.discr_max_grad_norm):
                self.accelerator.clip_grad_norm_(discr.parameters(), self.discr_max_grad_norm)

            self.discr_optim.step()

            # 记录日志
            self.print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")

        # 更新指数移动平均生成器

        if self.use_ema:
            ema_vae.update()

        # 定期采样结果

        if not (steps % self.save_results_every):
            vaes_to_evaluate = ((self.vae, str(steps)),)

            if self.use_ema:
                vaes_to_evaluate = ((ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate

            for model, filename in vaes_to_evaluate:
                model.eval()

                valid_data = next(self.valid_dl_iter)
                valid_data = valid_data.to(device)

                recons = model(valid_data, return_recons = True)

                # 保存图像网格

                imgs_and_recons = torch.stack((valid_data, recons), dim = 0)
                imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')

                imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
                grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))

                logs['reconstructions'] = grid

                save_image(grid, str(self.results_folder / f'{filename}.png'))

            self.print(f'{steps}: saving to {str(self.results_folder)}')

        # 定期保存模型
        self.accelerator.wait_for_everyone()
        if self.is_main and not (steps % self.save_model_every):
            state_dict = self.accelerator.unwrap_model(self.vae).state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.pt')
            self.accelerator.save(state_dict, model_path)

            if self.use_ema:
                ema_state_dict = self.accelerator.unwrap_model(self.ema_vae).state_dict()
                model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
                self.accelerator.save(ema_state_dict, model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数并返回日志
        self.steps += 1
        return logs
    # 定义一个训练方法,接受一个日志函数作为参数,默认为一个空操作函数
    def train(self, log_fn = noop):
        # 获取 VAE 模型参数中的设备信息
        device = next(self.vae.parameters()).device

        # 当训练步数小于总训练步数时,执行训练步骤并记录日志
        while self.steps < self.num_train_steps:
            # 执行一次训练步骤,返回日志信息
            logs = self.train_step()
            # 调用日志函数记录日志信息
            log_fn(logs)

        # 打印训练完成信息
        self.print('training complete')

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\vqgan_vae.py

# 导入必要的模块
from pathlib import Path
import copy
import math
from math import sqrt
from functools import partial, wraps

# 导入自定义模块
from vector_quantize_pytorch import VectorQuantize as VQ, LFQ

# 导入 PyTorch 模块
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad

# 导入 torchvision 模块
import torchvision

# 导入 einops 模块
from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 定义常量
MList = nn.ModuleList

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 装饰器

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 移除 VGG 属性装饰器
def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, '_vgg')
        if has_vgg:
            vgg = self._vgg
            delattr(self, '_vgg')

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self._vgg = vgg

        return out
    return inner

# 关键字参数辅助函数

# 选择并弹出指定键的值
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, string_input):
    return string_input.startswith(prefix)

# 根据前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀分组并修剪
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 张量辅助函数

# 对数函数
def log(t, eps = 1e-10):
    return torch.log(t + eps)

# 梯度惩罚函数
def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]

    gradients = torch_grad(
        outputs = output,
        inputs = images,
        grad_outputs = torch.ones(output.size(), device = images.device),
        create_graph = True,
        retain_graph = True,
        only_inputs = True
    )[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# Leaky ReLU 函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

# 安全除法函数
def safe_div(numer, denom, eps = 1e-8):
    return numer / denom.clamp(min = eps)

# GAN 损失函数

# Hinge 判别器损失函数
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 生成器损失函数
def hinge_gen_loss(fake):
    return -fake.mean()

# BCE 判别器损失函数
def bce_discr_loss(fake, real):
    return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()

# BCE 生成器损失函数
def bce_gen_loss(fake):
    return -log(torch.sigmoid(fake)).mean()

# 计算损失对层的梯度
def grad_layer_wrt_loss(loss, layer):
    return torch_grad(
        outputs = loss,
        inputs = layer,
        grad_outputs = torch.ones_like(loss),
        retain_graph = True
    )[0].detach()

# VQGAN VAE

# 通道层归一化类
class LayerNormChan(nn.Module):
    def __init__(
        self,
        dim,
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * var.clamp(min = self.eps).rsqrt() * self.gamma

# 判别器类
class Discriminator(nn.Module):
    def __init__(
        self,
        dims,
        channels = 3,
        groups = 16,
        init_kernel_size = 5
    # 定义一个继承自 nn.Module 的类,用于构建一个简单的卷积神经网络
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 将输入维度按照前后两两配对,形成一个维度对的列表
        dim_pairs = zip(dims[:-1], dims[1:])

        # 初始化网络的第一层,包括一个卷积层和激活函数
        self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])

        # 遍历维度对列表,构建网络的中间层,每层包括卷积层、归一化层和激活函数
        for dim_in, dim_out in dim_pairs:
            self.layers.append(nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
                nn.GroupNorm(groups, dim_out),
                leaky_relu()
            ))

        # 获取最后一个维度
        dim = dims[-1]
        # 构建输出层,包括两个卷积层和激活函数,用于生成输出结果
        self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
            nn.Conv2d(dim, dim, 1),
            leaky_relu(),
            nn.Conv2d(dim, 1, 4)
        )

    # 定义前向传播方法,将输入数据通过网络层进行处理,得到输出结果
    def forward(self, x):
        # 遍历网络的每一层,将输入数据依次传递给每一层
        for net in self.layers:
            x = net(x)

        # 返回经过所有网络层处理后的输出结果
        return self.to_logits(x)
# 定义一个名为 ResnetEncDec 的类,用于实现 ResNet 编码器/解码器
class ResnetEncDec(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        *,
        channels = 3,
        layers = 4,
        layer_mults = None,
        num_resnet_blocks = 1,
        resnet_groups = 16,
        first_conv_kernel_size = 5
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言确保维度能够被 resnet_groups 整除
        assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'

        # 初始化 layers 属性
        self.layers = layers

        # 初始化 encoders 和 decoders 为 MList 类型的空列表
        self.encoders = MList([])
        self.decoders = MList([])

        # 如果未提供 layer_mults 参数,则使用默认值
        layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
        # 断言确保 layer_mults 的长度等于 layers
        assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'

        # 计算每一层的维度
        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        # 记录编码后的维度
        self.encoded_dim = dims[-1]

        # 计算每一层的输入输出维度
        dim_pairs = zip(dims[:-1], dims[1:])

        # 定义辅助函数 append 和 prepend
        append = lambda arr, t: arr.append(t)
        prepend = lambda arr, t: arr.insert(0, t)

        # 如果 num_resnet_blocks 不是元组,则转换为元组
        if not isinstance(num_resnet_blocks, tuple):
            num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)

        # 断言确保 num_resnet_blocks 的长度等于 layers
        assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'

        # 遍历每一层,构建编码器和解码器
        for layer_index, (dim_in, dim_out), layer_num_resnet_blocks in zip(range(layers), dim_pairs, num_resnet_blocks):
            # 添加卷积层和激活函数到编码器
            append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
            # 添加反卷积层和激活函数到解码器
            prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))

            # 添加 ResBlock 或 GLUResBlock 到编码器和解码器
            for _ in range(layer_num_resnet_blocks):
                append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
                prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))

        # 添加第一层卷积层到编码器
        prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
        # 添加最后一层卷积层到解码器
        append(self.decoders, nn.Conv2d(dim, channels, 1))

    # 获取编码后特征图的大小
    def get_encoded_fmap_size(self, image_size):
        return image_size // (2 ** self.layers)

    # 返回最后一层解码器的权重
    @property
    def last_dec_layer(self):
        return self.decoders[-1].weight

    # 编码函数
    def encode(self, x):
        for enc in self.encoders:
            x = enc(x)
        return x

    # 解码函数
    def decode(self, x):
        for dec in self.decoders:
            x = dec(x)
        return x

# 定义 GLUResBlock 类,继承自 nn.Module
class GLUResBlock(nn.Module):
    # 初始化函数,接受通道数和组数参数
    def __init__(self, chan, groups = 16):
        # 调用父类的初始化函数
        super().__init__()
        # 定义网络结构
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan * 2, 3, padding = 1),
            nn.GLU(dim = 1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan * 2, 3, padding = 1),
            nn.GLU(dim = 1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan, 1)
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x) + x

# 定义 ResBlock 类,继承自 nn.Module
class ResBlock(nn.Module):
    # 初始化函数,接受通道数和组数参数
    def __init__(self, chan, groups = 16):
        # 调用父类的初始化函数
        super().__init__()
        # 定义网络结构
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 1)
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x) + x

# 定义 VQGanVAE 类,继承自 nn.Module
class VQGanVAE(nn.Module):
    # 初始化函数,设置模型的各种参数
    def __init__(
        self,
        *,
        dim,  # 模型的维度
        channels = 3,  # 输入图像的通道数,默认为3
        layers = 4,  # 模型的层数,默认为4
        l2_recon_loss = False,  # 是否使用L2重构损失,默认为False
        use_hinge_loss = True,  # 是否使用hinge loss,默认为True
        vgg = None,  # VGG模型,默认为None
        lookup_free_quantization = True,  # 是否使用无查找表的量化,默认为True
        codebook_size = 65536,  # 量化码书的大小,默认为65536
        vq_kwargs: dict = dict(  # VQ模型的参数,默认为一些参数设置
            codebook_dim = 256,
            decay = 0.8,
            commitment_weight = 1.,
            kmeans_init = True,
            use_cosine_sim = True,
        ),
        lfq_kwargs: dict = dict(  # LFQ模型的参数,默认为一些参数设置
            diversity_gamma = 4.
        ),
        use_vgg_and_gan = True,  # 是否使用VGG和GAN,默认为True
        discr_layers = 4,  # 判别器的层数,默认为4
        **kwargs  # 其他参数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数按照前缀分组并修剪
        vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
        encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)

        # 设置模型的一些属性
        self.channels = channels
        self.codebook_size = codebook_size
        self.dim_divisor = 2 ** layers

        enc_dec_klass = ResnetEncDec

        # 创建编码器解码器对象
        self.enc_dec = enc_dec_klass(
            dim = dim,
            channels = channels,
            layers = layers,
            **encdec_kwargs
        )

        self.lookup_free_quantization = lookup_free_quantization

        # 根据是否使用无查找表的量化选择量化器类型
        if lookup_free_quantization:
            self.quantizer = LFQ(
                dim = self.enc_dec.encoded_dim,
                codebook_size = codebook_size,
                **lfq_kwargs
            )
        else:
            self.quantizer = VQ(
                dim = self.enc_dec.encoded_dim,
                codebook_size = codebook_size,
                accept_image_fmap = True,
                **vq_kwargs
            )

        # 重构损失函数选择
        self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss

        # 如果是灰度图像,则关闭GAN和感知损失
        self._vgg = None
        self.discr = None
        self.use_vgg_and_gan = use_vgg_and_gan

        if not use_vgg_and_gan:
            return

        # 感知损失
        if exists(vgg):
            self._vgg = vgg

        # GAN相关损失
        layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        self.discr = Discriminator(dims = dims, channels = channels)

        self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
        self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 获取VGG模型
    @property
    def vgg(self):
        if exists(self._vgg):
            return self._vgg

        vgg = torchvision.models.vgg16(pretrained = True)
        vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
        self._vgg = vgg.to(self.device)
        return self._vgg

    # 获取编码后的维度
    @property
    def encoded_dim(self):
        return self.enc_dec.encoded_dim

    # 获取编码特征图的大小
    def get_encoded_fmap_size(self, image_size):
        return self.enc_dec.get_encoded_fmap_size(image_size)

    # 复制模型用于评估
    def copy_for_eval(self):
        device = next(self.parameters()).device
        vae_copy = copy.deepcopy(self.cpu())

        if vae_copy.use_vgg_and_gan:
            del vae_copy.discr
            del vae_copy._vgg

        vae_copy.eval()
        return vae_copy.to(device)

    # 获取模型状态字典
    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    # 加载模型状态字典
    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    # 保存模型
    def save(self, path):
        torch.save(self.state_dict(), path)

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert path.exists()
        state_dict = torch.load(str(path))
        self.load_state_dict(state_dict)

    # 编码函数
    def encode(self, fmap):
        fmap = self.enc_dec.encode(fmap)
        fmap, indices, vq_aux_loss = self.quantizer(fmap)
        return fmap, indices, vq_aux_loss
    # 从编码后的 ids 解码生成图像
    def decode_from_ids(self, ids):
        
        # 如果启用了自由量化查找,则将 ids 打包成字节流
        if self.lookup_free_quantization:
            ids, ps = pack([ids], 'b *')
            # 使用量化器将 ids 转换为 codes
            fmap = self.quantizer.indices_to_codes(ids)
            # 解码 codes 生成 fmap
            fmap, = unpack(fmap, ps, 'b * c')
        else:
            # 根据 ids 获取 codebook 中对应的 codes
            codes = self.codebook[ids]
            # 投影 codes 生成 fmap
            fmap = self.quantizer.project_out(codes)

        # 重新排列 fmap 的维度
        fmap = rearrange(fmap, 'b h w c -> b c h w')
        # 调用 decode 方法生成图像
        return self.decode(fmap)

    # 解码生成图像
    def decode(self, fmap):
        return self.enc_dec.decode(fmap)

    # 前向传播函数
    def forward(
        self,
        img,
        return_loss = False,
        return_discr_loss = False,
        return_recons = False,
        add_gradient_penalty = True
    ):
        # 获取图像的批次、通道数、高度、宽度和设备信息
        batch, channels, height, width, device = *img.shape, img.device

        # 检查高度和宽度是否能被 dim_divisor 整除
        for dim_name, size in (('height', height), ('width', width)):
            assert (size % self.dim_divisor) == 0, f'{dim_name} must be divisible by {self.dim_divisor}'

        # 检查通道数是否与 VQGanVAE 中设置的通道数相等
        assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'

        # 编码输入图像
        fmap, indices, commit_loss = self.encode(img)

        # 解码生成图像
        fmap = self.decode(fmap)

        # 如果不需要返回损失,则直接返回生成图像
        if not return_loss and not return_discr_loss:
            return fmap

        # 确保只返回自编码器损失或鉴别器损失
        assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'

        # 如果需要返回鉴别器损失
        if return_discr_loss:
            assert exists(self.discr), 'discriminator must exist to train it'

            # 分离 fmap 的梯度
            fmap.detach_()
            img.requires_grad_()

            # 获取 fmap 和输入图像的鉴别器 logits
            fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))

            # 计算鉴别器损失
            discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

            # 添加梯度惩罚
            if add_gradient_penalty:
                gp = gradient_penalty(img, img_discr_logits)
                loss = discr_loss + gp

            # 如果需要返回重构图像,则返回损失和 fmap
            if return_recons:
                return loss, fmap

            return loss

        # 计算重构损失
        recon_loss = self.recon_loss_fn(fmap, img)

        # 如果不使用 VGG 和 GAN,则直接返回重构损失
        if not self.use_vgg_and_gan:
            if return_recons:
                return recon_loss, fmap

            return recon_loss

        # 计算感知损失
        img_vgg_input = img
        fmap_vgg_input = fmap

        if img.shape[1] == 1:
            # 处理灰度图像用于 VGG
            img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))

        img_vgg_feats = self.vgg(img_vgg_input)
        recon_vgg_feats = self.vgg(fmap_vgg_input)
        perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)

        # 生成器损失
        gen_loss = self.gen_loss(self.discr(fmap))

        # 计算自适应权重
        last_dec_layer = self.enc_dec.last_dec_layer

        norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
        norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

        adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
        adaptive_weight.clamp_(max = 1e4)

        # 组合损失
        loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss

        # 如果需要返回重构图像,则返回损失和 fmap
        if return_recons:
            return loss, fmap

        return loss

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\__init__.py

# 从 muse_maskgit_pytorch.vqgan_vae 模块中导入 VQGanVAE 类
from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
# 从 muse_maskgit_pytorch.muse_maskgit_pytorch 模块中导入 Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic 类
from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic
# 从 muse_maskgit_pytorch.trainers 模块中导入 VQGanVAETrainer 类
from muse_maskgit_pytorch.trainers import VQGanVAETrainer

Muse - Pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

Install

$ pip install muse-maskgit-pytorch

Usage

First train your VAE - VQGanVAE

import torch
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer

vae = VQGanVAE(
    dim = 256,
    codebook_size = 65536
)

# train on folder of images, as many images as possible

trainer = VQGanVAETrainer(
    vae = vae,
    image_size = 128,             # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it
    folder = '/path/to/images',
    batch_size = 4,
    grad_accum_every = 8,
    num_train_steps = 50000
).cuda()

trainer.train()

Then pass the trained VQGanVAE and a Transformer to MaskGit

import torch
from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer

# first instantiate your vae

vae = VQGanVAE(
    dim = 256,
    codebook_size = 65536
).cuda()

vae.load('/path/to/vae.pt') # you will want to load the exponentially moving averaged VAE

# then you plug the vae and transformer into your MaskGit as so

# (1) create your transformer / attention network

transformer = MaskGitTransformer(
    num_tokens = 65536,       # must be same as codebook size above
    seq_len = 256,            # must be equivalent to fmap_size ** 2 in vae
    dim = 512,                # model dimension
    depth = 8,                # depth
    dim_head = 64,            # attention head dimension
    heads = 8,                # attention heads,
    ff_mult = 4,              # feedforward expansion factor
    t5_name = 't5-small',     # name of your T5
)

# (2) pass your trained VAE and the base transformer to MaskGit

base_maskgit = MaskGit(
    vae = vae,                 # vqgan vae
    transformer = transformer, # transformer
    image_size = 256,          # image size
    cond_drop_prob = 0.25,     # conditional dropout, for classifier free guidance
).cuda()

# ready your training text and images

texts = [
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 256, 256).cuda()

# feed it into your maskgit instance, with return_loss set to True

loss = base_maskgit(
    images,
    texts = texts
)

loss.backward()

# do this for a long time on much data
# then...

images = base_maskgit.generate(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.) # conditioning scale for classifier free guidance

images.shape # (3, 3, 256, 256)

To train the super-resolution maskgit requires you to change 1 field on MaskGit instantiation (you will need to now pass in the cond_image_size, as the previous image size being conditioned on)

Optionally, you can pass in a different VAE as cond_vae for the conditioning low-resolution image. By default it will use the vae for both tokenizing the super and low resoluted images.

import torch
import torch.nn.functional as F
from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer

# first instantiate your ViT VQGan VAE
# a VQGan VAE made of transformers

vae = VQGanVAE(
    dim = 256,
    codebook_size = 65536
).cuda()

vae.load('./path/to/vae.pt') # you will want to load the exponentially moving averaged VAE

# then you plug the VqGan VAE into your MaskGit as so

# (1) create your transformer / attention network

transformer = MaskGitTransformer(
    num_tokens = 65536,       # must be same as codebook size above
    seq_len = 1024,           # must be equivalent to fmap_size ** 2 in vae
    dim = 512,                # model dimension
    depth = 2,                # depth
    dim_head = 64,            # attention head dimension
    heads = 8,                # attention heads,
    ff_mult = 4,              # feedforward expansion factor
    t5_name = 't5-small',     # name of your T5
)

# (2) pass your trained VAE and the base transformer to MaskGit

superres_maskgit = MaskGit(
    vae = vae,
    transformer = transformer,
    cond_drop_prob = 0.25,
    image_size = 512,                     # larger image size
    cond_image_size = 256,                # conditioning image size <- this must be set
).cuda()

# ready your training text and images

texts = [
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 512, 512).cuda()

# feed it into your maskgit instance, with return_loss set to True

loss = superres_maskgit(
    images,
    texts = texts
)

loss.backward()

# do this for a long time on much data
# then...

images = superres_maskgit.generate(
    texts = [
        'a whale breaching from afar',
        'young girl blowing out candles on her birthday cake',
        'fireworks with blue and green sparkles',
        'waking up to a psychedelic landscape'
    ],
    cond_images = F.interpolate(images, 256),  # conditioning images must be passed in for generating from superres
    cond_scale = 3.
)

images.shape # (4, 3, 512, 512)

All together now

from muse_maskgit_pytorch import Muse

base_maskgit.load('./path/to/base.pt')

superres_maskgit.load('./path/to/superres.pt')

# pass in the trained base_maskgit and superres_maskgit from above

muse = Muse(
    base = base_maskgit,
    superres = superres_maskgit
)

images = muse([
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'waking up to a psychedelic landscape'
])

images # List[PIL.Image.Image]

Appreciation

  • StabilityAI for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • 🤗 Huggingface for the transformers and accelerate library, both which are wonderful

Todo

Citations

@inproceedings{Chang2023MuseTG,
    title   = {Muse: Text-To-Image Generation via Masked Generative Transformers},
    author  = {Huiwen Chang and Han Zhang and Jarred Barber and AJ Maschinot and Jos{\'e} Lezama and Lu Jiang and Ming-Hsuan Yang and Kevin P. Murphy and William T. Freeman and Michael Rubinstein and Yuanzhen Li and Dilip Krishnan},
    year    = {2023}
}
@article{Chen2022AnalogBG,
    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
    author  = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.04202}
}
@misc{jabri2022scalable,
    title   = {Scalable Adaptive Computation for Iterative Generation},
    author  = {Allan Jabri and David Fleet and Ting Chen},
    year    = {2022},
    eprint  = {2212.11972},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@inproceedings{Nijkamp2021SCRIPTSP,
    title   = {SCRIPT: Self-Critic PreTraining of Transformers},
    author  = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
    booktitle = {North American Chapter of the Association for Computational Linguistics},
    year    = {2021}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@misc{mentzer2023finite,
    title   = {Finite Scalar Quantization: VQ-VAE Made Simple},
    author  = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
    year    = {2023},
    eprint  = {2309.15505},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\muse-maskgit-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'muse-maskgit-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.3.5',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/muse-maskgit-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'text-to-image'
  ],
  # 安装依赖列表
  install_requires=[
    'accelerate',
    'beartype',
    'einops>=0.7',
    'ema-pytorch>=0.2.2',
    'memory-efficient-attention-pytorch>=0.1.4',
    'pillow',
    'sentencepiece',
    'torch>=1.6',
    'transformers',
    'torch>=1.6',
    'torchvision',
    'tqdm',
    'vector-quantize-pytorch>=1.11.8'
  ],
  # 分类列表
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\musiclm-pytorch\musiclm_pytorch\distributed.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function
# 从 torch.distributed 模块中导入 dist 模块
import torch.distributed as dist
# 从 einops 库中导入 rearrange 函数

from einops import rearrange

# 分布式辅助函数

# 定义一个函数,用于在所有进程中收集具有相同维度的张量
def all_gather_same_dim(t):
    # 获取世界大小
    world_size = dist.get_world_size()
    # 创建一个空列表,用于存储收集到的张量
    gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
    # 在所有进程中收集张量
    dist.all_gather(gathered_tensors, t)
    return gathered_tensors

# 定义一个函数,用于在所有进程中收集具有可变维度的张量
def all_gather_variable_dim(t, dim = 0, sizes = None):
    # 获取设备、进程编号和世界大小
    device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()

    # 如果 sizes 不存在
    if not exists(sizes):
        # 创建一个张量,表示张量在指定维度上的大小
        size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
        # 在所有进程中收集大小信息
        sizes = all_gather_same_dim(size)
        sizes = torch.stack(sizes)

    # 如果所有进程收集到的大小信息都相同
    if torch.unique(sizes).numel() == 1:
        # 在所有进程中收集张量
        gathered_tensors = all_gather_same_dim(t)
        return torch.cat(gathered_tensors, dim = dim), sizes

    # 获取最大的大小
    max_size = sizes.amax().item()

    # 将张量在指定维度上填充到最大大小
    padded_t = pad_dim_to(t, max_size, dim = dim)
    # 在所有进程中收集填充后的张量
    gathered_tensors = all_gather_same_dim(padded_t)

    # 拼接所有进程中收集到的张量
    gathered_tensor = torch.cat(gathered_tensors, dim = dim)
    # 创建一个序列
    seq = torch.arange(max_size, device = device)

    # 创建一个掩码,用于选择有效的数据
    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    # 根据掩码选择有效的数据
    gathered_tensor = gathered_tensor.index_select(dim, indices)

    return gathered_tensor, sizes

# 定义一个自定义函数类 AllGatherFunction
class AllGatherFunction(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes, all_reduce_grads):
        # 调用 all_gather_variable_dim 函数
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.dim = dim
        ctx.all_reduce_grads = all_reduce_grads
        ctx.batch_sizes = batch_sizes.tolist()
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        # 获取批次大小和进程编号
        batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
        # 如果需要对梯度进行全局归约
        if ctx.all_reduce_grads:
            dist.all_reduce(grads)

        # 根据批次大小拆分梯度
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None, None

# 定义一个类 AllGather,继承自 nn.Module
class AllGather(nn.Module):
    def __init__(
        self,
        dim,
        *,
        all_reduce_grads = False
    ):
        super().__init__()
        self.dim = dim
        self.all_reduce_grads = all_reduce_grads
        # 判断是否处于分布式环境中
        self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1

    def forward(
        self,
        x,
        sizes = None
    ):
        # 如果不处于分布式环境中,直接返回输入张量
        if not self.is_distributed:
            return x, None

        # 调用 AllGatherFunction 类的 apply 方法
        return AllGatherFunction.apply(x, self.dim, sizes, self.all_reduce_grads)

.\lucidrains\musiclm-pytorch\musiclm_pytorch\musiclm_pytorch.py

# 导入数学库
import math
# 从 functools 库中导入 wraps 和 partial 函数
from functools import wraps, partial

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum

# 从 torchaudio.transforms 模块中导入 Spectrogram, TimeStretch, FrequencyMasking, TimeMasking 函数
from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking

# 从 audiolm_pytorch 库中导入 AudioLM 类和 AudioConditionerBase 类
from audiolm_pytorch import AudioLM
from audiolm_pytorch.utils import AudioConditionerBase

# 从 torch.distributed 模块中导入 dist 函数
import torch.distributed as dist
# 从 musiclm_pytorch.distributed 模块中导入 AllGather 函数
from musiclm_pytorch.distributed import AllGather

# 从 x_clip.tokenizer 模块中导入 tokenizer 函数
from x_clip.tokenizer import tokenizer
# 从 vector_quantize_pytorch 库中导入 ResidualVQ 类
from vector_quantize_pytorch import ResidualVQ

# 从 einops 模块中导入 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 函数
from einops.layers.torch import Rearrange

# 从 beartype.typing 模块中导入 List, Optional, Tuple 类
# 从 beartype 模块中导入 beartype 函数
from beartype.typing import List, Optional, Tuple
from beartype import beartype

# functions

# 定义函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义函数 first,返回列表的第一个元素
def first(it):
    return it[0]

# 定义函数 default,如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数 round_down_nearest_multiple,返回最接近的整数倍数
def round_down_nearest_multiple(n, divisor):
    return n // divisor * divisor

# 定义函数 Sequential,返回过滤掉空值后的 nn.Sequential 对象
def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# decorators

# 定义装饰器 once,确保函数只调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 使用 once 装饰 print 函数,确保只打印一次
print_once = once(print)

# tensor functions

# 定义函数 log,计算张量的对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 定义函数 l2norm,计算张量的 L2 范数
def l2norm(t):
    return F.normalize(t, p = 2, dim = -1)

# 定义函数 matrix_diag,返回张量的对角线元素
def matrix_diag(t):
    device = t.device
    i, j = t.shape[-2:]
    num_diag_el = min(i, j)
    i_range = torch.arange(i, device = device)
    j_range = torch.arange(j, device = device)
    diag_mask = rearrange(i_range, 'i -> i 1') == rearrange(j_range, 'j -> 1 j')
    diag_el = t.masked_select(diag_mask)
    return rearrange(diag_el, '(b d) -> b d', d = num_diag_el)

# 2d sinusoidal positional embedding
# simple vit paper shows it is good enough compared to learned

# 定义函数 posemb_sincos_2d,生成二维正弦余弦位置嵌入
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'

    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 

    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    pe = pe.type(dtype)

    return rearrange(pe, '(h w) d -> h w d', h = h, w = w)

# biasless layernorm

# 定义类 LayerNorm,实现无偏差的 LayerNorm
class LayerNorm(nn.Module):
    def __init__(self, dim, scale = True):
        super().__init__()
        self.learned_gamma = nn.Parameter(torch.ones(dim)) if scale else None

        self.register_buffer('gamma', torch.ones(dim), persistent = False)
        self.register_buffer('beta', torch.zeros(dim), persistent = False)

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], default(self.learned_gamma, self.gamma), self.beta)

# feedforward

# 定义类 GEGLU,实现 GEGLU 激活函数
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

# 定义函数 FeedForward,实现前馈神经网络
def FeedForward(dim, mult = 4, dropout = 0.):
    dim_hidden = int(dim * mult * 2 / 3)

    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, dim_hidden * 2, bias = False),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_hidden, dim, bias = False)
    )

# attention

# 定义类 Attention,实现注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        scale = 8
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化头数和缩放比例
        self.heads = heads
        self.scale = scale
        self.causal = causal
        # 计算每个头的内部维度
        inner_dim = dim_head * heads

        # 初始化 LayerNorm 层
        self.norm = LayerNorm(dim)

        # 初始化注意力机制的 dropout 层
        self.attn_dropout = nn.Dropout(dropout)

        # 初始化查询、键、值的线性变换层
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        # 初始化查询和键的缩放参数
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 初始化输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        rel_pos_bias = None,
        mask = None
    ):
        # 获取输入张量 x 的形状和设备信息
        b, n, _, device = *x.shape, x.device

        # 对输入进行 LayerNorm 处理
        x = self.norm(x)

        # 对输入进行查询、键、值的线性变换
        q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)

        # 将查询、键、值分割为多头注意力
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        # 对查询和键进行 L2 归一化
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # 如果存在相对位置偏置,则加上
        if exists(rel_pos_bias):
            sim = sim + rel_pos_bias

        # 如果存在掩码,则进行掩码处理
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # ���果启用因果注意力,则进行因果掩码处理
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力权重计算
        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并多头
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# transformer

# 定义 Transformer 类,用于实现 Transformer 模型
class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        # 循环创建 Transformer 层,并添加到 layers 中
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
            ]))

    # 前向传播函数
    def forward(
        self,
        x,
        rel_pos_bias = None,
        mask = None,
        return_all_layers = False
    ):
        layers = []

        # 遍历 Transformer 层,依次进行注意力计算和前馈计算
        for attn, ff in self.layers:
            x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x
            x = ff(x) + x
            layers.append(x)

        # 如果不需要返回所有层的结果,则返回最后一层的结果
        if not return_all_layers:
            return x

        # 返回所有层的结果
        return x, torch.stack(layers[:-1])

# contrastive losses

# 定义 SoftmaxContrastiveLearning 类,用于实现 Softmax 对比学习
class SoftmaxContrastiveLearning(nn.Module):
    def __init__(
        self,
        *,
        layers = 1,
        decoupled_contrastive_learning = False,
        init_temp = 10
    ):
        super().__init__()
        self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
        self.decoupled_contrastive_learning = decoupled_contrastive_learning

        self.all_gather = AllGather(dim = 2)

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(self, audio_latents, text_latents):
        if audio_latents.ndim == 2:
            audio_latents = rearrange(audio_latents, '... -> 1 ...')

        if text_latents.ndim == 2:
            text_latents = rearrange(text_latents, '... -> 1 ...')

        batch = audio_latents.shape[1]

        # 分布式环境下,进行数据分发
        if self.all_gather.is_distributed:
            latents = torch.stack((audio_latents, text_latents))
            latents, _ = self.all_gather(latents)
            audio_latents, text_latents = latents

        # 计算相似度矩阵
        sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)

        sims = sims * self.temperatures.exp()

        cosine_sims_exp = sims.exp()

        numerator = matrix_diag(cosine_sims_exp)

        # 如果使用分离式对比学习,进行额外处理
        if self.decoupled_contrastive_learning:
            eye = torch.eye(batch, device = self.device, dtype = torch.bool)
            cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)

        denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
        denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')

        contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))

        contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
        return contrastive_loss.sum()

# 定义 SigmoidContrastiveLearning 类,用于实现 Sigmoid 对比学习
class SigmoidContrastiveLearning(nn.Module):
    """ https://arxiv.org/abs/2303.15343 """

    def __init__(
        self,
        *,
        layers = 1,
        init_temp = 10,
        init_bias = -10
    ):
        super().__init__()
        self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
        self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)

        self.all_gather = AllGather(dim = 1, all_reduce_grads = True)

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device
    # 定义一个前向传播函数,接受音频和文本的潜在表示作为输入
    def forward(self, audio_latents, text_latents):
        # 获取当前设备
        device = self.device

        # 如果音频潜在表示的维度为2,则重新排列为 '... -> 1 ...'
        if audio_latents.ndim == 2:
            audio_latents = rearrange(audio_latents, '... -> 1 ...')

        # 如果文本潜在表示的维度为2,则重新排列为 '... -> 1 ...'
        if text_latents.ndim == 2:
            text_latents = rearrange(text_latents, '... -> 1 ...')

        # 使用all_gather函数将文本潜在表示广播到所有设备上,并返回广播后的结果和每个设备上的大小
        text_latents, rank_sizes = self.all_gather(text_latents)

        # 获取文本潜在表示的第二维大小
        n = text_latents.shape[1]

        # 计算音频潜在表示和文本潜在表示之间的相似度
        sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)

        # 对相似度进行温度调节和偏置处理
        sims = sims * self.temperatures.exp() + self.bias

        # 创建一个对角线为1的标签矩阵
        labels = torch.eye(n, device=device)

        # 如果rank_sizes存在,则根据rank_sizes将标签拆分为不同的部分
        if exists(rank_sizes):
            labels_by_ranks = labels.split(rank_sizes.tolist(), dim=0)
            labels = labels_by_ranks[dist.get_rank()]

        # 将标签矩阵重新排列为 'i j -> 1 i j',并进行处理得到最终的标签
        labels = 2 * rearrange(labels, 'i j -> 1 i j') - torch.ones_like(sims)

        # 计算损失函数,返回负对数sigmoid损失的总和除以n
        return -F.logsigmoid(labels * sims).sum() / n
# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778

# 定义一个函数,用于将输入转换为元组
def pair(t):
    return (t, t) if not isinstance(t, tuple) else t

# 定义一个音频频谱变换器类
class AudioSpectrogramTransformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        patch_size = 16,
        dim_head = 64,
        heads = 8,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        accept_spec = False,
        accept_spec_time_first = True,
        spec_n_fft = 128,
        spec_power = 2,
        spec_win_length = 24,
        spec_hop_length = None,
        spec_pad = 0,
        spec_center = True,
        spec_pad_mode = 'reflect',
        spec_aug_stretch_factor = 0.8,
        spec_aug_freq_mask = 80,
        spec_aug_time_mask = 80,
        patch_dropout_prob = 0.25
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth

        self.patch_size = pair(patch_size)
        patch_input_dim = self.patch_size[0] * self.patch_size[1]

        # 将输入转换为补丁令牌
        self.to_patch_tokens = Sequential(
            Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1 = self.patch_size[0], p2 = self.patch_size[1]),
            nn.LayerNorm(patch_input_dim),
            nn.Linear(patch_input_dim, dim),
            nn.LayerNorm(dim)
        )

        self.accept_spec = accept_spec
        self.accept_spec_time_first = accept_spec_time_first

        # 创建频谱对象
        self.spec = Spectrogram(
            n_fft = spec_n_fft,
            power = spec_power,
            win_length = spec_win_length,
            hop_length = spec_hop_length,
            pad = spec_pad,
            center = spec_center,
            pad_mode = spec_pad_mode
        )

        # SpecAugment - 在音频领域中被广泛使用
        self.aug = torch.nn.Sequential(
            TimeStretch(spec_aug_stretch_factor, fixed_rate = True),
            FrequencyMasking(freq_mask_param = spec_aug_freq_mask),
            TimeMasking(time_mask_param = spec_aug_time_mask),
        )

        # 创建变换器
        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_mult = ff_mult,
            ff_dropout = ff_dropout
        )

        self.norm = LayerNorm(dim)

        # 补丁丢弃概率
        self.patch_dropout_prob = patch_dropout_prob

        # 2D动态位置偏差
        mlp_hidden_dim = dim // 4

        self.dynamic_pos_bias_mlp = nn.Sequential(
            nn.Linear(2, mlp_hidden_dim),
            nn.SiLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.SiLU(),
            nn.Linear(mlp_hidden_dim, heads),
            Rearrange('... i j h -> ... h i j')
        )

    def forward(
        self,
        x,
        force_no_patch_dropout = False,
        return_all_layers = False
        ):
        # 获取输入张量的批次大小和设备信息
        batch, device = x.shape[0], x.device
        # 断言输入张量的维度是否符合要求
        assert (self.accept_spec and x.ndim == 3) or (not self.accept_spec and x.ndim == 2)

        if self.accept_spec and self.accept_spec_time_first:
            # 如果接受频谱数据且要求时间维度在前,则重新排列输入张量的维度
            x = rearrange(x, 'b t f -> b f t')

        if not self.accept_spec:
            # 如果不接受频谱数据,则对输入进行频谱转换
            x = self.spec(x)

        if self.training:
            # 如果处于训练模式,则对输入进行数据增强
            x = self.aug(x)

        # 如果音频生成的二维频谱图不是patch大小的整数倍,则自动裁剪
        height, width = x.shape[-2:]
        patch_height, patch_width = self.patch_size

        rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))

        if (height, width) != (rounded_height, rounded_width): # 只是持续打印直到修复
            print_once(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')

        x = x[..., :rounded_height, :rounded_width]

        # 转换为patches
        x = self.to_patch_tokens(x)

        # 获取沿高度和宽度的patch数量
        _, num_patch_height, num_patch_width, _ = x.shape

        # 获取2D相对位置
        grid = torch.stack(torch.meshgrid(
            torch.arange(num_patch_height, device = device),
            torch.arange(num_patch_width, device = device)
        , indexing = 'ij'), dim = -1)

        grid = rearrange(grid, '... c -> (...) c')

        # 2D正弦余弦位置嵌入
        x = x + posemb_sincos_2d(x)

        x = rearrange(x, 'b ... c -> b (...) c')

        # patch丢弃
        if self.training and self.patch_dropout_prob > 0. and not force_no_patch_dropout:
            n, device = x.shape[1], x.device

            batch_indices = torch.arange(batch, device = device)
            batch_indices = rearrange(batch_indices, '... -> ... 1')
            num_patches_keep = max(1, int(n * (1 - self.patch_dropout_prob)))
            patch_indices_keep = torch.randn(batch, n, device = device).topk(num_patches_keep, dim = -1).indices

            x = x[batch_indices, patch_indices_keep]

            grid = repeat(grid, '... -> b ...', b = batch)
            grid = grid[batch_indices, patch_indices_keep]

        # 2D相对位置偏差
        rel_dist = rearrange(grid, '... i c -> ... i 1 c') - rearrange(grid, '... j c -> ... 1 j c')
        rel_pos_bias = self.dynamic_pos_bias_mlp(rel_dist.float())

        # 注意力机制
        x, all_layers = self.transformer(x, rel_pos_bias = rel_pos_bias, return_all_layers = True)

        # 最终全局平均和规范化(最近的论文表明这比CLS token更优越)
        x = reduce(x, 'b n d -> b d', 'mean')

        out = self.norm(x)

        if not return_all_layers:
            return out

        return out, all_layers
# 文本转换器类
class TextTransformer(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,  # 维度
        depth,  # 深度
        num_tokens = tokenizer.vocab_size,  # 标记数量,默认为tokenizer的词汇量
        max_seq_len = 256,  # 最大序列长度,默认为256
        dim_head = 64,  # 头部维度,默认为64
        heads = 8,  # 头部数量,默认为8
        attn_dropout = 0.,  # 注意力丢弃率,默认为0
        ff_dropout = 0.,  # 前馈丢弃率,默认为0
        ff_mult = 4,  # 前馈倍数,默认为4
        pad_id = 0  # 填充标记ID,默认为0
    ):
        super().__init__()
        self.dim = dim  # 维度

        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)  # 位置嵌入层

        self.depth = depth  # 深度
        self.max_seq_len = max_seq_len  # 最大序列长度

        self.cls_token = nn.Parameter(torch.randn(dim))  # 类别标记

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_mult = ff_mult
        )  # 转换器模型

        self.pad_id = pad_id  # 填充标记ID
        self.norm = LayerNorm(dim)  # 归一化层

    # 设备属性
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    @beartype
    def forward(
        self,
        x = None,  # 输入张量,默认为None
        raw_texts: Optional[List[str]] = None,  # 原始文本列表,默认为None
        mask = None,  # 掩码,默认为None
        return_all_layers = False  # 是否返回所有层,默认为False
    ):
        assert exists(x) ^ exists(raw_texts)  # 断言,x和raw_texts必须有且只有一个存在

        if exists(raw_texts):
            x = tokenizer.tokenize(raw_texts).to(self.device)  # 使用tokenizer对原始文本进行标记化,并转移到指定设备

        if not exists(mask):
            mask = x != self.pad_id  # 生成掩码,排除填充标记

        b, n, device = *x.shape, x.device  # 获取张量形状和设备信息

        # 标记嵌入 + 位置嵌入
        x = self.token_emb(x)  # 标记嵌入
        assert n <= self.max_seq_len, f'text sequence length {n} must be less than {self.max_seq_len}'  # 断言,文本序列长度必须小于等于最大序列长度
        x = x + self.pos_emb(torch.arange(n, device = device))  # 加上位置嵌入

        # 类别标记,类似于bert
        cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)  # 重复类别标记
        x, ps = pack([cls_tokens, x], 'b * d')  # 打包张量

        # 考虑使用自注意力掩码对类别标记进行注意力
        mask = F.pad(mask, (1, 0), value = True)  # 对掩码进行填充

        # 注意力
        x, all_layers = self.transformer(x, mask = mask, return_all_layers = True)  # 使用transformer进行注意力计算

        # 解包类别标记
        cls_tokens, _ = unpack(x, ps, 'b * d')  # 解包张量

        out = self.norm(cls_tokens)  # 归一化类别标记

        if not return_all_layers:
            return out  # 返回输出

        return out, all_layers  # 返回输出和所有层

# 分层对比损失
def interspersed_indices(layers, total_layers):
    assert total_layers >= layers  # 断言,总层数必须大于等于层数
    step = total_layers / layers  # 计算步长
    return (torch.arange(0, layers) * step).floor().long()  # 返回分散的索引

# 多层对比损失类
class MultiLayerContrastiveLoss(nn.Module):
    def __init__(
        self,
        *,
        audio_dim,  # 音频维度
        text_dim,  # 文本维度
        dim_latent,  # 潜在维度
        layers,  # 层数
        decoupled_contrastive_learning = False,  # 是否解耦对比学习,默认为False
        sigmoid_contrastive_loss = False  # 是否使用sigmoid对比损失,默认为False
    ):
        super().__init__()
        self.layers = layers  # 层数

        self.audio_norm = LayerNorm(audio_dim, scale = False)  # 音频归一化层
        self.audio_gamma = nn.Parameter(torch.ones(layers, 1, audio_dim))  # 音频gamma参数
        self.audio_latent_weight = nn.Parameter(torch.randn(layers, audio_dim, dim_latent))  # 音频潜在权重
        self.audio_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))  # 音频潜在偏置

        self.text_norm = LayerNorm(text_dim, scale = False)  # 文本归一化层
        self.text_gamma = nn.Parameter(torch.ones(layers, 1, text_dim))  # 文本gamma参数
        self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent))  # 文本潜在权重
        self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))  # 文本潜在偏置

        klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)  # 根据sigmoid_contrastive_loss选择对比学习类
        self.contrast = klass(layers = layers)  # 对比学习实例化
    # 定义一个前向传播函数,接收音频和文本的特征层作为参数
    def forward(self, *, audio_layers, text_layers):
        # 获取设备和批次大小
        device, batch = audio_layers.device, audio_layers.shape[1]

        # 对音频特征层进行降维处理,计算平均值
        audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean')
        # 对音频特征进行归一化处理,并乘以音频的缩放参数
        audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma
        # 使用音频的权重和偏置计算音频的潜在特征
        audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias
        # 对音频的潜在特征进行L2范数归一化处理
        audio_latents = l2norm(audio_latents)

        # 获取文本特征层中的分类标记
        text_cls_tokens = text_layers[:, :, 0]
        # 对文本特征进行归一化处理,并乘以文本的缩放参数
        text_embeds = self.text_norm(text_cls_tokens) * self.text_gamma
        # 使用文本的权重和偏置计算文本的潜在特征
        text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias
        # 对文本的潜在特征进行L2范数归一化处理
        text_latents = l2norm(text_latents)

        # 返回音频和文本潜在特征的对比结果
        return self.contrast(audio_latents, text_latents)
# 主要类

class MuLaN(nn.Module):
    # 初始化 MuLaN 类
    @beartype
    def __init__(
        self,
        audio_transformer: AudioSpectrogramTransformer,
        text_transformer: TextTransformer,
        dim_latent = 128,                       # 设置默认 latent 维度为 128
        decoupled_contrastive_learning = True,  # 是否使用 decoupled 对比学习,默认为 True
        hierarchical_contrastive_loss = False,  # 是否使用 hierarchical 对比损失,默认为 False
        hierarchical_contrastive_loss_layers = None,  # hierarchical 对比损失的层数,默认为 None
        sigmoid_contrastive_loss = False  # 是否使用 sigmoid 对比损失,默认为 False
    ):
        super().__init__()
        self.dim_latent = dim_latent

        self.audio = audio_transformer
        self.text = text_transformer

        # 将文本转换为 latent 向量
        self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
        # 将音频转换为 latent 向量
        self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)

        # 根据 sigmoid_contrastive_loss 的值选择对比学习方法
        klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
        self.contrast = klass()

        self.multi_layer_contrastive_learning = None

        # 如果启用 hierarchical 对比损失
        if hierarchical_contrastive_loss:
            # 计算层数
            num_layers = default(hierarchical_contrastive_loss_layers, min(audio_transformer.depth, text_transformer.depth) - 1)
            assert num_layers > 0

            # 注册文本层索引和音频层索引
            self.register_buffer('text_layers_indices', interspersed_indices(num_layers, text_transformer.depth))
            self.register_buffer('audio_layers_indices', interspersed_indices(num_layers, audio_transformer.depth))

            # 初始化多层对比损失
            self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss(
                audio_dim = self.audio.dim,
                text_dim = self.text.dim,
                dim_latent = dim_latent,
                layers = num_layers,
                decoupled_contrastive_learning = decoupled_contrastive_learning,
                sigmoid_contrastive_loss = sigmoid_contrastive_loss
            )

    # 获取音频 latent 向量
    def get_audio_latents(
        self,
        wavs,
        return_all_layers = False
    ):
        # 获取音频嵌入和层信息
        audio_embeds, audio_layers = self.audio(wavs, return_all_layers = True)
        audio_latents = self.audio_to_latents(audio_embeds)
        out = l2norm(audio_latents)

        if not return_all_layers:
            return out

        return out, audio_layers

    # 获取文本 latent 向量
    @beartype
    def get_text_latents(
        self,
        texts = None,
        raw_texts: Optional[List[str]] = None,
        return_all_layers = False
    ):
        # 获取文本嵌入和层信息
        text_embeds, text_layers = self.text(texts, raw_texts = raw_texts, return_all_layers = True)
        text_latents = self.text_to_latents(text_embeds)
        out = l2norm(text_latents)

        if not return_all_layers:
            return out

        return out, text_layers

    # MuLaN 类的前向传播函数
    @beartype
    def forward(
        self,
        wavs,
        texts = None,
        raw_texts: Optional[List[str]] = None,
        return_latents = False,
        return_similarities = False,
        return_pairwise_similarities = False
        # 获取输入张量的批次大小和设备信息
        batch, device = wavs.shape[0], wavs.device

        # 获取音频的潜在空间表示和层表示
        audio_latents, audio_layers = self.get_audio_latents(wavs, return_all_layers=True)
        
        # 获取文本的潜在空间表示和层表示
        text_latents, text_layers = self.get_text_latents(texts, raw_texts=raw_texts, return_all_layers=True)

        # 如果需要返回潜在空间表示,则直接返回音频和文本的潜在空间表示
        if return_latents:
            return audio_latents, text_latents

        # 如果需要返回相似度,则计算音频和文本潜在空间表示之间的相似度
        if return_similarities:
            return einsum('i d, i d -> i', audio_latents, text_latents)

        # 如果需要返回成对相似度,则计算音频和文本潜在空间表示之间的余弦相似度矩阵
        if return_pairwise_similarities:
            cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)
            return cosine_sim

        # 计算对比损失
        cl_loss = self.contrast(audio_latents, text_latents)

        # 如果没有多层对比学习模块,则直接返回对比损失
        if not exists(self.multi_layer_contrastive_learning):
            return cl_loss

        # 从音频和文本层表示中选择指定索引的层
        audio_layers = audio_layers[self.audio_layers_indices]
        text_layers = text_layers[self.text_layers_indices]

        # 根据 ViCHA 论文中的建议,是否在所有层之间进行对比损失
        hierarchical_cl_loss = self.multi_layer_contrastive_learning(
            audio_layers=audio_layers,
            text_layers=text_layers
        )

        # 返回对比损失和多层对比学习损失的总和
        return cl_loss + hierarchical_cl_loss
# 定义 MuLaNEmbedQuantizer 类,继承自 AudioConditionerBase 类
class MuLaNEmbedQuantizer(AudioConditionerBase):
    # 初始化函数
    @beartype
    def __init__(
        self,
        mulan: MuLaN,  # MuLaN 对象
        conditioning_dims: Tuple[int, ...],  # 条件维度元组
        rq_num_quantizers = 8,  # RQ 量化器数量,默认为 8
        rq_ema_decay = 0.9,  # RQ 指数移动平均衰减率,默认为 0.9
        codebook_size = 1024,  # 代码簿大小,默认为 1024
        namespaces: Tuple[str, ...] = ('semantic', 'coarse', 'fine'),  # 命名空间元组,默认包含 'semantic', 'coarse', 'fine'
    ):
        super().__init__()  # 调用父类的初始化函数
        self.mulan = mulan  # 初始化 MuLaN 对象

        assert len(namespaces) > 0  # 断言命名空间数量大于 0
        self.namespaces = namespaces  # 初始化命名空间
        self.conditioning_dims = conditioning_dims  # 初始化条件维度

        assert len(conditioning_dims) == len(namespaces), 'number of conditioning dimensions must be equal to number of namespaces'  # 断言条件维度数量等于命名空间数量

        dim = mulan.dim_latent  # 获取 MuLaN 对象的潜在维度

        # 初始化 RQ 对象
        self.rq = ResidualVQ(
            dim = dim,
            num_quantizers = rq_num_quantizers,
            codebook_size = codebook_size,
            decay = rq_ema_decay,
            commitment_weight = 0,    # 只使用 EMA 更新代码簿
            kmeans_init = True,
            threshold_ema_dead_code = 2,
            quantize_dropout = False  # 不使用量化丢弃
        )

        self.dim = dim  # 初始化维度
        self.num_codebooks = rq_num_quantizers  # 初始化代码簿数量

        self.cond_embeddings = nn.ParameterDict({})  # 初始化条件嵌入字典

        # 遍历命名空间和条件维度,初始化条件嵌入
        for namespace, conditioning_dim in zip(namespaces, conditioning_dims):
            cond_embeddings = nn.Parameter(torch.randn(rq_num_quantizers, codebook_size, conditioning_dim))
            nn.init.normal_(cond_embeddings, std = 0.02)

            self.cond_embeddings[namespace] = cond_embeddings

        self.set_default_namespace(namespaces[0])  # 设置默认命名空间为第一个命名空间

    # 返回参数
    def parameters(self):
        return self.cond_embeddings.parameters()

    # 设置默认命名空间
    def set_default_namespace(self, namespace):
        self._default_namespace = namespace

    # 前向传播函数
    def forward(
        self,
        wavs = None,  # 音频数据,默认为 None
        texts = None,  # 文本数据,默认为 None
        namespace = None  # 命名空间,默认为 None
    ):
        assert exists(wavs) ^ exists(texts)  # 断言音频数据或文本数据必须存在其中一个

        namespace = default(namespace, self._default_namespace)  # 获取命名空间,默认为默认命名空间
        assert namespace in self.namespaces, f'namespace {namespace} not found'  # 断言命名空间必须在命名空间列表中

        cond_embeddings = self.cond_embeddings[namespace]  # 获取对应命名空间的条件嵌入

        with torch.no_grad():  # 禁用梯度计算
            self.mulan.eval()  # 设置 MuLaN 为评估模式

            # 音频和语言存在于联合嵌入空间中,因为对比学习

            if exists(wavs):  # 如果音频数据存在
                latents = self.mulan.get_audio_latents(wavs)  # 获取音频潜在表示
            elif exists(texts):  # 如果文本数据存在
                latents = self.mulan.get_text_latents(texts)  # 获取文本潜在表示

        _, indices, _ = self.rq(latents)  # ���用 RQ 对象进行量化

        batch, num_codebooks, dim = indices.shape[0], self.num_codebooks, cond_embeddings.shape[-1]  # 获取批次大小、代码簿数量和维度

        cond_embeddings = repeat(cond_embeddings, 'q c d -> b q c d', b = batch)  # 重复条件嵌入
        indices = repeat(indices, 'b q -> b q 1 d', q = num_codebooks, d = dim)  # 重复索引

        cond_embeddings = cond_embeddings.gather(2, indices)  # 根据索引获取条件嵌入
        return rearrange(cond_embeddings, 'b q 1 d -> b q d')  # 重新排列条件嵌入维度

# 定义 MusicLM 类,继承自 nn.Module
class MusicLM(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        audio_lm: AudioLM,  # AudioLM 对象
        mulan_embed_quantizer: MuLaNEmbedQuantizer  # MuLaNEmbedQuantizer 对象
    ):
        super().__init__()  # 调用父类的初始化函数
        assert not exists(audio_lm.audio_conditioner), 'mulan must not have been passed into AudioLM. it will be managed externally now, embedding the text into the joint embedding space for text-to-audio synthesis'

        self.mulan_embed_quantizer = mulan_embed_quantizer  # 初始化 MuLaNEmbedQuantizer 对象
        self.audio_lm = audio_lm  # 初始化 AudioLM 对象

    # 设备属性
    @property
    def device(self):
        return next(self.parameters()).device  # 返回参数的设备

    # 前向传播函数
    @torch.no_grad()
    def forward(
        self,
        text: str,  # 文本数据
        num_samples = 1,  # 样本数量,默认为 1
        **audio_lm_kwargs  # 音频 LM 参数
        ):
        # 调用 eval 方法
        self.eval()

        # 使用分词器对文本进行分词,并将结果转移到指定设备上
        texts = tokenizer.tokenize([text]).to(self.device)

        # 使用 mulan_embed_quantizer 对文本进行嵌入量化
        text_embeds = self.mulan_embed_quantizer(texts=texts)

        # 无法处理变长音频

        # 初始化一个空列表用于存储生成的音乐样本
        samples = []

        # 生成指定数量的音乐样本
        for _ in range(num_samples):
            # 使用 audio_lm 生成音乐,传入文本嵌入和其他参数
            music = self.audio_lm(text_embeds=text_embeds, **audio_lm_kwargs)
            samples.append(music)

        # 如果只生成一个样本,则直接返回该样本
        if num_samples == 1:
            return first(samples)

        # 获取 mulan_embed_quantizer 中的 mulan 模型
        mulan = self.mulan_embed_quantizer.mulan

        # 计算所有样本与文本的相似度,找到相似度最高的样本
        sims = torch.cat([mulan(texts=texts, wavs=music, return_similarities=True) for music in samples], dim=0)
        top_matching_index = sims.topk(1, dim=0).indices.item()

        # 返回相似度最高的样本
        return samples[top_matching_index]

.\lucidrains\musiclm-pytorch\musiclm_pytorch\trainer.py

# 导入必要的库
import copy
from math import sqrt
from random import choice
from pathlib import Path
from shutil import rmtree
from functools import wraps, partial

from typing_extensions import Annotated

from beartype import beartype
from beartype.door import is_bearable
from beartype.vale import Is
from beartype.typing import Union, List, Optional, Tuple, Callable

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

from lion_pytorch import Lion
from musiclm_pytorch import MuLaN
from einops import rearrange
from accelerate import Accelerator

# 用于自动将数据从数据集发出到变换器包装器的关键字的路由

DATASET_FIELD_TYPE_CONFIG = dict(
    wavs = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
    ],
    raw_texts = List[str],
    texts = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.long and t.ndim == 2]
    ],
)

# 辅助函数

def exists(val):
    return val is not None

def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

def noop(*args, **kwargs):
    pass

def cycle(dl):
    while True:
        for data in dl:
            yield data

def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 自动将数据路由到模块关键字参数的函数

def has_duplicates(tup):
    counts = dict()
    for el in tup:
        if el not in counts:
            counts[el] = 0
        counts[el] += 1
    return any(filter(lambda count: count > 1, counts.values()))

def determine_types(data, config):
    output = []
    for el in data:
        for name, data_type in config.items():
            if is_bearable(el, data_type):
                output.append(name)
                break
        else:
            raise TypeError(f'unable to determine type of {data}')

    return tuple(output)

# 优化器函数

def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 数据加载器函数

def collate_one_or_multiple_tensors(fn):
    @wraps(fn)
    def inner(data):
        is_one_data = not isinstance(data[0], tuple)

        if is_one_data:
            data = torch.stack(data)
            return (data,)

        outputs = []
        for datum in zip(*data):
            if is_bearable(datum, Tuple[str, ...]):
                output = list(datum)
            else:
                output = fn(datum)

            outputs.append(output)

        return tuple(outputs)

    return inner

@collate_one_or_multiple_tensors
def curtail_to_shortest_collate(data):
    min_len = min(*[datum.shape[0] for datum in data])
    data = [datum[:min_len] for datum in data]
    return torch.stack(data)

@collate_one_or_multiple_tensors
def pad_to_longest_fn(data):
    return pad_sequence(data, batch_first = True)

def get_dataloader(ds, pad_to_longest = True, **kwargs):
    collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
    return DataLoader(ds, collate_fn = collate_fn, **kwargs)

# 语义变换器训练器

@beartype
class MuLaNTrainer(nn.Module):
    # 初始化函数,接受多个参数,设置默认值和必须参数
    def __init__(
        self,
        mulan: MuLaN,
        dataset: Dataset,
        *,
        num_train_steps = None,
        batch_size,
        data_max_length = None,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
        betas = (0.9, 0.99),
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        use_lion = False,
        force_clear_prev_results = None  # set to True | False to skip the prompt
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言批处理大小大于1,用于对比学习(最好尽可能大)
        assert batch_size > 1, 'batch size must be greater than 1 for contrastive learning (but ideally as large as possible)'

        # 初始化加速器
        self.accelerator = Accelerator(**accelerate_kwargs)

        # 设置参数
        self.mulan = mulan
        self.register_buffer('steps', torch.Tensor([0]))
        self.num_train_steps = default(num_train_steps, len(dataset)) # 默认为1个epoch
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 选择优化器
        optim_klass = Lion if use_lion else Adam
        self.optim = optim_klass(mulan.parameters(), lr = lr, betas = betas)

        # 设置最大梯度范数
        self.max_grad_norm = max_grad_norm
        self.data_max_length = data_max_length

        # 创建数据集
        self.ds = dataset
        self.ds_fields = None

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 创建数据加载器
        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, pad_to_longest = False, drop_last = True)
        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, pad_to_longest = False, drop_last = True)

        # 准备加速器
        (
            self.mulan,
            self.optim,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.mulan,
            self.optim,
            self.dl,
            self.valid_dl
        )

        # 创建数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置模型保存频率
        self.save_model_every = save_model_every

        # 设置超参数
        hps = dict(
            num_train_steps = num_train_steps,
            data_max_length = data_max_length,
            learning_rate = lr
        )

        # 初始化跟踪器
        self.accelerator.init_trackers("mulan", config = hps)

        # 设置结果文件夹
        self.results_folder = Path(results_folder)

        # 清除之前的结果
        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)

        # 将模型移动到设备
        self.mulan.to(self.device)

    # 保存模型
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.mulan),
            optim = self.optim.state_dict()
        )
        torch.save(pkg, path)

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        mulan = self.accelerator.unwrap_model(self.mulan)
        mulan.load_state_dict(pkg['model'])
        self.optim.load_state_dict(pkg['optim'])
    # 打印消息,调用加速器对象的打印方法
    def print(self, msg):
        self.accelerator.print(msg)

    # 返回加速器对象的设备属性
    @property
    def device(self):
        return self.accelerator.device

    # 返回是否为分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 返回是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 返回是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 将数据元组转换为关键字参数
    def data_tuple_to_kwargs(self, data):
        # 如果数据字段不存在,则根据数据和数据集字段类型配置确定数据字段
        if not exists(self.ds_fields):
            self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
            assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'

        # 将数据字段和数据组成字典
        data_kwargs =  dict(zip(self.ds_fields, data))

        # 截取音频数据长度
        wavs = data_kwargs['wavs']
        data_kwargs.update(wavs = wavs[..., :self.data_max_length])

        return data_kwargs

    # 训练步骤
    def train_step(self):
        # 获取设备
        device = self.device

        # 获取步数
        steps = int(self.steps.item())

        # 模型训练
        self.mulan.train()

        # 日志
        logs = {}

        # 更新生成器
        for _ in range(self.grad_accum_every):
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

            loss = self.mulan(**data_kwargs)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.mulan.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # 打印日志
        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step = steps)

        # 定期保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'mulan.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    # 训练方法
    def train(self, log_fn: Callable = noop):

        # 循环训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')

.\lucidrains\musiclm-pytorch\musiclm_pytorch\__init__.py

# 从 musiclm_pytorch 包中导入以下模块
from musiclm_pytorch.musiclm_pytorch import (
    MuLaN,  # 导入 MuLaN 类
    MuLaNEmbedQuantizer,  # 导入 MuLaNEmbedQuantizer 类
    MusicLM,  # 导入 MusicLM 类
    AudioSpectrogramTransformer,  # 导入 AudioSpectrogramTransformer 类
    TextTransformer,  # 导入 TextTransformer 类
    SigmoidContrastiveLearning,  # 导入 SigmoidContrastiveLearning 类
    SoftmaxContrastiveLearning  # 导入 SoftmaxContrastiveLearning 类
)

# 从 musiclm_pytorch 包中导入 MuLaNTrainer 模块
from musiclm_pytorch.trainer import MuLaNTrainer

MusicLM - Pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch.

They are basically using text-conditioned AudioLM, but surprisingly with the embeddings from a text-audio contrastive learned model named MuLan. MuLan is what will be built out in this repository, with AudioLM modified from the other repository to support the music generation needs here.

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

What's AI by Louis Bouchard

Appreciation

Usage

$ pip install musiclm-pytorch

Usage

MuLaN first needs to be trained

import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

# get a ton of <sound, text> pairs and train

wavs = torch.randn(2, 1024)
texts = torch.randint(0, 20000, (2, 256))

loss = mulan(wavs, texts)
loss.backward()

# after much training, you can embed sounds and text into a joint embedding space
# for conditioning the audio LM

embeds = mulan.get_audio_latents(wavs)  # during training

embeds = mulan.get_text_latents(texts)  # during inference

To obtain the conditioning embeddings for the three transformers that are a part of AudioLM, you must use the MuLaNEmbedQuantizer as so

from musiclm_pytorch import MuLaNEmbedQuantizer

# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,                          # pass in trained mulan from above
    conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024
    namespaces = ('semantic', 'coarse', 'fine')
)

# now say you want the conditioning embeddings for semantic transformer

wavs = torch.randn(2, 1024)
conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers

To train (or finetune) the three transformers that are a part of AudioLM, you simply follow the instructions over at audiolm-pytorch for training, but pass in the MulanEmbedQuantizer instance to the training classes under the keyword audio_conditioner

ex. SemanticTransformerTrainer

import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    audio_text_condition = True      # this must be set to True (same for CoarseTransformer and FineTransformers)
).cuda()

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    audio_conditioner = quantizer,   # pass in the MulanEmbedQuantizer instance above
    folder ='/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

After much training on all three transformers (semantic, coarse, fine), you will pass your finetuned or trained-from-scratch AudioLM and MuLaN wrapped in MuLaNEmbedQuantizer to the MusicLM

# you need the trained AudioLM (audio_lm) from above
# with the MulanEmbedQuantizer (mulan_embed_quantizer)

from musiclm_pytorch import MusicLM

musiclm = MusicLM(
    audio_lm = audio_lm,                 # `AudioLM` from https://github.com/lucidrains/audiolm-pytorch
    mulan_embed_quantizer = quantizer    # the `MuLaNEmbedQuantizer` from above
)

music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4) # sample 4 and pick the top match with mulan

Todo

Citations

@inproceedings{Agostinelli2023MusicLMGM,
    title     = {MusicLM: Generating Music From Text},
    author    = {Andrea Agostinelli and Timo I. Denk and Zal{\'a}n Borsos and Jesse Engel and Mauro Verzetti and Antoine Caillon and Qingqing Huang and Aren Jansen and Adam Roberts and Marco Tagliasacchi and Matthew Sharifi and Neil Zeghidour and C. Frank},
    year      = {2023}
}
@article{Huang2022MuLanAJ,
    title   = {MuLan: A Joint Embedding of Music Audio and Natural Language},
    author  = {Qingqing Huang and Aren Jansen and Joonseok Lee and Ravi Ganti and Judith Yue Li and Daniel P. W. Ellis},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.12415}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@article{Liu2022PatchDropoutEV,
    title   = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
    author  = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.07220}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{Shukor2022EfficientVP,
    title   = {Efficient Vision-Language Pretraining with Visual Concepts and Hierarchical Alignment},
    author  = {Mustafa Shukor and Guillaume Couairon and Matthieu Cord},
    booktitle = {British Machine Vision Conference},
    year    = {2022}
}
@inproceedings{Zhai2023SigmoidLF,
    title   = {Sigmoid Loss for Language Image Pre-Training},
    author  = {Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
    year    = {2023}
}

The only truth is music. - Jack Kerouac

Music is the universal language of mankind. - Henry Wadsworth Longfellow

.\lucidrains\musiclm-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'musiclm-pytorch', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.2.8', # 版本号
  license='MIT', # 许可证
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/musiclm-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'text to music',
    'contrastive learning'
  ],
  install_requires=[ # 安装依赖列表
    'accelerate',
    'audiolm-pytorch>=0.17.0',
    'beartype',
    'einops>=0.6',
    'lion-pytorch',
    'vector-quantize-pytorch>=1.0.0',
    'x-clip',
    'torch>=1.12',
    'torchaudio'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\n-grammer-pytorch\n_grammer_pytorch\n_grammer_pytorch.py

# 基于 jax 代码的实现
# https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ngrammer.py

import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
import sympy

# 辅助函数

def exists(val):
    return val is not None

def sum_squares(t, dim = -1):
    return (t ** 2).sum(dim = dim)

# 与 bigram 相关的函数

def multi_way_hash_ids(x, a, b, prime, buckets):
    return ((x * a + b) % prime) % buckets

def get_bigram_ids(ids, vocab_size, segment_pos = None):
    # ids 的形状为 (batch, seq, heads)

    ids = ids.long()
    ids_0 = F.pad(ids, (0, 0, 0, 1))
    ids_1 = F.pad(ids, (0, 0, 1, 0))

    if exists(segment_pos):
        segment_pos = rearrange(segment_pos, 'b n -> b n 1')
        mask = (segment_pos == 0).long()
        mask = 1 - mask
        mask = F.pad(mask, (0, 0, 0, 1))
        ids_1 *= mask

    ngram_ids = ids_0 + ids_1 * vocab_size
    ngram_ids = ngram_ids[:, :-1]
    return ngram_ids

# 与优化器相关的函数

def get_ngrammer_parameters(module):
    params = set()
    for m in module.modules():
        if isinstance(m, Ngrammer):
            params.update(m.parameters())
    rest = set(module.parameters()) - params
    return list(params), list(rest)

def get_ngrammer_param_groups(module, ngrammer_learning_rate = 1e-2):
    ngrammer_params, rest = get_ngrammer_parameters(module)
    return [{'params': rest}, {'params': ngrammer_params, 'lr': ngrammer_learning_rate}]

# layernorm

class MultiheadLayerNorm(nn.Module):
    def __init__(self, dim, heads = 1, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(heads, dim))
        self.b = nn.Parameter(torch.zeros(heads, dim))

    def forward(self, x):
        std = torch.var(x, dim = -1, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = -1, keepdim = True)
        return (x - mean) / (std + self.eps) * self.g + self.b

# 类

class VectorQuantization(nn.Module):
    def __init__(
        self,
        *,
        num_clusters,
        num_heads,
        dim_per_head,
        decay = 0.999,
        epsilon = 1e-6
    ):
        super().__init__()
        self.decay = decay
        self.epsilon = epsilon
        self.num_heads = num_heads
        self.dim_per_head = dim_per_head
        self.num_clusters = num_clusters

        self.register_buffer('means', torch.randn(num_heads, num_clusters, dim_per_head))

    def forward(
        self,
        x,
        mask = None
    ):
        h, dim_head, num_clusters, eps, decay, means = self.num_heads, self.dim_per_head, self.num_clusters, self.epsilon, self.decay, self.means
        assert x.shape[-1] == (h * dim_head), f'input embedding feature dimension must be {h * dim_head}'

        # 将输入中的头部分离出来

        x = rearrange(x, 'b n (h d) -> b n h d', h = h)

        # 获取输入嵌入与均值之间的距离

        dists = (
            rearrange(sum_squares(x), 'b n h -> b n h 1')
            - 2 * einsum('b n h d, h k d -> b n h k', x, means)
            + rearrange(sum_squares(means), 'h k -> 1 1 h k')
        )

        # 获取簇 id

        cluster_ids = dists.argmin(dim = -1)

        if self.training:
            # 获取 one hot 编码,用于计算每个均值的匹配数

            nearest_one_hot = F.one_hot(cluster_ids, num_classes = num_clusters)
            per_cluster_count = nearest_one_hot.sum(dim = (0, 1))

            # 每个最近质心的输入之和。

            sum_x = einsum('b n h k, b n h d -> h k d', nearest_one_hot.float(), x)

            # 计算新的均值

            new_means = sum_x / (eps + rearrange(per_cluster_count, '... -> ... 1'))

            # 指数移动平均

            updated_means = (1. - decay) * new_means + decay * means

            self.means.data.copy_(updated_means)

        return cluster_ids

class Ngrammer(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        unigram_vocab_size,  # 单字词汇表大小
        dim_per_head,  # 每个头的维度
        num_heads = 1,  # 头的数量,默认为1
        ngram_emb_dim = 8,  # n-gram嵌入维度,默认为8
        ngram_vocab_size = 768 * 256,  # n-gram词汇表大小,默认为768 * 256
        concat_ngrams = True  # 是否连接n-gram,默认为True
    ):
        super().__init__()
        # 断言,确保当连接n-gram时,每个头的维度不能小于n-gram嵌入维度
        assert not (concat_ngrams and dim_per_head <= ngram_emb_dim), 'unigram head dimension cannot be smaller than ngram embedding dimension when concatting'
        # 断言,确保当不连接n-gram时,每个头的维度必须等于n-gram嵌入维度
        assert not (not concat_ngrams and dim_per_head != ngram_emb_dim), 'unigram head dimension must be equal to ngram embedding dimension if not concatting'

        # 初始化模型参数
        self.num_heads = num_heads
        self.ngram_vocab_size = ngram_vocab_size
        self.unigram_vocab_size = unigram_vocab_size
        self.concat_ngrams = concat_ngrams

        # 初始化模型的嵌入层
        self.embeddings = nn.ModuleList([])

        # 初始化n-gram的LayerNorm
        self.ngram_layernorm = MultiheadLayerNorm(ngram_emb_dim, heads = num_heads)
        # 初始化嵌入的LayerNorm
        self.embeds_layernorm = MultiheadLayerNorm(dim_per_head, heads = num_heads)

        # 初始化n-gram的Embedding层
        self.ngram_embeds = nn.Embedding(ngram_vocab_size * num_heads, ngram_emb_dim)

        # 生成质数列表,用于多头哈希计算
        primes = list(sympy.primerange(ngram_vocab_size + 1, 2 * ngram_vocab_size))[:num_heads]
        self.register_buffer('primes', torch.tensor(primes), persistent = False)

    # 前向传播函数
    def forward(
        self,
        embeds,  # 嵌入
        cluster_ids,  # 聚类ID
        mask = None,  # 掩码,默认为None
        segment_pos = None  # 分段位置,默认为None
    ):
        # 获取模型参数
        num_heads, vocab_size, unigram_vocab_size, device = self.num_heads, self.ngram_vocab_size, self.unigram_vocab_size, embeds.device

        # 如果聚类ID的维度为2,则重复扩展为多头
        if cluster_ids.ndim == 2:
            cluster_ids = repeat(cluster_ids, '... -> ... h', h = num_heads)

        # 获取n-gram聚类ID
        ngram_cluster_ids = get_bigram_ids(cluster_ids, unigram_vocab_size, segment_pos)

        # 准备用于并行计算多头哈希ID的头范围
        head_range = torch.arange(num_heads, device = device)
        head_range = rearrange(head_range, 'h -> 1 1 h')
        primes = rearrange(self.primes, 'h -> 1 1 h')

        # 多头哈希ID计算
        ngram_ids = multi_way_hash_ids(ngram_cluster_ids, head_range + 1, head_range + 1, primes, vocab_size)

        # 根据头编号适当地移动词汇范围
        ngram_ids = ngram_ids + (vocab_size * head_range)

        # 一次性获取所有n-gram嵌入,并进行多头LayerNorm
        ngram_embeds = self.ngram_embeds(ngram_ids)
        normed_ngram_embeds = self.ngram_layernorm(ngram_embeds)

        # 多头LayerNorm输入
        embeds = rearrange(embeds, 'b n (h d) -> b n h d', h = num_heads)
        normed_embeds = self.embeds_layernorm(embeds)

        # 连接原始单字嵌入和bigram
        if self.concat_ngrams:
            input_sliced_dim = normed_embeds.shape[-1] - normed_ngram_embeds.shape[-1]
            out = torch.cat((
                normed_embeds[..., :input_sliced_dim],
                normed_ngram_embeds
            ), dim = -1)
        else:
            out = normed_embeds + normed_ngram_embeds

        # 展平
        out = rearrange(out, 'b n ... -> b n (...)')

        # 如果需要,进行掩码
        if exists(mask):
            out = out * rearrange(mask, 'b n -> b n 1').float()

        return out
# 主类定义

class VQNgrammer(nn.Module):
    def __init__(
        self,
        *,
        num_clusters,  # 聚类中心数量
        num_heads,  # 多头注意力机制中头的数量
        dim_per_head,  # 每个头的维度
        ngram_vocab_size = 768 * 256,  # N-gram词汇表大小,默认为768*256
        ngram_emb_dim = 8,  # N-gram嵌入维度,默认为8
        concat_ngrams = True,  # 是否连接N-gram
        decay = 0.999,  # 衰减率,默认为0.999
        epsilon = 1e-6  # 防止除零错误的小值,默认为1e-6
    ):
        super().__init__()
        assert ngram_vocab_size < (num_clusters ** 2), 'the ngram vocab size should be less than the number of clusters squared'

        # 初始化向量量化模块
        self.vq = VectorQuantization(
            num_clusters = num_clusters,
            num_heads = num_heads,
            dim_per_head = dim_per_head,
            decay = decay,
            epsilon = epsilon
        )

        # 初始化N-gram模块
        self.ngram = Ngrammer(
            unigram_vocab_size = num_clusters,
            ngram_vocab_size = ngram_vocab_size,
            ngram_emb_dim = ngram_emb_dim,
            concat_ngrams = concat_ngrams,
            num_heads = num_heads,
            dim_per_head = dim_per_head
        )

    def forward(
        self,
        x,
        mask = None,
        segment_pos = None
    ):

        # 使用向量量化模块对输入进行聚类
        cluster_ids = self.vq(x, mask = mask)

        # 使用N-gram模块处理输入数据
        out = self.ngram(
            x,
            cluster_ids = cluster_ids,
            mask = mask,
            segment_pos = segment_pos
        )

        return out

.\lucidrains\n-grammer-pytorch\n_grammer_pytorch\__init__.py

# 从 n_grammer_pytorch.n_grammer_pytorch 模块中导入 VQNgrammer, Ngrammer, get_ngrammer_parameters, get_ngrammer_param_groups 类/函数
from n_grammer_pytorch.n_grammer_pytorch import VQNgrammer, Ngrammer, get_ngrammer_parameters, get_ngrammer_param_groups

N-Grammer - Pytorch

Implementation of N-Grammer, augmenting Transformers with latent n-grams, in Pytorch

Install

$ pip install n-grammer-pytorch

Usage

import torch
from n_grammer_pytorch import VQNgrammer

vq_ngram = VQNgrammer(
    num_clusters = 1024,             # number of clusters
    dim_per_head = 32,               # dimension per head
    num_heads = 16,                  # number of heads
    ngram_vocab_size = 768 * 256,    # ngram vocab size
    ngram_emb_dim = 16,              # ngram embedding dimension
    decay = 0.999                    # exponential moving decay value
)

x = torch.randn(1, 1024, 32 * 16)
vq_ngram(x) # (1, 1024, 32 * 16)

Learning Rates

Like product key memories, Ngrammer parameters need to have a higher learning rate (1e-2 was recommended in the paper). The repository offers an easy way to generate the parameter groups.

from torch.optim import Adam
from n_grammer_pytorch import get_ngrammer_parameters

# this helper function, for your root model, finds all the VQNgrammer models and the embedding parameters
ngrammer_parameters, other_parameters = get_ngrammer_parameters(transformer)

optim = Adam([
    {'params': other_parameters},
    {'params': ngrammer_parameters, 'lr': 1e-2}
], lr = 3e-4)

Or, even more simply

from torch.optim import Adam
from n_grammer_pytorch import get_ngrammer_param_groups

param_groups = get_ngrammer_param_groups(model) # automatically creates array of parameter settings with learning rate set at 1e-2 for ngrammer parameter values
optim = Adam(param_groups, lr = 3e-4)

Citations

@inproceedings{thai2020using,
    title   = {N-grammer: Augmenting Transformers with latent n-grams},
    author  = {Anonymous},
    year    = {2021},
    url     = {https://openreview.net/forum?id=GxjCYmQAody}
}

.\lucidrains\n-grammer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'n-grammer-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.14',  # 版本号
  license='MIT',  # 许可证
  description = 'N-Grammer - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/n-grammer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'n-grams',
    'memory'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'sympy',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\aligner.py

from typing import Tuple
import numpy as np

import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F

from einops import rearrange, repeat

from beartype import beartype
from beartype.typing import Optional

# 检查变量是否存在
def exists(val):
    return val is not None

# 定义对齐模型类
class AlignerNet(Module):
    """alignment model https://arxiv.org/pdf/2108.10447.pdf """
    def __init__(
        self,
        dim_in=80,
        dim_hidden=512,
        attn_channels=80,
        temperature=0.0005,
    ):
        super().__init__()
        self.temperature = temperature

        # 定义关键字层
        self.key_layers = nn.ModuleList([
            nn.Conv1d(
                dim_hidden,
                dim_hidden * 2,
                kernel_size=3,
                padding=1,
                bias=True,
            ),
            nn.ReLU(inplace=True),
            nn.Conv1d(dim_hidden * 2, attn_channels, kernel_size=1, padding=0, bias=True)
        ])

        # 定义查询层
        self.query_layers = nn.ModuleList([
            nn.Conv1d(
                dim_in,
                dim_in * 2,
                kernel_size=3,
                padding=1,
                bias=True,
            ),
            nn.ReLU(inplace=True),
            nn.Conv1d(dim_in * 2, dim_in, kernel_size=1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True)
        ])

    # 前向传播函数
    @beartype
    def forward(
        self,
        queries: Tensor,
        keys: Tensor,
        mask: Optional[Tensor] = None
    ):
        key_out = keys
        for layer in self.key_layers:
            key_out = layer(key_out)

        query_out = queries
        for layer in self.query_layers:
            query_out = layer(query_out)

        key_out = rearrange(key_out, 'b c t -> b t c')
        query_out = rearrange(query_out, 'b c t -> b t c')

        attn_logp = torch.cdist(query_out, key_out)
        attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...')

        if exists(mask):
            mask = rearrange(mask.bool(), '... c -> ... 1 c')
            attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max)

        attn = attn_logp.softmax(dim = -1)
        return attn, attn_logp

# 填充张量函数
def pad_tensor(input, pad, value=0):
    pad = [item for sublist in reversed(pad) for item in sublist]  # Flatten the tuple
    assert len(pad) // 2 == len(input.shape), 'Padding dimensions do not match input dimensions'
    return F.pad(input, pad, mode='constant', value=value)

# 最大路径函数
def maximum_path(value, mask, const=None):
    device = value.device
    dtype = value.dtype
    if not exists(const):
        const = torch.tensor(float('-inf')).to(device)  # Patch for Sphinx complaint
    value = value * mask

    b, t_x, t_y = value.shape
    direction = torch.zeros(value.shape, dtype=torch.int64, device=device)
    v = torch.zeros((b, t_x), dtype=torch.float32, device=device)
    x_range = torch.arange(t_x, dtype=torch.float32, device=device).view(1, -1)

    for j in range(t_y):
        v0 = pad_tensor(v, ((0, 0), (1, 0)), value = const)[:, :-1]
        v1 = v
        max_mask = v1 >= v0
        v_max = torch.where(max_mask, v1, v0)
        direction[:, :, j] = max_mask

        index_mask = x_range <= j
        v = torch.where(index_mask.view(1,-1), v_max + value[:, :, j], const)

    direction = torch.where(mask.bool(), direction, 1)

    path = torch.zeros(value.shape, dtype=torch.float32, device=device)
    index = mask[:, :, 0].sum(1).long() - 1
    index_range = torch.arange(b, device=device)

    for j in reversed(range(t_y)):
        path[index_range, index, j] = 1
        index = index + direction[index_range, index, j] - 1

    path = path * mask.float()
    path = path.to(dtype=dtype)
    return path

# 前向求和损失类
class ForwardSumLoss(Module):
    def __init__(
        self,
        blank_logprob = -1
    # 初始化类,继承父类的属性和方法
    ):
        super().__init__()
        # 设置空白标签的对数概率
        self.blank_logprob = blank_logprob

        # 创建 CTC 损失函数对象
        self.ctc_loss = torch.nn.CTCLoss(
            blank = 0,  # 设置空白标签的值为0
            zero_infinity = True  # 设置是否将无穷大值转换为零
        )

    # 前向传播函数
    def forward(self, attn_logprob, key_lens, query_lens):
        # 获取设备信息和空白标签对数概率
        device, blank_logprob  = attn_logprob.device, self.blank_logprob
        # 获取输入的最大键长度
        max_key_len = attn_logprob.size(-1)

        # 重新排列输入数据的维度为[query_len, batch_size, key_len]
        attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')

        # 添加空白标签
        attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob)

        # 转换为对数概率
        # 注意:屏蔽超出键长度的概率
        mask_value = -torch.finfo(attn_logprob.dtype).max
        attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)

        attn_logprob = attn_logprob.log_softmax(dim = -1)

        # 目标序列
        target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long)
        target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel())

        # 计算 CTC 损失
        cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens)

        return cost
class BinLoss(Module):
    # 定义一个继承自 Module 的 BinLoss 类
    def forward(self, attn_hard, attn_logprob, key_lens):
        # 前向传播函数,接受注意力机制的硬分配、对数概率和键长度作为输入
        batch, device = attn_logprob.shape[0], attn_logprob.device
        # 获取 batch 大小和设备信息
        max_key_len = attn_logprob.size(-1)
        # 获取键的最大长度

        # 重新排列输入为 [query_len, batch_size, key_len]
        attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
        attn_hard = rearrange(attn_hard, 'b t c -> c b t')
        # 重新排列注意力机制的输入形状

        mask_value = -torch.finfo(attn_logprob.dtype).max
        # 创建一个用于掩码的值

        attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
        # 使用掩码值对注意力对数概率进行填充
        attn_logprob = attn_logprob.log_softmax(dim = -1)
        # 对注意力对数概率进行 log_softmax 操作

        return (attn_hard * attn_logprob).sum() / batch
        # 返回加权后的结果除以 batch 大小

class Aligner(Module):
    # 定义一个继承自 Module 的 Aligner 类
    def __init__(
        self,
        dim_in,
        dim_hidden,
        attn_channels=80,
        temperature=0.0005
    ):
        # 初始化函数,接受输入维度、隐藏维度、注意力通道数和温度参数
        super().__init__()
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.attn_channels = attn_channels
        self.temperature = temperature
        # 设置类的属性

        self.aligner = AlignerNet(
            dim_in = self.dim_in, 
            dim_hidden = self.dim_hidden,
            attn_channels = self.attn_channels,
            temperature = self.temperature
        )
        # 初始化 AlignerNet 模型

    def forward(
        self,
        x,
        x_mask,
        y,
        y_mask
    ):
        # 前向传播函数,接受输入 x、x_mask、y、y_mask
        alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask)
        # 使用 AlignerNet 模型计算软对齐和对数概率

        x_mask = rearrange(x_mask, '... i -> ... i 1')
        y_mask = rearrange(y_mask, '... j -> ... 1 j')
        attn_mask = x_mask * y_mask
        attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j')
        # 生成注意力掩码

        alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c')
        alignment_mask = maximum_path(alignment_soft, attn_mask)
        # 重新排列软对齐结果并计算最大路径

        alignment_hard = torch.sum(alignment_mask, -1).int()
        # 计算硬对齐结果
        return alignment_hard, alignment_soft, alignment_logprob, alignment_mask
        # 返回硬对齐结果、软对齐结果、对数概率和对齐掩码

if __name__ == '__main__':
    # 如果作为脚本运行
    batch_size = 10
    seq_len_y = 200   # 序列 y 的长度
    seq_len_x = 35
    feature_dim = 80  # 特征维度

    x = torch.randn(batch_size, 512, seq_len_x)
    y = torch.randn(batch_size, seq_len_y, feature_dim)
    y = y.transpose(1,2) #dim-1 is the channels for conv
    # 生成输入 x 和 y,并对 y 进行转置

    # 创建掩码
    x_mask = torch.ones(batch_size, 1, seq_len_x)
    y_mask = torch.ones(batch_size, 1, seq_len_y)

    align = Aligner(dim_in = 80, dim_hidden=512, attn_channels=80)
    # 初始化 Aligner 模型
    alignment_hard, alignment_soft, alignment_logprob, alignment_mas = align(x, x_mask, y, y_mask)
    # 进行对齐操作
posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(14)  评论(0编辑  收藏  举报