Lucidrains-系列项目源码解析-四十九-

Lucidrains 系列项目源码解析(四十九)

.\lucidrains\vit-pytorch\vit_pytorch\vit_3d.py

import torch  # 导入 PyTorch 库
from torch import nn  # 从 PyTorch 库中导入 nn 模块

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange 和 repeat 函数
from einops.layers.torch import Rearrange  # 从 einops 库中导入 Torch 版的 Rearrange 类

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)  # 如果 t 是元组则返回 t,否则返回 (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 层
            nn.Linear(hidden_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 层
        )
    def forward(self, x):
        return self.net(x)  # 前向传播

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)  # 对输入进行 Layer Normalization
        self.attend = nn.Softmax(dim = -1)  # Softmax 层
        self.dropout = nn.Dropout(dropout)  # Dropout 层

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换,用于计算 Q、K、V

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 层
        ) if project_out else nn.Identity()  # 如果 project_out 为真则使用 nn.Sequential,否则使用 nn.Identity

    def forward(self, x):
        x = self.norm(x)  # Layer Normalization
        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 将线性变换后的结果切分成 Q、K、V
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  # 重排 Q、K、V 的维度

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # 计算 Q、K 的点积

        attn = self.attend(dots)  # 注意力权重
        attn = self.dropout(attn)  # Dropout

        out = torch.matmul(attn, v)  # 加权求和
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排输出维度
        return self.to_out(out)  # 返回输出

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),  # 注意力层
                FeedForward(dim, mlp_dim, dropout = dropout)  # 前馈神经网络层
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x  # 注意力层输出与输入相加
            x = ff(x) + x  # 前馈神经网络层输出与输入相加
        return x  # 返回输出

class ViT(nn.Module):
    def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(image_patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'

        num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
        patch_dim = channels * patch_height * patch_width * frame_patch_size

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),  # 重排图像补丁的维度
            nn.LayerNorm(patch_dim),  # 对输入进行 Layer Normalization
            nn.Linear(patch_dim, dim),  # 线性变换
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 位置编码
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 类别标记
        self.dropout = nn.Dropout(emb_dropout)  # Dropout 层

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # Transformer 模块

        self.pool = pool  # 池化方式
        self.to_latent = nn.Identity()  # 转换为潜在空间的恒等映射

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, num_classes)  # 线性变换
        )  # MLP 头部
    # 前向传播函数,接收视频数据作为输入
    def forward(self, video):
        # 将视频数据转换为补丁嵌入
        x = self.to_patch_embedding(video)
        # 获取批量大小、补丁数量和嵌入维度
        b, n, _ = x.shape

        # 重复类别标记以匹配批量大小
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        # 将类别标记与补丁嵌入拼接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置嵌入到输入中
        x += self.pos_embedding[:, :(n + 1)]
        # 对输入进行 dropout 处理
        x = self.dropout(x)

        # 使用 Transformer 处理输入数据
        x = self.transformer(x)

        # 根据池化方式计算输出
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 将输出转换为潜在空间
        x = self.to_latent(x)
        # 使用 MLP 头部处理潜在空间的输出
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vit_for_small_dataset.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 导入 torch 模块
import torch
# 从 torch.nn 模块中导入 functional 模块和 nn 模块
import torch.nn.functional as F
from torch import nn
# 从 einops 模块中导入 rearrange 和 repeat 函数,从 einops.layers.torch 模块中导入 Rearrange 类

# 定义辅助函数 pair,如果输入参数 t 是元组则返回 t,否则返回 (t, t)
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 定义 FeedForward 类,继承自 nn.Module 类
class FeedForward(nn.Module):
    # 初始化函数,接受维度 dim、隐藏层维度 hidden_dim 和 dropout 参数
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义 LSA 类,继承自 nn.Module 类
class LSA(nn.Module):
    # 初始化函数,接受维度 dim、头数 heads、头维度 dim_head 和 dropout 参数
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()

        mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
        mask_value = -torch.finfo(dots.dtype).max
        dots = dots.masked_fill(mask, mask_value)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,继承自 nn.Module 类
class Transformer(nn.Module):
    # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP维度 mlp_dim 和 dropout 参数
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义 SPT 类,继承自 nn.Module 类
class SPT(nn.Module):
    # 初始化函数,接受维度 dim、patch 大小 patch_size 和通道数 channels 参数
    def __init__(self, *, dim, patch_size, channels = 3):
        super().__init__()
        patch_dim = patch_size * patch_size * 5 * channels

        self.to_patch_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim)
        )

    # 前向传播函数
    def forward(self, x):
        shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
        shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
        x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
        return self.to_patch_tokens(x_with_shifts)

# 定义 ViT 类
class ViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 检查图像的尺寸是否能被补丁的尺寸整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width
        # 检查池化类型是否为 'cls' 或 'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 创建补丁嵌入层
        self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 初始化类别标记参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 创建 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 设置池化类型
        self.pool = pool
        # 创建转换到潜在空间的层
        self.to_latent = nn.Identity()

        # 创建 MLP 头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为补丁
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 将类别标记与补丁连接
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置嵌入
        x += self.pos_embedding[:, :(n + 1)]
        # 应用丢弃层
        x = self.dropout(x)

        # 使用 Transformer 进行转换
        x = self.transformer(x)

        # 池化操作,根据池化类型选择不同的方式
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 转换到潜在空间
        x = self.to_latent(x)
        # 使用 MLP 头部进行分类
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vit_with_patch_dropout.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 repeat 函数,以及 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 如果输入 t 是元组,则返回 t,否则返回包含 t 的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 类定义

# 定义 PatchDropout 类,继承自 nn.Module
class PatchDropout(nn.Module):
    # 初始化函数,接受概率参数 prob
    def __init__(self, prob):
        super().__init__()
        # 断言概率在 [0, 1) 范围内
        assert 0 <= prob < 1.
        self.prob = prob

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 如果不在训练模式或概率为 0,则直接返回输入 x
        if not self.training or self.prob == 0.:
            return x

        # 获取输入 x 的形状信息
        b, n, _, device = *x.shape, x.device

        # 生成 batch 索引
        batch_indices = torch.arange(b, device = device)
        batch_indices = rearrange(batch_indices, '... -> ... 1')
        # 计算保留的 patch 数量
        num_patches_keep = max(1, int(n * (1 - self.prob)))
        # 生成保留的 patch 索引
        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

        return x[batch_indices, patch_indices_keep]

# 定义 FeedForward 类,继承自 nn.Module
class FeedForward(nn.Module):
    # 初始化函数,接受维度 dim、隐藏层维度 hidden_dim 和 dropout 参数
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    # 前向传播函数,接受输入 x
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
    # 初始化函数,接受维度 dim、头数 heads、头维度 dim_head 和 dropout 参数
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    # 前向传播函数,接受输入 x
    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP 维度 mlp_dim 和 dropout 参数
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        # 根据深度循环创建多个 Transformer 层
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 遍历每个 Transformer 层
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义 ViT 类,继承自 nn.Module
class ViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., patch_dropout = 0.25):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被补丁的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width
        # 断言池化类型只能是'cls'(CLS标记)或'mean'(平均池化)
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 将图像转换为补丁嵌入
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
        # 初始化CLS标记
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # 创建补丁丢弃层
        self.patch_dropout = PatchDropout(patch_dropout)
        # 创建嵌入丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 创建Transformer模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 设置池化类型
        self.pool = pool
        # 创建转换到潜在空间的层
        self.to_latent = nn.Identity()

        # 创建MLP头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 添加位置嵌入
        x += self.pos_embedding

        # 对补丁进行丢弃
        x = self.patch_dropout(x)

        # 重复CLS标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)

        # 将CLS标记和补丁连接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.dropout(x)

        # 使用Transformer进行特征提取
        x = self.transformer(x)

        # 池化操作,根据池化类型选择不同的方式
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 转换到潜在空间
        x = self.to_latent(x)
        # 使用MLP头部进行分类预测
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vit_with_patch_merger.py

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val ,d):
    return val if exists(val) else d

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# patch merger class

# 定义 PatchMerger 类
class PatchMerger(nn.Module):
    def __init__(self, dim, num_tokens_out):
        super().__init__()
        self.scale = dim ** -0.5
        self.norm = nn.LayerNorm(dim)
        self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))

    def forward(self, x):
        x = self.norm(x)
        sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
        attn = sim.softmax(dim = -1)
        return torch.matmul(attn, x)

# classes

# 定义 FeedForward 类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])

        self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
        self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for index, (attn, ff) in enumerate(self.layers):
            x = attn(x) + x
            x = ff(x) + x

            if index == self.patch_merge_layer_index:
                x = self.patch_merger(x)

        return self.norm(x)

class ViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 检查图像的尺寸是否能被补丁的尺寸整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width

        # 定义将图像转换为补丁嵌入的序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 定义丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化Transformer模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens)

        # 定义MLP头部
        self.mlp_head = nn.Sequential(
            Reduce('b n d -> b d', 'mean'),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 添加位置嵌入到补丁嵌入中
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        # 使用Transformer进行特征提取
        x = self.transformer(x)

        # 使用MLP头部进行分类
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vivit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange, repeat, reduce 函数
from einops import rearrange, repeat, reduce
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 将输入转换为元组的函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 变换器类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 视觉变换器类
class ViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        image_patch_size,
        frames,
        frame_patch_size,
        num_classes,
        dim,
        spatial_depth,
        temporal_depth,
        heads,
        mlp_dim,
        pool = 'cls',
        channels = 3,
        dim_head = 64,
        dropout = 0.,
        emb_dropout = 0.
    ):
        # 调用父类的构造函数
        super().__init__()
        # 解构图像尺寸和图像块尺寸
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(image_patch_size)

        # 断言图像高度和宽度能够被图像块高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言帧数能够被帧块大小整除
        assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'

        # 计算图像块数量和帧块数量
        num_image_patches = (image_height // patch_height) * (image_width // patch_width)
        num_frame_patches = (frames // frame_patch_size)

        # 计算图像块维度
        patch_dim = channels * patch_height * patch_width * frame_patch_size

        # 断言池化类型为'cls'或'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 根据池化类型设置是否使用全局平均池化
        self.global_average_pool = pool == 'mean'

        # 定义将图像块转换为嵌入向量的层序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化空间和时间的CLS token参数
        self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
        self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None

        # 初始化空间和时间的Transformer模型
        self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
        self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)

        # 设置池化类型和转换为潜在空间的层
        self.pool = pool
        self.to_latent = nn.Identity()

        # 定义MLP头部
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, video):
        # 将视频转换为图像块嵌入向量
        x = self.to_patch_embedding(video)
        b, f, n, _ = x.shape

        # 添加位置嵌入
        x = x + self.pos_embedding[:, :f, :n]

        # 如果存在空间CLS token,则添加到输入中
        if exists(self.spatial_cls_token):
            spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
            x = torch.cat((spatial_cls_tokens, x), dim = 2)

        # 应用Dropout
        x = self.dropout(x)

        # 重排张量形状以便空间注意力
        x = rearrange(x, 'b f n d -> (b f) n d')

        # 在空间上进行注意力计算
        x = self.spatial_transformer(x)

        # 重排张量形状以便后续处理
        x = rearrange(x, '(b f) n d -> b f n d', b = b)

        # 剔除空间CLS token或进行全局平均池化以便时间注意力
        x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')

        # 如果存在时间CLS token,则添加到输入中
        if exists(self.temporal_cls_token):
            temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
            x = torch.cat((temporal_cls_tokens, x), dim = 1)

        # 在时间上进行注意力计算
        x = self.temporal_transformer(x)

        # 剔除时间CLS token或进行全局平均池化
        x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')

        # 转换为潜在空间并返回MLP头部的输出
        x = self.to_latent(x)
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\xcit.py

# 从 random 模块中导入 randrange 函数
from random import randrange

# 导入 torch 模块及相关子模块
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F

# 导入 einops 模块及相关函数
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 将张量打包成指定模式的函数
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式的函数
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 对张量进行 L2 归一化的函数
def l2norm(t):
    return F.normalize(t, dim = -1, p = 2)

# 对神经网络层进行 dropout 处理的函数
def dropout_layers(layers, dropout):
    if dropout == 0:
        return layers

    num_layers = len(layers)
    to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout

    # 确保至少有一层不被丢弃
    if all(to_drop):
        rand_index = randrange(num_layers)
        to_drop[rand_index] = False

    layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
    return layers

# 类

# LayerScale 类,用于对输入进行缩放
class LayerScale(Module):
    def __init__(self, dim, fn, depth):
        super().__init__()
        if depth <= 18:
            init_eps = 0.1
        elif 18 < depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        self.fn = fn
        self.scale = nn.Parameter(torch.full((dim,), init_eps))

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

# FeedForward 类,前馈神经网络层
class FeedForward(Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# Attention 类,注意力机制层
class Attention(Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None):
        h = self.heads

        x = self.norm(x)
        context = x if not exists(context) else torch.cat((x, context), dim = 1)

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(sim)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# XCAttention 类,交叉通道注意力机制层
class XCAttention(Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    # 定义前向传播函数,接受输入 x
    def forward(self, x):
        # 获取头数
        h = self.heads
        # 将输入 x 打包成指定格式,并返回打包后的数据和打包方案 ps
        x, ps = pack_one(x, 'b * d')

        # 对输入 x 进行归一化处理
        x = self.norm(x)
        # 将 x 转换为查询、键、值,并按最后一个维度分割成三部分
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)

        # 将查询、键、值按照指定格式重新排列
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h=h), (q, k, v))

        # 对查询、键进行 L2 归一化处理
        q, k = map(l2norm, (q, k))

        # 计算注意力矩阵,包括计算相似度、温度调节和注意力计算
        sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp()

        # 进行注意力聚合
        attn = self.attend(sim)
        # 对注意力矩阵进行 dropout 处理
        attn = self.dropout(attn)

        # 根据注意力矩阵和值计算输出
        out = einsum('b h i j, b h j n -> b h i n', attn, v)
        # 将输出按指定格式重新排列
        out = rearrange(out, 'b h d n -> b n (h d)')

        # 将输出解包成原始格式
        out = unpack_one(out, ps, 'b * d')
        # 返回输出结果
        return self.to_out(out)
