Lucidrains-系列项目源码解析-五十一-

Lucidrains 系列项目源码解析(五十一)

.\lucidrains\x-transformers\x_transformers\xval.py

"""
定义了一个基于离散标记的常规变换器,但对于数字是连续的
更好地泛化了算术
https://arxiv.org/abs/2310.02989
"""

# 导入所需的库
import torch
from torch import nn, Tensor
import torch.nn.functional as F

from typing import Callable
from collections import namedtuple

from einops import rearrange
from einops.layers.torch import Rearrange

from x_transformers.x_transformers import (
    AttentionLayers,
    TokenEmbedding,
    ScaledSinusoidalEmbedding,
    AbsolutePositionalEmbedding
)

from x_transformers.autoregressive_wrapper import (
    top_k,
    top_p
)

# 常量

# 定义一个命名元组,用于表示损失的细分
LossBreakdown = namedtuple('LossBreakdown', ['cross_entropy_loss', 'numerical_mse_loss'])

# 定义一个命名元组,用于表示生成的返回结果
GenerateReturn = namedtuple('GenerateReturn', ['sampled_token_ids', 'sampled_numbers', 'is_number_mask'])

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 主要类

class XValTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        numerical_token_id,
        attn_layers: AttentionLayers,
        emb_dim = None,
        logits_dim = None,
        tie_embedding = False,
        max_mem_len = 0,
        num_memory_tokens = None,
        emb_dropout = 0.,
        use_abs_pos_emb = True,
        scaled_sinu_pos_emb = False
    ):
        super().__init__()
        dim = attn_layers.dim
        emb_dim = default(emb_dim, dim)

        self.emb_dim = emb_dim
        self.token_emb = TokenEmbedding(emb_dim, num_tokens)

        self.numerical_token_id = numerical_token_id

        self.max_seq_len = max_seq_len

        self.max_mem_len = max_mem_len

        if not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb):
            self.pos_emb = always(0)  # 如果不使用绝对位置编码或者禁用了绝对位置编码,则将位置编码设置为常数0
        elif scaled_sinu_pos_emb:
            self.pos_emb = ScaledSinusoidalEmbedding(dim)  # 如果使用了缩放的正弦位置编码,则使用缩放的正弦位置编码
        else:
            self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)  # 否则使用绝对位置编码

        self.emb_dropout = nn.Dropout(emb_dropout)

        # 内存标记

        num_memory_tokens = default(num_memory_tokens, 0)
        self.has_memory_tokens = num_memory_tokens > 0

        if num_memory_tokens > 0:
            self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))  # 初始化内存标记

        # 注意力层

        self.attn_layers = attn_layers

        # 转换为logits

        logits_dim = default(logits_dim, num_tokens)
        self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()

        self.to_numerical_output = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... 1 -> ...')
        )

    def forward(
        self,
        x: Tensor,
        x_num: Tensor,
        return_embeddings = False,
        return_intermediates = False,
        return_mems = False,
        mask = None,
        return_attn = False,
        mems = None,
        pos = None,
        prepend_embeds = None,
        **kwargs
        ):
        # 断言输入张量 x 的形状与 x_num 的形状相同
        assert x.shape == x_num.shape

        # 获取批次大小
        batch = x.shape[0]

        # 创建数值标记掩码
        is_number_mask = x == self.numerical_token_id

        # 对输入进行 token 嵌入
        x = self.token_emb(x)

        # 根据数值标记掩码调整缩放因子
        scale = torch.where(is_number_mask, x_num, 1.)
        # 重新排列张量维度,添加一个维度
        scale = rearrange(scale, '... -> ... 1')

        # 对输入进行缩放
        x = x * scale

        # 添加位置嵌入
        x = x + self.pos_emb(x, pos = pos)

        # 存储记忆令牌

        if self.has_memory_tokens:
            # 复制记忆令牌,扩展为与批次大小相同的维度
            m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
            # 打包输入张量和记忆令牌
            x, mem_ps = pack([m, x], 'b * d')

            if exists(mask):
                num_mems = m.shape[-2]
                # 在指定维度上填充掩码
                mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)

        # 是否追加嵌入,如 PaLI 中的图像嵌入
        if exists(prepend_embeds):
            _, prepend_dim = prepend_embeds.shape[1:]
            # 断言追加的嵌入维度与模型维度相同
            assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'

            # 在指定维度上连接张量
            x = torch.cat((prepend_embeds, x), dim = -2)

        # 对输入进行嵌入层的 dropout
        x = self.emb_dropout(x)

        # 注意力层

        x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)

        # 分离记忆令牌

        if self.has_memory_tokens:
            m, x = unpack(x, mem_ps, 'b * d')
            intermediates.memory_tokens = m

        # 如果不返回嵌入,则生成 logits 和数值预测
        if not return_embeddings:
            logits = self.to_logits(x)
            numerical_pred = self.to_numerical_output(x)
            out = (logits, numerical_pred)
        else:
            out = x

        # 如果返回中间结果
        if return_intermediates:
            return out, intermediates

        # 如果返回记忆令牌
        if return_mems:
            hiddens = intermediates.hiddens
            new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
            return out, new_mems

        # 如果返回注意力图
        if return_attn:
            attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
            return out, attn_maps

        return out
class XValAutoregressiveWrapper(nn.Module):
    # 定义 XValAutoregressiveWrapper 类,继承自 nn.Module
    def __init__(
        self,
        net: XValTransformerWrapper,
        ignore_index = -100,
        pad_value = 0,
        numerical_loss_weight = 1.
    ):
        # 初始化函数,接受网络 net、ignore_index、pad_value 和 numerical_loss_weight 参数
        super().__init__()
        # 调用父类的初始化函数
        self.net = net
        # 将传入的网络赋值给对象属性 net
        self.max_seq_len = net.max_seq_len
        # 获取网络的最大序列长度
        self.numerical_loss_weight = numerical_loss_weight
        # 设置数值损失的权重
        self.ignore_index = ignore_index
        # 设置忽略的索引值

    @torch.no_grad()
    def generate(
        self,
        start_tokens: Tensor,
        start_numbers: Tensor,
        seq_len,
        filter_logits_fn: Callable = top_k,
        filter_kwargs: dict = dict(),
        temperature = 1.,
        **kwargs
    ):
        # 生成函数,接受起始标记、起始数字、序列长度等参数
        device = start_tokens.device
        # 获取起始标记所在设备
        was_training = self.net.training
        # 保存网络是否处于训练状态
        num_dims = len(start_tokens.shape)
        # 获取起始标记的维度数

        assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'
        # 断言起始标记的维度数至少为 2
        assert start_tokens.shape == start_numbers.shape
        # 断言起始标记和起始数字的形状相同

        b, t, device = *start_tokens.shape, start_tokens.device
        # 获取起始标记的形状和设备信息
        self.net.eval()
        # 将网络设置为评估模式
        out = start_tokens
        num_out = start_numbers
        # 初始化输出和数字输出

        for _ in range(seq_len):
            # 循环生成序列
            x = out[:, -self.max_seq_len:]
            x_num = num_out[:, -self.max_seq_len:]
            # 获取最后 max_seq_len 个标记和数字

            logits, numerical_pred = self.net(x, x_num, **kwargs)
            # 使用网络生成 logits 和数值预测

            last_logits = logits[:, -1]
            last_num_pred = numerical_pred[:, -1:]
            # 获取最后一个 logits 和数值预测

            filtered_logits = filter_logits_fn(last_logits, **filter_kwargs)
            # 使用过滤函数过滤 logits

            probs = F.softmax(filtered_logits / temperature, dim=-1)
            # 计算 softmax 概率

            sample = torch.multinomial(probs, 1)
            # 从概率分布中采样一个标记

            out = torch.cat((out, sample), dim = -1)
            num_out = torch.cat((num_out, last_num_pred), dim = -1)
            # 将新生成的标记和数值添加到输出中

        out = out[:, t:]
        num_out = num_out[:, t:]
        # 去除起始标记
        is_number = out == self.net.numerical_token_id
        # 判断是否为数值标记
        num_out = torch.where(is_number, num_out, float('nan'))
        # 将非数值标记的数值设置为 NaN

        self.net.train(was_training)
        # 恢复网络的训练状态
        return GenerateReturn(out, num_out, is_number)
        # 返回生成的序列和数值信息

    def forward(
        self,
        x: Tensor,
        x_num: Tensor,
        return_loss_breakdown = False,
        **kwargs
    ):
        # 前向传播函数,接受输入 x、数值输入 x_num 和其他参数
        inp, target = x[:, :-1], x[:, 1:]
        # 获取输入和目标序列
        x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
        # 获取数值输入和数值目标

        mask = kwargs.get('mask', None)
        # 获取掩码
        if exists(mask) and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]
            kwargs['mask'] = mask
        # 处理掩码

        logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
        # 使用网络进行前向传播

        logits = rearrange(logits, 'b n c -> b c n')
        # 重新排列 logits 的维度

        cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index)
        # 计算交叉熵损失

        target_mask = target != self.ignore_index
        # 创建目标掩码

        numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
        # 计算数值均方误差损失

        numerical_mse_loss = numerical_mse_loss * target_mask
        # 根据目标掩码调整数值损失

        loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
        # 计算总损失

        if exists(mask):
            loss = loss[mask]
        # 根据掩码筛选损失

        loss = loss.mean()
        # 计算平均损失

        if not return_loss_breakdown:
            return loss
        # 如果不需要详细损失信息,直接返回总损失

        return loss, LossBreakdown(cross_entropy_loss, numerical_mse_loss)
        # 返回总损失和损失细分信息

.\lucidrains\x-transformers\x_transformers\x_transformers.py

# 导入数学库
import math
# 从 random 模块中导入 random 函数
from random import random
# 从 typing 模块中导入 Dict 类型提示
from typing import Dict
# 从 packaging 模块中导入 version 版本信息
from packaging import version

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast

# 导入 functools 模块中的 partial, wraps 函数
from functools import partial, wraps
# 导入 collections 模块中的 namedtuple 类
from collections import namedtuple
# 导入 dataclasses 模块中的 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块中导入 List, Callable, Optional, Union 类型提示
from typing import List, Callable, Optional, Union

# 从 einops 模块中导入 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 从 x_transformers.attend 模块中导入 Attend, Intermediates 类
from x_transformers.attend import Attend, Intermediates
# 从 x_transformers.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类

# 常量定义

# 默认头部维度
DEFAULT_DIM_HEAD = 64

# 定义 LayerIntermediates 数据类
@dataclass
class LayerIntermediates:
    hiddens:            Optional[List[Tensor]] = None   # 所有隐藏层,在最终规范化之前(在预规范化架构中)
    last_hidden:        Optional[Tensor] = None         # 所有注意力层之后的最后一个隐藏层,在最终规范化之后
    attn_intermediates: Optional[List[Intermediates]] = None
    layer_hiddens:      Optional[List[Tensor]] = None
    attn_z_loss:        Optional[Tensor] = None
    mems:               Optional[Tensor] = None
    memory_tokens:      Optional[Tensor] = None

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将变量转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else (val,) * depth

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 如果变量存在则执行函数
def maybe(fn):
    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)
    return inner

# 至多一个为真
def at_most_one_of(*bools):
    return sum(map(int, bools)) <= 1

# 始终返回相同值
class always():
    def __init__(self, val):
        self.val = val
    def __call__(self, *args, **kwargs):
        return self.val

# 不等于某个值
class not_equals():
    def __init__(self, val):
        self.val = val
    def __call__(self, x, *args, **kwargs):
        return x != self.val

# 等于某个值
class equals():
    def __init__(self, val):
        self.val = val
    def __call__(self, x, *args, **kwargs):
        return x == self.val

# 创建序列模块
def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# 张量辅助函数

# 返回张量的最小负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 对张量进行 L2 归一化
def l2norm(t, groups = 1):
    t = rearrange(t, '... (g d) -> ... g d', g = groups)
    t = F.normalize(t, p = 2, dim = -1)
    return rearrange(t, '... g d -> ... (g d)')

# 在指定维度上填充张量
def pad_at_dim(t, pad, dim = -1, value = 0.):
    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value = value)

# 对多个掩码进行逻辑或操作
def or_reduce(masks):
    head, *body = masks
    for rest in body:
        head = head | rest
    return head

# 辅助损失函数

# 计算 z 损失
def calc_z_loss(
    pre_softmax_attns: List[Tensor],
    mask = None,
    weight = 1.
):
    # 在 https://arxiv.org/abs/2202.08906 中应用于专家混合路由器对数的相同损失
    # 在论文中,他们在一个小脚注中提到将其应用于注意力对数,具有稳定效果
    # 在 PaLM 中也作为措施之一使用

    lse = 0.

    for attn in pre_softmax_attns:
        lse = lse + attn.logsumexp(dim = -1)

    loss = torch.square(lse)
    loss = reduce(loss, 'b h n -> b n', 'sum')

    if not exists(mask):
        return loss.mean() * weight

    loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
    return loss * weight

# 初始化辅助函数

# 初始化为零
def init_zero_(layer):
    nn.init.constant_(layer.weight, 0.)
    if exists(layer.bias):
        nn.init.constant_(layer.bias, 0.)

# 关键字参数辅助函数

# 选择并弹出键值对
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)
# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return str.startswith(prefix)

# 根据键的前缀对字典进行分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀对字典进行分组并修剪前缀
def groupby_prefix_and_trim(prefix, d):
    # 根据前缀对字典进行分组
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    # 剔除前缀,生成新的字典
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 结构化的 dropout,比传统的注意力 dropout 更有效
def dropout_seq(seq, mask, dropout):
    # 获取序列的形状和设备信息
    b, n, *_, device = *seq.shape, seq.device
    # 生成服从标准正态分布的随机数
    logits = torch.randn(b, n, device=device)

    # 如果存在掩码
    if exists(mask):
        # 获取 logits 中的最大负值
        mask_value = max_neg_value(logits)
        # 使用 mask_value 替换掩码为 False 的位置
        logits = logits.masked_fill(~mask, mask_value)

    # 计算保留的概率和保留的数量
    keep_prob = 1. - dropout
    num_keep = max(1, int(keep_prob * n))
    keep_indices = logits.topk(num_keep, dim=1).indices

    # 生成批次索引
    batch_indices = torch.arange(b, device=device)
    batch_indices = rearrange(batch_indices, 'b -> b 1')

    # 根据保留的索引获取序列的子集
    seq = seq[batch_indices, keep_indices]

    # 如果存在掩码
    if exists(mask):
        # 计算序列中每个样本的非零元素数量
        seq_counts = mask.sum(dim=-1)
        # 计算保留的元素数量
        seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
        keep_mask = torch.arange(num_keep, device=device) < rearrange(seq_keep_counts, 'b -> b 1')

        # 更新掩码
        mask = mask[batch_indices, keep_indices] & keep_mask

    return seq, mask

# 激活函数
class ReluSquared(nn.Module):
    def forward(self, x):
        return F.relu(x) ** 2

# 词嵌入
class TokenEmbedding(nn.Module):
    def __init__(self, dim, num_tokens, l2norm_embed=False):
        super().__init__()
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(num_tokens, dim)

    def forward(self, x):
        token_emb = self.emb(x.long())
        return l2norm(token_emb) if self.l2norm_embed else token_emb

# 绝对位置嵌入
class AbsolutePositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len, l2norm_embed=False):
        super().__init__()
        self.scale = dim ** -0.5 if not l2norm_embed else 1.
        self.max_seq_len = max_seq_len
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(max_seq_len, dim)

    def forward(self, x, pos=None, seq_start_pos=None):
        seq_len, device = x.shape[1], x.device
        assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'

        if not exists(pos):
            pos = torch.arange(seq_len, device=device)

        if exists(seq_start_pos):
            pos = (pos - seq_start_pos[..., None]).clamp(min=0)

        pos_emb = self.emb(pos)
        pos_emb = pos_emb * self.scale
        return l2norm(pos_emb) if self.l2norm_embed else pos_emb

# 缩放的正弦位置嵌入
class ScaledSinusoidalEmbedding(nn.Module):
    def __init__(self, dim, theta=10000):
        super().__init__()
        assert divisible_by(dim, 2)
        self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)

        half_dim = dim // 2
        freq_seq = torch.arange(half_dim).float() / half_dim
        inv_freq = theta ** -freq_seq
        self.register_buffer('inv_freq', inv_freq, persistent=False)

    def forward(self, x, pos=None, seq_start_pos=None):
        seq_len, device = x.shape[1], x.device

        if not exists(pos):
            pos = torch.arange(seq_len, device=device)

        if exists(seq_start_pos):
            pos = pos - seq_start_pos[..., None]

        emb = einsum('i, j -> i j', pos, self.inv_freq)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb * self.scale

