Lucidrains-系列项目源码解析-三-

Lucidrains 系列项目源码解析(三)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\data.py

# 导入必要的模块
from pathlib import Path
from functools import partial, wraps

# 导入 beartype 模块及相关类型
from beartype import beartype
from beartype.typing import Tuple, Union, Optional
from beartype.door import is_bearable

# 导入 torchaudio 模块及相关函数
import torchaudio
from torchaudio.functional import resample

# 导入 torch 模块及相关函数
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

# 导入自定义工具函数
from audiolm_pytorch.utils import curtail_to_multiple

# 导入 einops 模块中的函数
from einops import rearrange, reduce

# 定义一些辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 判断列表中的元素是否唯一
def is_unique(arr):
    return len(set(arr)) == len(arr

# 定义数据集类
class SoundDataset(Dataset):
    @beartype
    def __init__(
        self,
        folder,
        target_sample_hz: Union[int, Tuple[int, ...]],  # 目标采样率必须指定,或者是一个包含多个目标采样率的元组
        exts = ['flac', 'wav', 'mp3', 'webm'],
        max_length: Optional[int] = None,               # 如果有多个目标采样率,最大长度将应用于最高的采样率
        seq_len_multiple_of: Optional[Union[int, Tuple[Optional[int], ...]]] = None
    ):
        super().__init__()
        path = Path(folder)
        assert path.exists(), f'folder "{str(path)}" does not exist'

        files = [file for ext in exts for file in path.glob(f'**/*.{ext}')]
        assert len(files) > 0, 'no sound files found'

        self.files = files

        self.max_length = max_length
        self.target_sample_hz = cast_tuple(target_sample_hz)
        num_outputs = len(self.target_sample_hz)

        # 如果有多个目标采样率,首先将其重采样为最高的采样率,然后应用最大长度,最后再重采样为其他采样率

        self.max_target_sample_hz = max(self.target_sample_hz)
        self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs)

        assert len(self.target_sample_hz) == len(self.seq_len_multiple_of)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]

        data, sample_hz = torchaudio.load(file)

        assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'

        if data.shape[0] > 1:
            # 如果音频有多个通道,转换为单声道
            data = reduce(data, 'c ... -> 1 ...', 'mean')

        # 首先将数据重采样为最大目标频率

        data = resample(data, sample_hz, self.max_target_sample_hz)
        sample_hz = self.max_target_sample_hz

        # 根据最大长度截断或填充音频

        max_length = self.max_length
        audio_length = data.size(1)

        if exists(max_length):
            if audio_length > max_length:
                max_start = audio_length - max_length
                start = torch.randint(0, max_start, (1, ))
                data = data[:, start:start + max_length]
            else:
                data = F.pad(data, (0, max_length - audio_length), 'constant')

        data = rearrange(data, '1 ... -> ...')

        # 如果目标采样率不是元组中的 None,则重采样

        num_outputs = len(self.target_sample_hz)
        data = cast_tuple(data, num_outputs)

        data_tuple = tuple(resample(d, sample_hz, target_sample_hz) for d, target_sample_hz in zip(data, self.target_sample_hz))

        output = []

        # 逐个处理不同频率下的数据以符合多个长度的要求

        for data, seq_len_multiple_of in zip(data_tuple, self.seq_len_multiple_of):
            if exists(seq_len_multiple_of):
                data = curtail_to_multiple(data, seq_len_multiple_of)

            output.append(data.float())

        # 从列表转换为元组

        output = tuple(output)

        # 如果只有一个目标重采样频率,则返回一个音频

        if num_outputs == 1:
            return output[0]

        return output

# 数据加载函数

# 定义一个装饰器函数,用于处理单个或多个张量的填充
def collate_one_or_multiple_tensors(fn):
    @wraps(fn)
    def inner(data):
        is_one_data = not isinstance(data[0], tuple)

        if is_one_data:
            data = fn(data)
            return (data,)

        outputs = []
        for datum in zip(*data):
            if is_bearable(datum, Tuple[str, ...]):
                output = list(datum)
            else:
                output = fn(datum)

            outputs.append(output)

        return tuple(outputs)

    return inner

# 对最短的数据进行填充
@collate_one_or_multiple_tensors
def curtail_to_shortest_collate(data):
    min_len = min(*[datum.shape[0] for datum in data])
    data = [datum[:min_len] for datum in data]
    return torch.stack(data)

# 对最长的数据进行填充
@collate_one_or_multiple_tensors
def pad_to_longest_fn(data):
    return pad_sequence(data, batch_first = True)

# 获取数据加载器
def get_dataloader(ds, pad_to_longest = True, **kwargs):
    collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
    return DataLoader(ds, collate_fn = collate_fn, **kwargs)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\encodec.py

# 导入所需的库和模块
from functools import reduce
from einops import rearrange, pack, unpack
import torch
from torch import nn
from torchaudio.functional import resample
from vector_quantize_pytorch import ResidualVQ
from encodec import EncodecModel
from encodec.utils import _linear_overlap_add

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 获取模型中的量化器数量
def get_num_quantizers(model: EncodecModel, audio_length = 512):
    out = model.encode(torch.randn(1, 1, audio_length))
    return out[0][0].shape[1]

# 定义一个包装器类,用于支持预训练的 24kHz Encodec 模型
class EncodecWrapper(nn.Module):
    def __init__(
        self,
        target_sample_hz = 24000,
        strides = (2, 4, 5, 8),
        num_quantizers = 8,
        bandwidth = 6.0
    ):
        super().__init__()
        # 实例化一个预训练的 Encodec 模型
        self.model = EncodecModel.encodec_model_24khz()
        self.model.normalize = False

        # 设置目标带宽,影响量化器数量
        self.model.set_target_bandwidth(bandwidth)
        num_quantizers = get_num_quantizers(self.model)

        # 设置一些字段
        self.target_sample_hz = target_sample_hz
        assert self.target_sample_hz == 24000, "haven't done anything with non-24kHz yet"
        self.codebook_dim = 128
        self.rq_groups = 1
        self.num_quantizers = num_quantizers
        self.strides = strides

        # 初始化 ResidualVQ 模块
        self.rq = ResidualVQ(
            dim = 128,
            codebook_size = 1024,
            num_quantizers = num_quantizers
        )

        # 复制编码器的码书到 ResidualVQ 模块
        for encodec_rq_layer, rq_layer in zip(self.model.quantizer.vq.layers, self.rq.layers):
            encodec_codebook = dict(encodec_rq_layer._codebook.named_buffers()).get('embed')
            vq_codebook = dict(rq_layer._codebook.named_buffers()).get('embed')
            encodec_codebook = rearrange(encodec_codebook, '... -> 1 ...')
            vq_codebook.copy_(encodec_codebook)

    @property
    def seq_len_multiple_of(self):
        return reduce(lambda x, y: x * y, self.strides)

    @property
    def downsample_factor(self):
        return self.seq_len_multiple_of

    def forward(
        self,
        x,
        input_sample_hz = None,
        return_encoded = False,
        **kwargs
    ):
        x, ps = pack([x], '* n')

        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)

        assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode."

        wav = rearrange(x, f'b t -> b {self.model.channels} t')

        with torch.inference_mode():
            encoded_frames = self.model.encode(wav)

        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
        codes = rearrange(codes, 'b q n -> b n q')

        emb = None

        if return_encoded:
            emb = self.get_emb_from_indices(codes)
            emb, = unpack(emb, ps, '* n c')

        codes, = unpack(codes, ps, '* n q')

        return emb, codes, None

    def decode_from_codebook_indices(self, quantized_indices):
        frames = self._decode_frame(quantized_indices)
        result = _linear_overlap_add(frames, self.model.segment_stride or 1)
        return rearrange(result, 'b n -> b 1 n')

    def get_emb_from_indices(self, indices):
        codes = rearrange(indices, 'b t q -> q b t')
        emb = self.model.quantizer.decode(codes)
        return rearrange(emb, 'b c n -> b n c')

    def decode(self, emb):
        emb = rearrange(emb, 'b n c -> b c n')
        return self.model.decoder(emb)
    # 解码帧数据,输入为量化后的索引
    def _decode_frame(self, quantized_indices):
        # 以下代码是从 self.model._decode_frame() (Encodec 版本 0.1.1) 中插入的,假设我们已经解包了 EncodedFrame
        # 输入: batch x num tokens x num quantizers
        # 输出: batch x new_num_samples,其中 new_num_samples 是 num_frames * stride 的乘积(可能略大于原始 num samples,因为最后一帧可能不是完全填满的)
        # num_frames == 你拥有的声学标记数量,每个标记对应一帧
        # 重新排列量化后的索引,形状为 'b t q -> q b t'
        codes = rearrange(quantized_indices, 'b t q -> q b t')
        # 使用量化器解码得到的嵌入
        emb = self.model.quantizer.decode(codes)
        # emb 形状: batch x self.model.quantizer.dimension x T。注意 self.model.quantizer.dimension 是嵌入维度
        return self.model.decoder(emb)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\hubert_kmeans.py

# 导入必要的库
from pathlib import Path
import torch
from torch import nn, einsum
from torchaudio.functional import resample
from einops import rearrange, repeat, pack, unpack
from audiolm_pytorch.utils import curtail_to_multiple

# 定义一个空函数用于忽略警告
def noop(*args, **kwargs):
    pass

import warnings
import logging

# 设置日志级别为 ERROR
logging.root.setLevel(logging.ERROR)

# 忽略警告
warnings.warn = noop

# 导入 fairseq 和 joblib 用于 hubert 模型
import joblib
import fairseq

# 定义辅助函数
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# 定义一个带有 kmeans 的 Hubert 模型类
class HubertWithKmeans(nn.Module):
    """
    checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
    or you can train your own
    """

    def __init__(
        self,
        checkpoint_path,
        kmeans_path,
        target_sample_hz = 16000,
        seq_len_multiple_of = None,
        output_layer = 9
    ):
        super().__init__()

        # 初始化模型参数
        self.target_sample_hz = target_sample_hz
        self.seq_len_multiple_of = seq_len_multiple_of
        self.output_layer = output_layer

        # 加载模型和 kmeans
        model_path = Path(checkpoint_path)
        kmeans_path = Path(kmeans_path)

        assert model_path.exists(), f'path {checkpoint_path} does not exist'
        assert kmeans_path.exists(), f'path {kmeans_path} does not exist'

        checkpoint = torch.load(checkpoint_path)
        load_model_input = {checkpoint_path: checkpoint}
        model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)

        self.model = model[0]
        self.model.eval()

        kmeans = joblib.load(kmeans_path)

        self.kmeans = kmeans

        # 注册缓冲区
        self.register_buffer(
            'cluster_centers',
            torch.from_numpy(kmeans.cluster_centers_)
        )

    @property
    def groups(self):
        return 1

    @property
    def codebook_size(self):
        return self.kmeans.n_clusters

    @property
    def downsample_factor(self):
        # todo: double check
        return 320

    @torch.inference_mode()
    def forward(
        self,
        wav_input,
        flatten = True,
        input_sample_hz = None
    ):
        # 获取输入数据的批次和设备
        batch, device = wav_input.shape[0], wav_input.device

        # 如果输入采样率存在,则对输入进行重采样
        if exists(input_sample_hz):
            wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)

        # 如果设置了 seq_len_multiple_of,则对输入进行截断
        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

        # 提取特征
        embed = self.model(
            wav_input,
            features_only = True,
            mask = False,
            output_layer = self.output_layer
        )['x']

        # 重复聚类中心以匹配嵌入的形状
        batched_cluster_centers = repeat(self.cluster_centers, 'c d -> b c d', b = embed.shape[0])
        # 计算嵌入和聚类中心之间的欧氏距离
        dists = -torch.cdist(embed, batched_cluster_centers, p = 2)
        # 获取最大距离对应的聚类
        clusters = dists.argmax(dim = -1)

        # 如果 flatten 为 True,则返回平坦的聚类结果
        if flatten:
            return clusters

        # 否则返回重排后的聚类结果
        return rearrange(clusters, 'b ... -> b (...)')

.\lucidrains\audiolm-pytorch\audiolm_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    use_lion = False,
    **kwargs
):
    # 判断是否需要权重衰减
    has_wd = wd > 0

    # 根据是否需要过滤梯度为零的参数来更新参数列表
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并应用权重衰减
    if group_wd_params and has_wd:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 如果不需要权重衰减,则使用 Adam 优化器
    if not has_wd:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要权重衰减,则使用 AdamW 优化器
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\soundstream.py

# 导入必要的库
import functools
from pathlib import Path
from functools import partial, wraps
from itertools import cycle, zip_longest
from typing import Optional, List

import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.autograd import grad as torch_grad
import torch.nn.functional as F
from torch.linalg import vector_norm

import torchaudio.transforms as T
from torchaudio.functional import resample

from einops import rearrange, reduce, pack, unpack

# 导入自定义模块
from vector_quantize_pytorch import (
    GroupedResidualVQ,
    GroupedResidualLFQ,
    GroupedResidualFSQ
)

from local_attention import LocalMHA
from local_attention.transformer import FeedForward, DynamicPositionBias

from gateloop_transformer import SimpleGateLoopLayer as GateLoop

from audiolm_pytorch.utils import curtail_to_multiple

from audiolm_pytorch.version import __version__
from packaging import version
parsed_version = version.parse(__version__)

import pickle

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将元组转换为指定长度的元组
def cast_tuple(t, l = 1):
    return ((t,) * l) if not isinstance(t, tuple) else t

# 根据键过滤字典
def filter_by_keys(fn, d):
    return {k: v for k, v in d.items() if fn(k)}

# 映射字典键
def map_keys(fn, d):
    return {fn(k): v for k, v in d.items()}

# GAN 损失函数

# 对数函数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# Hinge 判别器损失
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 生成器损失
def hinge_gen_loss(fake):
    return -fake.mean()

# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(p)

# 梯度惩罚
def gradient_penalty(wave, output, weight = 10):
    batch_size, device = wave.shape[0], wave.device

    gradients = torch_grad(
        outputs = output,
        inputs = wave,
        grad_outputs = torch.ones_like(output),
        create_graph = True,
        retain_graph = True,
        only_inputs = True
    )[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean()

# 更好的序列化函数

def Sequential(*mods):
    return nn.Sequential(*filter(exists, mods))

# 判别器

class MultiScaleDiscriminator(Module):
    def __init__(
        self,
        channels = 16,
        layers = 4,
        groups = (4, 16, 64, 256),
        chan_max = 1024,
        input_channels = 1
    ):
        super().__init__()
        self.init_conv = nn.Conv1d(input_channels, channels, 15, padding = 7)
        self.conv_layers = ModuleList([])

        curr_channels = channels

        for _, group in zip(range(layers), groups):
            chan_out = min(curr_channels * 4, chan_max)

            self.conv_layers.append(nn.Sequential(
                nn.Conv1d(curr_channels, chan_out, 41, stride = 4, padding = 20, groups = group),
                leaky_relu()
            ))

            curr_channels = chan_out

        self.final_conv = nn.Sequential(
            nn.Conv1d(curr_channels, curr_channels, 5, padding = 2),
            leaky_relu(),
            nn.Conv1d(curr_channels, 1, 3, padding = 1),
        )

    def forward(
        self,
        x,
        return_intermediates = False
    ):
        x = self.init_conv(x)
        intermediates = []

        for layer in self.conv_layers:
            x = layer(x)
            intermediates.append(x)

        out = self.final_conv(x)

        if not return_intermediates:
            return out

        return out, intermediates

# 自回归挤压激励
# https://arxiv.org/abs/1709.01507

class SqueezeExcite(Module):
    def __init__(self, dim, reduction_factor = 4, dim_minimum = 8):
        super().__init__()
        dim_inner = max(dim_minimum, dim // reduction_factor)
        self.net = nn.Sequential(
            nn.Conv1d(dim, dim_inner, 1),
            nn.SiLU(),
            nn.Conv1d(dim_inner, dim, 1),
            nn.Sigmoid()
        )
    # 定义前向传播函数,输入参数 x
    def forward(self, x):
        # 获取输入 x 的序列长度和设备信息
        seq, device = x.shape[-2], x.device

        # 计算累积均值 - 因为是自回归的

        # 沿着倒数第二个维度对 x 进行累积求和
        cum_sum = x.cumsum(dim = -2)
        # 创建一个序列长度范围的张量,转换为浮点数类型,并移动到指定设备
        denom = torch.arange(1, seq + 1, device = device).float()
        # 计算累积均值,即累积和除以对应的序号
        cum_mean = cum_sum / rearrange(denom, 'n -> n 1')

        # glu 门

        # 通过神经网络计算门控值
        gate = self.net(cum_mean)

        # 返回输入 x 与门控值的乘积
        return x * gate
# 定义一个复杂的短时傅里叶变换鉴别器

class ModReLU(Module):
    """
    https://arxiv.org/abs/1705.09792
    https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801
    """
    # 定义一个自定义的激活函数模块,参考论文和GitHub链接
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        # 返回修正的ReLU激活函数应用于输入 x 的结果
        return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))

