Lucidrains-系列项目源码解析-四十八-

Lucidrains 系列项目源码解析(四十八)

.\lucidrains\vit-pytorch\vit_pytorch\mobile_vit.py

import torch
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Reduce

# helpers

# 定义一个 1x1 卷积层 + 批归一化 + SiLU 激活函数的函数
def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

# 定义一个 nxn 卷积层 + 批归一化 + SiLU 激活函数的函数
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

# classes

# 定义一个前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# 定义一个注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)

# 定义一个 Transformer 类
class Transformer(nn.Module):
    """Transformer block described in ViT.
    Paper: https://arxiv.org/abs/2010.11929
    Based on: https://github.com/lucidrains/vit-pytorch
    """

    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads, dim_head, dropout),
                FeedForward(dim, mlp_dim, dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义一个 MV2Block 类
class MV2Block(nn.Module):
    """MV2 block described in MobileNetV2.
    Paper: https://arxiv.org/pdf/1801.04381
    Based on: https://github.com/tonylins/pytorch-mobilenet-v2
    """

    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
    # 定义一个前向传播函数,接受输入 x
    def forward(self, x):
        # 将输入 x 通过卷积层 conv 处理得到输出 out
        out = self.conv(x)
        # 如果使用残差连接
        if self.use_res_connect:
            # 将输出 out 与输入 x 相加,实现残差连接
            out = out + x
        # 返回处理后的输出 out
        return out
class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        # 定义卷积层1,用于局部表示
        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        # 定义卷积层2,用于局部表示到全局表示的转换
        self.conv2 = conv_1x1_bn(channel, dim)

        # 定义 Transformer 模块,用于全局表示
        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        # 定义卷积层3,用于全局表示到局部表示的转换
        self.conv3 = conv_1x1_bn(dim, channel)
        # 定义卷积层4,用于融合局部和全局表示
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)

    def forward(self, x):
        y = x.clone()

        # 计算局部表示
        x = self.conv1(x)
        x = self.conv2(x)

        # 计算全局表示
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)        
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)

        # 融合局部和全局表示
        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return x

class MobileViT(nn.Module):
    """MobileViT.
    Paper: https://arxiv.org/abs/2110.02178
    Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
    """

    def __init__(
        self,
        image_size,
        dims,
        channels,
        num_classes,
        expansion=4,
        kernel_size=3,
        patch_size=(2, 2),
        depths=(2, 4, 3)
    ):
        super().__init__()
        assert len(dims) == 3, 'dims must be a tuple of 3'
        assert len(depths) == 3, 'depths must be a tuple of 3'

        ih, iw = image_size
        ph, pw = patch_size
        assert ih % ph == 0 and iw % pw == 0

        init_dim, *_, last_dim = channels

        # 定义第一个卷积层,用于图像输入的预处理
        self.conv1 = conv_nxn_bn(3, init_dim, stride=2)

        # 定义 stem 部分的卷积块
        self.stem = nn.ModuleList([])
        self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
        self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
        self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))

        # 定义 trunk 部分的卷积块和 MobileViTBlock
        self.trunk = nn.ModuleList([])
        self.trunk.append(nn.ModuleList([
            MV2Block(channels[3], channels[4], 2, expansion),
            MobileViTBlock(dims[0], depths[0], channels[5],
                           kernel_size, patch_size, int(dims[0] * 2))
        ]))

        self.trunk.append(nn.ModuleList([
            MV2Block(channels[5], channels[6], 2, expansion),
            MobileViTBlock(dims[1], depths[1], channels[7],
                           kernel_size, patch_size, int(dims[1] * 4))
        ]))

        self.trunk.append(nn.ModuleList([
            MV2Block(channels[7], channels[8], 2, expansion),
            MobileViTBlock(dims[2], depths[2], channels[9],
                           kernel_size, patch_size, int(dims[2] * 4))
        ]))

        # 定义输出层,包括卷积、池化和全连接层
        self.to_logits = nn.Sequential(
            conv_1x1_bn(channels[-2], last_dim),
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(channels[-1], num_classes, bias=False)
        )

    def forward(self, x):
        x = self.conv1(x)

        # stem 部分的卷积块
        for conv in self.stem:
            x = conv(x)

        # trunk 部分的卷积块和 MobileViTBlock
        for conv, attn in self.trunk:
            x = conv(x)
            x = attn(x)

        return self.to_logits(x)

.\lucidrains\vit-pytorch\vit_pytorch\mp3.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数,从 einops.layers.torch 中导入 Rearrange 类

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将输入转换为元组的函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 位置嵌入

# 生成二维正弦余弦位置嵌入的函数
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 创建网格矩阵 y 和 x
    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    # 断言特征维度必须是 4 的倍数
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    # 计算 omega
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置嵌入
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# 前馈网络

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# (跨)注意力机制

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.norm = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None):
        b, n, _, h = *x.shape, self.heads

        x = self.norm(x)

        context = self.norm(context) if exists(context) else x

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x, context = None):
        for attn, ff in self.layers:
            x = attn(x, context = context) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    # 初始化函数,定义模型的参数和结构
    def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度必须能够被补丁的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        
        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width

        # 设置模型的维度和补丁数量
        self.dim = dim
        self.num_patches = num_patches

        # 定义将图像转换为补丁嵌入的层序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 创建 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 定义将潜在表示转换为输出类别的层
        self.to_latent = nn.Identity()
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数,定义模型的前向计算过程
    def forward(self, img):
        # 获取图像的形状和数据类型
        *_, h, w, dtype = *img.shape, img.dtype

        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 生成位置编码
        pe = posemb_sincos_2d(x)
        # 将位置编码加到补丁嵌入中
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        # 经过 Transformer 模型处理
        x = self.transformer(x)
        # 对补丁进行平均池化
        x = x.mean(dim = 1)

        # 转换为潜在表示
        x = self.to_latent(x)
        # 经过线性层得到输出类别
        return self.linear_head(x)
# 定义 Masked Position Prediction Pre-Training 类
class MP3(nn.Module):
    # 初始化函数,接受 ViT 模型和 masking 比例作为参数
    def __init__(self, vit: ViT, masking_ratio):
        super().__init__()
        self.vit = vit

        # 断言确保 masking 比例在 0 到 1 之间
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        self.masking_ratio = masking_ratio

        # 获取 ViT 模型的维度
        dim = vit.dim
        # 定义 MLP 头部,包含 LayerNorm 和 Linear 层
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, vit.num_patches)
        )

    # 前向传播函数,接受图像作为输入
    def forward(self, img):
        # 获取输入图像的设备信息
        device = img.device
        # 将图像转换为 token
        tokens = self.vit.to_patch_embedding(img)
        # 重新排列 token 的维度
        tokens = rearrange(tokens, 'b ... d -> b (...) d')

        # 获取 batch 大小和 patch 数量
        batch, num_patches, *_ = tokens.shape

        # Masking
        # 计算需要被 mask 的数量
        num_masked = int(self.masking_ratio * num_patches)
        # 生成随机索引并排序
        rand_indices = torch.rand(batch, num_patches, device=device).argsort(dim=-1)
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

        # 生成 batch 范围的索引
        batch_range = torch.arange(batch, device=device)[:, None]
        # 获取未被 mask 的 token
        tokens_unmasked = tokens[batch_range, unmasked_indices]

        # 使用 ViT 模型的 transformer 进行注意力计算
        attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
        # 将输出结果通过 MLP 头部得到 logits
        logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
        
        # 定义标签
        labels = repeat(torch.arange(num_patches, device=device), 'n -> (b n)', b=batch)
        # 计算交叉熵损失
        loss = F.cross_entropy(logits, labels)

        return loss

.\lucidrains\vit-pytorch\vit_pytorch\mpp.py

# 导入数学库
import math

# 导入 PyTorch 库
import torch
from torch import nn
import torch.nn.functional as F

# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 创建概率掩码
def prob_mask_like(t, prob):
    batch, seq_length, _ = t.shape
    return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob

# 根据概率获取掩码子集
def get_mask_subset_with_prob(patched_input, prob):
    batch, seq_len, _, device = *patched_input.shape, patched_input.device
    max_masked = math.ceil(prob * seq_len)

    rand = torch.rand((batch, seq_len), device=device)
    _, sampled_indices = rand.topk(max_masked, dim=-1)

    new_mask = torch.zeros((batch, seq_len), device=device)
    new_mask.scatter_(1, sampled_indices, 1)
    return new_mask.bool()


# MPP 损失函数

class MPPLoss(nn.Module):
    def __init__(
        self,
        patch_size,
        channels,
        output_channel_bits,
        max_pixel_val,
        mean,
        std
    ):
        super().__init__()
        self.patch_size = patch_size
        self.channels = channels
        self.output_channel_bits = output_channel_bits
        self.max_pixel_val = max_pixel_val

        self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
        self.std = torch.tensor(std).view(-1, 1, 1) if std else None

    def forward(self, predicted_patches, target, mask):
        p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
        bin_size = mpv / (2 ** bits)

        # 反归一化输入
        if exists(self.mean) and exists(self.std):
            target = target * self.std + self.mean

        # 将目标数据重塑为补丁
        target = target.clamp(max=mpv)  # 为了安全起见,进行截断
        avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1=p, p2=p).contiguous()

        channel_bins = torch.arange(bin_size, mpv, bin_size, device=device)
        discretized_target = torch.bucketize(avg_target, channel_bins)

        bin_mask = (2 ** bits) ** torch.arange(0, c, device=device).long()
        bin_mask = rearrange(bin_mask, 'c -> () () c')

        target_label = torch.sum(bin_mask * discretized_target, dim=-1)

        loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
        return loss


# 主类