class RelativePositionBias(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
        # 调用父类的初始化函数
        super().__init__()
        # 设置模型的缩放比例
        self.scale = scale
        # 设置是否使用因果关系
        self.causal = causal
        # 设置桶的数量
        self.num_buckets = num_buckets
        # 设置最大距离
        self.max_distance = max_distance
        # 创建相对注意力偏置的嵌入层
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    # 静态方法,用于计算相对位置的桶索引
    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        # 初始化返回值
        ret = 0
        # 计算相对位置的负值
        n = -relative_position
        # 如果不是因果关系,调整桶的数量
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        # 计算最大精确值
        max_exact = num_buckets // 2
        is_small = n < max_exact

        # 计算大值时的桶索引
        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        # 根据大小选择桶索引
        ret += torch.where(is_small, n, val_if_large)
        return ret

    # 返回设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(self, i, j):
        # 获取设备信息
        device = self.device
        # 生成查询位置
        q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
        # 生成键位置
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        # 计算相对位置
        rel_pos = k_pos[None, :] - q_pos[:, None]
        # 计算相对位置的桶索引
        rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
        # 获取相对注意力偏置值
        values = self.relative_attention_bias(rp_bucket)
        # 重排形状
        bias = rearrange(values, 'i j h -> h i j')
        return bias * self.scale
class DynamicPositionBias(nn.Module):
    # 定义动态位置偏置类,继承自 nn.Module
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        # 初始化函数,接受维度、头数、深度、是否对距离取对数、是否进行归一化等参数
        super().__init__()
        # 调用父类的初始化函数
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        # 断言深度必须大于等于1
        self.log_distance = log_distance
        # 设置是否对距离取对数的标志

        self.mlp = nn.ModuleList([])
        # 初始化多层感知机模块列表

        self.mlp.append(Sequential(
            nn.Linear(1, dim),
            LayerNorm(dim) if norm else None,
            nn.SiLU()
        ))
        # 向多层感知机模块列表中添加线性层、归一化层和激活函数

        for _ in range(depth - 1):
            self.mlp.append(Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else None,
                nn.SiLU()
            ))
        # 根据深度循环添加多层感知机模块

        self.mlp.append(nn.Linear(dim, heads)
        # 向多层感知机模块列表中添加线性层,输出头数

    @property
    def device(self):
        # 定义设备属性,返回参数的设备
        return next(self.parameters()).device

    def forward(self, i, j):
        # 前向传播函数,接受输入i和j
        assert i == j
        # 断言i等于j
        n, device = j, self.device
        # 设置n为j,获取设备信息

        # get the (n x n) matrix of distances
        # 获取距离的(n x n)矩阵
        seq_arange = torch.arange(n, device = device)
        context_arange = torch.arange(n, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (n - 1)

        # input to continuous positions MLP
        # 连续位置多层感知机的输入
        pos = torch.arange(-n + 1, n, device = device).float()
        pos = rearrange(pos, '... -> ... 1')

        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
        # 如果需要对距离取对数,则进行对数操作

        for layer in self.mlp:
            pos = layer(pos)
        # 遍历多层感知机模块列表,对位置进行处理

        # get position biases        
        # 获取位置偏置
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias
        # 返回位置偏置

class AlibiPositionalBias(nn.Module):
    # 定义Alibi位置偏置类,继承自 nn.Module
    def __init__(self, heads, total_heads, **kwargs):
        # 初始化函数,接受头数和总头数等参数
        super().__init__()
        # 调用父类的初始化函数
        self.heads = heads
        self.total_heads = total_heads

        slopes = Tensor(self._get_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        self.register_buffer('slopes', slopes, persistent = False)
        self.register_buffer('bias', None, persistent = False)
        # 初始化斜率和偏置

    def get_bias(self, i, j, device):
        # 定义获取偏置的函数,接受i、j和设备参数
        i_arange = torch.arange(j - i, j, device = device)
        j_arange = torch.arange(j, device = device)
        bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
        return bias
        # 返回偏置

    @staticmethod
    def _get_slopes(heads):
        # 定义获取斜率的静态方法,接受头数参数
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]
        # 定义获取2的幂次方斜率的函数

        if math.log2(heads).is_integer():
            return get_slopes_power_of_2(heads)
        # 如果头数是2的幂次方,则返回对应斜率

        closest_power_of_2 = 2 ** math.floor(math.log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
        # 否则返回最接近的2的幂次方的斜率和补充的斜率

    @property
    def device(self):
        # 定义设备属性,返回缓冲区的设备
        return next(self.buffers()).device

    def forward(self, i, j):
        # 前向传播函数,接受输入i和j
        h, device = self.total_heads, self.device

        if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
            return self.bias[..., -i:, -j:]
        # 如果偏置存在且形状符合要求,则返回偏置

        bias = self.get_bias(i, j, device)
        bias = bias * self.slopes
        # 计算偏置并乘以斜率

        num_heads_unalibied = h - bias.shape[0]
        bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
        self.register_buffer('bias', bias, persistent = False)
        # 对未校准的头数进行填充

        return self.bias
        # 返回偏置

class RotaryEmbedding(nn.Module):
    # 定义旋转嵌入类,继承自 nn.Module
    def __init__(
        self,
        dim,
        use_xpos = False,
        scale_base = 512,
        interpolation_factor = 1.,
        base = 10000,
        base_rescale_factor = 1.
        # 初始化函数,接受维度、是否使用x位置、缩放基数、插值因子、基数和基数重新缩放因子等参数
    ):
        # 调用父类的构造函数
        super().__init__()
        # 根据 reddit 用户 bloc97 的建议,将旋转嵌入重新缩放到更长的序列长度,而无需微调
        # 与 NTK 文献有一定联系
        # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
        base *= base_rescale_factor ** (dim / (dim - 2))

        # 计算频率的倒数
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率的倒数作为缓冲区
        self.register_buffer('inv_freq', inv_freq)

        assert interpolation_factor >= 1.
        # 设置插值因子
        self.interpolation_factor = interpolation_factor

        if not use_xpos:
            # 如果不使用 xpos,则将缩放设置为 None
            self.register_buffer('scale', None)
            return

        # 计算缩放
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)

        self.scale_base = scale_base
        # 将缩放作为缓冲区
        self.register_buffer('scale', scale)

    # 根据序列长度进行前向传播
    def forward_from_seq_len(self, seq_len):
        device = self.inv_freq.device

        t = torch.arange(seq_len, device = device)
        return self.forward(t)

    # 禁用自动混合精度
    @autocast(enabled = False)
    def forward(self, t):
        # 计算最大位置
        max_pos = t.max()+1

        # 计算频率
        freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
        freqs = torch.cat((freqs, freqs), dim = -1)

        if not exists(self.scale):
            return freqs, 1.

        # 计算幂次
        power = (t - (max_pos // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim = -1)

        return freqs, scale
# 定义一个函数,将输入张量 x 进行重新排列,将最后两个维度中的第一个维度 j 换到倒数第二个维度
def rotate_half(x):
    x = rearrange(x, '... (j d) -> ... j d', j = 2)
    # 将 x 拆分为两部分 x1 和 x2,根据倒数第二个维度进行拆分
    x1, x2 = x.unbind(dim = -2)
    # 将 x2 取负值,然后与 x1 进行拼接,得到旋转后的张量
    return torch.cat((-x2, x1), dim = -1)

# 定义一个函数,应用旋转位置嵌入到输入张量 t 上
@autocast(enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
    # 获取旋转维度和序列长度
    rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
    # 截取与序列长度相同的频率信息
    freqs = freqs[-seq_len:, :]
    # 如果 scale 是张量,则截取与序列长度相同的部分
    scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale

    # 如果输入张量 t 和频率信息 freqs 的维度分别为 4 和 3
    if t.ndim == 4 and freqs.ndim == 3:
        # 将频率信息维度扩展为 4 维
        freqs = rearrange(freqs, 'b n d -> b 1 n d')

    # 部分旋转嵌入,Wang et al. GPT-J
    # 将输入张量 t 拆分为旋转部分 t 和未旋转部分 t_unrotated
    t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
    # 计算旋转后的张量 t
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
    # 将旋转后的张量 t 与未旋转部分拼接,返回结果
    return torch.cat((t, t_unrotated), dim = -1)

# norms

# 定义一个缩放层,用于对输入进行缩放
class Scale(nn.Module):
    def __init__(self, value, fn):
        super().__init__()
        self.value = value
        self.fn = fn

    def forward(self, x, **kwargs):
        # 对输入进行处理
        out = self.fn(x, **kwargs)
        # 定义缩放函数
        scale_fn = lambda t: t * self.value

        # 如果输出不是元组,则对输出进行缩放处理
        if not isinstance(out, tuple):
            return scale_fn(out)

        # 如果输出是元组,则对第一个元素进行缩放处理
        return (scale_fn(out[0]), *out[1:])

# 定义一个缩放归一化层
class ScaleNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))

    def forward(self, x):
        # 计算输入张量的范数,并进行归一化处理
        norm = torch.norm(x, dim = -1, keepdim = True)
        return x / norm.clamp(min = self.eps) * self.g

# 定义一个 LayerNorm ��
class LayerNorm(nn.Module):
    def __init__(self, dim):
        """
        bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
        """
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        # 使用 F.layer_norm 进行 LayerNorm 处理
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 如果 PyTorch 版本大于等于 2.1.0,则将 LayerNorm 替换为 nn.LayerNorm,并设置 bias 为 False
if version.parse(torch.__version__) >= version.parse('2.1.0'):
    LayerNorm = partial(nn.LayerNorm, bias = False)

# 定义一个 RMSNorm 层
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # 对输入进行归一化处理,并乘以缩放因子和参数 g
        return F.normalize(x, dim = -1) * self.scale * self.g

# 定义一个简单的 RMSNorm 层
class SimpleRMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5

    def forward(self, x):
        # 对输入进行归一化处理,并乘以缩放因子
        return F.normalize(x, dim = -1) * self.scale

# residual and residual gates

# 定义一个残差连接层
class Residual(nn.Module):
    def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
        super().__init__()
        self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
        self.scale_residual_constant = scale_residual_constant

    def forward(self, x, residual):
        # 如果存在残差缩放参数,则对残差进行缩放处理
        if exists(self.residual_scale):
            residual = residual * self.residual_scale

        # 如果缩放常数不为 1,则对残差进行缩放处理
        if self.scale_residual_constant != 1:
            residual = residual * self.scale_residual_constant

        # 返回残差连接结果
        return x + residual

# 定义一个 GRU 门控单元层
class GRUGating(nn.Module):
    def __init__(self, dim, scale_residual = False, **kwargs):
        super().__init__()
        self.gru = nn.GRUCell(dim, dim)
        self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None

    def forward(self, x, residual):
        # 如果存在残差缩放参数,则对残差进行缩放处理
        if exists(self.residual_scale):
            residual = residual * self.residual_scale

        # 使用 GRU 单元进行门控处理
        gated_output = self.gru(
            rearrange(x, 'b n d -> (b n) d'),
            rearrange(residual, 'b n d -> (b n) d')
        )

        # 将门控输出重塑为与输入相同的形状
        return gated_output.reshape_as(x)

# token shifting

# 定义一个函数,对输入张量进行平移操作
def shift(t, amount, mask = None):
    if amount == 0:
        return t
    else:
        # 如果平移量大于输入张量的长度,则取最大值
        amount = min(amount, t.shape[1])

    # 如果存在掩码,则对输入张量进行掩码填充
    if exists(mask):
        t = t.masked_fill(~mask[..., None], 0.)

    # 在指定维度上对输入张量进行填充操作
    return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)

# 定义一个 ShiftTokens 类,用于对输入进行平移操作
class ShiftTokens(nn.Module):
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)
    # 定义一个前向传播函数,接受输入 x 和关键字参数 kwargs
    def forward(self, x, **kwargs):
        # 从关键字参数 kwargs 中获取名为 'mask' 的值,如果没有则为 None
        mask = kwargs.get('mask', None)
        # 获取位移列表
        shifts = self.shifts
        # 计算段数
        segments = len(shifts)
        # 计算每个段的特征数
        feats_per_shift = x.shape[-1] // segments
        # 将输入 x 按特征数分割成多个张量
        splitted = x.split(feats_per_shift, dim=-1)
        # 将分割后的张量分为需要进行位移的段和剩余部分
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        # 对需要进行位移的段进行位移操作,使用 map 函数和 lambda 表达式
        segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
        # 将位移后的段和剩余部分拼接在一起
        x = torch.cat((*segments_to_shift, *rest), dim=-1)
        # 调用 self.fn 函数对拼接后的张量进行处理,返回结果
        return self.fn(x, **kwargs)
# 定义 GLU 类,用于实现门控线性单元
class GLU(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        activation: Callable,
        mult_bias = False
    ):
        super().__init__()
        self.act = activation
        self.proj = nn.Linear(dim_in, dim_out * 2)
        self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.

    # 前向传播函数
    def forward(self, x):
        # 将输入通过线性变换后分成两部分
        x, gate = self.proj(x).chunk(2, dim = -1)
        # 返回门控线性单元的输出
        return x * self.act(gate) * self.mult_bias

# 定义 FeedForward 类,用于实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        mult = 4,
        glu = False,
        glu_mult_bias = False,
        swish = False,
        relu_squared = False,
        post_act_ln = False,
        dropout = 0.,
        no_bias = False,
        zero_init_output = False
    ):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)

        # 根据参数选择激活函数
        if relu_squared:
            activation = ReluSquared()
        elif swish:
            activation = nn.SiLU()
        else:
            activation = nn.GELU()

        # 根据参数选择网络结构
        if glu:
            project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
        else:
            project_in = nn.Sequential(
                nn.Linear(dim, inner_dim, bias = not no_bias),
                activation
            )

        # 构建前馈神经网络
        self.ff = Sequential(
            project_in,
            LayerNorm(inner_dim) if post_act_ln else None,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out, bias = not no_bias)
        )

        # 初始化最后一层线性层的权重为0
        if zero_init_output:
            init_zero_(self.ff[-1])

    # 前向传播函数
    def forward(self, x):
        return self.ff(x)

# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = DEFAULT_DIM_HEAD,
        dim_context = None,
        heads = 8,
        causal = False,
        flash = False,
        talking_heads = False,
        head_scale = False,
        sparse_topk = None,
        num_mem_kv = 0,
        dropout = 0.,
        on_attn = False,
        gate_value_heads = False,
        swiglu_values = False,
        gate_values = False,
        zero_init_output = False,
        max_attend_past = None,
        qk_norm = False,
        qk_norm_groups = 1,
        qk_norm_scale = 10,
        qk_norm_dim_scale = False,
        one_kv_head = False,
        kv_heads = None,
        shared_kv = False,
        value_dim_head = None,
        tensor_product = False,      # https://arxiv.org/abs/2208.06061
        add_zero_kv = False,         # same as add_zero_attn in pytorch
        rotary_embed_values = False,
        onnxable = False
    # 前向传播函数
    def forward(
        self,
        x,
        context = None,
        mask = None,
        context_mask = None,
        attn_mask = None,
        rel_pos = None,
        rotary_pos_emb = None,
        prev_attn = None,
        mem = None,
        mem_mask = None,
        return_intermediates = False,
        cache: Optional[Intermediates] = None,
class AttentionLayers(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        dim,
        depth,
        heads = 8,
        causal = False,
        cross_attend = False,
        only_cross = False,
        use_scalenorm = False,
        use_rmsnorm = False,
        use_simple_rmsnorm = False,
        alibi_pos_bias = False,
        alibi_num_heads = None,
        rel_pos_bias = False,
        rel_pos_num_buckets = 32,
        rel_pos_max_distance = 128,
        dynamic_pos_bias = False,
        dynamic_pos_bias_log_distance = False,
        dynamic_pos_bias_mlp_depth = 2,
        dynamic_pos_bias_norm = False,
        rotary_pos_emb = False,
        rotary_emb_dim = None,
        rotary_xpos = False,
        rotary_interpolation_factor = 1.,
        rotary_xpos_scale_base = 512,
        rotary_base_rescale_factor = 1.,
        custom_layers = None,
        sandwich_coef = None,
        par_ratio = None,
        weight_tie_layers = False,   # Albert - https://arxiv.org/abs/1909.11942
        layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
        residual_attn = False,
        cross_residual_attn = False,
        macaron = False,
        pre_norm = True,
        pre_norm_has_final_norm = True,
        gate_residual = False,
        scale_residual = False,
        scale_residual_constant = 1.,
        shift_tokens = 0,
        sandwich_norm = False,
        resi_dual = False,
        resi_dual_scale = 1.,
        zero_init_branch_output = False,
        layer_dropout = 0.,
        cross_attn_tokens_dropout = 0.,
        disable_abs_pos_emb = None,
        **kwargs
    # 前向传播函数,接收输入数据并进行模型计算
    def forward(
        self,
        x,
        context = None,
        mask = None,
        context_mask = None,
        attn_mask = None,
        self_attn_kv_mask = None,
        mems = None,
        mem_masks = None,
        seq_start_pos: Optional[Tensor] = None,
        cache: Optional[LayerIntermediates] = None,
        cache_age = 1,
        return_hiddens = False,
        rotary_pos_emb = None
class Encoder(AttentionLayers):
    # 定义编码器类,继承自AttentionLayers类
    def __init__(self, **kwargs):
        # 初始化函数,接受任意关键字参数
        assert 'causal' not in kwargs, 'cannot set causality on encoder'
        # 断言关键字参数中不包含'causal',否则抛出异常
        super().__init__(causal = False, **kwargs)
        # 调用父类的初始化函数,设置causal参数为False,并传入其他关键字参数

class Decoder(AttentionLayers):
    # 定义解码器类,继承自AttentionLayers类
    def __init__(self, **kwargs):
        # 初始化函数,接受任意关键字参数
        assert 'causal' not in kwargs, 'cannot set causality on decoder'
        # 断言关键字参数中不包含'causal',否则抛出异常
        super().__init__(causal = True, **kwargs)
        # 调用父类的初始化函数,设置causal参数为True,并传入其他关键字参数

class PrefixDecoder(AttentionLayers):
    # 定义前缀解码器类,继承自AttentionLayers类
    def __init__(self, **kwargs):
        # 初始化函数,接受任意关键字参数
        assert 'causal' not in kwargs, 'cannot set causality on decoder'
        # 断言关键字参数中不包含'causal',否则抛出异常
        super().__init__(causal = False, **kwargs)
        # 调用父类的初始化函数,设置causal参数为False,并传入其他关键字参数

    def forward(
        self,
        x,
        *args,
        attn_mask = None,
        prefix_attn_len = None,
        **kwargs
    ):
        # 前向传播函数,接受输入x和任意位置参数args,注意力掩码attn_mask和前缀注意力长度prefix_attn_len,以及任意关键字参数kwargs
        b, n, device = x.shape[0], x.shape[1], x.device
        # 获取输入x的批量大小b,序列长度n,设备device
        causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
        # 创建一个全为1的张量作为因果掩码,上三角部分为True,下三角部分为False

        forwarded_mask = ~causal_mask
        # 计算非因果掩码,即上三角部分为False,下三角部分为True

        if exists(prefix_attn_len):
            # 如果前缀注意力长度存在
            if isinstance(prefix_attn_len, int):
                # 如果前缀注意力长度是整数
                prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
                # 创建一个形状为(b,)的张量,填充值为前缀注意力长度,设备为device

            prefix_mask = torch.arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
            # 创建前缀掩码,根据前缀注意���长度生成

            forwarded_mask = forwarded_mask | prefix_mask
            # 更新前向掩码,将前缀掩码应用到前向掩码中

        if exists(attn_mask):
            # 如果注意力掩码存在
            forwarded_mask = forwarded_mask & attn_mask
            # 更新前向掩码,将注意力掩码应用到前向掩码中

        return super().forward(x, *args, attn_mask = forwarded_mask, **kwargs)
        # 调用父类的前向传播函数,传入更新后的注意力掩码参数

class CrossAttender(AttentionLayers):
    # 定义交叉注意力层类,继承自AttentionLayers类
    def __init__(self, **kwargs):
        # 初始化函数,接受任意关键字参数
        super().__init__(cross_attend = True, only_cross = True, **kwargs)
        # 调用父类的初始化函数,设置cross_attend和only_cross参数为True,并传入其他关键字参数

class ViTransformerWrapper(nn.Module):
    # 定义ViTransformerWrapper类,继承自nn.Module类
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        attn_layers: Encoder,
        channels = 3,
        num_classes = None,
        post_emb_norm = False,
        num_register_tokens = 0,
        emb_dropout = 0.
    ):
        # 初始化函数,接受命名关键字参数
        super().__init__()
        # 调用父类的初始化函数
        assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
        # 断言图像尺寸能被补丁尺寸整除,否则抛出异常
        dim = attn_layers.dim
        # 获取注意力层的维度
        num_patches = (image_size // patch_size) ** 2
        # 计算图像中的补丁数量
        patch_dim = channels * patch_size ** 2
        # 计算补丁的维度

        self.patch_size = patch_size
        # 设置对象属性patch_size为传入的补丁尺寸

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        # 创建位置嵌入参数,形状为(1, num_patches, dim),初始化为随机值

        has_register_tokens = num_register_tokens > 0
        # 判断是否存在注册令牌
        self.has_register_tokens = has_register_tokens
        # 设置对象属性has_register_tokens为判断结果

        if has_register_tokens:
            # 如果存在注册令牌
            self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
            # 创建注册令牌参数,形状为(num_register_tokens, dim),初始化为随机值

        self.patch_to_embedding = nn.Sequential(
            LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            LayerNorm(dim)
        )
        # 创建补丁到嵌入的序列模块

        self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
        # 根据post_emb_norm参数选择是否进行嵌入后的归一化
        self.dropout = nn.Dropout(emb_dropout)
        # 创建丢弃层,用于嵌入的丢弃

        self.attn_layers = attn_layers
        # 设置对象属性attn_layers为传入的注意力层

        self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
        # 创建MLP头部,根据是否存在类别数量选择是否添加线性层

    def forward(
        self,
        img,
        return_embeddings = False,
        return_logits_and_embeddings = False
    ):
        # 前向传播函数,接受输入图像img,返回嵌入、逻辑和嵌入的标志
        b, p = img.shape[0], self.patch_size
        # 获取输入图像的批量大小b和补丁大小p

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        # 重排输入图像,将其转换为形状为(b, h*w, p1*p2*c)的张量
        x = self.patch_to_embedding(x)
        # 将补丁转换为嵌入

        n = x.shape[1]
        # 获取嵌入的序列长度n

        x = x + self.pos_embedding[:, :n]
        # 添加位置嵌入到嵌入中

        x = self.post_emb_norm(x)
        # 对嵌入进行归一化
        x = self.dropout(x)
        # 对嵌入进行丢弃

        if self.has_register_tokens:
            # 如果存在注册令牌
            r = repeat(self.register_tokens, 'n d -> b n d', b = b)
            # 重复注册令牌,形状为(b, num_register_tokens, dim)
            x, ps = pack((x, r), 'b * d')
            # 打包嵌入和注册令牌

        embed = self.attn_layers(x)
        # 使用注意力层处理嵌入

        if self.has_register_tokens:
            # 如果存在注册令牌
            embed, _ = unpack(embed, ps, 'b * d')
            # 解包嵌入

        assert at_most_one_of(return_embeddings, return_logits_and_embeddings)
        # 断言返回嵌入和逻辑的标志中最多只有一个为True

        if not exists(self.mlp_head) or return_embeddings:
            # 如果MLP头部不存在或者需要返回嵌入
            return embed
            # 返回嵌入

        pooled = embed.mean(dim = -2)
        # 对嵌入进行平均池化
        logits = self.mlp_head(pooled)
        # 使用MLP头部生成逻辑

        if not return_logits_and_embeddings:
            # 如果不需要返回逻辑和嵌入
            return logits
            # 返回逻辑

        return logits, embed
        # 返回逻辑和嵌入
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        num_tokens,  # 令牌数量
        max_seq_len,  # 最大序列长度
        attn_layers: AttentionLayers,  # 注意力层对象
        embed_num_tokens: Dict[str, int] = dict(),  # 嵌入令牌数量的字典,默认为空
        emb_dim = None,  # 嵌入维度,默认为空
        max_mem_len = 0,  # 最大记忆长度,默认为0
        shift_mem_down = 0,  # 记忆向下移动的步数,默认为0
        emb_dropout = 0.,  # 嵌入层的dropout率,默认为0
        post_emb_norm = False,  # 是否对嵌入后进行归一化,默认为False
        num_memory_tokens = None,  # 记忆令牌数量,默认为空
        memory_tokens_interspersed_every = None,  # 记忆令牌插入间隔,默认为空
        tie_embedding = False,  # 是否共享嵌入权重,默认为False
        logits_dim = None,  # logits维度,默认为空
        use_abs_pos_emb = True,  # 是否使用绝对位置编码,默认为True
        scaled_sinu_pos_emb = False,  # 是否使用缩放的正弦位置编码,默认为False
        l2norm_embed = False,  # 是否对嵌入进行L2归一化,默认为False
        emb_frac_gradient = 1.,  # 梯度分配给嵌入的比例,默认为1
        attn_z_loss_weight = 1e-4,  # 注意力z损失的权重,默认为1e-4
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 获取注意力层的维度
        dim = attn_layers.dim
        # 如果嵌入维度为空,则设置为注意力层的维度
        emb_dim = default(emb_dim, dim)
        self.emb_dim = emb_dim
        self.num_tokens = num_tokens

        self.max_seq_len = max_seq_len
        self.max_mem_len = max_mem_len
        self.shift_mem_down = shift_mem_down

        self.l2norm_embed = l2norm_embed
        # 创建令牌嵌入层对象
        self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)

        # 判断是否不需要绝对位置编码
        no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)

        # 根据条件选择不同的位置编码方式
        if no_abs_pos_emb:
            self.pos_emb = always(0)
        elif scaled_sinu_pos_emb:
            self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
        else:
            self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)

        # 初始化额外的嵌入层
        self.embeds = None

        # 如果有额外的嵌入令牌数量,则创建对应的嵌入层
        if len(embed_num_tokens) > 0:
            self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})

        # 设置梯度分配给嵌入的比例
        self.emb_frac_gradient = emb_frac_gradient

        # 对嵌入后的结果进行归一化
        self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
        self.emb_dropout = nn.Dropout(emb_dropout)

        # 投影嵌入到指定维度
        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
        self.attn_layers = attn_layers

        # 初始化模型参数
        self.init_()

        # 设置logits的维度
        logits_dim = default(logits_dim, num_tokens)
        # 如果不共享嵌入权重,则创建线性层
        self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()

        # 设置记忆令牌
        num_memory_tokens = default(num_memory_tokens, 0)
        self.num_memory_tokens = num_memory_tokens
        if num_memory_tokens > 0:
            self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))

        self.memory_tokens_interspersed_every = memory_tokens_interspersed_every

        # 判断是否可以进行缓存的kv解码
        self.can_cache_kv = self.num_memory_tokens == 0
        self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb

    # 初始化函数,根据是否进行L2归一化初始化权重
    def init_(self):
        if self.l2norm_embed:
            nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
            if not isinstance(self.pos_emb, always):
                nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
            return

        nn.init.kaiming_normal_(self.token_emb.emb.weight)

    # 前向传播函数
    def forward(
        self,
        x,  # 输入数据
        return_embeddings = False,  # 是否返回嵌入结果
        return_logits_and_embeddings = False,  # 是否返回logits和嵌入结果
        return_intermediates = False,  # 是否返回中间结果
        mask = None,  # 掩码
        return_mems = False,  # 是否返回记忆
        return_attn = False,  # 是否返回注意力
        mems = None,  # 记忆
        mem_masks = None,  # 记忆掩码
        pos = None,  # 位置编码
        prepend_embeds = None,  # 前置嵌入
        prepend_mask = None,  # 前置掩码
        embed_ids: Dict[str, Tensor] = dict(),  # 嵌入ID的字典
        sum_embeds = None,  # 嵌入求和
        return_attn_z_loss = False,  # 是否返回注意力z损失
        attn_z_loss_weight = 1e-4,  # 注意力z损失的权重
        seq_start_pos = None,  # 序列起始位置
        cache: Optional[LayerIntermediates] = None,  # 缓存
        **kwargs  # 其他参数
class XTransformer(nn.Module):
    # 定义 XTransformer 类,继承自 nn.Module
    def __init__(
        self,
        *,
        dim,
        tie_token_emb = False,
        ignore_index = -100,
        pad_value = 0,
        cross_attn_tokens_dropout = 0.,
        **kwargs
    ):
        # 初始化函数,接受一系列参数
        super().__init__()
        # 调用父类的初始化函数

        # 将参数按照前缀分组并修剪
        enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
        dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)

        # 断言确保编码器或解码器的维度必须使用 `dim` 关键字设置
        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'

        # 从参数中选择并弹出 'num_tokens' 和 'max_seq_len',并设置默认值
        enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
        enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
        enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
        enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
        enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)

        dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
        dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
        dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
        dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)

        # 设置交叉注意力的 tokens dropout 参数
        self.cross_attn_tokens_dropout = cross_attn_tokens_dropout

        # 创建编码器和解码器的 TransformerWrapper 对象
        self.encoder = TransformerWrapper(
            **enc_transformer_kwargs,
            attn_layers = Encoder(dim = dim, **enc_kwargs)
        )

        self.decoder = TransformerWrapper(
            **dec_transformer_kwargs,
            attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
        )

        # 如果 tie_token_emb 为 True,则共享解码器的 token_emb 层和编码器的 token_emb 层
        if tie_token_emb:
            self.decoder.token_emb = self.encoder.token_emb

        # 将解码器包装在 AutoregressiveWrapper 中
        self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)

    @torch.no_grad()
    def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
        # 生成函数,接受输入序列和输出序列的起始位置、长度等参数
        encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
        # 使用编码器对输入序列进行编码,返回编码结果
        return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
        # 使用解码器生成输出序列

    def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
        # 前向传播函数,接受源序列、目标序列、掩码等参数

        # 使用编码器对源序列进行编码
        enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)

        # 如果存在源序列的前置嵌入和掩码,则在掩码上进行填充
        if exists(src_prepend_embeds) and exists(mask):
            mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)

        # 如果处于训练状态且交叉注意力 tokens dropout 大于 0,则对编码结果进行 dropout
        if self.training and self.cross_attn_tokens_dropout > 0:
            enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)

        # 使用解码器生成输出序列
        out = self.decoder(tgt, context = enc, context_mask = mask)
        return out