class LocalPatchInteraction(Module):
    # 定义局部补丁交互模块,继承自 Module 类
    def __init__(self, dim, kernel_size = 3):
        # 初始化函数,接受维度 dim 和卷积核大小 kernel_size,默认为 3
        super().__init__()
        # 调用父类的初始化函数

        assert (kernel_size % 2) == 1
        # 断言卷积核大小为奇数
        padding = kernel_size // 2
        # 计算卷积的填充大小

        self.net = nn.Sequential(
            # 定义神经网络模块
            nn.LayerNorm(dim),
            # 对输入进行层归一化
            Rearrange('b h w c -> b c h w'),
            # 重新排列张量的维度
            nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
            # 二维卷积层
            nn.BatchNorm2d(dim),
            # 对输入进行批归一化
            nn.GELU(),
            # GELU 激活函数
            nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
            # 二维卷积层
            Rearrange('b c h w -> b h w c'),
            # 重新排列张量的维度
        )

    def forward(self, x):
        # 前向传播函数,接受输入 x
        return self.net(x)
        # 返回经过网络处理后的结果

class Transformer(Module):
    # 定义 Transformer 模块,继承自 Module 类
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
        # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP维度 mlp_dim、dropout率 dropout 和层dropout率 layer_dropout,默认为 0
        super().__init__()
        # 调用父类的初始化函数
        self.layers = ModuleList([])
        # 初始化模块列表

        self.layer_dropout = layer_dropout
        # 设置层dropout率

        for ind in range(depth):
            # 循环遍历深度次数
            layer = ind + 1
            # 计算当前层索引
            self.layers.append(ModuleList([
                # 向模块列表中添加模块列表
                LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
                # 添加注意力机制模块
                LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
                # 添加前馈神经网络模块
            ]))

    def forward(self, x, context = None):
        # 前向传播函数,接受输入 x 和上下文 context,默认为 None
        layers = dropout_layers(self.layers, dropout = self.layer_dropout)
        # 对模块列表进行层dropout处理

        for attn, ff in layers:
            # 遍历模块列表中的注意力机制和前馈神经网络模块
            x = attn(x, context = context) + x
            # 经过注意力机制处理后与原始输入相加
            x = ff(x) + x
            # 经过前馈神经网络处理后与原始输入相加

        return x
        # 返回处理后的结果

class XCATransformer(Module):
    # 定义 XCAttention Transformer 模块,继承自 Module 类
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.):
        # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP维度 mlp_dim、局部补丁卷积核大小 local_patch_kernel_size,默认为 3,dropout率 dropout 和层dropout率 layer_dropout,默认为 0
        super().__init__()
        # 调用父类的初始化函数
        self.layers = ModuleList([])
        # 初始化模块列表

        self.layer_dropout = layer_dropout
        # 设置层dropout率

        for ind in range(depth):
            # 循环遍历深度次数
            layer = ind + 1
            # 计算当前层索引
            self.layers.append(ModuleList([
                # 向模块列表中添加模块列表
                LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
                # 添加交叉协方差注意力机制模块
                LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer),
                # 添加局部补丁交互模块
                LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
                # 添加前馈神经网络模块
            ]))

    def forward(self, x):
        # 前向传播函数,接受输入 x
        layers = dropout_layers(self.layers, dropout = self.layer_dropout)
        # 对模块列表进行层dropout处理

        for cross_covariance_attn, local_patch_interaction, ff in layers:
            # 遍历模块列表中的交叉协方差注意力机制、局部补丁交互和前馈神经网络模块
            x = cross_covariance_attn(x) + x
            # 经过交叉协方差注意力机制处理后与原始输入相加
            x = local_patch_interaction(x) + x
            # 经过局部补丁交互处理后与原始输入相加
            x = ff(x) + x
            # 经过前馈神经网络处理后与原始输入相加

        return x
        # 返回处理后的结果

class XCiT(Module):
    # 定义 XCiT 模块,继承自 Module 类
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        cls_depth,
        heads,
        mlp_dim,
        dim_head = 64,
        dropout = 0.,
        emb_dropout = 0.,
        local_patch_kernel_size = 3,
        layer_dropout = 0.
    ):
        # 初始化函数,接受关键字参数 image_size、patch_size、num_classes、dim、depth、cls_depth、heads、mlp_dim、dim_head、dropout、emb_dropout、局部补丁卷积核大小 local_patch_kernel_size,默认为 3,层dropout率 layer_dropout,默认为 0
        super().__init__()
        # 调用父类的初始化函数
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言图像尺寸必须能被补丁大小整除

        num_patches = (image_size // patch_size) ** 2
        # 计算补丁数量
        patch_dim = 3 * patch_size ** 2
        # 计算补丁维度

        self.to_patch_embedding = nn.Sequential(
            # 定义序列模块
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            # 重新排列张量的维度
            nn.LayerNorm(patch_dim),
            # 对输入进行层归一化
            nn.Linear(patch_dim, dim),
            # 线性变换
            nn.LayerNorm(dim)
            # 对输入进行层归一化
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        # 定义位置编码参数
        self.cls_token = nn.Parameter(torch.randn(dim))
        # 定义类别标记参数

        self.dropout = nn.Dropout(emb_dropout)
        # 定义丢弃层

        self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
        # 定义 XCAttention Transformer 模块

        self.final_norm = nn.LayerNorm(dim)
        # 对最终结果进行层归一化

        self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
        # 定义 Transformer 模块

        self.mlp_head = nn.Sequential(
            # 定义序列模块
            nn.LayerNorm(dim),
            # 对输入进行层归一化
            nn.Linear(dim, num_classes)
            # 线性变换
        )
        # 定义 MLP 头部模块
    # 前向传播函数,接收输入图像并进行处理
    def forward(self, img):
        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)

        # 将嵌入的补丁打包成一个张量
        x, ps = pack_one(x, 'b * d')

        # 获取张量的形状信息
        b, n, _ = x.shape
        # 添加位置嵌入到张量中
        x += self.pos_embedding[:, :n]

        # 解包张量
        x = unpack_one(x, ps, 'b * d')

        # 对张量进行 dropout 操作
        x = self.dropout(x)

        # 使用 XCIT Transformer 处理张量
        x = self.xcit_transformer(x)

        # 对处理后的张量进行最终归一化
        x = self.final_norm(x)

        # 重复生成类别标记 tokens
        cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)

        # 重新排列张量的维度
        x = rearrange(x, 'b ... d -> b (...) d')
        # 使用类别标记 tokens 和上下文张量进行类别 Transformer 操作
        cls_tokens = self.cls_transformer(cls_tokens, context = x)

        # 返回 MLP 头部处理后的结果
        return self.mlp_head(cls_tokens[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\__init__.py

# 从 vit_pytorch.vit 模块中导入 ViT 类
from vit_pytorch.vit import ViT
# 从 vit_pytorch.simple_vit 模块中导入 SimpleViT 类
from vit_pytorch.simple_vit import SimpleViT

# 从 vit_pytorch.mae 模块中导入 MAE 类
from vit_pytorch.mae import MAE
# 从 vit_pytorch.dino 模块中导入 Dino 类
from vit_pytorch.dino import Dino

.\lucidrains\VN-transformer\denoise.py

# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch.optim 模块中导入 Adam 优化器
from torch.optim import Adam

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 导入 sidechainnet 库,并从 VN_transformer 模块中导入 VNTransformer 类
import sidechainnet as scn
from VN_transformer import VNTransformer

# 定义常量 BATCH_SIZE
BATCH_SIZE = 1
# 定义常量 GRADIENT_ACCUMULATE_EVERY
GRADIENT_ACCUMULATE_EVERY = 16
# 定义常量 MAX_SEQ_LEN
MAX_SEQ_LEN = 256
# 定义默认数据类型 DEFAULT_TYPE
DEFAULT_TYPE = torch.float64

# 设置 PyTorch 默认数据类型为 DEFAULT_TYPE
torch.set_default_dtype(DEFAULT_TYPE)

# 定义一个循环生成器函数 cycle,用于生成数据
def cycle(loader, len_thres = MAX_SEQ_LEN):
    while True:
        for data in loader:
            # 如果数据的序列长度大于 len_thres,则继续循环
            if data.seqs.shape[1] > len_thres:
                continue
            # 生成数据
            yield data

# 创建 VNTransformer 模型对象
transformer = VNTransformer(
    num_tokens = 24,
    dim = 64,
    depth = 4,
    dim_head = 64,
    heads = 8,
    dim_feat = 64,
    bias_epsilon = 1e-6,
    l2_dist_attn = True,
    flash_attn = False
).cuda()

# 加载数据集
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = BATCH_SIZE,
    dynamic_batching = False    
)

# 创建数据生成器 dl
dl = cycle(data['train'])
# 初始化 Adam 优化器
optim = Adam(transformer.parameters(), lr = 1e-4)

# 进行训练循环
for _ in range(10000):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        # 获取一个 batch 的数据
        batch = next(dl)
        seqs, coords, masks = batch.seqs, batch.crds, batch.msks

        # 将序列数据转移到 GPU 上,并取最大值作为索引
        seqs = seqs.cuda().argmax(dim = -1)
        # 将坐标数据转移到 GPU 上,并设置数据类型为默认类型
        coords = coords.cuda().type(torch.get_default_dtype())
        # 将掩码数据转移到 GPU 上,并转换为布尔类型
        masks = masks.cuda().bool()

        # 获取序列长度
        l = seqs.shape[1]
        # 重新排列坐标数据的维度
        coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

        # 保留主干坐标
        coords = coords[:, :, 0:3, :]
        coords = rearrange(coords, 'b l s c -> b (l s) c')

        # 将序列数据重复为坐标数据的维度
        seq = repeat(seqs, 'b n -> b (n c)', c = 3)
        masks = repeat(masks, 'b n -> b (n c)', c = 3)

        # 给坐标数据添加高斯噪声
        noised_coords = coords + torch.randn_like(coords).cuda()

        # 运行 Transformer 模型
        type1_out, _ = transformer(
            noised_coords,
            feats = seq,
            mask = masks
        )

        # 去噪后的坐标数据
        denoised_coords = noised_coords + type1_out

        # 计算均方误差损失
        loss = F.mse_loss(denoised_coords[masks], coords[masks]) 
        # 反向传播并计算梯度
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    # 输出当前损失值
    print('loss:', loss.item())
    # 更新优化器参数
    optim.step()
    # 清空梯度
    optim.zero_grad()

VN (Vector Neuron) Transformer

A Transformer made of Rotation-equivariant Attention using Vector Neurons

Open Review

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

Install

$ pip install VN-transformer

Usage

import torch
from VN_transformer import VNTransformer

model = VNTransformer(
    dim = 64,
    depth = 2,
    dim_head = 64,
    heads = 8,
    dim_feat = 64,       # will default to early fusion, since this was the best performing
    bias_epsilon = 1e-6  # in this paper, they propose breaking equivariance with a tiny bit of bias noise in the VN linear. they claim this leads to improved stability. setting this to 0 would turn off the epsilon approximate equivariance
)

coors = torch.randn(1, 32, 3)    # (batch, sequence, spatial coordinates)
feats = torch.randn(1, 32, 64)

coors_out, feats_out = model(coors, feats = feats) # (1, 32, 3), (1, 32, 64)

Tests

Confidence in equivariance

$ python setup.py test

Example

First install sidechainnet

$ pip install sidechainnet

Then run the protein backbone denoising task

$ python denoise.py

It does not perform as well as En-Transformer, nor Equiformer

Citations

@inproceedings{Assaad2022VNTransformerRA,
    title   = {VN-Transformer: Rotation-Equivariant Attention for Vector Neurons},
    author  = {Serge Assaad and C. Downey and Rami Al-Rfou and Nigamaa Nayakanti and Benjamin Sapp},
    year    = {2022}
}
@article{Deng2021VectorNA,
    title   = {Vector Neurons: A General Framework for SO(3)-Equivariant Networks},
    author  = {Congyue Deng and Or Litany and Yueqi Duan and Adrien Poulenard and Andrea Tagliasacchi and Leonidas J. Guibas},
    journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
    year    = {2021},
    pages   = {12180-12189},
    url     = {https://api.semanticscholar.org/CorpusID:233394028}
}
@inproceedings{Kim2020TheLC,
    title   = {The Lipschitz Constant of Self-Attention},
    author  = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
    booktitle = {International Conference on Machine Learning},
    year    = {2020},
    url     = {https://api.semanticscholar.org/CorpusID:219530837}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}

.\lucidrains\VN-transformer\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'VN-transformer',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.0',  # 版本号
  license='MIT',  # 许可证
  description = 'Vector Neuron Transformer (VN-Transformer)',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/VN-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'equivariance',
    'vector neurons',
    'transformers',
    'attention mechanism'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.6.0',
    'torch>=1.6'
  ],
  setup_requires=[  # 设置依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试依赖
    'pytest'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.8',
  ],
)

.\lucidrains\VN-transformer\tests\test.py

# 导入 pytest 库
import pytest

# 导入 torch 库
import torch
# 从 VN_transformer 模块中导入 VNTransformer、VNInvariant、VNAttention 类和 rot 函数
from VN_transformer.VN_transformer import VNTransformer, VNInvariant, VNAttention
from VN_transformer.rotations import rot

# 设置默认的 torch 数据类型为 float64
torch.set_default_dtype(torch.float64)

# 测试不变层
def test_vn_invariant():
    # 创建一个 VNInvariant 层对象,输入维度为 64
    layer = VNInvariant(64)

    # 生成一个形状为 (1, 32, 64, 3) 的随机张量
    coors = torch.randn(1, 32, 64, 3)

    # 生成一个随机旋转矩阵 R
    R = rot(*torch.randn(3))
    # 对输入张量和经过旋转的输入张量进行 VNInvariant 层的计算
    out1 = layer(coors)
    out2 = layer(coors @ R)

    # 检查经过不变层计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6)

# 测试等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_equivariance(l2_dist_attn):

    # 创建一个 VNTransformer 模型对象,设置相关参数
    model = VNTransformer(
        dim = 64,
        depth = 2,
        dim_head = 64,
        heads = 8,
        l2_dist_attn = l2_dist_attn
    )

    # 生成一个形状为 (1, 32, 3) 的随机张量
    coors = torch.randn(1, 32, 3)
    # 创建一个形状为 (1, 32) 的全为 True 的布尔张量
    mask  = torch.ones(1, 32).bool()

    # 生成一个随机旋转矩阵 R
    R   = rot(*torch.randn(3))
    # 对输入张量和经过旋转的输入张量进行 VNTransformer 模型的计算
    out1 = model(coors @ R, mask = mask)
    out2 = model(coors, mask = mask) @ R

    # 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'

# 测试 VN Perceiver 注意力等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_perceiver_vn_attention_equivariance(l2_dist_attn):

    # 创建一个 VNAttention 模型对象,设置相关参数
    model = VNAttention(
        dim = 64,
        dim_head = 64,
        heads = 8,
        num_latents = 2,
        l2_dist_attn = l2_dist_attn
    )

    # 生成一个形状为 (1, 32, 64, 3) 的随机张量
    coors = torch.randn(1, 32, 64, 3)
    # 创建一个形状为 (1, 32) 的全为 True 的布尔张量
    mask  = torch.ones(1, 32).bool()

    # 生成一个随机旋转矩阵 R
    R   = rot(*torch.randn(3))
    # 对输入张量和经过旋转的输入张量进行 VNAttention 模型的计算
    out1 = model(coors @ R, mask = mask)
    out2 = model(coors, mask = mask) @ R

    # ��查输出张量的形状是否符合预期
    assert out1.shape[1] == 2
    # 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'

# 测试 SO(3) 早期融合等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_equivariance_with_early_fusion(l2_dist_attn):

    # 创建一个 VNTransformer 模型对象,设置相关参数
    model = VNTransformer(
        dim = 64,
        depth = 2,
        dim_head = 64,
        heads = 8,
        dim_feat = 64,
        l2_dist_attn = l2_dist_attn
    )

    # 生成一个形状为 (1, 32, 64) 的随机张量
    feats = torch.randn(1, 32, 64)
    # 生成一个形状为 (1, 32, 3) 的随机张量
    coors = torch.randn(1, 32, 3)
    # 创建一个形状为 (1, 32) 的全为 True 的布尔张量
    mask  = torch.ones(1, 32).bool()

    # 生成一个随机旋转矩阵 R
    R   = rot(*torch.randn(3))
    # 对输入张量和特征张量进行 VNTransformer 模型的计算
    out1, _ = model(coors @ R, feats = feats, mask = mask, return_concatted_coors_and_feats = False)

    out2, _ = model(coors, feats = feats, mask = mask, return_concatted_coors_and_feats = False)
    out2 = out2 @ R

    # 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'

# 测试 SE(3) 早期融合等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_se3_equivariance_with_early_fusion(l2_dist_attn):

    # 创建一个 VNTransformer 模型对象,设置相关参数
    model = VNTransformer(
        dim = 64,
        depth = 2,
        dim_head = 64,
        heads = 8,
        dim_feat = 64,
        translation_equivariance = True,
        l2_dist_attn = l2_dist_attn
    )

    # 生成一个形状为 (1, 32, 64) 的随机张量
    feats = torch.randn(1, 32, 64)
    # 生成一个形状为 (1, 32, 3) 的随机张量
    coors = torch.randn(1, 32, 3)
    # 创建一个形状为 (1, 32) 的全为 True 的布尔张量
    mask  = torch.ones(1, 32).bool()

    # 生成一个随机平移向量 T 和旋转矩阵 R
    T   = torch.randn(3)
    R   = rot(*torch.randn(3))
    # 对输入张量和特征张量进行 VNTransformer 模型的计算
    out1, _ = model((coors + T) @ R, feats = feats, mask = mask, return_concatted_coors_and_feats = False)

    out2, _ = model(coors, feats = feats, mask = mask, return_concatted_coors_and_feats = False)
    out2 = (out2 + T) @ R

    # 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'

.\lucidrains\VN-transformer\VN_transformer\attend.py

# 导入必要的库
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# 定义一个命名元组 FlashAttentionConfig,包含三个布尔类型的参数
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只能调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个只能打印一次的函数
print_once = once(print)

# 主要类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False,
        l2_dist = False
    ):
        super().__init__()
        assert not (flash and l2_dist), 'flash attention is not compatible with l2 distance'
        self.l2_dist = l2_dist

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定用于 cuda 和 cpu 的高效注意力配置

        self.cpu_config = FlashAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(False, True, True)

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查是否存在 mask 并将其扩展到兼容的形状
        # mask 的形状为 B L,需要扩展为 B H N L

        if exists(mask):
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 torch.backends.cuda.sdp_kernel(**config._asdict()) 来调用 pytorch 2.0 的 flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0.
            )

        return out
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)和可选的掩码(mask)
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取查询(q)和键(k)的序列长度以及设备信息
        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        # 缩放因子,根据特征维度的倒数开根号
        scale = q.shape[-1] ** -0.5

        # 如果存在掩码(mask)且维度不是4,则重新排列掩码的维度
        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')

        # 如果启用了flash,则调用flash_attn函数
        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度计算

        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # L2距离计算

        if self.l2_dist:
            # -cdist squared == (-q^2 + 2qk - k^2)
            # 因此简单地基于上面的qk进行计算
            q_squared = reduce(q ** 2, 'b h i d -> b h i 1', 'sum')
            k_squared = reduce(k ** 2, 'b h j d -> b h 1 j', 'sum')
            sim = sim * 2 - q_squared - k_squared

        # 键填充掩码

        if exists(mask):
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 注意力计算

        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值

        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\VN-transformer\VN_transformer\rotations.py

# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos

# 定义绕 z 轴旋转的函数,参数为旋转角度 gamma
def rot_z(gamma):
    # 返回绕 z 轴旋转的旋转矩阵
    return torch.tensor([
        [cos(gamma), -sin(gamma), 0],
        [sin(gamma), cos(gamma), 0],
        [0, 0, 1]
    ], dtype=gamma.dtype)

# 定义绕 y 轴旋转的函数,参数为旋转角度 beta
def rot_y(beta):
    # 返回绕 y 轴旋转的旋转矩阵
    return torch.tensor([
        [cos(beta), 0, sin(beta)],
        [0, 1, 0],
        [-sin(beta), 0, cos(beta)]
    ], dtype=beta.dtype)

# 定义绕任意轴旋转的函数,参数为三个旋转角度 alpha, beta, gamma
def rot(alpha, beta, gamma):
    # 返回绕任意轴旋转的旋转矩阵,先绕 z 轴旋转 alpha,再绕 y 轴旋转 beta,最后绕 z 轴旋转 gamma
    return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

.\lucidrains\VN-transformer\VN_transformer\VN_transformer.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor

# 从 einops 中导入 rearrange, repeat, reduce
from einops import rearrange, repeat, reduce
# 从 einops.layers.torch 中导入 Rearrange, Reduce
from einops.layers.torch import Rearrange, Reduce
# 从 VN_transformer.attend 中导入 Attend

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 计算两个向量的内积
def inner_dot_product(x, y, *, dim = -1, keepdim = True):
    return (x * y).sum(dim = dim, keepdim = keepdim)

# layernorm

# LayerNorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# equivariant modules

# VNLinear 类
class VNLinear(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        bias_epsilon = 0.
    ):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(dim_out, dim_in))

        self.bias = None
        self.bias_epsilon = bias_epsilon

        # 在这篇论文中,他们提出使用一个小偏置进行准等变性,通过 epsilon 可控,他们声称这样可以获得更好的稳定性和结果

        if bias_epsilon > 0.:
            self.bias = nn.Parameter(torch.randn(dim_out))

    def forward(self, x):
        out = einsum('... i c, o i -> ... o c', x, self.weight)

        if exists(self.bias):
            bias = F.normalize(self.bias, dim = -1) * self.bias_epsilon
            out = out + rearrange(bias, '... -> ... 1')

        return out

# VNReLU 类
class VNReLU(nn.Module):
    def __init__(self, dim, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.W = nn.Parameter(torch.randn(dim, dim))
        self.U = nn.Parameter(torch.randn(dim, dim))

    def forward(self, x):
        q = einsum('... i c, o i -> ... o c', x, self.W)
        k = einsum('... i c, o i -> ... o c', x, self.U)

        qk = inner_dot_product(q, k)

        k_norm = k.norm(dim = -1, keepdim = True).clamp(min = self.eps)
        q_projected_on_k = q - inner_dot_product(q, k / k_norm) * k

        out = torch.where(
            qk >= 0.,
            q,
            q_projected_on_k
        )

        return out

# VNAttention 类
class VNAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dim_coor = 3,
        bias_epsilon = 0.,
        l2_dist_attn = False,
        flash = False,
        num_latents = None   # 设置此参数将启用类似于 perceiver 的跨注意力机制,从潜在变量到序列,潜在变量由 VNWeightedPool 推导而来
    ):
        super().__init__()
        assert not (l2_dist_attn and flash), 'l2 distance attention is not compatible with flash attention'

        self.scale = (dim_coor * dim_head) ** -0.5
        dim_inner = dim_head * heads
        self.heads = heads

        self.to_q_input = None
        if exists(num_latents):
            self.to_q_input = VNWeightedPool(dim, num_pooled_tokens = num_latents, squeeze_out_pooled_dim = False)

        self.to_q = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
        self.to_k = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
        self.to_v = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
        self.to_out = VNLinear(dim_inner, dim, bias_epsilon = bias_epsilon)

        if l2_dist_attn and not exists(num_latents):
            # 对于 l2 距离注意力,查询和键是相同的,不是 perceiver-like 注意力
            self.to_k = self.to_q

        self.attend = Attend(flash = flash, l2_dist = l2_dist_attn)
    # 定义一个前向传播函数,接受输入 x 和可选的 mask 参数
    def forward(self, x, mask = None):
        """
        einstein notation
        b - batch
        n - sequence
        h - heads
        d - feature dimension (channels)
        c - coordinate dimension (3 for 3d space)
        i - source sequence dimension
        j - target sequence dimension
        """

        # 获取输入 x 的最后一个维度,即特征维度的大小
        c = x.shape[-1]

        # 如果存在 self.to_q_input 方法,则使用该方法处理输入 x 和 mask,否则直接使用 x
        if exists(self.to_q_input):
            q_input = self.to_q_input(x, mask = mask)
        else:
            q_input = x

        # 分别通过 self.to_q、self.to_k、self.to_v 方法处理 q_input,得到 q、k、v
        q, k, v = self.to_q(q_input), self.to_k(x), self.to_v(x)
        # 将 q、k、v 重排维度,将其转换为 'b h n (d c)' 的形式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) c -> b h n (d c)', h = self.heads), (q, k, v))

        # 调用 attend 方法进行注意力计算
        out = self.attend(q, k, v, mask = mask)

        # 将输出 out 重排维度,将其转换为 'b n (h d) c' 的形式
        out = rearrange(out, 'b h n (d c) -> b n (h d) c', c = c)
        # 返回处理后的输出结果
        return self.to_out(out)
# 定义一个 VNFeedForward 类,包含线性层、ReLU 激活函数和另一个线性层
def VNFeedForward(dim, mult = 4, bias_epsilon = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult)
    # 返回一个包含上述三个层的序列模块
    return nn.Sequential(
        VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon),  # VNLinear 线性层
        VNReLU(dim_inner),  # VNReLU 激活函数
        VNLinear(dim_inner, dim, bias_epsilon = bias_epsilon)  # 另一个 VNLinear 线性层
    )

# 定义一个 VNLayerNorm 类,包含 LayerNorm 层
class VNLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.ln = LayerNorm(dim)  # LayerNorm 层

    def forward(self, x):
        norms = x.norm(dim = -1)
        x = x / rearrange(norms.clamp(min = self.eps), '... -> ... 1')
        ln_out = self.ln(norms)
        return x * rearrange(ln_out, '... -> ... 1')

# 定义一个 VNWeightedPool 类,包含权重参数和池化操作
class VNWeightedPool(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        num_pooled_tokens = 1,
        squeeze_out_pooled_dim = True
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.weight = nn.Parameter(torch.randn(num_pooled_tokens, dim, dim_out))  # 权重参数
        self.squeeze_out_pooled_dim = num_pooled_tokens == 1 and squeeze_out_pooled_dim

    def forward(self, x, mask = None):
        if exists(mask):
            mask = rearrange(mask, 'b n -> b n 1 1')
            x = x.masked_fill(~mask, 0.)
            numer = reduce(x, 'b n d c -> b d c', 'sum')
            denom = mask.sum(dim = 1)
            mean_pooled = numer / denom.clamp(min = 1e-6)
        else:
            mean_pooled = reduce(x, 'b n d c -> b d c', 'mean')

        out = einsum('b d c, m d e -> b m e c', mean_pooled, self.weight)

        if not self.squeeze_out_pooled_dim:
            return out

        out = rearrange(out, 'b 1 d c -> b d c')
        return out

# 定义一个 VNTransformerEncoder 类,包含多层 VNAttention、VNLayerNorm 和 VNFeedForward
class VNTransformerEncoder(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        dim_coor = 3,
        ff_mult = 4,
        final_norm = False,
        bias_epsilon = 0.,
        l2_dist_attn = False,
        flash_attn = False
    ):
        super().__init__()
        self.dim = dim
        self.dim_coor = dim_coor

        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                VNAttention(dim = dim, dim_head = dim_head, heads = heads, bias_epsilon = bias_epsilon, l2_dist_attn = l2_dist_attn, flash = flash_attn),  # VNAttention 层
                VNLayerNorm(dim),  # VNLayerNorm 层
                VNFeedForward(dim = dim, mult = ff_mult, bias_epsilon = bias_epsilon),  # VNFeedForward 层
                VNLayerNorm(dim)  # 另一个 VNLayerNorm 层
            ]))

        self.norm = VNLayerNorm(dim) if final_norm else nn.Identity()

    def forward(
        self,
        x,
        mask = None
    ):
        *_, d, c = x.shape

        assert x.ndim == 4 and d == self.dim and c == self.dim_coor, 'input needs to be in the shape of (batch, seq, dim ({self.dim}), coordinate dim ({self.dim_coor}))'

        for attn, attn_post_ln, ff, ff_post_ln in self.layers:
            x = attn_post_ln(attn(x, mask = mask)) + x
            x = ff_post_ln(ff(x)) + x

        return self.norm(x)

# 定义一个 VNInvariant 类,包含 MLP 模块
class VNInvariant(nn.Module):
    def __init__(
        self,
        dim,
        dim_coor = 3,

    ):
        super().__init__()
        self.mlp = nn.Sequential(
            VNLinear(dim, dim_coor),  # VNLinear 线性层
            VNReLU(dim_coor),  # VNReLU 激活函数
            Rearrange('... d e -> ... e d')  # 重新排列维度
        )

    def forward(self, x):
        return einsum('b n d i, b n i o -> b n o', x, self.mlp(x))

