Lucidrains-系列项目源码解析-三十四-

Lucidrains 系列项目源码解析(三十四)

.\lucidrains\pixel-level-contrastive-learning\pixel_level_contrastive_learning\__init__.py

# 从 pixel_level_contrastive_learning.pixel_level_contrastive_learning 模块中导入 PPM 和 PixelCL 类
from pixel_level_contrastive_learning.pixel_level_contrastive_learning import PPM, PixelCL

Pixel-level Contrastive Learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch. In addition to doing contrastive learning on the pixel level, the online network further passes the pixel level representations to a Pixel Propagation Module and enforces a similarity loss to the target network. They beat all previous unsupervised and supervised methods in segmentation tasks.

Install

$ pip install pixel-level-contrastive-learning

Usage

Below is an example of how you would use the framework to self-supervise training of a resnet, taking the output of layer 4 (8 x 8 'pixels').

import torch
from pixel_level_contrastive_learning import PixelCL
from torchvision import models
from tqdm import tqdm

resnet = models.resnet50(pretrained=True)

learner = PixelCL(
    resnet,
    image_size = 256,
    hidden_layer_pixel = 'layer4',  # leads to output of 8x8 feature map for pixel-level learning
    hidden_layer_instance = -2,     # leads to output for instance-level learning
    projection_size = 256,          # size of projection output, 256 was used in the paper
    projection_hidden_size = 2048,  # size of projection hidden dimension, paper used 2048
    moving_average_decay = 0.99,    # exponential moving average decay of target encoder
    ppm_num_layers = 1,             # number of layers for transform function in the pixel propagation module, 1 was optimal
    ppm_gamma = 2,                  # sharpness of the similarity in the pixel propagation module, already at optimal value of 2
    distance_thres = 0.7,           # ideal value is 0.7, as indicated in the paper, which makes the assumption of each feature map's pixel diagonal distance to be 1 (still unclear)
    similarity_temperature = 0.3,   # temperature for the cosine similarity for the pixel contrastive loss
    alpha = 1.,                      # weight of the pixel propagation loss (pixpro) vs pixel CL loss
    use_pixpro = True,               # do pixel pro instead of pixel contrast loss, defaults to pixpro, since it is the best one
    cutout_ratio_range = (0.6, 0.8)  # a random ratio is selected from this range for the random cutout
).cuda()

opt = torch.optim.Adam(learner.parameters(), lr=1e-4)

def sample_batch_images():
    return torch.randn(10, 3, 256, 256).cuda()

for _ in tqdm(range(100000)):
    images = sample_batch_images()
    loss = learner(images) # if positive pixel pairs is equal to zero, the loss is equal to the instance level loss

    opt.zero_grad()
    loss.backward()
    print(loss.item())
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# after much training, save the improved model for testing on downstream task
torch.save(resnet, 'improved-resnet.pt')

You can also return the number of positive pixel pairs on forward, for logging or other purposes

loss, positive_pairs = learner(images, return_positive_pairs = True)

Citations

@misc{xie2020propagate,
    title={Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning}, 
    author={Zhenda Xie and Yutong Lin and Zheng Zhang and Yue Cao and Stephen Lin and Han Hu},
    year={2020},
    eprint={2011.10043},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

.\lucidrains\pixel-level-contrastive-learning\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'pixel-level-contrastive-learning',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.1.1',
  # 许可证信息
  license='MIT',
  # 描述信息
  description = 'Pixel-Level Contrastive Learning',
  # 作者信息
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/pixel-level-contrastive-learning',
  # 关键词
  keywords = ['self-supervised learning', 'artificial intelligence'],
  # 安装依赖
  install_requires=[
      'einops',
      'torch>=1.6',
      'kornia>=0.4.0'
  ],
  # 分类信息
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\point-transformer-pytorch\point_transformer_pytorch\multihead_point_transformer_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 einops 库中导入 repeat, rearrange 函数

# helpers

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 获取张量的最大值
def max_value(t):
    return torch.finfo(t.dtype).max

# 在指定维度上对批量索引进行选择的函数
def batched_index_select(values, indices, dim = 1):
    # 获取值的维度
    value_dims = values.shape[(dim + 1):]
    # 获取值和索引的形状
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    # 将索引扩展到与值相同的维度
    indices = indices[(..., *((None,) * len(value_dims))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# classes

# 多头点变换器层类
class MultiheadPointTransformerLayer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 4,
        dim_head = 64,
        pos_mlp_hidden_dim = 64,
        attn_mlp_hidden_mult = 4,
        num_neighbors = None
    ):
        super().__init__()
        self.heads = heads
        inner_dim = dim_head * heads

        self.num_neighbors = num_neighbors

        # 线性变换,将输入维度映射到内部维度的三倍
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # 线性变换,将内部维度映射回输出维度
        self.to_out = nn.Linear(inner_dim, dim)

        # 位置多层感知机
        self.pos_mlp = nn.Sequential(
            nn.Linear(3, pos_mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(pos_mlp_hidden_dim, inner_dim)
        )

        attn_inner_dim = inner_dim * attn_mlp_hidden_mult

        # 注意力多层感知机
        self.attn_mlp = nn.Sequential(
            nn.Conv2d(inner_dim, attn_inner_dim, 1, groups = heads),
            nn.ReLU(),
            nn.Conv2d(attn_inner_dim, inner_dim, 1, groups = heads),
        )
    # 定义前向传播函数,接受输入 x、位置 pos 和可选的掩码 mask
    def forward(self, x, pos, mask = None):
        # 获取输入 x 的维度信息
        n, h, num_neighbors = x.shape[1], self.heads, self.num_neighbors

        # 获取查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 将查询、键、值按照头数 h 进行分组
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 计算相对位置嵌入
        rel_pos = rearrange(pos, 'b i c -> b i 1 c') - rearrange(pos, 'b j c -> b 1 j c')
        rel_pos_emb = self.pos_mlp(rel_pos)

        # 将相对位置嵌入按照头数 h 进行分组
        rel_pos_emb = rearrange(rel_pos_emb, 'b i j (h d) -> b h i j d', h = h)

        # 使用查询减去键。这可能是比点积更好的归纳偏差,适用于点云
        qk_rel = rearrange(q, 'b h i d -> b h i 1 d') - rearrange(k, 'b h j d -> b h 1 j d')

        # 准备掩码
        if exists(mask):
            mask = rearrange(mask, 'b i -> b i 1') * rearrange(mask, 'b j -> b 1 j')

        # 扩展值
        v = repeat(v, 'b h j d -> b h i j d', i = n)

        # 如果指定了 num_neighbors,则确定每个点的 k 近邻
        if exists(num_neighbors) and num_neighbors < n:
            rel_dist = rel_pos.norm(dim = -1)

            if exists(mask):
                mask_value = max_value(rel_dist)
                rel_dist.masked_fill_(~mask, mask_value)

            dist, indices = rel_dist.topk(num_neighbors, largest = False)

            indices_with_heads = repeat(indices, 'b i j -> b h i j', h = h)

            v = batched_index_select(v, indices_with_heads, dim = 3)
            qk_rel = batched_index_select(qk_rel, indices_with_heads, dim = 3)
            rel_pos_emb = batched_index_select(rel_pos_emb, indices_with_heads, dim = 3)

            if exists(mask):
                mask = batched_index_select(mask, indices, dim = 2)

        # 将相对位置嵌入添加到值中
        v = v + rel_pos_emb

        # 使用注意力 MLP,确保先添加相对位置嵌入
        attn_mlp_input = qk_rel + rel_pos_emb
        attn_mlp_input = rearrange(attn_mlp_input, 'b h i j d -> b (h d) i j')

        sim = self.attn_mlp(attn_mlp_input)

        # 掩码
        if exists(mask):
            mask_value = -max_value(sim)
            mask = rearrange(mask, 'b i j -> b 1 i j')
            sim.masked_fill_(~mask, mask_value)

        # 注意力
        attn = sim.softmax(dim = -2)

        # 聚合
        v = rearrange(v, 'b h i j d -> b i j (h d)')
        agg = einsum('b d i j, b i j d -> b i d', attn, v)

        # 合并头
        return self.to_out(agg)

.\lucidrains\point-transformer-pytorch\point_transformer_pytorch\point_transformer_pytorch.py

import torch
from torch import nn, einsum
from einops import repeat

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 获取张量的最大值
def max_value(t):
    return torch.finfo(t.dtype).max

# 在给定维度上对批量索引进行选择
def batched_index_select(values, indices, dim = 1):
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    indices = indices[(..., *((None,) * len(value_dims))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# 类

class PointTransformerLayer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        pos_mlp_hidden_dim = 64,
        attn_mlp_hidden_mult = 4,
        num_neighbors = None
    ):
        super().__init__()
        self.num_neighbors = num_neighbors

        # 线性变换,将输入维度映射到查询、键、值的维度
        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

        # 位置信息的多层感知机
        self.pos_mlp = nn.Sequential(
            nn.Linear(3, pos_mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(pos_mlp_hidden_dim, dim)
        )

        # 注意力机制的多层感知机
        self.attn_mlp = nn.Sequential(
            nn.Linear(dim, dim * attn_mlp_hidden_mult),
            nn.ReLU(),
            nn.Linear(dim * attn_mlp_hidden_mult, dim),
        )

    def forward(self, x, pos, mask = None):
        n, num_neighbors = x.shape[1], self.num_neighbors

        # 获取查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 计算相对位置嵌入
        rel_pos = pos[:, :, None, :] - pos[:, None, :, :]
        rel_pos_emb = self.pos_mlp(rel_pos)

        # 使用查询减去键。我认为这是点云的更好归纳偏差,而不是点积
        qk_rel = q[:, :, None, :] - k[:, None, :, :]

        # 准备掩码
        if exists(mask):
            mask = mask[:, :, None] * mask[:, None, :]

        # 扩展值
        v = repeat(v, 'b j d -> b i j d', i = n)

        # 如果指定了每个点的 k 近邻数,则确定 k 个最近邻
        if exists(num_neighbors) and num_neighbors < n:
            rel_dist = rel_pos.norm(dim = -1)

            if exists(mask):
                mask_value = max_value(rel_dist)
                rel_dist.masked_fill_(~mask, mask_value)

            dist, indices = rel_dist.topk(num_neighbors, largest = False)

            v = batched_index_select(v, indices, dim = 2)
            qk_rel = batched_index_select(qk_rel, indices, dim = 2)
            rel_pos_emb = batched_index_select(rel_pos_emb, indices, dim = 2)
            mask = batched_index_select(mask, indices, dim = 2) if exists(mask) else None

        # 将相对位置嵌入添加到值中
        v = v + rel_pos_emb

        # 使用注意力多层感知机,确保先添加相对位置嵌入
        sim = self.attn_mlp(qk_rel + rel_pos_emb)

        # 掩码
        if exists(mask):
            mask_value = -max_value(sim)
            sim.masked_fill_(~mask[..., None], mask_value)

        # 注意力
        attn = sim.softmax(dim = -2)

        # 聚合
        agg = einsum('b i j d, b i j d -> b i d', attn, v)
        return agg

.\lucidrains\point-transformer-pytorch\point_transformer_pytorch\__init__.py

# 从 point_transformer_pytorch 模块中导入 PointTransformerLayer 类
from point_transformer_pytorch.point_transformer_pytorch import PointTransformerLayer
# 从 point_transformer_pytorch 模块中导入 MultiheadPointTransformerLayer 类
from point_transformer_pytorch.multihead_point_transformer_pytorch import MultiheadPointTransformerLayer

Point Transformer - Pytorch

Implementation of the Point Transformer self-attention layer, in Pytorch. The simple circuit above seemed to have allowed their group to outperform all previous methods in point cloud classification and segmentation.

Install

$ pip install point-transformer-pytorch

Usage

import torch
from point_transformer_pytorch import PointTransformerLayer

attn = PointTransformerLayer(
    dim = 128,
    pos_mlp_hidden_dim = 64,
    attn_mlp_hidden_mult = 4
)

feats = torch.randn(1, 16, 128)
pos = torch.randn(1, 16, 3)
mask = torch.ones(1, 16).bool()

attn(feats, pos, mask = mask) # (1, 16, 128)

This type of vector attention is much more expensive than the traditional one. In the paper, they used k-nearest neighbors on the points to exclude attention on faraway points. You can do the same with a single extra setting.

import torch
from point_transformer_pytorch import PointTransformerLayer

attn = PointTransformerLayer(
    dim = 128,
    pos_mlp_hidden_dim = 64,
    attn_mlp_hidden_mult = 4,
    num_neighbors = 16          # only the 16 nearest neighbors would be attended to for each point
)

feats = torch.randn(1, 2048, 128)
pos = torch.randn(1, 2048, 3)
mask = torch.ones(1, 2048).bool()

attn(feats, pos, mask = mask) # (1, 16, 128)

Citations

@misc{zhao2020point,
    title={Point Transformer}, 
    author={Hengshuang Zhao and Li Jiang and Jiaya Jia and Philip Torr and Vladlen Koltun},
    year={2020},
    eprint={2012.09164},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

.\lucidrains\point-transformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'point-transformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.1.5',  # 版本号
  license='MIT',  # 许可证
  description = 'Point Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/point-transformer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'transformers',
    'attention mechanism',
    'point clouds'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\ponder-transformer\ponder_transformer\ponder_transformer.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# 常量

ABS_MAX_STEPS = 100

# 辅助函数

def exists(val):
    return val is not None

# 类

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

def FeedForward(dim, mult = 4):
    return nn.Sequential(
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Linear(dim * mult, dim)
    )

class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        causal = False
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x, mask = None):
        n, h, device = x.shape[1], self.heads, x.device
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        mask_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
            sim = sim.masked_fill(mask, mask_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), device = device).triu(j - i + 1).bool()
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# pondering 类和辅助函数

def pad_to(t, padding, dim = -1, value = 0.):
    if dim > 0:
        dim = dim - t.ndim
    zeroes = -dim - 1
    return F.pad(t, (*((0, 0) * zeroes), *padding), value = value)

def safe_cumprod(t, eps = 1e-10, dim = -1):
    t = torch.clip(t, min = eps, max = 1.)
    return torch.exp(torch.cumsum(torch.log(t), dim = dim))

def exclusive_cumprod(t, dim = -1):
    cum_prod = safe_cumprod(t, dim = dim)
    return pad_to(cum_prod, (1, -1), value = 1., dim = dim)

def calc_geometric(l, dim = -1):
    return exclusive_cumprod(1 - l, dim = dim) * l

# 主类

class Block(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        causal = False,
        ff_mult = 4
    ):
        super().__init__()
        self.causal = causal
        self.attn = PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal))
        self.ff = PreNorm(dim, FeedForward(dim = dim, mult = ff_mult))

        self.to_halt_logits = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... () -> ...')
        )

    def forward(self, x, mask = None):
        x = self.attn(x, mask = mask) + x
        x = self.ff(x) + x

        if self.causal:
            denom = torch.arange(x.shape[-2], device = x.device)
            denom = rearrange(denom, 'n -> () n ()')
            halt_input = x.cumsum(dim = 1) / (denom + 1)
        else:
            halt_input = x.mean(dim = 1)

        halt_logits = self.to_halt_logits(halt_input)

        return x, halt_logits

class PonderTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        max_seq_len,
        causal = True,
        dim_head = 64,
        heads = 8,
        ponder_kl_div_loss_weight = 0.01,
        ponder_lambda_p = 0.2,
        ponder_epsilon = 0.05,
        eps = 1e-20
        ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化epsilon值
        self.eps = eps
        # 初始化causal值
        self.causal = causal
        # 初始化序列长度为最大序列长度
        self.seq_len = max_seq_len
        # 创建token嵌入层,将token映射到指定维度
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建位置嵌入层,将位置映射到指定维度
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 计算最大步数

        # 计算停止概率的阈值
        thres = 1 - ponder_epsilon
        # 计算几何级数停止概率
        halt_probs = calc_geometric(torch.full((ABS_MAX_STEPS,), ponder_lambda_p))
        # 计算停止概率的累积和
        cum_halt_probs = halt_probs.cumsum(dim = 0)
        # 训练最大步数为满足停止概率小于阈值的步数
        self.train_max_steps = (cum_halt_probs < thres).sum().item()

        # 初始化ponder_lambda_p值
        self.ponder_lambda_p = ponder_lambda_p
        # 初始化ponder_kl_div_loss_weight值
        self.ponder_kl_div_loss_weight = ponder_kl_div_loss_weight

        # pondering block

        # 创建Block模块
        self.block = Block(
            dim = dim,
            dim_head = dim_head,
            heads = heads,
            causal = causal
        )

        # 隐藏状态到'Y' - 输出

        # 创建输出层,包括LayerNorm和线性层
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )

.\lucidrains\ponder-transformer\ponder_transformer\__init__.py

# 从 ponder_transformer.ponder_transformer 模块中导入 PonderTransformer 类
from ponder_transformer.ponder_transformer import PonderTransformer

Ponder(ing) Transformer

Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of the input sequence, using the scheme from the PonderNet paper. Will also try to abstract out a pondering module that can be used with any block that returns an output with the halting probability.

This repository would not have been possible without repeated viewings of Yannic's educational video

Install

$ pip install ponder-transformer

Usage

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512
)

mask = torch.ones(1, 512).bool()

x = torch.randint(0, 20000, (1, 512))
y = torch.randint(0, 20000, (1, 512))

loss = model(x, labels = y, mask = mask)
loss.backward()

Now you can set the model to .eval() mode and it will terminate early when all samples of the batch have emitted a halting signal

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512,
    causal = True
)

x = torch.randint(0, 20000, (2, 512))
mask = torch.ones(2, 512).bool()

model.eval() # setting to eval makes it return the logits as well as the halting indices

logits, layer_indices = model(x,  mask = mask) # (2, 512, 20000), (2)

# layer indices will contain, for each batch element, which layer they exited

Citations