.\lucidrains\x-transformers\x_transformers\__init__.py

# 从 x_transformers.x_transformers 模块中导入以下类
from x_transformers.x_transformers import (
    XTransformer,  # XTransformer 类,用于定义 Transformer 模型
    Encoder,  # Encoder 类,用于定义编码器
    Decoder,  # Decoder 类,用于定义解码器
    PrefixDecoder,  # PrefixDecoder 类,用于定义前缀解码器
    CrossAttender,  # CrossAttender 类,用于定义交叉注意力机制
    Attention,  # Attention 类,用于定义注意力机制
    TransformerWrapper,  # TransformerWrapper 类,用于包装 Transformer 模型
    ViTransformerWrapper  # ViTransformerWrapper 类,用于包装 Vision Transformer 模型
)

# 从 x_transformers.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

# 从 x_transformers.nonautoregressive_wrapper 模块中导入 NonAutoregressiveWrapper 类
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper

# 从 x_transformers.continuous 模块中导入以下类
from x_transformers.continuous import (
    ContinuousTransformerWrapper,  # ContinuousTransformerWrapper 类,用于包装连续 Transformer 模型
    ContinuousAutoregressiveWrapper  # ContinuousAutoregressiveWrapper 类,用于包装连续自回归模型
)

# 从 x_transformers.xval 模块中导入以下类
from x_transformers.xval import (
    XValTransformerWrapper,  # XValTransformerWrapper 类,用于包装交叉验证 Transformer 模型
    XValAutoregressiveWrapper  # XValAutoregressiveWrapper 类,用于包装交叉验证自回归模型
)