# 定义一个 VNTransformer 类,包含多个参数和模块
class VNTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        num_tokens = None,
        dim_feat = None,
        dim_head = 64,
        heads = 8,
        dim_coor = 3,
        reduce_dim_out = True,
        bias_epsilon = 0.,
        l2_dist_attn = False,
        flash_attn = False,
        translation_equivariance = False,
        translation_invariant = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 如果 num_tokens 存在,则创建一个维度为 dim 的嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None

        # 设置特征维度为 dim_feat 或默认为 0
        dim_feat = default(dim_feat, 0)
        self.dim_feat = dim_feat
        # 计算坐标总维度,包括坐标和特征
        self.dim_coor_total = dim_coor + dim_feat

        # 确保平移等变性和平移不变性最多只能有一个为真
        assert (int(translation_equivariance) + int(translation_invariant)) <= 1
        self.translation_equivariance = translation_equivariance
        self.translation_invariant = translation_invariant

        # 定义输入投影层
        self.vn_proj_in = nn.Sequential(
            Rearrange('... c -> ... 1 c'),
            VNLinear(1, dim, bias_epsilon = bias_epsilon)
        )

        # 创建 VNTransformerEncoder 编码器
        self.encoder = VNTransformerEncoder(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            bias_epsilon = bias_epsilon,
            dim_coor = self.dim_coor_total,
            l2_dist_attn = l2_dist_attn,
            flash_attn = flash_attn
        )

        # 如果需要减少输出维度,则定义输出投影层
        if reduce_dim_out:
            self.vn_proj_out = nn.Sequential(
                VNLayerNorm(dim),
                VNLinear(dim, 1, bias_epsilon = bias_epsilon),
                Rearrange('... 1 c -> ... c')
            )
        else:
            self.vn_proj_out = nn.Identity()

    def forward(
        self,
        coors,
        *,
        feats = None,
        mask = None,
        return_concatted_coors_and_feats = False
    ):
        # 如果需要平移等变性或平移不变性,则计算坐标的平均值并减去
        if self.translation_equivariance or self.translation_invariant:
            coors_mean = reduce(coors, '... c -> c', 'mean')
            coors = coors - coors_mean

        x = coors

        # 如果存在特征,则将特征拼接到坐标中
        if exists(feats):
            if feats.dtype == torch.long:
                assert exists(self.token_emb), 'num_tokens must be given to the VNTransformer (to build the Embedding), if the features are to be given as indices'
                feats = self.token_emb(feats)

            assert feats.shape[-1] == self.dim_feat, f'dim_feat should be set to {feats.shape[-1]}'
            x = torch.cat((x, feats), dim = -1)

        assert x.shape[-1] == self.dim_coor_total

        # 输入投影层
        x = self.vn_proj_in(x)
        # 编码器
        x = self.encoder(x, mask = mask)
        # 输出投影层
        x = self.vn_proj_out(x)

        # 提取坐标和特征
        coors_out, feats_out = x[..., :3], x[..., 3:]

        # 如果需要平移等变性,则将坐标输出加上坐标平均值
        if self.translation_equivariance:
            coors_out = coors_out + coors_mean

        # 如果没有特征,则返回坐标输出
        if not exists(feats):
            return coors_out

        # 如果需要返回拼接的坐标和特征,则返回拼接后的结果
        if return_concatted_coors_and_feats:
            return torch.cat((coors_out, feats_out), dim = -1)

        # 否则返回坐标和特征分开的结果
        return coors_out, feats_out

.\lucidrains\VN-transformer\VN_transformer\__init__.py

# 从VN_transformer.VN_transformer模块中导入以下类和函数
from VN_transformer.VN_transformer import (
    VNTransformer,         # 导入VNTransformer类
    VNLinear,              # 导入VNLinear类
    VNLayerNorm,           # 导入VNLayerNorm类
    VNFeedForward,         # 导入VNFeedForward类
    VNAttention,           # 导入VNAttention类
    VNWeightedPool,        # 导入VNWeightedPool类
    VNTransformerEncoder,  # 导入VNTransformerEncoder类
    VNInvariant            # 导入VNInvariant类
)

Voicebox - Pytorch

Implementation of Voicebox, new SOTA Text-to-Speech model from MetaAI, in Pytorch. Press release

In this work, we will use rotary embeddings. The authors seem unaware that ALiBi cannot be straightforwardly used for bidirectional models.

The paper also addresses the issue with time embedding incorrectly subjected to relative distances (they concat the time embedding along the frame dimension of the audio tokens). This repository will use adaptive normalization, as applied successfully in Paella

Appreciation

  • Translated for awarding me the Imminent Grant to advance the state of open sourced text-to-speech solutions. This project was started and will be completed under this grant.

  • StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • Bryan Chiang for the ongoing code review, sharing his expertise on TTS, and pointing me to an open sourced implementation of conditional flow matching

  • Manmay for getting the repository started with the alignment code

  • @chenht2010 for finding a bug with rotary positions, and for validating that the code in the repository converges

  • Lucas Newman for (yet again) pull requesting all the training code for Spear-TTS conditioned Voicebox training!

  • Lucas Newman has demonstrated that the whole system works with Spear-TTS conditioning. Training converges even better than Soundstorm

Install

$ pip install voicebox-pytorch

Usage

Training and sampling with TextToSemantic module from SpearTTS

import torch

from voicebox_pytorch import (
    VoiceBox,
    EncodecVoco,
    ConditionalFlowMatcherWrapper,
    HubertWithKmeans,
    TextToSemantic
)

# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

wav2vec = HubertWithKmeans(
    checkpoint_path = '/path/to/hubert/checkpoint.pt',
    kmeans_path = '/path/to/hubert/kmeans.bin'
)

text_to_semantic = TextToSemantic(
    wav2vec = wav2vec,
    dim = 512,
    source_depth = 1,
    target_depth = 1,
    use_openai_tokenizer = True
)

text_to_semantic.load('/path/to/trained/spear-tts/model.pt')

model = VoiceBox(
    dim = 512,
    audio_enc_dec = EncodecVoco(),
    num_cond_tokens = 500,
    depth = 2,
    dim_head = 64,
    heads = 16
)

cfm_wrapper = ConditionalFlowMatcherWrapper(
    voicebox = model,
    text_to_semantic = text_to_semantic
)

# mock data

audio = torch.randn(2, 12000)

# train

loss = cfm_wrapper(audio)
loss.backward()

# after much training

texts = [
    'the rain in spain falls mainly in the plains',
    'she sells sea shells by the seashore'
]

cond = torch.randn(2, 12000)
sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1, <audio length>)

For unconditional training, condition_on_text on VoiceBox must be set to False

import torch
from voicebox_pytorch import (
    VoiceBox,
    ConditionalFlowMatcherWrapper
)

model = VoiceBox(
    dim = 512,
    num_cond_tokens = 500,
    depth = 2,
    dim_head = 64,
    heads = 16,
    condition_on_text = False
)

cfm_wrapper = ConditionalFlowMatcherWrapper(
    voicebox = model
)

# mock data

x = torch.randn(2, 1024, 512)

# train

loss = cfm_wrapper(x)

loss.backward()

# after much training

cond = torch.randn(2, 1024, 512)

sampled = cfm_wrapper.sample(cond = cond) # (2, 1024, 512)

Todo

Citations