class ComplexConv2d(Module):
    # 定义一个复杂卷积层模块
    def __init__(
        self,
        dim,
        dim_out,
        kernel_size,
        stride = 1,
        padding = 0
    ):
        super().__init__()
        # 创建一个普通的卷积层对象
        conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64)
        # 将卷积层的权重和偏置参数转换为复数类型
        self.weight = nn.Parameter(torch.view_as_real(conv.weight))
        self.bias = nn.Parameter(torch.view_as_real(conv.bias))

        self.stride = stride
        self.padding = padding

    def forward(self, x):
        # 将权重和偏置参数转换为复数类型
        weight, bias = map(torch.view_as_complex, (self.weight, self.bias))

        x = x.to(weight.dtype)
        # 返回卷积操作的结果
        return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding)

def ComplexSTFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        # 定义一个复杂短时傅里叶变换残差单元
        Residual(Sequential(
            ComplexConv2d(chan_in, chan_in, 3, padding = 1),
            ModReLU(),
            ComplexConv2d(chan_in, chan_in, 3, padding = 1)
        )),
        ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
    )

class ComplexSTFTDiscriminator(Module):
    # 定义一个复杂短时傅里叶变换鉴别器模块
    def __init__(
        self,
        *,
        channels = 32,
        strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
        chan_mults = (1, 2, 4, 4, 8, 8),
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024,
        stft_normalized = False,
        stft_window_fn = torch.hann_window,
        logits_abs = True
    ):
        super().__init__()
        # 初始化卷积层
        self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
        layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))

        curr_channels = channels

        self.layers = ModuleList([])

        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            # 添加复杂短时傅里叶变换残差单元到层列表中
            self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride))

        # 添加最终的卷积层
        self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16

        # stft 设置

        self.stft_normalized = stft_normalized
        self.stft_window_fn = stft_window_fn

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length

        # 如何将对数输出转换为实数空间

        self.logits_abs = logits_abs
    # 定义一个前向传播函数,接受输入 x 和是否返回中间结果的标志
    def forward(self, x, return_intermediates = False):
        # 重新排列输入张量 x 的维度,将 'b 1 n' 转换为 'b n'
        x = rearrange(x, 'b 1 n -> b n')

        '''
        reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:
        The STFT-based discriminator is illustrated in Figure 4
        and operates on a single scale, computing the STFT with a
        window length of W = 1024 samples and a hop length of
        H = 256 samples
        '''
        
        # 使用 self.stft_window_fn 函数生成 STFT 窗口
        stft_window = self.stft_window_fn(self.win_length, device = x.device)

        # 计算输入 x 的短时傅里叶变换(STFT)
        x = torch.stft(
            x,
            self.n_fft,
            hop_length = self.hop_length,
            win_length = self.win_length,
            window = stft_window,
            normalized = self.stft_normalized,
            return_complex = True
        )

        # 重新排列 STFT 结果的维度,将 'b ...' 转换为 'b 1 ...'
        x = rearrange(x, 'b ... -> b 1 ...')

        intermediates = []

        # 对输入 x 进行初始卷积操作
        x = self.init_conv(x)

        intermediates.append(x)

        # 遍历所有层进行处理
        for layer in self.layers:
            x = layer(x)
            intermediates.append(x)

        # 对最终卷积结果进行处理,得到复数形式的 logits
        complex_logits = self.final_conv(x)

        # 如果 logits_abs 为 True,则取复数 logits 的绝对值
        if self.logits_abs:
            complex_logits = complex_logits.abs()
        else:
            complex_logits = torch.view_as_real(complex_logits)

        # 如果不需要返回中间结果,则直接返回复数 logits
        if not return_intermediates:
            return complex_logits

        # 如果需要返回中间结果,则同时返回复数 logits 和中间结果列表
        return complex_logits, intermediates
# 定义一个名为 Residual 的类,继承自 Module 类
class Residual(Module):
    # 初始化函数,接受一个名为 fn 的 Module 对象作为参数
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    # 前向传播函数,接受输入 x 和关键字参数 kwargs
    def forward(self, x, **kwargs):
        # 返回输入 x 经过 fn 处理后的结果与 x 相加的结果
        return self.fn(x, **kwargs) + x

# 定义一个名为 ChannelTranspose 的类,继承自 Module 类
class ChannelTranspose(Module):
    # 初始化函数,接受一个名为 fn 的 Module 对象作为参数
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    # 前向传播函数,接受输入 x 和关键字参数 kwargs
    def forward(self, x, **kwargs):
        # 将输入 x 的维度重新排列为 'b c n'
        x = rearrange(x, 'b c n -> b n c')
        # 将重新排列后的输入 x 经过 fn 处理后的结果与 x 相加的结果
        out = self.fn(x, **kwargs) + x
        # 将输出 out 的维度重新排列为 'b n c'
        return rearrange(out, 'b n c -> b c n')

# 定义一个名为 CausalConv1d 的类,继承自 Module 类
class CausalConv1d(Module):
    # 初始化函数,接受通道数 chan_in、输出通道数 chan_out、卷积核大小 kernel_size 和填充模式 pad_mode 等参数
    def __init__(self, chan_in, chan_out, kernel_size, pad_mode = 'reflect', **kwargs):
        super().__init__()
        # 设置卷积核大小
        kernel_size = kernel_size
        # 获取关键字参数中的膨胀值和步长
        dilation = kwargs.get('dilation', 1)
        stride = kwargs.get('stride', 1)
        self.pad_mode = pad_mode
        # 计算因果填充值
        self.causal_padding = dilation * (kernel_size - 1) + (1 - stride)

        # 创建一个 1D 卷积层
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 对输入 x 进行填充,使用填充模式 pad_mode
        x = F.pad(x, (self.causal_padding, 0), mode = self.pad_mode)
        # 将填充后的输入 x 经过卷积层处理后返回
        return self.conv(x)

# 定义一个名为 CausalConvTranspose1d 的类,继承自 Module 类
class CausalConvTranspose1d(Module):
    # 初始化函数,接受通道数 chan_in、输出通道数 chan_out、卷积核大小 kernel_size 和步长 stride 等参数
    def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs):
        super().__init__()
        self.upsample_factor = stride
        self.padding = kernel_size - 1
        # 创建一个 1D 转置卷积层
        self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        n = x.shape[-1]

        # 将输入 x 经过转置卷积层处理后返回,并截取指定长度的输出
        out = self.conv(x)
        out = out[..., :(n * self.upsample_factor)]

        return out

# 定义一个名为 ResidualUnit 的函数,接受输入通道数 chan_in、输出通道数 chan_out、膨胀值 dilation 等参数
def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False, pad_mode = 'reflect'):
    # 返回一个 Residual 类的实例,包含一系列操作
    return Residual(Sequential(
        CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation, pad_mode = pad_mode),
        nn.ELU(),
        CausalConv1d(chan_out, chan_out, 1, pad_mode = pad_mode),
        nn.ELU(),
        SqueezeExcite(chan_out) if squeeze_excite else None
    ))

# 定义一个名为 EncoderBlock 的函数,接受输入通道数 chan_in、输出通道数 chan_out、步长 stride 等参数
def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
    # 创建一个循环迭代器
    it = cycle(cycle_dilations)
    # 使用偏函数创建一个 ResidualUnit 函数的部分应用
    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)

    return nn.Sequential(
        # 一系列残差单元和卷积操作组成的编码器块
        residual_unit(chan_in, chan_in, next(it)),
        residual_unit(chan_in, chan_in, next(it)),
        residual_unit(chan_in, chan_in, next(it)),
        CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride)
    )

# 定义一个名为 DecoderBlock 的函数,接受输入通道数 chan_in、输出通道数 chan_out、步长 stride 等参数
def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
    even_stride = (stride % 2 == 0)
    padding = (stride + (0 if even_stride else 1)) // 2
    output_padding = 0 if even_stride else 1

    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)

    it = cycle(cycle_dilations)
    return nn.Sequential(
        # 一系列残差单元和卷积操作组成的解码器块
        CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride),
        residual_unit(chan_out, chan_out, next(it)),
        residual_unit(chan_out, chan_out, next(it)),
        residual_unit(chan_out, chan_out, next(it)),
    )

# 定义一个名为 LocalTransformer 的类,继承自 Module 类
class LocalTransformer(Module):
    # 初始化函数,接受关键字参数 dim、depth、heads、window_size、dynamic_pos_bias 等
    def __init__(
        self,
        *,
        dim,
        depth,
        heads,
        window_size,
        dynamic_pos_bias = False,
        **kwargs
        ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化窗口大小
        self.window_size = window_size
        # 初始化层列表
        self.layers = ModuleList([])

        # 初始化位置偏置
        self.pos_bias = None
        # 如果需要动态位置偏置
        if dynamic_pos_bias:
            # 创建动态位置偏置对象
            self.pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads)

        # 根据深度循环创建多个层
        for _ in range(depth):
            # 每个层包含局部多头注意力和前馈网络
            self.layers.append(ModuleList([
                LocalMHA(
                    dim = dim,
                    heads = heads,
                    qk_rmsnorm = True,
                    window_size = window_size,
                    use_rotary_pos_emb = not dynamic_pos_bias,
                    gate_values_per_head = True,
                    use_xpos = True,
                    **kwargs
                ),
                FeedForward(dim = dim)
            ]))

    # 前向传播函数
    def forward(self, x):
        # 获取窗口大小
        w = self.window_size

        # 如果存在位置偏置,则计算注意力偏置
        attn_bias = self.pos_bias(w, w * 2) if exists(self.pos_bias) else None

        # 遍历每个层,依次进行局部多头注意力和前馈网络操作
        for attn, ff in self.layers:
            x = attn(x, attn_bias = attn_bias) + x
            x = ff(x) + x

        # 返回处理后的数据
        return x
class FiLM(Module):
    # 定义 FiLM 类,继承自 Module 类
    def __init__(self, dim, dim_cond):
        # 初始化函数,接受两个参数 dim 和 dim_cond
        super().__init__()
        # 调用父类的初始化函数
        self.to_cond = nn.Linear(dim_cond, dim * 2)
        # 创建一个线性层,输入维度为 dim_cond,输出维度为 dim * 2

    def forward(self, x, cond):
        # 前向传播函数,接受输入 x 和条件 cond
        gamma, beta = self.to_cond(cond).chunk(2, dim = -1)
        # 将条件 cond 输入到线性层中,得到 gamma 和 beta
        return x * gamma + beta
        # 返回经过 FiLM 操作后的结果

class SoundStream(Module):
    # 定义 SoundStream 类,继承自 Module 类
    def __init__(
        self,
        *,
        channels = 32,
        strides = (2, 4, 5, 8),
        channel_mults = (2, 4, 8, 16),
        codebook_dim = 512,
        codebook_size: Optional[int] = None,
        finite_scalar_quantizer_levels: Optional[List[int]] = None,
        rq_num_quantizers = 8,
        rq_commitment_weight = 1.,
        rq_ema_decay = 0.95,
        rq_quantize_dropout_multiple_of = 1,
        rq_groups = 1,
        rq_stochastic_sample_codes = False,
        rq_kwargs: dict = {},
        use_lookup_free_quantizer = False,              
        use_finite_scalar_quantizer = False,            
        input_channels = 1,
        discr_multi_scales = (1, 0.5, 0.25),
        stft_normalized = False,
        enc_cycle_dilations = (1, 3, 9),
        dec_cycle_dilations = (1, 3, 9),
        multi_spectral_window_powers_of_two = tuple(range(6, 12)),
        multi_spectral_n_ffts = 512,
        multi_spectral_n_mels = 64,
        recon_loss_weight = 1.,
        multi_spectral_recon_loss_weight = 1e-5,
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
        quantize_dropout_cutoff_index = 1,
        target_sample_hz = 16000,
        use_local_attn = True,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8,
        attn_depth = 1,
        attn_xpos_scale_base = None,
        attn_dynamic_pos_bias = False,
        use_gate_loop_layers = False,
        squeeze_excite = False,
        complex_stft_discr_logits_abs = True,
        pad_mode = 'reflect',
        stft_discriminator: Optional[Module] = None,  
        complex_stft_discr_kwargs: dict = dict()
    @property
    def device(self):
        # 返回模型参数所在的设备
        return next(self.parameters()).device

    @property
    def configs(self):
        # 返回模型的配置信息
        return pickle.loads(self._configs)

    def decode_from_codebook_indices(self, quantized_indices):
        # 从量化索引解码得到输出
        assert quantized_indices.dtype in (torch.long, torch.int32)

        if quantized_indices.ndim == 3:
            quantized_indices = rearrange(quantized_indices, 'b n (g q) -> g b n q', g = self.rq_groups)

        x = self.rq.get_output_from_indices(quantized_indices)

        return self.decode(x)

    def decode(self, x, quantize = False):
        # 解码函数,接受输入 x 和是否进行量化的标志
        if quantize:
            x, *_ = self.rq(x)

        if exists(self.decoder_attn):
            x = self.decoder_attn(x)

        x = rearrange(x, 'b n c -> b c n')
        return self.decoder(x)

    def save(self, path):
        # 保存模型参数到指定路径
        path = Path(path)
        pkg = dict(
            model = self.state_dict(),
            config = self._configs,
            version = __version__
        )

        torch.save(pkg, str(path))

    @classmethod
    def init_and_load_from(cls, path, strict = True):
        # 初始化���从指定路径加载模型
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

        config = pickle.loads(pkg['config'])
        soundstream = cls(**config)
        soundstream.load(path, strict = strict)
        soundstream.eval()
        return soundstream
    # 加载模型参数
    def load(self, path, strict = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 加载模型参数
        pkg = torch.load(str(path), map_location = 'cpu')

        # 检查版本

        # 如果包中包含版本信息且版本小于指定版本,则打印警告信息
        if 'version' in pkg and version.parse(pkg['version']) < parsed_version:
            print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})')

        # 检查是否有 EMA 模型
        has_ema = 'ema_model' in pkg
        # 选择要加载的模型参数
        model_pkg = pkg['ema_model'] if has_ema else pkg['model']

        # 如果有 EMA 模型,则对模型参数进行处理
        if has_ema:
            # 过滤出以 'ema_model.' 开头的键
            model_pkg = filter_by_keys(lambda k: k.startswith('ema_model.'), model_pkg)
            # 将键名中的 'ema_model.' 替换为空
            model_pkg = map_keys(lambda k: k[len('ema_model.'):], model_pkg)

        # 加载模型参数
        self.load_state_dict(model_pkg, strict = strict)

    # 从训练器保存的对象中加载模型参数
    def load_from_trainer_saved_obj(self, path):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 加载模型参数
        obj = torch.load(str(path))
        self.load_state_dict(obj['model'])

    # 返回非判别器参数
    def non_discr_parameters(self):
        return [
            *self.encoder.parameters(),
            *self.decoder.parameters(),
            *(self.encoder_attn.parameters() if exists(self.encoder_attn) else []),
            *(self.decoder_attn.parameters() if exists(self.decoder_attn) else []),
            *self.encoder_film.parameters(),
            *self.decoder_film.parameters(),
            *self.rq.parameters()
        ]

    # 返回序列长度的倍数
    @property
    def seq_len_multiple_of(self):
        return functools.reduce(lambda x, y: x * y, self.strides)

    # 返回下采样因子
    @property
    def downsample_factor(self):
        return self.seq_len_multiple_of

    # 处理输入数据
    def process_input(
        self,
        x,
        input_sample_hz = None,
        curtail_from_left = False
    ):
        # 打包输入数据
        x, ps = pack([x], '* n')

        # 如果输入采样率存在,则重新采样输入数据
        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)

        # 对输入数据进行截断
        x = curtail_to_multiple(x, self.seq_len_multiple_of, from_left = curtail_from_left)

        # 如果输入数据维度为 2,则重新排列维度
        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')

        return x, ps

    # 对音频数据进行编码
    @torch.no_grad()
    def tokenize(self, audio):
        self.eval()
        return self.forward(audio, return_codes_only = True)

    # 前向传播函数
    def forward(
        self,
        x,
        target = None,
        is_denoising = None, # 如果要学习教 SoundStream 进行去噪的 film conditioner - 需要在上面传入目标
        return_encoded = False,
        return_codes_only = False,
        return_discr_loss = False,
        return_discr_losses_separately = False,
        return_loss_breakdown = False,
        return_recons_only = False,
        input_sample_hz = None,
        apply_grad_penalty = False,
        curtail_from_left = False
# 定义一个默认的音频语音流函数,参数包括步长、目标采样率和 RQ 量化器数量
def AudioLMSoundStream(
    strides = (2, 4, 5, 8),
    target_sample_hz = 16000,
    rq_num_quantizers = 12,
    **kwargs
):
    # 返回一个音频流对象,参数包括步长、目标采样率和 RQ 量化器数量
    return SoundStream(
        strides = strides,
        target_sample_hz = target_sample_hz,
        rq_num_quantizers = rq_num_quantizers,
        **kwargs
    )

# 定义一个默认的音乐语音流函数,参数包括步长、目标采样率和 RQ 量化器数量
def MusicLMSoundStream(
    strides = (3, 4, 5, 8),
    target_sample_hz = 24000,
    rq_num_quantizers = 12,
    **kwargs
):
    # 返回一个音频流对象,参数包括步长、目标采样率和 RQ 量化器数量
    return SoundStream(
        strides = strides,
        target_sample_hz = target_sample_hz,
        rq_num_quantizers = rq_num_quantizers,
        **kwargs
    )

.\lucidrains\audiolm-pytorch\audiolm_pytorch\t5.py

# 导入 torch 库
import torch
# 导入 transformers 库
import transformers
# 从 transformers 库中导入 T5Tokenizer, T5EncoderModel, T5Config
from transformers import T5Tokenizer, T5EncoderModel, T5Config
# 从 beartype 库中导入 beartype, Union, List
from beartype import beartype
from beartype.typing import Union, List

# 设置 transformers 库的日志级别为 error,减少警告信息
transformers.logging.set_verbosity_error()

# 定义一个辅助函数 exists,用于判断值是否存在
def exists(val):
    return val is not None

# 配置常量
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}