# 从 x_transformers.xl_autoregressive_wrapper 模块中导入 XLAutoregressiveWrapper 类
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper

# 从 x_transformers.dpo 模块中导入 DPO 类
from x_transformers.dpo import (
    DPO  # DPO 类,用于定义 Discrete-Continuous-Optimization 模型
)

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings

Install

$ pip install x-unet

Usage

import torch
from x_unet import XUnet

unet = XUnet(
    dim = 64,
    channels = 3,
    dim_mults = (1, 2, 4, 8),
    nested_unet_depths = (7, 4, 2, 1),     # nested unet depths, from unet-squared paper
    consolidate_upsample_fmaps = True,     # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
)

img = torch.randn(1, 3, 256, 256)
out = unet(img) # (1, 3, 256, 256)

For 3d (video or CT / MRI scans)

import torch
from x_unet import XUnet

unet = XUnet(
    dim = 64,
    frame_kernel_size = 3,                 # set this to greater than 1
    channels = 3,
    dim_mults = (1, 2, 4, 8),
    nested_unet_depths = (5, 4, 2, 1),     # nested unet depths, from unet-squared paper
    consolidate_upsample_fmaps = True,     # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
    weight_standardize = True
)

video = torch.randn(1, 3, 10, 128, 128)    # (batch, channels, frames, height, width)
out = unet(video) # (1, 3, 10, 128, 128)

Todo

Citations

@article{Ronneberger2015UNetCN,
    title   = {U-Net: Convolutional Networks for Biomedical Image Segmentation},
    author  = {Olaf Ronneberger and Philipp Fischer and Thomas Brox},
    journal = {ArXiv},
    year    = {2015},
    volume  = {abs/1505.04597}
}
@article{Qin2020U2NetGD,
    title   = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
    author  = {Xuebin Qin and Zichen Vincent Zhang and Chenyang Huang and Masood Dehghan and Osmar R Zaiane and Martin J{\"a}gersand},
    journal = {ArXiv},
    year    = {2020},
    volume  = {abs/2005.09007}
}
@inproceedings{Henry2020QueryKeyNF,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen},
    booktitle = {FINDINGS},
    year    = {2020}
}
@article{Qiao2019WeightS,
    title   = {Weight Standardization},
    author  = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1903.10520}
}
@article{Shleifer2021NormFormerIT,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Sam Shleifer and Jason Weston and Myle Ott},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.09456}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}
@inproceedings{Woo2023ConvNeXtVC,
    title   = {ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
    author  = {Sanghyun Woo and Shoubhik Debnath and Ronghang Hu and Xinlei Chen and Zhuang Liu and In-So Kweon and Saining Xie},
    year    = {2023}
}

.\lucidrains\x-unet\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'x-unet',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.3.1',  # 版本号
  license='MIT',  # 许可证
  description = 'X-Unet',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/x-unet',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'biomedical segmentation',
    'medical deep learning',
    'unets',
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.4',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\x-unet\x_unet\x_unet.py

# 导入必要的库
from functools import partial
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
# 导入 einops 库中的函数和类
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# 导入 beartype 库中的函数和类型
from beartype import beartype
from beartype.typing import Tuple, Union, Optional

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val, d):
    return val if exists(val) else d

# 检查一个数是否为2的幂
def is_power_two(n):
    return math.log2(n).is_integer()

# 检查一个数是否可以被另一个数整除
def divisible_by(num, denom):
    return (num % denom) == 0

# 将值转换为元组
def cast_tuple(val, length = None):
    if isinstance(val, list):
        val = tuple(val)

    output = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length):
        assert len(output) == length

    return output

# 辅助类

# 上采样函数
def Upsample(dim, dim_out):
    return nn.ConvTranspose3d(dim, dim_out, (1, 4, 4), (1, 2, 2), (0, 1, 1))

# 下采样函数
def Downsample(dim, dim_out):
    return nn.Sequential(
        Rearrange('b c f (h s1) (w s2) -> b (c s1 s2) f h w', s1 = 2, s2 = 2),
        nn.Conv3d(dim * 4, dim_out, 1)
    )

# 标准化

# 残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 层归一化
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + eps).sqrt() * self.gamma

# 权重标准化卷积
class WeightStandardizedConv3d(nn.Conv3d):
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight

        mean = reduce(weight, 'o ... -> o 1 1 1 1', 'mean')
        var = reduce(weight, 'o ... -> o 1 1 1 1', partial(torch.var, unbiased = False))
        weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

# ResNet 块

# 块类
class Block(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        weight_standardize = False,
        frame_kernel_size = 1
    ):
        super().__init__()
        kernel_conv_kwargs = partial(kernel_and_same_pad, frame_kernel_size)
        conv = nn.Conv3d if not weight_standardize else WeightStandardizedConv3d

        self.proj = conv(dim, dim_out, **kernel_conv_kwargs(3, 3))
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.proj(x)
        x = self.norm(x)
        return self.act(x)

# ResNet 块类
class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        frame_kernel_size = 1,
        nested_unet_depth = 0,
        nested_unet_dim = 32,
        weight_standardize = False
    ):
        super().__init__()
        self.block1 = Block(dim, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)

        if nested_unet_depth > 0:
            self.block2 = NestedResidualUnet(dim_out, depth = nested_unet_depth, M = nested_unet_dim, frame_kernel_size = frame_kernel_size, weight_standardize = weight_standardize, add_residual = True)
        else:
            self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)

        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        return h + self.res_conv(x)

# ConvNeXT 2

# 全局响应归一化
class GRN(nn.Module):
    """ global response normalization, proposed in updated convnext paper """
    # 初始化函数,设置参数维度和容差值
    def __init__(self, dim, eps = 1e-5):
        # 调用父类的初始化函数
        super().__init__()
        # 设置容差值
        self.eps = eps
        # 初始化 gamma 参数为全零张量
        self.gamma = nn.Parameter(torch.zeros(dim, 1, 1, 1))
        # 初始化 bias 参数为全零张量
        self.bias = nn.Parameter(torch.zeros(dim, 1, 1, 1))

    # 前向传播函数
    def forward(self, x):
        # 计算 x 在指定维度上的 L2 范数
        spatial_l2_norm = x.norm(p = 2, dim = (2, 3, 4), keepdim = True)
        # 计算特征的归一化值
        feat_norm = spatial_l2_norm / spatial_l2_norm.mean(dim = -1, keepdim = True).clamp(min = self.eps)
        # 返回经过归一化和缩放后的特征值
        return x * feat_norm * self.gamma + self.bias + x
# 定义一个卷积块类,用于构建下一个卷积块
class ConvNextBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        *,
        mult = 2,
        frame_kernel_size = 1,
        nested_unet_depth = 0,
        nested_unet_dim = 32
    ):
        super().__init__()
        kernel_conv_kwargs = partial(kernel_and_same_pad, frame_kernel_size)

        # 深度卷积
        self.ds_conv = nn.Conv3d(dim, dim, **kernel_conv_kwargs(7, 7), groups = dim)

        inner_dim = dim_out * mult

        # 构建一个包含多个层的神经网络
        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Conv3d(dim, inner_dim, **kernel_conv_kwargs(3, 3), groups = dim_out),
            nn.GELU(),
            GRN(inner_dim),
            nn.Conv3d(inner_dim, dim_out, **kernel_conv_kwargs(3, 3), groups = dim_out)
        )

        # 嵌套的残差 UNet
        self.nested_unet = NestedResidualUnet(dim_out, depth = nested_unet_depth, M = nested_unet_dim, add_residual = True) if nested_unet_depth > 0 else nn.Identity()

        # 残差卷积
        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        h = self.ds_conv(x)
        h = self.net(h)
        h = self.nested_unet(h)
        return h + self.res_conv(x)

# 前馈神经网络
def FeedForward(dim, mult = 4.):
    inner_dim = int(dim * mult)
    return Residual(nn.Sequential(
        LayerNorm(dim),
        nn.Conv3d(dim, inner_dim, 1, bias = False),
        nn.GELU(),
        LayerNorm(inner_dim),   # properly credit assign normformer
        nn.Conv3d(inner_dim, dim, 1, bias = False)
    ))

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 64
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = heads * dim_head
        self.norm = LayerNorm(dim)

        self.to_qkv = nn.Conv3d(dim, inner_dim * 3, 1, bias = False)
        self.to_out = nn.Conv3d(inner_dim, dim, 1, bias = False)

    def forward(self, x):
        f, h, w = x.shape[-3:]

        residual = x.clone()

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h (f x y) d -> b (h d) f x y', f = f, x = h, y = w)
        return self.to_out(out) + residual

# Transformer 块
class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        **kwargs
    ):
        super().__init__()
        self.attn = Attention(dim, **kwargs)
        self.ff = FeedForward(dim)

    def forward(self, x):
        x = self.attn(x)
        x = self.ff(x)
        return x