class MPP(nn.Module):
    def __init__(
        self,
        transformer,
        patch_size,
        dim,
        output_channel_bits=3,
        channels=3,
        max_pixel_val=1.0,
        mask_prob=0.15,
        replace_prob=0.5,
        random_patch_prob=0.5,
        mean=None,
        std=None
    ):
        super().__init__()
        self.transformer = transformer
        self.loss = MPPLoss(patch_size, channels, output_channel_bits,
                            max_pixel_val, mean, std)

        # 提取补丁函数
        self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])

        # 输出转换
        self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))

        # ViT 相关维度
        self.patch_size = patch_size

        # MPP 相关概率
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob
        self.random_patch_prob = random_patch_prob

        # 令牌 ID
        self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
    # 定义前向传播函数,接受输入和其他参数
    def forward(self, input, **kwargs):
        # 获取变换器
        transformer = self.transformer
        # 克隆原始图像用于计算损失
        img = input.clone().detach()

        # 将原始图像重塑为补丁
        p = self.patch_size
        input = rearrange(input,
                          'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                          p1=p,
                          p2=p)

        # 使用概率获取子集的掩码
        mask = get_mask_subset_with_prob(input, self.mask_prob)

        # 使用掩码补丁以概率替换输入(以概率 1 - replace_prob 保持补丁不变)
        masked_input = input.clone().detach()

        # 如果随机令牌概率 > 0 用于 mpp
        if self.random_patch_prob > 0:
            random_patch_sampling_prob = self.random_patch_prob / (
                1 - self.replace_prob)
            random_patch_prob = prob_mask_like(input,
                                               random_patch_sampling_prob).to(mask.device)

            bool_random_patch_prob = mask * (random_patch_prob == True)
            random_patches = torch.randint(0,
                                           input.shape[1],
                                           (input.shape[0], input.shape[1]),
                                           device=input.device)
            randomized_input = masked_input[
                torch.arange(masked_input.shape[0]).unsqueeze(-1),
                random_patches]
            masked_input[bool_random_patch_prob] = randomized_input[
                bool_random_patch_prob]

        # [mask] 输入
        replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
        bool_mask_replace = (mask * replace_prob) == True
        masked_input[bool_mask_replace] = self.mask_token

        # 补丁的线性嵌入
        masked_input = self.patch_to_emb(masked_input)

        # 将 cls 令牌添加到输入序列
        b, n, _ = masked_input.shape
        cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
        masked_input = torch.cat((cls_tokens, masked_input), dim=1)

        # 将位置嵌入添加到输入
        masked_input += transformer.pos_embedding[:, :(n + 1)]
        masked_input = transformer.dropout(masked_input)

        # 获取生成器输出并计算 mpp 损失
        masked_input = transformer.transformer(masked_input, **kwargs)
        cls_logits = self.to_bits(masked_input)
        logits = cls_logits[:, 1:, :]

        mpp_loss = self.loss(logits, img, mask)

        return mpp_loss

.\lucidrains\vit-pytorch\vit_pytorch\na_vit.py

from functools import partial
from typing import List, Union

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回固定值的函数
def always(val):
    return lambda *args: val

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 检查一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# auto grouping images

# 根据最大序列长度对图像进行分组
def group_images_by_max_seq_len(
    images: List[Tensor],
    patch_size: int,
    calc_token_dropout = None,
    max_seq_len = 2048

) -> List[List[Tensor]]:

    calc_token_dropout = default(calc_token_dropout, always(0.))

    groups = []
    group = []
    seq_len = 0

    if isinstance(calc_token_dropout, (float, int)):
        calc_token_dropout = always(calc_token_dropout)

    for image in images:
        assert isinstance(image, Tensor)

        image_dims = image.shape[-2:]
        ph, pw = map(lambda t: t // patch_size, image_dims)

        image_seq_len = (ph * pw)
        image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))

        assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'

        if (seq_len + image_seq_len) > max_seq_len:
            groups.append(group)
            group = []
            seq_len = 0

        group.append(image)
        seq_len += image_seq_len

    if len(group) > 0:
        groups.append(group)

    return groups

# normalization
# they use layernorm without bias, something that pytorch does not offer

# 自定义 LayerNorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper

# 自定义 RMSNorm 类
class RMSNorm(nn.Module):
    def __init__(self, heads, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

    def forward(self, x):
        normed = F.normalize(x, dim = -1)
        return normed * self.scale * self.gamma

# feedforward

# 定义 FeedForward 函数
def FeedForward(dim, hidden_dim, dropout = 0.):
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim, dim),
        nn.Dropout(dropout)
    )

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.norm = LayerNorm(dim)

        self.q_norm = RMSNorm(heads, dim_head)
        self.k_norm = RMSNorm(heads, dim_head)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_mask = None
        ):
        # 对输入进行归一化处理
        x = self.norm(x)
        # 从上下文中获取默认的键值对输入
        kv_input = default(context, x)

        # 将输入数据转换为查询、键、值三部分
        qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))

        # 将查询、键、值进行维度重排,以适应多头注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        # 对查询和键进行归一化处理
        q = self.q_norm(q)
        k = self.k_norm(k)

        # 计算查询和键的点积
        dots = torch.matmul(q, k.transpose(-1, -2))

        # 如果存在掩码,则进行掩码处理
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)

        # 如果存在注意力掩码,则进行掩码处理
        if exists(attn_mask):
            dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)

        # 进行注意力计算
        attn = self.attend(dots)
        # 对注意力结果进行dropout处理
        attn = self.dropout(attn)

        # 将注意力结果与值进行加权求和
        out = torch.matmul(attn, v)
        # 将结果进行维度重排,返回输出
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
class Transformer(nn.Module):
    # 定义 Transformer 类,继承自 nn.Module
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        # 初始化函数,接受维度、深度、头数、头维度、MLP维度和dropout参数
        super().__init__()
        # 调用父类的初始化函数
        self.layers = nn.ModuleList([])
        # 初始化层列表
        for _ in range(depth):
            # 循环创建指定深度的层
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

        self.norm = LayerNorm(dim)
        # 初始化 LayerNorm 层

    def forward(
        self,
        x,
        mask = None,
        attn_mask = None
    ):
        # 前向传播函数
        for attn, ff in self.layers:
            # 遍历层列表
            x = attn(x, mask = mask, attn_mask = attn_mask) + x
            # 使用注意力机制处理输入并加上残差连接
            x = ff(x) + x
            # 使用前馈网络处理输入并加上残差连接

        return self.norm(x)
        # 返回经过 LayerNorm 处理后的结果

class NaViT(nn.Module):
    # 定义 NaViT 类,继承自 nn.Module
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
        # 初始化函数,接受图像大小、补丁大小、类别数、维度、深度、头数、MLP维度等参数
        super().__init__()
        # 调用父类的初始化函数
        image_height, image_width = pair(image_size)

        # what percent of tokens to dropout
        # if int or float given, then assume constant dropout prob
        # otherwise accept a callback that in turn calculates dropout prob from height and width

        self.calc_token_dropout = None

        if callable(token_dropout_prob):
            self.calc_token_dropout = token_dropout_prob

        elif isinstance(token_dropout_prob, (float, int)):
            assert 0. < token_dropout_prob < 1.
            token_dropout_prob = float(token_dropout_prob)
            self.calc_token_dropout = lambda height, width: token_dropout_prob

        # calculate patching related stuff

        assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'

        patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
        patch_dim = channels * (patch_size ** 2)

        self.channels = channels
        self.patch_size = patch_size

        self.to_patch_embedding = nn.Sequential(
            LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            LayerNorm(dim),
        )

        self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
        self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # final attention pooling queries

        self.attn_pool_queries = nn.Parameter(torch.randn(dim))
        self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)

        # output to logits

        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, num_classes, bias = False)
        )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(
        self,
        batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
        group_images = False,
        group_max_seq_len = 2048

.\lucidrains\vit-pytorch\vit_pytorch\nest.py

# 导入必要的库
from functools import partial
import torch
from torch import nn, einsum

# 导入 einops 库中的 rearrange 和 Reduce 函数
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce

# 定义一个辅助函数,用于将输入值转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else ((val,) * depth)

# 定义 LayerNorm 类,用于实现层归一化
class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 定义 FeedForward 类,用于实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mlp_mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Conv2d(dim, dim * mlp_mult, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mlp_mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dropout = 0.):
        super().__init__()
        dim_head = dim // heads
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        b, c, h, w, heads = *x.shape, self.heads

        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

# 定义 Aggregate 函数,用于聚合特征
def Aggregate(dim, dim_out):
    return nn.Sequential(
        nn.Conv2d(dim, dim_out, 3, padding = 1),
        LayerNorm(dim_out),
        nn.MaxPool2d(3, stride = 2, padding = 1)
    )

# 定义 Transformer 类,用于实现 Transformer 模型
class Transformer(nn.Module):
    def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.pos_emb = nn.Parameter(torch.randn(seq_len))

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dropout = dropout),
                FeedForward(dim, mlp_mult, dropout = dropout)
            ]))
    def forward(self, x):
        *_, h, w = x.shape

        pos_emb = self.pos_emb[:(h * w)]
        pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
        x = x + pos_emb

        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义 NesT 类,用于实现 NesT 模型
