Lucidrains-系列项目源码解析-三十一-

Lucidrains 系列项目源码解析(三十一)

.\lucidrains\PaLM-pytorch\palm_pytorch\triton\layernorm.py

# 从 Phil Tillet 的 Triton 的 layernorm 教程中获取的代码

# Triton - https://triton-lang.org
# Layernorm 教程 - https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py
# 修改为无偏置

# 导入必要的库
import torch
import triton
import triton.language as tl

# 前向传播的 Triton 内核函数
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, M, V, stride, N,
                          BLOCK_SIZE: tl.constexpr):

    # 获取当前行号
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK_SIZE)
    mask = cols < N

    X += row * stride
    Y += row * stride

    x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)

    mean = tl.sum(x, axis=0) / N

    xmean = tl.where(mask, x - mean, 0.)
    var = tl.sum(xmean * xmean, axis=0) / N
    rstd = 1 / tl.sqrt(var + 1e-5)
    xhat = xmean * rstd

    tl.store(M + row, mean)
    tl.store(V + row, rstd)

    w = tl.load(W + cols, mask=mask)
    y = xhat * w

    tl.store(Y + cols, y, mask=mask)

# 反向传播的 Triton 内核函数
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, X, W, M, V, Lock, stride, N,
                             GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):

    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK_SIZE_N)
    mask = cols < N

    X += row * stride
    DY += row * stride
    DX += row * stride

    lock_id = row % GROUP_SIZE_M
    Lock += lock_id
    Count = Lock + GROUP_SIZE_M
    DW = DW + lock_id * N + cols

    x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
    dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
    w = tl.load(W + cols, mask=mask).to(tl.float32)
    mean = tl.load(M + row)
    rstd = tl.load(V + row)

    xhat = (x - mean) * rstd
    wdy = w * dy
    xhat = tl.where(mask, xhat, 0.)
    wdy = tl.where(mask, wdy, 0.)
    mean1 = tl.sum(xhat * wdy, axis=0) / N
    mean2 = tl.sum(wdy, axis=0) / N
    dx = (wdy - (xhat * mean1 + mean2)) * rstd
    
    tl.store(DX + cols, dx, mask=mask)

    partial_dw = (dy * xhat).to(w.dtype)

    while tl.atomic_cas(Lock, 0, 1) == 1:
        pass
    count = tl.load(Count)

    if count == 0:
        tl.atomic_xchg(Count, 1)
    else:
        partial_dw += tl.load(DW, mask=mask)

    tl.store(DW, partial_dw, mask=mask)

    tl.atomic_xchg(Lock, 0)

# 计算权重梯度的 Triton 内核函数
@triton.jit
def _layer_norm_bwd_dw(DW, FINAL_DW, M, N,
                         BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
    pid = tl.program_id(0)
    cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    for i in range(0, M, BLOCK_SIZE_M):
        rows = i + tl.arange(0, BLOCK_SIZE_M)
        mask = (rows[:, None] < M) & (cols[None, :] < N)
        offs = rows[:, None] * N + cols[None, :]
        dw += tl.load(DW + offs, mask=mask, other=0.)

    sum_dw = tl.sum(dw, axis=0)
    tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)

# LayerNorm 类,继承自 torch.autograd.Function
class LayerNorm(torch.autograd.Function):

    # 前向传播函数
    @staticmethod
    def forward(ctx, x, normalized_shape, weight):
        y = torch.empty_like(x)

        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape
        mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
        rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')

        MAX_FUSED_SIZE = 65536 // x.element_size()
        BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
        if N > BLOCK_SIZE:
            raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

        num_warps = min(max(BLOCK_SIZE // 256, 1), 8)

        _layer_norm_fwd_fused[(M,)](x_arg, y, weight, mean, rstd,
                                    x_arg.stride(0), N,
                                    BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
        ctx.save_for_backward(x, weight, mean, rstd)
        ctx.BLOCK_SIZE = BLOCK_SIZE
        ctx.num_warps = num_warps
        return y

    @staticmethod
    # 反向传播函数,计算梯度
    def backward(ctx, dy):
        # 从上下文中获取保存的张量 x, w, m, v
        x, w, m, v = ctx.saved_tensors

        # 获取 w 的形状信息
        N = w.shape[0]
        GROUP_SIZE_M = 64
        # 根据 w 的大小确定 GROUP_SIZE_M 的值
        if N <= 8192: GROUP_SIZE_M = 96
        if N <= 4096: GROUP_SIZE_M = 128
        if N <= 1024: GROUP_SIZE_M = 256

        # 创建用于同步的锁
        locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
        # 创建用于存储梯度的 _dw 张量
        _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)

        # 创建用于存储 w 梯度的 dw 张量
        dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
        # 创建用于存储输入 x 的梯度的 dx 张量
        dx = torch.empty_like(dy)

        # 将输入 x 重塑为二维张量
        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape
        # 调用 _layer_norm_bwd_dx_fused 函数计算 dx
        _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, x, w, m, v, locks,
                                       x_arg.stride(0), N,
                                       BLOCK_SIZE_N=ctx.BLOCK_SIZE,
                                       GROUP_SIZE_M=GROUP_SIZE_M,
                                       num_warps=ctx.num_warps)
        # 定义 grid 函数用于计算网格大小
        grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]

        # 调用 _layer_norm_bwd_dw 函数计算 dw
        _layer_norm_bwd_dw[grid](_dw, dw, GROUP_SIZE_M, N,
                                   BLOCK_SIZE_M=32,
                                   BLOCK_SIZE_N=128)
        # 返回计算得到的 dx 和 dw
        return dx, None, dw, None
# 将LayerNorm类的apply方法赋值给layernorm_without_bias变量
layernorm_without_bias = LayerNorm.apply

.\lucidrains\PaLM-pytorch\palm_pytorch\triton\palm.py

# 导入所需的库
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn

# 导入自定义的模块
from palm_pytorch.triton.softmax import causal_softmax
from palm_pytorch.triton.layernorm import layernorm_without_bias

# normalization

# 定义 LayerNorm 类,用于实现 Layer Normalization
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return layernorm_without_bias(x, x.shape[-1:], self.gamma)


# residual

# 定义 Residual 类,用于实现残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


# rotary positional embedding

# 定义 RotaryEmbedding 类,用于实现旋转位置嵌入
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)


# 定义旋转操作函数
def rotate_half(x):
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


# 应用旋转位置嵌入到输入张量
def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())


# feedforward

# 定义 SwiGLU 类,用于实现 Swish-Gated Linear Unit
class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame

# 定义 ParallelTransformerBlock 类,实现并行的 Transformer 模块
class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        self.heads = heads
        self.scale = dim_head**-0.5
        self.rotary_emb = RotaryEmbedding(dim_head)

        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
        self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)

        # for caching of rotary embeddings

        self.register_buffer("pos_emb", None, persistent=False)

    def get_rotary_embedding(self, n, device):
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n]

        pos_emb = self.rotary_emb(n, device=device)
        self.register_buffer("pos_emb", pos_emb, persistent=False)
        return pos_emb

    def forward(self, x):
        n, device, h = x.shape[1], x.device, self.heads

        # pre layernorm

        x = self.norm(x)

        # attention queries, keys, values, and feedforward inner

        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # split heads

        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # rotary embeddings

        positions = self.get_rotary_embedding(n, device)
        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

        # scale

        q = q * self.scale

        # similarity

        sim = einsum("b h i d, b j d -> b h i j", q, k)

        # attention

        attn = causal_softmax(sim)

        # aggregate values

        out = einsum("b h i j, b j d -> b h i d", attn, v)

        # merge heads

        out = rearrange(out, "b h n d -> b n (h d)")
        return self.attn_out(out) + self.ff_out(ff)


# transformer

# 定义 PaLM 函数,用于实现 Parallel Transformer
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
    # 创建一个神经网络模型,包括嵌入层、多个平行Transformer块、LayerNorm层和线性层
    net = nn.Sequential(
        # 创建一个嵌入层,将输入的标记转换为指定维度的向量
        nn.Embedding(num_tokens, dim),
        # 使用循环创建指定数量的平行Transformer块,并将它们作为残差连接添加到Sequential中
        *[
            Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
            for _ in range(depth)
        ],
        # 添加LayerNorm层,对模型的输出进行归一化处理
        LayerNorm(dim),
        # 添加线性层,将模型的输出映射为标记的数量
        nn.Linear(dim, num_tokens, bias=False)
    )

    # 将最后一个线性层的权重设置为嵌入层的权重
    net[-1].weight = net[0].weight

    # 对嵌入层的权重进行正态分布初始化,标准差为0.02
    nn.init.normal_(net[0].weight, std=0.02)
    # 返回创建的神经网络模型
    return net

.\lucidrains\PaLM-pytorch\palm_pytorch\triton\softmax.py

# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F

# 导入 triton 库
import triton
# 从 triton.language 模块中导入 tl
import triton.language as tl
# 从 triton_transformer.utils 模块中导入 calc_num_warps 函数
from triton_transformer.utils import calc_num_warps

# 定义 softmax_kernel_forward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_forward(
    output_ptr,
    input_ptr,
    input_row_stride,
    output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    # 获取当前程序的 ID
    row_idx = tl.program_id(0)

    # 计算当前行的起始指针
    row_start_ptr = input_ptr + row_idx * input_row_stride

    # 计算列偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    # 计算输入指针
    input_ptrs = row_start_ptr + col_offsets

    # 创建一个掩码,用于过滤超出列数的列
    mask = col_offsets < n_cols

    # 从输入指针加载数据到行
    row = tl.load(input_ptrs, mask = mask, other = -float('inf'))

    # 创建一个因果掩码
    causal_mask = col_offsets > (row_idx % n_cols)
    # 对行应用因果掩码
    row = row + tl.where(causal_mask, -float('inf'), 0.)

    # 计算行减去最大值
    row_minus_max = row - tl.max(row, axis=0)

    # 计算指数
    numerator = tl.exp(row_minus_max)
    # 计算分母
    denominator = tl.sum(numerator, axis=0)
    # 计算 softmax 输出
    softmax_output = numerator / denominator

    # 计算输出行的起始指针
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    # 计算输出指针
    output_ptrs = output_row_start_ptr + col_offsets
    # 存储 softmax 输出
    tl.store(output_ptrs, softmax_output, mask = mask)

# 定义 softmax_kernel_backward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_backward(
    output_ptr,
    input_ptr,
    grad_ptr,
    grad_row_stride,
    input_row_stride,
    output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    # 获取当前程序的 ID
    row_idx = tl.program_id(0)

    # 计算当前行的起始指针
    row_start_ptr = input_ptr + row_idx * input_row_stride
    grad_row_start_ptr = grad_ptr + row_idx * grad_row_stride

    # 计算列偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    # 计算输入指针和梯度指针
    input_ptrs = row_start_ptr + col_offsets
    grad_ptrs = grad_row_start_ptr + col_offsets

    # 创建一个掩码,用于过滤超出列数的列
    mask = col_offsets < n_cols

    # 从输入指针加载概率行和梯度行
    probs_row = tl.load(input_ptrs, mask = mask, other = 0.)
    grad_row = tl.load(grad_ptrs, mask = mask, other = 0.)

    # 计算 dxhat
    dxhat = probs_row * grad_row
    # 计算 softmax 梯度输出
    softmax_grad_output = dxhat - probs_row * tl.sum(dxhat, axis = 0)

    # 计算输出行的起始指针
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    # 计算输出指针
    output_ptrs = output_row_start_ptr + col_offsets
    # 存储 softmax 梯度输出
    tl.store(output_ptrs, softmax_grad_output, mask = mask)

# 定义 _softmax 类,继承自 autograd.Function
class _softmax(autograd.Function):
    # 定义前向传播函数
    @classmethod
    def forward(self, ctx, x):
        # 获取输入张量的形状
        shape = x.shape
        # 将输入张量展平成二维张量
        x = x.view(-1, shape[-1])
        n_rows, n_cols = x.shape

        # 计算 BLOCK_SIZE 和 num_warps
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        num_warps = calc_num_warps(BLOCK_SIZE)

        # 创建一个与输入张量相同形状的空张量
        y = torch.empty_like(x)

        # 调用 softmax_kernel_forward 函数
        softmax_kernel_forward[(n_rows,)](
            y,
            x,
            x.stride(0),
            y.stride(0),
            n_cols,
            num_warps = num_warps,
            BLOCK_SIZE = BLOCK_SIZE,
        )

        # 如果输入张量需要梯度,则保存中间结果
        if x.requires_grad:
            ctx.save_for_backward(y)
        return y.view(*shape)

    # 定义反向传播函数
    @classmethod
    def backward(self, ctx, grad_probs):
        # 获取梯度张量的形状
        shape = grad_probs.shape
        # 获取前向传播保存的中间结果
        probs, = ctx.saved_tensors

        # 将梯度张量展平成二维张量
        grad_probs = grad_probs.view(-1, grad_probs.shape[-1])
        n_rows, n_cols = grad_probs.shape

        # 计算 BLOCK_SIZE 和 num_warps
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        num_warps = calc_num_warps(BLOCK_SIZE)

        # 创建一个与概率张量相同形状的空张量
        dx = torch.empty_like(probs)

        # 调用 softmax_kernel_backward 函数
        softmax_kernel_backward[(n_rows,)](
            dx,
            probs,
            grad_probs,
            grad_probs.stride(0),
            probs.stride(0),
            dx.stride(0),
            n_cols,
            num_warps = num_warps,
            BLOCK_SIZE = BLOCK_SIZE
        )

        return dx.view(*shape), None

# 定义 causal_softmax 函数,调用 _softmax 类的 apply 方法
causal_softmax = _softmax.apply

.\lucidrains\PaLM-pytorch\palm_pytorch\triton\__init__.py

# 从 palm_pytorch.triton.palm 模块中导入 PaLM 类
from palm_pytorch.triton.palm import PaLM