# 全局单例变量

# 获取指定名称的 tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

# 获取指定名称的模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定名称的模型和 tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()

    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)

    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config = config)

    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]

    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config

    else:
        raise ValueError(f'unknown t5 name {name}')

    return config.d_model

# 对文本进行编码
@beartype
def t5_encode_text(
    texts: Union[str, List[str]],
    name = DEFAULT_T5_NAME,
    output_device = None
):
    # 如果 texts 是字符串,则转换为列表
    if isinstance(texts, str):
        texts = [texts]

    # 获取指定名称的模型和 tokenizer
    t5, tokenizer = get_model_and_tokenizer(name)

    # 如果 CUDA 可用,则将模型移至 CUDA
    if torch.cuda.is_available():
        t5 = t5.cuda()

    # 获取模型的设备
    device = next(t5.parameters()).device

    # 对文本进行编码
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = 'pt',
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    # 将输入张量和注意力掩��移至设备
    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    # 设置模型为评估模式
    t5.eval()

    # 进行推理
    with torch.inference_mode():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    # 扩展注意力掩码的维度
    attn_mask = attn_mask[..., None].bool()

    # 如果输出设备不存在,则对编码文本进行掩码填充并返回
    if not exists(output_device):
        encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
        return encoded_text

    # 将编码文本和注意力掩码移至输出设备
    encoded_text.to(output_device)
    attn_mask.to(output_device)

    # 对编码文本进行掩码填充并返回
    encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
    return encoded_text

.\lucidrains\audiolm-pytorch\audiolm_pytorch\trainer.py

# 导入所需的库
import re
import copy
from math import sqrt
from datetime import timedelta
from random import choice
from pathlib import Path
from shutil import rmtree
from functools import partial
from collections import Counter
from contextlib import contextmanager, nullcontext

# 导入类型提示相关的库
from beartype.typing import Union, List, Optional, Tuple, Type
from typing_extensions import Annotated

# 导入 beartype 相关的库
from beartype import beartype
from beartype.door import is_bearable
from beartype.vale import Is

# 导入 PyTorch 相关的库
import torch
import torchaudio
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
from torch.utils.data import Dataset, DataLoader, random_split

# 导入 pytorch_warmup 库
import pytorch_warmup as warmup

# 导入 einops 库
from einops import rearrange

# 导入 audiolm_pytorch 相关的库
from audiolm_pytorch.optimizer import get_optimizer
import wandb
from ema_pytorch import EMA
from audiolm_pytorch.soundstream import SoundStream
from audiolm_pytorch.encodec import EncodecWrapper
from audiolm_pytorch.audiolm_pytorch import (
    SemanticTransformer,
    SemanticTransformerWrapper,
    CoarseTransformer,
    CoarseTransformerWrapper,
    FineTransformer,
    FineTransformerWrapper,
    FairseqVQWav2Vec,
    HubertWithKmeans
)

# 导入 audiolm_pytorch 中的数据处理相关的库
from audiolm_pytorch.data import SoundDataset, get_dataloader
from audiolm_pytorch.utils import AudioConditionerBase

# 导入 audiolm_pytorch 版本相关的库
from audiolm_pytorch.version import __version__
from packaging import version

# 导入 accelerate 相关的库
from accelerate import Accelerator, DistributedType
from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.tracking import WandBTracker

# 常量定义

DEFAULT_SAMPLE_RATE = 16000

# 定义 ConstantLRScheduler 为 LambdaLR 的部分应用
ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)

# 确保只有一个 Trainer 实例化

ONE_TRAINER_INSTANTIATED = False

def check_one_trainer():
    global ONE_TRAINER_INSTANTIATED
    assert not ONE_TRAINER_INSTANTIATED, 'only one Trainer can be instantiated at a time for training'
    ONE_TRAINER_INSTANTIATED = True

DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(find_unused_parameters = True)

# 用于自动将数据从数据集传递到变换器包装器的关键字

DATASET_FIELD_TYPE_CONFIG = dict(
    raw_wave = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
    ],
    text = List[str],
    text_embeds = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim == 3]
    ],
)

# 辅助函数

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def noop(*args, **kwargs):
    pass

def find_first(cond, arr):
    for el in arr:
        if cond(el):
            return el
    return None

def cycle(dl):
    while True:
        for data in dl:
            yield data

def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

def dict_values_to_device(d: dict, device):
    out = {}
    for k, v in d.items():
        out[k] = v.to(device) if torch.is_tensor(v) else v
    return out

# 自动将数据传递到模块关键字参数路由函数

def has_duplicates(tup):
    counts = dict(Counter(tup))
    return any(filter(lambda count: count > 1, counts.values()))

def determine_types(data, config):
    output = []
    for el in data:
        for name, data_type in config.items():
            if is_bearable(el, data_type):
                output.append(name)
                break
        else:
            raise TypeError(f'unable to determine type of {data}')

    return tuple(output)

def checkpoint_num_steps(checkpoint_path):
    """Returns the number of steps trained from a checkpoint based on the filename.
    # 假设文件名格式类似于"/path/to/semantic.transformer.20000.pt",表示训练步数为2万步。在这种情况下返回20000
    """
    # 使用正则表达式查找文件路径中的数字部分,并返回结果列表
    results = re.findall(r'\d+', str(checkpoint_path))

    # 如果结果列表为空,则返回0
    if len(results) == 0:
        return 0

    # 返回结果列表中最后一个元素(即最后一个数字)
    return int(results[-1])
# 定义一个带有调度器和热身启动的优化器类
class OptimizerWithWarmupSchedule(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        accelerator: Accelerator,
        optimizer: Optimizer,
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        warmup_steps: int = 0
    ):
        super().__init__()
        # 创建一个线性热身启动对象
        self.warmup = warmup.LinearWarmup(optimizer, warmup_period = warmup_steps)

        # 如果调度器存在,则使用给定的调度器,否则使用常数学习率调度器
        if exists(scheduler):
            self.scheduler = scheduler(optimizer, **scheduler_kwargs)
        else:
            self.scheduler = ConstantLRScheduler(optimizer)

        self.optimizer = optimizer

        # 准备优化器和调度器
        self.optimizer, self.scheduler = accelerator.prepare(self.optimizer, self.scheduler)
        self.accelerator = accelerator

    # 返回状态字典
    def state_dict(self):
        return dict(
            optimizer = self.optimizer.state_dict(),
            scheduler = self.scheduler.state_dict(),
            warmup = self.warmup.state_dict()
        )

    # 加载状态字典
    def load_state_dict(self, pkg):
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.scheduler.load_state_dict(pkg['scheduler'])
        self.warmup.load_state_dict(pkg['warmup'])

    # 清零梯度
    def zero_grad(self):
        self.optimizer.zero_grad()

    # 执行优化步骤
    def step(self):
        self.optimizer.step()

        # 如果优化步骤未被跳过,则执行调度器步骤
        if not self.accelerator.optimizer_step_was_skipped:
            with self.warmup.dampening():
                self.scheduler.step()

# 主训练器类
class SoundStreamTrainer(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        soundstream: SoundStream,
        *,
        num_train_steps: int,
        batch_size: int,
        data_max_length: int = None,
        data_max_length_seconds: Union[int, float] = None,
        folder: str = None,
        dataset: Optional[Dataset] = None,
        val_dataset: Optional[Dataset] = None,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloader: Optional[DataLoader] = None,
        lr: float = 2e-4,
        grad_accum_every: int = 4,
        wd: float = 0.,
        warmup_steps: int = 1000,
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        discr_warmup_steps: Optional[int] = None,
        discr_scheduler: Optional[Type[_LRScheduler]] = None,
        discr_scheduler_kwargs: dict = dict(),
        max_grad_norm: float = 0.5,
        discr_max_grad_norm: float = None,
        save_results_every: int = 100,
        save_model_every: int = 1000,
        log_losses_every: int = 1,
        results_folder: str = './results',
        valid_frac: float = 0.05,
        random_split_seed: int = 42,
        use_ema: bool = True,
        ema_beta: float = 0.995,
        ema_update_after_step: int = 500,
        ema_update_every: int = 10,
        apply_grad_penalty_every: int = 4,
        dl_num_workers: int = 0,
        accelerator: Optional[Accelerator] = None,
        accelerate_kwargs: dict = dict(),
        init_process_group_timeout_seconds = 1800,
        dataloader_drop_last = True,
        split_batches = False,
        use_wandb_tracking = False,
        force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    @property
    def ema_tokenizer(self):
        return self.ema_soundstream.ema_model

    # 对音频进行标记化处理
    def tokenize(self, audio):
        return ema_tokenizer.tokenize(audio)

    # 将模型设置为指数移动平均模型
    def set_model_as_ema_model_(self):
        """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """
        assert self.use_ema
        self.ema_soundstream.ema_model.load_state_dict(self.soundstream.state_dict())
    # 保存模型参数到指定路径
    def save(self, path):
        # 构建包含模型参数、优化器状态、配置信息等的字典
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
            optim = self.optim.state_dict(),
            config = self.unwrapped_soundstream._configs,
            discr_optim = self.discr_optim.state_dict(),
            version = __version__
        )

        # 如果使用指数移动平均模型,保存其参数
        if self.use_ema:
            pkg['ema_model'] = self.ema_soundstream.state_dict()

        # 遍历多尺度鉴别器优化器,保存其参数
        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            pkg[key] = discr_optim.state_dict()

        # 保存整个包含模型参数的字典到指定路径
        torch.save(pkg, path)

    # 获取未包装的声音流模型
    @property
    def unwrapped_soundstream(self):
        return self.accelerator.unwrap_model(self.soundstream)

    # 加载模型参数
    def load(self, path):
        path = Path(path)
        assert path.exists()
        # 加载模型参数字典
        pkg = torch.load(str(path), map_location = 'cpu')

        # 如果加载的是旧版本,进行特殊处理

        if len(pkg.keys()) > 20:
            self.unwrapped_soundstream.load_state_dict(pkg)

            if self.use_ema:
                self.ema_soundstream.ema_model.load_state_dict(pkg)
            return

        # 检查版本

        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')

        # 否则正常加载模型参数

        self.unwrapped_soundstream.load_state_dict(pkg['model'])

        if self.use_ema:
            assert 'ema_model' in pkg
            self.ema_soundstream.load_state_dict(pkg['ema_model'])

        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            discr_optim.load_state_dict(pkg[key])

        # + 1 以从下一步开始,避免覆盖最后一个检查点

        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 遍历多尺度鉴别器
    def multiscale_discriminator_iter(self):
        for ind, discr in enumerate(self.unwrapped_soundstream.discriminators):
            yield f'multiscale_discr_optimizer_{ind}', discr

    # 遍历多尺度鉴别器优化器
    def multiscale_discriminator_optim_iter(self):
        for name, _ in self.multiscale_discriminator_iter():
            yield name, getattr(self, name)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 记录日志
    def log(self, **logs_as_kwargs):
        self.accelerator.log(logs_as_kwargs, step = self.steps.item())

    # 使用wandb跟踪器
    @contextmanager
    def wandb_tracker(self, project, run = None, hps = None):
        assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on SoundStreamTrainer'

        hps = default(hps, self.tracker_hps)

        self.accelerator.init_trackers(project, config = None)

        if exists(run):
            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
            assert exists(wandb_tracker)

            wandb_tracker.run.name = run

        yield

        self.accelerator.end_training()

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 是否主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 是否本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 训练模型
    def train(self, log_fn = noop):

        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')
# 语义转换器训练器

class SemanticTransformerTrainer(nn.Module):
    @beartype
    def __init__(
        self,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        transformer: SemanticTransformer,
        *,
        num_train_steps,
        batch_size,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        valid_dataset: Optional[Dataset] = None,
        data_max_length = None,
        data_max_length_seconds = None,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        init_process_group_timeout_seconds = 1800,
        use_wandb_tracking = False,
        split_batches = False,
        drop_last = False,
        force_clear_prev_results = None,
        average_valid_loss_over_grad_accum_every: bool = True, # if False, valid loss on a single batch
    # 保存模型参数到指定路径
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict(),
            version = __version__
        )
        torch.save(pkg, path)

    # 从指定路径加载模型参数
    def load(self, path):
        transformer = self.accelerator.unwrap_model(self.transformer)
        pkg = transformer.load(path)
        # 特定于训练器的操作
        self.optim.load_state_dict(pkg['optim'])

        # + 1 to start from the next step and avoid overwriting the last checkpoint
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)


    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 将数据元组转换为关键字参数
    def data_tuple_to_kwargs(self, data):
        if not exists(self.ds_fields):
            self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
            assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'

        return dict(zip(self.ds_fields, data))

    @contextmanager
    def wandb_tracker(self, project, run = None, hps = None):
        assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on SemanticTransformerTrainer'

        hps = default(hps, self.tracker_hps)

        self.accelerator.init_trackers(project, config = None)

        if exists(run):
            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
            assert exists(wandb_tracker)

            wandb_tracker.run.name = run

        yield

        self.accelerator.end_training()
    # 定义训练步骤函数
    def train_step(self):
        # 获取设备信息
        device = self.device

        # 获取当前步数
        steps = int(self.steps.item())

        # 设置 Transformer 模型为训练模式
        self.transformer.train()

        # 初始化日志字典
        logs = {}

        # 更新 Transformer 模型
        for i in range(self.grad_accum_every):
            # 判断是否为最后一次迭代
            is_last = i == (self.grad_accum_every - 1)
            # 根据是否为最后一次迭代选择上下文管理器
            context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

            # 将数据转换为关键字参数
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

            # 使用自动混合精度和上下文管理器进行训练
            with self.accelerator.autocast(), context():
                # 计算损失
                loss = self.train_wrapper(**data_kwargs, return_loss = True)

                # 反向传播
                self.accelerator.backward(loss / self.grad_accum_every)

            # 累积损失日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 根据最大梯度范数对梯度进行裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 打印日志
        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 每隔一段时间对结果进行采样
        self.accelerator.wait_for_everyone()

        # 如果是主进程且满足保存结果的条件
        if self.is_main and not (steps % self.save_results_every):
            # 初始化验证损失
            valid_loss = 0
            # 获取未包装的模型
            unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)

            # 计算平均验证损失
            for _ in range(self.average_valid_loss_over_grad_accum_every):
                data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
                data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)

                with torch.inference_mode():
                    unwrapped_model.eval()
                    valid_loss += unwrapped_model(**data_kwargs, return_loss = True)

            valid_loss = valid_loss.clone() # 避免推理模���到非推理模式的错误
            valid_loss /= self.average_valid_loss_over_grad_accum_every

            # 打印验证损失日志
            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 每隔一段时间保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'semantic.transformer.{steps}.pt')
            self.save(model_path)
            if self.use_wandb_tracking:
                wandb.save(model_path)
            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.accelerator.wait_for_everyone()

        # 更新步数
        self.steps.add_(1)
        return logs

    # 训练函数
    def train(self, log_fn = noop):

        # 循环训练直到达到指定步数
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 打印训练完成信息
        self.print('training complete')