@misc{banino2021pondernet,
    title   = {PonderNet: Learning to Ponder}, 
    author  = {Andrea Banino and Jan Balaguer and Charles Blundell},
    year    = {2021},
    eprint  = {2107.05407},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\ponder-transformer\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'ponder-transformer',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.8',  # 版本号
  license='MIT',  # 许可证
  description = 'Ponder Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/ponder-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'adaptive computation time'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\product-key-memory\product_key_memory\product_key_memory.py

# 导入 math、torch 库
import math
import torch
# 从 torch 库中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 库中导入 rearrange、Rearrange、Reduce 函数
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# 从 colt5_attention.py 文件中导入 topk 函数作为 coor_descent_topk

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 计算输入张量的自然对数,避免输入值小于给定的最小值
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 初始化函数

# 初始化张量,均值为 0,标准差为 1/sqrt(dim)
def init_(t, dim = None):
    dim = default(dim, t.shape[-1])
    std = 1. / math.sqrt(dim)
    return nn.init.normal_(t, mean=0, std=std)

# 优化器函数

# 从列表 l 中减去列表 r 中的元素
def list_subtract(l, r):
    return [el for el in l if el not in set(r)]

# 获取 PKM 模块中的值参数
def fetch_pkm_value_parameters(module):
    params = []
    for m in module.modules():
        if isinstance(m, PKM):
            params.append(m.values.weight)
    rest = list_subtract(module.parameters(), params)
    return params, rest

# 获取优化器参数
def fetch_optimizer_parameters(module, pkm_learning_rate = 1e-2):
    pkm_params, rest = fetch_pkm_value_parameters(module)
    return [{'params': rest}, {'params': pkm_params, 'lr': pkm_learning_rate}]

# 归一化函数

# 一维掩码批归一化类
class MaskedBatchNorm1D(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(
        self,
        x,
        mask = None
    ):
        if exists(mask):
            initial_x = x
            x = x[mask]

        x = self.fn(x)

        if exists(mask):
            initial_x[mask] = x
            x = initial_x

        return x

# PKM 模块
class PKM(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        num_keys = 128,
        topk = 32,
        dim_head = 128,
        input_dropout = 0.,
        query_dropout = 0.,
        value_dropout = 0.,
        attn_dropout = 0.,
        use_layernorm = True,
        pre_layernorm = False,
        differentiable_topk = False,
        concat_values_and_combine = False,
        norm_output = False
    ):
        super().__init__()
        self.topk = topk
        self.heads = heads
        self.num_keys = num_keys

        dim_query = dim_head * heads * 2
        self.to_queries = nn.Linear(dim, dim_query, bias = False)

        # 预层归一化模式

        self.pre_layernorm = nn.LayerNorm(dim) if pre_layernorm else nn.Identity()

        # 批归一化会破坏因果性

        self.use_layernorm = use_layernorm

        if use_layernorm:
            self.norm = nn.LayerNorm(dim_head)
        else:
            self.norm = MaskedBatchNorm1D(nn.BatchNorm1d(dim_head))

        # 键

        self.keys = nn.Parameter(torch.zeros(heads, num_keys, 2, dim_head))
        init_(self.keys)

        # 值

        self.concat_values_and_combine = concat_values_and_combine

        if concat_values_and_combine:
            values = nn.Embedding(num_keys ** 2, dim_head)

            self.values = nn.Sequential(
                values,
                Reduce('b (h k) d -> b h d', 'sum', h = heads),
                Rearrange('b n d -> b (n d)'),
                nn.Linear(dim_head * heads, dim, bias = False)
            )
        else:
            values = nn.EmbeddingBag(num_keys ** 2, dim, mode = 'sum')
            self.values = values

        init_(values.weight)

        # 丢弃

        self.input_dropout = nn.Dropout(input_dropout)
        self.query_dropout = nn.Dropout(query_dropout)
        self.value_dropout = nn.Dropout(value_dropout)
        self.attn_dropout = nn.Dropout(attn_dropout)

        # 使用可微分的 topk,基于坐标下降

        self.differentiable_topk = differentiable_topk

        # https://arxiv.org/abs/2302.06461
        # 声称通过简单地对输出进行 layernorm 来提高 softmax 键/值网络的性能

        self.output_norm = nn.LayerNorm(dim) if norm_output else nn.Identity()

    def forward(
        self,
        x,
        input_mask = None,
        gumbel_noise_scale = 0.,
        **kwargs
        ):
        # 解构 x 的形状,分别赋值给 b, t, h
        b, t, h = *x.shape[:2], self.heads

        # 对输入进行预层归一化
        x = self.pre_layernorm(x)
        # 对输入进行输入层的 dropout
        x = self.input_dropout(x)

        # 将输入转换为查询
        queries = self.to_queries(x)

        # 分割查询头

        queries = rearrange(queries, 'b t (p h d) -> (b p h) t d', p = 2, h = h)

        # 对查询进行归一化和 dropout

        norm_kwargs = dict(mask = input_mask) if not self.use_layernorm else dict()
        queries = self.norm(queries, **norm_kwargs)
        queries = self.query_dropout(queries)

        # 准备查询

        queries = rearrange(queries, '(b p h) t d -> p b t h d', p = 2, h = h)

        # 与键计算相似度

        dots = einsum('p b t h d, h n p d -> b t h p n', queries, self.keys)

        # gumbel 噪声

        if gumbel_noise_scale > 0.:
            dots = dots + gumbel_noise(dots) * gumbel_noise_scale

        # topk 分数

        if self.differentiable_topk:
            scores, indices, *_ = coor_descent_topk(dots, k = self.topk, fused = True)
        else:
            scores, indices = dots.topk(k = self.topk, dim = -1)

        # 分数进行因式分解

        (scores_x, scores_y), (indices_x, indices_y) = map(lambda t: t.chunk(2, dim = 3), (scores, indices))

        all_topk = self.topk ** 2

        all_scores = rearrange((
            rearrange(scores_x, '... k -> ... k 1') +
            rearrange(scores_y, '... k -> ... 1 k')
        ), 'b t h ... -> b t h (...)')

        all_indices = rearrange((
            rearrange(indices_x, '... k -> ... k 1') * self.num_keys +
            rearrange(indices_y, '... k -> ... 1 k')
        ), 'b t h ... -> b t h (...)')

        final_topk, final_indices = all_scores.topk(self.topk, dim=-1)
        value_indices = all_indices.gather(-1, final_indices)

        # 注意力

        attn = final_topk.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        value_indices, attn = map(lambda t: rearrange(t, 'b t h k -> (b t) (h k)'), (value_indices, attn))

        # 聚合

        if self.concat_values_and_combine:
            out = self.values(value_indices)
        else:
            out = self.values(value_indices, per_sample_weights = attn)

        out = self.value_dropout(out)

        # 可能对输出进行层归一化

        out = self.output_norm(out)

        return rearrange(out, '(b t) d -> b t d', b = b)

.\lucidrains\product-key-memory\product_key_memory\transformer.py

# 导入所需的库
import json

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

from product_key_memory.product_key_memory import PKM

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 采样辅助函数

# 评估装饰器,用于在评估模式下运行函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 获取前k个最大值的函数
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
    probs.scatter_(1, ind, val)
    return probs

# 前馈网络

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        dim_inner = heads * dim_head

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
        self.to_out = nn.Linear(dim_inner, dim, bias = False)

    def forward(self, x):
        n, h, device = x.shape[1], self.heads, x.device

        q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', h = h, qkv = 3)

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 创建前馈网络
def FeedForward(dim, mult = 4):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Linear(dim * mult, dim)
    )

# Transformer

class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        depth,
        seq_len,
        pkm_layers = None,
        dim_head = 64,
        heads = 8,
        pad_value = 0,
        pkm_kwargs: dict = dict()
    ):
        super().__init__()
        self.seq_len = seq_len
        self.pad_value = pad_value

        pkm_layers = default(pkm_layers, depth // 2)
        pkm_layers = (pkm_layers,) if not isinstance(pkm_layers, tuple) else pkm_layers
        pkm_layers = set(pkm_layers)

        if len(pkm_layers) > 0:
            print(f'using PKM at layers {pkm_layers}')
            print(json.dumps(pkm_kwargs, indent = 2)
            print('\n\n')

        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(seq_len, dim)

        self.layers = nn.ModuleList([])

        for ind in range(depth):
            layer = ind + 1
            use_pkm = layer in pkm_layers

            self.layers.append(nn.ModuleList([
                Attention(dim, dim_head = dim_head, heads = heads),
                FeedForward(dim) if not use_pkm else PKM(dim, **pkm_kwargs)
            ]))

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    # 生成函数,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prompt,
        seq_len,
        temperature = 1.0,
        filter_thres = 0.9
    ):
        b, n, device = *prompt.shape, prompt.device

        out = prompt

        for _ in range(seq_len):
            logits = self.forward(out[:, -self.seq_len:], return_loss = False)
            logits = logits[:, -1]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim = -1)

            sample = torch.multinomial(probs, 1)
            out = torch.cat((out, sample), dim = -1)

        return out[:, n:]
    # 定义前向传播函数,接受输入 x 和是否返回损失值的标志 return_loss
    def forward(self, x, return_loss = True):

        # 如果需要返回损失值,则将输入 x 的最后一个元素作为标签 labels
        if return_loss:
            x, labels = x[:, :-1], x[:, 1:]

        # 对输入 x 进行 token embedding
        x = self.token_emb(x)
        # 添加位置编码到输入 x
        x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))

        # 遍历每个注意力层和前馈层
        for attn, ff in self.layers:
            # 使用注意力层处理输入 x,并将结果与原始输入相加
            x = attn(x) + x
            # 使用前馈层处理输入 x,并将结果与原始输入相加
            x = ff(x) + x

        # 将处理后的结果传递给输出层,得到 logits
        logits = self.to_logits(x)

        # 如果不需要返回损失值,则直接返回 logits
        if not return_loss:
            return logits

        # 重新排列 logits 的维度
        logits = rearrange(logits, 'b c n -> b n c')

        # 计算交叉熵损失并返回
        return F.cross_entropy(logits, labels)

.\lucidrains\product-key-memory\product_key_memory\__init__.py

# 从 product_key_memory 模块中导入 PKM、fetch_pkm_value_parameters 和 fetch_optimizer_parameters 函数
from product_key_memory.product_key_memory import PKM, fetch_pkm_value_parameters, fetch_optimizer_parameters

# 将 PKM 赋值给 ProductKeyMemory,简化使用
ProductKeyMemory = PKM

Product Key Memory

PyPI version

Standalone Product Key Memory module for augmenting Transformer models

Install

$ pip install product-key-memory

Usage

Replace the feedforwards in a Transformer with the following

import torch
from product_key_memory import PKM

pkm = PKM(
    dim = 512,
    heads = 4,
    dim_head = 128,       # keep at 128 for best results
    num_keys = 256,       # number of subkeys, # values will be num_keys ^ 2
    topk = 32             # the top number of subkeys to select
)

x = torch.randn(1, 1024, 512)
mask = torch.ones((1, 1024)).bool()
values = pkm(x, input_mask = mask) # (1, 1024, 512)

Learning Rates

To give different learning rates to the value parameters of the product-key-memory network, use the following helper function.

from torch.optim import Adam
from product_key_memory import fetch_pkm_value_parameters

# this helper function, for your root model, finds all the PKM models and the embedding bag weight parameters
pkm_parameters, other_parameters = fetch_pkm_value_parameters(model)

optim = Adam([
    {'params': other_parameters},
    {'params': pkm_parameters, 'lr': 1e-2}
], lr=1e-3)

Or, if product-key-memory parameters are the only other parameters you have a different learning rate for

from torch.optim import Adam
from product_key_memory import fetch_optimizer_parameters

parameters = fetch_optimizer_parameters(model) # automatically creates array of parameter settings with learning rate set at 1e-2 for pkm values
optim = Adam(parameters, lr=1e-3)

Appreciation

Special thanks go to Aran for encouraging me to look into this, and to Madison May for his educational blog post, which helped me understand this better.

Todo

Citations

@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{liu2020evolving,
    title   = {Evolving Normalization-Activation Layers},
    author  = {Hanxiao Liu and Andrew Brock and Karen Simonyan and Quoc V. Le},
    year    = {2020},
    eprint  = {2004.02967},
    archivePrefix = {arXiv}
}
@article{Shen2023ASO,
    title   = {A Study on ReLU and Softmax in Transformer},
    author  = {Kai Shen and Junliang Guo and Xuejiao Tan and Siliang Tang and Rui Wang and Jiang Bian},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2302.06461},
    url     = {https://api.semanticscholar.org/CorpusID:256827573}
}

.\lucidrains\product-key-memory\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
    name = 'product_key_memory',  # 包的名称
    packages = find_packages(),  # 查找所有包
    version = '0.2.10',  # 版本号
    license = 'MIT',  # 许可证
    description = 'Product Key Memory',  # 描述
    long_description_content_type = 'text/markdown',  # 长描述内容类型
    author = 'Aran Komatsuzaki, Phil Wang',  # 作者
    author_email = 'aran1234321@gmail.com, lucidrains@gmail.com',  # 作者邮箱
    url = 'https://github.com/lucidrains/product-key-memory',  # 项目链接
    keywords = [
        'transformers',  # 关键词:transformers
        'artificial intelligence'  # 关键词:artificial intelligence
    ],
    install_requires=[
        'colt5-attention>=0.10.14',  # 安装所需的依赖包
        'einops>=0.6',  # 安装所需的依赖包
        'torch'  # 安装所需的依赖包
    ],
    classifiers=[
        'Development Status :: 4 - Beta',  # 分类器:开发状态为Beta
        'Intended Audience :: Developers',  # 分类器:面向的受众为开发者
        'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器:主题为科学/工程和人工智能
        'License :: OSI Approved :: MIT License',  # 分类器:许可证为MIT
        'Programming Language :: Python :: 3.6',  # 分类器:编程语言为Python 3.6
    ],
)

.\lucidrains\product-key-memory\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from product_key_memory.transformer import Transformer

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 实例化 Transformer 模型
model = Transformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    seq_len = SEQ_LEN,
    pkm_layers = (4,),
    pkm_kwargs = dict(
        heads = 4,
        num_keys = 128,
        topk = 32,
        dim_head = 128,
        input_dropout = 0.,
        query_dropout = 0.,
        value_dropout = 0.,
        attn_dropout = 0.,
        use_layernorm = True,
        pre_layernorm = True,
        differentiable_topk = False,
        concat_values_and_combine = False
    )
).cuda()

# 准备 enwik8 数据

# 读取并处理 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义自定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str, "\n")

.\lucidrains\progen\generate_data.py

# 导入所需的库
import os
import gzip
import click
import re
import random
from math import ceil
from functools import partial
from itertools import islice, chain
from operator import itemgetter

from pyfaidx import Faidx

import numpy as np
from random import random
from pathlib import Path

import toml
from google.cloud import storage

from prefect import Parameter, task, Flow

from progen_transformer.data import with_tfrecord_writer
from progen_transformer.utils import clear_directory_

# 常量定义
GCS_WRITE_TIMEOUT = 60 * 30
TMP_DIR = Path('./.tmp')

# 定义函数

# 根据给定的排序函数对字典进行排序
def order_dict_by(d, fn):
    keys = fn(d.keys())
    return dict(tuple(map(lambda k: (k, d[k]), keys)))

# 从描述中提取注释信息
def get_annotations_from_description(config, description):
    taxonomy_matches = re.findall(r'Tax=([a-zA-Z\s]*)\s[a-zA-Z\=]', description)
    annotations = dict()

    if len(taxonomy_matches) > 0:
        annotations['tax'] = taxonomy_matches[0]

    return annotations

# 将 fasta 行转换为序列字符串
def fasta_row_to_sequence_strings(config, fa, uid):
    seq_len = fa.index[uid].rlen
    seq = str(fa.fetch(uid, 1, seq_len))
    description = fa.get_long_name(uid)

    sequences = []
    annotations = get_annotations_from_description(config, description)
    # todo: gather annotations from GO

    if len(annotations) > 0:
        sort_annot_by = random.shuffle if not config['sort_annotations'] else sorted
        annotations = order_dict_by(annotations, sort_annot_by)

        annotation_str = [f"[{annot_name}={annot}]" for annot_name, annot in annotations.items()]
        annotation_str = ' '.join(annotation_str)

        seq_annot_pair = (annotation_str, seq)

        if random() <= config['prob_invert_seq_annotation']:
            seq_annot_pair = tuple(reversed(seq_annot_pair))

        sequence = ' # '.join(seq_annot_pair)
        sequence = sequence.encode('utf-8')
        sequences.append(sequence)

    sequence = f'# {seq}'
    sequence = sequence.encode('utf-8')
    sequences.append(sequence)

    return sequences

# 处理并写入临时文件
def process_and_write_to_tmp_file(i, seq_str):
    filename = TMP_DIR / str(i)
    with gzip.open(str(filename), 'wb') as f:
        f.write(seq_str)

# 对每个元素应用函数
def foreach(fn, it):
    for el in it:
        fn(*el)

# DAG 函数

# 将 fasta 文件转换为临时文件
@task
def fasta_to_tmp_files(config):
    clear_directory_(TMP_DIR)

    print('reading from fasta')
    fa = Faidx(config['read_from'], sequence_always_upper = True)

    print('filtering by length')
    it = iter(fa.index.items())
    it = filter(lambda el: el[1].rlen <= config['max_seq_len'], it)

    print('parallel processing to tmp files')
    it = islice(it, 0, config['num_samples'])
    it = map(itemgetter(0), it)

    fasta_to_seq_fn = partial(fasta_row_to_sequence_strings, config, fa)
    it = map(fasta_to_seq_fn, it)
    it = enumerate(chain.from_iterable(it))
    foreach(process_and_write_to_tmp_file, it)

# 将文件转换为 tfrecords
@task
def files_to_tfrecords(config):
    filenames = [*TMP_DIR.glob('**/*')]
    num_samples = len(filenames)
    num_valids = ceil(config['fraction_valid_data'] * num_samples)

    num_sequences_per_file = config['num_sequences_per_file']

    # 分割出验证序列

    permuted_sequences = np.random.permutation(num_samples)
    valid_seqs, train_seqs = np.split(permuted_sequences, [num_valids])

    # 清空写入目录

    write_to = config['write_to']
    upload_gcs = write_to.startswith('gs://')

    if upload_gcs:
        write_to = write_to[5:]
        client = storage.Client()
        bucket_name = write_to

        bucket = client.get_bucket(bucket_name)
        bucket.delete_blobs(list(bucket.list_blobs()))

    write_to_path = Path(write_to)
    clear_directory_(write_to_path)

    # 循环并将所有训练和验证文件写入 tfrecords
    # 遍历训练集和验证集,每个元组包含序列类型和序列数据
    for (seq_type, seqs) in (('train', train_seqs), ('valid', valid_seqs)):
        # 计算需要拆分的文件数量
        num_split = ceil(seqs.shape[0] / num_sequences_per_file)
        # 对序列数据进行拆分,每个文件包含 num_sequences_per_file 个序列
        for file_index, indices in enumerate(np.array_split(seqs, num_split)):
            # 获取当前文件中序列的数量
            num_sequences = len(indices)
            # 构建 TFRecord 文件名,包含文件索引、序列数量和序列类型
            tfrecord_filename = f'{file_index}.{num_sequences}.{seq_type}.tfrecord.gz'
            # 构建 TFRecord 文件路径
            tfrecord_path = str(write_to_path / tfrecord_filename)

            # 使用 TFRecord 写入器打开文件,写入序列数据
            with with_tfrecord_writer(tfrecord_path) as write:
                # 遍历当前文件中的序列索引
                for index in indices:
                    # 获取当前序列对应的文件名
                    filename = filenames[index]
                    # 使用 gzip 打开文件,读取数据并写入 TFRecord 文件
                    with gzip.open(filename, 'rb') as f:
                        write(f.read())

            # 如果需要上传到 Google Cloud Storage
            if upload_gcs:
                # 创建一个存储桶对象
                blob = bucket.blob(tfrecord_filename)
                # 从本地文件上传 TFRecord 文件到存储桶,设置超时时间
                blob.upload_from_filename(tfrecord_path, timeout = GCS_WRITE_TIMEOUT)
# 创建一个名为'parse-fasta'的Flow对象
with Flow('parse-fasta') as flow:
    # 创建一个名为'config'的参数,必须提供数值
    config = Parameter('config', required = True)
    # 调用fasta_to_tmp_files函数,传入config参数
    fasta_to_tmp_files(config = config)
    # 调用files_to_tfrecords函数,传入config参数