.\lucidrains\PaLM-pytorch\palm_pytorch\__init__.py

# 从 palm_pytorch 模块中导入 PaLM 类
from palm_pytorch.palm_pytorch import PaLM

PaLM - Pytorch

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways, in less than 200 lines of code.

This model is pretty much SOTA on everything language. Yannic Kilcher explanation

It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is.

Jax version

Install

$ pip install PaLM-pytorch

Usage

import torch
from palm_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    heads = 8,
    dim_head = 64,
)

tokens = torch.randint(0, 20000, (1, 2048))
logits = palm(tokens) # (1, 2048, 20000)

The PaLM 540B in the paper would be

palm = PaLM(
    num_tokens = 256000,
    dim = 18432,
    depth = 118,
    heads = 48,
    dim_head = 256
)

Test on Enwik8

$ python train.py

Citations

@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. T. Kung and David D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}

.\lucidrains\PaLM-pytorch\setup.py

# 导入必要的模块
from setuptools import find_packages, setup

# 设置包的信息
setup(
    # 包的名称
    name="PaLM-pytorch",
    # 查找所有包,不排除任何包
    packages=find_packages(exclude=[]),
    # 版本号
    version="0.2.2",
    # 许可证
    license="MIT",
    # 描述
    description="PaLM: Scaling Language Modeling with Pathways - Pytorch",
    # 作者
    author="Phil Wang",
    # 作者邮箱
    author_email="lucidrains@gmail.com",
    # 长描述内容类型为 markdown
    long_description_content_type = 'text/markdown',
    # 项目链接
    url="https://github.com/lucidrains/PaLM-pytorch",
    # 关键词
    keywords=[
        "artificial general intelligence",
        "deep learning",
        "transformers",
        "attention mechanism",
    ],
    # 安装依赖
    install_requires=[
        "einops>=0.4",
        "torch>=1.6",
        "triton>=2.0dev"
    ],
    # 分类
    classifiers=[
        "Development Status :: 4 - Beta",
        "Intended Audience :: Developers",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "License :: OSI Approved :: MIT License",
        "Programming Language :: Python :: 3.6",
    ],
)

.\lucidrains\PaLM-pytorch\train.py

# 导入所需的库
import gzip
import random

import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义的类和函数
from palm_pytorch.triton import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024

# 定义辅助函数

# 生成数据加载器的无限循环
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 将 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 将 tokens 解码为字符串
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = PaLM(num_tokens=256, dim=512, depth=8)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
model.cuda()

# 准备 enwik8 数据
with gzip.open("./examples/enwik8_deepspeed/data/enwik8.gz") as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\attention.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple
# 从 functools 模块中导入 wraps 函数
from functools import wraps
# 从 packaging 模块中导入 version 类
from packaging import version

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 定义一个命名元组 Config,包含三个布尔类型的参数
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,用于确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个打印函数,使用 once 装饰器确保只打印一次
print_once = once(print)

# 主要类定义

class Attention(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        use_flash_attn = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        # 注册一个缓冲区变量 mask,初始值为 None,不会被持久化
        self.register_buffer("mask", None, persistent=False)

        self.use_flash_attn = use_flash_attn
        # 断言条件,如果不满足则抛出异常
        assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 CUDA 和 CPU 的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        # 如果没有可用的 CUDA 或不使用 flash attention,则直接返回
        if not torch.cuda.is_available() or not use_flash_attn:
            return

        # 获取当前 CUDA 设备的属性
        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        # 根据 CUDA 设备的主要和次要版本号选择配置
        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    # 获取掩码 mask
    def get_mask(self, n, device):
        if exists(self.mask) and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        # 推荐的多查询单键值注意力重排操作
        k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
        v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)

        # 检查是否存在 mask 并扩展到兼容的形状
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 flash attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 torch.backends.cuda.sdp_kernel 函数应用配置,执行 Flash Attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = self.causal
            )

        return out
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)和掩码(mask)作为输入参数
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取序列长度(n)和设备信息(device)
        n, device = q.shape[-2], q.device

        # 计算缩放因子
        scale = q.shape[-1] ** -0.5

        # 如果使用闪回注意力机制,则调用flash_attn函数
        if self.use_flash_attn:
            return self.flash_attn(q, k, v, mask = mask)

        # 计算相似度
        sim = einsum("b h i d, b j d -> b h i j", q, k) * scale

        # 键填充掩码
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 因果掩码
        if self.causal:
            causal_mask = self.get_mask(n, device)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力计算
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum("b h i j, b j d -> b h i d", attn, v)

        return out

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\lora.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# LoRA - https://arxiv.org/abs/2106.09685

# 定义 LoRA 类,继承自 nn.Module 类
class LoRA(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        dim_out,
        r = 8,
        alpha = None
    ):
        super().__init__()
        # 如果 alpha 不存在,则使用 r 作为默认值
        alpha = default(alpha, r)
        # 计算缩放因子
        self.scale = alpha / r

        # 定义 A 和 B 为可学习参数
        self.A = nn.Parameter(torch.randn(dim, r))
        self.B = nn.Parameter(torch.zeros(r, dim_out))

    # 定义 weight 属性,返回 A 和 B 的乘积再乘以缩放因子
    @property
    def weight(self):
        return (self.A @ self.B) * self.scale

    # 前向传播函数,返回输入 x 与权重 weight 的乘积
    def forward(self, x):
        return x @ self.weight

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam
# 从 lion_pytorch 模块中导入 Lion 类

# 将参数分为需要权重衰减和不需要权重衰减的两组参数
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    use_lion = True,
    **kwargs
):
    # 根据是否需要过滤梯度为零的参数来更新参数列表
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并应用权重衰减
    if group_wd_params and wd > 0:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 如果使用 Lion 优化器
    if use_lion:
        return Lion(params, lr = lr, betas = betas, weight_decay = wd)

    # 如果不需要权重衰减
    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 使用 AdamW 优化器
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\palm.py

# 导入数学库
import math
# 导入拷贝库
import copy
# 导入路径库
from pathlib import Path
# 导入命名元组库
from collections import namedtuple
# 导入装饰器库
from functools import wraps
# 导入zip_longest函数
from itertools import zip_longest

# 导入进度条库
from tqdm import tqdm
# 导入beartype库
from beartype import beartype
# 导入beartype中的Tuple和Optional
from beartype.typing import Tuple, Optional

# 导入torch库
import torch
# 从torch中导入einsum和nn
from torch import einsum, nn
# 从torch.nn中导入functional模块
import torch.nn.functional as F

# 导入einops库
from einops import rearrange, repeat, reduce, pack, unpack
# 从einops.layers.torch中导入Rearrange和Reduce
from einops.layers.torch import Rearrange, Reduce

# 从palm_rlhf_pytorch.attention中导入Attention
from palm_rlhf_pytorch.attention import Attention
# 从palm_rlhf_pytorch.utils中导入top_p, top_k, masked_mean, gumbel_sample, eval_decorator
from palm_rlhf_pytorch.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator
# 从palm_rlhf_pytorch.lora中导入LoRA

# 函数和装饰器

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 对输入张量进行L2范数归一化
def l2norm(t):
    return F.normalize(t, dim=-1)

# 标准化
# 他们使用没有偏置的layernorm,这是PyTorch不提供的功能

# 标准化层
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 残差连接

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        y = self.fn(x, **kwargs)

        if not any([t.requires_grad for t in (x, y)]):
            return x.add_(y)

        return y + x

# 旋转位置嵌入带xpos
# https://arxiv.org/abs/2104.09864
# https://arxiv.org/abs/2212.10554v1

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, scale_base=512, use_xpos=True):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        self.use_xpos = use_xpos
        self.scale_base = scale_base
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        self.register_buffer('scale', scale)

    def forward(self, seq_len, device):
        t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim=-1)

        if not self.use_xpos:
            return freqs, torch.ones(1, device=device)

        power = (t - (seq_len // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim=-1)

        return freqs, scale

# 旋转半个张量
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t, scale=1.):
    return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)

# 经典的Noam Shazeer论文,但这里他们使用SwiGLU而不是更流行的GEGLU来门控前馈
# https://arxiv.org/abs/2002.05202

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

# 并行注意力和前馈与残差
# 王等人和GPT-J的EleutherAI发现

class ParallelTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_head=64,
        causal=True,
        heads=8,
        qk_rmsnorm=False,
        qk_scale=8,
        ff_mult=4,
        attn_dropout=0.,
        ff_dropout=0.,
        use_xpos=True,
        xpos_scale_base=512,
        flash_attn=False,
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化 LayerNorm 层
        self.norm = LayerNorm(dim)

        # 计算注意力内部维度
        attn_inner_dim = dim_head * heads
        # 计算前馈内部维度
        ff_inner_dim = dim * ff_mult
        # 定义融合维度
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        # 设置是否进行 qk rmsnorm
        self.qk_rmsnorm = qk_rmsnorm

        if qk_rmsnorm:
            # 初始化 q 的缩放参数
            self.q_scale = nn.Parameter(torch.ones(dim_head))
            # 初始化 k 的缩放参数
            self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 初始化注意力模块
        self.attend = Attention(
            causal = causal,
            dropout = attn_dropout,
            use_flash_attn = flash_attn
        )

        # 设置头数
        self.heads = heads
        # 设置缩放因子
        self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale
        # 设置是否是因果关系
        self.causal = causal

        # 初始化旋转嵌入
        self.rotary_emb = RotaryEmbedding(dim_head, scale_base = xpos_scale_base, use_xpos = use_xpos and causal)

        # 初始化融合的注意力和前馈投影
        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)

        # 设置是否使用 Flash Attention
        self.flash_attn = flash_attn
        # 初始化注意力输出层
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
        # 初始化注意力的 Dropout 层
        self.attn_dropout = nn.Dropout(attn_dropout)
        # 设置 Flash Attention 的 Dropout
        self.flash_attn_dropout = attn_dropout

        # 并行前馈尾部

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Dropout(ff_dropout),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # 用于缓存因果掩码和旋转嵌入

        self.register_buffer("pos_emb", None, persistent=False)
        self.register_buffer("pos_emb_scale", None, persistent=False)

    def get_rotary_embedding(self, n, device):
        if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n], self.pos_emb_scale[:n]

        pos_emb, scale = self.rotary_emb(n, device=device)
        self.register_buffer("pos_emb", pos_emb, persistent=False)
        self.register_buffer("pos_emb_scale", scale, persistent=False)
        return pos_emb, scale

    def forward(
        self,
        x,
        mask = None,
        finetune_modules = None
    ):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device, h = x.shape[1], x.device, self.heads

        # 预 Layernorm

        x = self.norm(x)

        # 注意力查询、键、值和前馈内部

        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # 调整 LORAS

        lora_q = lora_k = lora_v = lora_o = None

        if exists(finetune_modules):
            lora_q, lora_k, lora_v, lora_o = finetune_modules
            q = q + lora_q(x)
            k = k + lora_k(x)
            v = v + lora_v(x)

        # 分割头部
        # 他们使用多查询单键值注意力,另一篇 Noam Shazeer 的论文
        # 他们发现在一定规模之后没有性能损失,并且解码更有效
        # https://arxiv.org/abs/1911.02150

        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # qk rmsnorm

        if self.qk_rmsnorm:
            q, k = map(l2norm, (q, k))
            q = q * self.q_scale
            k = k * self.k_scale

        # 使用 xpos 衰减的旋转嵌入以获得更好的长度外推

        positions, scale = self.get_rotary_embedding(n, device)

        q = apply_rotary_pos_emb(positions, q, scale)
        k = apply_rotary_pos_emb(positions, k, scale ** -1)

        # 注意力函数,常规或 Flash

        out = self.attend(q, k, v, mask = mask)

        # 合并头部

        out = rearrange(out, "b h n d -> b n (h d)")

        attn_out = self.attn_out(out)

        ff_out = self.ff_out(ff)

        if exists(lora_o):
            attn_out = attn_out + lora_o(out)

        return attn_out + ff_out