# 定义粗糙变换器训练器类
class CoarseTransformerTrainer(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        transformer: CoarseTransformer,  # 粗糙变换器对象
        codec: Union[SoundStream, EncodecWrapper],  # 编解码器对象
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],  # 可选的音频向量化器对象
        *,
        num_train_steps,  # 训练步数
        batch_size,  # 批量大小
        audio_conditioner: Optional[AudioConditionerBase] = None,  # 可选的音频调节器对象
        dataset: Optional[Dataset] = None,  # 可选的数据集对象
        valid_dataset: Optional[Dataset] = None,  # 可选的验证数据集对象
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'),  # 数据集字段元组
        data_max_length = None,  # 数据最大长度
        data_max_length_seconds = None,  # 数据最大长度(秒)
        folder = None,  # 文件夹路径
        lr = 3e-4,  # 学习率
        grad_accum_every = 1,  # 梯度累积频率
        wd = 0.,  # 权重衰减
        max_grad_norm = 0.5,  # 最大梯度范数
        valid_frac = 0.05,  # 验证集比例
        random_split_seed = 42,  # 随机拆分种子
        save_results_every = 100,  # 每隔多少步保存结果
        save_model_every = 1000,  # 每隔多少步保存模型
        results_folder = './results',  # 结果文件夹路径
        accelerate_kwargs: dict = dict(),  # 加速参数字典
        init_process_group_timeout_seconds = 1800,  # 初始化进程组超时时间(秒)
        split_batches = False,  # 是否拆分批次
        drop_last = False,  # 是否丢弃最后一批
        force_clear_prev_results = None,  # 强制清除之前的结果
        use_wandb_tracking = False,  # 是否使用WandB跟踪
        average_valid_loss_over_grad_accum_every: bool = True,  # 是否在梯度累积频率上平均验证损失
    # 保存方法
    def save(self, path):
        # 封装模型、优化器状态字典和版本信息,保存到指定路径
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict(),
            version = __version__
        )
        torch.save(pkg, path)

    # 加载方法
    def load(self, path):
        # 解封装模型,加载模型状态字典和优化器状态字典
        transformer = self.accelerator.unwrap_model(self.transformer)
        pkg = transformer.load(path)
        # 加载训练器特定内容
        self.optim.load_state_dict(pkg['optim'])

        # 从下一步开始,避免覆盖最后一个检查点
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印方法
    def print(self, msg):
        # 打印消息
        self.accelerator.print(msg)

    # 生成方法
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    # WandB跟踪器上下文管理器
    @contextmanager
    def wandb_tracker(self, project, run = None, hps = None):
        assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on CoarseTransformerTrainer'

        hps = default(hps, self.tracker_hps)

        self.accelerator.init_trackers(project, config = None)

        if exists(run):
            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
            assert exists(wandb_tracker)

            wandb_tracker.run.name = run

        yield

        self.accelerator.end_training()  

    # 设备属性
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式属性
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 是否主进程属性
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 是否本地主进程属性
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process
    # 定义训练步骤函数
    def train_step(self):
        # 获取设备信息
        device = self.device

        # 获取当前步数
        steps = int(self.steps.item())

        # 设置 Transformer 模型为训练模式
        self.transformer.train()

        # 初始化日志字典
        logs = {}

        # 更新 Transformer 模型
        for i in range(self.grad_accum_every):
            # 判断是否是最后一次迭代
            is_last = i == (self.grad_accum_every - 1)
            # 根据是否是最后一次迭代选择上下文管理器
            context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

            # 从数据加载器迭代器中获取数据关键字参数
            data_kwargs = dict(zip(self.ds_fields, next(self.dl_iter)))

            # 在自动混合精度下,执行训练包装器
            with self.accelerator.autocast(), context():
                loss = self.train_wrapper(
                    **data_kwargs,
                    return_loss = True
                )

                # 反向传播并计算梯度
                self.accelerator.backward(loss / self.grad_accum_every)

            # 累积损失日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数限制,则进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)

        # 更新优化器参数
        self.optim.step()
        self.optim.zero_grad()

        # 记录日志
        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果

        self.accelerator.wait_for_everyone()

        # 如果是主进程且满足保存结果的条件
        if self.is_main and not (steps % self.save_results_every):
            valid_loss = 0
            unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)

            # 计算平均验证损失
            for i in range(self.average_valid_loss_over_grad_accum_every):
                data_kwargs = dict(zip(self.ds_fields, next(self.valid_dl_iter)))
                data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)

                with torch.no_grad():
                    unwrapped_model.eval()

                    valid_loss += unwrapped_model(
                        **data_kwargs,
                        return_loss = True
                    )

            valid_loss = valid_loss.clone() # 避免推理模式到非推理模式的错误
            valid_loss /= self.average_valid_loss_over_grad_accum_every

            # 记录验证损失日志
            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'coarse.transformer.{steps}.pt')
            self.save(model_path)
            if self.use_wandb_tracking:
                wandb.save(model_path)
            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.accelerator.wait_for_everyone()

        # 更新步数
        self.steps.add_(1)
        return logs

    # 训练函数
    def train(self, log_fn = noop):

        # 在未达到训练步数之前循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 打印训练完成信息
        self.print('training complete')
# 定义一个 FineTransformerTrainer 类,用于训练 FineTransformer 模型
class FineTransformerTrainer(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        transformer: FineTransformer,  # 接收 FineTransformer 模型
        codec: Union[SoundStream, EncodecWrapper],  # 接收音频流或编码器包装器
        *,
        num_train_steps,  # 训练步数
        batch_size,  # 批量大小
        audio_conditioner: Optional[AudioConditionerBase] = None,  # 可选的音频调节器
        dataset: Optional[Dataset] = None,  # 可选的数据集
        valid_dataset: Optional[Dataset] = None,  # 可选的验证数据集
        data_max_length = None,  # 数据最大长度
        data_max_length_seconds = None,  # 数据最大长度(秒)
        dataset_normalize = False,  # 是否对数据集进行归一化
        folder = None,  # 文件夹路径
        lr = 3e-4,  # 学习率
        grad_accum_every = 1,  # 梯度累积频率
        wd = 0.,  # 权重衰减
        max_grad_norm = 0.5,  # 最大梯度范数
        valid_frac = 0.05,  # 验证集比例
        random_split_seed = 42,  # 随机拆分种子
        save_results_every = 100,  # 每隔多少步保存结果
        save_model_every = 1000,  # 每隔多少步保存模型
        results_folder = './results',  # 结果保存文件夹路径
        accelerate_kwargs: dict = dict(),  # 加速参数
        init_process_group_timeout_seconds = 1800,  # 初始化进程组超时时间(秒)
        split_batches = False,  # 是否拆分批次
        drop_last = False,  # 是否丢弃最后一批次
        use_wandb_tracking = False,  # 是否使用 WandB 追踪
        force_clear_prev_results = None,  # 强制清除之前的结果
        average_valid_loss_over_grad_accum_every: bool = True,  # 是否在梯度累积频率上计算验证损失的平均值
    # 保存模型方法
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),  # 获取模型状态字典
            optim = self.optim.state_dict(),  # 获取优化器状态字典
            version = __version__  # 版本信息
        )
        torch.save(pkg, path)  # 保存模型参数到指定路径

    # 加载模型方法
    def load(self, path):
        transformer = self.accelerator.unwrap_model(self.transformer)  # 解封装模型
        pkg = transformer.load(path)  # 加载模型参数
        # 特定于训练器的操作
        self.optim.load_state_dict(pkg['optim'])  # 加载优化器参数

        # + 1 to start from the next step and avoid overwriting the last checkpoint
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)  # 设置训练步数

    # 打印方法
    def print(self, msg):
        self.accelerator.print(msg)  # 打印消息

    # 生成方法
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)  # 生成结果

    # WandB 追踪上下文管理器
    @contextmanager
    def wandb_tracker(self, project, run = None, hps = None):
        assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on FineTransformerTrainer'  # 断言是否启用 WandB 追踪

        hps = default(hps, self.tracker_hps)  # 设置超参数

        self.accelerator.init_trackers(project, config = None)  # 初始化追踪器

        if exists(run):
            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)  # 查找 WandB 追踪器
            assert exists(wandb_tracker)  # 断言是否存在 WandB 追踪器

            wandb_tracker.run.name = run  # 设置运行名称

        yield  # 生成结果

        self.accelerator.end_training()  # 结束训练

    # 设备属性
    @property
    def device(self):
        return self.accelerator.device  # 返回设备

    # 是否分布式属性
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)  # 判断是否分布式

    # 是否主进程属性
    @property
    def is_main(self):
        return self.accelerator.is_main_process  # 判断是否主进程

    # 是否本地主进程属性
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process  # 判断是否本地主进程

    # 数据元组转关键字参数方法
    def data_tuple_to_kwargs(self, data):
        if not exists(self.ds_fields):
            self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)  # 确定数据类型
            assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'  # 断言数据字段不能有重复字段名

        return dict(zip(self.ds_fields, data))  # 返回数据关键字参数
    # 定义训练步骤函数
    def train_step(self):
        # 获取设备信息
        device = self.device

        # 获取当前步数
        steps = int(self.steps.item())

        # 设置 Transformer 模型为训练模式
        self.transformer.train()

        # 初始化日志字典
        logs = {}

        # 更新 Transformer 模型
        for i in range(self.grad_accum_every):
            # 判断是否是最后一次迭代
            is_last = i == (self.grad_accum_every - 1)
            # 根据是否是最后一次迭代选择上下文管理器
            context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

            # 将数据转换为关键字参数
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

            # 使用自动混合精度和上下文管理器执行训练
            with self.accelerator.autocast(), context():
                # 计算损失
                loss = self.train_wrapper(**data_kwargs, return_loss = True)

                # 反向传播
                self.accelerator.backward(loss / self.grad_accum_every)

            # 累积损失日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数,则进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 打印日志
        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果
        self.accelerator.wait_for_everyone()

        # 如果是主进程且满足保存结果条件
        if self.is_main and not (steps % self.save_results_every):
            # 获取未包装的模型
            unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)
            valid_loss = 0

            # 计算验证集损失
            for i in range(self.average_valid_loss_over_grad_accum_every):
                data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
                data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)

                with torch.inference_mode():
                    unwrapped_model.eval()
                    valid_loss += unwrapped_model(**data_kwargs, return_loss = True)

            valid_loss = valid_loss.clone() # 避免推理模式到非推理模式的错误
            valid_loss /= self.average_valid_loss_over_grad_accum_every

            # 打印验证集损失
            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'fine.transformer.{steps}.pt')
            self.save(model_path)
            if self.use_wandb_tracking:
                wandb.save(model_path)
            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.accelerator.wait_for_everyone()

        # 更新步数
        self.steps.add_(1)
        return logs

    # 训练函数
    def train(self, log_fn = noop):

        # 循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 训练完成后打印信息
        self.print('training complete')

.\lucidrains\audiolm-pytorch\audiolm_pytorch\utils.py

# 从 torch 模块中导入 nn 模块

from torch import nn

# 定义函数

def round_down_nearest_multiple(num, divisor):
    # 返回最接近 num 且能被 divisor 整除的数
    return num // divisor * divisor

def curtail_to_multiple(t, mult, from_left = False):
    # 获取输入张量的最后一个维度的长度
    data_len = t.shape[-1]
    # 将长度舍入到最接近的 mult 的倍数
    rounded_seq_len = round_down_nearest_multiple(data_len, mult)
    # 根据 from_left 参数选择截取的方式
    seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None)
    # 返回截取后的张量
    return t[..., seq_slice]

# 基类

class AudioConditionerBase(nn.Module):
    # 空的类,用于继承 nn.Module
    pass

.\lucidrains\audiolm-pytorch\audiolm_pytorch\version.py

# 定义变量 __version__,赋值为字符串 '2.0.7'
__version__ = '2.0.7'

.\lucidrains\audiolm-pytorch\audiolm_pytorch\vq_wav2vec.py

# 导入所需的模块
from pathlib import Path
# 导入 torch 模块
import torch
# 导入 torch 中的 nn 模块
from torch import nn
# 导入 einops 中的 rearrange 函数
from einops import rearrange
# 导入 fairseq 模块
import fairseq
# 导入 torchaudio 中的 resample 函数
from torchaudio.functional import resample
# 导入自定义的 curtail_to_multiple 函数
from audiolm_pytorch.utils import curtail_to_multiple
# 导入 logging 模块
import logging
# 设置日志级别为 ERROR
logging.root.setLevel(logging.ERROR)

# 定义一个函数,用于判断值是否存在
def exists(val):
    return val is not None

# 定义 FairseqVQWav2Vec 类
class FairseqVQWav2Vec(nn.Module):
    """
    checkpoint path can be found at https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/README.md#vq-wav2vec
    specifically download the kmeans model for now

    $ wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt
    """

    # 初始化函数
    def __init__(
        self,
        checkpoint_path,
        target_sample_hz = 24000,
        seq_len_multiple_of = None
    ):
        super().__init__()
        self.target_sample_hz = target_sample_hz
        self.seq_len_multiple_of = seq_len_multiple_of

        # 获取路径对象
        path = Path(checkpoint_path)
        # 断言路径存在
        assert path.exists(), f'path {checkpoint_path} does not exist'

        # 加载模型
        checkpoint = torch.load(checkpoint_path)
        load_model_input = {checkpoint_path: checkpoint}
        model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)

        self.model = model[0]
        self.model.eval()

        # 断言模型有效
        assert hasattr(self.model, 'vector_quantizer') and hasattr(self.model.vector_quantizer, 'embedding'), 'the vq wav2vec model does not seem to be valid'

    # 获取 groups 属性
    @property
    def groups(self):
        return self.model.vector_quantizer.groups

    # 获取 downsample_factor 属性
    @property
    def downsample_factor(self):
        # todo: double check architecture
        return 80

    # 获取 codebook_size 属性
    @property
    def codebook_size(self):
        return self.model.vector_quantizer.embedding.shape[0]

    # 前向传播函数
    @torch.inference_mode()
    def forward(
        self,
        wav_input,
        flatten = True,
        input_sample_hz = None
    ):
        # 如果输入采样率存在,则对输入进行重采样
        if exists(input_sample_hz):
            wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)

        # 如果 seq_len_multiple_of 存在,则对输入进行截断
        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

        # 提取特征
        embed = self.model.feature_extractor(wav_input)
        # 获取 codebook 索引
        _, codebook_indices = self.model.vector_quantizer.forward_idx(embed)

        # 如果不需要展平,则返回 codebook 索引
        if not flatten:
            return codebook_indices

        # 对 codebook 索引进行重新排列
        return rearrange(codebook_indices, 'b ... -> b (...)')

.\lucidrains\audiolm-pytorch\audiolm_pytorch\__init__.py

# 导入 torch 库
import torch
# 导入版本模块
from packaging import version

# 检查 torch 版本是否大于等于 '2.0.0',如果是则执行以下操作
if version.parse(torch.__version__) >= version.parse('2.0.0'):
    # 从 einops._torch_specific 模块中导入 allow_ops_in_compiled_graph 函数
    from einops._torch_specific import allow_ops_in_compiled_graph
    # 调用 allow_ops_in_compiled_graph 函数

# 从 audiolm_pytorch.audiolm_pytorch 模块中导入 AudioLM 类
from audiolm_pytorch.audiolm_pytorch import AudioLM
# 从 audiolm_pytorch.soundstream 模块中导入 SoundStream, AudioLMSoundStream, MusicLMSoundStream 类
from audiolm_pytorch.soundstream import SoundStream, AudioLMSoundStream, MusicLMSoundStream
# 从 audiolm_pytorch.encodec 模块中导入 EncodecWrapper 类

# 从 audiolm_pytorch.audiolm_pytorch 模块中导入 SemanticTransformer, CoarseTransformer, FineTransformer 类
from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
# 从 audiolm_pytorch.audiolm_pytorch 模块中导入 FineTransformerWrapper, CoarseTransformerWrapper, SemanticTransformerWrapper 类

# 从 audiolm_pytorch.vq_wav2vec 模块中导入 FairseqVQWav2Vec 类
from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
# 从 audiolm_pytorch.hubert_kmeans 模块中导入 HubertWithKmeans 类

# 从 audiolm_pytorch.trainer 模块中导入 SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer, CoarseTransformerTrainer 类

# 从 audiolm_pytorch.audiolm_pytorch 模块中导入 get_embeds 函数
from audiolm_pytorch.audiolm_pytorch import get_embeds

AudioLM - Pytorch

Implementation of AudioLM, a Language Modeling Approach to Audio Generation out of Google Research, in Pytorch

It also extends the work for conditioning with classifier free guidance with T5. This allows for one to do text-to-audio or TTS, not offered in the paper. Yes, this means VALL-E can be trained from this repository. It is essentially the same.

Please join Join us on Discord if you are interested in replicating this work in the open

This repository now also contains a MIT licensed version of SoundStream. It is also compatible with EnCodec, which is also MIT-licensed at the time of writing.

Update: AudioLM was essentially used to 'solve' music generation in the new MusicLM

In the future, this movie clip would no longer make any sense. You would just prompt an AI instead.

Appreciation

  • Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research

  • 🤗 Huggingface for their amazing accelerate and transformers libraries

  • MetaAI for Fairseq and the liberal license

  • @eonglints and Joseph for offering their professional advice and expertise as well as pull requests!

  • @djqualia, @yigityu, @inspirit, and @BlackFox1197 for helping with the debugging of soundstream

  • Allen and LWprogramming for reviewing the code and submitting bug fixes!

  • Ilya for finding an issue with multi-scale discriminator downsampling and for soundstream trainer improvements

  • Andrey for identifying a missing loss in soundstream and guiding me through the proper mel spectrogram hyperparameters

  • Alejandro and Ilya for sharing their results with training soundstream, and for working through a few issues with the local attention positional embeddings

  • LWprogramming for adding Encodec compatibility!

  • LWprogramming for finding an issue with handling of the EOS token when sampling from the FineTransformer!

  • @YoungloLee for identifying a big bug in the 1d causal convolution for soundstream related to padding not accounting for strides!

  • Hayden for pointing out some discrepancies in the multi-scale discriminator for Soundstream

Install

$ pip install audiolm-pytorch

Usage

SoundStream & Encodec

There are two options for the neural codec. If you want to use the pretrained 24kHz Encodec, just create an Encodec object as follows:

from audiolm_pytorch import EncodecWrapper
encodec = EncodecWrapper()
# Now you can use the encodec variable in the same way you'd use the soundstream variables below.

Otherwise, to stay more true to the original paper, you can use SoundStream. First, SoundStream needs to be trained on a large corpus of audio data

from audiolm_pytorch import SoundStream, SoundStreamTrainer

soundstream = SoundStream(
    codebook_size = 4096,
    rq_num_quantizers = 8,
    rq_groups = 2,                       # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
    use_lookup_free_quantizer = True,    # whether to use residual lookup free quantization - there are now reports of successful usage of this unpublished technique
    use_finite_scalar_quantizer = False, # whether to use residual finite scalar quantization
    attn_window_size = 128,              # local attention receptive field at bottleneck
    attn_depth = 2                       # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = '/path/to/audio/files',
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length_seconds = 2,  # train on 2 second audio
    num_train_steps = 1_000_000
).cuda()

trainer.train()

# after a lot of training, you can test the autoencoding as so

soundstream.eval() # your soundstream must be in eval mode, to avoid having the residual dropout of the residual VQ necessary for training