class NesT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        heads,
        num_hierarchies,
        block_repeats,
        mlp_mult = 4,
        channels = 3,
        dim_head = 64,
        dropout = 0.
    ):
        # 调用父类的构造函数
        super().__init__()
        # 确保图像尺寸能够被分块尺寸整除
        assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算总的分块数量
        num_patches = (image_size // patch_size) ** 2
        # 计算每个分块的维度
        patch_dim = channels * patch_size ** 2
        # 计算特征图的大小
        fmap_size = image_size // patch_size
        # 计算块的数量
        blocks = 2 ** (num_hierarchies - 1)

        # 计算序列长度,跨层次保持不变
        seq_len = (fmap_size // blocks) ** 2
        # 生成层次列表
        hierarchies = list(reversed(range(num_hierarchies)))
        # 计算每个层次的倍数
        mults = [2 ** i for i in reversed(hierarchies)]

        # 计算每个层次的头数
        layer_heads = list(map(lambda t: t * heads, mults))
        # 计算每个层次的维度
        layer_dims = list(map(lambda t: t * dim, mults))
        # 最后一个维度
        last_dim = layer_dims[-1]

        # 添加最后一个维度到层次维度列表
        layer_dims = [*layer_dims, layer_dims[-1]]
        # 生成维度对
        dim_pairs = zip(layer_dims[:-1], layer_dims[1:])

        # 定义将图像转换为分块嵌入的序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
            LayerNorm(patch_dim),
            nn.Conv2d(patch_dim, layer_dims[0], 1),
            LayerNorm(layer_dims[0])
        )

        # 将块重复次数转换为元组
        block_repeats = cast_tuple(block_repeats, num_hierarchies)

        # 初始化层次列表
        self.layers = nn.ModuleList([])

        # 遍历层次、头数、维度对、块重复次数
        for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
            is_last = level == 0
            depth = block_repeat

            # 添加 Transformer 和 Aggregate 模块到层次列表
            self.layers.append(nn.ModuleList([
                Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
                Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
            ]))


        # 定义 MLP 头部
        self.mlp_head = nn.Sequential(
            LayerNorm(last_dim),
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(last_dim, num_classes)
        )

    def forward(self, img):
        # 将图像转换为分块��入
        x = self.to_patch_embedding(img)
        b, c, h, w = x.shape

        # 获取层次数量
        num_hierarchies = len(self.layers)

        # 遍历层次,应用 Transformer 和 Aggregate 模块
        for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
            block_size = 2 ** level
            x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
            x = transformer(x)
            x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
            x = aggregate(x)

        # 应用 MLP 头部并返回结果
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\parallel_vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 定义一个函数,如果输入参数 t 是元组则返回 t,否则返回一个包含 t 的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 类定义

# 定义一个并行模块类
class Parallel(nn.Module):
    def __init__(self, *fns):
        super().__init__()
        self.fns = nn.ModuleList(fns)

    def forward(self, x):
        return sum([fn(x) for fn in self.fns])

# 定义一个前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 定义一个注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义一个 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])

        attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
        ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
                Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
            ]))

    def forward(self, x):
        for attns, ffs in self.layers:
            x = attns(x) + x
            x = ffs(x) + x
        return x

# 定义一个 ViT 类
class ViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        # 检查图像尺寸是否能被分块尺寸整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算分块数量和分块维度
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        # 检查池化类型是否为'cls'或'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 将图像分块转换为嵌入向量
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        # 初始化位置编码和类别标记
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化Transformer模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        # MLP头部,用于分类
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 图像转换为分块嵌入向量
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        # Transformer模型处理嵌入向量
        x = self.transformer(x)

        # 池化操作,根据池化类型选择不同的操作
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 转换为最终输出
        x = self.to_latent(x)
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\pit.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt

# 导入 torch 模块及相关子模块
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入 einops 模块中的 rearrange 和 repeat 函数,以及 torch 子模块中的 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 定义辅助函数

# 将输入值转换为元组,如果不是元组则重复 num 次
def cast_tuple(val, num):
    return val if isinstance(val, tuple) else (val,) * num

# 计算卷积输出大小
def conv_output_size(image_size, kernel_size, stride, padding = 0):
    return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)

# 定义类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads

        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 深度卷积,用于池化
class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

# 池化层
class Pool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
        self.cls_ff = nn.Linear(dim, dim * 2)

    def forward(self, x):
        cls_token, tokens = x[:, :1], x[:, 1:]

        cls_token = self.cls_ff(cls_token)

        tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
        tokens = self.downsample(tokens)
        tokens = rearrange(tokens, 'b c h w -> b (h w) c')

        return torch.cat((cls_token, tokens), dim = 1)

# 主类
class PiT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        heads,
        mlp_dim,
        dim_head = 64,
        dropout = 0.,
        emb_dropout = 0.,
        channels = 3
    ):  
        # 初始化函数,继承父类的初始化方法
        super().__init__()
        # 确保图像尺寸能够被分块大小整除
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 确保深度是一个整数元组,指定每个下采样之前的块数
        assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
        # 将头数转换为元组
        heads = cast_tuple(heads, len(depth))

        # 计算每个分块的维度
        patch_dim = channels * patch_size ** 2

        # 创建将图像转换为分块嵌入的序列
        self.to_patch_embedding = nn.Sequential(
            nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
            Rearrange('b c n -> b n c'),
            nn.Linear(patch_dim, dim)
        )

        # 计算输出大小和分块数量
        output_size = conv_output_size(image_size, patch_size, patch_size // 2)
        num_patches = output_size ** 2

        # 初始化位置嵌入和类别令牌
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        layers = []

        # 遍历深度和头数,构建Transformer层和池化层
        for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
            not_last = ind < (len(depth) - 1)
            
            layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))

            if not_last:
                layers.append(Pool(dim))
                dim *= 2

        self.layers = nn.Sequential(*layers)

        # 创建MLP头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # 将图像转换为分块嵌入
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别令牌并连接到分块嵌入
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :n+1]
        x = self.dropout(x)

        # 通过层堆叠Transformer
        x = self.layers(x)

        # 返回MLP头部的输出
        return self.mlp_head(x[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\recorder.py

# 从 functools 模块导入 wraps 装饰器
from functools import wraps
# 导入 torch 模块
import torch
# 从 torch 模块导入 nn 模块
from torch import nn

# 从 vit_pytorch.vit 模块导入 Attention 类
from vit_pytorch.vit import Attention

# 定义一个函数,用于查找指定类型的模块
def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

# 定义一个 Recorder 类,继承自 nn.Module 类
class Recorder(nn.Module):
    # 初始化方法
    def __init__(self, vit, device = None):
        super().__init__()
        self.vit = vit

        self.data = None
        self.recordings = []
        self.hooks = []
        self.hook_registered = False
        self.ejected = False
        self.device = device

    # 私有方法,用于注册钩子函数
    def _hook(self, _, input, output):
        self.recordings.append(output.clone().detach())

    # 注册钩子函数的方法
    def _register_hook(self):
        # 查找所有 transformer 模块中的 Attention 模块
        modules = find_modules(self.vit.transformer, Attention)
        # 为每个 Attention 模块注册前向钩子函数
        for module in modules:
            handle = module.attend.register_forward_hook(self._hook)
            self.hooks.append(handle)
        self.hook_registered = True

    # 弹出 Recorder 对象的方法
    def eject(self):
        self.ejected = True
        # 移除所有钩子函数
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        return self.vit

    # 清空记录数据的方法
    def clear(self):
        self.recordings.clear()

    # 记录数据的方法
    def record(self, attn):
        recording = attn.clone().detach()
        self.recordings.append(recording)

    # 前向传播方法
    def forward(self, img):
        # 断言 Recorder 对象未被弹出
        assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
        self.clear()
        # 如果钩子函数未注册,则注册钩子函数
        if not self.hook_registered:
            self._register_hook()

        # 对输入图片进行预测
        pred = self.vit(img)

        # 将所有记录数据移动到指定设备上
        target_device = self.device if self.device is not None else img.device
        recordings = tuple(map(lambda t: t.to(target_device), self.recordings))

        # 如果有记录数据,则在指定维度上堆叠
        attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
        return pred, attns

.\lucidrains\vit-pytorch\vit_pytorch\regionvit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange 函数和 Rearrange、Reduce 类
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# 从 torch 库中导入 nn.functional 模块,并重命名为 F

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将变量转换为元组,如果不是元组则重复 length 次
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 判断一个数是否可以被另一个数整除
def divisible_by(val, d):
    return (val % d) == 0

# 辅助类

# 下采样类
class Downsample(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)

    def forward(self, x):
        return self.conv(x)

# PEG 类
class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        return self.proj(x) + x

# transformer 类

# 前馈网络
def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult, 1),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim, 1)
    )

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, rel_pos_bias = None):
        h = self.heads

        # prenorm

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # split heads

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        q = q * self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # add relative positional bias for local tokens

        if exists(rel_pos_bias):
            sim = sim + rel_pos_bias

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # merge heads

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# R2LTransformer 类
class R2LTransformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        window_size,
        depth = 4,
        heads = 4,
        dim_head = 32,
        attn_dropout = 0.,
        ff_dropout = 0.,
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        self.window_size = window_size
        rel_positions = 2 * window_size - 1
        self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
                FeedForward(dim, dropout = ff_dropout)
            ]))
    # 定义一个前向传播函数,接受本地 tokens 和区域 tokens 作为输入
    def forward(self, local_tokens, region_tokens):
        # 获取本地 tokens 的设备信息
        device = local_tokens.device
        # 获取本地 tokens 和区域 tokens 的高度和宽度
        lh, lw = local_tokens.shape[-2:]
        rh, rw = region_tokens.shape[-2:]
        # 计算窗口大小
        window_size_h, window_size_w = lh // rh, lw // rw

        # 重排本地 tokens 和区域 tokens 的维度
        local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
        region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')

        # 计算本地相对位置偏差
        h_range = torch.arange(window_size_h, device = device)
        w_range = torch.arange(window_size_w, device = device)
        grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij')
        grid = torch.stack((grid_x, grid_y))
        grid = rearrange(grid, 'c h w -> c (h w)')
        grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
        bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
        rel_pos_bias = self.local_rel_pos_bias(bias_indices)
        rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
        rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)

        # 遍历 r2l transformer 层
        for attn, ff in self.layers:
            # 对区域 tokens 进行自注意力操作
            region_tokens = attn(region_tokens) + region_tokens

            # 将区域 tokens 连接到本地 tokens
            local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
            local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
            region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')

            # 对本地 tokens 进行自注意力操作,同时考虑区域 tokens
            region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
            region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens

            # 前馈神经网络
            region_and_local_tokens = ff(region_and_local_tokens) + region_and_local_tokens

            # 分离本地和区域 tokens
            region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
            local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
            region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)

        # 重排本地 tokens 和区域 tokens 的维度
        local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
        region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
        # 返回本地 tokens 和区域 tokens
        return local_tokens, region_tokens