# 定义一个名为 PaLM 的类,继承自 nn.Module 类,用于实现一个基于 Transformer 的模型
@beartype
class PaLM(nn.Module):
    # 初始化函数,接收多个参数用于配置模型的各种属性
    def __init__(
        self,
        *,
        dim,  # 模型的维度
        num_tokens,  # token 的数量
        depth,  # Transformer 的深度
        causal = True,  # 是否使用 causal attention
        dim_head = 64,  # 每个头的维度
        heads = 8,  # 头的数量
        ff_mult = 4,  # FeedForward 层的倍数
        attn_dropout = 0.,  # 注意力层的 dropout 概率
        ff_dropout = 0.,  # FeedForward 层的 dropout 概率
        qk_rmsnorm = False,  # 是否对 QK 矩阵进行 RMS 归一化
        lora_r = 8,  # LoRA 模块的参数 r
        rotary_xpos_scale_base = 512,  # 旋转位置编码的基数
        flash_attn = False,  # 是否使用 Flash Attention
        finetune_scopes = tuple(),  # 微调的范围
        cross_entropy_ignore_index = 0  # 交叉熵损失的忽略索引
    ):
        super().__init__()
        # 初始化模型的各种属性
        self.dim = dim
        self.dim_head = dim_head
        self.heads = heads
        self.causal = causal
        self.num_tokens = num_tokens

        # 创建 token 的嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.layers = nn.ModuleList([])

        # 根据深度循环创建多个 Transformer Block
        for _ in range(depth):
            block = Residual(ParallelTransformerBlock(
                dim = dim,
                causal = causal,
                dim_head = dim_head,
                heads = heads,
                qk_rmsnorm = qk_rmsnorm,
                ff_mult = ff_mult,
                attn_dropout = attn_dropout,
                ff_dropout = ff_dropout,
                xpos_scale_base = rotary_xpos_scale_base,
                flash_attn = flash_attn
            ))

            self.layers.append(block)

        # 创建 LayerNorm 层
        self.norm = LayerNorm(dim)
        # 创建输出层,用于将模型输出转换为 token 的概率分布
        self.to_logits = nn.Linear(dim, num_tokens, bias=False)
        
        # 将输出层的权重与 token 嵌入层的权重共享
        self.to_logits.weight = self.token_emb.weight

        # 对 token 嵌入层的权重进行正态分布初始化
        nn.init.normal_(self.token_emb.weight, std=0.02)

        # 微调相关

        self.lora_r = lora_r
        self.finetune_modules = nn.ModuleDict({})

        # 根据微调范围添加微调参数
        for scope in finetune_scopes:
            self.add_finetune_params(scope)

        # 损失相关

        self.cross_entropy_ignore_index = cross_entropy_ignore_index

    # 定义 device 属性,用于获取模型参数所在的设备
    @property
    def device(self):
        return next(self.parameters()).device

    # 加载模型参数
    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))

    # 设置模型中的 Dropout 层的概率
    def set_dropout(self, dropout):
        for module in self.layers.modules():
            if isinstance(module, nn.Dropout):
                module.p = dropout
        return self

    # 添加微调参数
    def add_finetune_params(self, scope, lora_r = None):
        assert scope not in self.finetune_modules, f'finetune scope {scope} already found'
        dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device

        q_inner_dim = heads * dim_head
        kv_inner_dim = dim_head

        lora_modules = nn.ModuleList([])

        for _ in range(len(self.layers)):
            lora_modules.append(nn.ModuleList([
                LoRA(dim, q_inner_dim, r = r),   # queries
                LoRA(dim, kv_inner_dim, r = r),  # keys
                LoRA(dim, kv_inner_dim, r = r),  # values
                LoRA(q_inner_dim, dim, r = r)    # wo
            ]))

        self.finetune_modules[scope] = lora_modules.to(device)

    # 移除微调参数
    def remove_finetune_params(self, scope):
        assert scope in self.finetune_modules, f'finetune scope {scope} not found'
        return self.finetune_modules.pop(scope)

    # 禁用梯度计算
    @torch.no_grad()
    # 合并微调的 actor LORA 参数,用于多轮不同奖励模型的微调
    def merge_finetune_params(self, scope):
        """ in the case one wants to merge the fine-tuned actor LORA parameters and do multiple rounds of fine tuning off different reward models """

        # 确保指定的微调范围存在
        assert scope in self.finetune_modules, f'finetune scope {scope} not found'

        # 弹出指定范围的 LORA 模块
        lora_modules = self.finetune_modules.pop(scope)

        # 遍历每个层和对应的 LORA 模块
        for layer, (lora_q, lora_k, lora_v, lora_o) in zip(self.layers, lora_modules):
            block = layer.fn

            # 获取融合的注意力和前馈权重
            fused_attn_ff_weight = block.fused_attn_ff_proj.weight
            attn_out_weight = block.attn_out.weight

            # 获取融合后的投影输出维度
            fused_proj_out_dim = fused_attn_ff_weight.shape[0]

            # 打包 Q、K、V 权重
            lora_qkv_weight, _ = pack([lora_q.weight, lora_k.weight, lora_v.weight], 'i *')
            lora_qkv_weight = F.pad(lora_qkv_weight, (0, fused_proj_out_dim - lora_qkv_weight.shape[1]))

            # 重排 QKV 权重
            lora_qkv_weight = rearrange(lora_qkv_weight, 'i o -> o i')
            lora_o_weight = rearrange(lora_o.weight, 'i o -> o i')

            # 更新融合的注意力和前馈权重
            fused_attn_ff_weight.add_(lora_qkv_weight)
            attn_out_weight.add_(lora_o_weight)

    # 研究员首先训练 PALM 参数,然后进行微调

    # 获取 PALM 参数
    def palm_parameters(self):
        return set(self.parameters()) - set(self.finetune_modules.parameters())

    # 获取微调参数
    def finetune_parameters(self, scope = 'default'):
        assert scope in self.finetune_modules, f'finetune parameters of scope {scope} not found'
        return self.finetune_modules[scope].parameters()

    # 生成函数

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        seq_len,
        prompt = None,
        temperature = 1.,
        filter_logits_fn = top_k,
        filter_thres = 0.9,
        pad_value = 0.,
        eos_token = None,
        return_seq_without_prompt = True,
        use_tqdm = False,
        **kwargs
    ):
        # 如果没有指定提示,则随机生成一个
        if not exists(prompt):
            prompt = torch.randint(0, self.num_tokens, (1, 1))
            prompt = prompt.to(self.device)
            return_seq_without_prompt = False

        prompt, leading_dims = pack([prompt], '* n')

        n, out = prompt.shape[-1], prompt.clone()

        wrapper_fn = identity if not use_tqdm else tqdm
        sample_num_times = max(1, seq_len - prompt.shape[-1])

        for _ in wrapper_fn(range(sample_num_times)):
            logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
            logits, embeds = logits[:, -1], embeds[:, -1]

            if exists(filter_logits_fn):
                logits = filter_logits_fn(logits, thres = filter_thres)

            sample = gumbel_sample(logits, temperature = temperature, dim = -1)
            out, _ = pack([out, sample], 'b *')

            if exists(eos_token):
                is_eos_tokens = (out == eos_token)

                if is_eos_tokens.any(dim = -1).all():
                    # 掩盖掉 EOS 标记后的所有内容
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                    out = out.masked_fill(mask, pad_value)
                    break

        out, = unpack(out, leading_dims, '* n')

        if not return_seq_without_prompt:
            return out

        return out[..., n:]

    # 前向传播函数
    def forward(
        self,
        x,
        return_loss = False,
        disable_lora = False,
        finetune_scope = None,
        extra_embed = None,
        return_only_embedding = False,
        return_logits_with_embedding = False
        ):
        # 如果需要返回损失,则将输入数据 x 切片,分别作为输入和标签
        if return_loss:
            x, labels = x[:, :-1], x[:, 1:]

        # 如果不是自回归模型,对编码器进行掩码处理
        # 将任何负数的标记视为需要屏蔽的标记 - 仅在非自回归情况下需要
        if not self.causal:
            mask = x >= 0
            x = x.masked_fill(~mask, 0)
        else:
            mask = None

        # 获取标记嵌入
        x = self.token_emb(x)

        # 如果存在额外的嵌入,则将其加到标记嵌入中
        if exists(extra_embed):
            x = x + extra_embed

        # 微调模块
        finetune_modules = tuple()
        if exists(finetune_scope) and not disable_lora:
            assert finetune_scope in self.finetune_modules
            finetune_modules = self.finetune_modules[finetune_scope]

        # 并行注意力 / 前馈块,传入微调 lora
        for layer, finetune_modules in zip_longest(self.layers, finetune_modules):
            x = layer(x, mask = mask, finetune_modules = finetune_modules)

        # 最终规范化
        embeds = self.norm(x)

        # 如果只需要返回嵌入,则直接返回嵌入
        if return_only_embedding:
            return embeds

        # 转换为逻辑值
        logits = self.to_logits(embeds)

        # 返回结果,根据需要返回逻辑值和嵌入或仅逻辑值
        ret = (logits, embeds) if return_logits_with_embedding else logits

        # 如果不需要返回损失,则直接返回结果
        if not return_loss:
            return ret

        # 重新排列逻辑值的维度,以便计算交叉熵损失
        logits = rearrange(logits, 'b n c -> b c n')
        return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index)

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\ppo.py

import math
from pathlib import Path
import copy
from tqdm import tqdm
from functools import partial
from collections import deque, namedtuple
from random import randrange

from beartype import beartype
from beartype.typing import List, Optional, Callable, Deque

import torch
from torch import nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from palm_rlhf_pytorch.palm import PaLM
from palm_rlhf_pytorch.reward import RewardModel
from palm_rlhf_pytorch.optimizer import get_optimizer
from palm_rlhf_pytorch.utils import masked_mean, eval_decorator

from accelerate import Accelerator

# actor critic - PaLM with lora

PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
    'actions',
    'sequence',
    'mask',
    'prompt_mask',
    'action_logits',
    'values'
])

@beartype
class ActorCritic(nn.Module):
    def __init__(
        self,
        palm: PaLM,
        critic_palm: Optional[PaLM] = None,
        pooled_values = False,
        actor_lora = True,
        critic_lora = True,
        actor_lora_r = 8,
        critic_lora_r = 8,
        actor_lora_scope = 'actor',
        critic_lora_scope = 'critic',
        actor_dropout = 0.,
        critic_dropout = 0.
    ):
        super().__init__()
        self.actor_palm = palm

        self.critic_palm = critic_palm

        if not exists(self.critic_palm):
            self.critic_palm = copy.deepcopy(palm)

        self.actor_palm.set_dropout(actor_dropout)
        self.critic_palm.set_dropout(critic_dropout)

        self.actor_lora = actor_lora
        self.critic_lora = critic_lora

        self.actor_lora_scope = actor_lora_scope if actor_lora else None
        self.critic_lora_scope = critic_lora_scope if critic_lora else None

        if self.actor_lora:
            self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)

        if self.critic_lora:
            self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)

        self.pooled_values = pooled_values
        self.value_head = nn.Sequential(
            nn.Linear(palm.dim, 1),
            Rearrange('... 1 -> ...')
        )

        nn.init.zeros_(self.value_head[0].bias)
        nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))

    def actor_parameters(self):
        # 返回 actor 参数,如果不使用 lora,则返回 actor_palm 的参数
        if not self.actor_lora:
            return self.actor_palm.parameters()

        return [
            *self.actor_palm.finetune_parameters(self.actor_lora_scope)
        ]

    def critic_parameters(self):
        # 返回 critic 参数,如果不使用 lora,则返回 critic_palm 和 value_head 的参数
        if not self.actor_lora:
            return [*self.critic_palm.parameters(), *self.value_head.parameters()]

        return [
            *self.critic_palm.finetune_parameters(self.critic_lora_scope),
            *self.value_head.parameters()
        ]

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        state,
        max_seq_len,
        eos_token = None,
        return_values = False,
        **kwargs
    # 生成动作序列,根据当前状态和最大序列长度
    actions = self.actor_palm.generate(
        max_seq_len,
        prompt = state,       
        eos_token = eos_token,     
        finetune_scope = self.actor_lora_scope,
        use_tqdm = True,
        **kwargs
    )

    # 将当前状态和生成的动作序列拼接在一起
    sequence = torch.cat((state, actions), dim = -1)
    action_len = actions.shape[-1]
    state_len = state.shape[-1]

    # 创建用于标记当前状态的掩码
    prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
    prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])

    # 创建用于标记动作的掩码
    action_mask = ~prompt_mask

    mask = None
    # 如果存在结束标记,创建用于标记结束标记的掩码
    if exists(eos_token):
        mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
        mask = F.pad(mask, (1, -1), value = True) # include eos token
        action_mask &= mask

    # 获取动作的logits和值
    action_logits, value = self.forward(
        sequence,
        mask = action_mask,
        return_values = return_values
    )        

    # 返回动作和值的对象
    return PPOActionCriticReturn(
        actions,
        sequence,
        mask,
        prompt_mask,
        action_logits,
        value
    )

def forward(
    self,
    x,
    mask = None,
    return_values = True
):
    # 获取动作的logits
    action_logits = self.actor_palm(
        x,
        finetune_scope = self.actor_lora_scope
    )

    # 如果不需要返回值,直接返回动作logits
    if not return_values:
        return action_logits, None

    # 获取评论者的嵌入
    critic_embeds = self.critic_palm(
        x,
        return_only_embedding = True,
        finetune_scope = self.critic_lora_scope
    )

    # 如果使用池化值,计算平均值
    if self.pooled_values:
        critic_embeds = shift(critic_embeds, shift = 1, dim = -2)
        critic_embeds = masked_mean(critic_embeds, mask, dim = 1)

    # 获取值
    values = self.value_head(critic_embeds)

    # 返回动作logits和值
    return action_logits, values
# 定义一个命名元组 Memory,包含了序列、提示掩码、掩码、动作概率、动作对数概率、奖励和价值
Memory = namedtuple('Memory', [
    'sequence',
    'prompt_mask',
    'mask',
    'action_prob',
    'action_log_prob',
    'reward',
    'value'
])

# ExperienceDataset 类,继承自 Dataset 类,用于处理经验数据集
class ExperienceDataset(Dataset):
    def __init__(
        self,
        data: List[torch.Tensor],  # 接受一个包含 torch.Tensor 的列表作为数据
        device = None  # 设备参数,默认为 None
    ):
        super().__init__()
        self.data = data  # 存储数据
        self.device = device  # 存储设备信息

    def __len__(self):
        return self.data[0].shape[0]  # 返回数据的第一个维度大小

    def __getitem__(self, ind):
        return tuple(map(lambda t: t[ind].to(self.device), self.data))  # 返回指定索引的数据,并将其移动到指定设备上

# 创建数据加载器函数,接受数据、批量大小、是否打乱数据、设备等参数
def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
    ds = ExperienceDataset(data, device = device)  # 创建 ExperienceDataset 实例
    return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)  # 返回 DataLoader 实例

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 对张量进行归一化处理
def masked_normalize(t, eps = 1e-5, mask = None, dim = None):
    dim = default(dim, tuple(range(t.ndim)))  # 获取维度信息
    kwargs = dict(dim = dim, keepdim = True)

    mean = masked_mean(t, mask = mask, **kwargs)  # 计算均值
    mean_centered = t - mean  # 中心化
    var = masked_mean(mean_centered ** 2, mask = mask, **kwargs)  # 计算方差

    return mean_centered * var.clamp(min = eps).rsqrt()  # 返回归一化后的结果