audio = torch.randn(10080).cuda()
recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel

Your trained SoundStream can then be used as a generic tokenizer for audio


audio = torch.randn(1, 512 * 320)

codes = soundstream.tokenize(audio)

# you can now train anything with the codebook ids

recon_audio_from_codes = soundstream.decode_from_codebook_indices(codes)

# sanity check

assert torch.allclose(
    recon_audio_from_codes,
    soundstream(audio, return_recons_only = True)
)

You can also use soundstreams that are specific to AudioLM and MusicLM by importing AudioLMSoundStream and MusicLMSoundStream respectively

from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream

soundstream = AudioLMSoundStream(...) # say you want the hyperparameters as in Audio LM paper

# rest is the same as above

As of version 0.17.0, you can now invoke the class method on SoundStream to load from checkpoint files, without having to remember your configurations.

from audiolm_pytorch import SoundStream

soundstream = SoundStream.init_and_load_from('./path/to/checkpoint.pt')

To use Weights & Biases tracking, first set use_wandb_tracking = True on the SoundStreamTrainer, then do the following


trainer = SoundStreamTrainer(
    soundstream,
    ...,
    use_wandb_tracking = True
)

# wrap .train() with contextmanager, specifying project and run name

with trainer.wandb_tracker(project = 'soundstream', run = 'baseline'):
    trainer.train()

Hierarchical Transformers

Then three separate transformers (SemanticTransformer, CoarseTransformer, FineTransformer) need to be trained

ex. SemanticTransformer

import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    flash_attn = True
).cuda()


trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder ='/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

ex. CoarseTransformer

import torch
from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6,
    flash_attn = True
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = '/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

trainer.train()

ex. FineTransformer

import torch
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer

soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6,
    flash_attn = True
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = soundstream,
    folder = '/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

trainer.train()

All together now

from audiolm_pytorch import AudioLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

# or with priming

generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))

# or with text condition, if given

generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])

Text Conditioned Audio Synthesis

Update: Looks like this will work, given 'VALL-E'

ex. Semantic Transformer

import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = 500,
    dim = 1024,
    depth = 6,
    has_condition = True,               # this will have to be set to True
    cond_as_self_attn_prefix = True     # whether to condition as prefix to self attention, instead of cross attention, as was done in 'VALL-E' paper
).cuda()

# mock text audio dataset (as an example)

# you will have to extend your own from `Dataset`, and return an audio tensor as well as a string (the audio description) in any order (the framework will autodetect and route it into the transformer)

from torch.utils.data import Dataset

class MockTextAudioDataset(Dataset):
    def __init__(self, length = 100, audio_length = 320 * 32):
        super().__init__()
        self.audio_length = audio_length
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        mock_audio = torch.randn(self.audio_length)
        mock_caption = 'audio caption'
        return mock_caption, mock_audio

dataset = MockTextAudioDataset()

# instantiate semantic transformer trainer and train

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    dataset = dataset,
    batch_size = 4,
    grad_accum_every = 8,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

trainer.train()

# after much training above

sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_size = 1, max_length = 2) # (1, < 128) - may terminate early if it detects [eos]

Multi-GPU

Because all the trainer classes uses 🤗 Accelerator, you can easily do multi gpu training by using the accelerate command as so

At the project root

$ accelerate config

Then, in the same directory

$ accelerate launch train.py

Todo

Citations

@inproceedings{Borsos2022AudioLMAL,
  title  = {AudioLM: a Language Modeling Approach to Audio Generation},
  author = {Zal{\'a}n Borsos and Rapha{\"e}l Marinier and Damien Vincent and Eugene Kharitonov and Olivier Pietquin and Matthew Sharifi and Olivier Teboul and David Grangier and Marco Tagliasacchi and Neil Zeghidour},
  year   = {2022}
}
@misc{https://doi.org/10.48550/arxiv.2107.03312,
  title  = {SoundStream: An End-to-End Neural Audio Codec},
  author = {Zeghidour, Neil and Luebs, Alejandro and Omran, Ahmed and Skoglund, Jan and Tagliasacchi, Marco},
  publisher = {arXiv},
  url    = {https://arxiv.org/abs/2107.03312},
  year   = {2021}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam M. Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@article{Ho2022ClassifierFreeDG,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2207.12598}
}
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/rivershavewings}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Liu2022FCMFC,
    title   = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners},
    author  = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13432}
}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Li2021LocalViTBL,
    title   = {LocalViT: Bringing Locality to Vision Transformers},
    author  = {Yawei Li and K. Zhang and Jie Cao and Radu Timofte and Luc Van Gool},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2104.05707}
}
@article{Defossez2022HighFN,
    title   = {High Fidelity Neural Audio Compression},
    author  = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13438}
}
@article{Hu2017SqueezeandExcitationN,
    title   = {Squeeze-and-Excitation Networks},
    author  = {Jie Hu and Li Shen and Gang Sun},
    journal = {2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition},
    year    = {2017},
    pages   = {7132-7141}
}
@inproceedings{Yang2023HiFiCodecGV,
    title   = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
    author  = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
    year    = {2023}
}
@article{Kazemnejad2023TheIO,
    title   = {The Impact of Positional Encoding on Length Generalization in Transformers},
    author  = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.19466}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}

.\lucidrains\audiolm-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('audiolm_pytorch/version.py').read())

# 设置包的元信息
setup(
  name = 'audiolm-pytorch',  # 包名
  packages = find_packages(exclude=[]),  # 查找包
  version = __version__,  # 版本号
  license='MIT',  # 许可证
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/audiolm-pytorch',  # URL
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'audio generation'
  ],
  install_requires=[  # 安装依赖
    'accelerate>=0.24.0',
    'beartype>=0.16.1',
    'einops>=0.7.0',
    'ema-pytorch>=0.2.2',
    'encodec',
    'fairseq',
    'wandb',
    'gateloop-transformer>=0.2.3',
    'joblib',
    'local-attention>=1.9.0',
    'pytorch-warmup',
    'scikit-learn',
    'sentencepiece',
    'torch>=2.1',
    'torchaudio',
    'transformers',
    'tqdm',
    'vector-quantize-pytorch>=1.12.5'
  ],
  classifiers=[  # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\autoregressive-linear-attention-cuda\autoregressive_linear_attention_cuda\autoregressive_linear_attention_cuda.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\autoregressive-linear-attention-cuda\autoregressive_linear_attention_cuda\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

Linear Attention - Autoregressive CUDA kernel (wip)

CUDA implementation of autoregressive linear attention, with all the latest research findings.

Citations

@inproceedings{katharopoulos-et-al-2020,
    author    = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
    title     = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
    booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
    year      = {2020},
    url       = {https://arxiv.org/abs/2006.16236}
}
@article{Nguyen2022MomentumTC,
    title   = {Momentum Transformer: Closing the Performance Gap Between Self-attention and Its Linearization},
    author  = {Tan Minh Nguyen and Richard Baraniuk and Robert M. Kirby and Stanley J. Osher and Bao Wang},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.00579}
}
@article{Zhai2021AnAF,
    title   = {An Attention Free Transformer},
    author  = {Shuangfei Zhai and Walter A. Talbott and Nitish Srivastava and Chen Huang and Hanlin Goh and Ruixiang Zhang and Joshua M. Susskind},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2105.14103}
}
@inproceedings{Peng2023RWKVRR,
    title   = {RWKV: Reinventing RNNs for the Transformer Era},
    author  = {Bo Peng and Eric Alcaide and Quentin Anthony and Alon Albalak and Samuel Arcadinho and Huanqi Cao and Xin Cheng and Michael Chung and Matteo Grella and GV KranthiKiran and Xuzheng He and Haowen Hou and Przemyslaw Kazienko and Jan Kocon and Jiaming Kong and Bartlomiej Koptyra and Hayden Lau and Krishna Sri Ipsit Mantri and Ferdinand Mom and Atsushi Saito and Xiangru Tang and Bolun Wang and Johan S. Wind and Stansilaw Wozniak and Ruichong Zhang and Zhenyuan Zhang and Qihang Zhao and Peng Zhou and Jian Zhu and Rui-Jie Zhu},
    year    = {2023}
}

.\lucidrains\autoregressive-linear-attention-cuda\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'autoregressive-linear-attention-cuda',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.1',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Autoregressive Linear Attention CUDA kernel',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/autoregressive-linear-attention-cuda',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'linear attention',
    'cuda'
  ],
  # 安装依赖
  install_requires=[
    'torch>=1.6'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\axial-attention\axial_attention\axial_attention.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 axial_attention.reversible 模块中导入 ReversibleSequence 类

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 获取数组中指定索引位置的元素
def map_el_ind(arr, ind):
    return list(map(itemgetter(ind), arr))

# 对数组进行排序并返回排序后的索引
def sort_and_return_indices(arr):
    indices = [ind for ind in range(len(arr))]
    arr = zip(arr, indices)
    arr = sorted(arr)
    return map_el_ind(arr, 0), map_el_ind(arr, 1)

# 计算将输入张量转换为可进行注意力计算的排列顺序
# 同时计算将张量恢复到原始形状的逆排列顺序
def calculate_permutations(num_dimensions, emb_dim):
    total_dimensions = num_dimensions + 2
    emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)
    axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]

    permutations = []

    for axial_dim in axial_dims:
        last_two_dims = [axial_dim, emb_dim]
        dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
        permutation = [*dims_rest, *last_two_dims]
        permutations.append(permutation)
      
    return permutations

# 辅助类

# 通道层归一化
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (std + self.eps) * self.g + self.b

# 预归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

# 顺序模块
class Sequential(nn.Module):
    def __init__(self, blocks):
        super().__init__()
        self.blocks = blocks

    def forward(self, x):
        for f, g in self.blocks:
            x = x + f(x)
            x = x + g(x)
        return x

# 将输入张量排列到指定顺序并从指定顺序恢复的模块
class PermuteToFrom(nn.Module):
    def __init__(self, permutation, fn):
        super().__init__()
        self.fn = fn
        _, inv_permutation = sort_and_return_indices(permutation)
        self.permutation = permutation
        self.inv_permutation = inv_permutation

    def forward(self, x, **kwargs):
        axial = x.permute(*self.permutation).contiguous()

        shape = axial.shape
        *_, t, d = shape

        # 合并除了轴向维度之外的所有维度
        axial = axial.reshape(-1, t, d)

        # 注意力计算
        axial = self.fn(axial, **kwargs)

        # 恢复到原始形状和排列顺序
        axial = axial.reshape(*shape)
        axial = axial.permute(*self.inv_permutation).contiguous()
        return axial

# 轴向位置嵌入
class AxialPositionalEmbedding(nn.Module):
    def __init__(self, dim, shape, emb_dim_index = 1):
        super().__init__()
        parameters = []
        total_dimensions = len(shape) + 2
        ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]

        self.num_axials = len(shape)

        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
            shape = [1] * total_dimensions
            shape[emb_dim_index] = dim
            shape[axial_dim_index] = axial_dim
            parameter = nn.Parameter(torch.randn(*shape))
            setattr(self, f'param_{i}', parameter)

    def forward(self, x):
        for i in range(self.num_axials):
            x = x + getattr(self, f'param_{i}')
        return x

# 注意力
class SelfAttention(nn.Module):
    def __init__(self, dim, heads, dim_heads = None):
        super().__init__()
        self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
        dim_hidden = self.dim_heads * heads

        self.heads = heads
        self.to_q = nn.Linear(dim, dim_hidden, bias = False)
        self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias = False)
        self.to_out = nn.Linear(dim_hidden, dim)

    def forward(self, x, kv = None):
        kv = x if kv is None else kv
        q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))

        b, t, d, h, e = *q.shape, self.heads, self.dim_heads

        merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
        q, k, v = map(merge_heads, (q, k, v))

        dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)
        dots = dots.softmax(dim=-1)
        out = torch.einsum('bij,bje->bie', dots, v)

        out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
        out = self.to_out(out)
        return out

# 轴向注意力类
class AxialAttention(nn.Module):
    def __init__(self, dim, num_dimensions = 2, heads = 8, dim_heads = None, dim_index = -1, sum_axial_out = True):
        assert (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
        super().__init__()
        self.dim = dim
        self.total_dimensions = num_dimensions + 2
        self.dim_index = dim_index if dim_index > 0 else (dim_index + self.total_dimensions)

        attentions = []
        for permutation in calculate_permutations(num_dimensions, dim_index):
            attentions.append(PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads)))

        self.axial_attentions = nn.ModuleList(attentions)
        self.sum_axial_out = sum_axial_out

    def forward(self, x):
        assert len(x.shape) == self.total_dimensions, 'input tensor does not have the correct number of dimensions'
        assert x.shape[self.dim_index] == self.dim, 'input tensor does not have the correct input dimension'

        if self.sum_axial_out:
            return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))

        out = x
        for axial_attn in self.axial_attentions:
            out = axial_attn(out)
        return out

# 轴向图像变换器
class AxialImageTransformer(nn.Module):
    def __init__(self, dim, depth, heads = 8, dim_heads = None, dim_index = 1, reversible = True, axial_pos_emb_shape = None):
        super().__init__()
        permutations = calculate_permutations(2, dim_index)

        get_ff = lambda: nn.Sequential(
            ChanLayerNorm(dim),
            nn.Conv2d(dim, dim * 4, 3, padding = 1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(dim * 4, dim, 3, padding = 1)
        )

        self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if exists(axial_pos_emb_shape) else nn.Identity()

        layers = nn.ModuleList([])
        for _ in range(depth):
            attn_functions = nn.ModuleList([PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads))) for permutation in permutations])
            conv_functions = nn.ModuleList([get_ff(), get_ff()])
            layers.append(attn_functions)
            layers.append(conv_functions)            

        execute_type = ReversibleSequence if reversible else Sequential
        self.layers = execute_type(layers)

    def forward(self, x):
        x = self.pos_emb(x)
        return self.layers(x)

.\lucidrains\axial-attention\axial_attention\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 torch.autograd.function 中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 中导入 get_device_states 和 set_device_states 函数
from torch.utils.checkpoint import get_device_states, set_device_states

# 定义一个继承自 nn.Module 的类 Deterministic
# 参考链接:https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
class Deterministic(nn.Module):
    # 初始化方法
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    # 记录随机数生成器状态的方法
    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    # 前向传播方法
    def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 定义一个继承自 nn.Module 的类 ReversibleBlock
# 参考链接:https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    # 初始化方法
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    # 前向传播方法
    def forward(self, x, f_args={}, g_args={}):
        x1, x2 = torch.chunk(x, 2, dim=1)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=1)

    # 反向传播方法
    def backward_pass(self, y, dy, f_args={}, g_args={}):
        y1, y2 = torch.chunk(y, 2, dim=1)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=1)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=1)
            dx = torch.cat([dx1, dx2], dim=1)

        return x, dx

# 定义一个继承自 nn.Module 的类 IrreversibleBlock
class IrreversibleBlock(nn.Module):
    # 初始化方法
    def __init__(self, f, g):
        super().__init__()
        self.f = f
        self.g = g

    # 前向传播方法
    def forward(self, x, f_args, g_args):
        x1, x2 = torch.chunk(x, 2, dim=1)
        y1 = x1 + self.f(x2, **f_args)
        y2 = x2 + self.g(y1, **g_args)
        return torch.cat([y1, y2], dim=1)

# 定义一个继承自 Function 的类 _ReversibleFunction
class _ReversibleFunction(Function):
    # 前向传播方法
    @staticmethod
    def forward(ctx, x, blocks, kwargs):
        ctx.kwargs = kwargs
        for block in blocks:
            x = block(x, **kwargs)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    # 反向传播方法
    @staticmethod
    def backward(ctx, dy):
        y = ctx.y
        kwargs = ctx.kwargs
        for block in ctx.blocks[::-1]:
            y, dy = block.backward_pass(y, dy, **kwargs)
        return dy, None, None

# 定义一个继承自 nn.Module 的类 ReversibleSequence
class ReversibleSequence(nn.Module):
    # 初始化方法
    def __init__(self, blocks):
        super().__init__()
        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])

    # 前向传播方法
    def forward(self, x, arg_route=(True, True), **kwargs):
        f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)
        block_kwargs = {'f_args': f_args, 'g_args': g_args}
        x = torch.cat((x, x), dim=1)
        x = _ReversibleFunction.apply(x, self.blocks, block_kwargs)
        return torch.stack(x.chunk(2, dim=1)).mean(dim=0)

.\lucidrains\axial-attention\axial_attention\__init__.py

# 从 axial_attention.axial_attention 模块中导入 AxialAttention, AxialPositionalEmbedding, AxialImageTransformer, SelfAttention 类
from axial_attention.axial_attention import AxialAttention, AxialPositionalEmbedding, AxialImageTransformer, SelfAttention

Axial Attention

PyPI version

Implementation of Axial attention in Pytorch. A simple but powerful technique to attend to multi-dimensional data efficiently. It has worked wonders for me and many other researchers.

Simply add some positional encoding to your data and pass it into this handy class, specifying which dimension is considered the embedding, and how many axial dimensions to rotate through. All the permutating, reshaping, will be taken care of for you.

This paper was actually rejected on the basis of being too simple. And yet, it has since been used successfully in a number of applications, among those weather prediction, all-attention image segmentation. Just goes to show.

Install

$ pip install axial_attention

Usage

Image

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 3, 256, 256)