@click.command()
# 添加一个名为'data_dir'的命令行选项,默认值为'./configs/data'
@click.option('--data_dir', default = './configs/data')
# 添加一个名为'name'的命令行选项,默认值为'default'
@click.option('--name', default = 'default')
def main(
    data_dir,
    name
):
    # 将data_dir转换为Path对象
    data_dir = Path(data_dir)
    # 构建配置文件路径
    config_path = data_dir / f'{name}.toml'
    # 断言配置文件路径存在
    assert config_path.exists(), f'config does not exist at {str(config_path)}'

    # 读取配置文件内容并解析为字典
    config = toml.loads(config_path.read_text())
    # 运行Flow对象,传入config参数
    flow.run(config = config)

# 如果当前脚本作为主程序运行,则执行main函数
if __name__ == '__main__':
    main()

.\lucidrains\progen\progen_transformer\checkpoint.py

# 导入所需的模块
import time
import os, errno
from pathlib import Path
from functools import partial
# 导入 Google Cloud Storage 相关模块
from google.cloud import storage
from cloudpickle import pickle
from progen_transformer.utils import clear_directory_, silentremove

# 文件系统检查点函数

# 重置文件系统检查点
def file_reset_checkpoint(path):
    clear_directory_(path)

# 获取文件系统中最后一个检查点
def file_get_last_checkpoint(path):
    checkpoints = sorted(path.glob('**/ckpt_*'))
    if len(checkpoints) == 0:
        return None

    with open(str(checkpoints[-1]), 'rb') as f:
        package = pickle.load(f)

    return package

# 保存文件系统检查点
def file_save_checkpoint(path, package, keep_last_n = None):
    unix_time = int(time.time())
    checkpoints = sorted(path.glob('**/ckpt_*'))
    num_checkpoints = len(checkpoints)

    with open(str(path / f'ckpt_{unix_time}.pkl'), 'wb') as f:
        pickle.dump(package, f)

    if keep_last_n is None:
        return

    for path_to_rm in checkpoints[:max(0, num_checkpoints - keep_last_n)]:
        silentremove(path_to_rm)

# Google Cloud Storage 检查点函数

GCS_READ_TIMEOUT = 60 * 30
GCS_WRITE_TIMEOUT = 60 * 30

# 重置 Google Cloud Storage 检查点
def gcs_reset_checkpoint(bucket):
    bucket.delete_blobs(list(bucket.list_blobs()))

# 获取 Google Cloud Storage 中最后一个检查点
def gcs_get_last_checkpoint(bucket):
    blobs = sorted(list(bucket.list_blobs()))

    if len(blobs) == 0:
        return None

    last_checkpoint = blobs[-1]

    filename = f'/tmp/{last_checkpoint.name}'
    with open(filename, 'wb') as f:
        last_checkpoint.download_to_file(f, timeout = GCS_READ_TIMEOUT)

    with open(filename, 'rb') as f:
        package = pickle.load(f)

    return package

# 保存 Google Cloud Storage 检查点
def gcs_save_checkpoint(bucket, package, keep_last_n = None):
    unix_time = int(time.time())
    blobs = sorted(list(bucket.list_blobs()))
    num_checkpoints = len(blobs)

    filename = f'ckpt_{unix_time}.pkl'
    tmp_path = f'/tmp/{filename}'

    with open(tmp_path, 'wb') as f:
        pickle.dump(package, f)

    blob = bucket.blob(filename)
    blob.upload_from_filename(tmp_path, timeout = GCS_WRITE_TIMEOUT)

    if keep_last_n is None:
        return

    bucket.delete_blobs(blobs[:max(0, num_checkpoints - keep_last_n)])

# 工厂函数

# 获取检查点函数
def get_checkpoint_fns(path):
    # 判断是否使用 Google Cloud Storage
    use_gcs = path.startswith('gs://')

    if not use_gcs:
        obj = Path(path)
        obj.mkdir(exist_ok = True, parents = True)

        fns = (
            file_reset_checkpoint,
            file_get_last_checkpoint,
            file_save_checkpoint
        )
    else:
        client = storage.Client()
        bucket_name = path[5:]
        obj = client.get_bucket(bucket_name)

        fns = (
            gcs_reset_checkpoint,
            gcs_get_last_checkpoint,
            gcs_save_checkpoint
        )

    # 将函数对象和路径对象绑定,返回函数元组
    fns = tuple(map(lambda fn: partial(fn, obj), fns))
    return fns

.\lucidrains\progen\progen_transformer\data.py

# 导入所需的库
import tensorflow as tf
import numpy as np
from functools import partial
from pathlib import Path
from contextlib import contextmanager

# 写入 tfrecords

# 定义写入函数,将值写入 tfrecord 文件
def write(writer, values):
    # 将值序列化为字节流
    record_bytes = tf.train.Example(features = tf.train.Features(feature={
        'seq': tf.train.Feature(bytes_list = tf.train.BytesList(value=[values]))
    })).SerializeToString()

    # 写入字节流到 tfrecord 文件
    writer.write(record_bytes)

# 定义上下文管理器,用于创建 tfrecord 文件写入器
@contextmanager
def with_tfrecord_writer(path):
    # 设置 TFRecordWriter 的选项,使用 GZIP 压缩
    options = tf.io.TFRecordOptions(compression_type = 'GZIP')

    # 创建 TFRecordWriter 对象
    with tf.io.TFRecordWriter(path, options = options) as writer:
        # 使用 partial 函数创建写入函数的偏函数
        yield partial(write, writer)

# 读取 tfrecords

# 解析 tfrecord 样本的函数
def parse_fn(sample):
    return tf.io.parse_single_example(sample, {
        'seq': tf.io.FixedLenFeature([], tf.string)
    })

# 对批次数据进行整理的函数
def collate_fn(batch, pad_length, offset = 0):
    # 将字节流转换为 numpy 数组
    tensors = [np.frombuffer(el, dtype = np.uint8).astype(np.uint16) for el in batch.numpy()]
    tensors = map(lambda t: t[..., :pad_length], tensors)
    tensors = map(lambda t: t + offset, tensors)
    padded_tensors = map(lambda t: np.pad(t, (0, pad_length - t.shape[-1])), tensors)
    return np.stack(list(padded_tensors))

# 从 tfrecords 文件夹创建迭代器的函数
def iterator_from_tfrecords_folder(folder, data_type = 'train'):
    # 判断是否为 GCS 路径
    is_gcs_path = folder.startswith('gs://')

    # 根据路径获取 tfrecord 文件名列表
    if is_gcs_path:
        filenames = tf.io.gfile.glob(f'{folder}/*.{data_type}.tfrecord.gz')
    else:
        folder = Path(folder)
        filenames = [str(p) for p in folder.glob(f'**/*.{data_type}.tfrecord.gz')]

    # 计算总序列数
    num_seqs = sum(map(lambda t: int(t.split('.')[-4]), filenames))

    # 定义迭代器函数
    def iter_fn(
        seq_len,
        batch_size,
        skip = 0,
        loop = False
    ):
        # 创建 TFRecordDataset 对象
        dataset = tf.data.TFRecordDataset(filenames, compression_type = 'GZIP')

        # 跳过指定数量的样本
        dataset = dataset.skip(skip)
        dataset = dataset.map(parse_fn)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        # 如果需要循环迭代,则重复数据集
        if loop:
            dataset = dataset.repeat()

        # 遍历数据集,整理数据并返回
        for batch in dataset:
            seq = batch['seq']
            batch_size = seq.shape[0]
            seq = collate_fn(seq, pad_length = seq_len, offset = 1)
            bos = np.zeros((batch_size, 1), dtype = np.uint16)
            seq = np.concatenate((bos, seq), axis = 1)
            yield seq

    return num_seqs, iter_fn

# 标记化

# 编码单个标记的函数
def encode_token(token):
    return ord(token) + 1

# 解码单个标记的函数
def decode_token(token):
    if token < 0:
        return ''
    return str(chr(token))

# 编码标记序列的函数
def encode_tokens(tokens):
    return list(map(encode_token, tokens))