# 对序列进行固定填充
def pad_sequence_fixed(sequences, *args, **kwargs):
    first_el = sequences[0]  # 获取第一个元素
    has_no_dimension = first_el.ndim == 0  # 判断是否没有维度

    # 如果没有维度,添加一个维度
    if has_no_dimension:
        sequences = tuple(map(lambda t: t[None], sequences))

    out = pad_sequence(sequences, *args, **kwargs)  # 使用 pad_sequence 进行填充

    if has_no_dimension:
        out = rearrange(out, '... 1 -> ...')  # 重新排列维度

    return out

# 计算张量的对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 计算对数概率
def log_prob(prob, indices):
    assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match'
    return log(prob.gather(-1, indices[..., None])).squeeze(-1)

# 对张量进行移位
def shift(t, value = 0, shift = 1, dim = -1):
    zeros = (0, 0) * (-dim - 1)
    return F.pad(t, (*zeros, shift, -shift), value = value)

# 计算掩码熵
def masked_entropy(prob, dim = -1, mask = None):
    entropies = (prob * log(prob)).sum(dim = -1)
    return masked_mean(entropies, mask = mask).mean()

# 计算掩码 KL 散度
def masked_kl_div(prob1, prob2, mask = None, reduce_batch = False):
    """
    need to account for variable sequence lengths, therefore not using the built-in functional version
    """
    kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1)
    loss = masked_mean(kl_divs, mask)

    if reduce_batch:
        return loss.mean()

    return loss

# 计算截断值损失
def clipped_value_loss(values, rewards, old_values, clip):
    value_clipped = old_values + (values - old_values).clamp(-clip, clip)
    value_loss_1 = (value_clipped.flatten() - rewards) ** 2
    value_loss_2 = (values.flatten() - rewards) ** 2
    return torch.mean(torch.max(value_loss_1, value_loss_2))

# RLHFTrainer 类,继承自 nn.Module
class RLHFTrainer(nn.Module):
    # 初始化函数,设置模型的各种参数和超参数
    def __init__(
        self,
        *,
        prompts: Optional[List[str]] = None,  # 提示语列表
        prompts_path: Optional[str] = None,  # 提示语文件路径
        prompt_token_ids: Optional[torch.Tensor] = None,  # 提示语的token ids
        tokenizer: Callable = None,  # 分词器
        palm: PaLM,  # 主模型
        reward_model: RewardModel,  # 奖励模型
        critic_palm: Optional[PaLM] = None,  # 评论者模型
        actor_critic: Optional[ActorCritic] = None,  # 演员评论者模型
        actor_lr = 1e-4,  # 演员学习率
        critic_lr = 1e-4,  # 评论者学习率
        actor_wd = 0.,  # 演员权重衰减
        critic_wd = 0.,  # 评论者权重衰减
        actor_adam_eps = 1e-7,  # 演员Adam优化器epsilon
        critic_adam_eps = 1e-7,  # 评论者Adam优化器epsilon
        actor_lora = True,  # 演员是否使用LoRA
        critic_lora = True,  # 评论者是否使用LoRA
        actor_lora_r = 8,  # 演员LoRA半径
        critic_lora_r = 8,  # 评论者LoRA半径
        critic_pooled_values = True,  # 评论者是否使用池化值
        actor_dropout = 0.,  # 演员Dropout
        critic_dropout = 0.,  # 评论者Dropout
        betas = (0.9, 0.999),  # Adam优化器betas
        max_norm = None,  # 梯度裁剪最大范数
        eps_clip = 0.2,  # PPO算法epsilon裁剪
        value_clip = 0.4,  # 值函数裁剪
        beta_s = .01,  # beta_s参数
        pad_value = 0.,  # token填充值
        minibatch_size = 16,  # 小批量大小
        epochs = 1,  # 训练轮数
        kl_div_loss_weight = 0.1,  # KL散度损失权重
        accelerate_kwargs: dict = {},  # 加速器参数
        use_lion = False  # 是否使用LION
    ):
        # 调用父类初始化函数
        super().__init__()

        # 初始化加速器
        self.accelerate = Accelerator(**accelerate_kwargs)

        # 处理提示语到token ids的转换
        assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1

        if exists(prompts_path):
            path = Path(prompts_path)
            prompts = path.read_text().split('\n')

        if exists(prompts):
            assert len(prompts) > 0, 'no prompts'
            assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given'
            prompt_token_ids = tokenizer(prompts)

        self.pad_value = pad_value  # token填充值
        self.num_prompts = prompt_token_ids.shape[0]  # 提示语数量
        self.register_buffer('prompt_token_ids', prompt_token_ids)  # 注册提示语token ids

        # 初始化模型
        self.palm = palm

        if not exists(actor_critic):
            actor_critic = ActorCritic(
                palm = palm,
                critic_palm = critic_palm,
                actor_lora = actor_lora,
                critic_lora = critic_lora,
                actor_lora_r = actor_lora_r,
                critic_lora_r = critic_lora_r,
                pooled_values = critic_pooled_values,
                actor_dropout = actor_dropout,
                critic_dropout = critic_dropout
            ).to(palm.device)

        self.actor_critic = actor_critic  # 演员评论者模型

        self.reward_model = reward_model.eval()  # 奖励模型

        # 训练超参数
        self.epochs = epochs
        self.minibatch_size = minibatch_size
        self.max_norm = max_norm
        self.kl_div_loss_weight = kl_div_loss_weight

        # 优化器
        self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = actor_lr, wd = actor_wd, betas = betas, eps = actor_adam_eps, use_lion = use_lion)
        self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = critic_lr, wd = critic_wd, betas = betas, eps = critic_adam_eps, use_lion = use_lion)

        # PPO算法超参数
        self.eps_clip = eps_clip
        self.value_clip = value_clip
        self.beta_s = beta_s

        # 准备加速器
        (
            self.actor_critic,
            self.reward_model,
            self.actor_optim,
            self.critic_optim
        ) = self.accelerate.prepare(
            self.actor_critic,
            self.reward_model,
            self.actor_optim,
            self.critic_optim
        )

    # 打印函数
    def print(self, msg):
        return self.accelerate.print(msg)

    # 保存模型参数
    def save(self, filepath = './checkpoint.pt'):
        torch.save(self.actor_critic.state_dict(), filepath)

    # 加载模型参数
    def load(self, filepath = './checkpoint.pt'):
        state_dict = torch.load(filepath)
        self.actor_critic.load_state_dict(state_dict)

    # 设备属性
    @property
    def device(self):
        return self.accelerate.device

    # 禁用梯度计算
    @torch.no_grad()
    # 定义一个生成器函数,用于生成文本序列
    def generate(
        self,
        max_seq_len,
        *args,
        prompt,
        num_samples = 4,  # 每个提示生成4个样本,选择具有最高奖励的一个
        **kwargs
    ):
        # 断言只有一个提示允许在同一时间
        assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
        # 复制提示以匹配生成的样本数量
        prompt = repeat(prompt, 'n -> b n', b = num_samples)

        # 获取未加速的 actor_critic 模型
        actor_critic = self.accelerate.unwrap_model(self.actor_critic)
        # 获取未加速的 reward_model 模型
        reward_model = self.accelerate.unwrap_model(self.reward_model)

        # 设置 actor_critic 模型为评估模式
        actor_critic.eval()

        # 生成动作、序列、掩码、提示掩码、动作概率等信息
        (
            actions,
            sequences,
            mask,
            prompt_mask,
            action_logits,
            _
        ) = actor_critic.generate(
            prompt,
            *args,
            max_seq_len = max_seq_len,
            return_values = False,
            **kwargs
        )

        # 使用奖励模型计算奖励
        rewards = reward_model(
            sequences,
            prompt_mask = prompt_mask,
            mask = mask,
            sample = True
        )

        # 选择具有最高奖励的序列索引
        best_sequence_index = rewards.topk(1, dim = -1).indices

        # 获取最佳序列
        best_sequence = sequences[best_sequence_index]
        # 重新排列最佳序列的维度
        best_sequence = rearrange(best_sequence, '1 ... -> ...')

        # 返回最佳序列
        return best_sequence

    # 定义一个学习函数,用于学习记忆
    def learn(
        self,
        memories: Deque[Memory]
    
    # 定义一个训练函数,用于训练模型
    def train(
        self,
        num_episodes = 50000,
        max_timesteps = 500,
        update_timesteps = 5000,
        max_batch_size = 16,
        max_seq_len = 2048,
        eos_token = None,
        temperature = 1.
        ):
        # 获取当前环境设备
        device = self.device

        # 初始化时间步长和记忆队列
        time = 0
        memories = deque([])

        # 循环执行一定数量的 episodes
        for eps in tqdm(range(num_episodes), desc='episodes'):
            # 在每个 episode 中执行一定数量的时间步长
            for timestep in range(max_timesteps):
                time += 1

                # 选择一组随机状态(提示)并获取动作(从 palm 中采样的序列以及动作概率)
                # 使用奖励模型计算奖励并存储

                # 随机选择一个提示的索引
                rand_prompt_index = randrange(0, self.num_prompts)

                # 获取状态(提示)的 token ID
                state = self.prompt_token_ids[rand_prompt_index]

                # 去除状态中的填充
                state_mask = state != self.pad_value
                state = state[state_mask]

                # 生成预测序列
                (
                    actions,
                    sequence,
                    mask,
                    prompt_mask,
                    action_logits,
                    value
                ) = self.actor_critic.generate(
                    rearrange(state, 'n -> 1 n'),
                    max_seq_len=max_seq_len,
                    eos_token=eos_token,
                    temperature=temperature,
                    return_values=True
                )
                action_logits = shift(action_logits, shift=1, dim=-2)  # 需要沿着序列维度移动 1,因为动作从最后一个提示(状态)标记开始

                action_prob = action_logits.softmax(dim=-1)

                action_len = actions.shape[-1]
                action_log_prob = log_prob(action_prob, sequence)
                action_log_prob = action_log_prob[:, -action_len:]

                actions = rearrange(actions, '1 ... -> ...')

                # 使用经过监督训练的奖励模型获取奖励
                sequence = torch.cat((state, actions), dim=0)

                prompt_length = len(state)
                prompt_mask = torch.arange(sequence.shape[-1], device=device) < prompt_length

                sequence = rearrange(sequence, 'n -> 1 n')
                prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
                mask = default(mask, lambda: torch.ones(sequence.shape, dtype=torch.bool, device=device))

                reward = self.reward_model(
                    sequence,
                    prompt_mask=prompt_mask,
                    mask=mask,
                    sample=True
                )

                detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...')

                # 存储用于学习的记忆
                memories.append(Memory(*map(detach_to_cpu_, (
                    sequence,
                    prompt_mask,
                    mask,
                    action_prob,
                    action_log_prob,
                    reward,
                    value
                )))

                # 从存储的记忆中学习
                if time % update_timesteps == 0:
                    self.learn(memories)
                    memories.clear()

        print('rlhf training complete')

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\reward.py

# 导入必要的库
import copy
from pathlib import Path

from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

from palm_rlhf_pytorch.utils import masked_mean, gumbel_sample
from palm_rlhf_pytorch.palm import PaLM

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 奖励模型 - 带有标量头的 PaLM

@beartype
class RewardModel(nn.Module):
    def __init__(
        self,
        palm: PaLM,
        dropout = 0.1,
        num_binned_output = 0.,
        use_lora = True,
        lora_r = 8,
        reward_lora_scope = 'reward',
    ):
        super().__init__()

        # 深拷贝传入的 PaLM 模型
        self.palm = copy.deepcopy(palm)
        self.palm.set_dropout(dropout)

        # 根据 use_lora 参数决定是否使用 LORA
        self.reward_lora_scope = reward_lora_scope if use_lora else None

        # 如果启用了 LORA,则为奖励模型添加微调参数
        if exists(self.reward_lora_scope):
            self.palm.add_finetune_params(reward_lora_scope, lora_r = lora_r)

        dim = palm.dim

        # 判断是否需要输出多个分箱
        self.binned_output = num_binned_output > 1

        # 初始化提示和响应的嵌入向量
        self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
        self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))

        # 根据是否需要多个分箱选择不同的输出层
        if self.binned_output:
            self.to_pred = nn.Linear(dim, num_binned_output)
        else:
            self.to_pred = nn.Sequential(
                nn.Linear(dim, 1, bias = False),
                Rearrange('... 1 -> ...')
            )

    # 加载模型参数
    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))

    # 获取需要微调的参数
    def finetune_parameters(self):
        return [
            *self.to_pred.parameters(),
            *(self.palm.finetune_parameters(self.reward_lora_scope) if exists(self.reward_lora_scope) else self.palm.parameters())
        ]

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None,
        prompt_mask = None,
        prompt_lengths = None,
        labels = None,
        sample = False,
        sample_temperature = 1.,
        disable_lora = False
    ):

        assert not (exists(prompt_mask) and exists(prompt_lengths))

        # 从提示长度中推���提示掩码
        if exists(prompt_lengths):
            batch, seq_len = x.shape
            arange = torch.arange(seq_len, device = x.device)
            prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')

        # 奖励模型应该了解哪部分是提示,哪部分是响应

        extra_embed = None

        if exists(prompt_mask):
            extra_embed = torch.where(
                rearrange(prompt_mask, 'b n -> b n 1'),
                self.prompt_embed,
                self.response_embed
            )

        # 从 PaLM 中获取嵌入向量
        embeds = self.palm(
            x,
            extra_embed = extra_embed,
            return_only_embedding = True,
            disable_lora = disable_lora,
            finetune_scope = self.reward_lora_scope
        )

        # 对嵌入向量进行平均池化
        pooled = masked_mean(embeds, mask, dim = 1)
        pred = self.to_pred(pooled)

        # 如果需要采样并且输出为多个分箱,则对输出进行 Gumbel 采样
        if sample and self.binned_output:
            assert not exists(labels)
            pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1)

        # 如果标签不存在,则直接返回预测值
        if not exists(labels):
            return pred

        # 如果输出不是多个分箱,则计算均方误差损失
        if not self.binned_output:
            return F.mse_loss(pred, labels)

        # 如果输出为多个分箱,则计算交叉熵损失
        return F.cross_entropy(pred, labels)

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\utils.py