attn = AxialAttention(
    dim = 3,               # embedding dimension
    dim_index = 1,         # where is the embedding dimension
    dim_heads = 32,        # dimension of each head. defaults to dim // heads if not supplied
    heads = 1,             # number of heads for multi-head attention
    num_dimensions = 2,    # number of axial dimensions (images is 2, video is 3, or more)
    sum_axial_out = True   # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)

attn(img) # (1, 3, 256, 256)

Channel-last image latents

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 20, 20, 512)

attn = AxialAttention(
    dim = 512,           # embedding dimension
    dim_index = -1,      # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 2,  # number of axial dimensions (images is 2, video is 3, or more)
)

attn(img) # (1, 20, 20 ,512)

Video

import torch
from axial_attention import AxialAttention

video = torch.randn(1, 5, 128, 256, 256)

attn = AxialAttention(
    dim = 128,           # embedding dimension
    dim_index = 2,       # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 3,  # number of axial dimensions (images is 2, video is 3, or more)
)

attn(video) # (1, 5, 128, 256, 256)

Image Transformer, with reversible network

import torch
from torch import nn
from axial_attention import AxialImageTransformer

conv1x1 = nn.Conv2d(3, 128, 1)

transformer = AxialImageTransformer(
    dim = 128,
    depth = 12,
    reversible = True
)

img = torch.randn(1, 3, 512, 512)

transformer(conv1x1(img)) # (1, 3, 512, 512)

With axial positional embedding

import torch
from axial_attention import AxialAttention, AxialPositionalEmbedding

img = torch.randn(1, 512, 20, 20)

attn = AxialAttention(
    dim = 512,
    heads = 8,
    dim_index = 1
)

pos_emb = AxialPositionalEmbedding(
    dim = 512,
    shape = (20, 20)
)

img = pos_emb(img)  # (1, 512, 20, 20)  - now positionally embedded
img = attn(img)     # (1, 512, 20, 20)

Citation

@misc{ho2019axial,
    title  = {Axial Attention in Multidimensional Transformers},
    author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
    year   = {2019},
    archivePrefix = {arXiv}
}
@misc{wang2020axialdeeplab,
    title   = {Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation},
    author  = {Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen},
    year    = {2020},
    eprint  = {2003.07853},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{huang2019ccnet,
    title   = {Ccnet: Criss-cross attention for semantic segmentation},
    author  = {Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
    pages   = {603--612},
    year    = {2019}
}

.\lucidrains\axial-attention\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'axial_attention',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '0.6.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Axial Attention',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/axial-attention',  # 项目链接
  keywords = ['attention', 'artificial intelligence'],  # 关键词
  install_requires=[
      'torch'  # 安装所需的依赖
  ],
  classifiers=[
      'Development Status :: 4 - Beta',  # 开发状态
      'Intended Audience :: Developers',  # 预期受众
      'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
      'License :: OSI Approved :: MIT License',  # 许可证类型
      'Programming Language :: Python :: 3.6',  # 使用的编程语言版本
  ],
)

.\lucidrains\axial-positional-embedding\axial_positional_embedding\axial_positional_embedding.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 operator 模块中导入 mul 函数
from operator import mul
# 从 functools 模块中导入 reduce 函数
from functools import reduce

# 定义 AxialPositionalEmbedding 类,继承自 nn.Module
class AxialPositionalEmbedding(nn.Module):
    # 初始化函数,接受维度 dim、轴形状 axial_shape 和轴维度 axial_dims
    def __init__(self, dim, axial_shape, axial_dims = None):
        super().__init__()

        # 初始化对象的属性
        self.dim = dim
        self.shape = axial_shape
        self.max_seq_len = reduce(mul, axial_shape, 1)

        # 判断是否需要对轴维度进行求和
        self.summed = axial_dims is None
        axial_dims = ((dim,) * len(axial_shape)) if self.summed else axial_dims

        # 断言轴形状和轴维度的长度相等
        assert len(self.shape) == len(axial_dims), 'number of axial dimensions must equal the number of dimensions in the shape'
        # 断言轴维度的总和等于目标维度
        assert self.summed or not self.summed and sum(axial_dims) == dim, f'axial dimensions must sum up to the target dimension {dim}'

        # 初始化权重列表
        self.weights = ParameterList(self, 'weights', len(axial_shape))

        # 遍历轴形状和轴维度,创建轴位置嵌入
        for ind, (shape, axial_dim) in enumerate(zip(self.shape, axial_dims)):
            ax_shape = [1] * len(self.shape)
            ax_shape[ind] = shape
            ax_shape = (1, *ax_shape, axial_dim)
            ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1))
            self.weights.append(ax_emb)

    # 前向传播函数
    def forward(self, x):
        b, t, e = x.shape
        # 断言序列长度小于等于最大序列长度
        assert (t <= self.max_seq_len), f'Sequence length ({t}) must be less than the maximum sequence length allowed ({self.max_seq_len})'
        embs = []

        # 遍历权重列表,扩展维度并拼接轴位置嵌入
        for ax_emb in self.weights.to_list():
            axial_dim = ax_emb.shape[-1]
            expand_shape = (b, *self.shape, axial_dim)
            emb = ax_emb.expand(expand_shape).reshape(b, self.max_seq_len, axial_dim)
            embs.append(emb)

        # 求和或拼接轴位置嵌入
        pos_emb = sum(embs) if self.summed else torch.cat(embs, dim=-1)
        return pos_emb[:, :t].to(x)

# 一个模拟参数列表对象,直到下面的问题得到解决
# https://github.com/pytorch/pytorch/issues/36035
class ParameterList(object):
    def __init__(self, kls, prefix, length):
        self.ind = 0
        self.kls = kls
        self.prefix = prefix
        self.length = length

    def _keyname(self, prefix, ind):
        return f'{prefix}_{ind}'

    def append(self, x):
        setattr(self.kls, self._keyname(self.prefix, self.ind), x)
        self.ind += 1

    def to_list(self):
        return [getattr(self.kls, self._keyname(self.prefix, i)) for i in range(self.length)]

# 为图像定义 AxialPositionalEmbedding 类

class AxialPositionalEmbeddingImage(nn.Module):
    def __init__(self, dim, axial_shape, axial_dims = None):
        super().__init__()
        # 断言轴形状必须有 2 个维度,适用于图像
        assert len(axial_shape) == 2, 'Axial shape must have 2 dimensions for images'
        # 创建 AxialPositionalEmbedding 对象
        self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims)

    # 前向传播函数
    def forward(self, img):
        b, c, h, w = img.shape
        img = img.permute(0, 2, 3, 1).reshape(b, h * w, c)
        pos_emb = self.pos_emb(img)
        return pos_emb.reshape(b, h, w, c).permute(0, 3, 1, 2)

.\lucidrains\axial-positional-embedding\axial_positional_embedding\__init__.py

# 从 axial_positional_embedding 库中导入 AxialPositionalEmbedding 和 AxialPositionalEmbeddingImage 类
from axial_positional_embedding.axial_positional_embedding import AxialPositionalEmbedding, AxialPositionalEmbeddingImage

Axial Positional Embedding

PyPI version

A type of positional embedding that is very effective when working with attention networks on multi-dimensional data, or for language models in general.

Install

$ pip install axial-positional-embedding

Usage

import torch
from axial_positional_embedding import AxialPositionalEmbedding

pos_emb = AxialPositionalEmbedding(
    dim = 512,
    axial_shape = (64, 64),          # axial shape will multiply up to the maximum sequence length allowed (64 * 64 = 4096)
    axial_dims = (256, 256)          # if not specified, dimensions will default to 'dim' for all axials and summed at the end. if specified, each axial will have the specified dimension and be concatted together. the concatted dimensions needs to sum up to the `dim` (256 + 256 = 512)
)

tokens = torch.randn(1, 1024, 512)  # assume are tokens
tokens = pos_emb(tokens) + tokens   # add positional embedding to token embeddings

Citations

@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{ho2019axial,
    title = {Axial Attention in Multidimensional Transformers},
    author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
    year = {2019},
    archivePrefix = {arXiv}
}

.\lucidrains\axial-positional-embedding\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'axial_positional_embedding',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '0.2.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Axial Positional Embedding',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/axial-positional-embedding',  # 项目链接
  keywords = ['transformers', 'artificial intelligence'],  # 关键词
  install_requires=[
      'torch'  # 安装所需的依赖
  ],
  classifiers=[
      'Development Status :: 4 - Beta',  # 开发状态
      'Intended Audience :: Developers',  # 预期受众
      'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
      'License :: OSI Approved :: MIT License',  # 许可证类型
      'Programming Language :: Python :: 3.6',  # 使用的编程语言版本
  ],
)

.\lucidrains\bidirectional-cross-attention\bidirectional_cross_attention\bidirectional_cross_attention.py

import torch
from torch import nn
from einops import rearrange
from torch import einsum

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 双向交叉注意力机制模块
class BidirectionalCrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        context_dim = None,
        dropout = 0.,
        talking_heads = False,
        prenorm = False,
    ):
        super().__init__()
        context_dim = default(context_dim, dim)

        # 初始化层归一化模块
        self.norm = nn.LayerNorm(dim) if prenorm else nn.Identity()
        self.context_norm = nn.LayerNorm(context_dim) if prenorm else nn.Identity()

        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        # 初始化 dropout 模块
        self.dropout = nn.Dropout(dropout)
        self.context_dropout = nn.Dropout(dropout)

        # 初始化线性变换模块
        self.to_qk = nn.Linear(dim, inner_dim, bias = False)
        self.context_to_qk = nn.Linear(context_dim, inner_dim, bias = False)

        self.to_v = nn.Linear(dim, inner_dim, bias = False)
        self.context_to_v = nn.Linear(context_dim, inner_dim, bias = False)

        self.to_out = nn.Linear(inner_dim, dim)
        self.context_to_out = nn.Linear(inner_dim, context_dim)

        # 初始化 talking heads 模块
        self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
        self.context_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()

    def forward(
        self,
        x,
        context,
        mask = None,
        context_mask = None,
        return_attn = False,
        rel_pos_bias = None
    ):
        b, i, j, h, device = x.shape[0], x.shape[-2], context.shape[-2], self.heads, x.device

        x = self.norm(x)
        context = self.context_norm(context)

        # 获取共享的查询/键和值用于序列和上下文
        qk, v = self.to_qk(x), self.to_v(x)
        context_qk, context_v = self.context_to_qk(context), self.context_to_v(context)

        # 分割头部
        qk, context_qk, v, context_v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (qk, context_qk, v, context_v))

        # 计算相似度
        sim = einsum('b h i d, b h j d -> b h i j', qk, context_qk) * self.scale

        # 如果提供了相对位置偏置
        if exists(rel_pos_bias):
            sim = sim + rel_pos_bias

        # 掩码
        if exists(mask) or exists(context_mask):
            mask = default(mask, torch.ones((b, i), device = device, dtype = torch.bool))
            context_mask = default(context_mask, torch.ones((b, j), device = device, dtype = torch.bool))

            attn_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        # 获取序列长度和上下文长度维度的注意力
        # 共享相似度矩阵
        attn = sim.softmax(dim = -1)
        context_attn = sim.softmax(dim = -2)

        # dropout
        attn = self.dropout(attn)
        context_attn = self.context_dropout(context_attn)

        # talking heads
        attn = self.talking_heads(attn)
        context_attn = self.context_talking_heads(context_attn)

        # 源序列聚合上下文的值,上下文聚合源序列的值
        out = einsum('b h i j, b h j d -> b h i d', attn, context_v)
        context_out = einsum('b h j i, b h j d -> b h i d', context_attn, v)

        # 合并头部并组合输出
        out, context_out = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (out, context_out))

        out = self.to_out(out)
        context_out = self.context_to_out(context_out)

        if return_attn:
            return out, context_out, attn, context_attn

        return out, context_out

.\lucidrains\bidirectional-cross-attention\bidirectional_cross_attention\__init__.py

# 从bidirectional_cross_attention包中导入BidirectionalCrossAttention类
from bidirectional_cross_attention.bidirectional_cross_attention import BidirectionalCrossAttention

Bidirectional Cross Attention

A simple cross attention that updates both the source and target in one step. The key insight is that one can do shared query / key attention and use the attention matrix twice to update both ways. Used for a contracting project for predicting DNA / protein binding here.

Install

$ pip install bidirectional-cross-attention

Usage

import torch
from bidirectional_cross_attention import BidirectionalCrossAttention

video = torch.randn(1, 4096, 512)
audio = torch.randn(1, 8192, 386)

video_mask = torch.ones((1, 4096)).bool()
audio_mask = torch.ones((1, 8192)).bool()

joint_cross_attn = BidirectionalCrossAttention(
    dim = 512,
    heads = 8,
    dim_head = 64,
    context_dim = 386
)

video_out, audio_out = joint_cross_attn(
    video,
    audio,
    mask = video_mask,
    context_mask = audio_mask
)

# attended output should have the same shape as input

assert video_out.shape == video.shape
assert audio_out.shape == audio.shape

Todo

Citations

As far as I know, I came up with it, but if you discover this in the literature, do let me know and I will cite it appropriately.

.\lucidrains\bidirectional-cross-attention\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'bidirectional-cross-attention',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.5',  # 版本号
  license='MIT',  # 许可证
  description = 'Bidirectional Cross Attention',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/bidirectional-cross-attention',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'attention mechanism'  # 关键词
  ],
  install_requires=[
    'einops>=0.7',  # 安装所需的依赖包
    'torch>=2.0',  # 安装所需的依赖包
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\big-sleep\big_sleep\biggan.py

# 导入所需的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json
import copy
import logging
import os
import shutil
import tempfile
from functools import wraps
from hashlib import sha256
import sys
from io import open

import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm

# 尝试导入 Python 3 版本的 urllib.parse,如果失败则导入 Python 2 版本的 urlparse
try:
    from urllib.parse import urlparse
except ImportError:
    from urlparse import urlparse

# 尝试导入 Python 3 版本的 pathlib.Path,设置缓存路径为用户主目录下的 .pytorch_pretrained_biggan 文件夹
try:
    from pathlib import Path
    PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
                                                   Path.home() / '.pytorch_pretrained_biggan'))
except (AttributeError, ImportError):
    PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
                                              os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan'))

logger = logging.getLogger(__name__)  # 获取当前模块的日志记录器

# 预训练模型和配置文件的下载链接映射
PRETRAINED_MODEL_ARCHIVE_MAP = {
    'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
    'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
    'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
}

PRETRAINED_CONFIG_ARCHIVE_MAP = {
    'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
    'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
    'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
}

WEIGHTS_NAME = 'pytorch_model.bin'  # 权重文件名
CONFIG_NAME = 'config.json'  # 配置文件名

# 将 URL 转换为哈希文件名的函数
def url_to_filename(url, etag=None):
    """
    Convert `url` into a hashed filename in a repeatable way.
    If `etag` is specified, append its hash to the url's, delimited
    by a period.
    """
    url_bytes = url.encode('utf-8')
    url_hash = sha256(url_bytes)
    filename = url_hash.hexdigest()

    if etag:
        etag_bytes = etag.encode('utf-8')
        etag_hash = sha256(etag_bytes)
        filename += '.' + etag_hash.hexdigest()

    return filename

# 将文件名转换为 URL 的函数
def filename_to_url(filename, cache_dir=None):
    """
    Return the url and etag (which may be ``None``) stored for `filename`.
    Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    cache_path = os.path.join(cache_dir, filename)
    if not os.path.exists(cache_path):
        raise EnvironmentError("file {} not found".format(cache_path))

    meta_path = cache_path + '.json'
    if not os.path.exists(meta_path):
        raise EnvironmentError("file {} not found".format(meta_path))

    with open(meta_path, encoding="utf-8") as meta_file:
        metadata = json.load(meta_file)
    url = metadata['url']
    etag = metadata['etag']

    return url, etag

# 缓存路径函数,根据输入的 URL 或文件名判断是下载文件还是返回本地文件路径
def cached_path(url_or_filename, cache_dir=None):
    """
    Given something that might be a URL (or might be a local path),
    determine which. If it's a URL, download the file and cache it, and
    return the path to the cached file. If it's already a local path,
    make sure the file exists and then return the path.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
    if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
        url_or_filename = str(url_or_filename)
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    parsed = urlparse(url_or_filename)
    # 如果 URL 方案是 'http', 'https', 's3' 中的一个,说明是 URL 地址,从缓存中获取数据(必要时下载)
    if parsed.scheme in ('http', 'https', 's3'):
        return get_from_cache(url_or_filename, cache_dir)
    # 如果是文件路径,并且文件存在
    elif os.path.exists(url_or_filename):
        return url_or_filename
    # 如果是文件路径,但文件不存在
    elif parsed.scheme == '':
        raise EnvironmentError("file {} not found".format(url_or_filename))
    # 其他情况,无法解析为 URL 或本地路径
    else:
        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
# 将完整的 S3 路径分割成存储桶名称和路径
def split_s3_path(url):
    # 解析 URL
    parsed = urlparse(url)
    # 检查是否存在 netloc 和 path
    if not parsed.netloc or not parsed.path:
        raise ValueError("bad s3 path {}".format(url))
    # 获取存储桶名称和 S3 路径
    bucket_name = parsed.netloc
    s3_path = parsed.path
    # 移除路径开头的 '/'
    if s3_path.startswith("/"):
        s3_path = s3_path[1:]
    return bucket_name, s3_path