# 解码标记序列的函数
def decode_tokens(tokens, offset = 1):
    return ''.join(list(map(decode_token, tokens.astype(np.int16) - offset))

.\lucidrains\progen\progen_transformer\progen.py

# 导入必要的库
from functools import partial

import jax
from jax import random
from jax import nn
from jax.lax import stop_gradient
import jax.numpy as np
import jmp

import haiku as hk
from haiku import initializers
from einops import rearrange, repeat

from progen_transformer.utils import exists

# 定义常量

ATTN_MASK_VALUE = -1e10

# 定义辅助函数

# 部分应用 LayerNorm 函数,创建 LayerNorm 实例
LayerNorm = partial(hk.LayerNorm, create_scale = True, create_offset = False, axis = -1)

# 生成固定位置的嵌入
def fixed_pos_embedding(seq, dim):
    # 计算频率
    inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
    # 生成正弦和余弦输入
    sinusoid_inp = np.einsum("i , j -> i j", np.arange(seq), inv_freq)
    sinusoid_inp = repeat(sinusoid_inp, "b n -> b (n r)", r = 2)[None, :, :]
    return np.sin(sinusoid_inp), np.cos(sinusoid_inp)

# 将每两个元素进行旋转
def rotate_every_two(x):
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x[..., 0], x[..., 1]
    x = np.stack((-x2, x1), axis = -1)
    return rearrange(x, "... d r -> ... (d r)")

# 应用旋转位置嵌入
def apply_rotary_pos_emb(x, sincos):
    sin, cos = sincos
    rot_dim = sin.shape[-1]
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    x = (x * cos) + (rotate_every_two(x) * sin)
    return np.concatenate((x, x_pass), axis = -1)

# 移动令牌
def shift_tokens(x):
    x_shift, x_pass = np.array_split(x, 2, axis = -1)
    x_shift = np.pad(x_shift, ((1, 0), (0, 0)), mode = 'constant')[:-1]
    return np.concatenate((x_shift, x_pass), axis = -1)

# 定义类

# 局部注意力机制
class LocalAttention(hk.Module):
    def __init__(
        self,
        *,
        name,
        dim,
        window_size,
        heads = 8,
        dim_head = 64,
        shift_tokens = True
    ):
        super().__init__(name = name)
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.window_size = window_size
        inner_dim = dim_head * heads

        self.norm = LayerNorm()
        self.shift_tokens = shift_tokens

        self.to_qkv = hk.Linear(inner_dim * 3, with_bias = False)
        self.to_out = hk.Linear(dim)

    def __call__(self, x, *, pos_emb):
        x = self.norm(x)

        if self.shift_tokens:
            x = shift_tokens(x)

        n, h, wsz = x.shape[0], self.heads, self.window_size
        assert (n % wsz) == 0, 'sequence length must be divisible by the window size'
        window = n // wsz

        qkv = self.to_qkv(x)
        q, k, v = np.split(qkv, 3, axis = -1)
        q, k, v = map(lambda t: rearrange(t, 'n (h d) -> h n d', h = h), (q, k, v))

        q, k, v = map(lambda t: apply_rotary_pos_emb(t, pos_emb), (q, k, v))
        q, k, v = map(lambda t: rearrange(t, 'h (w n) d -> h w n d', w = window), (q, k, v))

        k, v = map(lambda t: np.pad(t, ((0, 0), (1, 0), (0, 0), (0, 0)), constant_values = 0.), (k ,v))
        k, v = map(lambda t: np.concatenate((t[:, :-1], t[:, 1:]), axis = 2), (k, v))

        sim = np.einsum('h w i d, h w j d -> h w i j', q, k) * self.scale

        mask = np.tril(np.ones((wsz, wsz * 2)), wsz)
        sim = np.where(mask, sim, ATTN_MASK_VALUE)

        sim = sim - stop_gradient(np.amax(sim, axis = -1, keepdims = True))
        attn = nn.softmax(sim, axis = -1)

        out = np.einsum('h w i j, h w j d -> h w i d', attn, v)
        out = rearrange(out, 'h w n d -> (w n) (h d)')
        return self.to_out(out)

# 前馈神经网络
class FeedForward(hk.Module):
    def __init__(
        self,
        *,
        name,
        dim,
        ff_mult = 4,
        glu = False,
        seq_len = None,
        spatial_gate = False,
        shift_tokens = True
    ):
        super().__init__(name = name)
        assert not (glu and spatial_gate), 'glu and sgu cannot be turned on at the same time'
        hidden_dim = dim * ff_mult
        hidden_dim *= (1 if not glu else 2)

        self.norm = LayerNorm()
        self.shift_tokens = shift_tokens

        self.proj_in = hk.Linear(hidden_dim)
        self.proj_out = hk.Linear(dim)

        self.glu = glu
        self.sgu = SGU(dim = hidden_dim, dim_out = hidden_dim // 2, seq_len = seq_len) if spatial_gate else None
    # 定义一个类的调用方法,接受输入 x
    def __call__(self, x):
        # 对输入 x 进行归一化处理
        x = self.norm(x)

        # 如果需要进行移位操作
        if self.shift_tokens:
            # 对 x 进行移位操作
            x = shift_tokens(x)

        # 对 x 进行投影操作
        x = self.proj_in(x)

        # 如果使用门控线性单元(GLU)
        if self.glu:
            # 将 x 拆分成两部分,分别为 x 和门控信号 gate
            x, gate = np.split(x, 2, axis=-1)
            # 对 x 进行门控线性单元激活函数处理
            x *= nn.gelu(gate)
        else:
            # 对 x 进行门控线性单元激活函数处理
            x = nn.gelu(x)

        # 如果存在自定义的门控单元(SGU)
        if exists(self.sgu):
            # 对 x 进行自定义门控单元处理
            x = self.sgu(x)

        # 对 x 进行输出投影操作
        x = self.proj_out(x)
        # 返回处理后的 x
        return x
# 定义 SGU 类,继承自 hk.Module
class SGU(hk.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,
        dim_out,
        seq_len,
        eps = 1e-3
    ):
        super().__init__()
        self.eps = eps
        self.seq_len = seq_len
        self.norm = LayerNorm()
        self.proj_out = hk.Linear(dim_out)

    # 调用函数
    def __call__(self, x):
        n = self.seq_len
        # 将输入 x 沿着最后一个轴分割成两部分
        x, gate = np.split(x, 2, axis = -1)

        # 对 gate 进行归一化
        gate = self.norm(gate)

        # 初始化缩放值
        init_scale = self.eps / n
        # 初始化随机均匀分布
        init_eps = initializers.RandomUniform(minval = -init_scale, maxval = init_scale)

        # 获取参数 weights 和 biases
        weights = hk.get_parameter('spatial_weights', shape = (n, n), init = init_eps)
        biases = hk.get_parameter('spatial_biases', shape = (n, 1), init = np.ones)

        # 生成一个下三角矩阵 mask
        mask = np.tril(np.ones((n, n)))
        weights = weights * mask

        # 使用矩阵乘法计算 gate
        gate = np.einsum('n d, m n -> m d', gate, weights)
        gate += biases

        # 对输入 x 进行门控
        x = x * gate
        return self.proj_out(x)

# 定义 ProGenBase 类,继承自 hk.Module
class ProGenBase(hk.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        seq_len,
        depth,
        window_size = 256,
        global_mlp_depth = 2,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        ff_glu = True,
        attn_dim = None,
        clamp_gate = True,
        shift_tokens = True
    ):
        super().__init__()
        self.dim_head = dim_head
        self.embed = hk.Embed(num_tokens, dim)

        self.layers = []
        # 循环创建 depth 个层
        for i in range(depth):
            use_gmlp = (depth - i) <= global_mlp_depth
            use_ff_glu = not use_gmlp and ff_glu

            # 添加 LocalAttention 和 FeedForward 层到 layers 列表
            self.layers.append([
                LocalAttention(name = f'attn{i}', dim = dim, window_size = window_size, heads = heads, dim_head = dim_head, shift_tokens = shift_tokens),
                FeedForward(name = f'ff{i}', dim = dim, ff_mult = ff_mult, seq_len = seq_len, spatial_gate = use_gmlp, glu = use_ff_glu, shift_tokens = shift_tokens)
            ])

        # 定义输出层
        self.to_logits = hk.Sequential([
            LayerNorm(),
            hk.Linear(num_tokens)
        ])

    # 调用函数
    def __call__(self, x):
        n = x.shape[0]
        x = self.embed(x)
        rotary_emb = fixed_pos_embedding(n, self.dim_head)

        # 循环遍历每个层并进行操作
        for attn, ff in self.layers:
            x += attn(x, pos_emb = rotary_emb)
            x += ff(x)

        return self.to_logits(x)

# 定义 ProGen 函数
def ProGen(mixed_precision = False, mixed_precision_policy = dict(params = 'float32', compute = 'float16', output = 'float32'), **kwargs):
    # 使用 hk.transform 装饰器
    @hk.transform
    def inner(seq):
        if mixed_precision:
            serialized_policy = ','.join([f'{k}={v}' for k, v in mixed_precision_policy.items()])
            policy = jmp.get_policy(serialized_policy)
            hk.mixed_precision.set_policy(ProGenBase, policy)
        return ProGenBase(**kwargs)(seq)
    return inner

.\lucidrains\progen\progen_transformer\utils.py

# 从 math 模块中导入 ceil 函数
from math import ceil
# 导入 os 和 errno 模块
import os, errno
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree

# 导入 jax 库
import jax
# 从 jax 库中导入 random, nn, value_and_grad, vmap, pmap, jit, lax 模块
from jax import random, nn, value_and_grad, vmap, pmap, jit, lax
# 从 jax.numpy 模块中导入 np 别名
import jax.numpy as np

# 从 einops 模块中导入 rearrange 函数

from einops import rearrange

# 辅助函数

# 定义一个空操作函数
def noop(x):
    return x

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 计算对数的函数
def log(t, eps = 1e-20):
    return np.log(t + eps)

# 确认函数
def confirm(question):
    while True:
        resp = input(f'{question} (y/n) ')
        lower_resp = resp.lower()
        if lower_resp in ('y', 'n'):
            return lower_resp == 'y'

# 清空目录的函数
def clear_directory_(path):
    rmtree(str(path), ignore_errors = True)
    path.mkdir(exist_ok = True, parents = True)

# 安静删除文件的函数
def silentremove(filename):
    try:
        os.remove(filename)
    except OSError:
        pass

# 训练函数

# 计算带掩码的均值的函数
def masked_mean(t, mask, axis = None):
    return (t * mask).sum(axis = axis) / mask.sum(axis = axis)

# 交叉熵损失函数
def cross_entropy(logits, targets, axis = -1, ignore_index = 0):
    logprobs = nn.log_softmax(logits, axis = axis)

    nll = np.take_along_axis(logprobs, np.expand_dims(targets, axis = axis), axis = axis)
    nll = nll.squeeze(-1)

    # 为损失创建掩码,以便从第一个填充标记中学习
    # 填充标记被重用作字符串结束标记,以简化
    mask = (targets != ignore_index)
    eos_mask = (~mask).cumsum(axis = -1) == 1
    mask = mask | eos_mask

    ce = -masked_mean(nll, mask, axis = -1)
    return ce

# 获取损失函数
def get_loss_fn(model, data_parallel = False):
    def loss_fn(params, key, data):
        ids, labels = data[:-1], data[1:]
        logits = model.apply(params, key, ids)
        return cross_entropy(logits, labels, axis = -1)

    loss_fn = jit(vmap(loss_fn, in_axes = (None, None, 0), out_axes = 0))

    if data_parallel:
        loss_fn = pmap(loss_fn, in_axes = (None, None, 0), out_axes = 0)

    @value_and_grad
    def batched_loss_fn(params, key, data):
        if not data_parallel:
            values = loss_fn(params, key, data)
            return np.mean(values)

        mask = np.ones((data.shape[0],))

        device_count = jax.local_device_count()
        batch_size = data.shape[0]

        remainder = (batch_size % device_count)
        if remainder != 0:
            padding = device_count - remainder
            data = np.pad(data, ((0, padding), (0, 0)))
            mask = np.pad(mask, ((0, padding)))

        data, mask = map(lambda t: rearrange(t, '(p b) ... -> p b ...', p = device_count), (data, mask))
        values = loss_fn(params, key, data)
        return masked_mean(values, mask)

    return batched_loss_fn

# 采样函数

# 选择前 k 个值的函数
def select_top_k(tensor, k):
    values, _ = top_k(tensor, k)
    mask = tensor > values.min()
    return mask, np.where(mask, tensor, 0.)

# 生成 Gumbel 噪声的函数
def gumbel_noise(rng, shape):
    noise = random.uniform(rng, shape = shape, minval = 0., maxval = 1.)
    return -log(-log(noise))

# 采样函数
def sample(rng, fn, params, prime, length, top_k = None, add_bos = False):
    start_pos = prime.shape[-1]
    pad_right = length - prime.shape[-1]

    padding = (0, pad_right) if not add_bos else (1, pad_right - 1)
    seq = np.pad(prime, padding)

    one_hots = np.eye(length, dtype = int)

    for curr_pos in range(start_pos, length):
        logits = fn(params, next(rng), seq)
        logits = logits[curr_pos - 1]

        noise = gumbel_noise(next(rng), logits.shape)

        if exists(top_k):
            mask, logits = select_top_k(logits, top_k)
            noise *= mask

        logits += noise
        sampled_ind = np.argmax(logits, axis = -1)

        one_hot = one_hots[curr_pos]
        seq += one_hot * sampled_ind

    # 目前,将第二个填充标记(eos)后的所有内容设置为填充
    remove_after_eos_mask = (seq == 0).cumsum(axis = -1) > 1
    seq *= ~remove_after_eos_mask

    return seq

# RNG 修复

# 硬件均匀分布函数
def hardware_uniform(
    rng_key,
    shape,
    dtype = np.float32,
    minval = np.float32(0),
    maxval = np.float32(1)
):
    del rng_key
    # 将最小值转换为指定数据类型
    minval = lax.convert_element_type(minval, dtype)
    # 将最大值转换为指定数据类型
    maxval = lax.convert_element_type(maxval, dtype)
    # 返回一个形状为 shape 的在 [minval, maxval) 范围内均匀分布的随机数
    return lax.rng_uniform(minval, maxval, shape)
# 定义一个硬件实现的伯努利分布函数,接受随机数生成器密钥、概率和形状参数
def hardware_bernoulli(rng_key, p = np.float32(0.5), shape = None):
    # 删除随机数生成器密钥参数
    del rng_key
    # 返回一个布尔数组,表示是否小于给定概率 p
    return lax.rng_uniform(0.0, 1.0, shape) < p

# 设置 JAX 库中的随机数生成器函数为硬件实现的伯努利分布函数
def set_hardware_rng_(jax):
    # 将 JAX 库中的伯努利分布函数替换为硬件实现的伯努利分布函数
    jax.random.bernoulli = hardware_bernoulli
    # 将 JAX 库中的均匀分布函数替换为硬件实现的均匀分布函数
    jax.random.uniform = hardware_uniform
    # 将 JAX 库中的源码中的均匀分布函数替换为硬件实现的均匀分布函数
    jax._src.random.uniform = hardware_uniform

.\lucidrains\progen\progen_transformer\__init__.py

# 从 progen_transformer.progen 模块中导入 ProGen 类
from progen_transformer.progen import ProGen

ProGen - (wip)

Implementation and replication of ProGen, Language Modeling for Protein Generation, in Pytorch and Jax (the weights will be made easily transferrable between the two). You can think of this as GPT for proteins sequences.

Requirements

We are going to use Poetry for managing the dependencies for this project. So first install it using the one-liner bash command.

Next, git clone the project and install the dependencies

$ git clone git@github.com:lucidrains/progen
$ cd progen
$ poetry install

For training on GPUs, you may need to rerun pip install with the correct CUDA version. You can follow the instructions here

# ex. CUDA 11.1
$ pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

For running any scripts, you'll notice that it will always be prepended with poetry run

Usage

from jax import random
from haiku import PRNGSequence
from progen_transformer import ProGen

model = ProGen(
    num_tokens = 256,
    dim = 512,
    seq_len = 1024,
    window_size = 256,       # local attention window size
    depth = 12,              # depth
    heads = 8,               # attention heads
    dim_head = 64,           # dimension per head
    ff_glu = True,           # use GLU in feedforward, from Noam's paper
    global_mlp_depth = 2     # last N global gmlp layers
)

rng = PRNGSequence(42)
seq = random.randint(next(rng), (1024,), 0, 256)

params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 256)

Training

Download Uniref50 from UniProt and place uniref50.fasta in the root directory

$ poetry run python generate_data.py

You should see a lot of green if everything succeeds. Then

$ poetry run python train.py

By default, the script will checkpoint and resume automatically, but if you wish to clear your progress and restart, just add a --new flag

$ poetry run python train.py --new

Model checkpoints will be saved periodically to ./ckpts

Finally, to sample from your checkpoint, just do

$ poetry run python sample.py

You can pass a prime with --prime. You can either pass the annotations, followed by #, to get the generated sequence, or pass the sequence (also followed by #) and get the generated annotations

$ poetry run python sample.py --prime "[Tax=Mammalia] #"

Mixed Precision

To use mixed precision training, you'll need to install the latest Haiku with the following command

$ pip install git+https://github.com/deepmind/dm-haiku

Then make sure to set the --mixed_precision flag when invoking the training script

$ poetry run python train.py --mixed_precision

Todo

Acknowledgements

Many thanks goes out to Ben Wang, who showed this type of large-scale training can be achieved with GPT-J

Citations

@misc{madani2020progen,
    title   = {ProGen: Language Modeling for Protein Generation}, 
    author  = {Ali Madani and Bryan McCann and Nikhil Naik and Nitish Shirish Keskar and Namrata Anand and Raphael R. Eguchi and Po-Ssu Huang and Richard Socher},
    year    = {2020},
    eprint  = {2004.03497},
    archivePrefix = {arXiv},
    primaryClass = {q-bio.BM}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}

.\lucidrains\progen\sample.py

# 导入 load_dotenv 函数,用于加载环境变量
from dotenv import load_dotenv
# 调用 load_dotenv 函数加载环境变量

# 导入 click 模块,用于创建命令行接口
import click
# 导入 humanize 模块,用于处理人类可读的数据格式

# 导入 jax 模块及其子模块
import jax
from jax import nn, random, jit, tree_util, numpy as np

# 导入 haiku 模块中的 PRNGSequence 类
from haiku import PRNGSequence

# 导入 progen_transformer 模块及其子模块
from progen_transformer import ProGen
from progen_transformer.data import decode_tokens, encode_tokens
from progen_transformer.utils import sample, set_hardware_rng_
from progen_transformer.checkpoint import get_checkpoint_fns

# 调用 set_hardware_rng_ 函数,加速随机数生成器

# 定义主函数
@click.command()
# 定义命令行参数
@click.option('--seed', default = 42)
@click.option('--checkpoint_path', default = './ckpts')
@click.option('--prime', default = '')
def main(
    seed,
    checkpoint_path,
    prime,
):
    # 准备文件夹

    # 获取最后一个检查点
    _, get_last_checkpoint, _ = get_checkpoint_fns(checkpoint_path)
    last_checkpoint = get_last_checkpoint()

    # 如果没有找到最后一个检查点,则退出程序
    if last_checkpoint is None:
        exit(f'no checkpoints found at {checkpoint_path}')

    # 获取参数和序列数
    params = last_checkpoint['params']
    num_seqs = max(last_checkpoint['next_seq_index'], 0)

    # 设置模型和参数
    model_kwargs = last_checkpoint['model_config']
    model = ProGen(**model_kwargs)
    model_apply = jit(model.apply)
    rng = PRNGSequence(seed)

    # 初始化所有状态,或从检查点加载

    seq_len = model_kwargs['seq_len']
    num_params = tree_util.tree_reduce(lambda acc, el: acc + el.size, params, 0)
    num_params_readable = humanize.naturalsize(num_params)

    # 打印参数、序列长度和训练序列数
    print(f'params: {num_params_readable}')
    print(f'sequence length: {seq_len}')
    print(f'trained for {num_seqs} sequences')

    # 使用 prime 进行采样
    prime_tokens = encode_tokens(prime)
    prime_length = len(prime_tokens) + 1
    prime_tensor = np.array(prime_tokens, dtype = np.uint16)

    sampled = sample(rng, jit(model_apply), params, prime_tensor, seq_len, top_k = 25, add_bos = True)
    sampled_str = decode_tokens(sampled[prime_length:])

    # 打印采样结果
    print("\n", prime, "\n", "*" * 40, "\n", sampled_str)

# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
    main()

.\lucidrains\progen\train.py

# 导入 load_dotenv 函数,用于加载环境变量
from dotenv import load_dotenv
# 调用 load_dotenv 函数加载环境变量

# 导入 click、humanize、Template、Path、tqdm、numpy 等模块
import click
import humanize
from jinja2 import Template
from pathlib import Path
import tqdm
import numpy as np

# 导入 toml 模块
import toml

# 导入 jax 相关模块和函数
import jax
from jax import nn, random, jit, tree_util, tree_map
from optax import adamw, clip_by_global_norm, chain, apply_updates, apply_every

# 导入 haiku 模块中的 PRNGSequence 类
from haiku import PRNGSequence

# 导入 progen_transformer 模块及其子模块
from progen_transformer import ProGen
from progen_transformer.data import decode_tokens, iterator_from_tfrecords_folder
from progen_transformer.utils import sample, get_loss_fn, set_hardware_rng_, confirm, exists
from progen_transformer.checkpoint import get_checkpoint_fns

# 导入 wandb 模块
import wandb

# 创建模板对象 sample_tmpl,用于生成 HTML 样式
sample_tmpl = Template("""<i>{{prime_str}}</i><br/><br/><div style="overflow-wrap: break-word;">{{sampled_str}}</div>""")

# 设置硬件随机数生成器
set_hardware_rng_(jax)

# 主函数定义,接收多个命令行参数
@click.command()
@click.option('--seed', default = 42)
@click.option('--batch_size', default = 4)
@click.option('--grad_accum_every', default = 4)
@click.option('--learning_rate', default = 2e-4)
@click.option('--weight_decay', default = 1e-3)
@click.option('--data_parallel', default = False, is_flag = True)
@click.option('--max_grad_norm', default = 0.5)
@click.option('--validate_every', default = 100)
@click.option('--sample_every', default = 500)
@click.option('--checkpoint_every', default = 1000)
@click.option('--checkpoint_path', default = './ckpts')
@click.option('--checkpoint_keep_n', default = 500)
@click.option('--config_path', default = './configs/model')
@click.option('--model_name', default = 'default')
@click.option('--prime_length', default = 25)
@click.option('--seq_len', default = 1024)
@click.option('--mixed_precision', default = False, is_flag = True)
@click.option('--data_path', default = './train_data')
@click.option('--wandb_off', default = False, is_flag = True)
@click.option('--wandb_project_name', default = 'progen-training')
@click.option('--new', default = False, is_flag = True)
def main(
    seed,
    batch_size,
    grad_accum_every,
    learning_rate,
    weight_decay,
    data_parallel,
    max_grad_norm,
    validate_every,
    sample_every,
    checkpoint_every,
    checkpoint_path,
    checkpoint_keep_n,
    config_path,
    model_name,
    prime_length,
    seq_len,
    mixed_precision,
    data_path,
    wandb_off,
    wandb_project_name,
    new
):
    # 准备文件夹

    # 获取重置、获取最新、保存检查点的函数
    reset_checkpoint, get_last_checkpoint, save_checkpoint = get_checkpoint_fns(checkpoint_path)

    # 如果设置了 new 参数,清除所有检查点并重新开始训练
    if new:
        if not confirm('are you sure you want to clear all your checkpoints and restart training?'):
            exit()
        reset_checkpoint()

    # 初始化所有状态,或从检查点加载

    # 获取最新的检查点
    last_checkpoint = get_last_checkpoint()

    # 如果最新的检查点不存在
    if not exists(last_checkpoint):
        # 获取模型配置文件路径
        config_folder_path = Path(config_path)
        config_path = config_folder_path / f'{model_name}.toml'
        # 检查模型配置文件是否存在
        assert config_path.exists(), f'path to your model config {str(config_path)} does not exist'
        # 加载模型参数
        model_kwargs = toml.loads(config_path.read_text())
    else:
        # 使用最新的检查点中的模型配置
        model_kwargs = last_checkpoint['model_config']

    # 设置模型和参数

    # 创建 ProGen 模型实例
    model = ProGen(**{
        **model_kwargs,
        'mixed_precision': mixed_precision
    })

    # 编译模型应用函数
    model_apply = jit(model.apply)
    # 创建随机数生成器
    rng = PRNGSequence(seed)
    # 获取损失函数
    loss_fn = get_loss_fn(model, data_parallel = data_parallel)

    # 优化器

    # 定义排除规范和偏置参数的函数
    exclude_norm_and_bias_params = lambda p: tree_map(lambda x: x.ndim > 1, p)

    # 构建优化器链
    optim = chain(
        clip_by_global_norm(max_grad_norm),
        adamw(learning_rate, weight_decay = weight_decay, mask = exclude_norm_and_bias_params),
        apply_every(grad_accum_every)
    )

    # 获取参数和优化器状态

    if exists(last_checkpoint):
        params = last_checkpoint['params']
        optim_state = last_checkpoint['optim_state']
        start_seq_index = last_checkpoint['next_seq_index']
    else:
        # 如果不是第一次训练,则创建一个全零数组作为模拟数据
        mock_data = np.zeros((model_kwargs['seq_len'],), dtype = np.uint8)
        # 使用模拟数据初始化模型参数
        params = model.init(next(rng), mock_data)
        # 使用初始化的参数初始化优化器状态
        optim_state = optim.init(params)
        # 设置起始序列索引为0
        start_seq_index = 0

    # 实验追踪器

    # 获取模型序列长度
    seq_len = model_kwargs['seq_len']
    # 计算模型参数的数量
    num_params = tree_util.tree_reduce(lambda acc, el: acc + el.size, params, 0)
    # 将参数数量转换为可读的格式
    num_params_readable = humanize.naturalsize(num_params)

    # 设置wandb配置中的参数数量
    wandb.config.num_params = num_params

    # 根据wandb_off参数决定是否禁用wandb
    wandb_kwargs = {'mode': 'disabled'} if wandb_off else {}

    # 如果存在上次的检查点信息,则恢复运行ID和恢复模式
    if exists(last_checkpoint) and exists(last_checkpoint['run_id']):
        run_id = last_checkpoint['run_id']
        wandb_kwargs = {**wandb_kwargs, 'id': run_id, 'resume': 'allow'}

    # 初始化wandb
    wandb.init(project = wandb_project_name, **wandb_kwargs)
    wandb_run_id = wandb.run.id if not wandb_off else None

    # 获取tf数据集

    # 从tfrecords文件夹中获取训练数据集
    total_train_seqs, get_train_dataset = iterator_from_tfrecords_folder(data_path, data_type = 'train')
    # 从tfrecords文件夹中获取验证数据集
    total_valid_seqs, get_valid_dataset = iterator_from_tfrecords_folder(data_path, data_type = 'valid',)

    # 断言训练数据集和验证数据集的序列数量大于0
    assert total_train_seqs > 0, 'no protein sequences found for training'
    assert total_valid_seqs > 0, 'no protein sequences found for validation'

    # 获取训练数据集和验证数据集
    train_dataset = get_train_dataset(
        seq_len = seq_len,
        batch_size = batch_size,
        skip = start_seq_index
    )

    valid_dataset = get_valid_dataset(
        seq_len = seq_len,
        batch_size = batch_size,
        loop = True
    )

    # 打印信息

    print(f'params: {num_params_readable}')
    print(f'sequence length: {seq_len}')
    print(f'num sequences: {total_train_seqs}')
    print(f'starting from sequence {start_seq_index}')

    # 训练

    # 计算有效批次大小
    effective_batch_size = batch_size * grad_accum_every
    # 计算序列索引范围
    seq_index_ranges = range(start_seq_index, total_train_seqs, effective_batch_size)    

    # 遍历序列索引范围
    for i, seq_index in tqdm.tqdm(enumerate(seq_index_ranges), mininterval = 10., desc = 'training', total = len(seq_index_ranges)):
        # 根据梯度累积次数进行训练
        for _ in range(grad_accum_every):
            data = next(train_dataset)

            # 计算损失和梯度
            loss, grads = loss_fn(params, next(rng), data)
            # 更新参数和优化器状态
            updates, optim_state = optim.update(grads, optim_state, params)
            params = apply_updates(params, updates)

        print(f'loss: {loss.item()}')
        wandb.log({'loss': loss.item()})

        if i % checkpoint_every == 0:
            # 保存检查点信息
            package = {
                'next_seq_index': seq_index + effective_batch_size,
                'params': params,
                'optim_state': optim_state,
                'model_config': model_kwargs,
                'run_id': wandb_run_id
            }

            save_checkpoint(package, checkpoint_keep_n)
            print(f"checkpoint to start at sequence index of {package['next_seq_index']}")

        if i % validate_every == 0:
            # 验证模型
            valid_data = next(valid_dataset)
            loss, _ = loss_fn(params, next(rng), valid_data)
            print(f'valid_loss: {loss.item()}')
            wandb.log({'valid_loss': loss.item()})

        if i % sample_every == 0:
            # 生成样本
            valid_data = next(valid_dataset)[0]
            prime = valid_data[:prime_length]
            prime_str = decode_tokens(prime)

            sampled = sample(rng, model_apply, params, prime, seq_len, top_k = 25)
            sampled_str = decode_tokens(sampled[prime_length:])

            print(prime_str, "\n", "*" * 40, "\n", sampled_str)
            wandb.log({'samples': wandb.Html(sample_tmpl.render(prime_str = prime_str, sampled_str = sampled_str))})
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
    main()

.\lucidrains\protein-bert-pytorch\protein_bert_pytorch\protein_bert_pytorch.py

# 导入 math、torch 库以及 torch.nn.functional 模块中的 F 函数
import math
import torch
import torch.nn.functional as F
# 从 torch 模块中导入 nn、einsum 函数
from torch import nn, einsum
# 从 einops.layers.torch 模块中导入 Rearrange、Reduce 类
from einops.layers.torch import Rearrange, Reduce
# 从 einops 模块中导入 rearrange、repeat 函数

# helpers

# 判断变量是否存在的辅助函数
def exists(val):
    return val is not None

# 返回给定张量类型的最小负值的辅助函数
def max_neg_value(t):
    return -torch.finfo(t.dtype).max

# helper classes

# 残差连接类
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 全局线性自注意力类
class GlobalLinearSelfAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head,
        heads
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, feats, mask = None):
        h = self.heads
        q, k, v = self.to_qkv(feats).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        if exists(mask):
            mask = rearrange(mask, 'b n -> b () n ()')
            k = k.masked_fill(~mask, -torch.finfo(k.dtype).max)

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        if exists(mask):
            v = v.masked_fill(~mask, 0.)

        context = einsum('b h n d, b h n e -> b h d e', k, v)
        out = einsum('b h d e, b h n d -> b h n e', context, q)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 交叉注意力类
class CrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_keys,
        dim_out,
        heads,
        dim_head = 64,
        qk_activation = nn.Tanh()
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.qk_activation = qk_activation

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_keys, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim_out)

        self.null_key = nn.Parameter(torch.randn(dim_head))
        self.null_value = nn.Parameter(torch.randn(dim_head))

    def forward(self, x, context, mask = None, context_mask = None):
        b, h, device = x.shape[0], self.heads, x.device

        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        null_k, null_v = map(lambda t: repeat(t, 'd -> b h () d', b = b, h = h), (self.null_key, self.null_value))
        k = torch.cat((null_k, k), dim = -2)
        v = torch.cat((null_v, v), dim = -2)

        q, k = map(lambda t: self.qk_activation(t), (q, k))

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        if exists(mask) or exists(context_mask):
            i, j = sim.shape[-2:]

            if not exists(mask):
                mask = torch.ones(b, i, dtype = torch.bool, device = device)

            if exists(context_mask):
                context_mask = F.pad(context_mask, (1, 0), value = True)
            else:
                context_mask = torch.ones(b, j, dtype = torch.bool, device = device)

            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
            sim.masked_fill_(~mask, max_neg_value(sim))

        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Layer(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        dim,
        dim_global,
        narrow_conv_kernel = 9,
        wide_conv_kernel = 9,
        wide_conv_dilation = 5,
        attn_heads = 8,
        attn_dim_head = 64,
        attn_qk_activation = nn.Tanh(),
        local_to_global_attn = False,
        local_self_attn = False,
        glu_conv = False
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 如果启用局部自注意力机制,则创建全局线性自注意力对象
        self.seq_self_attn = GlobalLinearSelfAttention(dim = dim, dim_head = attn_dim_head, heads = attn_heads) if local_self_attn else None

        # 如果启用门控线性单元,则设置卷积倍数为2,否则为1
        conv_mult = 2 if glu_conv else 1

        # 创建窄卷积层
        self.narrow_conv = nn.Sequential(
            nn.Conv1d(dim, dim * conv_mult, narrow_conv_kernel, padding = narrow_conv_kernel // 2),
            nn.GELU() if not glu_conv else nn.GLU(dim = 1)
        )

        # 计算宽卷积的填充大小
        wide_conv_padding = (wide_conv_kernel + (wide_conv_kernel - 1) * (wide_conv_dilation - 1)) // 2

        # 创建宽卷积层
        self.wide_conv = nn.Sequential(
            nn.Conv1d(dim, dim * conv_mult, wide_conv_kernel, dilation = wide_conv_dilation, padding = wide_conv_padding),
            nn.GELU() if not glu_conv else nn.GLU(dim = 1)
        )

        # 设置是否进行局部到全局的注意力计算
        self.local_to_global_attn = local_to_global_attn

        # 根据是否进行局部到全局的注意力计算,创建相应的全局信息提取层
        if local_to_global_attn:
            self.extract_global_info = CrossAttention(
                dim = dim,
                dim_keys = dim_global,
                dim_out = dim,
                heads = attn_heads,
                dim_head = attn_dim_head
            )
        else:
            self.extract_global_info = nn.Sequential(
                Reduce('b n d -> b d', 'mean'),
                nn.Linear(dim_global, dim),
                nn.GELU(),
                Rearrange('b d -> b () d')
            )

        # 创建局部层归一化层
        self.local_norm = nn.LayerNorm(dim)

        # 创建局部前馈网络
        self.local_feedforward = nn.Sequential(
            Residual(nn.Sequential(
                nn.Linear(dim, dim),
                nn.GELU(),
            )),
            nn.LayerNorm(dim)
        )

        # 创建全局关注局部的交叉注意力层
        self.global_attend_local = CrossAttention(dim = dim_global, dim_out = dim_global, dim_keys = dim, heads = attn_heads, dim_head = attn_dim_head, qk_activation = attn_qk_activation)

        # 创建全局密集层
        self.global_dense = nn.Sequential(
            nn.Linear(dim_global, dim_global),
            nn.GELU()
        )

        # 创建全局层归一化层
        self.global_norm = nn.LayerNorm(dim_global)

        # 创建全局前馈网络
        self.global_feedforward = nn.Sequential(
            Residual(nn.Sequential(
                nn.Linear(dim_global, dim_global),
                nn.GELU()
            )),
            nn.LayerNorm(dim_global),
        )

    # 前向传播函数
    def forward(self, tokens, annotation, mask = None):
        # 如果启用局部到全局的注意力计算,则提取全局信息
        if self.local_to_global_attn:
            global_info = self.extract_global_info(tokens, annotation, mask = mask)
        else:
            global_info = self.extract_global_info(annotation)

        # 处理局部(蛋白质序列)

        # 如果存在局部自注意力机制,则计算全局线性注意力
        global_linear_attn = self.seq_self_attn(tokens) if exists(self.seq_self_attn) else 0

        # 重排输入以适应卷积层的输入格式
        conv_input = rearrange(tokens, 'b n d -> b d n')

        # 如果存在掩码,则根据掩码进行填充
        if exists(mask):
            conv_input_mask = rearrange(mask, 'b n -> b () n')
            conv_input = conv_input.masked_fill(~conv_input_mask, 0.)

        # 进行窄卷积和宽卷积操作
        narrow_out = self.narrow_conv(conv_input)
        narrow_out = rearrange(narrow_out, 'b d n -> b n d')
        wide_out = self.wide_conv(conv_input)
        wide_out = rearrange(wide_out, 'b d n -> b n d')

        # 更新 tokens
        tokens = tokens + narrow_out + wide_out + global_info + global_linear_attn
        tokens = self.local_norm(tokens)

        # 应用局部前馈网络
        tokens = self.local_feedforward(tokens)

        # 处理全局(注释)

        # 全局关注局部的交叉注意力
        annotation = self.global_attend_local(annotation, tokens, context_mask = mask)
        annotation = self.global_dense(annotation)
        annotation = self.global_norm(annotation)
        annotation = self.global_feedforward(annotation)

        return tokens, annotation
# 主模型类定义
class ProteinBERT(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        num_tokens = 26,  # 标记的数量
        num_annotation = 8943,  # 注释的数量
        dim = 512,  # 维度
        dim_global = 256,  # 全局维度
        depth = 6,  # 深度
        narrow_conv_kernel = 9,  # 窄卷积核大小
        wide_conv_kernel = 9,  # 宽卷积核大小
        wide_conv_dilation = 5,  # 宽卷积膨胀率
        attn_heads = 8,  # 注意力头数
        attn_dim_head = 64,  # 注意力头维度
        attn_qk_activation = nn.Tanh(),  # 注意力激活函数
        local_to_global_attn = False,  # 是否使用局部到全局注意力
        local_self_attn = False,  # 是否使用局部自注意力
        num_global_tokens = 1,  # 全局标记数量
        glu_conv = False  # 是否使用门控线性单元卷积
    ):
        super().__init__()
        self.num_tokens = num_tokens  # 设置标记数量
        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入层

        self.num_global_tokens = num_global_tokens  # 设置全局标记数量
        self.to_global_emb = nn.Linear(num_annotation, num_global_tokens * dim_global)  # 全局嵌入层

        # 创建多层神经网络
        self.layers = nn.ModuleList([Layer(dim = dim, dim_global = dim_global, narrow_conv_kernel = narrow_conv_kernel, wide_conv_dilation = wide_conv_dilation, wide_conv_kernel = wide_conv_kernel, attn_qk_activation = attn_qk_activation, local_to_global_attn = local_to_global_attn, local_self_attn = local_self_attn, glu_conv = glu_conv) for layer in range(depth)])

        self.to_token_logits = nn.Linear(dim, num_tokens)  # 标记的逻辑回归层

        self.to_annotation_logits = nn.Sequential(
            Reduce('b n d -> b d', 'mean'),  # 减少维度
            nn.Linear(dim_global, num_annotation)  # 全局注释的逻辑回归层
        )

    # 前向传播函数
    def forward(self, seq, annotation, mask = None):
        tokens = self.token_emb(seq)  # 标记嵌入

        annotation = self.to_global_emb(annotation)  # 全局嵌入
        annotation = rearrange(annotation, 'b (n d) -> b n d', n = self.num_global_tokens)  # 重新排列全局嵌入

        for layer in self.layers:
            tokens, annotation = layer(tokens, annotation, mask = mask)  # 多层神经网络的前向传播

        tokens = self.to_token_logits(tokens)  # 标记的逻辑回归
        annotation = self.to_annotation_logits(annotation)  # 全局注释的逻辑回归
        return tokens, annotation  # 返回标记和注释

# 预训练包装器类定义
class PretrainingWrapper(nn.Module):
    # 初始化函数
    def __init__(
        self,
        model,
        random_replace_token_prob = 0.05,  # 随机替换标记的概率
        remove_annotation_prob = 0.25,  # 移除注释的概率
        add_annotation_prob = 0.01,  # 添加注释的概率
        remove_all_annotations_prob = 0.5,  # 移除所有注释的概率
        seq_loss_weight = 1.,  # 序列损失权重
        annotation_loss_weight = 1.,  # 注释损失权重
        exclude_token_ids = (0, 1, 2)   # 要排除的标记ID(用于排除填充、开始和结束标记)
    ):
        super().__init__()
        assert isinstance(model, ProteinBERT), 'model must be an instance of ProteinBERT'  # 断言模型必须是ProteinBERT的实例

        self.model = model  # 设置模型

        self.random_replace_token_prob = random_replace_token_prob  # 设置随机替换标记的概率
        self.remove_annotation_prob = remove_annotation_prob  # 设置移除注释的概率
        self.add_annotation_prob = add_annotation_prob  # 设置添加注释的概率
        self.remove_all_annotations_prob = remove_all_annotations_prob  # 设置移除所有注释的概率

        self.seq_loss_weight = seq_loss_weight  # 设置序列损失权重
        self.annotation_loss_weight = annotation_loss_weight  # 设置注释损失权重

        self.exclude_token_ids = exclude_token_ids  # 设置要排除的标记ID
    # 定义一个前向传播函数,接受序列、注释和掩码作为输入
    def forward(self, seq, annotation, mask = None):
        # 获取批量大小和设备信息
        batch_size, device = seq.shape[0], seq.device

        # 复制输入序列和注释
        seq_labels = seq
        annotation_labels = annotation

        # 如果没有提供掩码,则创建一个全为 True 的掩码
        if not exists(mask):
            mask = torch.ones_like(seq).bool()

        # 准备用于对序列进行噪声处理的掩码

        excluded_tokens_mask = mask

        # 根据排除的标记 ID,生成排除标记的掩码
        for token_id in self.exclude_token_ids:
            excluded_tokens_mask = excluded_tokens_mask & (seq != token_id)

        # 根据给定的概率生成随机替换标记的掩码
        random_replace_token_prob_mask = get_mask_subset_with_prob(excluded_tokens_mask, self.random_replace_token_prob)

        # 准备用于对注释进行噪声处理的掩码

        batch_mask = torch.ones(batch_size, device = device, dtype = torch.bool)
        batch_mask = rearrange(batch_mask, 'b -> b ()')
        remove_annotation_from_batch_mask = get_mask_subset_with_prob(batch_mask, self.remove_all_annotations_prob)

        annotation_mask = annotation > 0
        remove_annotation_prob_mask = get_mask_subset_with_prob(annotation_mask, self.remove_annotation_prob)
        add_annotation_prob_mask = get_mask_subset_with_prob(~annotation_mask, self.add_annotation_prob)
        remove_annotation_mask = remove_annotation_from_batch_mask & remove_annotation_prob_mask

        # 生成随机标记

        random_tokens = torch.randint(0, self.model.num_tokens, seq.shape, device=seq.device)

        # 确保不会用排除的标记类型(填充、开始、结束)替换标记
        for token_id in self.exclude_token_ids:
            random_replace_token_prob_mask = random_replace_token_prob_mask & (random_tokens != token_id)

        # 对序列进行噪声处理

        noised_seq = torch.where(random_replace_token_prob_mask, random_tokens, seq)

        # 对注释进行噪声处理

        noised_annotation = annotation + add_annotation_prob_mask.type(annotation.dtype)
        noised_annotation = noised_annotation * remove_annotation_mask.type(annotation.dtype)

        # 使用模型进行去噪处理

        seq_logits, annotation_logits = self.model(noised_seq, noised_annotation, mask = mask)

        # 计算损失

        seq_logits = seq_logits[mask]
        seq_labels = seq_labels[mask]

        seq_loss = F.cross_entropy(seq_logits, seq_labels, reduction = 'sum')
        annotation_loss = F.binary_cross_entropy_with_logits(annotation_logits, annotation_labels, reduction = 'sum')

        # 返回序列损失加上注释损失的加权和
        return seq_loss * self.seq_loss_weight + annotation_loss * self.annotation_loss_weight

.\lucidrains\protein-bert-pytorch\protein_bert_pytorch\__init__.py

# 从 protein_bert_pytorch 包中导入 ProteinBERT 和 PretrainingWrapper 类
from protein_bert_pytorch.protein_bert_pytorch import ProteinBERT, PretrainingWrapper

ProteinBERT - Pytorch (wip)

Implementation of ProteinBERT in Pytorch.

Original Repository

Install

$ pip install protein-bert-pytorch

Usage

import torch
from protein_bert_pytorch import ProteinBERT

model = ProteinBERT(
    num_tokens = 21,
    num_annotation = 8943,
    dim = 512,
    dim_global = 256,
    depth = 6,
    narrow_conv_kernel = 9,
    wide_conv_kernel = 9,
    wide_conv_dilation = 5,
    attn_heads = 8,
    attn_dim_head = 64
)

seq = torch.randint(0, 21, (2, 2048))
mask = torch.ones(2, 2048).bool()
annotation = torch.randint(0, 1, (2, 8943)).float()

seq_logits, annotation_logits = model(seq, annotation, mask = mask) # (2, 2048, 21), (2, 8943)

To use for pretraining

import torch
from protein_bert_pytorch import ProteinBERT, PretrainingWrapper

model = ProteinBERT(
    num_tokens = 21,
    num_annotation = 8943,
    dim = 512,
    dim_global = 256,
    depth = 6,
    narrow_conv_kernel = 9,
    wide_conv_kernel = 9,
    wide_conv_dilation = 5,
    attn_heads = 8,
    attn_dim_head = 64,
    local_to_global_attn = False,
    local_self_attn = True,
    num_global_tokens = 2,
    glu_conv = False
)

learner = PretrainingWrapper(
    model,
    random_replace_token_prob = 0.05,    # what percentage of the tokens to replace with a random one, defaults to 5% as in paper
    remove_annotation_prob = 0.25,       # what percentage of annotations to remove, defaults to 25%
    add_annotation_prob = 0.01,          # probability to add an annotation randomly, defaults to 1%
    remove_all_annotations_prob = 0.5,   # what percentage of batch items to remove annotations for completely, defaults to 50%
    seq_loss_weight = 1.,                # weight on loss of sequence
    annotation_loss_weight = 1.,         # weight on loss of annotation
    exclude_token_ids = (0, 1, 2)        # for excluding padding, start, and end tokens from being masked
)

# do the following in a loop for a lot of sequences and annotations

seq        = torch.randint(0, 21, (2, 2048))
annotation = torch.randint(0, 1, (2, 8943)).float()
mask       = torch.ones(2, 2048).bool()

loss = learner(seq, annotation, mask = mask) # (2, 2048, 21), (2, 8943)
loss.backward()

# save your model and evaluate it

torch.save(model, './improved-protein-bert.pt')

Citations

@article {Brandes2021.05.24.445464,
    author      = {Brandes, Nadav and Ofer, Dan and Peleg, Yam and Rappoport, Nadav and Linial, Michal},
    title       = {ProteinBERT: A universal deep-learning model of protein sequence and function},
    year        = {2021},
    doi         = {10.1101/2021.05.24.445464},
    publisher   = {Cold Spring Harbor Laboratory},
    URL         = {https://www.biorxiv.org/content/early/2021/05/25/2021.05.24.445464},
    eprint      = {https://www.biorxiv.org/content/early/2021/05/25/2021.05.24.445464.full.pdf},
    journal     = {bioRxiv}
}

.\lucidrains\protein-bert-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
    name="protein-bert-pytorch",  # 包的名称
    packages=find_packages(),  # 查找所有包
    version="0.1.0",  # 版本号
    license="MIT",  # 许可证
    description="ProteinBERT - Pytorch",  # 描述
    author="Phil Wang",  # 作者
    author_email="lucidrains@gmail.com",  # 作者邮箱
    url="https://github.com/lucidrains/protein-bert-pytorch",  # 项目链接
    keywords=[  # 关键词列表
        "artificial intelligence",
        "deep learning",
        "attention mechanism",
        "protein sequences",
        "unsupervised learning"
    ],
    install_requires=[  # 安装依赖
        "einops>=0.3",
        "torch>=1.6",
    ],
    classifiers=[  # 分类器列表
        "Development Status :: 4 - Beta",
        "Intended Audience :: Developers",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "License :: OSI Approved :: MIT License",
        "Programming Language :: Python :: 3.6",
    ],
)

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\accelerate_utils.py

# 导入必要的模块
from functools import partial, wraps
from typing import Optional, Callable
from contextlib import nullcontext, contextmanager

from torch.nn import Module

from accelerate import Accelerator
from accelerate.tracking import WandBTracker

# 辅助函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 创建一个结合两个上下文管理器的上下文管理器
@contextmanager
def combine_contexts(a, b):
    with a() as c1, b() as c2:
        yield (c1, c2)

# 在数组中查找第一个满足条件的元素
def find_first(cond: Callable, arr):
    for el in arr:
        if cond(el):
            return el

    return None

# 添加一个用于 wandb 跟踪的上下文管理器,具有特定的项目和实验名称

def add_wandb_tracker_contextmanager(
    accelerator_instance_name = 'accelerator',
    tracker_hps_instance_name = 'tracker_hps'
):
    def decorator(klass):

        @contextmanager
        def wandb_tracking(
            self,
            project: str,
            run: Optional[str] = None,
            hps: Optional[dict] = None
        ):
            maybe_accelerator = getattr(self, accelerator_instance_name, None)

            assert exists(maybe_accelerator) and isinstance(maybe_accelerator, Accelerator), f'Accelerator instance not found at self.{accelerator_instance_name}'

            hps = getattr(self, tracker_hps_instance_name, hps)

            maybe_accelerator.init_trackers(project, config = hps)

            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), maybe_accelerator.trackers)

            assert exists(wandb_tracker), 'wandb tracking was not enabled. you need to set `log_with = "wandb"` on your accelerate kwargs'

            if exists(run):
                assert exists(wandb_tracker)
                wandb_tracker.run.name = run

            yield

            maybe_accelerator.end_training() 

        if not hasattr(klass, 'wandb_tracking'):
            klass.wandb_tracking = wandb_tracking

        return klass

    return decorator

# 当在可能的 DDP 包装的主模型上找不到属性时,自动取消包装模型

class ForwardingWrapper:
  def __init__(self, parent, child):
    self.parent = parent
    self.child = child

  def __getattr__(self, key):
    if hasattr(self.parent, key):
      return getattr(self.parent, key)

    return getattr(self.child, key)

  def __call__(self, *args, **kwargs):
    call_fn = self.__getattr__('__call__')
    return call_fn(*args, **kwargs)

def auto_unwrap_model(
    accelerator_instance_name = 'accelerator',
    model_instance_name = 'model'
):
    def decorator(klass):
        _orig_init = klass.__init__

        @wraps(_orig_init)
        def __init__(self, *args, **kwargs):
            _orig_init(self, *args, **kwargs)
            model = getattr(self, model_instance_name)
            accelerator = getattr(self, accelerator_instance_name)

            assert isinstance(accelerator, Accelerator)
            forward_wrapped_model = ForwardingWrapper(model, accelerator.unwrap_model(model))
            setattr(self, model_instance_name, forward_wrapped_model)

        klass.__init__ = __init__
        return klass

    return decorator

# 梯度累积上下文管理器
# 对除最后一次迭代外的所有迭代应用 no_sync 上下文

def model_forward_contexts(
    accelerator: Accelerator,
    model: Module,
    grad_accum_steps: int = 1
):
    for i in range(grad_accum_steps):
        is_last_step = i == grad_accum_steps - 1

        maybe_no_sync = partial(accelerator.no_sync, model) if not is_last_step else nullcontext

        yield partial(combine_contexts, accelerator.autocast, maybe_no_sync)

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\get_adam_optimizer.py

# 从 typing 模块导入 Tuple 类型
from typing import Tuple
# 从 torch.optim 模块导入 AdamW 和 Adam 优化器

# optimizer

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []

    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)

    return wd_params, no_wd_params