# 导入 math、torch 模块,以及从 torch 模块中导入 einsum、nn 和 nn.functional 模块
import math
import torch
from torch import einsum, nn
import torch.nn.functional as F

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange

# 检查变量是否存在的函数
def exists(val):
    return val is not None

# 装饰器函数

# 评估装饰器函数,用于在执行函数时将模型设置为评估模式
def eval_decorator(fn):
    def inner(self, *args, **kwargs):
        was_training = self.training
        self.eval()
        out = fn(self, *args, **kwargs)
        self.train(was_training)
        return out
    return inner

# 张量辅助函数

# 对张量取对数,避免取对数时出现负无穷
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 计算带掩码的平均值,如果没有掩码则直接计算平均值
def masked_mean(seq, mask = None, dim = 1, keepdim = False):
    if not exists(mask):
        return seq.mean(dim = dim)

    if seq.ndim == 3:
        mask = rearrange(mask, 'b n -> b n 1')

    masked_seq = seq.masked_fill(~mask, 0.)
    numer = masked_seq.sum(dim = dim, keepdim = keepdim)
    denom = mask.sum(dim = dim, keepdim = keepdim)

    masked_mean = numer / denom.clamp(min = 1e-3)
    masked_mean = masked_mean.masked_fill(denom == 0, 0.)
    return masked_mean

# 采样辅助函数

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 使用 Gumbel 噪声对张量进行采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# Top-p 采样方法
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# Top-k 采样方法
def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\__init__.py

# 从 palm_rlhf_pytorch.palm 模块中导入 PaLM 类
from palm_rlhf_pytorch.palm import PaLM
# 从 palm_rlhf_pytorch.reward 模块中导入 RewardModel 类
from palm_rlhf_pytorch.reward import RewardModel
# 从 palm_rlhf_pytorch.ppo 模块中导入 RLHFTrainer, ActorCritic 类
from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic

official chatgpt blogpost

PaLM + RLHF - Pytorch (wip)

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Maybe I'll add retrieval functionality too, à la RETRO

If you are interested in replicating something like ChatGPT out in the open, please consider joining Laion Join us on Discord

Potential successor: Direct Preference Optimization - all the code in this repo becomes ~ binary cross entropy loss, < 5 loc. So much for Reward models and PPO

FAQ

  • Does this contain a model for inference?

There is no trained model. This is just the ship and overall map. We still need millions of dollars of compute + data to sail to the correct point in high dimensional parameter space. Even then, you need professional sailors (like Robin Rombach of Stable Diffusion fame) to actually guide the ship through turbulent times to that point.

Community

CarperAI had been working on an RLHF framework for large language models for many months prior to the release of ChatGPT.

Yannic Kilcher is also working on an open sourced implementation

AI Coffeebreak w/ Letitia | Code Emporium | Code Emporium Part 2

Appreciation

Install

$ pip install palm-rlhf-pytorch

Usage

First train PaLM, like any other autoregressive transformer

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    flash_attn = True # https://arxiv.org/abs/2205.14135
).cuda()

seq = torch.randint(0, 20000, (1, 2048)).cuda()

loss = palm(seq, return_loss = True)
loss.backward()

# after much training, you can now generate sequences

generated = palm.generate(2048) # (1, 2048)

Then train your reward model, with the curated human feedback. In the original paper, they could not get reward model to be finetuned from a pretrained transformer without overfitting, but I gave the option to finetune with LoRA anyways, since it is still open research.

import torch
from palm_rlhf_pytorch import PaLM, RewardModel

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    causal = False
)

reward_model = RewardModel(
    palm,
    num_binned_output = 5 # say rating from 1 to 5
).cuda()

# mock data

seq = torch.randint(0, 20000, (1, 1024)).cuda()
prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()

# train

loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()

# after much training

reward = reward_model(seq, prompt_mask = prompt_mask)

Then you will pass your transformer and the rewards model to the RLHFTrainer

import torch
from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer

# load your pretrained palm

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12
).cuda()

palm.load('./path/to/pretrained/palm.pt')

# load your pretrained reward model

reward_model = RewardModel(
    palm,
    num_binned_output = 5
).cuda()

reward_model.load('./path/to/pretrained/reward_model.pt')

# ready your list of prompts for reinforcement learning

prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts

# pass it all to the trainer and train

trainer = RLHFTrainer(
    palm = palm,
    reward_model = reward_model,
    prompt_token_ids = prompts
)

trainer.train(num_episodes = 50000)

# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one

answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)

Todo

Citations

@article{Stiennon2020LearningTS,
    title   = {Learning to summarize from human feedback},
    author  = {Nisan Stiennon and Long Ouyang and Jeff Wu and Daniel M. Ziegler and Ryan J. Lowe and Chelsea Voss and Alec Radford and Dario Amodei and Paul Christiano},
    journal = {ArXiv},
    year    = {2020},
    volume  = {abs/2009.01325}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
@article{Hu2021LoRALA,
    title   = {LoRA: Low-Rank Adaptation of Large Language Models},
    author  = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2106.09685}
}
@inproceedings{Sun2022ALT,
    title     = {A Length-Extrapolatable Transformer},
    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
    year      = {2022}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}

.\lucidrains\PaLM-rlhf-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'PaLM-rlhf-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.2.1',  # 版本号
  license='MIT',  # 许可证
  description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/PaLM-rlhf-pytorch',  # URL
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'reinforcement learning',
    'human feedback'
  ],
  install_requires=[  # 安装依赖
    'accelerate',
    'beartype',
    'einops>=0.6',
    'lion-pytorch',
    'torch>=1.6',
    'tqdm'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\PaLM-rlhf-pytorch\train.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np

import torch
from lion_pytorch import Lion
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from palm_rlhf_pytorch import PaLM
from accelerate import Accelerator

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))


# 初始化加速器
accelerator = Accelerator()
device = accelerator.device

# 实例化 PaLM 模型
model = PaLM(
    num_tokens=256,
    dim=512,
    depth=8,
    flash_attn=True
).to(device)

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 初始化优化器
optim = Lion(model.palm_parameters(), lr = LEARNING_RATE)

# 准备模型、优化器、训练集加载器和验证集加载器
model, optim, train_loader, val_loader = accelerator.prepare(
    model, optim, train_loader, val_loader
)

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    accelerator.print(f"training loss: {loss.item()}")
    accelerator.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            accelerator.print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        accelerator.print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(GENERATE_LENGTH, inp[None, ...])
        output_str = decode_tokens(sample[0])
        accelerator.print(output_str, "\n")

.\lucidrains\panoptic-transformer\panoptic_transformer\data.py

# 导入所需的库
from pathlib import Path
from random import choice
from PIL import Image
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torchvision import transforms as T

# 定义一个循环生成器函数,用于循环遍历数据集
def cycle(dl):
    while True:
        for el in dl:
            yield el

# 定义 PathfinderXDataset 类,继承自 Dataset 类
class PathfinderXDataset(Dataset):
    def __init__(
        self,
        folder,
        augment = False
    ):
        super().__init__()
        # 获取文件夹中所有的 .npy 文件
        metadata_files = [*Path(folder).glob(f'**/*.npy')]
        # 断言确保找到了至少一个 metadata 文件
        assert len(metadata_files) > 0, 'not able to find more than 1 metadata file'

        # 获取第一个 metadata 文件
        metadata_file = metadata_files[0]
        # 加载 metadata 文件
        metadata = np.load(str(metadata_file))
        # 获取 metadata 文件的父目录
        root_path = metadata_file.parents[1]

        self.augment = augment
        # 将数据集的路径和标签存储为元组的列表
        self.data = [(str(root_path / m[0] / m[1]), int(m[3])) for m in metadata]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, ind):
        # 获取指定索引的路径和标签
        path, label = self.data[ind]
        # 打开图像文件
        img = Image.open(path)

        # 对图像进行数据增强处理
        img = T.Compose([
            T.RandomHorizontalFlip() if self.augment else nn.Identity(),
            T.RandomVerticalFlip() if self.augment else nn.Identity(),
            T.PILToTensor()
        ])(img)

        # 将标签转换为 torch 张量
        label = torch.tensor(label, dtype = torch.float32)

        if self.augment:
            # 随机选择旋转角度
            rand_rotate = [0, 90, 180, 270]
            img = T.functional.rotate(img, choice(rand_rotate))
            # 随机选择填充方式
            rand_padding = [(0, 0, 0, 0), (1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1)]
            img = F.pad(img, choice(rand_padding))

        return img.float(), label

# 获取训练和验证数据加载器函数
def get_dataloaders(
    data_path,
    *,
    augment = True,
    frac_valids = 0.05,
    batch_size
):
    # 创建 PathfinderXDataset 实例
    ds = PathfinderXDataset(data_path, augment = augment)

    total_samples = len(ds)
    # 计算验证集样本数量
    num_valid = int(frac_valids * total_samples)
    # 计算训练集样本数量
    num_train = total_samples - num_valid

    print(f'training with {num_train} samples and validating with {num_valid} samples')

    # 随机划分数据集为训练集和验证集
    train_ds, valid_ds = random_split(ds, [num_train, num_valid])

    # 创建训练数据加载器和验证数据加载器
    train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle = True)
    valid_dl = DataLoader(valid_ds, batch_size = batch_size, shuffle = True)

    return cycle(train_dl), cycle(valid_dl)

.\lucidrains\panoptic-transformer\panoptic_transformer\panoptic_transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch.nn.functional 中导入 F 模块

# 定义一个名为 Attention 的类,继承自 nn.Module 类
class Attention(nn.Module):
    # 初始化函数,接受参数 dim、dim_head 和 heads
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        # 计算内部维度
        inner_dim = heads * dim_head
        # 缩放因子
        self.scale = dim_head ** -0.5
        # 头数
        self.heads = heads

        # 定义一个线性层,用于将输入转换为查询向量
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        # 定义一个线性层,用于将输入转换为键值对
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
        # 定义一个线性层,用于将输出转换为指定维度
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 将输入 x 转换为查询向量 q,键向量 k 和值向量 v
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        # 重排查询向量 q 的维度
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
        # 缩放查询向量 q
        q = q * self.scale

        # 计算相似度矩阵 sim
        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # 对相似度矩阵进行 softmax 操作,得到注意力矩阵 attn
        attn = sim.softmax(dim = -1)

        # 根据注意力矩阵计算输出 out
        out = einsum('b h i j, b j d -> b h i d', attn , v)

        # 重排输出 out 的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回转换后的输出
        return self.to_out(out)

# 定义一个名为 PanopticTransformer 的类,继承自 nn.Module 类
class PanopticTransformer(nn.Module):
    # 初始化函数,接受参数 dim、dim_head 和 heads
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 直接返回输入 x,未进行任何操作
        return x

.\lucidrains\panoptic-transformer\panoptic_transformer\__init__.py

# 从 panoptic_transformer 包中导入 PanopticTransformer 类
from panoptic_transformer.panoptic_transformer import PanopticTransformer

Panoptic Transformer (wip)

Another attempt at a long-context / efficient transformer by me. This approach will completely generalize all multi-scale approaches of the past. I will be attempting the Pathfinder-X task, which so far has not been beat by a transformer.

Update: on track to solving path-x with transformers

Training

The script will generate 25000 training samples (in paper they used 100k; you can change it to this number if you are willing to wait).

$ ./setup.sh

.\lucidrains\panoptic-transformer\scripts\gen-pathx.py

# 导入所需的库
import time
import sys
import numpy as np
import os

# 导入自定义的 snakes2 模块
import snakes2

# 定义一个参数类,用于设置各种参数
class Args:
    def __init__(self,
                 contour_path = './contour', batch_id=0, n_images = 200000,
                 window_size=[256,256], padding=22, antialias_scale = 4,
                 LABEL =1, seed_distance= 27, marker_radius = 3,
                 contour_length=15, distractor_length=5, num_distractor_snakes=6, snake_contrast_list=[1.], use_single_paddles=True,
                 max_target_contour_retrial = 4, max_distractor_contour_retrial = 4, max_paddle_retrial=2,
                 continuity = 1.4, paddle_length=5, paddle_thickness=1.5, paddle_margin_list=[4], paddle_contrast_list=[1.],
                 pause_display=False, save_images=True, save_metadata=True):

        # 初始化参数
        self.contour_path = contour_path
        self.batch_id = batch_id
        self.n_images = n_images

        self.window_size = window_size
        self.padding = padding
        self.antialias_scale = antialias_scale

        self.LABEL = LABEL
        self.seed_distance = seed_distance
        self.marker_radius = marker_radius
        self.contour_length = contour_length
        self.distractor_length = distractor_length
        self.num_distractor_snakes = num_distractor_snakes
        self.snake_contrast_list = snake_contrast_list
        self.use_single_paddles = use_single_paddles

        self.max_target_contour_retrial = max_target_contour_retrial
        self.max_distractor_contour_retrial = max_distractor_contour_retrial
        self.max_paddle_retrial = max_paddle_retrial

        self.continuity = continuity
        self.paddle_length = paddle_length
        self.paddle_thickness = paddle_thickness
        self.paddle_margin_list = paddle_margin_list # 如果列表中有多个元素,每个图像将采样一个数字
        self.paddle_contrast_list = paddle_contrast_list # 如果列表中有多个元素,每个 paddle 将采样一个数字

        self.pause_display = pause_display
        self.save_images = save_images
        self.save_metadata = save_metadata

# 记录开始时间
t = time.time()
# 创建参数对象
args = Args()