# 特征图整合器
class FeatureMapConsolidator(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_ins = tuple(),
        dim_outs = tuple(),
        resize_fmap_before = True,
        conv_block_fn = None
    ):
        super().__init__()
        assert len(dim_ins) == len(dim_outs)
        self.needs_consolidating = len(dim_ins) > 0

        block_fn = default(conv_block_fn, Block)

        # 特征图卷积层列表
        self.fmap_convs = nn.ModuleList([block_fn(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
        self.resize_fmap_before = resize_fmap_before

        self.final_dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)

    # 调整特征图大小
    def resize_fmaps(self, fmaps, height, width):
        return [F.interpolate(fmap, (fmap.shape[-3], height, width)) for fmap in fmaps]
    # 定义一个前向传播函数,接受输入 x 和特征图 fmaps,默认为 None
    def forward(self, x, fmaps = None):
        # 获取输入 x 的高度和宽度
        target_height, target_width = x.shape[-2:]

        # 如果未提供特征图 fmaps,则设置为空元组
        fmaps = default(fmaps, tuple())

        # 如果不需要合并特征图,则直接返回输入 x
        if not self.needs_consolidating:
            return x

        # 如果需要在卷积之前调整特征图大小
        if self.resize_fmap_before:
            # 调整特征图大小
            fmaps = self.resize_fmaps(fmaps, target_height, target_width)

        # 初始化一个空列表用于存储输出
        outs = []
        # 遍历特征图和卷积层,将卷积后的结果添加到输出列表中
        for fmap, conv in zip(fmaps, self.fmap_convs):
            outs.append(conv(fmap))

        # 如果需要在卷积之前调整特征图大小
        if self.resize_fmap_before:
            # 调整输出列表中的特征图大小
            outs = self.resize_fmaps(outs, target_height, target_width)

        # 将输入 x 和所有输出特征图连接在一起,沿着通道维度
        return torch.cat((x, *outs), dim = 1)
# 定义一个函数,返回一个类型为 type 或者包含 type 的元组
def MaybeTuple(type):
    return Union[type, Tuple[type, ...]]

# 根据卷积核大小计算 padding 大小
def kernel_and_same_pad(*kernel_size):
    paddings = tuple(map(lambda k: k // 2, kernel_size))
    return dict(kernel_size = kernel_size, padding = paddings)

# 定义 XUnet 类
class XUnet(nn.Module):

    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        frame_kernel_size = 1,
        dim_mults: MaybeTuple(int) = (1, 2, 4, 8),
        num_blocks_per_stage: MaybeTuple(int) = (2, 2, 2, 2),
        num_self_attn_per_stage: MaybeTuple(int) = (0, 0, 0, 1),
        nested_unet_depths: MaybeTuple(int) = (0, 0, 0, 0),
        nested_unet_dim = 32,
        channels = 3,
        use_convnext = False,
        resnet_groups = 8,
        consolidate_upsample_fmaps = True,
        skip_scale = 2 ** -0.5,
        weight_standardize = False,
        attn_heads: MaybeTuple(int) = 8,
        attn_dim_head: MaybeTuple(int) = 32
    def forward(self, x):
        is_image = x.ndim == 4

        # 验证

        assert not (is_image and not self.train_as_images), 'you specified a frame kernel size for the convolutions in this unet, but you are passing in images'
        assert not (not is_image and self.train_as_images), 'you specified no frame kernel size dimension, yet you are passing in a video. fold the frame dimension into the batch'

        # 将图像转换为帧数为 1 的视频

        if is_image:
            x = rearrange(x, 'b c h w -> b c 1 h w')

        # 初始卷积

        x = self.init_conv(x)

        # 残差

        r = x.clone()

        # 下采样和上采样

        down_hiddens = []
        up_hiddens = []

        for init_block, blocks, attn_blocks, downsample in self.downs:
            x = init_block(x)

            for block in blocks:
                x = block(x)

            for attn_block in attn_blocks:
                x = attn_block(x)

            down_hiddens.append(x)
            x = downsample(x)

        x = self.mid(x)
        x = self.mid_attn(x) + x
        x = self.mid_after(x)

        up_hiddens.append(x)
        x = self.mid_upsample(x)


        for init_block, blocks, attn_blocks, upsample in self.ups:
            x = torch.cat((x, down_hiddens.pop() * self.skip_scale), dim=1)

            x = init_block(x)

            for block in blocks:
                x = block(x)

            for attn_block in attn_blocks:
                x = attn_block(x)

            up_hiddens.insert(0, x)
            x = upsample(x)

        # 合并特征图

        x = self.consolidator(x, up_hiddens)

        # 最终残差

        x = torch.cat((x, r), dim = 1)

        # 最终卷积

        out = self.final_conv(x)

        if is_image:
            out = rearrange(out, 'b c 1 h w -> b c h w')

        return out

# 定义 PixelShuffleUpsample 类
class PixelShuffleUpsample(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        scale_factor = 2
    ):
        super().__init__()
        self.scale_squared = scale_factor ** 2
        dim_out = default(dim_out, dim)
        conv = nn.Conv3d(dim, dim_out * self.scale_squared, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            Rearrange('b (c r s) f h w -> b c f (h r) (w s)', r = scale_factor, s = scale_factor)
        )

        self.init_conv_(conv)

    # 初始化卷积层
    def init_conv_(self, conv):
        o, i, *rest_dims = conv.weight.shape
        conv_weight = torch.empty(o // self.scale_squared, i, *rest_dims)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.scale_squared)

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def forward(self, x):
        x = self.net(x)
        return x

# 定义 NestedResidualUnet 类
class NestedResidualUnet(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        dim,
        *,
        depth,
        M = 32,
        frame_kernel_size = 1,
        add_residual = False,
        groups = 4,
        skip_scale = 2 ** -0.5,
        weight_standardize = False
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 设置模型深度和下采样、上采样模块
        self.depth = depth
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])

        # 根据是否需要标准化权重选择卷积层类型
        conv = WeightStandardizedConv3d if weight_standardize else nn.Conv3d

        # 循环构建下采样模块
        for ind in range(depth):
            is_first = ind == 0
            dim_in = dim if is_first else M

            down = nn.Sequential(
                conv(dim_in, M, (1, 4, 4), stride = (1, 2, 2), padding = (0, 1, 1)),
                nn.GroupNorm(groups, M),
                nn.SiLU()
            )

            # 添加到下采样模块列表
            self.downs.append(down)

            # 构建上采样模块
            up = nn.Sequential(
                PixelShuffleUpsample(2 * M, dim_in),
                nn.GroupNorm(groups, dim_in),
                nn.SiLU()
            )

            # 添加到上采样模块列表
            self.ups.append(up)

        # 中间层模块
        self.mid = nn.Sequential(
            conv(M, M, **kernel_and_same_pad(frame_kernel_size, 3, 3)),
            nn.GroupNorm(groups, M),
            nn.SiLU()
        )

        # 设置跳跃连接的缩放因子和是否添加残差连接
        self.skip_scale = skip_scale
        self.add_residual = add_residual

    # 前向传播函数
    def forward(self, x, residual = None):
        # 判断输入是否为视频
        is_video = x.ndim == 5

        # 如果需要添加残差连接,则复制输入作为残差
        if self.add_residual:
            residual = default(residual, x.clone())

        # 获取输入张量的高度和宽度
        *_, h, w = x.shape

        # 计算模型层数
        layers = len(self.ups)

        # 检查输入张量的高度和宽度是否符合要求
        for dim_name, size in (('height', h), ('width', w)):
            assert divisible_by(size, 2 ** layers), f'{dim_name} dimension {size} must be divisible by {2 ** layers} ({layers} layers in nested unet)'
            assert (size % (2 ** self.depth)) == 0, f'the unet has too much depth for the image {dim_name} ({size}) being passed in'

        # hiddens

        # 存储中间特征
        hiddens = []

        # unet

        # 下采样过程
        for down in self.downs:
            x = down(x)
            hiddens.append(x.clone().contiguous())

        # 中间层处理
        x = self.mid(x)

        # 上采样过程
        for up in reversed(self.ups):
            x = torch.cat((x, hiddens.pop() * self.skip_scale), dim = 1)
            x = up(x)

        # 添加残差连接
        if self.add_residual:
            x = x + residual
            x = F.silu(x)

        # 返回处理后的张量
        return x

.\lucidrains\x-unet\x_unet\__init__.py

# 从 x_unet 模块中导入 XUnet 和 NestedResidualUnet 类
from x_unet.x_unet import XUnet, NestedResidualUnet

Zorro - Pytorch

Implementation of Zorro, Masked Multimodal Transformer, in Pytorch. This is a Deepmind work that claims a special masking strategy within a transformer help them achieve SOTA on a few multimodal benchmarks.

Appreciation

  • Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research

Install

$ pip install zorro-pytorch

Usage

import torch
from zorro_pytorch import Zorro, TokenTypes as T

model = Zorro(
    dim = 512,                        # model dimensions
    depth = 6,                        # depth
    dim_head = 64,                    # attention dimension heads
    heads = 8,                        # attention heads
    ff_mult = 4,                      # feedforward multiple
    num_fusion_tokens = 16,           # number of fusion tokens
    audio_patch_size = 16,            # audio patch size, can also be Tuple[int, int]
    video_patch_size = 16,            # video patch size, can also be Tuple[int, int]
    video_temporal_patch_size = 2,    # video temporal patch size
    video_channels = 3,               # video channels
    return_token_types = (
        T.AUDIO,
        T.AUDIO,
        T.FUSION,
        T.GLOBAL,
        T.VIDEO,
        T.VIDEO,
        T.VIDEO,
    ) # say you want to return 2 tokens for audio, 1 token for fusion, 3 for video - for whatever self-supervised learning, supervised learning, etc etc
)

video = torch.randn(2, 3, 8, 32, 32) # (batch, channels, time, height, width)
audio = torch.randn(2, 1024 * 10)    # (batch, time)

return_tokens = model(audio = audio, video = video) # (2, 6, 512) - all 6 tokes as indicated above is returned

# say you only want 1 audio and 1 video token, for contrastive learning

audio_token, video_token = model(audio = audio, video = video, return_token_indices = (0, 3)).unbind(dim = -2) # (2, 512), (2, 512)

Citations

@inproceedings{Recasens2023ZorroTM,
  title  = {Zorro: the masked multimodal transformer},
  author = {Adri{\`a} Recasens and Jason Lin and Jo{\~a}o Carreira and Drew Jaegle and Luyu Wang and Jean-Baptiste Alayrac and Pauline Luc and Antoine Miech and Lucas Smaira and Ross Hemsley and Andrew Zisserman},
  year   = {2023}
}

.\lucidrains\zorro-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'zorro-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.1.1',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Zorro - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/zorro-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'multimodal fusion'
  ],
  # 安装依赖
  install_requires=[
    'beartype',
    'einops>=0.4',
    'torch>=1.6',
    'torchaudio'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\zorro-pytorch\zorro_pytorch\zorro_pytorch.py

# 导入所需的模块和类
from enum import Enum
import functools
from functools import wraps

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

from beartype import beartype
from beartype.typing import Tuple, Optional, Union

from torchaudio.transforms import Spectrogram

# 定义枚举类型 TokenTypes,包含音频、视频、融合和全局四种类型
class TokenTypes(Enum):
    AUDIO = 0
    VIDEO = 1
    FUSION = 2
    GLOBAL = 3

# 定义一些通用的函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回参数列表中第一个存在的参数,如果都不存在则返回 None
def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

# 返回小于等于 n 的最接近的 divisor 的倍数
def round_down_nearest_multiple(n, divisor):
    return n // divisor * divisor

# 将输入转换为元组,如果输入不是元组则返回 (t, t)
def pair(t):
    return (t, t) if not isinstance(t, tuple) else t

# 对可迭代对象进行累积乘法
def cum_mul(it):
    return functools.reduce(lambda x, y: x * y, it, 1)

# 判断 numer 是否能被 denom 整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 装饰器

# 保证函数只调用一次的装饰器
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用 once 装饰的 print 函数,确保只打印一次
print_once = once(print)

# 无偏置的 Layernorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# GEGLU 激活函数
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

# FeedForward 网络结构
def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult * 2 / 3)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context = None,
        attn_mask = None
    ):
        x = self.norm(x)
        kv_x = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(kv_x).chunk(2, dim = -1))

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        if exists(attn_mask):
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 主类 Zorro
class Zorro(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        num_fusion_tokens = 16,
        audio_patch_size: Union[int, Tuple[int, int]] = 16,
        video_patch_size: Union[int, Tuple[int, int]] = 16,
        video_temporal_patch_size = 2,
        video_channels = 3,
        spec_n_fft = 128,
        spec_power = 2,
        spec_win_length = 24,
        spec_hop_length = None,
        spec_pad = 0,
        spec_center = True,
        spec_pad_mode = 'reflect',
        spec_aug_stretch_factor = 0.8,
        spec_aug_freq_mask = 80,
        spec_aug_time_mask = 80,
        return_token_types: Tuple[TokenTypes] = (TokenTypes.AUDIO, TokenTypes.VIDEO, TokenTypes.FUSION)
        ):
        # 调用父类的构造函数
        super().__init__()
        # 设置最大返回标记数为返回标记类型列表的长度
        self.max_return_tokens = len(return_token_types)

        # 存储返回标记类型列表
        self.return_token_types = return_token_types
        # 将返回标记类型列表转换为张量
        return_token_types_tensor = torch.tensor(list(map(lambda t: t.value, return_token_types)))
        # 将返回标记类型张量注册为缓冲区
        self.register_buffer('return_token_types_tensor', return_token_types_tensor, persistent=False)

        # 初始化返回标记张量
        self.return_tokens = nn.Parameter(torch.randn(self.max_return_tokens, dim))
        # 初始化注意力池
        self.attn_pool = Attention(dim=dim, dim_head=dim_head, heads=heads)

        # 音频输入

        # 设置音频块大小
        self.audio_patch_size = audio_patch_height, audio_patch_width = pair(audio_patch_size)

        # 初始化频谱图
        self.spec = Spectrogram(
            n_fft=spec_n_fft,
            power=spec_power,
            win_length=spec_win_length,
            hop_length=spec_hop_length,
            pad=spec_pad,
            center=spec_center,
            pad_mode=spec_pad_mode
        )

        # 计算音频输入维度
        audio_input_dim = cum_mul(self.audio_patch_size)
        # 将音频转换为标记
        self.audio_to_tokens = nn.Sequential(
            Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1=audio_patch_height, p2=audio_patch_width),
            nn.LayerNorm(audio_input_dim),
            nn.Linear(audio_input_dim, dim),
            nn.LayerNorm(dim)
        )

        # 视频输入

        # 设置视频块大小
        self.video_patch_size = (video_temporal_patch_size, *pair(video_patch_size))

        # 计算视频输入维度
        video_input_dim = cum_mul(self.video_patch_size) * video_channels
        video_patch_time, video_patch_height, video_patch_width = self.video_patch_size

        # 将视频转换为标记
        self.video_to_tokens = nn.Sequential(
            Rearrange('b c (t p1) (h p2) (w p3) -> b t h w (c p1 p2 p3)', p1=video_patch_time, p2=video_patch_height, p3=video_patch_width),
            nn.LayerNorm(video_input_dim),
            nn.Linear(video_input_dim, dim),
            nn.LayerNorm(dim)
        )

        # 融合标记

        # 初始化融合标记
        self.fusion_tokens = nn.Parameter(torch.randn(num_fusion_tokens, dim))

        # transformer

        # 初始化层列表
        self.layers = nn.ModuleList([])

        # 循环创建指定数量的层
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim=dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim, mult=ff_mult)
            ]))

        # 初始化层归一化
        self.norm = LayerNorm(dim)

    def forward(
        self,
        *,
        audio,
        video,
        return_token_indices: Optional[Tuple[int]] = None
        ):
        # 获取音频的批次大小和设备信息
        batch, device = audio.shape[0], audio.device
    
        # 验证视频是否可以被分块
        assert all([divisible_by(numer, denom) for denom, numer in zip(self.video_patch_size, tuple(video.shape[-3:]))]), f'video shape {video.shape[-3:]} needs to be divisible by {self.video_patch_size}'

        # 如果音频产生的二维频谱图不是patch大小的倍数,则自动裁剪
        audio = self.spec(audio)

        height, width = audio.shape[-2:]
        patch_height, patch_width = self.audio_patch_size

        rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))

        if (height, width) != (rounded_height, rounded_width): # 只要打印,直到修复为止
            print_once(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')

        audio = audio[..., :rounded_height, :rounded_width]

        # 转换为tokens
        audio_tokens = self.audio_to_tokens(audio)
        video_tokens = self.video_to_tokens(video)
        fusion_tokens = repeat(self.fusion_tokens, 'n d -> b n d', b = batch)

        # 构建所有tokens
        audio_tokens, fusion_tokens, video_tokens = map(lambda t: rearrange(t, 'b ... d -> b (...) d'), (audio_tokens, fusion_tokens, video_tokens))
        tokens, ps = pack((
            audio_tokens,
            fusion_tokens,
            video_tokens
        ), 'b * d')

        # 构建mask(即zorro)
        token_types = torch.tensor(list((
            *((TokenTypes.AUDIO.value,) * audio_tokens.shape[-2]),
            *((TokenTypes.FUSION.value,) * fusion_tokens.shape[-2]),
            *((TokenTypes.VIDEO.value,) * video_tokens.shape[-2]),
        )), device = device, dtype = torch.long)

        token_types_attend_from = rearrange(token_types, 'i -> i 1')
        token_types_attend_to = rearrange(token_types, 'j -> 1 j')

        # 逻辑是每个模态,包括融合,都可以关注自己
        zorro_mask = token_types_attend_from == token_types_attend_to

        # 融合可以关注所有
        zorro_mask = zorro_mask | (token_types_attend_from == TokenTypes.FUSION.value)

        # 注意力和前馈
        for attn, ff in self.layers:
            tokens = attn(tokens, attn_mask = zorro_mask) + tokens
            tokens = ff(tokens) + tokens

        tokens = self.norm(tokens)

        # 最终注意力池化 - 每个模态池token只能关注自己的tokens
        return_tokens = self.return_tokens
        return_token_types_tensor = self.return_token_types_tensor

        if exists(return_token_indices):
            assert len(set(return_token_indices)) == len(return_token_indices), 'all indices must be unique'
            assert all([indice < self.max_return_tokens for indice in return_token_indices]), 'indices must range from 0 to max_num_return_tokens - 1'

            return_token_indices = torch.tensor(return_token_indices, dtype = torch.long, device = device)

            return_token_types_tensor = return_token_types_tensor[return_token_indices]
            return_tokens = return_tokens[return_token_indices]

        return_tokens = repeat(return_tokens, 'n d -> b n d', b = batch)
        pool_mask = rearrange(return_token_types_tensor, 'i -> i 1') == token_types_attend_to
        # 全局查询可以关注所有tokens
        pool_mask = pool_mask | rearrange(return_token_types_tensor, 'i -> i 1') == torch.ones_like(token_types_attend_to, dtype=torch.long) * TokenTypes.GLOBAL.value

        pooled_tokens = self.attn_pool(return_tokens, context = tokens, attn_mask = pool_mask) + return_tokens

        return pooled_tokens

.\lucidrains\zorro-pytorch\zorro_pytorch\__init__.py

# 从 zorro_pytorch.zorro_pytorch 模块中导入 Zorro 类和 TokenTypes 常量
from zorro_pytorch.zorro_pytorch import Zorro, TokenTypes
posted @ 2024-06-28 14:14  绝不原创的飞龙  阅读(29)  评论(0编辑  收藏  举报