# 定义一个名为 RegionViT 的类,继承自 nn.Module
class RegionViT(nn.Module):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        *,
        dim = (64, 128, 256, 512),  # 定义维度的元组
        depth = (2, 2, 8, 2),  # 定义深度的元组
        window_size = 7,  # 定义窗口大小
        num_classes = 1000,  # 定义类别数量
        tokenize_local_3_conv = False,  # 是否使用局部 3 卷积
        local_patch_size = 4,  # 定义局部补丁大小
        use_peg = False,  # 是否使用 PEG
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_dropout = 0.,  # 前馈神经网络的 dropout
        channels = 3,  # 通道数
    ):
        super().__init__()  # 调用父类的初始化函数
        dim = cast_tuple(dim, 4)  # 将维度转换为元组
        depth = cast_tuple(depth, 4)  # 将深度转换为元组
        assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'  # 断言维度长度为 4
        assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'  # 断言深度长度为 4

        self.local_patch_size = local_patch_size  # 设置局部补丁大小

        region_patch_size = local_patch_size * window_size  # 计算区域补丁大小
        self.region_patch_size = local_patch_size * window_size  # 设置区域补丁大小

        init_dim, *_, last_dim = dim  # 解构维度元组

        # 定义局部和区域编码器

        if tokenize_local_3_conv:
            self.local_encoder = nn.Sequential(
                nn.Conv2d(3, init_dim, 3, 2, 1),
                nn.LayerNorm(init_dim),
                nn.GELU(),
                nn.Conv2d(init_dim, init_dim, 3, 2, 1),
                nn.LayerNorm(init_dim),
                nn.GELU(),
                nn.Conv2d(init_dim, init_dim, 3, 1, 1)
            )
        else:
            self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)

        self.region_encoder = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
            nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
        )

        # 定义层

        current_dim = init_dim  # 初始化当前维度
        self.layers = nn.ModuleList([])  # 初始化层列表

        for ind, dim, num_layers in zip(range(4), dim, depth):
            not_first = ind != 0  # 判断是否为第一层
            need_downsample = not_first  # 是否需要下采样
            need_peg = not_first and use_peg  # 是否需要 PEG

            self.layers.append(nn.ModuleList([
                Downsample(current_dim, dim) if need_downsample else nn.Identity(),  # 如果需要下采样则使用 Downsample,否则使用恒等映射
                PEG(dim) if need_peg else nn.Identity(),  # 如果需要 PEG 则使用 PEG,否则使用恒等映射
                R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)  # 使用 R2LTransformer
            ]))

            current_dim = dim  # 更新当前维度

        # 定义最终的 logits

        self.to_logits = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),  # 对特征进行降维
            nn.LayerNorm(last_dim),  # 对最后一个维度进行 LayerNorm
            nn.Linear(last_dim, num_classes)  # 线性变换得到类别数量
        )

    # 前向传播函数
    def forward(self, x):
        *_, h, w = x.shape  # 获取输入张量的高度和宽度
        assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'  # 断言高度和宽度必须能被区域补丁大小整除
        assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'  # 断言高度和宽度必须能被局部补丁大小整除

        local_tokens = self.local_encoder(x)  # 使用局部编码器对输入进行编码
        region_tokens = self.region_encoder(x)  # 使用区域编码器对输入进行编码

        for down, peg, transformer in self.layers:  # 遍历层列表
            local_tokens, region_tokens = down(local_tokens), down(region_tokens)  # 对局部和区域 tokens 进行下采样
            local_tokens = peg(local_tokens)  # 使用 PEG 对局部 tokens 进行处理
            local_tokens, region_tokens = transformer(local_tokens, region_tokens)  # 使用 transformer 对局部和区域 tokens 进行处理

        return self.to_logits(region_tokens)  # 返回最终的 logits

.\lucidrains\vit-pytorch\vit_pytorch\rvt.py

# 从 math 模块中导入 sqrt, pi, log 函数
# 从 torch 模块中导入 nn, einsum, F
# 从 einops 模块中导入 rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
from math import sqrt, pi, log

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 旋转嵌入

# 将输入张量中的每两个元素进行旋转
def rotate_every_two(x):
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d j -> ... (d j)')

# 轴向旋转嵌入类
class AxialRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_freq = 10):
        super().__init__()
        self.dim = dim
        scales = torch.linspace(1., max_freq / 2, self.dim // 4)
        self.register_buffer('scales', scales)

    def forward(self, x):
        device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))

        seq = torch.linspace(-1., 1., steps = n, device = device)
        seq = seq.unsqueeze(-1)

        scales = self.scales[(*((None,) * (len(seq.shape) - 1)), Ellipsis]
        scales = scales.to(x)

        seq = seq * scales * pi

        x_sinu = repeat(seq, 'i d -> i j d', j = n)
        y_sinu = repeat(seq, 'j d -> i j d', i = n)

        sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
        cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)

        sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
        sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
        return sin, cos

# 深度可分离卷积类
class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

# 辅助类

# 空间卷积类
class SpatialConv(nn.Module):
    def __init__(self, dim_in, dim_out, kernel, bias = False):
        super().__init__()
        self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
        self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()

    def forward(self, x, fmap_dims):
        cls_token, x = x[:, :1], x[:, 1:]
        x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
        x = self.conv(x)
        x = rearrange(x, 'b d h w -> b (h w) d')
        cls_token = self.cls_proj(cls_token)
        return torch.cat((cls_token, x), dim = 1)

# GEGLU 类
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return F.gelu(gates) * x

# 前馈网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
            GEGLU() if use_glu else nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5):
        super().__init__()
        inner_dim = dim_head *  heads
        self.use_rotary = use_rotary
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.use_ds_conv = use_ds_conv

        self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)

        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    # 定义前向传播函数,接受输入 x,位置嵌入 pos_emb,特征图维度 fmap_dims
    def forward(self, x, pos_emb, fmap_dims):
        # 获取输入 x 的形状信息
        b, n, _, h = *x.shape, self.heads

        # 如果使用深度可分离卷积,则传递特定参数给 to_q 函数
        to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}

        # 对输入 x 进行归一化处理
        x = self.norm(x)

        # 将 x 传递给 to_q 函数,得到查询向量 q
        q = self.to_q(x, **to_q_kwargs)

        # 将 q 与键值对应的结果拆分为 q, k, v
        qkv = (q, *self.to_kv(x).chunk(2, dim = -1))

        # 将 q, k, v 重排维度,以适应多头注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

        # 如果使用旋转注意力机制
        if self.use_rotary:
            # 对查询和键应用二维旋转嵌入,不包括 CLS 标记
            sin, cos = pos_emb
            dim_rotary = sin.shape[-1]

            # 拆分 CLS 标记和其余部分
            (q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))

            # 处理旋转维度小于头维度的情况
            (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
            q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
            q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))

            # 拼接回 CLS 标记
            q = torch.cat((q_cls, q), dim = 1)
            k = torch.cat((k_cls, k), dim = 1)

        # 计算点积注意力得分
        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        # 经过注意力计算
        attn = self.attend(dots)
        attn = self.dropout(attn)

        # 计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        # 返回输出结果
        return self.to_out(out)
# 定义一个 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化一个空的层列表
        self.layers = nn.ModuleList([])
        # 创建 AxialRotaryEmbedding 对象作为位置编码
        self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
        # 循环创建指定数量的层
        for _ in range(depth):
            # 每层包含注意力机制和前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv),
                FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu)
            ]))
    # 前向传播函数,接受输入 x 和 fmap_dims
    def forward(self, x, fmap_dims):
        # 计算位置编码
        pos_emb = self.pos_emb(x[:, 1:])
        # 遍历每一层,依次进行注意力机制和前馈神经网络操作
        for attn, ff in self.layers:
            x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x
            x = ff(x) + x
        # 返回处理后的结果
        return x