# 从命令行参数中获取机器数量、当前 ID 和总图像数量
num_machines = int(sys.argv[1])
current_id = int(sys.argv[2])
args.batch_id = current_id
total_images = int(sys.argv[3])
args.n_images = total_images/num_machines
dataset_root = './pathx-data' #'/media/data_cifs/pathfinder_seg/'

# 根据命令行参数设置数据集根目录
if len(sys.argv)==4:
    print('Using default path...')
elif len(sys.argv)==5:
    print('Using custom save path...')
    dataset_root = str(sys.argv[4])

# 设置一些参数的值
args.padding = 1
args.antialias_scale = 4
args.paddle_margin_list = [2,3]
args.seed_distance = 20
args.window_size = [128,128]
args.marker_radius = 3
args.contour_length = 14
args.paddle_thickness = 1.5
args.antialias_scale = 2
args.continuity = 1.8  # 从 1.8 到 0.8,步长为 66%
args.distractor_length = args.contour_length // 3
args.num_distractor_snakes = 35 / args.distractor_length
args.snake_contrast_list = [0.9]

args.use_single_paddles = False
args.segmentation_task = False # False
args.segmentation_task_double_circle = False

# 设置轮廓路径
dataset_subpath = 'curv_baseline'
args.contour_path = os.path.join(dataset_root, dataset_subpath)

# 调用 snakes2 模块中的 from_wrapper 函数,传入参数对象
snakes2.from_wrapper(args)

.\lucidrains\panoptic-transformer\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'panoptic-transformer', # 包名
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.1', # 版本号
  license='MIT', # 许可证
  description = 'Panoptic Transformer', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/panoptic-transformer', # 项目链接
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention-mechanism',
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\parti-pytorch\parti_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    **kwargs
):
    # 根据是否需要梯度过滤参数
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果权重衰减为0,则使用 Adam 优化器
    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要对参数进行分组权重衰减
    if group_wd_params:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 将参数分为需要权重衰减和不需要权重衰减的两组
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 使用 AdamW 优化器,设置学习率、权重衰减、动量参数和 epsilon
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\parti-pytorch\parti_pytorch\parti_pytorch.py

# 导入所需的库
from typing import List
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn, einsum
import torchvision.transforms as T

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from parti_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 模型评估装饰器,用于在评估模式下运行模型
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 采样辅助函数

# 计算对数
def log(t, eps = 1e-20):
    return torch.log(t + eps)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从 Gumbel 分布中采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# 保留前 k 个最大值,其余设为负无穷
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 无监督分类器辅助函数

# 根据概率生成掩码
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 归一化

# LayerNorm 模块
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 2D 相对位置偏置

class RelPosBias2d(nn.Module):
    def __init__(self, size, heads):
        super().__init__()
        self.pos_bias = nn.Embedding((2 * size - 1) ** 2, heads)

        arange = torch.arange(size)

        pos = torch.stack(torch.meshgrid(arange, arange, indexing = 'ij'), dim = -1)
        pos = rearrange(pos, '... c -> (...) c')
        rel_pos = rearrange(pos, 'i c -> i 1 c') - rearrange(pos, 'j c -> 1 j c')

        rel_pos = rel_pos + size - 1
        h_rel, w_rel = rel_pos.unbind(dim = -1)
        pos_indices = h_rel * (2 * size - 1) + w_rel
        self.register_buffer('pos_indices', pos_indices)

    def forward(self, qk):
        i, j = qk.shape[-2:]

        bias = self.pos_bias(self.pos_indices[:i, :(j - 1)])
        bias = rearrange(bias, 'i j h -> h i j')

        bias = F.pad(bias, (j - bias.shape[-1], 0), value = 0.) # 考虑无监督分类器辅助指导的空键/值
        return bias

# 前馈网络

def FeedForward(dim, mult = 4, dropout = 0.):
    dim_hidden = int(dim * mult)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, dim_hidden, bias = False),
        nn.GELU(),
        LayerNorm(dim_hidden),
        nn.Linear(dim_hidden, dim, bias = False)
    )

# 注意力机制

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        causal = False,
        dropout = 0.,
        norm_context = False,
        rel_pos_bias = False,
        encoded_fmap_size = None
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置是否使用因果关系
        self.causal = causal
        # 计算缩放因子
        self.scale = dim_head ** -0.5
        # 对输入进行归一化
        self.norm = LayerNorm(dim)

        # 计算内部维度
        inner_dim = heads * dim_head
        # 设置上下文维度
        context_dim = default(context_dim, dim)
        # 如果需要对上下文进行归一化,则使用 LayerNorm,否则使用 nn.Identity()
        self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()

        # 构建查询层
        self.to_q = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(dim, inner_dim, bias = False),
            Rearrange('b n (h d) -> b h n d', h = heads)
        )

        # 需要用于分类器自由引导的变换器
        self.null_kv = nn.Parameter(torch.randn(dim_head))

        # 单头键/值注意力,来自 Shazeer 的多查询论文,被 Alphacode 和 PaLM 采用
        self.to_kv = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(context_dim, dim_head, bias = False)
        )

        # 输出层
        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(inner_dim, dim, bias = False)
        )

        # 位置偏置
        self.rel_pos_bias = None

        # 如果需要相对位置偏置
        if rel_pos_bias:
            assert exists(encoded_fmap_size)
            # 初始化相对位置偏置
            self.rel_pos_bias = RelPosBias2d(encoded_fmap_size, heads)

    def forward(
        self,
        x,
        context = None,
        context_mask = None
    ):
        # 获取批次大小和设备信息
        batch, device = x.shape[0], x.device

        # 对输入进行归一化
        x = self.norm(x)

        # 计算查询向量
        q = self.to_q(x) * self.scale

        # 获取上下文信息
        context = default(context, x)
        context = self.norm_context(context)

        # 计算键/值对
        kv = self.to_kv(context)

        # 创建空键/值对
        null_kv = repeat(self.null_kv, 'd -> b 1 d', b = batch)
        kv = torch.cat((null_kv, kv), dim = 1)

        # 计算相似度
        sim = einsum('b h i d, b j d -> b h i j', q, kv)

        # 如果存在相对位置偏置
        if exists(self.rel_pos_bias):
            pos_bias = self.rel_pos_bias(sim)
            sim = sim + pos_bias

        # 设置掩码值
        mask_value = -torch.finfo(sim.dtype).max

        # 如果存在上下文掩码
        if exists(context_mask):
            context_mask = F.pad(context_mask, (1, 0), value = True)
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~context_mask, mask_value)

        # 如果是因果关系
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        # 计算注意力权重
        attn = sim.softmax(dim = -1, dtype = torch.float32)
        # 计算输出
        out = einsum('b h i j, b j d -> b h i d', attn, kv)

        return self.to_out(out)
# 定义一个名为Parti的类,继承自nn.Module
class Parti(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        ff_mult = 4,
        vae = None,
        vae_image_size = None,
        vae_codebook_size = None,
        t5_name = DEFAULT_T5_NAME,
        text_embed_dim = None,
        cond_drop_prob = 0.25,
        max_text_len = 128,
        ignore_index = -1
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 文本编码
        text_embed_dim = default(text_embed_dim, get_encoded_dim(t5_name))
        self.encode_texts = partial(t5_encode_text, name = t5_name)
        self.max_text_len = max_text_len

        assert cond_drop_prob > 0.
        self.cond_drop_prob = cond_drop_prob # 用于transformers的分类器自由引导 - @crowsonkb

        # VAE和图像处理
        assert exists(vae) ^ exists(vae_codebook_size)
        self.vae = vae

        codebook_size = default(vae_codebook_size, vae.codebook_size)
        image_size = default(vae_image_size, vae.image_size)

        self.start_token = nn.Parameter(torch.randn(dim))
        self.image_token_embed = nn.Embedding(codebook_size, dim)

        self.image_encoded_dim = vae.get_encoded_fmap_size(image_size)

        self.axial_height_pos = nn.Parameter(torch.randn(self.image_encoded_dim, dim))
        self.axial_width_pos = nn.Parameter(torch.randn(self.image_encoded_dim, dim))

        # 投影到logits
        self.init_norm = LayerNorm(dim)

        self.layers = nn.ModuleList([])

        # 循环depth次,添加Attention、FeedForward等模块到layers中
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, causal = True, encoded_fmap_size = self.image_encoded_dim, rel_pos_bias = True, dim_head = dim_head, heads = heads, dropout = dropout),
                Attention(dim, context_dim = text_embed_dim, dim_head = dim_head, heads = heads, dropout = dropout),
                FeedForward(dim, mult = ff_mult, dropout = dropout)
            ]))

        self.final_norm = LayerNorm(dim)

        self.to_logits = nn.Linear(dim, codebook_size, bias = False)
        self.to_logits.weight = self.image_token_embed.weight

        # 默认设备
        if exists(vae):
            self.to(next(vae.parameters()).device)

        # 与损失相关
        self.ignore_index = ignore_index

    # 生成函数,用于生成图像
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        texts,
        *,
        cond_scale = 3.,
        filter_thres = 0.9,
        temperature = 1.,
        return_pil_images = False
    ):
        device = next(self.parameters()).device

        text_token_embeds, text_mask = self.encode_texts(texts, output_device = device)

        batch = text_token_embeds.shape[0]
        image_seq_len = self.image_encoded_dim ** 2
        image_tokens = torch.empty((batch, 0), device = device, dtype = torch.long)

        # 循环生成图像序列
        for _ in range(image_seq_len):
            logits = self.forward_with_cond_scale(
                text_token_embeds = text_token_embeds,
                text_mask = text_mask,
                image_token_ids = image_tokens
            )[:, -1]

            filtered_logits = top_k(logits, thres = filter_thres)
            sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sampled = rearrange(sampled, 'b -> b 1')
            image_tokens = torch.cat((image_tokens, sampled), dim = -1)

        image_tokens = rearrange(image_tokens, 'b (h w) -> b h w', h = self.image_encoded_dim)

        # 如果没有VAE,则直接返回图像tokens
        if not exists(self.vae):
            return image_tokens

        with torch.no_grad():
            fmap = self.vae.get_fmap_from_codebook(image_tokens)
            images = self.vae.decode(fmap)

        # 如果return_pil_images为True,则返回PIL格式的图像
        if not return_pil_images:
            return images

        pil_images = list(map(T.ToPILImage(), images.unbind(dim = 0))
        return pil_images
    # 带有条件缩放的前向传播函数,根据条件缩放因子对输出进行缩放
    def forward_with_cond_scale(self, *args, cond_scale = 3, **kwargs):
        # 调用前向传播函数获取输出 logits
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        # 如果条件缩放因子为1,则直接返回 logits
        if cond_scale == 1:
            return logits

        # 否则,计算空值 logits,并返回缩放后的结果
        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    # 前向传播函数,接收文本和图像输入,返回 logits 或损失
    def forward(
        self,
        texts: List[str] = None,
        text_token_embeds = None,
        text_mask = None,
        images = None,
        image_token_ids = None,
        cond_drop_prob = None,
        return_loss = False
    ):
        # 断言文本或文本嵌入必须存在,图像或图像 token ID 必须存在
        assert exists(texts) ^ exists(text_token_embeds)
        assert exists(images) ^ exists(image_token_ids)
        # 设置条件丢弃概率为默认值或传入值
        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        # 编码图像

        # 如果不存在图像 token ID,则使用 VAE 对图像进行编码
        if not exists(image_token_ids):
            assert exists(self.vae), 'vae must be given if you want to encode the image live'

            with torch.no_grad():
                _, image_token_ids, _ = self.vae.encode(images, return_indices_and_loss = True)

            image_token_ids = rearrange(image_token_ids, 'b ... -> b (...)')

        # 如果需要返回损失,则截取最后一个 token 作为标签
        if return_loss:
            assert image_token_ids.shape[-1] > 1, 'not enough image tokens given to return a loss'
            image_token_ids, labels = image_token_ids[:, :-1], image_token_ids

        # 获取图像 token 嵌入
        image_token_emb = self.image_token_embed(image_token_ids)

        # 添加轴向位置嵌入

        axial_pos_emb = rearrange(self.axial_width_pos, 'w d -> 1 w d') + rearrange(self.axial_height_pos, 'h d -> h 1 d')
        axial_pos_emb = rearrange(axial_pos_emb, 'h w d -> (h w) d')

        batch, seq_len, device = *image_token_emb.shape[:2], image_token_emb.device

        image_token_emb = image_token_emb + axial_pos_emb[:seq_len]

        # 添加起始 token

        start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch)
        image_token_emb = torch.cat((start_tokens, image_token_emb), dim = 1)

        # 文本

        # 如果不存在文本 token 嵌入,则使用编码文本函数对文本进行编码
        if not exists(text_token_embeds):
            with torch.no_grad():
                text_token_embeds, text_mask = self.encode_texts(texts, output_device = device)

        # 如果不存在文本 mask,则创建全为 True 的 mask
        if not exists(text_mask):
            text_mask = torch.ones(text_token_embeds.shape[:2], dtype = torch.bool)

        # 限制文本长度不超过最大文本长度
        text_token_embeds, text_mask = map(lambda t: t[:, :self.max_text_len], (text_token_embeds, text_mask))

        # 分类器自由引导条件丢弃

        # 如果条件丢弃概率大于0,则根据概率生成保留 mask
        if cond_drop_prob > 0:
            keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

        # 注意力

        x = image_token_emb
        x = self.init_norm(x)

        # 遍历每个层,依次进行自注意力、交叉注意力和前馈网络操作
        for self_attn, cross_attn, ff in self.layers:
            x = self_attn(x) + x
            x = cross_attn(x, context = text_token_embeds, context_mask = text_mask) + x
            x = ff(x) + x

        x = self.final_norm(x)

        # 转换为 logits

        logits = self.to_logits(x)

        # 如果不需要返回损失,则直接返回 logits
        if not return_loss:
            return logits

        # 计算交叉熵损失
        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels,
            ignore_index = self.ignore_index
        )

        return loss

.\lucidrains\parti-pytorch\parti_pytorch\t5.py