# 用于包装 S3 请求的装饰器函数,以便创建更有用的错误消息
def s3_request(func):
    
    @wraps(func)
    def wrapper(url, *args, **kwargs):
        try:
            return func(url, *args, **kwargs)
        except ClientError as exc:
            # 检查错误码是否为 404
            if int(exc.response["Error"]["Code"]) == 404:
                raise EnvironmentError("file {} not found".format(url))
            else:
                raise

    return wrapper


# 检查 S3 对象的 ETag
@s3_request
def s3_etag(url):
    s3_resource = boto3.resource("s3")
    bucket_name, s3_path = split_s3_path(url)
    s3_object = s3_resource.Object(bucket_name, s3_path)
    return s3_object.e_tag


# 从 S3 直接获取文件
@s3_request
def s3_get(url, temp_file):
    s3_resource = boto3.resource("s3")
    bucket_name, s3_path = split_s3_path(url)
    s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


# 从 HTTP 获取文件
def http_get(url, temp_file):
    # 发送 GET 请求
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit="B", total=total)
    # 逐块写入文件
    for chunk in req.iter_content(chunk_size=1024):
        if chunk: # 过滤掉保持连接的新块
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()


# 从缓存获取文件
def get_from_cache(url, cache_dir=None):
    # 如果未指定缓存目录,则使用默认缓存目录
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
    # 如果是 Python 3 并且缓存目录是 Path 对象,则转换为字符串
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    # 如果缓存目录不存在,则创建
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    # 如果 URL 是以 "s3://" 开头,则获取 ETag
    if url.startswith("s3://"):
        etag = s3_etag(url)
    else:
        # 发送 HEAD 请求获取 ETag
        response = requests.head(url, allow_redirects=True)
        if response.status_code != 200:
            raise IOError("HEAD request failed for url {} with status code {}"
                          .format(url, response.status_code))
        etag = response.headers.get("ETag")

    # 根据 URL 和 ETag 生成文件名
    filename = url_to_filename(url, etag)

    # 获取缓存路径
    cache_path = os.path.join(cache_dir, filename)
    # 检查缓存路径是否存在,如果不存在则执行下载操作
    if not os.path.exists(cache_path):
        # 在下载完成之前,先下载到临时文件,然后再复制到缓存目录中
        # 否则,如果下载被中断,会导致缓存条目损坏
        with tempfile.NamedTemporaryFile() as temp_file:
            logger.info("%s not found in cache, downloading to %s", url, temp_file.name)

            # 获取文件对象
            if url.startswith("s3://"):
                s3_get(url, temp_file)
            else:
                http_get(url, temp_file)

            # 在关闭文件之前复制文件,因此需要刷新以避免截断
            temp_file.flush()
            # shutil.copyfileobj() 从当前位置开始复制,所以需要回到起始位置
            temp_file.seek(0)

            logger.info("copying %s to cache at %s", temp_file.name, cache_path)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)

            logger.info("creating metadata file for %s", cache_path)
            # 创建元数据,包括 URL 和 ETag
            meta = {'url': url, 'etag': etag}
            meta_path = cache_path + '.json'
            with open(meta_path, 'w', encoding="utf-8") as meta_file:
                json.dump(meta, meta_file)

            logger.info("removing temp file %s", temp_file.name)

    # 返回缓存路径
    return cache_path
# 从文件中提取一个去重的文本集合(集合)
# 预期文件格式是每行一个项目
def read_set_from_file(filename):
    collection = set()
    # 使用 utf-8 编码打开文件
    with open(filename, 'r', encoding='utf-8') as file_:
        # 逐行读取文件内容,去除行尾的换行符后添加到集合中
        for line in file_:
            collection.add(line.rstrip())
    # 返回集合
    return collection

# 获取文件扩展名
def get_file_extension(path, dot=True, lower=True):
    # 获取文件路径的扩展名
    ext = os.path.splitext(path)[1]
    # 如果 dot 为 True,则保留扩展名中的点号
    ext = ext if dot else ext[1:]
    # 如果 lower 为 True,则将扩展名转换为小写
    return ext.lower() if lower else ext

# BigGAN 的配置类
class BigGANConfig(object):
    """ Configuration class to store the configuration of a `BigGAN`. 
        Defaults are for the 128x128 model.
        layers tuple are (up-sample in the layer ?, input channels, output channels)
    """
    def __init__(self,
                 output_dim=128,
                 z_dim=128,
                 class_embed_dim=128,
                 channel_width=128,
                 num_classes=1000,
                 layers=[(False, 16, 16),
                         (True, 16, 16),
                         (False, 16, 16),
                         (True, 16, 8),
                         (False, 8, 8),
                         (True, 8, 4),
                         (False, 4, 4),
                         (True, 4, 2),
                         (False, 2, 2),
                         (True, 2, 1)],
                 attention_layer_position=8,
                 eps=1e-4,
                 n_stats=51):
        """Constructs BigGANConfig. """
        # 初始化 BigGAN 的配置参数
        self.output_dim = output_dim
        self.z_dim = z_dim
        self.class_embed_dim = class_embed_dim
        self.channel_width = channel_width
        self.num_classes = num_classes
        self.layers = layers
        self.attention_layer_position = attention_layer_position
        self.eps = eps
        self.n_stats = n_stats

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `BigGANConfig` from a Python dictionary of parameters."""
        # 从 Python 字典中构建 BigGANConfig 实例
        config = BigGANConfig()
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `BigGANConfig` from a json file of parameters."""
        # 从 JSON 文件中构建 BigGANConfig ��例
        with open(json_file, "r", encoding='utf-8') as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        # 将实例序列化为 Python 字典
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        # 将实例序列化为 JSON 字符串
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

# 使用谱范数封装的二维卷积层
def snconv2d(eps=1e-12, **kwargs):
    return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)

# 使用谱范数封装的线性层
def snlinear(eps=1e-12, **kwargs):
    return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)

# 使用谱范数封装的嵌入层
def sn_embedding(eps=1e-12, **kwargs):
    return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)

# 自注意力层
class SelfAttn(nn.Module):
    """ Self attention Layer"""
    # 初始化 SelfAttn 类,设置输入通道数和 epsilon 值
    def __init__(self, in_channels, eps=1e-12):
        # 调用父类的初始化方法
        super(SelfAttn, self).__init__()
        # 设置输入通道数
        self.in_channels = in_channels
        # 创建 theta 路径的 1x1 卷积层,并使用 spectral normalization
        self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
                                        kernel_size=1, bias=False, eps=eps)
        # 创建 phi 路径的 1x1 卷积层,并使用 spectral normalization
        self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
                                      kernel_size=1, bias=False, eps=eps)
        # 创建 g 路径的 1x1 卷积层,并使用 spectral normalization
        self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
                                    kernel_size=1, bias=False, eps=eps)
        # 创建输出卷积层的 1x1 卷积层,并使用 spectral normalization
        self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
                                         kernel_size=1, bias=False, eps=eps)
        # 创建最大池化层
        self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
        # 创建 Softmax 层
        self.softmax  = nn.Softmax(dim=-1)
        # 创建可学习参数 gamma
        self.gamma = nn.Parameter(torch.zeros(1))

    # 前向传播函数
    def forward(self, x):
        # 获取输入 x 的尺寸信息
        _, ch, h, w = x.size()
        # Theta 路径
        theta = self.snconv1x1_theta(x)
        theta = theta.view(-1, ch//8, h*w)
        # Phi 路径
        phi = self.snconv1x1_phi(x)
        phi = self.maxpool(phi)
        phi = phi.view(-1, ch//8, h*w//4)
        # 注意力图
        attn = torch.bmm(theta.permute(0, 2, 1), phi)
        attn = self.softmax(attn)
        # g 路径
        g = self.snconv1x1_g(x)
        g = self.maxpool(g)
        g = g.view(-1, ch//2, h*w//4)
        # 注意力加权的 g - o_conv
        attn_g = torch.bmm(g, attn.permute(0, 2, 1))
        attn_g = attn_g.view(-1, ch//2, h, w)
        attn_g = self.snconv1x1_o_conv(attn_g)
        # 输出
        out = x + self.gamma*attn_g
        return out
class BigGANBatchNorm(nn.Module):
    """ This is a batch norm module that can handle conditional input and can be provided with pre-computed
        activation means and variances for various truncation parameters.
        We cannot just rely on torch.batch_norm since it cannot handle
        batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
        If you want to train this model you should add running means and variance computation logic.
    """
    # 初始化函数,定义了 BigGANBatchNorm 类的属性和参数
    def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
        super(BigGANBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.conditional = conditional

        # 使用预先计算的统计数据来处理不同截断参数的情况
        self.register_buffer('running_means', torch.zeros(n_stats, num_features))
        self.register_buffer('running_vars', torch.ones(n_stats, num_features))
        self.step_size = 1.0 / (n_stats - 1)

        # 如果是有条件的批量归一化
        if conditional:
            assert condition_vector_dim is not None
            self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
            self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
        else:
            self.weight = torch.nn.Parameter(torch.Tensor(num_features))
            self.bias = torch.nn.Parameter(torch.Tensor(num_features))

    # 前向传播函数
    def forward(self, x, truncation, condition_vector=None):
        # 获取与此截断相关的预先计算的统计数据
        coef, start_idx = math.modf(truncation / self.step_size)
        start_idx = int(start_idx)
        if coef != 0.0:  # 插值
            running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
            running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
        else:
            running_mean = self.running_means[start_idx]
            running_var = self.running_vars[start_idx]

        if self.conditional:
            running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)

            weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
            bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)

            out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
        else:
            out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
                               training=False, momentum=0.0, eps=self.eps)

        return out

class GenBlock(nn.Module):
    # 初始化生成器块,设置输入大小、输出大小、条件向量维度、缩减因子、是否上采样、统计数、eps值
    def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
                 n_stats=51, eps=1e-12):
        # 调用父类的初始化方法
        super(GenBlock, self).__init__()
        # 设置是否上采样
        self.up_sample = up_sample
        # 判断是否需要减少通道数
        self.drop_channels = (in_size != out_size)
        # 计算中间大小
        middle_size = in_size // reduction_factor

        # 初始化批量归一化层
        self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
        # 初始化卷积层
        self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)

        self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
        self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)

        self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
        self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)

        self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
        self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)

        # 初始化ReLU激活函数
        self.relu = nn.ReLU()

    # 前向传播函数
    def forward(self, x, cond_vector, truncation):
        # 保存输入x
        x0 = x

        # 执行第一个批量归一化层、ReLU激活函数、卷积层操作
        x = self.bn_0(x, truncation, cond_vector)
        x = self.relu(x)
        x = self.conv_0(x)

        # 执行第二个批量归一化层、ReLU激活函数、上采样(如果需要)、卷积层操作
        x = self.bn_1(x, truncation, cond_vector)
        x = self.relu(x)
        if self.up_sample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv_1(x)

        # 执行第三个批量归一化层、ReLU激活函数、卷积层操作
        x = self.bn_2(x, truncation, cond_vector)
        x = self.relu(x)
        x = self.conv_2(x)

        # 执行第四个批量归一化层、ReLU激活函数、卷积层操作
        x = self.bn_3(x, truncation, cond_vector)
        x = self.relu(x)
        x = self.conv_3(x)

        # 如���需要减少通道数,进行通道数减半操作
        if self.drop_channels:
            new_channels = x0.shape[1] // 2
            x0 = x0[:, :new_channels, ...]
        # 如果需要上采样,进行上采样操作
        if self.up_sample:
            x0 = F.interpolate(x0, scale_factor=2, mode='nearest')

        # 将两部分特征相加作为输出
        out = x + x0
        return out
class Generator(nn.Module):
    def __init__(self, config):
        # 初始化生成器类,继承自 nn.Module
        super(Generator, self).__init__()
        # 保存配置信息
        self.config = config
        # 从配置中获取通道宽度
        ch = config.channel_width
        # 计算条件向量的维度
        condition_vector_dim = config.z_dim * 2

        # 生成器的线性层,输入为条件向量的维度,输出为特定维度
        self.gen_z = snlinear(in_features=condition_vector_dim,
                              out_features=4 * 4 * 16 * ch, eps=config.eps)

        layers = []
        # 遍历配置中的层信息
        for i, layer in enumerate(config.layers):
            # 如果当前层是注意力层的位置
            if i == config.attention_layer_position:
                # 添加自注意力层
                layers.append(SelfAttn(ch*layer[1], eps=config.eps))
            # 添加生成块
            layers.append(GenBlock(ch*layer[1],
                                   ch*layer[2],
                                   condition_vector_dim,
                                   up_sample=layer[0],
                                   n_stats=config.n_stats,
                                   eps=config.eps))
        # 将所有层组成模块列表
        self.layers = nn.ModuleList(layers)

        # 生成器的批归一化层
        self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
        # ReLU 激活函数
        self.relu = nn.ReLU()
        # 生成器的卷积层,将特征图转换为 RGB 图像
        self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
        # Tanh 激活函数
        self.tanh = nn.Tanh()

    def forward(self, cond_vector, truncation):
        # 生成随机噪声
        z = self.gen_z(cond_vector[0].unsqueeze(0))

        # 调整张量形状以适应 TF 权重格式
        z = z.view(-1, 4, 4, 16 * self.config.channel_width)
        z = z.permute(0, 3, 1, 2).contiguous()

        next_available_latent_index = 1
        # 遍历所有层
        for layer in self.layers:
            # 如果是生成块
            if isinstance(layer, GenBlock):
                # 使用生成块
                z = layer(z, cond_vector[next_available_latent_index].unsqueeze(0), truncation)
                next_available_latent_index += 1
            else:
                z = layer(z)

        # 批归一化
        z = self.bn(z, truncation)
        z = self.relu(z)
        z = self.conv_to_rgb(z)
        z = z[:, :3, ...]
        z = self.tanh(z)
        return z

class BigGAN(nn.Module):
    """BigGAN Generator."""

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
        # 根据预训练模型名称或路径加载模型
        if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
            model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
            config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
        else:
            model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)

        try:
            # 解析模型文件和配置文件
            resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
            resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
        except EnvironmentError:
            logger.error("Wrong model name, should be a valid path to a folder containing "
                         "a {} file and a {} file or a model name in {}".format(
                         WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
            raise

        logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))

        # 加载配置
        config = BigGANConfig.from_json_file(resolved_config_file)
        logger.info("Model config {}".format(config))

        # 实例化模型
        model = cls(config, *inputs, **kwargs)
        state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
        model.load_state_dict(state_dict, strict=False)
        return model

    def __init__(self, config):
        # 初始化 BigGAN 类,继承自 nn.Module
        super(BigGAN, self).__init__()
        # 保存配置信息
        self.config = config
        # 线性层,用于生成器的嵌入
        self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
        # 生成器实例
        self.generator = Generator(config)
    # 定义一个前向传播函数,接受输入 z(随机噪声)、class_label(类别标签)、truncation(截断值)
    def forward(self, z, class_label, truncation):
        # 断言截断值在 (0, 1] 范围内
        assert 0 < truncation <= 1

        # 通过类别标签获取对应的嵌入向量
        embed = self.embeddings(class_label)
        # 将随机噪声 z 和类别嵌入向量拼接在一起,形成条件向量
        cond_vector = torch.cat((z, embed), dim=1)

        # 使用条件向量和截断值作为参数,生成图像数据
        z = self.generator(cond_vector, truncation)
        # 返回生成的图像数据
        return z

.\lucidrains\big-sleep\big_sleep\big_sleep.py

# 导入必要的库
import os
import sys
import subprocess
import signal
import string
import re

from datetime import datetime
from pathlib import Path
import random

import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from torchvision.utils import save_image
import torchvision.transforms as T
from PIL import Image
from tqdm import tqdm, trange

from big_sleep.ema import EMA
from big_sleep.resample import resample
from big_sleep.biggan import BigGAN
from big_sleep.clip import load, tokenize

# 检查是否有可用的 CUDA
assert torch.cuda.is_available(), 'CUDA must be available in order to use Big Sleep'

# 优雅地处理键盘中断
terminate = False

def signal_handling(signum,frame):
    print('detecting keyboard interrupt, gracefully exiting')
    global terminate
    terminate = True

signal.signal(signal.SIGINT,signal_handling)

# 辅助函数
def exists(val):
    return val is not None

def open_folder(path):
    if os.path.isfile(path):
        path = os.path.dirname(path)

    if not os.path.isdir(path):
        return

    cmd_list = None
    if sys.platform == 'darwin':
        cmd_list = ['open', '--', path]
    elif sys.platform == 'linux2' or sys.platform == 'linux':
        cmd_list = ['xdg-open', path]
    elif sys.platform in ['win32', 'win64']:
        cmd_list = ['explorer', path.replace('/','\\')]
    if cmd_list == None:
        return

    try:
        subprocess.check_call(cmd_list)
    except subprocess.CalledProcessError:
        pass
    except OSError:
        pass

def create_text_path(text=None, img=None, encoding=None):
    input_name = ""
    if text is not None:
        input_name += text
    if img is not None:
        if isinstance(img, str):
            img_name = "".join(img.split(".")[:-1]) # replace spaces by underscores, remove img extension
            img_name = img_name.split("/")[-1]  # only take img name, not path
        else:
            img_name = "PIL_img"
        input_name += "_" + img_name
    if encoding is not None:
        input_name = "your_encoding"
    return input_name.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:255]