@article{Le2023VoiceboxTM,
    title   = {Voicebox: Text-Guided Multilingual Universal Speech Generation at Scale},
    author  = {Matt Le and Apoorv Vyas and Bowen Shi and Brian Karrer and Leda Sari and Rashel Moritz and Mary Williamson and Vimal Manohar and Yossi Adi and Jay Mahadeokar and Wei-Ning Hsu},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.15687},
    url     = {https://api.semanticscholar.org/CorpusID:259275061}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@misc{torchdiffeq,
    author  = {Chen, Ricky T. Q.},
    title   = {torchdiffeq},
    year    = {2018},
    url     = {https://github.com/rtqichen/torchdiffeq},
}
@inproceedings{lienen2022torchode,
    title     = {torchode: A Parallel {ODE} Solver for PyTorch},
    author    = {Marten Lienen and Stephan G{\"u}nnemann},
    booktitle = {The Symbiosis of Deep Learning and Differential Equations II, NeurIPS},
    year      = {2022},
    url       = {https://openreview.net/forum?id=uiKVKTiUYB0}
}
@article{siuzdak2023vocos,
    title   = {Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis},
    author  = {Siuzdak, Hubert},
    journal = {arXiv preprint arXiv:2306.00814},
    year    = {2023}
}
@misc{darcet2023vision,
    title   = {Vision Transformers Need Registers},
    author  = {Timothée Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    year    = {2023},
    eprint  = {2309.16588},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Dehghani2023ScalingVT,
    title   = {Scaling Vision Transformers to 22 Billion Parameters},
    author  = {Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim M. Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Collier and Alexey A. Gritsenko and Vighnesh Birodkar and Cristina Nader Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Paveti'c and Dustin Tran and Thomas Kipf and Mario Luvci'c and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
    booktitle = {International Conference on Machine Learning},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:256808367}
}
@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}

.\lucidrains\voicebox-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'voicebox-pytorch',  # 包名
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.5.0',  # 版本号
  license='MIT',  # 许可证
  description = 'Voicebox - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/voicebox-pytorch',  # URL
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'text to speech'
  ],
  install_requires=[  # 安装依赖
    'accelerate',
    'audiolm-pytorch>=1.2.28',
    'naturalspeech2-pytorch>=0.1.8',
    'beartype',
    'einops>=0.6.1',
    'gateloop-transformer>=0.2.4',
    'spear-tts-pytorch>=0.4.0',
    'torch>=2.0',
    'torchdiffeq',
    'torchode',
    'vocos'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\voicebox-pytorch\voicebox_pytorch\attend.py

# 从 functools 模块导入 wraps 函数
# 从 packaging 模块导入 version 类
# 从 collections 模块导入 namedtuple 类
# 导入 torch 库
# 从 torch 模块中导入 nn, einsum 函数
# 从 torch.nn 模块中导入 functional 模块
# 从 einops 模块中导入 rearrange, reduce 函数
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# 定义一个命名元组 FlashAttentionConfig,包含三个布尔类型的字段
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,判断值是否存在
def exists(val):
    return val is not None

# 定义一个辅助函数,如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义一个装饰器函数,确保被装饰的函数只能调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个打印函数,使用 once 装饰器确保只打印一次
print_once = once(print)

# 主类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False,
        scale = None
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.scale = scale

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 CUDA 和 CPU 的高效注意力配置

        self.cpu_config = FlashAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(False, True, True)

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, dim_head, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 如果给定了 scale,将 q 乘以默认的 scale

        if exists(self.scale):
            q = q * (self.scale / (dim_head ** -0.5))

        # 检查 mask 是否存在并扩展到兼容的形状

        if exists(mask):
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 torch.backends.cuda.sdp_kernel 函数应用 Flash Attention

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        scale = default(self.scale, q.shape[-1] ** -0.5)

        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')

        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度计算

        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # key padding mask

        if exists(mask):
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 注意力计算

        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值

        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\voicebox-pytorch\voicebox_pytorch\data.py

# 导入必要的模块
from pathlib import Path
from functools import wraps

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange

# 从 beartype 模块中导入 beartype 函数和 is_bearable 函数,以及 Optional、Tuple 和 Union 类型
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Optional, Tuple, Union

# 导入 torch 模块
import torch
# 从 torch.nn.utils.rnn 模块中导入 pad_sequence 函数
from torch.nn.utils.rnn import pad_sequence
# 从 torch.utils.data 模块中导入 Dataset 和 DataLoader 类
from torch.utils.data import Dataset, DataLoader

# 导入 torchaudio 模块
import torchaudio

# utilities

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 将值转换为元组的函数
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# dataset functions

# 定义 AudioDataset 类,继承自 Dataset 类
class AudioDataset(Dataset):
    # 初始化函数
    @beartype
    def __init__(
        self,
        folder,
        audio_extension = ".flac"
    ):
        super().__init__()
        # 将文件夹路径转换为 Path 对象
        path = Path(folder)
        # 断言文件夹存在
        assert path.exists(), 'folder does not exist'

        self.audio_extension = audio_extension

        # 获取文件夹下所有指定扩展名的文件列表
        files = list(path.glob(f'**/*{audio_extension}'))
        # 断言找到了文件
        assert len(files) > 0, 'no files found'

        self.files = files

    # 返回数据集的长度
    def __len__(self):
        return len(self.files)

    # 获取指定索引处的数据
    def __getitem__(self, idx):
        file = self.files[idx]

        # 加载音频文件
        wave, _ = torchaudio.load(file)
        # 重新排列音频数据的维度
        wave = rearrange(wave, '1 ... -> ...')

        return wave

# dataloader functions

# 定义装饰器函数,用于处理单个或多个张量的数据
def collate_one_or_multiple_tensors(fn):
    @wraps(fn)
    def inner(data):
        is_one_data = not isinstance(data[0], tuple)

        if is_one_data:
            data = fn(data)
            return (data,)

        outputs = []
        for datum in zip(*data):
            if is_bearable(datum, Tuple[str, ...]):
                output = list(datum)
            else:
                output = fn(datum)

            outputs.append(output)

        return tuple(outputs)

    return inner

# 裁剪数据到最短长度的函数
@collate_one_or_multiple_tensors
def curtail_to_shortest_collate(data):
    min_len = min(*[datum.shape[0] for datum in data])
    data = [datum[:min_len] for datum in data]
    return torch.stack(data)

# 填充数据到最长长度的函数
@collate_one_or_multiple_tensors
def pad_to_longest_fn(data):
    return pad_sequence(data, batch_first = True)

# 获取 DataLoader 对象的函数
def get_dataloader(ds, pad_to_longest = True, **kwargs):
    collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
    return DataLoader(ds, collate_fn = collate_fn, **kwargs)

.\lucidrains\voicebox-pytorch\voicebox_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 根据参数设置创建优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True
):
    # 判断是否需要权重衰减
    has_wd = wd > 0

    # 根据是否需要过滤梯度为零的参数来更新参数列表
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并应用权重衰减
    if group_wd_params and has_wd:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 将参数分为需要权重衰减和不需要权重衰减的两组
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 如果不需要权重衰减,则使用 Adam 优化器
    if not has_wd:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要权重衰减,则使用 AdamW 优化器
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\voicebox-pytorch\voicebox_pytorch\trainer.py

# 导入正则表达式模块
import re
# 从路径模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 contextlib 模块中导入 nullcontext 上下文管理器
from contextlib import nullcontext

# 导入 beartype 模块中的 beartype 装饰器
from beartype import beartype

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.optim.lr_scheduler 模块中导入 CosineAnnealingLR 类
from torch.optim.lr_scheduler import CosineAnnealingLR
# 从 torch.utils.data 模块中导入 Dataset 类和 random_split 函数
from torch.utils.data import Dataset, random_split

# 从 voicebox_pytorch.voicebox_pytorch 模块中导入 ConditionalFlowMatcherWrapper 类
from voicebox_pytorch.voicebox_pytorch import ConditionalFlowMatcherWrapper
# 从 voicebox_pytorch.data 模块中导入 get_dataloader 函数
from voicebox_pytorch.data import get_dataloader
# 从 voicebox_pytorch.optimizer 模块中导入 get_optimizer 函数

from voicebox_pytorch.optimizer import get_optimizer

# 从 accelerate 模块中导入 Accelerator 类和 DistributedType 类
from accelerate import Accelerator, DistributedType
# 从 accelerate.utils 模块中导入 DistributedDataParallelKwargs 类
from accelerate.utils import DistributedDataParallelKwargs

# helpers

# 定义一个函数,判断值是否存在
def exists(val):
    return val is not None

# 定义一个空函数,不做任何操作
def noop(*args, **kwargs):
    pass

# 定义一个循环生成器函数,用于循环遍历数据集
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 定义一个函数,将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 定义一个函数,询问用户是或否的问题
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 定义一个函数,累积日志信息
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 定义一个函数,从检查点文件名中获取训练步数
def checkpoint_num_steps(checkpoint_path):
    """Returns the number of steps trained from a checkpoint based on the filename.

    Filename format assumed to be something like "/path/to/voicebox.20000.pt" which is
    for 20k train steps. Returns 20000 in that case.
    """
    # 使用正则表达式查找文件名中的数字
    results = re.findall(r'\d+', str(checkpoint_path)

    # 如果没有找到数字,则返回 0
    if len(results) == 0:
        return 0

    # 返回最后一个找到的数字
    return int(results[-1])

# 定义一个 VoiceBoxTrainer 类,继承自 nn.Module
class VoiceBoxTrainer(nn.Module):
    # 使用 beartype 装饰器对初始化方法进行类型检查
    @beartype
    def __init__(
        self,
        cfm_wrapper: ConditionalFlowMatcherWrapper,
        *,
        batch_size,
        dataset: Dataset,
        num_train_steps = None,
        num_warmup_steps = None,
        num_epochs = None,
        lr = 3e-4,
        initial_lr = 1e-5,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        log_every = 10,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        force_clear_prev_results = None,
        split_batches = False,
        drop_last = False,
        accelerate_kwargs: dict = dict(),
        ):
        # 调用父类的构造函数
        super().__init__()

        # 设置分布式数据并行的参数
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)

        # 初始化加速器
        self.accelerator = Accelerator(
            kwargs_handlers = [ddp_kwargs],
            split_batches = split_batches,
            **accelerate_kwargs
        )

        # 设置模型包装器
        self.cfm_wrapper = cfm_wrapper

        # 注册缓冲区
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置批量大小和梯度累积步数
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 初始化优化器
        self.optim = get_optimizer(
            cfm_wrapper.parameters(),
            lr = lr,
            wd = wd
        )

        self.lr = lr
        self.initial_lr = initial_lr

        # 设置最大梯度范数
        self.max_grad_norm = max_grad_norm

        # 创建数据集
        self.ds = dataset

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
        assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'

        assert exists(num_train_steps) or exists(num_epochs), 'either num_train_steps or num_epochs must be specified'

        if exists(num_epochs):
            self.num_train_steps = len(dataset) // batch_size * num_epochs
        else:
            self.num_train_steps = num_train_steps
        self.scheduler = CosineAnnealingLR(self.optim, T_max=self.num_train_steps)
        self.num_warmup_steps = num_warmup_steps if exists(num_warmup_steps) else 0
        
        # 初始化数据加载器
        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        # 使用加速器准备模型、优化器、调度器和数据加载器
        (
            self.cfm_wrapper,
            self.optim,
            self.scheduler,
            self.dl
        ) = self.accelerator.prepare(
            self.cfm_wrapper,
            self.optim,
            self.scheduler,
            self.dl
        )

        # 初始化数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置日志、保存模型和保存结果的频率
        self.log_every = log_every
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置结果文件夹路径
        self.results_folder = Path(results_folder)

        # 如果是主进程并且需要清除之前的结果,则清除结果文件夹
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        # 设置超参数
        hps = {
            "num_train_steps": self.num_train_steps,
            "num_warmup_steps": self.num_warmup_steps,
            "learning_rate": self.lr,
            "initial_learning_rate": self.initial_lr,
            "wd": wd
        }
        # 初始化加速器的跟踪器
        self.accelerator.init_trackers("voicebox", config=hps)

    # 保存模型的方法
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.cfm_wrapper),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        # 保存模型参数、优化器状态和调度器状态到指定路径
        torch.save(pkg, path)
    # 加载模型参数和优化器状态
    def load(self, path):
        # 解封装模型
        cfm_wrapper = self.accelerator.unwrap_model(self.cfm_wrapper)
        # 加载模型参数
        pkg = cfm_wrapper.load(path)

        # 加载优化器状态
        self.optim.load_state_dict(pkg['optim'])
        # 加载调度器状态
        self.scheduler.load_state_dict(pkg['scheduler'])

        # 从下一步开始,避免覆盖最后一个检查点
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.cfm_wrapper.generate(*args, **kwargs)

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 是否主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 是否本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 热身
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    
    # 训练步骤
    def train_step(self):
        steps = int(self.steps.item())

        self.cfm_wrapper.train()
        
        # 根据调度表调整学习率
        
        if steps < self.num_warmup_steps:
            # 应用热身

            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 热身期后,开始应用学习率退火
            
            self.scheduler.step()

        # 日志

        logs = {}

        # 训练步骤

        for grad_accum_step in range(self.grad_accum_every):
            is_last = grad_accum_step == (self.grad_accum_every - 1)
            context = partial(self.accelerator.no_sync, self.cfm_wrapper) if not is_last else nullcontext

            wave, = next(self.dl_iter)

            with self.accelerator.autocast(), context():
                loss = self.cfm_wrapper(wave)

                self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.cfm_wrapper.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # 日志

        if not steps % self.log_every:
            self.print(f"{steps}: loss: {logs['loss']:0.3f}")

        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 每隔一段时间采样结果

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            wave, = next(self.valid_dl_iter)
            unwrapped_model = self.accelerator.unwrap_model(self.cfm_wrapper)

            with torch.inference_mode():
                unwrapped_model.eval()

                wave = wave.to(unwrapped_model.device)
                valid_loss = unwrapped_model(wave)

                self.print(f'{steps}: valid loss {valid_loss:0.3f}')
                self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 每隔一段时间保存模型

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'voicebox.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    # 训练
    def train(self, log_fn = noop):
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')
        self.accelerator.end_training()

.\lucidrains\voicebox-pytorch\voicebox_pytorch\voicebox_pytorch.py

import math
import logging
from random import random
from functools import partial
from pathlib import Path

import torch
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
from torch.nn import Module
import torch.nn.functional as F
from torch.cuda.amp import autocast

import torchode as to
from torchdiffeq import odeint

from beartype import beartype
from beartype.typing import Tuple, Optional, List, Union

from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack

from voicebox_pytorch.attend import Attend

from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, BinLoss, maximum_path
from naturalspeech2_pytorch.utils.tokenizer import Tokenizer
from naturalspeech2_pytorch.naturalspeech2_pytorch import generate_mask_from_repeats

from audiolm_pytorch import EncodecWrapper
from spear_tts_pytorch import TextToSemantic

from gateloop_transformer import SimpleGateLoopLayer as GateLoop

import torchaudio.transforms as T
from torchaudio.functional import DB_to_amplitude, resample

from vocos import Vocos

LOGGER = logging.getLogger(__file__)

# helper functions

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t):
    return t

# 返回输入值或默认值
def default(val, d):
    return val if exists(val) else d

# 检查是否可以被整除
def divisible_by(num, den):
    return (num % den) == 0

# 检查是否为奇数
def is_odd(n):
    return not divisible_by(n, 2)

# 随机返回 True 或 False
def coin_flip():
    return random() < 0.5

# 将张量打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# tensor helpers

# 根据概率生成掩码张量
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob

# 将多个掩码张量按位与操作
def reduce_masks_with_and(*masks):
    masks = [*filter(exists, masks)]

    if len(masks) == 0:
        return None

    mask, *rest_masks = masks

    for rest_mask in rest_masks:
        mask = mask & rest_mask

    return mask

# 对一维张量进行插值
def interpolate_1d(t, length, mode='bilinear'):
    " pytorch does not offer interpolation 1d, so hack by converting to 2d "

    dtype = t.dtype
    t = t.float()

    implicit_one_channel = t.ndim == 2
    if implicit_one_channel:
        t = rearrange(t, 'b n -> b 1 n')

    t = rearrange(t, 'b d n -> b d n 1')
    t = F.interpolate(t, (length, 1), mode=mode)
    t = rearrange(t, 'b d n 1 -> b d n')

    if implicit_one_channel:
        t = rearrange(t, 'b 1 n -> b n')

    t = t.to(dtype)
    return t

# 裁剪或填充张量至目标长度
def curtail_or_pad(t, target_length):
    length = t.shape[-2]

    if length > target_length:
        t = t[..., :target_length, :]
    elif length < target_length:
        t = F.pad(t, (0, 0, 0, target_length - length), value=0.)

    return t

# mask construction helpers

# 根据起始和结束索引生成掩码张量
def mask_from_start_end_indices(seq_len: int, start: Tensor, end: Tensor):
    assert start.shape == end.shape
    device = start.device

    seq = torch.arange(seq_len, device=device, dtype=torch.long)
    seq = seq.reshape(*((-1,) * start.ndim), seq_len)
    seq = seq.expand(*start.shape, seq_len)

    mask = seq >= start[..., None].long()
    mask &= seq < end[..., None].long()
    return mask

# 根据分数长度生成掩码张量
def mask_from_frac_lengths(seq_len: int, frac_lengths: Tensor):
    device = frac_lengths.device

    lengths = (frac_lengths * seq_len).long()
    max_start = seq_len - lengths

    rand = torch.zeros_like(frac_lengths, device=device).float().uniform_(0, 1)
    start = (max_start * rand).clamp(min=0)
    end = start + lengths

    return mask_from_start_end_indices(seq_len, start, end)

# sinusoidal positions

# 用于 @crowsonkb 的学习正弦位置编码类
class LearnedSinusoidalPosEmb(Module):
    """ used by @crowsonkb """
    # 初始化函数,接受维度参数
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 断言维度是2的倍数
        assert divisible_by(dim, 2)
        # 计算维度的一半
        half_dim = dim // 2
        # 初始化权重参数为服从标准正态分布的张量
        self.weights = nn.Parameter(torch.randn(half_dim))

    # 前向传播函数,接受输入张量 x
    def forward(self, x):
        # 重新排列输入张量 x 的维度,增加一个维度
        x = rearrange(x, 'b -> b 1')
        # 计算频率,乘以权重参数和 2π
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 将正弦和余弦值拼接在一起,沿着最后一个维度
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        # 返回傅立叶变换后的张量
        return fouriered
# 旋转位置嵌入
# https://arxiv.org/abs/2104.09864

class RotaryEmbedding(Module):
    def __init__(self, dim, theta = 50000):
        super().__init__()
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    @property
    def device(self):
        return self.inv_freq.device

    @autocast(enabled = False)
    @beartype
    def forward(self, t: Union[int, Tensor]):
        if not torch.is_tensor(t):
            t = torch.arange(t, device = self.device)

        t = t.type_as(self.inv_freq)
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)
        return freqs

def rotate_half(x):
    x1, x2 = x.chunk(2, dim = -1)
    return torch.cat((-x2, x1), dim = -1)

@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
    return t * pos.cos() + rotate_half(t) * pos.sin()

# 卷积位置生成模块

class ConvPositionEmbed(Module):
    def __init__(
        self,
        dim,
        *,
        kernel_size,
        groups = None
    ):
        super().__init__()
        assert is_odd(kernel_size)
        groups = default(groups, dim) # 默认情况下进行全深度卷积

        self.dw_conv1d = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
            nn.GELU()
        )

    def forward(self, x, mask = None):

        if exists(mask):
            mask = mask[..., None]
            x = x.masked_fill(~mask, 0.)

        x = rearrange(x, 'b n c -> b c n')
        x = self.dw_conv1d(x)
        out = rearrange(x, 'b c n -> b n c')

        if exists(mask):
            out = out.masked_fill(~mask, 0.)

        return out

# 规范化

class RMSNorm(Module):
    def __init__(
        self,
        dim
    ):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

class AdaptiveRMSNorm(Module):
    def __init__(
        self,
        dim,
        cond_dim = None
    ):
        super().__init__()
        cond_dim = default(cond_dim, dim)
        self.scale = dim ** 0.5

        self.to_gamma = nn.Linear(cond_dim, dim)
        self.to_beta = nn.Linear(cond_dim, dim)

        # 初始化为单位矩阵

        nn.init.zeros_(self.to_gamma.weight)
        nn.init.ones_(self.to_gamma.bias)

        nn.init.zeros_(self.to_beta.weight)
        nn.init.zeros_(self.to_beta.bias)

    def forward(self, x, *, cond):
        normed = F.normalize(x, dim = -1) * self.scale

        gamma, beta = self.to_gamma(cond), self.to_beta(cond)
        gamma, beta = map(lambda t: rearrange(t, 'b d -> b 1 d'), (gamma, beta))

        return normed * gamma + beta

# 注意力

class MultiheadRMSNorm(Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.gamma * self.scale

class Attention(Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0,
        flash = False,
        qk_norm = False,
        qk_norm_scale = 10
    ):
        super().__init__()
        self.heads = heads
        dim_inner = dim_head * heads

        scale = qk_norm_scale if qk_norm else None

        self.attend = Attend(dropout, flash = flash, scale = scale)

        self.qk_norm = qk_norm

        if qk_norm:
            self.q_norm = MultiheadRMSNorm(dim_head, heads = heads)
            self.k_norm = MultiheadRMSNorm(dim_head, heads = heads)

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
        self.to_out = nn.Linear(dim_inner, dim, bias = False)
    # 定义一个前向传播函数,接受输入张量 x,掩码 mask 和旋转嵌入 rotary_emb
    def forward(self, x, mask = None, rotary_emb = None):
        # 获取头数
        h = self.heads

        # 将输入张量 x 分别映射为查询 q,键 k,值 v
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        # 将查询 q,键 k,值 v 重排维度,以适应多头注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 如果启用了查询和键的归一化
        if self.qk_norm:
            # 对查询 q 进行归一化
            q = self.q_norm(q)
            # 对键 k 进行归一化
            k = self.k_norm(k)

        # 如果存在旋转嵌入
        if exists(rotary_emb):
            # 对查询 q 和键 k 应用旋转位置嵌入
            q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))

        # 进行注意力计算,得到输出 out
        out = self.attend(q, k, v, mask = mask)

        # 重排输出 out 的维度,以适应后续全连接层
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 将输出 out 传递给输出层,返回结果
        return self.to_out(out)