# 定义一个 RvT 类,继承自 nn.Module
class RvT(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
        # 调用父类的初始化函数
        super().__init__()
        # 断言确保图像尺寸能够被补丁大小整除
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算补丁数量和补丁维度
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2

        # 初始化补丁嵌入层
        self.patch_size = patch_size
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_dim, dim),
        )

        # 初始化分类令牌
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)

        # 初始化 MLP 头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数,接受输入图像 img
    def forward(self, img):
        # 获取输入图像的形状信息
        b, _, h, w, p = *img.shape, self.patch_size

        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        n = x.shape[1]

        # 重复分类令牌并与补丁嵌入拼接
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)

        # 计算特征图尺寸信息
        fmap_dims = {'h': h // p, 'w': w // p}
        # 使用 Transformer 处理输入数据
        x = self.transformer(x, fmap_dims = fmap_dims)

        # 返回 MLP 头部处理后的结果
        return self.mlp_head(x[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\scalable_vit.py

# 导入必要的库
from functools import partial
import torch
from torch import nn

# 导入 einops 库中的函数和层
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 将输入转换为指定长度的元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 辅助类

# 通道层归一化
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 下采样
class Downsample(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)

    def forward(self, x):
        return self.conv(x)

# 位置编码器
class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        return self.proj(x) + x

# 前馈网络
class FeedForward(nn.Module):
    def __init__(self, dim, expansion_factor = 4, dropout = 0.):
        super().__init__()
        inner_dim = dim * expansion_factor
        self.net = nn.Sequential(
            ChanLayerNorm(dim),
            nn.Conv2d(dim, inner_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制

# 可扩展的自注意力机制
class ScalableSelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_key = 32,
        dim_value = 32,
        dropout = 0.,
        reduction_factor = 1
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_key ** -0.5
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.norm = ChanLayerNorm(dim)
        self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
        self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(dim_value * heads, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        height, width, heads = *x.shape[-2:], self.heads

        x = self.norm(x)

        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        # 分割头部

        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))

        # 相似度

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # 注意力权重

        attn = self.attend(dots)
        attn = self.dropout(attn)

        # 聚合数值

        out = torch.matmul(attn, v)

        # 合并头部

        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
        return self.to_out(out)

# 交互式窗口化自注意力机制
class InteractiveWindowedSelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        window_size,
        heads = 8,
        dim_key = 32,
        dim_value = 32,
        dropout = 0.
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化头数和缩放因子
        self.heads = heads
        self.scale = dim_key ** -0.5
        self.window_size = window_size
        # 初始化注意力机制和dropout层
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        # 初始化通道层归一化和局部交互模块
        self.norm = ChanLayerNorm(dim)
        self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)

        # 初始化转换层,将输入转换为查询、键和值
        self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)

        # 初始化输出层,包括卷积层和dropout层
        self.to_out = nn.Sequential(
            nn.Conv2d(dim_value * heads, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # 获取输入张量的高度、宽度、头数和窗口大小
        height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size

        # 对输入张量进行归一化
        x = self.norm(x)

        # 计算窗口的高度和宽度
        wsz_h, wsz_w = default(wsz, height), default(wsz, width)
        assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'

        # 将输入张量转换为查询、键和值
        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        # 获取局部交互模块的输出
        local_out = self.local_interactive_module(v)

        # 将查询、键和值分割成窗口(并拆分出头部)以进行有效的自注意力计算
        q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))

        # 计算相似度
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # 注意力计算
        attn = self.attend(dots)
        attn = self.dropout(attn)

        # 聚合值
        out = torch.matmul(attn, v)

        # 将窗口重新整形为完整的特征图(并合并头部)
        out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)

        # 添加局部交互模块的输出
        out = out + local_out

        return self.to_out(out)
class Transformer(nn.Module):
    # 定义 Transformer 类,继承自 nn.Module
    def __init__(
        self,
        dim,
        depth,
        heads = 8,
        ff_expansion_factor = 4,
        dropout = 0.,
        ssa_dim_key = 32,
        ssa_dim_value = 32,
        ssa_reduction_factor = 1,
        iwsa_dim_key = 32,
        iwsa_dim_value = 32,
        iwsa_window_size = None,
        norm_output = True
    ):
        # 初始化函数
        super().__init__()
        # 初始化 nn.ModuleList 用于存储 Transformer 层
        self.layers = nn.ModuleList([])
        # 循环创建 Transformer 层
        for ind in range(depth):
            # 判断是否为第一层
            is_first = ind == 0

            # 添加 Transformer 层的组件到 layers 中
            self.layers.append(nn.ModuleList([
                ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout),
                FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
                PEG(dim) if is_first else None,
                FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
                InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout)
            ]))

        # 初始化最后的归一化层
        self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

    # 前向传播函数
    def forward(self, x):
        # 遍历 Transformer 层
        for ssa, ff1, peg, iwsa, ff2 in self.layers:
            # Self-Attention 操作
            x = ssa(x) + x
            # FeedForward 操作
            x = ff1(x) + x

            # 如果存在 PEG 操作,则执行
            if exists(peg):
                x = peg(x)

            # Interactive Windowed Self-Attention 操作
            x = iwsa(x) + x
            # 再次 FeedForward 操作
            x = ff2(x) + x

        # 返回归一化后的结果
        return self.norm(x)

class ScalableViT(nn.Module):
    # 定义 ScalableViT 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        heads,
        reduction_factor,
        window_size = None,
        iwsa_dim_key = 32,
        iwsa_dim_value = 32,
        ssa_dim_key = 32,
        ssa_dim_value = 32,
        ff_expansion_factor = 4,
        channels = 3,
        dropout = 0.
    ):
        # 初始化函数
        super().__init__()
        # 将图像转换为补丁
        self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)

        # 断言 depth 为元组,表示每个阶段的 Transformer 块数量
        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

        # 计算每个阶段的维度
        num_stages = len(depth)
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))

        # 定义每个阶段的超参数
        hyperparams_per_stage = [
            heads,
            ssa_dim_key,
            ssa_dim_value,
            reduction_factor,
            iwsa_dim_key,
            iwsa_dim_value,
            window_size,
        ]

        # 将超参数转换为每个阶段的形式
        hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
        assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))

        # 初始化 Transformer 层
        self.layers = nn.ModuleList([])

        # 遍历每个阶段的维度和超参数
        for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
            is_last = ind == (num_stages - 1)

            # 添加 Transformer 层和下采样层到 layers 中
            self.layers.append(nn.ModuleList([
                Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size, norm_output = not is_last),
                Downsample(layer_dim, layer_dim * 2) if not is_last else None
            ]))

        # MLP 头部
        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为补丁
        x = self.to_patches(img)

        # 遍历每个 Transformer 层
        for transformer, downsample in self.layers:
            x = transformer(x)

            # 如果存在下采样层,则执行
            if exists(downsample):
                x = downsample(x)

        # 返回 MLP 头部的结果
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\sep_vit.py

# 导入必要的库
from functools import partial

import torch
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# 辅助函数

def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 辅助类

class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class OverlappingPatchEmbed(nn.Module):
    def __init__(self, dim_in, dim_out, stride = 2):
        super().__init__()
        kernel_size = stride * 2 - 1
        padding = kernel_size // 2
        self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)

    def forward(self, x):
        return self.conv(x)

class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        return self.proj(x) + x

# 前馈网络

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            ChanLayerNorm(dim),
            nn.Conv2d(dim, inner_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制

class DSSA(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 32,
        dropout = 0.,
        window_size = 7
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.window_size = window_size
        inner_dim = dim_head * heads

        self.norm = ChanLayerNorm(dim)

        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)

        # 窗口标记

        self.window_tokens = nn.Parameter(torch.randn(dim))

        # 窗口标记的预处理和非线性变换
        # 然后将窗口标记投影到查询和键

        self.window_tokens_to_qk = nn.Sequential(
            nn.LayerNorm(dim_head),
            nn.GELU(),
            Rearrange('b h n c -> b (h c) n'),
            nn.Conv1d(inner_dim, inner_dim * 2, 1),
            Rearrange('b (h c) n -> b h n c', h = heads),
        )

        # 窗口注意力

        self.window_attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        """
        einstein notation

        b - batch
        c - channels
        w1 - window size (height)
        w2 - also window size (width)
        i - sequence dimension (source)
        j - sequence dimension (target dimension to be reduced)
        h - heads
        x - height of feature map divided by window size
        y - width of feature map divided by window size
        """

        # 获取输入张量的形状信息
        batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
        # 检查高度和宽度是否可以被窗口大小整除
        assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
        # 计算窗口数量
        num_windows = (height // wsz) * (width // wsz)

        # 对输入张量进行归一化处理
        x = self.norm(x)

        # 将窗口折叠进行“深度”注意力 - 不确定为什么它被命名为深度,当它只是“窗口化”注意力时
        x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)

        # 添加窗口标记
        w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
        x = torch.cat((w, x), dim = -1)

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        # 分离头部
        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))

        # 缩放
        q = q * self.scale

        # 相似度
        dots = einsum('b h i d, b h j d -> b h i j', q, k)

        # 注意力
        attn = self.attend(dots)

        # 聚合值
        out = torch.matmul(attn, v)

        # 分离窗口标记和窗口化特征图
        window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]

        # 如果只有一个窗口,则提前返回
        if num_windows == 1:
            fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
            return self.to_out(fmap)

        # 执行点对点注意力,这是论文中的主要创新
        window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
        windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)

        # 窗口化查询和键(在进行预归一化激活之前)
        w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)

        # 缩放
        w_q = w_q * self.scale

        # 相似度
        w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)

        w_attn = self.window_attend(w_dots)

        # 聚合来自“深度”注意力步骤的特征图(论文中最有趣的部分,我以前没有见过)
        aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)

        # 折叠回窗口,然后组合头部以进行聚合
        fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
        return self.to_out(fmap)
class Transformer(nn.Module):
    # 定义 Transformer 类,继承自 nn.Module
    def __init__(
        self,
        dim,
        depth,
        dim_head = 32,
        heads = 8,
        ff_mult = 4,
        dropout = 0.,
        norm_output = True
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        self.layers = nn.ModuleList([])

        for ind in range(depth):
            # 遍历深度次数
            self.layers.append(nn.ModuleList([
                DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mult = ff_mult, dropout = dropout),
            ]))
            # 在 layers 中添加 DSSA 和 FeedForward 模块

        self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
        # 如果 norm_output 为 True,则使用 ChanLayerNorm,否则使用 nn.Identity

    def forward(self, x):
        # 前向传播函数
        for attn, ff in self.layers:
            # 遍历 layers 中的模块
            x = attn(x) + x
            # 对输入 x 进行注意力操作
            x = ff(x) + x
            # 对输入 x 进行前馈操作

        return self.norm(x)
        # 返回经过规范化的结果

class SepViT(nn.Module):
    # 定义 SepViT 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        heads,
        window_size = 7,
        dim_head = 32,
        ff_mult = 4,
        channels = 3,
        dropout = 0.
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
        # 断言 depth 是元组类型,用于指示每个阶段的 transformer 块数量

        num_stages = len(depth)
        # 获取深度的长度

        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (channels, *dims)
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))
        # 计算每个阶段的维度

        strides = (4, *((2,) * (num_stages - 1)))
        # 定义步长

        hyperparams_per_stage = [heads, window_size]
        hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
        assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
        # 处理每个阶段的超参数

        self.layers = nn.ModuleList([])

        for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)):
            # 遍历每个阶段的参数
            is_last = ind == (num_stages - 1)

            self.layers.append(nn.ModuleList([
                OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride),
                PEG(layer_dim),
                Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last),
            ]))
            # 在 layers 中添加 OverlappingPatchEmbed、PEG 和 Transformer 模块

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )
        # 定义 MLP 头部模块

    def forward(self, x):
        # 前向传播函数
        for ope, peg, transformer in self.layers:
            # 遍历 layers 中的模块
            x = ope(x)
            # 对输入 x 进行 OverlappingPatchEmbed 操作
            x = peg(x)
            # 对输入 x 进行 PEG 操作
            x = transformer(x)
            # 对输入 x 进行 Transformer 操作

        return self.mlp_head(x)
        # 返回经过 MLP 头部处理的结果