# 导入 torch 库
import torch
# 导入 transformers 库
import transformers
# 从 transformers 库中导入 T5Tokenizer, T5EncoderModel, T5Config

# 设置 transformers 库的日志级别为 error
transformers.logging.set_verbosity_error()

# 定义一个函数,用于检查值是否存在
def exists(val):
    return val is not None

# 配置

# 定义最大长度为 256
MAX_LENGTH = 256

# 默认的 T5 模型名称
DEFAULT_T5_NAME = 'google/t5-v1_1-base'

# 存储 T5 模型配置的字典
T5_CONFIGS = {}

# 全局单例变量

# 获取指定名称的 tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

# 获取指定名称的模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定名称的模型和 tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()
    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)
    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        # 避免在只需要获取维度时加载模型
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config=config)
    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]
    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config
    else:
        assert False
    return config.d_model

# 编码文本

# 对文本进行编码
def t5_encode_text(texts, name = DEFAULT_T5_NAME, output_device = None):
    # 获取模型和 tokenizer
    t5, tokenizer = get_model_and_tokenizer(name)

    # 如果 CUDA 可用,则将模型移至 CUDA
    if torch.cuda.is_available():
        t5 = t5.cuda()

    # 获取设备
    device = next(t5.parameters()).device

    # 使用 tokenizer 对文本进行编码
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    # 将输入数据移至设备
    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    # 设置模型为评估模式
    t5.eval()

    # 禁用梯度计算
    with torch.no_grad():
        # 获取模型输出
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    # 将注意���掩码转换为布尔类型
    attn_mask = attn_mask.bool()

    # 如果输出设备不存在,则返回编码文本和注意力掩码
    if not exists(output_device):
        return encoded_text, attn_mask

    # 将编码文本和注意力掩码移至输出设备
    encoded_text.to(output_device)
    attn_mask.to(output_device)

    return encoded_text, attn_mask

.\lucidrains\parti-pytorch\parti_pytorch\version.py

# 定义当前代码版本号为 '0.2.0'
__version__ = '0.2.0'

.\lucidrains\parti-pytorch\parti_pytorch\vit_vqgan.py

# 导入必要的库
import copy
import math
from math import sqrt
from functools import partial, wraps

# 导入自定义的模块
from vector_quantize_pytorch import VectorQuantize as VQ, LFQ

# 导入 PyTorch 库
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision

# 导入 einops 库
from einops import rearrange, reduce, repeat, pack, unpack
from einops_exts import rearrange_many
from einops.layers.torch import Rearrange

# 定义常量
MList = nn.ModuleList

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 装饰器

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 移除 VGG 属性装饰器
def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, 'vgg')
        if has_vgg:
            vgg = self.vgg
            delattr(self, 'vgg')

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self.vgg = vgg

        return out
    return inner

# 关键字参数辅助函数

# 从字典中选择指定键的值并弹出这些键
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, string_input):
    return string_input.startswith(prefix)

# 根据前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀将字典分组并去除前缀
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 张量辅助函数

# 对数函数
def log(t, eps = 1e-10):
    return torch.log(t + eps)

# 计算梯度惩罚
def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs = output, inputs = images,
                           grad_outputs = torch.ones(output.size(), device = images.device),
                           create_graph = True, retain_graph = True, only_inputs = True)[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

# 安全除法
def safe_div(numer, denom, eps = 1e-8):
    return numer / (denom + eps)

# GAN 损失函数

# Hinge 判别器损失
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 生成器损失
def hinge_gen_loss(fake):
    return -fake.mean()

# BCE 判别器损失
def bce_discr_loss(fake, real):
    return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()

# BCE 生成器损失
def bce_gen_loss(fake):
    return -log(torch.sigmoid(fake)).mean()

# 计算损失对层的梯度
def grad_layer_wrt_loss(loss, layer):
    return torch_grad(
        outputs = loss,
        inputs = layer,
        grad_outputs = torch.ones_like(loss),
        retain_graph = True
    )[0].detach()

# 傅立叶变换

# 正弦余弦位置编码
class SinusoidalPosEmb(nn.Module):
    def __init__(
        self,
        dim,
        height_or_width,
        theta = 10000
    ):
        super().__init__()
        self.dim = dim
        self.theta = theta

        hw_range = torch.arange(height_or_width)
        coors = torch.stack(torch.meshgrid(hw_range, hw_range, indexing = 'ij'), dim = -1)
        coors = rearrange(coors, 'h w c -> h w c')
        self.register_buffer('coors', coors, persistent = False)
    # 定义一个前向传播函数,接受输入 x
    def forward(self, x):
        # 计算特征维度的一半
        half_dim = self.dim // 2
        # 计算嵌入向量的值
        emb = math.log(self.theta) / (half_dim - 1)
        # 计算指数函数
        emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
        # 重排坐标和嵌入向量的维度
        emb = rearrange(self.coors, 'h w c -> h w c 1') * rearrange(emb, 'j -> 1 1 1 j')
        # 将正弦和余弦部分连接起来
        fourier = torch.cat((emb.sin(), emb.cos()), dim = -1)
        # 将嵌入向量重复到与输入 x 相同的维度
        fourier = repeat(fourier, 'h w c d -> b (c d) h w', b = x.shape[0])
        # 将输入 x 和傅立叶特征连接起来
        return torch.cat((x, fourier), dim = 1)
# 定义通道层归一化模块
class ChanLayerNorm(nn.Module):
    def __init__(
        self,
        dim,
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        # 计算输入张量 x 的方差和均值
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        # 返回归一化后的结果
        return (x - mean) * (var + self.eps).rsqrt() * self.gamma

# 定义交叉嵌入层模块
class CrossEmbedLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        kernel_sizes,
        dim_out = None,
        stride = 2
    ):
        super().__init__()
        assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
        dim_out = default(dim_out, dim_in)

        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # 计算每个尺度的维度
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        # 对输入 x 进行卷积操作
        fmaps = tuple(map(lambda conv: conv(x), self.convs))
        # 拼接卷积结果
        return torch.cat(fmaps, dim = 1)

# 定义块模块
class Block(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8
    ):
        super().__init__()
        self.groupnorm = nn.GroupNorm(groups, dim)
        self.activation = leaky_relu()
        self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)

    def forward(self, x, scale_shift = None):
        x = self.groupnorm(x)
        x = self.activation(x)
        return self.project(x)

# 定义残差块模块
class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        groups = 8
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.block = Block(dim, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x):
        h = self.block(x)
        return h + self.res_conv(x)

# 定义鉴别器模块
class Discriminator(nn.Module):
    def __init__(
        self,
        dims,
        channels = 3,
        groups = 8,
        init_kernel_size = 5,
        cross_embed_kernel_sizes = (3, 7, 15)
    ):
        super().__init__()
        init_dim, *_, final_dim = dims
        dim_pairs = zip(dims[:-1], dims[1:])

        self.layers = MList([nn.Sequential(
            CrossEmbedLayer(channels, cross_embed_kernel_sizes, init_dim, stride = 1),
            leaky_relu()
        )])

        for dim_in, dim_out in dim_pairs:
            self.layers.append(nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
                leaky_relu(),
                nn.GroupNorm(groups, dim_out),
                ResnetBlock(dim_out, dim_out),
            ))

        self.to_logits = nn.Sequential( # 返回 5 x 5,用于 PatchGAN 风格的训练
            nn.Conv2d(final_dim, final_dim, 1),
            leaky_relu(),
            nn.Conv2d(final_dim, 1, 4)
        )

    def forward(self, x):
        for net in self.layers:
            x = net(x)

        return self.to_logits(x)

# 定义 2D 相对位置偏置模块
class RelPosBias2d(nn.Module):
    # 初始化函数,接受输入的size和heads参数
    def __init__(self, size, heads):
        # 调用父类的初始化函数
        super().__init__()
        # 创建一个嵌入层,用于存储位置偏置信息,参数为((2 * size - 1) ** 2, heads)
        self.pos_bias = nn.Embedding((2 * size - 1) ** 2, heads)

        # 生成一个从0到size-1的张量
        arange = torch.arange(size)

        # 生成一个二维网格,表示位置信息
        pos = torch.stack(torch.meshgrid(arange, arange, indexing='ij'), dim=-1)
        # 重新排列张量的维度
        pos = rearrange(pos, '... c -> (...) c')
        # 计算相对位置信息
        rel_pos = rearrange(pos, 'i c -> i 1 c') - rearrange(pos, 'j c -> 1 j c')

        # 将相对位置信息调整到合适的范围
        rel_pos = rel_pos + size - 1
        # 拆分相对位置信息为高度和宽度
        h_rel, w_rel = rel_pos.unbind(dim=-1)
        # 计算位置索引
        pos_indices = h_rel * (2 * size - 1) + w_rel
        # 将位置索引注册为模型的缓冲区
        self.register_buffer('pos_indices', pos_indices)

    # 前向传播函数,接受输入qk
    def forward(self, qk):
        # 获取输入张量的倒数第二和倒数第一维度的大小
        i, j = qk.shape[-2:]

        # 根据位置索引获取位置偏置信息
        bias = self.pos_bias(self.pos_indices)
        # 重新排列位置偏置信息的维度
        bias = rearrange(bias, 'i j h -> h i j')
        # 返回位置偏置信息
        return bias
# ViT 编码器/解码器

class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        # 定义一个卷积层,用于投影
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        # 对输入进行投影操作
        return self.proj(x)

class SPT(nn.Module):
    """ https://arxiv.org/abs/2112.13492 """

    def __init__(self, *, dim, patch_size, channels = 3):
        super().__init__()
        patch_dim = patch_size * patch_size * 5 * channels

        # 将输入图像划分为补丁,并进行通道层归一化和卷积操作
        self.to_patch_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
            ChanLayerNorm(patch_dim),
            nn.Conv2d(patch_dim, dim, 1)
        )

    def forward(self, x):
        shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
        shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
        x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
        return self.to_patch_tokens(x_with_shifts)

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads = 8,
        dim_head = 32,
        fmap_size = None,
        rel_pos_bias = False
    ):
        super().__init__()
        # 通道层归一化
        self.norm = ChanLayerNorm(dim)
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        # 将输入转换为查询、键、值,并进行卷积操作
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
        self.primer_ds_convs = nn.ModuleList([PEG(inner_dim) for _ in range(3)])

        # 输出卷积层
        self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)

        # 如果需要相对位置偏置,则创建相对位置偏置对象
        self.rel_pos_bias = None
        if rel_pos_bias:
            assert exists(fmap_size)
            self.rel_pos_bias = RelPosBias2d(fmap_size, heads)

    def forward(self, x):
        fmap_size = x.shape[-1]
        h = self.heads

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        q, k, v = [ds_conv(t) for ds_conv, t in zip(self.primer_ds_convs, (q, k, v))]
        q, k, v = rearrange_many((q, k, v), 'b (h d) x y -> b h (x y) d', h = h)

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        if exists(self.rel_pos_bias):
            sim = sim + self.rel_pos_bias(sim)

        attn = sim.softmax(dim = -1, dtype = torch.float32)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = fmap_size, y = fmap_size)
        return self.to_out(out)