# 定义 GEGLU 类,用于实现 Gated GLU 激活函数
class GEGLU(Module):
    # 前向传播函数
    def forward(self, x):
        # 将输入张量 x 按照最后一个维度分成两部分,x 和 gate
        x, gate = x.chunk(2, dim = -1)
        # 对 gate 部分应用 GELU 激活函数,然后与 x 相乘
        return F.gelu(gate) * x

# 定义 FeedForward 函数,用于创建前馈神经网络层
def FeedForward(dim, mult = 4, dropout = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult * 2 / 3)
    # 返回一个包含线性层、GEGLU 激活函数、Dropout 层和线性层的序列模块
    return nn.Sequential(
        nn.Linear(dim, dim_inner * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim)
    )

# 定义 Transformer 类,用于实现 Transformer 模型
class Transformer(Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        attn_dropout = 0.,
        ff_dropout = 0.,
        num_register_tokens = 0.,
        attn_flash = False,
        adaptive_rmsnorm = False,
        adaptive_rmsnorm_cond_dim_in = None,
        use_unet_skip_connection = False,
        skip_connect_scale = None,
        attn_qk_norm = False,
        use_gateloop_layers = False,
        gateloop_use_jax = False,
    ):
        super().__init__()
        # 断言深度是偶数
        assert divisible_by(depth, 2)
        # 初始化层列表
        self.layers = nn.ModuleList([])

        # 创建旋转嵌入层
        self.rotary_emb = RotaryEmbedding(dim = dim_head)

        # 设置注册令牌数量
        self.num_register_tokens = num_register_tokens
        self.has_register_tokens = num_register_tokens > 0

        # 如果存在注册令牌,则创建注册令牌参数
        if self.has_register_tokens:
            self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))

        # 根据是否自适应 RMSNorm 选择不同的 RMSNorm 类
        if adaptive_rmsnorm:
            rmsnorm_klass = partial(AdaptiveRMSNorm, cond_dim = adaptive_rmsnorm_cond_dim_in)
        else:
            rmsnorm_klass = RMSNorm

        # 设置跳跃连接的缩放因子
        self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5)

        # 循环创建 Transformer 层
        for ind in range(depth):
            layer = ind + 1
            has_skip = use_unet_skip_connection and layer > (depth // 2)

            self.layers.append(nn.ModuleList([
                nn.Linear(dim * 2, dim) if has_skip else None,
                GateLoop(dim = dim, use_jax_associative_scan = gateloop_use_jax, post_ln = True) if use_gateloop_layers else None,
                rmsnorm_klass(dim = dim),
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, qk_norm = attn_qk_norm),
                rmsnorm_klass(dim = dim),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        # 创建最终的 RMSNorm 层
        self.final_norm = RMSNorm(dim)

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None,
        adaptive_rmsnorm_cond = None
        ):
            # 获取输入张量的批量大小、序列长度等信息
            batch, seq_len, *_ = x.shape

            # 在左侧添加注册令牌

            if self.has_register_tokens:
                # 重复注册令牌以匹配批量大小
                register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)

                # 将注册令牌和输入张量打包
                x, ps = pack([register_tokens, x], 'b * d')

                # 如果存在掩码,则在左侧填充
                if exists(mask):
                    mask = F.pad(mask, (self.num_register_tokens, 0), value = True)

            # 跟踪跳跃连接

            skip_connects = []

            # 旋转嵌入

            positions = seq_len

            if self.has_register_tokens:
                # 创建主要位置和注册位置
                main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long)
                register_positions = torch.full((self.num_register_tokens,), -10000, device = self.device, dtype = torch.long)
                positions = torch.cat((register_positions, main_positions))

            # 计算旋转嵌入
            rotary_emb = self.rotary_emb(positions)

            # 自适应 RMSNorm

            rmsnorm_kwargs = dict()
            if exists(adaptive_rmsnorm_cond):
                rmsnorm_kwargs = dict(cond = adaptive_rmsnorm_cond)

            # 通过注意力层

            for skip_combiner, maybe_gateloop, attn_prenorm, attn, ff_prenorm, ff in self.layers:

                # 在论文中,他们使用类似 U-Net 的跳跃连接
                # 不清楚这有多大帮助,因为除了简短的一两句提到外,没有给出任何消融或进一步的数字

                if not exists(skip_combiner):
                    skip_connects.append(x)
                else:
                    skip_connect = skip_connects.pop() * self.skip_connect_scale
                    x = torch.cat((x, skip_connect), dim = -1)
                    x = skip_combiner(x)

                if exists(maybe_gateloop):
                    x = maybe_gateloop(x) + x

                # 计算注意力输入
                attn_input = attn_prenorm(x, **rmsnorm_kwargs)
                x = attn(attn_input, mask = mask, rotary_emb = rotary_emb) + x

                # 计算前馈神经网络输入
                ff_input = ff_prenorm(x, **rmsnorm_kwargs) 
                x = ff(ff_input) + x

            # 移除注册令牌

            if self.has_register_tokens:
                _, x = unpack(x, ps, 'b * d')

            # 返回最终规范化结果
            return self.final_norm(x)
# 定义音频编码器解码器的基类
class AudioEncoderDecoder(nn.Module):
    pass

# 定义 MelVoco 类,继承自 AudioEncoderDecoder
class MelVoco(AudioEncoderDecoder):
    def __init__(
        self,
        *,
        log = True,
        n_mels = 100,
        sampling_rate = 24000,
        f_max = 8000,
        n_fft = 1024,
        win_length = 640,
        hop_length = 160,
        pretrained_vocos_path = 'charactr/vocos-mel-24khz'
    ):
        super().__init__()
        self.log = log
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.f_max = f_max
        self.win_length = win_length
        self.hop_length = hop_length
        self.sampling_rate = sampling_rate

        # 加载预训练的 Vocos 模型
        self.vocos = Vocos.from_pretrained(pretrained_vocos_path)

    @property
    def downsample_factor(self):
        raise NotImplementedError

    @property
    def latent_dim(self):
        return self.num_mels

    # 对音频进行编码
    def encode(self, audio):
        # 对音频进行短时傅里叶变换
        stft_transform = T.Spectrogram(
            n_fft = self.n_fft,
            win_length = self.win_length,
            hop_length = self.hop_length,
            window_fn = torch.hann_window
        )

        spectrogram = stft_transform(audio)

        # 对频谱图进行梅尔频谱变换
        mel_transform = T.MelScale(
            n_mels = self.n_mels,
            sample_rate = self.sampling_rate,
            n_stft = self.n_fft // 2 + 1,
            f_max = self.f_max
        )

        mel = mel_transform(spectrogram)

        # 如果需要对梅尔频谱进行对数变换
        if self.log:
            mel = T.AmplitudeToDB()(mel)

        mel = rearrange(mel, 'b d n -> b n d')
        return mel

    # 对梅尔频谱进行解码
    def decode(self, mel):
        mel = rearrange(mel, 'b n d -> b d n')

        # 如果需要对梅尔频谱进行反对数变换
        if self.log:
            mel = DB_to_amplitude(mel, ref = 1., power = 0.5)

        return self.vocos.decode(mel)

# 定义 EncodecVoco 类,继承自 AudioEncoderDecoder
class EncodecVoco(AudioEncoderDecoder):
    def __init__(
        self,
        *,
        sampling_rate = 24000,
        pretrained_vocos_path = 'charactr/vocos-encodec-24khz',
        bandwidth_id = 2
    ):
        super().__init__()
        self.sampling_rate = sampling_rate
        self.encodec = EncodecWrapper()
        # 加载预训练的 Vocos 模型
        self.vocos = Vocos.from_pretrained(pretrained_vocos_path)

        # 注册缓冲区,存储带宽 ID
        self.register_buffer('bandwidth_id', torch.tensor([bandwidth_id]))

    @property
    def downsample_factor(self):
        return self.encodec.downsample_factor

    @property
    def latent_dim(self):
        return self.encodec.codebook_dim

    # 对音频进行编码
    def encode(self, audio):
        encoded_audio, _, _ = self.encodec(audio, return_encoded = True)
        return encoded_audio

    # 解码为编码
    def decode_to_codes(self, latents):
        _, codes, _ = self.encodec.rq(latents)
        codes = rearrange(codes, 'b n q -> b q n')
        return codes

    # 解码编码为音频
    def decode(self, latents):
        codes = self.decode_to_codes(latents)

        all_audios = []
        for code in codes:
            features = self.vocos.codes_to_features(code)
            audio = self.vocos.decode(features, bandwidth_id = self.bandwidth_id)
            all_audios.append(audio)

        return torch.stack(all_audios)