# 获取 Adam 优化器
def get_adam_optimizer(
    params,
    lr: float = 1e-4,
    wd: float = 1e-2,
    betas: Tuple[int, int] = (0.9, 0.99),
    eps: float = 1e-8,
    filter_by_requires_grad = False,
    omit_gammas_and_betas_from_wd = True,
    **kwargs
):
    # 判断是否需要权重衰减
    has_weight_decay = wd > 0.

    # 根据是否需要过滤 requires_grad 来筛选参数
    if filter_by_requires_grad:
        params = [t for t in params if t.requires_grad]

    # 设置优化器的参数
    opt_kwargs = dict(
        lr = lr,
        betas = betas,
        eps = eps
    )

    # 如果不需要权重衰减,则返回 Adam 优化器
    if not has_weight_decay:
        return Adam(params, **opt_kwargs)

    # 设置带有权重衰减的优化器参数
    opt_kwargs = {'weight_decay': wd, **opt_kwargs}

    # 如果不忽略 gammas 和 betas 的权重衰减,则返回 AdamW 优化器
    if not omit_gammas_and_betas_from_wd:
        return AdamW(params, **opt_kwargs)

    # 在 transformers 中有一种早期实践,其中从权重衰减中省略了 betas 和 gammas
    # 不确定是否真的需要
    wd_params, no_wd_params = separate_weight_decayable_params(params)

    # 将参数分为需要权重衰减和不需要权重衰减的两部分
    params = [
        {'params': wd_params},
        {'params': no_wd_params, 'weight_decay': 0},
    ]

    return AdamW(params, **opt_kwargs)

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\module_device.py

# 导入必要的模块
from functools import wraps
from typing import List
from optree import tree_flatten, tree_unflatten

import torch
from torch import is_tensor
from torch.nn import Module

# 为模型提供一个 .device 属性
# 使用一个虚拟的标量张量

def module_device(
    device_property_name = 'device'
):
    # 装饰器函数,用于装饰类
    def decorator(klass):
        # 断言被装饰的类是 torch.nn.Module 的子类
        assert issubclass(klass, Module), 'should decorate a subclass of torch.nn.Module'

        # 保存原始的 __init__ 方法
        _orig_init = klass.__init__

        @wraps(_orig_init)
        def __init__(self, *args, **kwargs):
            # 调用原始的 __init__ 方法
            _orig_init(self, *args, **kwargs)

            # 在模型中注册一个名为 '_dummy' 的缓冲区,值为 torch.tensor(0),不持久化
            self.register_buffer('_dummy', torch.tensor(0), persistent = False)

        @property
        def _device_property(self):
            # 返回 '_dummy' 缓冲区的设备信息
            return self._dummy.device

        # 替换类的 __init__ 方法为自定义的 __init__ 方法
        klass.__init__ = __init__
        # 设置类的属性 device_property_name 为 _device_property
        setattr(klass, device_property_name, _device_property)
        return klass

    return decorator

# 一个装饰器,自动将传入 .forward 方法的所有张量转换为正确的设备

def autocast_device(
    methods: List[str] = ['forward']
):
    # 装饰器函数,用于装饰类
    def decorator(klass):
        # 断言被装饰的类是 torch.nn.Module 的子类
        assert issubclass(klass, Module), 'should decorate a subclass of torch.nn.Module'

        # 获取要装饰的方法的原始函数
        orig_fns = [getattr(klass, method) for method in methods]

        for method, orig_fn in zip(methods, orig_fns):

            @wraps(orig_fn)
            def fn(self, *args, **kwargs):

                # 确定设备
                # 使用上面装饰器中的虚拟张量
                # 否则查找参数并使用参数上的设备

                if hasattr(self, '_dummy'):
                    device = self._dummy.device
                else:
                    device = next(self.parameters()).device

                # 展平参数

                flattened_args, tree_spec = tree_flatten([args, kwargs])

                # 转换参数

                maybe_transformed_args = []

                for flattened_arg in flattened_args:
                    if is_tensor(flattened_arg):
                        flattened_arg = flattened_arg.to(device)

                    maybe_transformed_args.append(flattened_arg)

                # 还原参数

                args, kwargs = tree_unflatten(tree_spec, maybe_transformed_args)

                # 调用原始函数

                orig_fn(self, *args, **kwargs)

            # 设置类的方法为新的 fn 函数
            setattr(klass, method, fn)

        return klass

    return decorator

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\optimizer_scheduler_warmup.py

# 导入所需的模块和类
from contextlib import nullcontext
from typing import Optional, Type
from accelerate import Accelerator
from functools import partial
from torch import nn
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
import pytorch_warmup as warmup

# 定义一个辅助函数,用于检查变量是否存在
def exists(v):
    return v is not None

# 定义一个常量,为 LambdaLR 类的部分应用,设置 lr_lambda 为恒定值 1.0
ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)

# 定义一个带有调度器和预热的优化器类
class OptimizerWithWarmupSchedule(nn.Module):
    def __init__(
        self,
        accelerator: Accelerator,
        optimizer: Optimizer,
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        warmup_steps: int = 0,
        max_grad_norm: Optional[float] = None
    ):
        super().__init__()
        self.max_grad_norm = max_grad_norm
        has_warmup = warmup_steps > 0

        # 如果有预热步数大于0,则创建 LinearWarmup 对象,否则为 None
        self.warmup = warmup.LinearWarmup(optimizer, warmup_period = warmup_steps) if has_warmup else None

        # 如果调度器存在,则使用给定参数创建调度器对象,否则使用常量调度器
        if exists(scheduler):
            self.scheduler = scheduler(optimizer, **scheduler_kwargs)
        else:
            self.scheduler = ConstantLRScheduler(optimizer)

        self.optimizer = optimizer

        # 准备优化器和调度器,返回准备后的优化器和调度器对象
        self.optimizer, self.scheduler = accelerator.prepare(self.optimizer, self.scheduler)
        self.accelerator = accelerator

    # 返回当前状态的字典表示
    def state_dict(self):
        pkg = dict(
            optimizer = self.optimizer.state_dict(),
            scheduler = self.scheduler.state_dict()
        )

        if exists(self.warmup):
            pkg['warmup'] = self.warmup.state_dict()

        return pkg

    # 加载状态字典表示
    def load_state_dict(self, pkg):
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.scheduler.load_state_dict(pkg['scheduler'])

        if exists(self.warmup):
            self.warmup.load_state_dict(pkg['warmup'])

    # 将所有参数的梯度清零
    def zero_grad(self):
        self.optimizer.zero_grad()

    # 执行一步优化
    def step(self):
        # 如果最大梯度范数存在,则对参数进行梯度裁剪
        if exists(self.max_grad_norm):
            for param_group in self.optimizer.param_groups:
                self.accelerator.clip_grad_norm_(param_group['params'], self.max_grad_norm)

        # 执行一步优化
        self.optimizer.step()

        # 如果优化步骤未被跳过,则执行调度器的步骤
        if not self.accelerator.optimizer_step_was_skipped:
            # 根据是否存在预热对象,选择上下文管理器
            context = nullcontext if not exists(self.warmup) else self.warmup.dampening

            # 执行调度器的步骤
            with context():
                self.scheduler.step()

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\save_load.py

# 导入所需的模块
import pickle
from functools import wraps
from pathlib import Path
from packaging import version
import torch
from torch.nn import Module
from beartype import beartype
from beartype.typing import Optional

# 定义一个辅助函数,用于检查变量是否存在
def exists(v):
    return v is not None

# 装饰器函数,用于保存和加载模型
@beartype
def save_load(
    save_method_name = 'save',
    load_method_name = 'load',
    config_instance_var_name = '_config',
    init_and_load_classmethod_name = 'init_and_load',
    version: Optional[str] = None
):
    # 内部函数,用于实现保存和加载功能
    def _save_load(klass):
        # 断言被装饰的类是 torch.nn.Module 的子类
        assert issubclass(klass, Module), 'save_load should decorate a subclass of torch.nn.Module'

        # 保存原始的 __init__ 方法
        _orig_init = klass.__init__

        # 重写 __init__ 方法
        @wraps(_orig_init)
        def __init__(self, *args, **kwargs):
            # 序列化参数和关键字参数
            _config = pickle.dumps((args, kwargs))
            # 将序列化后的参数保存到实例变量中
            setattr(self, config_instance_var_name, _config)
            # 调用原始的 __init__ 方法
            _orig_init(self, *args, **kwargs)

        # 保存模型到文件
        def _save(self, path, overwrite = True):
            path = Path(path)
            assert overwrite or not path.exists()

            pkg = dict(
                model = self.state_dict(),
                config = getattr(self, config_instance_var_name),
                version = version,
            )

            torch.save(pkg, str(path))

        # 从文件加载模型
        def _load(self, path, strict = True):
            path = Path(path)
            assert path.exists()

            pkg = torch.load(str(path), map_location = 'cpu')

            if exists(version) and exists(pkg['version']) and version.parse(version) != version.parse(pkg['version']):
                self.print(f'loading saved model at version {pkg["version"]}, but current package version is {__version__}')

            self.load_state_dict(pkg['model'], strict = strict)

        # 从文件初始化并加载模型
        @classmethod
        def _init_and_load_from(cls, path, strict = True):
            path = Path(path)
            assert path.exists()
            pkg = torch.load(str(path), map_location = 'cpu')

            assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

            config = pickle.loads(pkg['config'])
            args, kwargs = config
            model = cls(*args, **kwargs)

            _load(model, path, strict = strict)
            return model

        # 设置装饰后的 __init__ 方法,以及保存、加载和初始化加载方法
        klass.__init__ = __init__
        setattr(klass, save_method_name, _save)
        setattr(klass, load_method_name, _load)
        setattr(klass, init_and_load_classmethod_name, _init_and_load_from)

        return klass

    return _save_load

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\total_parameters.py

# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module

# 为你的模型提供一个 .total_parameters 属性,该属性简单地对所有模块的参数求和

# 定义一个装饰器函数,用于为类添加 total_parameters 属性
def total_parameters(
    count_only_requires_grad = False,  # 是否只计算需要梯度的参数
    total_parameters_property_name = 'total_parameters'  # total_parameters 属性的名称
):
    # 装饰器函数
    def decorator(klass):
        # 断言 klass 是 torch.nn.Module 的子类
        assert issubclass(klass, Module), 'should decorate a subclass of torch.nn.Module'

        # 定义一个计算所有参数数量的属性
        @property
        def _total_parameters(self):
            return sum(p.numel() for p in self.parameters())

        # 定义一个计算需要梯度的参数数量的属性
        @property
        def _total_parameters_with_requires_grad(self):
            return sum(p.numel() for p in self.parameters() if p.requires_grad)

        # 根据 count_only_requires_grad 的值选择计算哪种参数数量
        fn = _total_parameters_with_requires_grad if count_only_requires_grad else  _total_parameters

        # 将计算参数数量的函数设置为 klass 的属性
        setattr(klass, total_parameters_property_name, fn)
        return klass

    return decorator

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\utils.py

# 导入所需的模块
from typing import Tuple
import torch.nn.functional as F

# 检查变量是否存在
def exists(v):
    return v is not None

# 填充和切片

# 在指定维度上填充张量
def pad_at_dim(t, pad: Tuple[int, int], *, dim = -1, value = 0.):
    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value = value)

# 在指定维度上切片张量
def slice_at_dim(t, dim_slice: slice, *, dim):
    dim += (t.ndim if dim < 0 else 0)
    colons = [slice(None)] * t.ndim
    colons[dim] = dim_slice
    return t[tuple(colons)]

# 根据长度填充或切片张量
def pad_or_slice_to(t, length, *, dim, pad_value = 0):
    curr_length = t.shape[dim]

    if curr_length < length:
        t = pad_at_dim(t, (0, length - curr_length), dim = dim, value = pad_value)
    elif curr_length > length:
        t = slice_at_dim(t, slice(0, length), dim = dim)

    return t

# 与掩码相关

# 计算受掩码影响的张量的均值
def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
    if not exists(mask):
        return tensor.mean(dim = dim)

    tensor.masked_fill_(~mask, 0.)

    total_el = mask.sum(dim = dim)
    num = tensor.sum(dim = dim)
    den = total_el.float().clamp(min = eps)
    mean = num / den
    mean.masked_fill_(total_el == 0, 0.)
    return mean

# 对多个掩码进行逻辑与操作
def maybe_and_mask(*masks):
    masks = [*filter(exists, masks)]
    if len(masks) == 0:
        return None

    mask, *rest_masks = masks
    for rest_mask in rest_masks:
        mask = mask & rest_mask

    return mask

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\__init__.py

# 从 pytorch_custom_utils.module_device 模块中导入 module_device 和 autocast_device 函数
from pytorch_custom_utils.module_device import (
    module_device,
    autocast_device
)

# 从 pytorch_custom_utils.save_load 模块中导入 save_load 函数
from pytorch_custom_utils.save_load import save_load

# 从 pytorch_custom_utils.total_parameters 模块中导入 total_parameters 函数
from pytorch_custom_utils.total_parameters import total_parameters

# 从 pytorch_custom_utils.get_adam_optimizer 模块中导入 get_adam_optimizer 函数
from pytorch_custom_utils.get_adam_optimizer import get_adam_optimizer

# 从 pytorch_custom_utils.optimizer_scheduler_warmup 模块中导入 OptimizerWithWarmupSchedule 类
from pytorch_custom_utils.optimizer_scheduler_warmup import OptimizerWithWarmupSchedule

# 从 pytorch_custom_utils.accelerate_utils 模块中导入 add_wandb_tracker_contextmanager 和 auto_unwrap_model 函数
from pytorch_custom_utils.accelerate_utils import (
    add_wandb_tracker_contextmanager,
    auto_unwrap_model
)

Pytorch Custom Utils (wip)

Just some miscellaneous utility functions / decorators / modules related to Pytorch and Accelerate to help speed up implementation of new AI research

Install

$ pip install pytorch-custom-utils

Quick save and load

Class decorator for adding a quick save and load method to the module instance. Can also initialize the entire network with a class method, init_and_load.

ex.

import torch
from torch import nn

from pytorch_custom_utils import save_load

# decorate the entire class with `save_load` class decorator

@save_load()
class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim))

    def forward(self, x):
        return self.net(x)

# instantiated mlp

mlp = MLP(dim = 512)

# now you have a save and load method

mlp.save('./mlp.pt')
mlp.load('./mlp.pt')

# you can also directly initialize from the checkpoint, without having to save the corresponding hyperparameters (in this case, dim = 512)

mlp = MLP.init_and_load('./mlp.pt')

Keep track of device on module

ex.

import torch
from torch import nn

from pytorch_custom_utils import module_device

# decorate the class with `module_device` class decorator

@module_device()
class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Linear(dim, dim)

    def forward(self, x):
        return self.net(x)

# instantiated mlp

mlp = MLP(dim = 512)
mlp.to(torch.device('mps'))

# now you have a convenient .device

mlp.device # mps:0

.\lucidrains\pytorch-custom-utils\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'pytorch-custom-utils',  # 包名
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.18',  # 版本号
  license='MIT',  # 许可证
  description = 'Pytorch Custom Utils',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/pytorch-custom-utils',  # URL
  keywords = [
    'pytorch',  # 关键字
    'accelerate'  # 关键字
  ],
  install_requires=[
    'accelerate',  # 安装依赖
    'optree',  # 安装依赖
    'pytorch-warmup',  # 安装依赖
    'torch>=2.0'  # 安装依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类
    'Intended Audience :: Developers',  # 分类
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类
    'License :: OSI Approved :: MIT License',  # 分类
    'Programming Language :: Python :: 3.6',  # 分类
  ],
)

.\lucidrains\q-transformer\q_transformer\agent.py

# 导入必要的库
import sys
from pathlib import Path

# 导入 numpy 的相关模块
from numpy.lib.format import open_memmap

# 导入 torch 相关模块
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
from torch.utils.data import Dataset

# 导入 einops 库
from einops import rearrange

# 导入自定义的 QRoboticTransformer 类
from q_transformer.q_robotic_transformer import QRoboticTransformer

# 导入 torchtyping 库
from torchtyping import TensorType

# 导入 beartype 库
from beartype import beartype
from beartype.typing import Iterator, Tuple, Union

# 导入 tqdm 库
from tqdm import tqdm

# 确保在 64 位系统上进行训练
assert sys.maxsize > (2 ** 32), 'you need to be on 64 bit system to store > 2GB experience for your q-transformer agent'

# 定义常量
TEXT_EMBEDS_FILENAME = 'text_embeds.memmap.npy'
STATES_FILENAME = 'states.memmap.npy'
ACTIONS_FILENAME = 'actions.memmap.npy'
REWARDS_FILENAME = 'rewards.memmap.npy'
DONES_FILENAME = 'dones.memmap.npy'

DEFAULT_REPLAY_MEMORIES_FOLDER = './replay_memories_data'