def FeedForward(dim, mult = 4):
    return nn.Sequential(
        ChanLayerNorm(dim),
        nn.Conv2d(dim, dim * mult, 1, bias = False),
        nn.GELU(),
        PEG(dim * mult),
        nn.Conv2d(dim * mult, dim, 1, bias = False)
    )

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        layers,
        dim_head = 32,
        heads = 8,
        ff_mult = 4,
        fmap_size = None
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(layers):
            # 每个 Transformer 层包含投影、注意力和前馈网络
            self.layers.append(nn.ModuleList([
                PEG(dim = dim),
                Attention(dim = dim, dim_head = dim_head, heads = heads, fmap_size = fmap_size, rel_pos_bias = True),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        self.norm = ChanLayerNorm(dim)

    def forward(self, x):
        for peg, attn, ff in self.layers:
            x = peg(x) + x
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ViTEncDec(nn.Module):
    def __init__(
        self,
        dim,
        image_size,
        channels = 3,
        layers = 4,
        patch_size = 16,
        dim_head = 32,
        heads = 8,
        ff_mult = 4
    # 初始化函数,设置编码维度和补丁大小
    def __init__(
        self,
        dim,
        patch_size,
        channels,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        layers = 12,
        image_size = 224
    ):
        # 调用父类初始化函数
        super().__init__()
        # 设置编码维度和补丁大小
        self.encoded_dim = dim
        self.patch_size = patch_size

        # 计算输入维度
        input_dim = channels * (patch_size ** 2)
        # 计算特征图大小
        fmap_size = image_size // patch_size

        # 编码器部分
        self.encoder = nn.Sequential(
            # SPT 模块
            SPT(dim = dim, patch_size = patch_size, channels = channels),
            # Transformer 模块
            Transformer(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                layers = layers,
                fmap_size = fmap_size
            ),
        )

        # 解码器部分
        self.decoder = nn.Sequential(
            # Transformer 模块
            Transformer(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                layers = layers,
                fmap_size = fmap_size
            ),
            # 后续处理
            nn.Sequential(
                SinusoidalPosEmb(dim // 2, height_or_width = fmap_size),
                nn.Conv2d(2 * dim, dim * 4, 3, bias = False, padding = 1),
                nn.Tanh(),
                nn.Conv2d(dim * 4, input_dim, 1, bias = False),
            ),
            # 重排数据维度
            Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
        )

    # 获取编码后的特征图大小
    def get_encoded_fmap_size(self, image_size):
        return image_size // self.patch_size

    # 获取最后一个解码器层的权重
    @property
    def last_dec_layer(self):
        return self.decoder[-2][-1].weight

    # 编码函数
    def encode(self, x):
        return self.encoder(x)

    # 解码函数
    def decode(self, x):
        return self.decoder(x)
# 定义 VitVQGanVAE 类,继承自 nn.Module
class VitVQGanVAE(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 模型维度
        image_size,  # 图像尺寸
        channels = 3,  # 通道数,默认为 3
        layers = 4,  # 层数,默认为 4
        l2_recon_loss = False,  # 是否使用 L2 重建损失,默认为 False
        use_hinge_loss = True,  # 是否使用 Hinge 损失,默认为 True
        vgg = None,  # VGG 模型,默认为 None
        lookup_free_quantization = True,  # 是否使用无查找表量化,默认为 True
        codebook_size = 65536,  # 代码簿大小,默认为 65536
        vq_kwargs: dict = dict(  # VQ 参数字典
            codebook_dim = 64,  # 代码簿维度,默认为 64
            decay = 0.9,  # 衰减率,默认为 0.9
            commitment_weight = 1.,  # 承诺权重,默认为 1.0
            kmeans_init = True  # 是否使用 K-means 初始化,默认为 True
        ),
        lfq_kwargs: dict = dict(  # LFQ 参数字典
            entropy_loss_weight = 0.1,  # 熵损失权重,默认为 0.1
            diversity_gamma = 2.  # 多样性参数,默认为 2.0
        ),
        use_vgg_and_gan = True,  # 是否使用 VGG 和 GAN,默认为 True
        discr_layers = 4,  # 判别器层数,默认为 4
        **kwargs  # 其他参数
    ):
        super().__init__()  # 调用父类初始化函数
        vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)  # 根据前缀 'vq_' 对参数进行分组
        encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)  # 根据前缀 'encdec_' 对参数进行分组

        self.image_size = image_size  # 图像尺寸
        self.channels = channels  # 通道数
        self.codebook_size = codebook_size  # 代码簿大小

        # 创建 ViTEncDec 实例
        self.enc_dec = ViTEncDec(
            dim = dim,
            image_size = image_size,
            channels = channels,
            layers = layers,
            **encdec_kwargs
        )

        # 提供无查找表量化
        self.lookup_free_quantization = lookup_free_quantization

        if lookup_free_quantization:
            # 创建 LFQ 实例
            self.quantizer = LFQ(
                dim = self.enc_dec.encoded_dim,
                codebook_size = codebook_size,
                **lfq_kwargs
            )
        else:
            # 创建 VQ 实例
            self.quantizer = VQ(
                dim = self.enc_dec.encoded_dim,
                codebook_size = codebook_size,
                accept_image_fmap = True,
                use_cosine_sim = True,
                **vq_kwargs
            )

        # 重��损失函数
        self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss

        # 如果是灰度图像,则关闭 GAN 和感知损失
        self.vgg = None
        self.discr = None
        self.use_vgg_and_gan = use_vgg_and_gan

        if not use_vgg_and_gan:
            return

        # 感知损失
        if exists(vgg):
            self.vgg = vgg
        else:
            self.vgg = torchvision.models.vgg16(pretrained = True)
            self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])

        # GAN 相关损失
        layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        # 创建判别器实例
        self.discr = Discriminator(dims = dims, channels = channels)

        self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
        self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss

    @property
    def encoded_dim(self):
        return self.enc_dec.encoded_dim

    # 获取编码特征图大小
    def get_encoded_fmap_size(self, image_size):
        return self.enc_dec.get_encoded_fmap_size(image_size)

    # 用于评估的模型副本
    def copy_for_eval(self):
        device = next(self.parameters()).device
        vae_copy = copy.deepcopy(self.cpu())

        if vae_copy.use_vgg_and_gan:
            del vae_copy.discr
            del vae_copy.vgg

        vae_copy.eval()
        return vae_copy.to(device)

    # 状态字典函数
    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    # 加载状态字典函数
    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    # 从代码簿获取特征图
    def get_fmap_from_codebook(self, indices):
        if self.lookup_free_quantization:
            indices, ps = pack([indices], 'b *')
            fmap = self.quantizer.indices_to_codes(indices)
            fmap, = unpack(fmap, ps, 'b * c')
        else:
            codes = self.quantizer.codebook[indices]
            fmap = self.vq.project_out(codes)

        return rearrange(fmap, 'b h w c -> b c h w')
    # 编码输入特征图,返回编码后的特征图、索引和量化器辅助损失
    def encode(self, fmap, return_indices_and_loss = True):
        # 使用编码器对特征图进行编码
        fmap = self.enc_dec.encode(fmap)

        # 对编码后的特征图进行量化
        fmap, indices, quantizer_aux_loss = self.quantizer(fmap)

        # 如果不需要返回索引和损失,则直接返回编码后的特征图
        if not return_indices_and_loss:
            return fmap

        # 返回编码后的特征图、索引和量化器辅助损失
        return fmap, indices, quantizer_aux_loss

    # 解码特征图
    def decode(self, fmap):
        return self.enc_dec.decode(fmap)

    # 前向传播函数
    def forward(
        self,
        img,
        return_loss = False,
        return_discr_loss = False,
        return_recons = False,
        apply_grad_penalty = True
    ):
        # 获取输入图像的批次大小、通道数、高度、宽度和设备信息
        batch, channels, height, width, device = *img.shape, img.device
        # 检查输入图像的高度和宽度是否与设定的图像大小相等
        assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
        # 检查输入图像的通道数是否与VQGanVAE中设置的通道数相等
        assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'

        # 对输入图像进行编码,返回编码后的特征图、索引和损失
        fmap, indices, commit_loss = self.encode(img, return_indices_and_loss = True)

        # 对编码后的特征图进行解码
        fmap = self.decode(fmap)

        # 如果不需要返回损失和判别器损失,则直接返回解码后的特征图
        if not return_loss and not return_discr_loss:
            return fmap

        # 确保只返回自编码器损失或判别器损失,而不是两者都返回
        assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'

        # 是否返回判别器损失
        if return_discr_loss:
            # 确保判别器存在以便训练
            assert exists(self.discr), 'discriminator must exist to train it'

            # 分离编码后的特征图,使其不参与梯度计算
            fmap.detach_()
            # 设置输入图像需要计算梯度
            img.requires_grad_()

            # 获取编码后特征图和输入图像的判别器logits
            fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))

            # 计算判别器损失
            discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

            # 如果应用梯度惩罚
            if apply_grad_penalty:
                # 计算梯度惩罚
                gp = gradient_penalty(img, img_discr_logits)
                loss = discr_loss + gp

            # 如果需要返回重构图像
            if return_recons:
                return loss, fmap

            return loss

        # 计算重构损失
        recon_loss = self.recon_loss_fn(fmap, img)

        # 如果不使用VGG和GAN
        if not self.use_vgg_and_gan:
            # 如果需要返回重构图像
            if return_recons:
                return recon_loss, fmap

            return recon_loss

        # 计算感知损失
        img_vgg_input = img
        fmap_vgg_input = fmap

        if img.shape[1] == 1:
            # 处理灰度图像用于VGG
            img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))

        # 获取输入图像和解码后特征图的VGG特征
        img_vgg_feats = self.vgg(img_vgg_input)
        recon_vgg_feats = self.vgg(fmap_vgg_input)
        perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)

        # 生成器损失
        gen_loss = self.gen_loss(self.discr(fmap))

        # 计算自适应权重
        last_dec_layer = self.enc_dec.last_dec_layer

        norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
        norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

        adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
        adaptive_weight.clamp_(max = 1e4)

        # 组合损失
        loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss

        # 如果需要返回重构图像
        if return_recons:
            return loss, fmap

        return loss

.\lucidrains\parti-pytorch\parti_pytorch\vit_vqgan_trainer.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 copy 模块中导入 copy 函数
import copy
# 从 random 模块中导入 choice 函数
from random import choice
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree
# 从 PIL 模块中导入 Image 类
from PIL import Image

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.cuda.amp 模块中导入 autocast, GradScaler 类
from torch.cuda.amp import autocast, GradScaler
# 从 torch.utils.data 模块中导入 Dataset, DataLoader, random_split 类
from torch.utils.data import Dataset, DataLoader, random_split

# 导入 torchvision.transforms 模块,并重命名为 T
import torchvision.transforms as T
# 从 torchvision.datasets 模块中导入 ImageFolder 类
from torchvision.datasets import ImageFolder
# 从 torchvision.utils 模块中导入 make_grid, save_image 函数
from torchvision.utils import make_grid, save_image

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange

# 从 parti_pytorch.vit_vqgan 模块中导入 VitVQGanVAE 类
from parti_pytorch.vit_vqgan import VitVQGanVAE
# 从 parti_pytorch.optimizer 模块中导入 get_optimizer 函数
from parti_pytorch.optimizer import get_optimizer

# 从 ema_pytorch 模块中导入 EMA 类

from ema_pytorch import EMA

# 辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 空操作函数
def noop(*args, **kwargs):
    pass

# 生成数据循环的函数
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组的函数
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是否为是或否的函数
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积日志的函数
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 类

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 主训练器类

class VQGanVAETrainer(nn.Module):
    def __init__(
        self,
        vae,
        *,
        num_train_steps,
        batch_size,
        folder,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        ema_beta = 0.995,
        ema_update_after_step = 500,
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
        amp = False
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言确保 vae 是 VitVQGanVAE 的实例
        assert isinstance(vae, VitVQGanVAE), 'vae must be instance of VitVQGanVAE'
        # 获取 VAE 的图像大小
        image_size = vae.image_size

        # 设置 VAE 和 EMA VAE
        self.vae = vae
        self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)

        # 注册缓冲区 'steps',用于记录步数
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、批量大小、梯度累积频率
        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 获取所有参数、判别器参数、VAE 参数
        all_parameters = set(vae.parameters())
        discr_parameters = set(vae.discr.parameters())
        vae_parameters = all_parameters - discr_parameters

        # 获取优化器
        self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
        self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)

        # 设置混合精度训练相关参数
        self.amp = amp
        self.scaler = GradScaler(enabled = amp)
        self.discr_scaler = GradScaler(enabled = amp)

        # 创建数据集
        self.ds = ImageDataset(folder, image_size = image_size)

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 创建数据加载器
        self.dl = cycle(DataLoader(
            self.ds,
            batch_size = batch_size,
            shuffle = True
        ))

        self.valid_dl = cycle(DataLoader(
            self.valid_ds,
            batch_size = batch_size,
            shuffle = True
        ))

        # 设置保存模型和结果的频率
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置应用梯度惩罚的频率
        self.apply_grad_penalty_every = apply_grad_penalty_every

        # 设置结果文件夹
        self.results_folder = Path(results_folder)

        # 如果结果文件夹中有文件且用户确认清除,则清除文件夹
        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
    # 定义训练步骤函数
    def train_step(self):
        # 获取模型参数所在设备
        device = next(self.vae.parameters()).device
        # 获取当前步数
        steps = int(self.steps.item())
        # 是否应用梯度惩罚
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

        # 设置 VAE 模型为训练模式
        self.vae.train()

        # 初始化日志字典
        logs = {}

        # 更新 VAE(生成器)

        # 根据梯度累积次数进行更新
        for _ in range(self.grad_accum_every):
            # 获取下一个数据批次
            img = next(self.dl)
            img = img.to(device)

            # 开启自动混合精度
            with autocast(enabled = self.amp):
                # 计算损失
                loss = self.vae(
                    img,
                    return_loss = True,
                    apply_grad_penalty = apply_grad_penalty
                )

                # 反向传播并缩放损失
                self.scaler.scale(loss / self.grad_accum_every).backward()

            # 累积日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 梯度更新
        self.scaler.step(self.optim)
        self.scaler.update()
        self.optim.zero_grad()

        # 更新鉴别器

        if exists(self.vae.discr):
            self.discr_optim.zero_grad()

            discr_loss = 0
            for _ in range(self.grad_accum_every):
                img = next(self.dl)
                img = img.to(device)

                with autocast(enabled = self.amp):
                    loss = self.vae(img, return_discr_loss = True)

                    self.discr_scaler.scale(loss / self.grad_accum_every).backward()

                accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})

            self.discr_scaler.step(self.discr_optim)
            self.discr_scaler.update()

            # 打印日志
            print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")

        # 更新指数移动平均生成器
        self.ema_vae.update()

        # 定期采样结果

        if not (steps % self.save_results_every):
            for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
                model.eval()

                imgs = next(self.dl)
                imgs = imgs.to(device)

                recons = model(imgs)
                nrows = int(sqrt(self.batch_size))

                imgs_and_recons = torch.stack((imgs, recons), dim = 0)
                imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')

                imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
                grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))

                logs['reconstructions'] = grid

                save_image(grid, str(self.results_folder / f'{filename}.png'))

            print(f'{steps}: saving to {str(self.results_folder)}')

        # 定期保存模型

        if not (steps % self.save_model_every):
            state_dict = self.vae.state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.pt')
            torch.save(state_dict, model_path)

            ema_state_dict = self.ema_vae.state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
            torch.save(ema_state_dict, model_path)

            print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数并返回日志
        self.steps += 1
        return logs

    # 训练函数
    def train(self, log_fn = noop):
        # 获取模型参数所在设备
        device = next(self.vae.parameters()).device

        # 在训练步数未达到总训练步数前循环训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 训练完成
        print('training complete')

.\lucidrains\parti-pytorch\parti_pytorch\__init__.py

# 从 parti_pytorch.parti_pytorch 模块中导入 Parti 类
from parti_pytorch.parti_pytorch import Parti

# 从 parti_pytorch.vit_vqgan 模块中导入 VitVQGanVAE 类
from parti_pytorch.vit_vqgan import VitVQGanVAE

# 从 parti_pytorch.vit_vqgan_trainer 模块中导入 VQGanVAETrainer 类
from parti_pytorch.vit_vqgan_trainer import VQGanVAETrainer
posted @ 2024-06-28 14:04  绝不原创的飞龙  阅读(7)  评论(0编辑  收藏  举报