# 张量辅助函数
def differentiable_topk(x, k, temperature=1.):
    n, dim = x.shape
    topk_tensors = []

    for i in range(k):
        is_last = i == (k - 1)
        values, indices = (x / temperature).softmax(dim=-1).topk(1, dim=-1)
        topks = torch.zeros_like(x).scatter_(-1, indices, values)
        topk_tensors.append(topks)
        if not is_last:
            x = x.scatter(-1, indices, float('-inf'))

    topks = torch.cat(topk_tensors, dim=-1)
    return topks.reshape(n, k, dim).sum(dim = 1)

def create_clip_img_transform(image_width):
    clip_mean = [0.48145466, 0.4578275, 0.40821073]
    clip_std = [0.26862954, 0.26130258, 0.27577711]
    transform = T.Compose([
                    #T.ToPILImage(),
                    T.Resize(image_width),
                    T.CenterCrop((image_width, image_width)),
                    T.ToTensor(),
                    T.Normalize(mean=clip_mean, std=clip_std)
            ])
    return transform

def rand_cutout(image, size, center_bias=False, center_focus=2):
    width = image.shape[-1]
    min_offset = 0
    max_offset = width - size
    if center_bias:
        # 以图像中心为中心进行采样
        center = max_offset / 2
        std = center / center_focus
        offset_x = int(random.gauss(mu=center, sigma=std))
        offset_y = int(random.gauss(mu=center, sigma=std))
        # 如果超出边界,则均匀重新采样
        offset_x = random.randint(min_offset, max_offset) if (offset_x > max_offset or offset_x < min_offset) else offset_x
        offset_y = random.randint(min_offset, max_offset) if (offset_y > max_offset or offset_y < min_offset) else offset_y
    else:
        offset_x = random.randint(min_offset, max_offset)
        offset_y = random.randint(min_offset, max_offset)
    cutout = image[:, :, offset_x:offset_x + size, offset_y:offset_y + size]
    # 返回变量 cutout 的值
    return cutout
# 加载 BigGAN 模型

class Latents(torch.nn.Module):
    def __init__(
        self,
        num_latents = 15,
        num_classes = 1000,
        z_dim = 128,
        max_classes = None,
        class_temperature = 2.
    ):
        super().__init__()
        # 初始化正态分布的参数用于生成隐变量
        self.normu = torch.nn.Parameter(torch.zeros(num_latents, z_dim).normal_(std = 1))
        # 初始化正态分布的参数用于生成类别信息
        self.cls = torch.nn.Parameter(torch.zeros(num_latents, num_classes).normal_(mean = -3.9, std = .3))
        # 注册缓冲区,用于存储阈值
        self.register_buffer('thresh_lat', torch.tensor(1))

        # 检查最大类别数是否在合理范围内
        assert not exists(max_classes) or max_classes > 0 and max_classes <= num_classes, f'max_classes must be between 0 and {num_classes}'
        self.max_classes = max_classes
        self.class_temperature = class_temperature

    def forward(self):
        # 根据最大类别数选择类别信息
        if exists(self.max_classes):
            classes = differentiable_topk(self.cls, self.max_classes, temperature = self.class_temperature)
        else:
            classes = torch.sigmoid(self.cls)

        return self.normu, classes

class Model(nn.Module):
    def __init__(
        self,
        image_size,
        max_classes = None,
        class_temperature = 2.,
        ema_decay = 0.99
    ):
        super().__init__()
        # 确保图像尺寸合法
        assert image_size in (128, 256, 512), 'image size must be one of 128, 256, or 512'
        # 加载预训练的 BigGAN 模型
        self.biggan = BigGAN.from_pretrained(f'biggan-deep-{image_size}')
        self.max_classes = max_classes
        self.class_temperature = class_temperature
        self.ema_decay\
            = ema_decay

        self.init_latents()

    def init_latents(self):
        # 初始化隐变量
        latents = Latents(
            num_latents = len(self.biggan.config.layers) + 1,
            num_classes = self.biggan.config.num_classes,
            z_dim = self.biggan.config.z_dim,
            max_classes = self.max_classes,
            class_temperature = self.class_temperature
        )
        self.latents = EMA(latents, self.ema_decay)

    def forward(self):
        self.biggan.eval()
        out = self.biggan(*self.latents(), 1)
        return (out + 1) / 2


class BigSleep(nn.Module):
    def __init__(
        self,
        num_cutouts = 128,
        loss_coef = 100,
        image_size = 512,
        bilinear = False,
        max_classes = None,
        class_temperature = 2.,
        experimental_resample = False,
        ema_decay = 0.99,
        center_bias = False,
        larger_clip = False
    ):
        super().__init__()
        self.loss_coef = loss_coef
        self.image_size = image_size
        self.num_cutouts = num_cutouts
        self.experimental_resample = experimental_resample
        self.center_bias = center_bias

        # 根据插值方式设置插值参数
        self.interpolation_settings = {'mode': 'bilinear', 'align_corners': False} if bilinear else {'mode': 'nearest'}

        model_name = 'ViT-B/32' if not larger_clip else 'ViT-L/14'
        # 加载视觉-文本模型和图像归一化函数
        self.perceptor, self.normalize_image = load(model_name, jit = False)

        self.model = Model(
            image_size = image_size,
            max_classes = max_classes,
            class_temperature = class_temperature,
            ema_decay = ema_decay
        )

    def reset(self):
        # 重置隐变量
        self.model.init_latents()

    def sim_txt_to_img(self, text_embed, img_embed, text_type="max"):
        sign = -1
        if text_type == "min":
            sign = 1
        # 计算文本嵌入和图像嵌入的余弦相似度
        return sign * self.loss_coef * torch.cosine_similarity(text_embed, img_embed, dim = -1).mean()
    # 定义前向传播函数,接受文本嵌入和文本最小嵌入作为输入,返回损失值
    def forward(self, text_embeds, text_min_embeds=[], return_loss = True):
        # 获取图像大小和裁剪块数量
        width, num_cutouts = self.image_size, self.num_cutouts

        # 使用模型进行前向传播
        out = self.model()

        # 如果不需要返回损失值,则直接返回模型输出
        if not return_loss:
            return out

        # 初始化空列表用于存储裁剪块
        pieces = []
        for ch in range(num_cutouts):
            # 随机采样裁剪块大小
            size = int(width * torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95))
            # 获取裁剪块
            apper = rand_cutout(out, size, center_bias=self.center_bias)
            # 如果启用实验性重采样,则进行重采样
            if (self.experimental_resample):
                apper = resample(apper, (224, 224))
            else:
                apper = F.interpolate(apper, (224, 224), **self.interpolation_settings)
            pieces.append(apper)

        # 将所有裁剪块拼接在一起
        into = torch.cat(pieces)
        # 对拼接后的图像进行归一化处理
        into = self.normalize_image(into)

        # 对拼接后的图像进行编码
        image_embed = self.perceptor.encode_image(into)

        # 获取潜在向量和软标签
        latents, soft_one_hot_classes = self.model.latents()
        num_latents = latents.shape[0]
        latent_thres = self.model.latents.model.thresh_lat

        # 计算潜在向量的损失
        lat_loss =  torch.abs(1 - torch.std(latents, dim=1)).mean() + \
                    torch.abs(torch.mean(latents, dim = 1)).mean() + \
                    4 * torch.max(torch.square(latents).mean(), latent_thres)

        # 遍历每个潜在向量数组,计算额外的损失
        for array in latents:
            mean = torch.mean(array)
            diffs = array - mean
            var = torch.mean(torch.pow(diffs, 2.0))
            std = torch.pow(var, 0.5)
            zscores = diffs / std
            skews = torch.mean(torch.pow(zscores, 3.0))
            kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0

            lat_loss = lat_loss + torch.abs(kurtoses) / num_latents + torch.abs(skews) / num_latents

        # 计算分类损失
        cls_loss = ((50 * torch.topk(soft_one_hot_classes, largest = False, dim = 1, k = 999)[0]) ** 2).mean()

        # 初始化结果列表
        results = []
        # 计算文本嵌入与图像嵌入之间的相似性损失
        for txt_embed in text_embeds:
            results.append(self.sim_txt_to_img(txt_embed, image_embed))
        # 计算文本最小嵌入与图像嵌入之间的相似性损失
        for txt_min_embed in text_min_embeds:
            results.append(self.sim_txt_to_img(txt_min_embed, image_embed, "min"))
        # 计算总的相似性损失
        sim_loss = sum(results).mean()
        # 返回模型输出和各项损失值
        return out, (lat_loss, cls_loss, sim_loss)
class Imagine(nn.Module):
    # 定义 Imagine 类,继承自 nn.Module
    def __init__(
        self,
        *,
        text=None,
        img=None,
        encoding=None,
        text_min = "",
        lr = .07,
        image_size = 512,
        gradient_accumulate_every = 1,
        save_every = 50,
        epochs = 20,
        iterations = 1050,
        save_progress = False,
        bilinear = False,
        open_folder = True,
        seed = None,
        append_seed = False,
        torch_deterministic = False,
        max_classes = None,
        class_temperature = 2.,
        save_date_time = False,
        save_best = False,
        experimental_resample = False,
        ema_decay = 0.99,
        num_cutouts = 128,
        center_bias = False,
        larger_clip = False
    ):
        # 初始化函数,接收多个参数
        super().__init__()

        if torch_deterministic:
            # 如果 torch_deterministic 为真
            assert not bilinear, 'the deterministic (seeded) operation does not work with interpolation (PyTorch 1.7.1)'
            # 断言不使用双线性插值,因为确定性(种子化)操作与插值不兼容(PyTorch 1.7.1)
            torch.set_deterministic(True)

        self.seed = seed
        self.append_seed = append_seed

        if exists(seed):
            # 如果种子存在
            print(f'setting seed of {seed}')
            # 打印设置种子值
            if seed == 0:
                print('you can override this with --seed argument in the command line, or --random for a randomly chosen one')
            # 如果种子为0,提示可以在命令行中使用 --seed 参数覆盖,或者使用 --random 选择随机种子
            torch.manual_seed(seed)

        self.epochs = epochs
        self.iterations = iterations

        model = BigSleep(
            image_size = image_size,
            bilinear = bilinear,
            max_classes = max_classes,
            class_temperature = class_temperature,
            experimental_resample = experimental_resample,
            ema_decay = ema_decay,
            num_cutouts = num_cutouts,
            center_bias = center_bias,
            larger_clip = larger_clip
        ).cuda()
        # 创建 BigSleep 模型对象
        self.model = model

        self.lr = lr
        self.optimizer = Adam(model.model.latents.model.parameters(), lr)
        self.gradient_accumulate_every = gradient_accumulate_every
        self.save_every = save_every

        self.save_progress = save_progress
        self.save_date_time = save_date_time

        self.save_best = save_best
        self.current_best_score = 0

        self.open_folder = open_folder
        self.total_image_updates = (self.epochs * self.iterations) / self.save_every
        self.encoded_texts = {
            "max": [],
            "min": []
        }
        # 创建编码文本的字典
        self.clip_transform = create_clip_img_transform(224)
        # 创建图像转换
        self.set_clip_encoding(text=text, img=img, encoding=encoding, text_min=text_min)
        # 设置剪辑编码

    @property
    def seed_suffix(self):
        # 定义 seed_suffix 属性
        return f'.{self.seed}' if self.append_seed and exists(self.seed) else ''
        # 如果 append_seed 为真且存在种子值,则返回种子值后缀

    def set_text(self, text):
        # 设置文本
        self.set_clip_encoding(text = text)

    def create_clip_encoding(self, text=None, img=None, encoding=None):
        # 创建剪辑编码
        self.text = text
        self.img = img
        if encoding is not None:
            encoding = encoding.cuda()
        #elif self.create_story:
        #    encoding = self.update_story_encoding(epoch=0, iteration=1)
        elif text is not None and img is not None:
            encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2
        elif text is not None:
            encoding = self.create_text_encoding(text)
        elif img is not None:
            encoding = self.create_img_encoding(img)
        return encoding
        # 返回编码结果

    def create_text_encoding(self, text):
        # 创建文本编码
        tokenized_text = tokenize(text).cuda()
        # 对文本进行标记化
        with torch.no_grad():
            text_encoding = self.model.perceptor.encode_text(tokenized_text).detach()
        # 使用模型对文本进行编码
        return text_encoding
        # 返回文本编码结果
    # 创建图像编码,将图像转换为张量并进行归一化处理,然后在GPU上执行
    def create_img_encoding(self, img):
        if isinstance(img, str):
            img = Image.open(img)
        normed_img = self.clip_transform(img).unsqueeze(0).cuda()
        with torch.no_grad():
            img_encoding = self.model.perceptor.encode_image(normed_img).detach()
        return img_encoding
    
    # 对多个短语进行编码,根据文本类型将编码结果存储在字典中
    def encode_multiple_phrases(self, text, img=None, encoding=None, text_type="max"):
        if text is not None and "|" in text:
            self.encoded_texts[text_type] = [self.create_clip_encoding(text=prompt_min, img=img, encoding=encoding) for prompt_min in text.split("|")]
        else:
            self.encoded_texts[text_type] = [self.create_clip_encoding(text=text, img=img, encoding=encoding)]

    # 对最大和最小短语进行编码,调用encode_multiple_phrases方法
    def encode_max_and_min(self, text, img=None, encoding=None, text_min=""):
        self.encode_multiple_phrases(text, img=img, encoding=encoding)
        if text_min is not None and text_min != "":
            self.encode_multiple_phrases(text_min, img=img, encoding=encoding, text_type="min")

    # 设置Clip编码,包括文本、图像、编码等信息,并调用encode_max_and_min方法
    def set_clip_encoding(self, text=None, img=None, encoding=None, text_min=""):
        self.current_best_score = 0
        self.text = text
        self.text_min = text_min
        
        if len(text_min) > 0:
            text = text + "_wout_" + text_min[:255] if text is not None else "wout_" + text_min[:255]
        text_path = create_text_path(text=text, img=img, encoding=encoding)
        if self.save_date_time:
            text_path = datetime.now().strftime("%y%m%d-%H%M%S-") + text_path

        self.text_path = text_path
        self.filename = Path(f'./{text_path}{self.seed_suffix}.png')
        self.encode_max_and_min(text, img=img, encoding=encoding, text_min=text_min) # Tokenize and encode each prompt

    # 重置模型,将模型移至GPU上,并初始化优化器
    def reset(self):
        self.model.reset()
        self.model = self.model.cuda()
        self.optimizer = Adam(self.model.model.latents.parameters(), self.lr)

    # 训练模型的一步,计算损失并更新模型参数
    def train_step(self, epoch, i, pbar=None):
        total_loss = 0

        for _ in range(self.gradient_accumulate_every):
            out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
            loss = sum(losses) / self.gradient_accumulate_every
            total_loss += loss
            loss.backward()

        self.optimizer.step()
        self.model.model.latents.update()
        self.optimizer.zero_grad()

        if (i + 1) % self.save_every == 0:
            with torch.no_grad():
                self.model.model.latents.eval()
                out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
                top_score, best = torch.topk(losses[2], k=1, largest=False)
                image = self.model.model()[best].cpu()
                self.model.model.latents.train()

                save_image(image, str(self.filename))
                if pbar is not None:
                    pbar.update(1)
                else:
                    print(f'image updated at "./{str(self.filename)}"')

                if self.save_progress:
                    total_iterations = epoch * self.iterations + i
                    num = total_iterations // self.save_every
                    save_image(image, Path(f'./{self.text_path}.{num}{self.seed_suffix}.png'))

                if self.save_best and top_score.item() < self.current_best_score:
                    self.current_best_score = top_score.item()
                    save_image(image, Path(f'./{self.text_path}{self.seed_suffix}.best.png'))

        return out, total_loss
    # 定义一个方法用于前向传播
    def forward(self):
        # 初始化一个空字符串用于记录惩罚信息
        penalizing = ""
        # 如果self.text_min的长度大于0,则将punishing赋值为包含self.text_min的字符串
        if len(self.text_min) > 0:
            penalizing = f'penalizing "{self.text_min}"'
        # 打印信息,包括self.text_path和punishing信息
        print(f'Imagining "{self.text_path}" {penalizing}...')
        
        # 禁用梯度计算
        with torch.no_grad():
            # 对模型进行一次前向传播,用于解决CLIP和CUDA的问题
            self.model(self.encoded_texts["max"][0])

        # 如果需要打开文件夹
        if self.open_folder:
            # 打开当前目录
            open_folder('./')
            # 将self.open_folder设置为False
            self.open_folder = False

        # 创建一个进度条用于显示图片更新的进度
        image_pbar = tqdm(total=self.total_image_updates, desc='image update', position=2, leave=True)
        # 创建一个进度条用于显示训练轮数的进度
        epoch_pbar = trange(self.epochs, desc = '      epochs', position=0, leave=True)
        # 遍历每个轮数
        for epoch in (ep for ep in epoch_pbar if not terminate):
            # 创建一个进度条用于显示每轮训练迭代的进度
            pbar = trange(self.iterations, desc='   iteration', position=1, leave=True)
            # 更新图片更新进度条
            image_pbar.update(0)
            # 遍历每个迭代
            for i in (it for it in pbar if not terminate):
                # 执行训练步骤,获取输出和损失值
                out, loss = self.train_step(epoch, i, image_pbar)
                # 设置进度条描述信息为当前损失值
                pbar.set_description(f'loss: {loss.item():04.2f}')
posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(23)  评论(0编辑  收藏  举报