# 定义辅助函数
def exists(v):
    return v is not None

def cast_tuple(t):
    return (t,) if not isinstance(t, tuple) else t

# 定义回放记忆数据集类
class ReplayMemoryDataset(Dataset):
    @beartype
    def __init__(
        self,
        folder: str = DEFAULT_REPLAY_MEMORIES_FOLDER,
        num_timesteps: int = 1
    ):
        # 确保时间步数大于等于 1
        assert num_timesteps >= 1
        self.is_single_timestep = num_timesteps == 1
        self.num_timesteps = num_timesteps

        # 检查文件夹是否存在
        folder = Path(folder)
        assert folder.exists() and folder.is_dir()

        # 打开并读取相关文件
        text_embeds_path = folder / TEXT_EMBEDS_FILENAME
        states_path = folder / STATES_FILENAME
        actions_path = folder / ACTIONS_FILENAME
        rewards_path = folder / REWARDS_FILENAME
        dones_path = folder / DONES_FILENAME

        self.text_embeds = open_memmap(str(text_embeds_path), dtype='float32', mode='r')
        self.states = open_memmap(str(states_path), dtype='float32', mode='r')
        self.actions = open_memmap(str(actions_path), dtype='int', mode='r')
        self.rewards = open_memmap(str(rewards_path), dtype='float32', mode='r')
        self.dones = open_memmap(str(dones_path), dtype='bool', mode='r')

        self.num_timesteps = num_timesteps

        # 根据结束标志计算每个 episode 的长度
        self.episode_length = (self.dones.cumsum(axis=-1) == 0).sum(axis=-1) + 1

        # 过滤出长度足够的 episode
        trainable_episode_indices = self.episode_length >= num_timesteps

        self.text_embeds = self.text_embeds[trainable_episode_indices]
        self.states = self.states[trainable_episode_indices]
        self.actions = self.actions[trainable_episode_indices]
        self.rewards = self.rewards[trainable_episode_indices]
        self.dones = self.dones[trainable_episode_indices]

        self.episode_length = self.episode_length[trainable_episode_indices]

        # 确保存在可训练的 episode
        assert self.dones.size > 0, 'no trainable episodes'

        self.num_episodes, self.max_episode_len = self.dones.shape

        timestep_arange = torch.arange(self.max_episode_len)

        timestep_indices = torch.stack(torch.meshgrid(
            torch.arange(self.num_episodes),
            timestep_arange
        ), dim=-1)

        trainable_mask = timestep_arange < rearrange(torch.from_numpy(self.episode_length) - num_timesteps, 'e -> e 1')
        self.indices = timestep_indices[trainable_mask]

    # 返回数据集的长度
    def __len__(self):
        return self.indices.shape[0]
    # 重载索引操作符,根据索引获取数据
    def __getitem__(self, idx):
        # 从索引中获取当前 episode 和 timestep 的索引
        episode_index, timestep_index = self.indices[idx]

        # 创建一个切片对象,用于获取当前 timestep 到 num_timesteps 之间的数据
        timestep_slice = slice(timestep_index, (timestep_index + self.num_timesteps))

        # 复制当前 episode 的文本嵌入数据
        text_embeds = self.text_embeds[episode_index, timestep_slice].copy()
        # 复制当前 episode 的状态数据
        states = self.states[episode_index, timestep_slice].copy()
        # 复制当前 episode 的动作数据
        actions = self.actions[episode_index, timestep_slice].copy()
        # 复制当前 episode 的奖励数据
        rewards = self.rewards[episode_index, timestep_slice].copy()
        # 复制当前 episode 的完成标志数据
        dones = self.dones[episode_index, timestep_slice].copy()

        # 获取下一个状态数据,如果当前 timestep 已经是最后一个,则获取最后一个状态数据
        next_state = self.states[episode_index, min(timestep_index, self.max_episode_len - 1)].copy()

        # 返回文本嵌入数据、状态数据、动作数据、下一个状态数据、奖励数据、完成标志数据
        return text_embeds, states, actions, next_state, rewards, dones