.\lucidrains\vit-pytorch\vit_pytorch\simmim.py

import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat

class SimMIM(nn.Module):
    def __init__(
        self,
        *,
        encoder,
        masking_ratio = 0.5
    ):
        super().__init__()
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        self.masking_ratio = masking_ratio

        # extract some hyperparameters and functions from encoder (vision transformer to be trained)

        self.encoder = encoder
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        self.to_patch = encoder.to_patch_embedding[0]
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

        pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

        # simple linear head

        self.mask_token = nn.Parameter(torch.randn(encoder_dim))
        self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)

    def forward(self, img):
        device = img.device

        # get patches

        patches = self.to_patch(img)
        batch, num_patches, *_ = patches.shape

        # for indexing purposes

        batch_range = torch.arange(batch, device = device)[:, None]

        # get positions

        pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]

        # patch to encoder tokens and add positions

        tokens = self.patch_to_emb(patches)
        tokens = tokens + pos_emb

        # prepare mask tokens

        mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
        mask_tokens = mask_tokens + pos_emb

        # calculate of patches needed to be masked, and get positions (indices) to be masked

        num_masked = int(self.masking_ratio * num_patches)
        masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
        masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()

        # mask tokens

        tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)

        # attend with vision transformer

        encoded = self.encoder.transformer(tokens)

        # get the masked tokens

        encoded_mask_tokens = encoded[batch_range, masked_indices]

        # small linear projection for predicted pixel values

        pred_pixel_values = self.to_pixels(encoded_mask_tokens)

        # get the masked patches for the final reconstruction loss

        masked_patches = patches[batch_range, masked_indices]

        # calculate reconstruction loss

        recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
        return recon_loss

.\lucidrains\vit-pytorch\vit_pytorch\simple_flash_attn_vit.py

# 导入必要的库
from collections import namedtuple
from packaging import version

import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# 定义常量
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# 主类

class Attend(nn.Module):
    def __init__(self, use_flash = False):
        super().__init__()
        self.use_flash = use_flash
        assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 CUDA 和 CPU 的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not use_flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            self.cuda_config = Config(True, False, False)
        else:
            self.cuda_config = Config(False, True, True)

    def flash_attn(self, q, k, v):
        config = self.cuda_config if q.is_cuda else self.cpu_config

        # Flash Attention - https://arxiv.org/abs/2205.14135
        
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(q, k, v)

        return out

    def forward(self, q, k, v):
        n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5

        if self.use_flash:
            return self.flash_attn(q, k, v)

        # 相似度

        sim = einsum("b h i d, b j d -> b h i j", q, k) * scale

        # 注意力

        attn = sim.softmax(dim=-1)

        # 聚合值

        out = einsum("b h i j, b j d -> b h i d", attn, v)

        return out

# 类

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = Attend(use_flash = use_flash)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        out = self.attend(q, k, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    # 初始化 Transformer 模型
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个空的层列表
        self.layers = nn.ModuleList([])
        # 根据深度循环创建多个 Transformer 层
        for _ in range(depth):
            # 每个 Transformer 层包含注意力机制和前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
                FeedForward(dim, mlp_dim)
            ]))
    
    # 前向传播函数
    def forward(self, x):
        # 遍历每个 Transformer 层
        for attn, ff in self.layers:
            # 使用注意力机制处理输入,并将结果与输入相加
            x = attn(x) + x
            # 使用前馈神经网络处理输入,并将结果与输入相加
            x = ff(x) + x
        # 返回处理后的结果
        return x
# 定义一个简单的ViT模型,继承自nn.Module类
class SimpleViT(nn.Module):
    # 初始化函数,接收一系列参数来构建ViT模型
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash = True):
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取patch的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度必须能够被patch的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算patch的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个patch的维度
        patch_dim = channels * patch_height * patch_width

        # 将图像转换为patch嵌入
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 创建Transformer模块
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash)

        # 将嵌入转换为潜在空间
        self.to_latent = nn.Identity()
        # 线性头部,用于分类
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数,接收图像作为输入
    def forward(self, img):
        # 获取图像的形状和数据类型
        *_, h, w, dtype = *img.shape, img.dtype

        # 将图像转换为patch嵌入
        x = self.to_patch_embedding(img)
        # 生成位置编码
        pe = posemb_sincos_2d(x)
        # 将位置编码加到嵌入中
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        # 经过Transformer模块处理
        x = self.transformer(x)
        # 对所有patch的输出取平均值
        x = x.mean(dim = 1)

        # 转换为潜在空间
        x = self.to_latent(x)
        # 使用线性头部进行分类
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 Rearrange 函数
from einops import rearrange
from einops.layers.torch import Rearrange

# 定义辅助函数

# 如果 t 是元组则返回 t,否则返回 (t, t)
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维正弦余弦位置编码
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    # 生成网格坐标
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    # 确保特征维度是 4 的倍数
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    # 计算 omega
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    # 拼接正弦余弦位置编码
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# 定义类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 简单 ViT 模型类
class SimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        ) 

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.pool = "mean"
        self.to_latent = nn.Identity()

        self.linear_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        device = img.device

        x = self.to_patch_embedding(img)
        x += self.pos_embedding.to(device, dtype=x.dtype)

        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_1d.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 Rearrange 函数
from einops import rearrange
from einops.layers.torch import Rearrange

# 定义函数 posemb_sincos_1d,用于生成位置编码
def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 生成序列 n
    n = torch.arange(n, device = device)
    # 检查 dim 是否为偶数
    assert (dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb'
    # 计算 omega
    omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置编码
    n = n.flatten()[:, None] * omega[None, :]
    pe = torch.cat((n.sin(), n.cos()), dim = 1)
    return pe.type(dtype)

# 定义类 FeedForward
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 定义类 Attention
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义类 Transformer
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 定义类 SimpleViT
class SimpleViT(nn.Module):
    def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        super().__init__()

        assert seq_len % patch_size == 0

        num_patches = seq_len // patch_size
        patch_dim = channels * patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (n p) -> b n (p c)', p = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.to_latent = nn.Identity()
        self.linear_head = nn.Linear(dim, num_classes)

    def forward(self, series):
        *_, n, dtype = *series.shape, series.dtype

        x = self.to_patch_embedding(series)
        pe = posemb_sincos_1d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return self.linear_head(x)

# 在主函数中创建 SimpleViT 实例 v
if __name__ == '__main__':

    v = SimpleViT(
        seq_len = 256,
        patch_size = 16,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 8,
        mlp_dim = 2048
    )

    # 生成随机时间序列数据
    time_series = torch.randn(4, 3, 256)
    # 输入时间序列数据到 SimpleViT 模型中,得到 logits
    logits = v(time_series) # (4, 1000)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_3d.py

import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

# 如果输入参数是元组,则返回元组,否则返回包含两个相同元素的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成三维位置编码的正弦和余弦值
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 生成三维网格坐标
    z, y, x = torch.meshgrid(
        torch.arange(f, device = device),
        torch.arange(h, device = device),
        torch.arange(w, device = device),
    indexing = 'ij')

    # 计算傅立叶维度
    fourier_dim = dim // 6

    # 计算温度参数
    omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置编码
    z = z.flatten()[:, None] * omega[None, :]
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 

    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)

    # 如果特征维度不能被6整除,则进行填充
    pe = F.pad(pe, (0, dim - (fourier_dim * 6)))
    return pe.type(dtype)

# classes

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 模型类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class SimpleViT(nn.Module):
    # 初始化函数,设置模型参数和结构
    def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像大小的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取图像块的高度和宽度
        patch_height, patch_width = pair(image_patch_size)

        # 断言图像高度和宽度能够被图像块的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言帧数能够被帧块大小整除
        assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'

        # 计算图像块的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
        # 计算图像块的维度
        patch_dim = channels * patch_height * patch_width * frame_patch_size

        # 将图像块转换为嵌入向量
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 创建 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        # 将嵌入向量转换为潜在向量
        self.to_latent = nn.Identity()
        # 线性层,用于分类
        self.linear_head = nn.Linear(dim, num_classes)

    # 前向传播函数
    def forward(self, video):
        # 获取视频的形状信息
        *_, h, w, dtype = *video.shape, video.dtype

        # 将视频转换为图像块的嵌入向量
        x = self.to_patch_embedding(video)
        # 获取位置编码
        pe = posemb_sincos_3d(x)
        # 将位置编码加到嵌入向量中
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        # 经过 Transformer 模型处理
        x = self.transformer(x)
        # 对结果进行平均池化
        x = x.mean(dim = 1)

        # 转换为潜在向量
        x = self.to_latent(x)
        # 使用线性层进行分类
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_with_fft.py

# 导入 torch 库
import torch
# 从 torch.fft 中导入 fft2 函数
from torch.fft import fft2
# 从 torch 中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange、reduce、pack、unpack 函数
from einops import rearrange, reduce, pack, unpack
# 从 einops.layers.torch 中导入 Rearrange 类

# 辅助函数