# 定义 DurationPredictor 类,继承自 Module
class DurationPredictor(Module):
    @beartype
    def __init__(
        self,
        *,
        audio_enc_dec: Optional[AudioEncoderDecoder] = None,
        tokenizer: Optional[Tokenizer] = None,
        num_phoneme_tokens: Optional[int] = None,
        dim_phoneme_emb = 512,
        dim = 512,
        depth = 10,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        ff_dropout = 0.,
        conv_pos_embed_kernel_size = 31,
        conv_pos_embed_groups = None,
        attn_dropout = 0,
        attn_flash = False,
        attn_qk_norm = True,
        use_gateloop_layers = False,
        p_drop_prob = 0.2, # p_drop in paper
        frac_lengths_mask: Tuple[float, float] = (0.1, 1.),
        aligner_kwargs: dict = dict(dim_in = 80, attn_channels = 80)
    ):
        # 调用父类的构造函数
        super().__init__()

        # 音频编码器/解码器
        self.audio_enc_dec = audio_enc_dec

        # 如果音频编码器/解码器存在且维度不等于音频编码器/解码器的潜在维度,则创建输入投影层
        if exists(audio_enc_dec) and dim != audio_enc_dec.latent_dim:
            self.proj_in = nn.Linear(audio_enc_dec.latent_dim, dim)
        else:
            self.proj_in = nn.Identity()

        # 与音素相关

        # 如果传入了音素标记器和音素标记数,则抛出断言错误
        assert not (exists(tokenizer) and exists(num_phoneme_tokens)), 'if a phoneme tokenizer was passed into duration module, number of phoneme tokens does not need to be specified'

        # 如果音素标记器和音素标记数都不存在,则默认使用英语音素和 espeak 创建标记器
        if not exists(tokenizer) and not exists(num_phoneme_tokens):
            tokenizer = Tokenizer()

        # 如果存在音素标记器,则设置音素标记数为标记器的词汇量大小
        if exists(tokenizer):
            num_phoneme_tokens = tokenizer.vocab_size

        self.tokenizer = tokenizer

        # 创建音素嵌入层
        self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens, dim_phoneme_emb)

        self.p_drop_prob = p_drop_prob
        self.frac_lengths_mask = frac_lengths_mask

        # 创建线性层,用于将音频编码器/解码器输出和音素嵌入层输出连接起来
        self.to_embed = nn.Linear(dim + dim_phoneme_emb, dim)

        # 创建空条件参数
        self.null_cond = nn.Parameter(torch.zeros(dim), requires_grad = False)

        # 创建卷积位置嵌入层
        self.conv_embed = ConvPositionEmbed(
            dim = dim,
            kernel_size = conv_pos_embed_kernel_size,
            groups = conv_pos_embed_groups
        )

        # 创建 Transformer 模型
        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            ff_mult = ff_mult,
            ff_dropout = ff_dropout,
            attn_dropout=attn_dropout,
            attn_flash = attn_flash,
            attn_qk_norm = attn_qk_norm,
            use_gateloop_layers = use_gateloop_layers
        )

        # 创建预测层
        self.to_pred = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... 1 -> ...')
        )

        # 对齐器相关

        # 如果使用具有 80 个通道的 mel 频谱,则将 attn_channels 设置为 80
        # 假设输入维度为具有 80 个通道的 spec
        self.aligner = Aligner(dim_hidden = dim_phoneme_emb, **aligner_kwargs)
        self.align_loss = ForwardSumLoss()

    @property
    def device(self):
        # 返回模型参数所在的设备
        return next(self.parameters()).device

    def align_phoneme_ids_with_durations(self, phoneme_ids, durations):
        # 生成重复掩码
        repeat_mask = generate_mask_from_repeats(durations.clamp(min = 1))
        # 将音素标记与持续时间对齐
        aligned_phoneme_ids = einsum('b i, b i j -> b j', phoneme_ids.float(), repeat_mask.float()).long()
        return aligned_phoneme_ids

    @torch.inference_mode()
    @beartype
    def forward_with_cond_scale(
        self,
        *args,
        texts: Optional[List[str]] = None,
        phoneme_ids = None,
        cond_scale = 1.,
        return_aligned_phoneme_ids = False,
        **kwargs
    ):
        if exists(texts):
            phoneme_ids = self.tokenizer.texts_to_tensor_ids(texts)

        forward_kwargs = dict(
            return_aligned_phoneme_ids = False,
            phoneme_ids = phoneme_ids
        )

        durations = self.forward(*args, cond_drop_prob = 0., **forward_kwargs, **kwargs)

        if cond_scale == 1.:
            if not return_aligned_phoneme_ids:
                return durations

            return durations, self.align_phoneme_ids_with_durations(phoneme_ids, durations)

        null_durations = self.forward(*args, cond_drop_prob = 1., **forward_kwargs, **kwargs)
        scaled_durations = null_durations + (durations - null_durations) * cond_scale

        if not return_aligned_phoneme_ids:
            return scaled_durations

        return scaled_durations, self.align_phoneme_ids_with_durations(phoneme_ids, scaled_durations)

    @beartype
    def forward_aligner(
        self,
        x: FloatTensor,     # (b, t, c)
        x_mask: IntTensor,  # (b, 1, t)
        y: FloatTensor,     # (b, t, c)
        y_mask: IntTensor   # (b, 1, t)
    # 定义函数的返回类型为元组,包含四个张量
    ) -> Tuple[
        FloatTensor,        # alignment_hard: (b, t)
        FloatTensor,        # alignment_soft: (b, tx, ty)
        FloatTensor,        # alignment_logprob: (b, 1, ty, tx)
        BoolTensor          # alignment_mas: (b, tx, ty)
    ]:
        # 创建注意力掩码,用于限制注意力的计算范围
        attn_mask = rearrange(x_mask, 'b 1 t -> b 1 t 1') * rearrange(y_mask, 'b 1 t -> b 1 1 t')
        # 调用aligner模型计算软对齐和对数概率
        alignment_soft, alignment_logprob = self.aligner(rearrange(y, 'b t c -> b c t'), x, x_mask)

        # 断言软对齐张量中不包含NaN值
        assert not torch.isnan(alignment_soft).any()

        # 使用最大路径算法计算最佳对齐路径
        alignment_mas = maximum_path(
            rearrange(alignment_soft, 'b 1 t1 t2 -> b t2 t1').contiguous(),
            rearrange(attn_mask, 'b 1 t1 t2 -> b t1 t2').contiguous()
        )

        # 计算硬对齐张量
        alignment_hard = torch.sum(alignment_mas, -1).float()
        # 重新排列软对齐张量的维度
        alignment_soft = rearrange(alignment_soft, 'b 1 t1 t2 -> b t2 t1')
        # 返回硬对齐、软对齐、对数概率和对齐掩码
        return alignment_hard, alignment_soft, alignment_logprob, alignment_mas

    # 定义前向传播函数,接受多个参数
    @beartype
    def forward(
        self,
        *,
        cond,
        texts: Optional[List[str]] = None,
        phoneme_ids = None,
        cond_drop_prob = 0.,
        target = None,
        cond_mask = None,
        mel = None,
        phoneme_len = None,
        mel_len = None,
        phoneme_mask = None,
        mel_mask = None,
        self_attn_mask = None,
        return_aligned_phoneme_ids = False
    ):
        # 获取输入的 batch 大小、序列长度和条件维度
        batch, seq_len, cond_dim = cond.shape

        # 对条件进行投影
        cond = self.proj_in(cond)

        # 如果未提供音素 id,则使用分词器将文本转换为音素 id
        if not exists(phoneme_ids):
            assert exists(self.tokenizer)
            phoneme_ids = self.tokenizer.texts_to_tensor_ids(texts)

        # 如果未提供条件掩码,则根据条件生成掩码
        if not exists(cond_mask):
            if coin_flip():
                frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
                cond_mask = mask_from_frac_lengths(seq_len, frac_lengths)
            else:
                cond_mask = prob_mask_like((batch, seq_len), self.p_drop_prob, self.device)

        # 根据条件掩码对条件进行掩码处理
        cond = cond * rearrange(~cond_mask, '... -> ... 1')

        # 如果条件丢弃概率大于 0,则对条件进行丢弃处理
        if cond_drop_prob > 0.:
            cond_drop_mask = prob_mask_like(cond.shape[:1], cond_drop_prob, cond.device)

            cond = torch.where(
                rearrange(cond_drop_mask, '... -> ... 1 1'),
                self.null_cond,
                cond
            )

        # 音素 id 为 -1 表示填充
        if not exists(self_attn_mask):
            self_attn_mask = phoneme_ids != -1

        # 将音素 id 限制在大于等于 0 的范围内
        phoneme_ids = phoneme_ids.clamp(min=0)

        # 获取音素嵌入
        phoneme_emb = self.to_phoneme_emb(phoneme_ids)

        # 强制条件与输入音素具有相同的长度
        cond = curtail_or_pad(cond, phoneme_ids.shape[-1])

        # 合并音素嵌入、条件
        embed = torch.cat((phoneme_emb, cond), dim=-1)
        x = self.to_embed(embed)

        # 进行卷积嵌入
        x = self.conv_embed(x, mask=self_attn_mask) + x

        # 进行 transformer 操作
        x = self.transformer(
            x,
            mask=self_attn_mask
        )

        # 预测持续时间
        durations = self.to_pred(x)

        # 如果不是训练阶段,则返回持续时间
        if not self.training:
            if not return_aligned_phoneme_ids:
                return durations

            return durations, self.align_phoneme_ids_with_durations(phoneme_ids, durations)

        # 对齐器
        # 使用 alignment_hard 过采样音素
        # Duration Predictor 应该预测未掩码音素的持续时间,其中目标是掩码对齐硬
        assert all([exists(el) for el in (phoneme_len, mel_len, phoneme_mask, mel_mask)], '需要传递 phoneme_len���mel_len、phoneme_mask、mel_mask 给训练持续时间预测模块')

        alignment_hard, _, alignment_logprob, _ = self.forward_aligner(phoneme_emb, phoneme_mask, mel, mel_mask)
        target = alignment_hard

        if exists(self_attn_mask):
            loss_mask = cond_mask & self_attn_mask
        else:
            loss_mask = self_attn_mask

        if not exists(loss_mask):
            return F.l1_loss(x, target)

        loss = F.l1_loss(x, target, reduction='none')
        loss = loss.masked_fill(~loss_mask, 0.)

        # 掩码平均值
        num = reduce(loss, 'b n -> b', 'sum')
        den = loss_mask.sum(dim=-1).clamp(min=1e-5)
        loss = num / den
        loss = loss.mean()
        
        if not return_aligned_phoneme_ids:
            return loss

        # 对齐器损失
        align_loss = self.align_loss(alignment_logprob, phoneme_len, mel_len)
        loss = loss + align_loss

        return loss
# VoiceBox 类,继承自 Module 类
class VoiceBox(Module):
    # 初始化方法
    def __init__(
        self,
        *,
        num_cond_tokens = None, # 条件标记数量,默认为 None
        audio_enc_dec: Optional[AudioEncoderDecoder] = None, # 音频编码器解码器,默认为 None
        dim_in = None, # 输入维度,默认为 None
        dim_cond_emb = 1024, # 条件嵌入维度,默认为 1024
        dim = 1024, # 维度,默认为 1024
        depth = 24, # 深度,默认为 24
        dim_head = 64, # 头维度,默认为 64
        heads = 16, # 头数,默认为 16
        ff_mult = 4, # FeedForward 层倍数,默认为 4
        ff_dropout = 0., # FeedForward 层的 dropout,默认为 0
        time_hidden_dim = None, # 时间隐藏维度,默认为 None
        conv_pos_embed_kernel_size = 31, # 卷积位置嵌入的卷积核大小,默认为 31
        conv_pos_embed_groups = None, # 卷积位置嵌入的分组,默认为 None
        attn_dropout = 0, # 注意力 dropout,默认为 0
        attn_flash = False, # 是否使用 Flash 注意力,默认为 False
        attn_qk_norm = True, # 注意力的 QK 归一化,默认为 True
        use_gateloop_layers = False, # 是否使用 Gateloop 层,默认为 False
        num_register_tokens = 16, # 寄存器标记数量,默认为 16
        p_drop_prob = 0.3, # p_drop 在论文中的概率,默认为 0.3
        frac_lengths_mask: Tuple[float, float] = (0.7, 1.), # 长度掩码的分数,默认为 (0.7, 1)
        condition_on_text = True # 是否基于文本条件,默认为 True
    ):
        super().__init__() # 调用父类的初始化方法
        dim_in = default(dim_in, dim) # 如果输入维度为 None,则使用默认维度

        time_hidden_dim = default(time_hidden_dim, dim * 4) # 如果时间隐藏维度为 None,则使用默认维度

        self.audio_enc_dec = audio_enc_dec # 设置音频编码器解码器

        if exists(audio_enc_dec) and dim != audio_enc_dec.latent_dim: # 如果音频编码器解码器存在且维度不等于潜在维度
            self.proj_in = nn.Linear(audio_enc_dec.latent_dim, dim) # 使用线性层进行投影
        else:
            self.proj_in = nn.Identity() # 否则使用恒等映射

        # 正弦位置嵌入
        self.sinu_pos_emb = nn.Sequential(
            LearnedSinusoidalPosEmb(dim), # 学习的正弦位置嵌入
            nn.Linear(dim, time_hidden_dim), # 线性层
            nn.SiLU() # SiLU 激活函数
        )

        assert not (condition_on_text and not exists(num_cond_tokens)), 'number of conditioning tokens must be specified (whether phonemes or semantic token ids) if training conditional voicebox'

        if not condition_on_text: # 如果不基于文本条件
            dim_cond_emb = 0 # 条件嵌入维度为 0

        self.dim_cond_emb = dim_cond_emb # 设置条件嵌入维度
        self.condition_on_text = condition_on_text # 设置是否基于文本条件
        self.num_cond_tokens = num_cond_tokens # 设置条件标记数量

        if condition_on_text: # 如果基于文本条件
            self.null_cond_id = num_cond_tokens # 使用最后一个音素标记作为 CFG 的空标记
            self.to_cond_emb = nn.Embedding(num_cond_tokens + 1, dim_cond_emb) # 条件嵌入层

        self.p_drop_prob = p_drop_prob # 设置 p_drop 概率
        self.frac_lengths_mask = frac_lengths_mask # 设置长度掩码

        self.to_embed = nn.Linear(dim_in * 2 + dim_cond_emb, dim) # 输入到嵌入的线性层

        self.null_cond = nn.Parameter(torch.zeros(dim_in), requires_grad = False) # 空条件参数

        self.conv_embed = ConvPositionEmbed(
            dim = dim,
            kernel_size = conv_pos_embed_kernel_size,
            groups = conv_pos_embed_groups
        ) # 卷积位置嵌入层

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            ff_mult = ff_mult,
            ff_dropout = ff_dropout,
            attn_dropout= attn_dropout,
            attn_flash = attn_flash,
            attn_qk_norm = attn_qk_norm,
            num_register_tokens = num_register_tokens,
            adaptive_rmsnorm = True,
            adaptive_rmsnorm_cond_dim_in = time_hidden_dim,
            use_gateloop_layers = use_gateloop_layers
        ) # Transformer 模型

        dim_out = audio_enc_dec.latent_dim if exists(audio_enc_dec) else dim_in # 输出维度

        self.to_pred = nn.Linear(dim, dim_out, bias = False) # 预测线性层

    @property
    def device(self):
        return next(self.parameters()).device # 返回参数的设备

    @torch.inference_mode()
    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs) # 前向传播计算 logits

        if cond_scale == 1.: # 如果条件缩放为 1
            return logits # 返回 logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) # 使用条件概率为 1 计算 logits
        return null_logits + (logits - null_logits) * cond_scale # 返回缩放后的结果

    def forward(
        self,
        x,
        *,
        times,
        cond_token_ids,
        self_attn_mask = None,
        cond_drop_prob = 0.1,
        target = None,
        cond = None,
        cond_mask = None
        ):
            # 项目输入,以防代码簿维度不等于模型维度

            x = self.proj_in(x)

            cond = default(cond, target)

            if exists(cond):
                cond = self.proj_in(cond)

            # 获取形状信息

            batch, seq_len, cond_dim = cond.shape
            assert cond_dim == x.shape[-1]

            # 自动管理时间维度的形状,用于odeint times

            if times.ndim == 0:
                times = repeat(times, '-> b', b = cond.shape[0])

            if times.ndim == 1 and times.shape[0] == 1:
                times = repeat(times, '1 -> b', b = cond.shape[0])

            # 如果未提供条件掩码,则构建条件掩码

            if self.training:
                if not exists(cond_mask):
                    frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
                    cond_mask = mask_from_frac_lengths(seq_len, frac_lengths)
            else:
                if not exists(cond_mask):
                    cond_mask = torch.ones((batch, seq_len), device = cond.device, dtype = torch.bool)

            cond_mask_with_pad_dim = rearrange(cond_mask, '... -> ... 1')

            # 如第3.2节所述

            cond = cond * ~cond_mask_with_pad_dim

            # 无分类器指导

            cond_ids = cond_token_ids

            if cond_drop_prob > 0.:
                cond_drop_mask = prob_mask_like(cond.shape[:1], cond_drop_prob, self.device)

                cond = torch.where(
                    rearrange(cond_drop_mask, '... -> ... 1 1'),
                    self.null_cond,
                    cond
                )

                cond_ids = torch.where(
                    rearrange(cond_drop_mask, '... -> ... 1'),
                    self.null_cond_id,
                    cond_token_ids
                )

            # 音素或语义条件嵌入

            cond_emb = None

            if self.condition_on_text:
                cond_emb = self.to_cond_emb(cond_ids)

                cond_emb_length = cond_emb.shape[-2]
                if cond_emb_length != seq_len:
                    cond_emb = rearrange(cond_emb, 'b n d -> b d n')
                    cond_emb = interpolate_1d(cond_emb, seq_len)
                    cond_emb = rearrange(cond_emb, 'b d n -> b n d')

                    if exists(self_attn_mask):
                        self_attn_mask = interpolate_1d(self_attn_mask, seq_len)

            # 连接源信号、语义/音素条件嵌入和条件,并进行投影

            to_concat = [*filter(exists, (x, cond_emb, cond))]
            embed = torch.cat(to_concat, dim = -1)

            x = self.to_embed(embed)

            x = self.conv_embed(x, mask = self_attn_mask) + x

            time_emb = self.sinu_pos_emb(times)

            # 注意力

            x = self.transformer(
                x,
                mask = self_attn_mask,
                adaptive_rmsnorm_cond = time_emb
            )

            x = self.to_pred(x)

            # 如果未传入目标,则只返回对数

            if not exists(target):
                return x

            loss_mask = reduce_masks_with_and(cond_mask, self_attn_mask)

            if not exists(loss_mask):
                return F.mse_loss(x, target)

            loss = F.mse_loss(x, target, reduction = 'none')

            loss = reduce(loss, 'b n d -> b n', 'mean')
            loss = loss.masked_fill(~loss_mask, 0.)

            # 掩码均值

            num = reduce(loss, 'b n -> b', 'sum')
            den = loss_mask.sum(dim = -1).clamp(min = 1e-5)
            loss = num / den

            return loss.mean()
# 对 CNF 的包装器

# 判断输入是否可能是音频数据,根据其形状来判断
def is_probably_audio_from_shape(t):
    return exists(t) and (t.ndim == 2 or (t.ndim == 3 and t.shape[1] == 1))