# 定义一个基础环境类,用于扩展
class BaseEnvironment(Module):
    # 初始化方法,接受状态形状和文本嵌入形状作为参数
    @beartype
    def __init__(
        self,
        *,
        state_shape: Tuple[int, ...],
        text_embed_shape: Union[int, Tuple[int, ...]]
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置状态形状和文本嵌入形状属性
        self.state_shape = state_shape
        self.text_embed_shape = cast_tuple(text_embed_shape)
        # 注册一个缓冲区
        self.register_buffer('dummy', torch.zeros(0), persistent=False)

    # 返回缓冲区所在设备
    @property
    def device(self):
        return self.dummy.device

    # 初始化方法,返回指令和初始状态
    def init(self) -> Tuple[str, Tensor]:
        raise NotImplementedError

    # 前向传播方法,接受动作作为参数,返回奖励、下一个状态和是否结束的元组
    def forward(
        self,
        actions: Tensor
    ) -> Tuple[
        TensorType[(), float],     # reward
        Tensor,                    # next state
        TensorType[(), bool]       # done
    ]:
        raise NotImplementedError

# 代理类
class Agent(Module):
    # 初始化方法,接受 QRoboticTransformer 对象、环境对象和一些参数
    @beartype
    def __init__(
        self,
        q_transformer: QRoboticTransformer,
        *,
        environment: BaseEnvironment,
        memories_dataset_folder: str = DEFAULT_REPLAY_MEMORIES_FOLDER,
        num_episodes: int = 1000,
        max_num_steps_per_episode: int = 10000,
        epsilon_start: float = 0.25,
        epsilon_end: float = 0.001,
        num_steps_to_target_epsilon: int = 1000
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置 QRoboticTransformer 对象
        self.q_transformer = q_transformer
        # 设置是否在文本上进行条件
        condition_on_text = q_transformer.condition_on_text
        self.condition_on_text = condition_on_text
        # 设置环境对象
        self.environment = environment

        # 断言环境对象具有状态形状和文本嵌入形状属性
        assert hasattr(environment, 'state_shape') and hasattr(environment, 'text_embed_shape')

        # 断言参数的取值范围
        assert 0. <= epsilon_start <= 1.
        assert 0. <= epsilon_end <= 1.
        assert epsilon_start >= epsilon_end

        # 设置一些参数
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.num_steps_to_target_epsilon = num_steps_to_target_epsilon
        self.epsilon_slope = (epsilon_end - epsilon_start) / num_steps_to_target_epsilon

        self.num_episodes = num_episodes
        self.max_num_steps_per_episode = max_num_steps_per_episode

        # 创建存储回忆的文件夹
        mem_path = Path(memories_dataset_folder)
        self.memories_dataset_folder = mem_path

        mem_path.mkdir(exist_ok=True, parents=True)
        assert mem_path.is_dir()

        # 设置存储状态、动作、奖励和结束标志的文件路径
        states_path = mem_path / STATES_FILENAME
        actions_path = mem_path / ACTIONS_FILENAME
        rewards_path = mem_path / REWARDS_FILENAME
        dones_path = mem_path / DONES_FILENAME

        # 设置先验形状和动作数量
        prec_shape = (num_episodes, max_num_steps_per_episode)
        num_actions = q_transformer.num_actions
        state_shape = environment.state_shape

        # 如果在文本上进行条件
        if condition_on_text:
            text_embeds_path = mem_path / TEXT_EMBEDS_FILENAME
            text_embed_shape = environment.text_embed_shape
            self.text_embed_shape = text_embed_shape
            # 创建文本嵌入的内存映射
            self.text_embeds = open_memmap(str(text_embeds_path), dtype='float32', mode='w+', shape=(*prec_shape, *text_embed_shape))

        # 创建状态、动作、奖励和结束标志的内存映射
        self.states = open_memmap(str(states_path), dtype='float32', mode='w+', shape=(*prec_shape, *state_shape))
        self.actions = open_memmap(str(actions_path), dtype='int', mode='w+', shape=(*prec_shape, num_actions))
        self.rewards = open_memmap(str(rewards_path), dtype='float32', mode='w+', shape=prec_shape)
        self.dones = open_memmap(str(dones_path), dtype='bool', mode='w+', shape=prec_shape)

    # 根据步数获取 epsilon 值
    def get_epsilon(self, step):
        return max(self.epsilon_end, self.epsilon_slope * float(step) + self.epsilon_start)

    # 无需梯度的装饰器
    @beartype
    @torch.no_grad()
    # 定义一个方法,用于执行前向传播
    def forward(self):
        # 将 Q-Transformer 设置为评估模式
        self.q_transformer.eval()

        # 循环执行多个 episode
        for episode in range(self.num_episodes):
            # 打印当前 episode 的信息
            print(f'episode {episode}')

            # 初始化环境,获取指令和当前状态
            instruction, curr_state = self.environment.init()

            # 在每个 episode 中执行多个步骤
            for step in tqdm(range(self.max_num_steps_per_episode)):
                # 判断是否是最后一个步骤
                last_step = step == (self.max_num_steps_per_episode - 1)

                # 根据当前步骤获取 epsilon 值
                epsilon = self.get_epsilon(step)

                # 初始化文本嵌入为 None
                text_embed = None

                # 如果需要根据文本条件执行动作
                if self.condition_on_text:
                    # 获取指令的文本嵌入
                    text_embed = self.q_transformer.embed_texts([instruction])

                # 获取动作
                actions = self.q_transformer.get_actions(
                    rearrange(curr_state, '... -> 1 ...'),
                    text_embeds = text_embed,
                    prob_random_action = epsilon
                )

                # 执行动作,获取奖励、下一个状态和是否结束的标志
                reward, next_state, done = self.environment(actions)

                # 判断是否结束或是最后一个步骤
                done = done | last_step

                # 使用 memmap 存储记忆,以便后续回顾和学习

                # 如果需要根据文本条件执行动作
                if self.condition_on_text:
                    # 断言文本嵌入的形状符合预期
                    assert text_embed.shape[1:] == self.text_embed_shape
                    # 将文本嵌入存储到指定位置
                    self.text_embeds[episode, step] = text_embed

                # 存储当前状态、动作、奖励和结束标志
                self.states[episode, step]      = curr_state
                self.actions[episode, step]     = actions
                self.rewards[episode, step]     = reward
                self.dones[episode, step]       = done

                # 如果已经结束,跳出当前 episode 的循环
                if done:
                    break

                # 更新当前状态为下一个状态
                curr_state = next_state

            # 如果需要根据文本条件执行动作
            if self.condition_on_text:
                # 刷���文本嵌入的存储
                self.text_embeds.flush()

            # 刷新当前状态、动作、奖励和结束标志的存储
            self.states.flush()
            self.actions.flush()
            self.rewards.flush()
            self.dones.flush()

        # 关闭 memmap

        # 如果需要根据文本条件执行动作
        if self.condition_on_text:
            # 删除文本嵌入
            del self.text_embeds

        # 删除当前状态、动作、奖励和结束标志
        del self.states
        del self.actions
        del self.rewards
        del self.dones

        # 打印完成信息,存储的记忆位置
        print(f'completed, memories stored to {self.memories_dataset_folder.resolve()}')

.\lucidrains\q-transformer\q_transformer\attend.py

# 导入所需的模块和函数
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# 定义一个装饰器函数,确保函数只被调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用装饰器once包装print函数,确保只打印一次
print_once = once(print)

# 判断变量是否存在的辅助函数
def exists(val):
    return val is not None

# 如果val存在则返回val,否则返回默认值d的辅助函数
def default(val, d):
    return val if exists(val) else d

# 将多个可能的mask合并为一个mask的辅助函数
def maybe_reduce_mask_and(*maybe_masks):
    maybe_masks = [*filter(exists, maybe_masks)]

    if len(maybe_masks) == 0:
        return None

    mask, *rest_masks = maybe_masks

    for rest_mask in rest_masks:
        mask = mask & rest_mask

    return mask

# 主要的Attend类
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False,
        causal = False,
        flash_config: dict = dict(
            enable_flash = True,
            enable_math = True,
            enable_mem_efficient = True
        )
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        if flash:
            print_once('using memory efficient attention')

        self.flash_config = flash_config

    # Flash Attention函数
    def flash_attn(self, q, k, v, mask = None, attn_mask = None):
        _, heads, q_len, dim_head, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查mask是否存在并扩展到兼容的形状
        if exists(mask):
            mask = mask.expand(-1, heads, q_len, -1)

        mask = maybe_reduce_mask_and(mask, attn_mask)

        # 使用torch.backends.cuda.sdp_kernel(**self.flash_config)进行pytorch 2.0的flash attention计算
        with torch.backends.cuda.sdp_kernel(**self.flash_config):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                is_causal = self.causal,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v, mask = None, attn_mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        scale = q.shape[-1] ** -0.5

        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')

        if self.flash:
            return self.flash_attn(q, k, v, mask = mask, attn_mask = attn_mask)

        # 相似度计算
        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # 因果mask
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # key padding mask
        if exists(mask):
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # attention mask
        if exists(attn_mask):
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        # 注意力权重计算
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\q-transformer\q_transformer\mocks.py

from random import randrange
# 从 random 模块中导入 randrange 函数

import torch
# 导入 torch 库
from torch.utils.data import Dataset
# 从 torch.utils.data 模块中导入 Dataset 类

from beartype.typing import Tuple, Optional
# 从 beartype.typing 模块中导入 Tuple 和 Optional 类型

from torchtyping import TensorType
# 从 torchtyping 模块中导入 TensorType 类型
from q_transformer.agent import BaseEnvironment
# 从 q_transformer.agent 模块中导入 BaseEnvironment 类

class MockEnvironment(BaseEnvironment):
    # 定义 MockEnvironment 类,继承自 BaseEnvironment 类
    def init(self) -> Tuple[
        Optional[str],
        TensorType[float]
    ]:
        # 初始化方法,返回一个元组,包含可选的字符串和浮点数张量
        return 'please clean the kitchen', torch.randn(self.state_shape, device = self.device)
        # 返回指令字符串和根据状态形状和设备生成的随机张量

    def forward(self, actions) -> Tuple[
        TensorType[(), float],
        TensorType[float],
        TensorType[(), bool]
    ]:
        # 前向传播方法,接受动作参数,返回一个元组,包含标量浮点数张量、浮点数张量和布尔值张量
        rewards = torch.randn((), device = self.device)
        # 生成一个随机标量浮点数张量
        next_states = torch.randn(self.state_shape, device = self.device)
        # 生成一个随机状态形状的浮点数张量
        done = torch.zeros((), device = self.device, dtype = torch.bool)
        # 生成一个全零张量,数据类型为布尔型

        return rewards, next_states, done
        # 返回奖励、下一个状态和完成标志

class MockReplayDataset(Dataset):
    # 定义 MockReplayDataset 类,继承自 Dataset 类
    def __init__(
        self,
        length = 10000,
        num_actions = 1,
        num_action_bins = 256,
        video_shape = (6, 224, 224)
    ):
        # 初始化方法,设置数据集长度、动作数量、动作区间数量和视频形状
        self.length = length
        self.num_actions = num_actions
        self.num_action_bins = num_action_bins
        self.video_shape = video_shape

    def __len__(self):
        # 返回数据集长度
        return self.length

    def __getitem__(self, _):
        # 获取数据集中的一项
        instruction = "please clean the kitchen"
        # 指令字符串
        state = torch.randn(3, *self.video_shape)
        # 随机生成一个状态张量

        if self.num_actions == 1:
            action = torch.tensor(randrange(self.num_action_bins + 1))
        else:
            action = torch.randint(0, self.num_action_bins + 1, (self.num_actions,))
        # 根据动作数量生成动作张量

        next_state = torch.randn(3, *self.video_shape)
        # 随机生成下一个状态张量
        reward = torch.tensor(randrange(2))
        # 随机生成奖励张量
        done = torch.tensor(randrange(2), dtype = torch.bool)
        # 随机生成完成标志张量

        return instruction, state, action, next_state, reward, done
        # 返回指令、状态、动作、下一个状态、奖励和完成标志

class MockReplayNStepDataset(Dataset):
    # 定义 MockReplayNStepDataset 类,继承自 Dataset 类
    def __init__(
        self,
        length = 10000,
        num_steps = 2,
        num_actions = 1,
        num_action_bins = 256,
        video_shape = (6, 224, 224)
    ):
        # 初始化方法,设置数据集长度、步数、动作数量、动作区间数量和视频形状
        self.num_steps = num_steps
        self.time_shape = (num_steps,)
        self.length = length
        self.num_actions = num_actions
        self.num_action_bins = num_action_bins
        self.video_shape = video_shape

    def __len__(self):
        # 返回数据集长度
        return self.length

    def __getitem__(self, _):
        # 获取数据集中的一项
        action_dims = (self.num_actions,) if self.num_actions > 1 else tuple()
        # 根据动作数量设置动作维度元组

        instruction = "please clean the kitchen"
        # 指令字符串
        state = torch.randn(*self.time_shape, 3, *self.video_shape)
        # 随机生成一个时间维度状态张量
        action = torch.randint(0, self.num_action_bins + 1, (*self.time_shape, *action_dims))
        # 根据动作数量生成动作张量
        next_state = torch.randn(3, *self.video_shape)
        # 随机生成下一个状态张量
        reward = torch.randint(0, 2, self.time_shape)
        # 随机生成奖励张量
        done = torch.zeros(self.time_shape, dtype = torch.bool)
        # 生成全零完成标志张量

        return instruction, state, action, next_state, reward, done
        # 返回指令、状态、动作、下一个状态、奖励和完成标志

.\lucidrains\q-transformer\q_transformer\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取 Adam 或 AdamW 优化器
def get_adam_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True
):
    # 判断是否需要权重衰减
    has_wd = wd > 0

    # 根据是否需要过滤梯度为 True 的参数来更新参数列表
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并进行权重衰减
    if group_wd_params and has_wd:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 更新参数列表,分为需要权重衰减和不需要权重衰减的两组
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 如果不需要权重衰减,则返回 Adam 优化器
    if not has_wd:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要权重衰减,则返回 AdamW 优化器
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\q-transformer\q_transformer\q_learner.py

# 导入所需的模块
from pathlib import Path
from functools import partial
from contextlib import nullcontext
from collections import namedtuple

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
from torch.utils.data import Dataset, DataLoader

# 导入自定义的类型注解模块
from torchtyping import TensorType

# 导入 einops 相关模块
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 导入 beartype 相关模块
from beartype import beartype
from beartype.typing import Optional, Union, List, Tuple

# 导入自定义的 QRoboticTransformer 类
from q_transformer.q_robotic_transformer import QRoboticTransformer

# 导入自定义的优化器获取函数
from q_transformer.optimizer import get_adam_optimizer

# 导入 accelerate 相关模块
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

# 导入 EMA 模块
from ema_pytorch import EMA

# 定义常量

# 定义 QIntermediates 命名元组,包含 Q 学习中的中间变量
QIntermediates = namedtuple('QIntermediates', [
    'q_pred_all_actions',
    'q_pred',
    'q_next',
    'q_target'
])

# 定义 Losses 命名元组,包含损失函数中的损失项
Losses = namedtuple('Losses', [
    'td_loss',
    'conservative_reg_loss'
])

# 定义辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断两个数是否整除
def is_divisible(num, den):
    return (num % den) == 0

# 将单个张量按指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将单个张量按指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 生成数据集的无限循环迭代器
def cycle(dl):
    while True:
        for batch in dl:
            yield batch

# 张量操作辅助函数

# 从张量中选择指定索引的元素
def batch_select_indices(t, indices):
    indices = rearrange(indices, '... -> ... 1')
    selected = t.gather(-1, indices)
    return rearrange(selected, '... 1 -> ...')

# Q 学习在机器人变压器上的实现

# 定义 QLearner 类,继承自 Module
class QLearner(Module):

    # 初始化函数
    @beartype
    def __init__(
        self,
        model: Union[QRoboticTransformer, Module],
        *,
        dataset: Dataset,
        batch_size: int,
        num_train_steps: int,
        learning_rate: float,
        min_reward: float = 0.,
        grad_accum_every: int = 1,
        monte_carlo_return: Optional[float] = None,
        weight_decay: float = 0.,
        accelerator: Optional[Accelerator] = None,
        accelerator_kwargs: dict = dict(),
        dataloader_kwargs: dict = dict(
            shuffle = True
        ),
        q_target_ema_kwargs: dict = dict(
            beta = 0.99,
            update_after_step = 10,
            update_every = 5
        ),
        max_grad_norm = 0.5,
        n_step_q_learning = False,
        discount_factor_gamma = 0.98,
        conservative_reg_loss_weight = 1., # they claim 1. is best in paper
        optimizer_kwargs: dict = dict(),
        checkpoint_folder = './checkpoints',
        checkpoint_every = 1000,
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        model,
        discount_factor_gamma,
        n_step_q_learning,
        conservative_reg_loss_weight,
        q_target_ema_kwargs,
        max_grad_norm,
        learning_rate,
        weight_decay,
        optimizer_kwargs,
        accelerator,
        accelerator_kwargs,
        min_reward,
        monte_carlo_return,
        dataset,
        batch_size,
        dataloader_kwargs,
        checkpoint_every,
        checkpoint_folder,
        num_train_steps,
        grad_accum_every
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 判断是否有多个动作
        self.is_multiple_actions = model.num_actions > 1

        # Q-learning 相关超参数
        self.discount_factor_gamma = discount_factor_gamma
        self.n_step_q_learning = n_step_q_learning

        # 是否有保守正则化损失
        self.has_conservative_reg_loss = conservative_reg_loss_weight > 0.
        self.conservative_reg_loss_weight = conservative_reg_loss_weight

        # 注册缓冲区
        self.register_buffer('discount_matrix', None, persistent = False)

        # 在线 Q 模型
        self.model = model

        # EMA(目标)Q 模型
        self.ema_model = EMA(
            model,
            include_online_model = False,
            **q_target_ema_kwargs
        )

        # 最大梯度范数
        self.max_grad_norm = max_grad_norm

        # 获取 Adam 优化器
        self.optimizer = get_adam_optimizer(
            model.parameters(),
            lr = learning_rate,
            wd = weight_decay,
            **optimizer_kwargs
        )

        # 如果加速器不存在,则创建一个
        if not exists(accelerator):
            accelerator = Accelerator(
                kwargs_handlers = [
                    DistributedDataParallelKwargs(find_unused_parameters = True)
                ],
                **accelerator_kwargs
            )

        self.accelerator = accelerator

        # 最小奖励和蒙特卡洛回报
        self.min_reward = min_reward
        self.monte_carlo_return = monte_carlo_return

        # 创建数据加载器
        self.dataloader = DataLoader(
            dataset,
            batch_size = batch_size,
            **dataloader_kwargs
        )

        # 准备模型、EMA 模型、优化器和数据加载器
        (
            self.model,
            self.ema_model,
            self.optimizer,
            self.dataloader
        ) = self.accelerator.prepare(
            self.model,
            self.ema_model,
            self.optimizer,
            self.dataloader
        )

        # 检查点相关
        self.checkpoint_every = checkpoint_every
        self.checkpoint_folder = Path(checkpoint_folder)

        # 创建检查点文件夹
        self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
        assert self.checkpoint_folder.is_dir()

        # 创建一个零张量作为虚拟损失
        self.register_buffer('zero', torch.tensor(0.))

        # 训练步骤相关
        self.num_train_steps = num_train_steps
        self.grad_accum_every = grad_accum_every

        # 注册步骤计数器
        self.register_buffer('step', torch.tensor(0))

    # 保存模型
    def save(
        self,
        checkpoint_num = None,
        overwrite = True
    ):
        name = 'checkpoint'
        if exists(checkpoint_num):
            name += f'-{checkpoint_num}'

        path = self.checkpoint_folder / (name + '.pt')

        assert overwrite or not path.exists()

        # 打包模型、EMA 模型、优化器和步骤计数器
        pkg = dict(
            model = self.unwrap(self.model).state_dict(),
            ema_model = self.unwrap(self.ema_model).state_dict(),
            optimizer = self.optimizer.state_dict(),
            step = self.step.item()
        )

        # 保存模型
        torch.save(pkg, str(path))

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert exists(path)

        pkg = torch.load(str(path))

        # 加载模型、EMA 模型和优化器
        self.unwrap(self.model).load_state_dict(pkg['model'])
        self.unwrap(self.ema_model).load_state_dict(pkg['ema_model'])

        self.optimizer.load_state_dict(pkg['optimizer'])
        self.step.copy_(pkg['step'])

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 解包模型
    def unwrap(self, module):
        return self.accelerator.unwrap_model(module)

    # 打印信息
    def print(self, msg):
        return self.accelerator.print(msg)

    # 等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()
    def get_discount_matrix(self, timestep):
        # 检查是否已存在折扣矩阵并且其时间步长大于等于当前时间步长
        if exists(self.discount_matrix) and self.discount_matrix.shape[-1] >= timestep:
            # 如果满足条件,则返回已存在的折扣矩阵的子矩阵
            return self.discount_matrix[:timestep, :timestep]

        # 创建一个时间步长范围的张量
        timestep_arange = torch.arange(timestep, device=self.accelerator.device)
        # 计算时间步长之间的幂次
        powers = (timestep_arange[None, :] - timestep_arange[:, None])
        # 根据幂次计算折扣矩阵
        discount_matrix = torch.triu(self.discount_factor_gamma ** powers)

        # 将折扣矩阵注册为缓冲区
        self.register_buffer('discount_matrix', discount_matrix, persistent=False)
        # 返回折扣矩阵
        return self.discount_matrix

    def q_learn(
        self,
        text_embeds: TensorType['b', 'd', float],
        states: TensorType['b', 'c', 'f', 'h', 'w', float],
        actions: TensorType['b', int],
        next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
        reward: TensorType['b', float],
        done: TensorType['b', bool],
        *,
        monte_carlo_return=None
    ) -> Tuple[TensorType[()], QIntermediates]:
        # 'next'代表下一个时间步(无论是状态、q值、动作等)

        γ = self.discount_factor_gamma
        # 计算非终止状态的掩码
        not_terminal = (~done).float()

        # 使用在线Q机器人变换器进行预测
        q_pred_all_actions = self.model(states, text_embeds=text_embeds)
        # 选择出采取的动作对应的Q值
        q_pred = batch_select_indices(q_pred_all_actions, actions)

        # 使用指数平滑的模型副本作为未来的Q目标。比在每个批次之后将q_target设置为q_eval更稳定
        # 最大Q值被视为最优动作,隐含地是具有最高Q分数的动作
        q_next = self.ema_model(next_states, text_embeds=text_embeds).amax(dim=-1)
        q_next.clamp_(min=default(monte_carlo_return, -1e4))

        # 贝尔曼方程。最重要的代码行,希望正确执行
        q_target = reward + not_terminal * (γ * q_next)

        # 强制在线模型能够预测这个目标
        loss = F.mse_loss(q_pred, q_target)

        # 这就是全部。对于Q学习的核心,大约5行代码
        # 返回损失和一些中间结果以便记录
        return loss, QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)

    def n_step_q_learn(
        self,
        text_embeds: TensorType['b', 'd', float],
        states: TensorType['b', 't', 'c', 'f', 'h', 'w', float],
        actions: TensorType['b', 't', int],
        next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
        rewards: TensorType['b', 't', float],
        dones: TensorType['b', 't', bool],
        *,
        monte_carlo_return=None
    ) -> Tuple[TensorType[()], QIntermediates]:
        """
        einops

        b - batch
        c - channels
        f - frames
        h - height
        w - width
        t - timesteps
        a - action bins
        q - q values
        d - text cond dimension
        """

        num_timesteps, device = states.shape[1], states.device

        # fold time steps into batch

        states, time_ps = pack_one(states, '* c f h w')
        text_embeds, _ = pack_one(text_embeds, '* d')

        # repeat text embeds per timestep

        repeated_text_embeds = repeat(text_embeds, 'b ... -> (b n) ...', n = num_timesteps)

        γ = self.discount_factor_gamma

        # anything after the first done flag will be considered terminal

        dones = dones.cumsum(dim = -1) > 0
        dones = F.pad(dones, (1, 0), value = False)

        not_terminal = (~dones).float()

        # get q predictions

        actions = rearrange(actions, 'b t -> (b t)')

        q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds)
        q_pred = batch_select_indices(q_pred_all_actions, actions)
        q_pred = unpack_one(q_pred, time_ps, '*')

        q_next = self.ema_model(next_states, text_embeds = text_embeds).amax(dim = -1)
        q_next.clamp_(min = default(monte_carlo_return, -1e4))

        # prepare rewards and discount factors across timesteps

        rewards, _ = pack([rewards, q_next], 'b *')

        γ = self.get_discount_matrix(num_timesteps + 1)[:-1, :]

        # account for discounting using the discount matrix

        q_target = einsum('b t, q t -> b q', not_terminal * rewards, γ)

        # have transformer learn to predict above Q target

        loss = F.mse_loss(q_pred, q_target)

        # prepare q prediction

        q_pred_all_actions = unpack_one(q_pred_all_actions, time_ps, '* a')

        return loss, QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)

    def autoregressive_q_learn_handle_single_timestep(
        self,
        text_embeds,
        states,
        actions,
        next_states,
        rewards,
        dones,
        *,
        monte_carlo_return = None
    ):
        """
        simply detect and handle single timestep
        and use `autoregressive_q_learn` as more general function
        """
        if states.ndim == 5:
            states = rearrange(states, 'b ... -> b 1 ...')

        if actions.ndim == 2:
            actions = rearrange(actions, 'b ... -> b 1 ...')

        if rewards.ndim == 1:
            rewards = rearrange(rewards, 'b -> b 1')

        if dones.ndim == 1:
            dones = rearrange(dones, 'b -> b 1')

        return self.autoregressive_q_learn(text_embeds, states, actions, next_states, rewards, dones, monte_carlo_return = monte_carlo_return)

    def autoregressive_q_learn(
        self,
        text_embeds:    TensorType['b', 'd', float],
        states:         TensorType['b', 't', 'c', 'f', 'h', 'w', float],
        actions:        TensorType['b', 't', 'n', int],
        next_states:    TensorType['b', 'c', 'f', 'h', 'w', float],
        rewards:        TensorType['b', 't', float],
        dones:          TensorType['b', 't', bool],
        *,
        monte_carlo_return = None
    ) -> Tuple[TensorType[()], QIntermediates]:
        """
        einops

        b - batch
        c - channels
        f - frames
        h - height
        w - width
        t - timesteps
        n - number of actions
        a - action bins
        q - q values
        d - text cond dimension
        """
        # 设置默认的蒙特卡洛回报值
        monte_carlo_return = default(monte_carlo_return, -1e4)
        # 获取状态的时间步数和设备信息
        num_timesteps, device = states.shape[1], states.device

        # 将时间步折叠到批次中

        states, time_ps = pack_one(states, '* c f h w')
        actions, _ = pack_one(actions, '* n')
        text_embeds, _ = pack_one(text_embeds, '* d')

        # 每个时间步重复文本嵌入

        repeated_text_embeds = repeat(text_embeds, 'b ... -> (b n) ...', n = num_timesteps)

        # 第一个完成标志之后的任何内容都将被视为终止

        dones = dones.cumsum(dim = -1) > 0
        dones = F.pad(dones, (1, -1), value = False)

        not_terminal = (~dones).float()

        # 奖励不应在终止步骤及之后给出

        rewards = rewards * not_terminal

        # 因为希腊字母Unicode看起来很好

        γ = self.discount_factor_gamma

        # 获取每个动作的预测 Q 值
        # 解包回 (b, t, n)

        q_pred_all_actions = self.model(states, text_embeds = repeated_text_embeds, actions = actions)
        q_pred = batch_select_indices(q_pred_all_actions, actions)
        q_pred = unpack_one(q_pred, time_ps, '* n')

        # 获取 q_next

        q_next = self.ema_model(next_states, text_embeds = text_embeds)
        q_next = q_next.max(dim = -1).values
        q_next.clamp_(min = monte_carlo_return)

        # 获取目标 Q
        # 解包回 - (b, t, n)

        q_target_all_actions = self.ema_model(states, text_embeds = repeated_text_embeds, actions = actions)
        q_target = q_target_all_actions.max(dim = -1).values

        q_target.clamp_(min = monte_carlo_return)
        q_target = unpack_one(q_target, time_ps, '* n')

        # 论文的主要贡献是以下逻辑
        # 第 4.1 节 - 方程 1

        # 首先处理除最后一个动作之外的所有动作的损失

        q_pred_rest_actions, q_pred_last_action      = q_pred[..., :-1], q_pred[..., -1]
        q_target_first_action, q_target_rest_actions = q_target[..., 0], q_target[..., 1:]

        losses_all_actions_but_last = F.mse_loss(q_pred_rest_actions, q_target_rest_actions, reduction = 'none')

        # 接下来处理最后一个动作,其中包含奖励

        q_target_last_action, _ = pack([q_target_first_action[..., 1:], q_next], 'b *')

        q_target_last_action = rewards + γ * q_target_last_action

        losses_last_action = F.mse_loss(q_pred_last_action, q_target_last_action, reduction = 'none')

        # 展平并平均

        losses, _ = pack([losses_all_actions_but_last, losses_last_action], '*')

        return losses.mean(), QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)

    def learn(
        self,
        *args,
        min_reward: Optional[float] = None,
        monte_carlo_return: Optional[float] = None
    ):
        # 从参数中解包出 actions
        _, _, actions, *_ = args

        # q-learn kwargs
        # 创建包含 monte_carlo_return 参数的字典
        q_learn_kwargs = dict(
            monte_carlo_return = monte_carlo_return
        )

        # main q-learning loss, respectively
        # 1. proposed autoregressive q-learning for multiple actions - (handles single or n-step automatically)
        # 2. single action - single timestep (classic q-learning)
        # 3. single action - n-steps

        # 如果是多个动作
        if self.is_multiple_actions:
            # 使用 autoregressive_q_learn_handle_single_timestep 处理单个时间步的动作
            td_loss, q_intermediates = self.autoregressive_q_learn_handle_single_timestep(*args, **q_learn_kwargs)
            num_timesteps = actions.shape[1]

        # 如果是 n-step Q-learning
        elif self.n_step_q_learning:
            # 使用 n_step_q_learn 处理 n-step Q-learning
            td_loss, q_intermediates = self.n_step_q_learn(*args, **q_learn_kwargs)
            num_timesteps = actions.shape[1]

        else:
            # 使用 q_learn 处理单个时间步的动作
            td_loss, q_intermediates = self.q_learn(*args, **q_learn_kwargs)
            num_timesteps = 1

        # 如果没有保守正则化损失
        if not self.has_conservative_reg_loss:
            # 返回损失和 Losses 对象
            return loss, Losses(td_loss, self.zero)

        # 计算保守正则化
        # 论文中的第 4.2 节,方程式 2

        # 获取批次大小
        batch = actions.shape[0]

        # 获取所有动作的 Q 预测值
        q_preds = q_intermediates.q_pred_all_actions
        q_preds = rearrange(q_preds, '... a -> (...) a')

        # 获取动作的数量
        num_action_bins = q_preds.shape[-1]
        num_non_dataset_actions = num_action_bins - 1

        # 重新排列动作
        actions = rearrange(actions, '... -> (...) 1')

        # 创建数据集动作掩码
        dataset_action_mask = torch.zeros_like(q_preds).scatter_(-1, actions, torch.ones_like(q_preds))

        # 获取未选择的动作的 Q 值
        q_actions_not_taken = q_preds[~dataset_action_mask.bool()]
        q_actions_not_taken = rearrange(q_actions_not_taken, '(b t a) -> b t a', b = batch, a = num_non_dataset_actions)

        # 计算保守正则化损失
        conservative_reg_loss = ((q_actions_not_taken - (min_reward * num_timesteps)) ** 2).sum() / num_non_dataset_actions

        # 总损失
        loss =  0.5 * td_loss + \
                0.5 * conservative_reg_loss * self.conservative_reg_loss_weight

        # 损失细分
        loss_breakdown = Losses(td_loss, conservative_reg_loss)

        return loss, loss_breakdown

    # 前向传播函数
    def forward(
        self,
        *,
        monte_carlo_return: Optional[float] = None,
        min_reward: Optional[float] = None
        ):
            # 如果未提供蒙特卡洛回报和最小奖励,则使用默认值
            monte_carlo_return = default(monte_carlo_return, self.monte_carlo_return)
            min_reward = default(min_reward, self.min_reward)

            # 获取当前步数
            step = self.step.item()

            # 创建一个循环迭代器,用于遍历数据加载器
            replay_buffer_iter = cycle(self.dataloader)

            # 设置模型为训练模式
            self.model.train()
            self.ema_model.train()

            # 在训练步数小于总训练步数时执行循环
            while step < self.num_train_steps:

                # 清空梯度
                self.optimizer.zero_grad()

                # 主要的 Q-learning 算法

                # 对于每个梯度累积步骤
                for grad_accum_step in range(self.grad_accum_every):
                    is_last = grad_accum_step == (self.grad_accum_every - 1)
                    # 如果不是最后一个梯度累积步骤,则使用 partial 函数创建上下文
                    context = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext

                    # 使用自动混合精度和上下文执行学习过程
                    with self.accelerator.autocast(), context():

                        # 调用 learn 方法进行学习
                        loss, (td_loss, conservative_reg_loss) = self.learn(
                            *next(replay_buffer_iter),
                            min_reward = min_reward,
                            monte_carlo_return = monte_carlo_return
                        )

                        # 反向传播
                        self.accelerator.backward(loss / self.grad_accum_every)

                # 打印 TD 损失
                self.print(f'td loss: {td_loss.item():.3f}')

                # 限制梯度大小(变压器最佳实践)
                self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                # 执行优化器步骤
                self.optimizer.step()

                # 更新目标 EMA
                self.wait()
                self.ema_model.update()

                # 增加步数
                step += 1
                self.step.add_(1)

                # 是否进行检查点
                self.wait()

                if self.is_main and is_divisible(step, self.checkpoint_every):
                    checkpoint_num = step // self.checkpoint_every
                    self.save(checkpoint_num)

                self.wait()

            # 训练完成后打印信息
            self.print('training complete')
posted @ 2024-06-28 14:04  绝不原创的飞龙  阅读(4)  评论(0编辑  收藏  举报