# 如果输入 t 是元组,则返回 t,否则返回 (t, t)
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的函数
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    # 生成网格坐标
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    # 确保特征维度是4的倍数
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    # 计算 omega
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    # 计算位置编码
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# 类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# SimpleViT 类
class SimpleViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        # 调用父类初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取 patch 的高度和宽度
        patch_height, patch_width = pair(patch_size)
        # 获取频域 patch 的高度和宽度
        freq_patch_height, freq_patch_width = pair(freq_patch_size)

        # 断言图像的高度和宽度能够被 patch 的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言图像的高度和宽度能够被频域 patch 的高度和宽度整除
        assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'

        # 计算 patch 的维度
        patch_dim = channels * patch_height * patch_width
        # 计算频域 patch 的维度
        freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width

        # 将图像转换为 patch 的嵌入向量
        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 将频域 patch 转换为嵌入向量
        self.to_freq_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
            nn.LayerNorm(freq_patch_dim),
            nn.Linear(freq_patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 生成位置编码
        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        )

        # 生成频域位置编码
        self.freq_pos_embedding = posemb_sincos_2d(
            h = image_height // freq_patch_height,
            w = image_width // freq_patch_width,
            dim = dim
        )

        # 创建 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        # 池化方式为平均池化
        self.pool = "mean"
        # 转换为潜在空间的操作
        self.to_latent = nn.Identity()

        # 线性层,用于分类
        self.linear_head = nn.Linear(dim, num_classes)

    # 前向传播函数
    def forward(self, img):
        # 获取设备和数据类型
        device, dtype = img.device, img.dtype

        # 将图像转换为 patch 的嵌入向量
        x = self.to_patch_embedding(img)
        # 对图像进行二维傅里叶变换
        freqs = torch.view_as_real(fft2(img))

        # 将频域 patch 转换为嵌入向量
        f = self.to_freq_embedding(freqs)

        # 添加位置编码
        x += self.pos_embedding.to(device, dtype = dtype)
        f += self.freq_pos_embedding.to(device, dtype = dtype)

        # 打包数据
        x, ps = pack((f, x), 'b * d')

        # 使用 Transformer 进行特征提取
        x = self.transformer(x)

        # 解包数据
        _, x = unpack(x, ps, 'b * d')
        # 对特征进行池化操作
        x = reduce(x, 'b n d -> b d', 'mean')

        # 转换为潜在空间
        x = self.to_latent(x)
        # 使用线性层进行分类
        return self.linear_head(x)
# 如果当前脚本作为主程序运行
if __name__ == '__main__':
    # 创建一个简单的ViT模型实例,指定参数包括类别数、图像大小、patch大小、频率patch大小、维度、深度、头数、MLP维度
    vit = SimpleViT(
        num_classes = 1000,
        image_size = 256,
        patch_size = 8,
        freq_patch_size = 8,
        dim = 1024,
        depth = 1,
        heads = 8,
        mlp_dim = 2048,
    )

    # 生成一个8个样本的随机张量,每个样本包含3个通道,大小为256x256
    images = torch.randn(8, 3, 256, 256)

    # 将图像输入ViT模型,得到输出logits
    logits = vit(images)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_with_patch_dropout.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 Rearrange 函数
from einops import rearrange
from einops.layers.torch import Rearrange

# 辅助函数

# 如果输入 t 是元组,则返回 t,否则返回 (t, t)
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 创建网格矩阵 y 和 x
    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    # 确保特征维度是 4 的倍数,用于 sincos 编码
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    # 计算 omega 值
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置编码
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# 补丁丢弃

# 定义 PatchDropout 类
class PatchDropout(nn.Module):
    def __init__(self, prob):
        super().__init__()
        assert 0 <= prob < 1.
        self.prob = prob

    def forward(self, x):
        # 如果不是训练状态或概率为 0,则直接返回输入 x
        if not self.training or self.prob == 0.:
            return x

        # 获取输入 x 的形状信息
        b, n, _, device = *x.shape, x.device

        # 创建批次索引
        batch_indices = torch.arange(b, device = device)
        batch_indices = rearrange(batch_indices, '... -> ... 1')
        # 计算要保留的补丁数量
        num_patches_keep = max(1, int(n * (1 - self.prob)))
        # 随机选择要保留的补丁索引
        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

        return x[batch_indices, patch_indices_keep]

# 类

# 定义前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 定义注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义变换器类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 简单 ViT 模型类
class SimpleViT(nn.Module):
    # 初始化函数,设置模型参数和层结构
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被补丁的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width

        # 定义将图像转换为补丁嵌入的层结构
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 定义补丁的丢弃层
        self.patch_dropout = PatchDropout(patch_dropout)

        # 定义变换器层
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        # 定义转换为潜在空间的层
        self.to_latent = nn.Identity()
        # 定义线性头层
        self.linear_head = nn.Linear(dim, num_classes)

    # 前向传播函数
    def forward(self, img):
        # 获取图像的形状和数据类型
        *_, h, w, dtype = *img.shape, img.dtype

        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 获取位置编码
        pe = posemb_sincos_2d(x)
        # 将位置编码添加到补丁嵌入中
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        # 对补丁进行丢弃
        x = self.patch_dropout(x)

        # 使用变换器进行转换
        x = self.transformer(x)
        # 对结果进行平均池化
        x = x.mean(dim = 1)

        # 转换为潜在空间
        x = self.to_latent(x)
        # 使用线性头层进行分类预测
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_with_qk_norm.py

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

# 定义一个函数,如果输入参数是元组则返回元组,否则返回包含两个相同元素的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper

# in latest tweet, seem to claim more stable training at higher learning rates
# unsure if this has taken off within Brain, or it has some hidden drawback

# 定义一个类,实现 RMS 归一化
class RMSNorm(nn.Module):
    def __init__(self, heads, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim) / self.scale)

    def forward(self, x):
        normed = F.normalize(x, dim = -1)
        return normed * self.scale * self.gamma

# classes

# 定义一个类,实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 定义一个类,实现注意力机制
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.q_norm = RMSNorm(heads, dim_head)
        self.k_norm = RMSNorm(heads, dim_head)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        q = self.q_norm(q)
        k = self.k_norm(k)

        dots = torch.matmul(q, k.transpose(-1, -2))

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义一个类,实现 Transformer 模型
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class SimpleViT(nn.Module):
    # 初始化函数,设置模型参数和层结构
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被补丁的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的维度
        patch_dim = channels * patch_height * patch_width

        # 定义将图像转换为补丁嵌入的层结构
        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 生成位置编码
        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        ) 

        # 定义 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        # 池化方式为平均池化
        self.pool = "mean"
        # 定义将嵌入转换为潜在表示的层结构
        self.to_latent = nn.Identity()

        # 线性层归一化
        self.linear_head = nn.LayerNorm(dim)

    # 前向传播函数
    def forward(self, img):
        # 获取输入图像的设备信息
        device = img.device

        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 添加位置编码
        x += self.pos_embedding.to(device, dtype=x.dtype)

        # 经过 Transformer 模型
        x = self.transformer(x)
        # 对特征进行平均池化
        x = x.mean(dim = 1)

        # 将特征转换为潜在表示
        x = self.to_latent(x)
        # 返回线性层归一化后的结果
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_with_register_tokens.py

"""
    Vision Transformers Need Registers
    https://arxiv.org/abs/2309.16588
"""

import torch
from torch import nn

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# helpers

# 定义一个函数,将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# classes

# 定义前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 定义注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义Transformer类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 定义简单的ViT模型类
class SimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))

        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        ) 

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.pool = "mean"
        self.to_latent = nn.Identity()

        self.linear_head = nn.Linear(dim, num_classes)
    # 定义前向传播函数,接收输入图像
    def forward(self, img):
        # 获取输入图像的批量大小和设备信息
        batch, device = img.shape[0], img.device

        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 将位置嵌入添加到补丁嵌入中
        x += self.pos_embedding.to(device, dtype=x.dtype)

        # 重复注册令牌以匹配批量大小
        r = repeat(self.register_tokens, 'n d -> b n d', b = batch)

        # 打包补丁嵌入和注册令牌
        x, ps = pack([x, r], 'b * d')

        # 使用Transformer处理输入数据
        x = self.transformer(x)

        # 解包处理后的数据
        x, _ = unpack(x, ps, 'b * d')

        # 对数据进行平均池化
        x = x.mean(dim = 1)

        # 将数据转换为潜在空间
        x = self.to_latent(x)
        # 使用线性头部进行最终预测
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\t2t.py

# 导入所需的库
import math
import torch
from torch import nn

# 导入自定义的 Transformer 类
from vit_pytorch.vit import Transformer

# 导入 einops 库中的函数和类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于计算卷积层输出的大小
def conv_output_size(image_size, kernel_size, stride, padding):
    return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)

# 定义一个类,用于将输入重新排列成指定形状
class RearrangeImage(nn.Module):
    def forward(self, x):
        return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1]))

# 定义主要的 T2TViT 类
class T2TViT(nn.Module):
    def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
        super().__init__()
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        layers = []
        layer_dim = channels
        output_image_size = image_size

        # 遍历 t2t_layers 中的每个元素
        for i, (kernel_size, stride) in enumerate(t2t_layers):
            layer_dim *= kernel_size ** 2
            is_first = i == 0
            is_last = i == (len(t2t_layers) - 1)
            output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)

            # 根据条件选择不同的层
            layers.extend([
                RearrangeImage() if not is_first else nn.Identity(),
                nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
                Rearrange('b c n -> b n c'),
                Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
            ])

        layers.append(nn.Linear(layer_dim, dim))
        self.to_patch_embedding = nn.Sequential(*layers)

        # 初始化位置编码和类别标记
        self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 根据是否提供 transformer 参数选择不同的 Transformer 模型
        if not exists(transformer):
            assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
            self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        else:
            self.transformer = transformer

        self.pool = pool
        self.to_latent = nn.Identity()

        # 定义 MLP 头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :n+1]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\twins_svt.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助方法

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(), dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 根据前缀分组并移除前缀
def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
    return kwargs_without_prefix, kwargs

# 类