# 条件流匹配器的包装器类
class ConditionalFlowMatcherWrapper(Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        voicebox: VoiceBox,
        text_to_semantic: Optional[TextToSemantic] = None,
        duration_predictor: Optional[DurationPredictor] = None,
        sigma = 0.,
        ode_atol = 1e-5,
        ode_rtol = 1e-5,
        use_torchode = False,
        torchdiffeq_ode_method = 'midpoint',   # 使用中点法作为 torchdiffeq 的方法,与论文中一致
        torchode_method_klass = to.Tsit5,      # 使用 tsit5 作为 torchode 的方法,因为 torchode 没有中点法(由 Bryan @b-chiang 推荐)
        cond_drop_prob = 0.
    ):
        super().__init__()
        self.sigma = sigma

        self.voicebox = voicebox
        self.condition_on_text = voicebox.condition_on_text

        # 断言条件,确保不在不条件下使用 TextToSemantic
        assert not (not self.condition_on_text and exists(text_to_semantic)), 'TextToSemantic should not be passed in if not conditioning on text'
        # 断言条件,确保在使用 TextToSemantic 时存在 wav2vec 模块
        assert not (exists(text_to_semantic) and not exists(text_to_semantic.wav2vec)), 'the wav2vec module must exist on the TextToSemantic, if being used to condition on text'

        self.text_to_semantic = text_to_semantic
        self.duration_predictor = duration_predictor

        # 断言条件,确保在条件下使用 TextToSemantic 或 DurationPredictor
        if self.condition_on_text and (exists(text_to_semantic) or exists(duration_predictor)):
            assert exists(text_to_semantic) ^ exists(duration_predictor), 'you should use either TextToSemantic from Spear-TTS, or DurationPredictor for the text / phoneme to audio alignment, but not both'

        self.cond_drop_prob = cond_drop_prob

        self.use_torchode = use_torchode
        self.torchode_method_klass = torchode_method_klass

        self.odeint_kwargs = dict(
            atol = ode_atol,
            rtol = ode_rtol,
            method = torchdiffeq_ode_method
        )

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 加载模型
    def load(self, path, strict = True):
        # 返回 pkg 以便训练器可以访问
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')
        self.load_state_dict(pkg['model'], strict = strict)
        return pkg

    # 采样方法
    @torch.inference_mode()
    def sample(
        self,
        *,
        cond = None,
        texts: Optional[List[str]] = None,
        text_token_ids: Optional[Tensor] = None,
        semantic_token_ids: Optional[Tensor] = None,
        phoneme_ids: Optional[Tensor] = None,
        cond_mask = None,
        steps = 3,
        cond_scale = 1.,
        decode_to_audio = True,
        decode_to_codes = False,
        max_semantic_token_ids = 2048,
        spec_decode = False,
        spec_decode_gamma = 5 # 可能需要更高,因为语音可能比文本更容易,需要测试
    # 前向传播方法
    def forward(
        self,
        x1,
        *,
        mask = None,
        semantic_token_ids = None,
        phoneme_ids = None,
        cond = None,
        cond_mask = None,
        input_sampling_rate = None # 如果未给出,则假定与音频编码器解码器采样率相同,如果给出,则重新采样
        ):
        """
        following eq (5) (6) in https://arxiv.org/pdf/2306.15687.pdf
        """

        # 获取输入张量 x1 的批量大小、序列长度、数据类型和标准差
        batch, seq_len, dtype, σ = *x1.shape[:2], x1.dtype, self.sigma

        # 如果输入是原始音频,则转换为音频编码器/解码器传入的格式
        input_is_raw_audio, cond_is_raw_audio = map(is_probably_audio_from_shape, (x1, cond))

        if input_is_raw_audio:
            raw_audio = x1

        if any([input_is_raw_audio, cond_is_raw_audio]):
            assert exists(self.voicebox.audio_enc_dec), 'audio_enc_dec must be set on VoiceBox to train directly on raw audio'

            audio_enc_dec_sampling_rate = self.voicebox.audio_enc_dec.sampling_rate
            input_sampling_rate = default(input_sampling_rate, audio_enc_dec_sampling_rate)

            with torch.no_grad():
                self.voicebox.audio_enc_dec.eval()

                if input_is_raw_audio:
                    x1 = resample(x1, input_sampling_rate, audio_enc_dec_sampling_rate)
                    x1 = self.voicebox.audio_enc_dec.encode(x1)

                if exists(cond) and cond_is_raw_audio:
                    cond = resample(cond, input_sampling_rate, audio_enc_dec_sampling_rate)
                    cond = self.voicebox.audio_enc_dec.encode(cond)

        # 设置文本条件,可以来自持续时间模型(作为音素 id)或来自文本到语义模块,使用 wav2vec 编码的语义 id(通常是 hubert)

        assert self.condition_on_text or not (exists(semantic_token_ids) or exists(phoneme_ids)), 'semantic or phoneme ids should not be passed in if not conditioning on text'

        cond_token_ids = None

        if self.condition_on_text:
            if exists(self.text_to_semantic) or exists(semantic_token_ids):
                assert not exists(phoneme_ids), 'phoneme ids are not needed for conditioning with spear-tts text-to-semantic'

                if not exists(semantic_token_ids):
                    assert input_is_raw_audio
                    wav2vec = self.text_to_semantic.wav2vec
                    wav2vec_input = resample(raw_audio, input_sampling_rate, wav2vec.target_sample_hz)
                    semantic_token_ids = wav2vec(wav2vec_input).clone()

                cond_token_ids = semantic_token_ids
            else:
                assert exists(phoneme_ids)
                cond_token_ids = phoneme_ids

        # 主要的条件流程逻辑在下面

        # x0 是高斯噪声

        x0 = torch.randn_like(x1)

        # 随机时间

        times = torch.rand((batch,), dtype=dtype, device=self.device)
        t = rearrange(times, 'b -> b 1 1')

        # 采样 xt(论文中的 w)

        w = (1 - (1 - σ) * t) * x0 + t * x1

        flow = x1 - (1 - σ) * x0

        # 预测

        self.voicebox.train()

        loss = self.voicebox(
            w,
            cond=cond,
            cond_mask=cond_mask,
            times=times,
            target=flow,
            self_attn_mask=mask,
            cond_token_ids=cond_token_ids,
            cond_drop_prob=self.cond_drop_prob
        )

        return loss

.\lucidrains\voicebox-pytorch\voicebox_pytorch\__init__.py

# 从 voicebox_pytorch.voicebox_pytorch 模块中导入 Transformer, EncodecVoco, VoiceBox, DurationPredictor, ConditionalFlowMatcherWrapper 类
from voicebox_pytorch.voicebox_pytorch import (
    Transformer,
    EncodecVoco,
    VoiceBox,
    DurationPredictor,
    ConditionalFlowMatcherWrapper,
)

# 从 voicebox_pytorch.trainer 模块中导入 VoiceBoxTrainer 类
from voicebox_pytorch.trainer import (
    VoiceBoxTrainer
)

# 从 spear_tts_pytorch 模块中导入 TextToSemantic 类
from spear_tts_pytorch import TextToSemantic

# 从 audiolm_pytorch 模块中导入 HubertWithKmeans 类
from audiolm_pytorch import HubertWithKmeans

Join us on Discord

x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers

Install

$ pip install x-clip

Usage

import torch
from x_clip import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
    visual_patch_dropout = 0.5,             # patch dropout probability, used in Kaiming He's FLIP to save compute and improve end results - 0.5 is good value, 0.75 on high end is tolerable
    use_all_token_embeds = False,           # whether to use fine-grained contrastive learning (FILIP)
    decoupled_contrastive_learning = True,  # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
    extra_latent_projection = True,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_visual_ssl = True,                  # whether to do self supervised learning on iages
    use_mlm = False,                        # use masked language learning (MLM) on text (DeCLIP)
    text_ssl_loss_weight = 0.05,            # weight for text MLM loss
    image_ssl_loss_weight = 0.05            # weight for image self-supervised learning loss
)

# mock data

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

# train

loss = clip(
    text,
    images,
    freeze_image_encoder = False,   # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper
    return_loss = True              # needs to be set to True to return contrastive loss
)

loss.backward()

You can also pass in an external visual transformer / residual net. You simply have to make sure your image encoder returns a set of embeddings in the shape of batch x seq x dim, and make sure dim_image is properly specified as the dimension of the returned embeddings. Below is an example using vision transformer from vit_pytorch

$ pip install vit_pytorch>=0.25.6
import torch
from x_clip import CLIP

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

base_vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

vit = Extractor(
    base_vit,
    return_embeddings_only = True
)

clip = CLIP(
    image_encoder = vit,
    dim_image = 512,           # must be set as the same dimensions as the vision transformer above
    dim_text = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = clip(text, images, return_loss = True)
loss.backward()

Finally, one can also have the text transformer be externally defined. It will need to return the embeddings including the CLS token, for now.

import torch
from x_clip import CLIP, TextTransformer

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

base_vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

image_encoder = Extractor(
    base_vit,
    return_embeddings_only = True
)

text_encoder = TextTransformer(
    dim = 512,
    num_tokens = 10000,
    max_seq_len = 256,
    depth = 6,
    heads = 8
)

clip = CLIP(
    image_encoder = image_encoder,
    text_encoder = text_encoder,
    dim_image = 512,
    dim_text = 512,
    dim_latent = 512
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = clip(text, images, return_loss = True)
loss.backward()

Multiview CL Losses

This repository also supports multiview contrastive learning loss, as proposed in DeCLIP. Just pass in the augmented text and/or augmented image, and it will be auto-calculated, weighed by multiview_loss_weight set on initialization.

ex.

import torch
from x_clip import CLIP, TextTransformer

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

base_vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

image_encoder = Extractor(
    base_vit,
    return_embeddings_only = True
)

text_encoder = TextTransformer(
    dim = 512,
    num_tokens = 10000,
    max_seq_len = 256 + 1,
    depth = 6,
    heads = 8
)

clip = CLIP(
    image_encoder = image_encoder,
    text_encoder = text_encoder,
    dim_image = 512,
    dim_text = 512,
    dim_latent = 512,
    extra_latent_projection = True,
    multiview_loss_weight = 0.1         # weight multiview contrastive loss by 0.1
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

aug_text = torch.randint(0, 10000, (4, 256))  # augmented text (backtranslation or EDA), same dimensions as text
aug_images = torch.randn(4, 3, 256, 256)      # augmented images, same dimension as images above
loss = clip(
    text,
    images,
    aug_text = aug_text,           # pass in augmented texts
    aug_image = aug_images,        # pass in augmented images
    return_loss = True,
    freeze_image_encoder = True
)

loss.backward()

You can even send in more than one augmented text or image

# ...

aug_texts = (
    torch.randint(0, 10000, (4, 256)),
    torch.randint(0, 10000, (4, 256)),
)

aug_images = (
    torch.randn(4, 3, 256, 256),
    torch.randn(4, 3, 256, 256),
)

loss = clip(
    text,
    images,
    aug_text = aug_texts,
    aug_image = aug_images,
    return_loss = True,
    freeze_image_encoder = True
)

loss.backward()

Custom Vision Self-supervised Learning Module

You can pass in your own vision self-supervised learning module through the visual_ssl keyword as so

import torch
from x_clip import CLIP
from x_clip.visual_ssl import SimSiam

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

base_vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

image_encoder = Extractor(
    base_vit,
    return_embeddings_only = True
)

visual_ssl = SimSiam(                 # SimSiam defined externally - needs to be a module that accepts an image of the same dimensions as CLIP and returns a scalar loss
    image_encoder,
    image_size = 256,
    hidden_layer = -1
)

clip = CLIP(
    image_encoder = image_encoder,
    dim_image = 512,
    dim_text = 512,
    dim_latent = 512,
    use_mlm = True,
    visual_ssl = visual_ssl,           # SSL module passed into CLIP
    use_all_token_embeds = False,
    extra_latent_projection = False,
    mlm_random_token_prob = 0.1
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = clip(text, images, return_loss = True)
loss.backward()

Citations

@misc{radford2021learning,
    title   = {Learning Transferable Visual Models From Natural Language Supervision}, 
    author  = {Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
    year    = {2021},
    eprint  = {2103.00020},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yao2021filip,
    title   = {FILIP: Fine-grained Interactive Language-Image Pre-Training}, 
    author  = {Lewei Yao and Runhui Huang and Lu Hou and Guansong Lu and Minzhe Niu and Hang Xu and Xiaodan Liang and Zhenguo Li and Xin Jiang and Chunjing Xu},
    year    = {2021},
    eprint  = {2111.07783},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{fürst2021cloob,
    title   = {CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP},
    author  = {Andreas Fürst and Elisabeth Rumetshofer and Viet Tran and Hubert Ramsauer and Fei Tang and Johannes Lehner and David Kreil and Michael Kopp and Günter Klambauer and Angela Bitto-Nemling and Sepp Hochreiter},
    year    = {2021},
    eprint  = {2110.11316},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{yeh2021decoupled,
    title   = {Decoupled Contrastive Learning},
    author  = {Chun-Hsiao Yeh and Cheng-Yao Hong and Yen-Chi Hsu and Tyng-Luh Liu and Yubei Chen and Yann LeCun},
    year    = {2021},
    eprint  = {2110.06848},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{zhai2021lit,
    title   = {LiT: Zero-Shot Transfer with Locked-image Text Tuning},
    author  = {Xiaohua Zhai and Xiao Wang and Basil Mustafa and Andreas Steiner and Daniel Keysers and Alexander Kolesnikov and Lucas Beyer},
    year    = {2021},
    eprint  = {2111.07991},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{li2021supervision,
    title   = {Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm},
    author  = {Yangguang Li and Feng Liang and Lichen Zhao and Yufeng Cui and Wanli Ouyang and Jing Shao and Fengwei Yu and Junjie Yan},
    year    = {2021},
    eprint  = {2110.05208},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@Article{mu2021slip,
    author  = {Norman Mu and Alexander Kirillov and David Wagner and Saining Xie},
    title   = {SLIP: Self-supervision meets Language-Image Pre-training},
    journal = {arXiv preprint arXiv:2112.12750},
    year    = {2021},
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
@inproceedings{Li2022ScalingLP,
    title   = {Scaling Language-Image Pre-training via Masking},
    author  = {Yanghao Li and Haoqi Fan and Ronghang Hu and Christoph Feichtenhofer and Kaiming He},
    year    = {2022}
}
@article{Liu2022PatchDropoutEV,
    title   = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
    author  = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.07220}
}
@misc{shi2023enhance,
    title   = {Enhance audio generation controllability through representation similarity regularization}, 
    author  = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
    year    = {2023},
    eprint  = {2309.08773},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
posted @ 2024-06-28 14:08  绝不原创的飞龙  阅读(21)  评论(0编辑  收藏  举报