# 残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 层归一化
class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Conv2d(dim, dim * mult, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 图像块嵌入
class PatchEmbedding(nn.Module):
    def __init__(self, *, dim, dim_out, patch_size):
        super().__init__()
        self.dim = dim
        self.dim_out = dim_out
        self.patch_size = patch_size

        self.proj = nn.Sequential(
            LayerNorm(patch_size ** 2 * dim),
            nn.Conv2d(patch_size ** 2 * dim, dim_out, 1),
            LayerNorm(dim_out)
        )

    def forward(self, fmap):
        p = self.patch_size
        fmap = rearrange(fmap, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = p, p2 = p)
        return self.proj(fmap)

# 像素级注意力
class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = Residual(nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1))

    def forward(self, x):
        return self.proj(x)

# 局部注意力
class LocalAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7):
        super().__init__()
        inner_dim = dim_head *  heads
        self.patch_size = patch_size
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = LayerNorm(dim)
        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, fmap):
        fmap = self.norm(fmap)

        shape, p = fmap.shape, self.patch_size
        b, n, x, y, h = *shape, self.heads
        x, y = map(lambda t: t // p, (x, y))

        fmap = rearrange(fmap, 'b c (x p1) (y p2) -> (b x y) c p1 p2', p1 = p, p2 = p)

        q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) p1 p2 -> (b h) (p1 p2) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = dots.softmax(dim = - 1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b x y h) (p1 p2) d -> b (h d) (x p1) (y p2)', h = h, x = x, y = y, p1 = p, p2 = p)
        return self.to_out(out)

class GlobalAttention(nn.Module):
    # 初始化函数,设置注意力机制的参数
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
        # 调用父类的初始化函数
        super().__init__()
        # 计算内部维度
        inner_dim = dim_head *  heads
        # 设置头数和缩放因子
        self.heads = heads
        self.scale = dim_head ** -0.5

        # 归一化层
        self.norm = LayerNorm(dim)

        # 转换查询向量
        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        # 转换键值对
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)

        # 丢弃部分数据
        self.dropout = nn.Dropout(dropout)

        # 输出层
        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        # 对输入数据进行归一化
        x = self.norm(x)

        # 获取输入数据的形状
        shape = x.shape
        b, n, _, y, h = *shape, self.heads
        # 分别计算查询、键、值
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))

        # 重排查询、键、值的维度
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

        # 计算点积
        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        # 计算注意力分布
        attn = dots.softmax(dim = -1)
        attn = self.dropout(attn)

        # 计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)
class Transformer(nn.Module):
    # 定义 Transformer 类,继承自 nn.Module
    def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_patch_size = 7, global_k = 7, dropout = 0., has_local = True):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        self.layers = nn.ModuleList([])
        # 初始化 layers 为一个空的 ModuleList
        for _ in range(depth):
            # 循环 depth 次
            self.layers.append(nn.ModuleList([
                # 向 layers 中添加一个 ModuleList
                Residual(LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size)) if has_local else nn.Identity(),
                # 添加 LocalAttention 或者 Identity 到 ModuleList
                Residual(FeedForward(dim, mlp_mult, dropout = dropout)) if has_local else nn.Identity(),
                # 添加 FeedForward 或者 Identity 到 ModuleList
                Residual(GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k)),
                # 添加 GlobalAttention 到 ModuleList
                Residual(FeedForward(dim, mlp_mult, dropout = dropout))
                # 添加 FeedForward 到 ModuleList
            ]))
        # 循环结束后,layers 中包含 depth 个 ModuleList
    def forward(self, x):
        # 定义 forward 函数,接受输入 x
        for local_attn, ff1, global_attn, ff2 in self.layers:
            # 遍历 layers 中的每个 ModuleList
            x = local_attn(x)
            # 对 x 应用 local_attn
            x = ff1(x)
            # 对 x 应用 ff1
            x = global_attn(x)
            # 对 x 应用 global_attn
            x = ff2(x)
            # 对 x 应用 ff2
        return x
        # 返回处理后的 x

class TwinsSVT(nn.Module):
    # 定义 TwinsSVT 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_classes,
        s1_emb_dim = 64,
        s1_patch_size = 4,
        s1_local_patch_size = 7,
        s1_global_k = 7,
        s1_depth = 1,
        s2_emb_dim = 128,
        s2_patch_size = 2,
        s2_local_patch_size = 7,
        s2_global_k = 7,
        s2_depth = 1,
        s3_emb_dim = 256,
        s3_patch_size = 2,
        s3_local_patch_size = 7,
        s3_global_k = 7,
        s3_depth = 5,
        s4_emb_dim = 512,
        s4_patch_size = 2,
        s4_local_patch_size = 7,
        s4_global_k = 7,
        s4_depth = 4,
        peg_kernel_size = 3,
        dropout = 0.
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        kwargs = dict(locals())
        # 将参数保存为字典

        dim = 3
        # 初始化维度为 3
        layers = []
        # 初始化 layers 为空列表

        for prefix in ('s1', 's2', 's3', 's4'):
            # 遍历前缀列表
            config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
            # 从参数字典中提取以当前前缀开头的参数
            is_last = prefix == 's4'
            # 判断是否是最后一个前缀

            dim_next = config['emb_dim']
            # 获取下一个维度

            layers.append(nn.Sequential(
                # 向 layers 中添加一个 Sequential 模块
                PatchEmbedding(dim = dim, dim_out = dim_next, patch_size = config['patch_size']),
                # 添加 PatchEmbedding 到 Sequential
                Transformer(dim = dim_next, depth = 1, local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last),
                # 添加 Transformer 到 Sequential
                PEG(dim = dim_next, kernel_size = peg_kernel_size),
                # 添加 PEG 到 Sequential
                Transformer(dim = dim_next, depth = config['depth'],  local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last)
                # 添加 Transformer 到 Sequential
            ))

            dim = dim_next
            # 更新维度为下一个维度

        self.layers = nn.Sequential(
            # 将 layers 中的模块组合成一个 Sequential
            *layers,
            # 展开 layers 中的模块
            nn.AdaptiveAvgPool2d(1),
            # 添加 AdaptiveAvgPool2d 到 Sequential
            Rearrange('... () () -> ...'),
            # 添加 Rearrange 到 Sequential
            nn.Linear(dim, num_classes)
            # 添加 Linear 到 Sequential
        )

    def forward(self, x):
        # 定义 forward 函数,接受输入 x
        return self.layers(x)
        # 返回处理后的 x

.\lucidrains\vit-pytorch\vit_pytorch\vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 定义辅助函数 pair,用于返回元组形式的输入
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 定义 FeedForward 类,继承自 nn.Module 类
class FeedForward(nn.Module):
    # 初始化函数,接受维度 dim、隐藏层维度 hidden_dim 和 dropout 概率
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类,继承自 nn.Module 类
class Attention(nn.Module):
    # 初始化函数,接受维度 dim、头数 heads、头维度 dim_head 和 dropout 概率
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    # 前向传播函数
    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,继承自 nn.Module 类
class Transformer(nn.Module):
    # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP 维度 mlp_dim 和 dropout 概率
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

# 定义 ViT 类,继承自 nn.Module 类
class ViT(nn.Module):
    # 初始化函数,接受关键字参数 image_size、patch_size、num_classes、dim、depth、heads、mlp_dim、pool、channels、dim_head、dropout 和 emb_dropout
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)
    # 前向传播函数,接收输入图像并进行处理
    def forward(self, img):
        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 获取补丁嵌入的形状信息
        b, n, _ = x.shape

        # 为每个样本添加类别标记
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        # 将类别标记与补丁嵌入拼接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置编码到输入
        x += self.pos_embedding[:, :(n + 1)]
        # 对输入进行 dropout 处理
        x = self.dropout(x)

        # 使用 Transformer 处理输入
        x = self.transformer(x)

        # 对 Transformer 输出进行池化操作,取平均值或者只取第一个位置的输出
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 将输出转换为潜在空间
        x = self.to_latent(x)
        # 使用 MLP 头部处理最终输出
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vit_1d.py

import torch
from torch import nn

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 导入所需的库

# 定义 FeedForward 类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 正则化
            nn.Linear(hidden_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)  # 对输入进行 Layer Normalization
        self.attend = nn.Softmax(dim = -1)  # Softmax 函数
        self.dropout = nn.Dropout(dropout)  # Dropout 正则化

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 将输入进行线性变换后切分成三部分
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  # 重排张量形状

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # 点积计算

        attn = self.attend(dots)  # 注意力权重计算
        attn = self.dropout(attn)  # Dropout 正则化

        out = torch.matmul(attn, v)  # 加权求和
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排张量形状
        return self.to_out(out)

# 定义 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义 ViT ��
class ViT(nn.Module):
    def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert (seq_len % patch_size) == 0

        num_patches = seq_len // patch_size
        patch_dim = channels * patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (n p) -> b n (p c)', p = patch_size),  # 重排张量形状
            nn.LayerNorm(patch_dim),  # 对输入进行 Layer Normalization
            nn.Linear(patch_dim, dim),  # 线性变换
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 位置编码
        self.cls_token = nn.Parameter(torch.randn(dim))  # 类别标记
        self.dropout = nn.Dropout(emb_dropout)  # Dropout 正则化

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # Transformer 模块

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, num_classes)  # 线性变换
        )

    def forward(self, series):
        x = self.to_patch_embedding(series)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)  # 重复类别标记

        x, ps = pack([cls_tokens, x], 'b * d')  # 打包张量

        x += self.pos_embedding[:, :(n + 1)]  # 加上位置编码
        x = self.dropout(x)  # Dropout 正则化

        x = self.transformer(x)  # Transformer 模块

        cls_tokens, _ = unpack(x, ps, 'b * d')  # 解包张量

        return self.mlp_head(cls_tokens)  # MLP 头部

if __name__ == '__main__':

    v = ViT(
        seq_len = 256,
        patch_size = 16,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 8,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

    time_series = torch.randn(4, 3, 256)
    logits = v(time_series) # (4, 1000)
posted @ 2024-06-28 14:07  绝不原创的飞龙  阅读(9)  评论(0编辑  收藏  举报