Lucidrains-系列项目源码解析-三十九-

Lucidrains 系列项目源码解析(三十九)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\spherical_harmonics.py

# 从 math 模块中导入 pi 和 sqrt 函数
# 从 functools 模块中导入 reduce 函数
# 从 operator 模块中导入 mul 函数
# 导入 torch 模块
from math import pi, sqrt
from functools import reduce
from operator import mul
import torch

# 从 functools 模块中导入 lru_cache 装饰器
# 从 se3_transformer_pytorch.utils 模块中导入 cache 函数
from functools import lru_cache
from se3_transformer_pytorch.utils import cache

# 定义常量 CACHE,初始化为空字典
CACHE = {}

# 清空球谐函数缓存
def clear_spherical_harmonics_cache():
    CACHE.clear()

# 定义函数 lpmv_cache_key_fn,用于生成缓存键
def lpmv_cache_key_fn(l, m, x):
    return (l, m)

# 定义函数 semifactorial,使用 lru_cache 装饰器缓存结果
@lru_cache(maxsize = 1000)
def semifactorial(x):
    return reduce(mul, range(x, 1, -2), 1.)

# 定义函数 pochhammer,使用 lru_cache 装饰器缓存结果
@lru_cache(maxsize = 1000)
def pochhammer(x, k):
    return reduce(mul, range(x + 1, x + k), float(x))

# 定义函数 negative_lpmv,计算负的球谐函数
def negative_lpmv(l, m, y):
    if m < 0:
        y *= ((-1) ** m / pochhammer(l + m + 1, -2 * m))
    return y

# 定义函数 lpmv,使用 cache 装饰器缓存结果
@cache(cache=CACHE, key_fn=lpmv_cache_key_fn)
def lpmv(l, m, x):
    """Associated Legendre function including Condon-Shortley phase.

    Args:
        m: int order 
        l: int degree
        x: float argument tensor
    Returns:
        tensor of x-shape
    """
    # 检查是否有缓存版本
    m_abs = abs(m)

    if m_abs > l:
        return None

    if l == 0:
        return torch.ones_like(x)
    
    if m_abs == l:
        y = (-1)**m_abs * semifactorial(2*m_abs-1)
        y *= torch.pow(1-x*x, m_abs/2)
        return negative_lpmv(l, m, y)

    lpmv(l-1, m, x)

    y = ((2*l-1) / (l-m_abs)) * x * lpmv(l-1, m_abs, x)

    if l - m_abs > 1:
        y -= ((l+m_abs-1)/(l-m_abs)) * CACHE[(l-2, m_abs)]
    
    if m < 0:
        y = negative_lpmv(l, m, y)
    return y

# 定义函数 get_spherical_harmonics_element,计算球谐函数元素
def get_spherical_harmonics_element(l, m, theta, phi):
    """Tesseral spherical harmonic with Condon-Shortley phase.

    The Tesseral spherical harmonics are also known as the real spherical
    harmonics.

    Args:
        l: int for degree
        m: int for order, where -l <= m < l
        theta: collatitude or polar angle
        phi: longitude or azimuth
    Returns:
        tensor of shape theta
    """
    m_abs = abs(m)
    assert m_abs <= l, "absolute value of order m must be <= degree l"

    N = sqrt((2*l + 1) / (4 * pi))
    leg = lpmv(l, m_abs, torch.cos(theta))

    if m == 0:
        return N * leg

    if m > 0:
        Y = torch.cos(m * phi)
    else:
        Y = torch.sin(m_abs * phi)

    Y *= leg
    N *= sqrt(2. / pochhammer(l - m_abs + 1, 2 * m_abs))
    Y *= N
    return Y

# 定义函数 get_spherical_harmonics,计算球谐函数
def get_spherical_harmonics(l, theta, phi):
    """ Tesseral harmonic with Condon-Shortley phase.

    The Tesseral spherical harmonics are also known as the real spherical
    harmonics.

    Args:
        l: int for degree
        theta: collatitude or polar angle
        phi: longitude or azimuth
    Returns:
        tensor of shape [*theta.shape, 2*l+1]
    """
    return torch.stack([get_spherical_harmonics_element(l, m, theta, phi) \
                        for m in range(-l, l+1)],
                        dim=-1)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\utils.py

# 导入必要的库
import os
import sys
import time
import pickle
import gzip
import torch
import contextlib
from functools import wraps, lru_cache
from filelock import FileLock
from einops import rearrange

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回唯一值列表
def uniq(arr):
    return list({el: True for el in arr}.keys())

# 返回指定阶数
def to_order(degree):
    return 2 * degree + 1

# 对字典的值应用函数
def map_values(fn, d):
    return {k: fn(v) for k, v in d.items()}

# 安全地拼接张量
def safe_cat(arr, el, dim):
    if not exists(arr):
        return el
    return torch.cat((arr, el), dim=dim)

# 将值转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else (val,) * depth

# 广播张量
def broadcat(tensors, dim=-1):
    num_tensors = len(tensors)
    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
    assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
    shape_len = list(shape_lens)[0]

    dim = (dim + shape_len) if dim < 0 else dim
    dims = list(zip(*map(lambda t: list(t.shape), tensors)))

    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
    assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
    expanded_dims.insert(dim, (dim, dims[dim]))
    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
    return torch.cat(tensors, dim=dim)

# 批量索引选择
def batched_index_select(values, indices, dim=1):
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    indices = indices[(..., *((None,) * len(value_dims))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# 掩码均值
def masked_mean(tensor, mask, dim=-1):
    diff_len = len(tensor.shape) - len(mask.shape)
    mask = mask[(..., *((None,) * diff_len))]
    tensor.masked_fill_(~mask, 0.)

    total_el = mask.sum(dim=dim)
    mean = tensor.sum(dim=dim) / total_el.clamp(min=1.)
    mean.masked_fill_(total_el == 0, 0.)
    return mean

# 生成均匀分布的随机张量
def rand_uniform(size, min_val, max_val):
    return torch.empty(size).uniform_(min_val, max_val)

# 快速分割张量
def fast_split(arr, splits, dim=0):
    axis_len = arr.shape[dim]
    splits = min(axis_len, max(splits, 1))
    chunk_size = axis_len // splits
    remainder = axis_len - chunk_size * splits
    s = 0
    for i in range(splits):
        adjust, remainder = 1 if remainder > 0 else 0, remainder - 1
        yield torch.narrow(arr, dim, s, chunk_size + adjust)
        s += chunk_size + adjust

# 傅立叶编码
def fourier_encode(x, num_encodings=4, include_self=True, flatten=True):
    x = x.unsqueeze(-1)
    device, dtype, orig_x = x.device, x.dtype, x
    scales = 2 ** torch.arange(num_encodings, device=device, dtype=dtype)
    x = x / scales
    x = torch.cat([x.sin(), x.cos()], dim=-1)
    x = torch.cat((x, orig_x), dim=-1) if include_self else x
    x = rearrange(x, 'b m n ... -> b m n (...)') if flatten else x
    return x

# 默认数据类型上下文管理器
@contextlib.contextmanager
def torch_default_dtype(dtype):
    prev_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(prev_dtype)

# 转换为 torch 张量的装饰器
def cast_torch_tensor(fn):
    @wraps(fn)
    # 定义一个内部函数 inner,接受一个参数 t
    def inner(t):
        # 如果 t 不是 torch 的张量,则将 t 转换为 torch 的张量,数据类型为默认数据类型
        if not torch.is_tensor(t):
            t = torch.tensor(t, dtype=torch.get_default_dtype())
        # 调用外部函数 fn,并传入处理后的张量 t
        return fn(t)
    # 返回内部函数 inner
    return inner
# benchmark 工具函数,用于计算函数执行时间
def benchmark(fn):
    # 内部函数,记录函数执行时间并返回结果
    def inner(*args, **kwargs):
        # 记录开始时间
        start = time.time()
        # 执行函数
        res = fn(*args, **kwargs)
        # 计算时间差
        diff = time.time() - start
        # 返回时间差和结果
        return diff, res
    return inner

# 缓存函数装饰器
def cache(cache, key_fn):
    # 内部函数,实现缓存功能
    def cache_inner(fn):
        @wraps(fn)
        # 内部函数,检查缓存并返回结果
        def inner(*args, **kwargs):
            # 生成缓存键名
            key_name = key_fn(*args, **kwargs)
            # 如果缓存中存在键名,则直接返回结果
            if key_name in cache:
                return cache[key_name]
            # 否则执行函数并将结果存入缓存
            res = fn(*args, **kwargs)
            cache[key_name] = res
            return res

        return inner
    return cache_inner

# 在目录中进行缓存
def cache_dir(dirname, maxsize=128):
    '''
    Cache a function with a directory

    :param dirname: the directory path
    :param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
    '''
    def decorator(func):

        @lru_cache(maxsize=maxsize)
        @wraps(func)
        # 内部函数,实现目录缓存功能
        def wrapper(*args, **kwargs):
            # 如果目录不存在,则直接执行函数
            if not exists(dirname):
                return func(*args, **kwargs)

            # 创建目录
            os.makedirs(dirname, exist_ok=True)

            indexfile = os.path.join(dirname, "index.pkl")
            lock = FileLock(os.path.join(dirname, "mutex"))

            with lock:
                index = {}
                # 如果索引文件存在,则加载索引
                if os.path.exists(indexfile):
                    with open(indexfile, "rb") as file:
                        index = pickle.load(file)

                key = (args, frozenset(kwargs), func.__defaults__)

                # 如果键存在于索引中,则获取文件名
                if key in index:
                    filename = index[key]
                else:
                    index[key] = filename = f"{len(index)}.pkl.gz"
                    with open(indexfile, "wb") as file:
                        pickle.dump(index, file)

            filepath = os.path.join(dirname, filename)

            # 如果文件存在,则加载结果
            if os.path.exists(filepath):
                with lock:
                    with gzip.open(filepath, "rb") as file:
                        result = pickle.load(file)
                return result

            print(f"compute {filename}... ", end="", flush=True)
            result = func(*args, **kwargs)
            print(f"save {filename}... ", end="", flush=True)

            with lock:
                with gzip.open(filepath, "wb") as file:
                    pickle.dump(result, file)

            print("done")

            return result
        return wrapper
    return decorator

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\__init__.py

# 从 se3_transformer_pytorch 库中导入 SE3Transformer 类
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

.\lucidrains\se3-transformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'se3-transformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  include_package_data = True,  # 包含所有数据文件
  version = '0.9.0',  # 版本号
  license='MIT',  # 许可证
  description = 'SE3 Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/se3-transformer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'equivariance',
    'SE3'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'filelock',
    'numpy',
    'torch>=1.6'
  ],
  setup_requires=[  # 设置需要的依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试需要的依赖
    'pytest',
    'lie_learn',
    'numpy',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\se3-transformer-pytorch\tests\test_basis.py

# 导入 torch 库
import torch
# 从 se3_transformer_pytorch.basis 模块中导入 get_basis, get_R_tensor, basis_transformation_Q_J 函数
from se3_transformer_pytorch.basis import get_basis, get_R_tensor, basis_transformation_Q_J
# 从 se3_transformer_pytorch.irr_repr 模块中导入 irr_repr 函数

# 定义测试函数 test_basis
def test_basis():
    # 设置最大阶数为 3
    max_degree = 3
    # 生成一个形状为 (2, 1024, 3) 的随机张量
    x = torch.randn(2, 1024, 3)
    # 调用 get_basis 函数获取基函数
    basis = get_basis(x, max_degree)
    # 断言基函数字典的长度是否为 (max_degree + 1) 的平方
    assert len(basis.keys()) == (max_degree + 1) ** 2, 'correct number of basis kernels'

# 定义测试函数 test_basis_transformation_Q_J
def test_basis_transformation_Q_J():
    # 生成一个形状为 (4, 3) 的随机角度张量
    rand_angles = torch.rand(4, 3)
    # 设置 J, order_out, order_in 的值为 1
    J, order_out, order_in = 1, 1, 1
    # 调用 basis_transformation_Q_J 函数获取变换矩阵 Q_J,并转换为浮点型
    Q_J = basis_transformation_Q_J(J, order_in, order_out).float()
    # 断言对于随机角度中的每个角度 (a, b, c),基函数变换矩阵和不可约表示矩阵的乘积是否与 Q_J 和不可约表示函数的乘积相近
    assert all(torch.allclose(get_R_tensor(order_out, order_in, a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in rand_angles)

.\lucidrains\se3-transformer-pytorch\tests\test_equivariance.py

import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch.irr_repr import rot
from se3_transformer_pytorch.utils import torch_default_dtype, fourier_encode

# 测试普通 SE3Transformer 模型
def test_transformer():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        valid_radius = 10
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    out = model(feats, coors, mask, return_type = 0)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试有因果性的 SE3Transformer 模型
def test_causal_se3_transformer():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        valid_radius = 10,
        causal = True
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    out = model(feats, coors, mask, return_type = 0)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试带全局节点的 SE3Transformer 模型
def test_se3_transformer_with_global_nodes():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        valid_radius = 10,
        global_feats_dim = 16
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    global_feats = torch.randn(1, 2, 16)

    out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试带单头键值对的 SE3Transformer 模型
def test_one_headed_key_values_se3_transformer_with_global_nodes():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        valid_radius = 10,
        global_feats_dim = 16,
        one_headed_key_values = True
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    global_feats = torch.randn(1, 2, 16)

    out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试带边的 SE3Transformer 模型
def test_transformer_with_edges():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        edge_dim = 4,
        num_edge_tokens = 4
    )

    feats = torch.randn(1, 32, 64)
    edges = torch.randint(0, 4, (1, 32))
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    out = model(feats, coors, mask, edges = edges, return_type = 0)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试带连续边的 SE3Transformer 模型
def test_transformer_with_continuous_edges():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_degrees = 2,
        output_degrees = 2,
        edge_dim = 34
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2))

    edges = fourier_encode(
        pairwise_continuous_values,
        num_encodings = 8,
        include_self = True
    )

    out = model(feats, coors, mask, edges = edges, return_type = 1)
    assert True

# 测试不同输入维度的 SE3Transformer 模型
def test_different_input_dimensions_for_types():
    model = SE3Transformer(
        dim_in = (4, 2),
        dim = 4,
        depth = 1,
        input_degrees = 2,
        num_degrees = 2,
        output_degrees = 2,
        reduce_dim_out = True
    )

    atom_feats  = torch.randn(2, 32, 4, 1)
    coors_feats = torch.randn(2, 32, 2, 3)

    features = {'0': atom_feats, '1': coors_feats}
    coors = torch.randn(2, 32, 3)
    mask  = torch.ones(2, 32).bool()

    refined_coors = coors + model(features, coors, mask, return_type = 1)
    assert True

# 测试等变性
def test_equivariance():
    # 创建一个 SE3Transformer 模型对象,设置参数:维度为64,深度为1,自我关注为True,邻居数量为4,角度数量为2,输出角度数量为2,距离进行傅立叶编码为True
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        fourier_encode_dist = True
    )

    # 生成一个大小为(1, 32, 64)的随机张量作为特征
    feats = torch.randn(1, 32, 64)
    # 生成一个大小为(1, 32, 3)的随机张量作为坐标
    coors = torch.randn(1, 32, 3)
    # 生成一个大小为(1, 32)的全为True的布尔张量作为掩码
    mask  = torch.ones(1, 32).bool()

    # 生成一个旋转矩阵 R,旋转角度为(15, 0, 45)
    R   = rot(15, 0, 45)
    # 使用模型对特征、经过旋转后的坐标、掩码进行前向传播,返回类型为1
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用模型对特征、原始坐标、掩码进行前向传播,返回类型为1,然后再乘以旋转矩阵 R
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算两个输出之间的最大差异
    diff = (out1 - out2).max()
    # 断言差异小于1e-4,如果不成立则抛出异常 'is not equivariant'
    assert diff < 1e-4, 'is not equivariant'
# 测试具有 EGNN 骨干的等变性
def test_equivariance_with_egnn_backbone():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        fourier_encode_dist = True,
        use_egnn = True
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标进行模型推理
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用旋转后的特征进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试旋转
def test_rotary():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        fourier_encode_dist = True,
        rotary_position = True,
        rotary_rel_dist = True
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标进行模型推理
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用旋转后的特征进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试等变性线性投影键
def test_equivariance_linear_proj_keys():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        fourier_encode_dist = True,
        linear_proj_keys = True
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标进行模型推理
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用旋转后的特征进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试仅稀疏邻居的等变性
@torch_default_dtype(torch.float64)
def test_equivariance_only_sparse_neighbors():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_degrees = 2,
        output_degrees = 2,
        num_neighbors = 0,
        attend_sparse_neighbors = True,
        num_adj_degrees = 2,
        adj_dim = 4
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成邻接矩阵
    seq = torch.arange(32)
    adj_mat = (seq[:, None] >= (seq[None, :] - 1)) & (seq[:, None] <= (seq[None, :] + 1))

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标和邻接矩阵进行模型推理
    out1 = model(feats, coors @ R, mask, adj_mat = adj_mat, return_type = 1)
    # 使用旋转后的特征和邻接矩阵进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试具有可逆网络的等变性
def test_equivariance_with_reversible_network():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        reversible = True
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标进行模型推理
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用旋转后的特征进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试具有类型一输入的等变性
def test_equivariance_with_type_one_input():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        input_degrees = 2,
        output_degrees = 2
    )

    # 生成随机原子特征和预测坐标
    atom_features = torch.randn(1, 32, 64, 1)
    pred_coors = torch.randn(1, 32, 64, 3)

    # 生成随机坐标和掩码
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标和预测坐标进行模型推理
    out1 = model({'0': atom_features, '1': pred_coors @ R}, coors @ R, mask, return_type = 1)
    # 使用旋转后的原子特征和预测坐标进行模型推理,然后再旋转输出
    out2 = model({'0': atom_features, '1': pred_coors}, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

.\lucidrains\se3-transformer-pytorch\tests\test_irrep_repr.py

# 导入 torch 库
import torch
# 从 se3_transformer_pytorch.spherical_harmonics 模块中导入 clear_spherical_harmonics_cache 函数
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache
# 从 se3_transformer_pytorch.irr_repr 模块中导入 spherical_harmonics, irr_repr, compose 函数
from se3_transformer_pytorch.irr_repr import spherical_harmonics, irr_repr, compose
# 从 se3_transformer_pytorch.utils 模块中导入 torch_default_dtype 函数
from se3_transformer_pytorch.utils import torch_default_dtype

# 使用 torch.float64 作为默认数据类型
@torch_default_dtype(torch.float64)
# 定义测试函数 test_irr_repr
def test_irr_repr():
    """
    This test tests that
    - irr_repr
    - compose
    - spherical_harmonics
    are compatible

    Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
    with x = Z(a) Y(b) eta
    """
    # 循环遍历阶数范围为 0 到 6
    for order in range(7):
        # 生成两个随机数 a, b
        a, b = torch.rand(2)
        # 生成三个随机数 alpha, beta, gamma
        alpha, beta, gamma = torch.rand(3)

        # 计算 compose(alpha, beta, gamma, a, b, 0) 的结果
        ra, rb, _ = compose(alpha, beta, gamma, a, b, 0)
        # 计算 spherical_harmonics(order, ra, rb) 的结果
        Yrx = spherical_harmonics(order, ra, rb)
        # 清除球谐函数缓存
        clear_spherical_harmonics_cache()

        # 计算 spherical_harmonics(order, a, b) 的结果
        Y = spherical_harmonics(order, a, b)
        # 清除球谐函数缓存
        clear_spherical_harmonics_cache()

        # 计算 irr_repr(order, alpha, beta, gamma) @ Y 的结果
        DrY = irr_repr(order, alpha, beta, gamma) @ Y

        # 计算 (Yrx - DrY).abs().max() 和 Y.abs().max() 的最大值
        d, r = (Yrx - DrY).abs().max(), Y.abs().max()
        # 打印结果
        print(d.item(), r.item())
        # 断言 d < 1e-10 * r,如果不成立则抛出异常
        assert d < 1e-10 * r, d / r

.\lucidrains\se3-transformer-pytorch\tests\test_spherical_harmonics.py

# 导入必要的库
import time
import torch
import numpy as np

# 从 lie_learn 库中导入 spherical_harmonics 函数
from lie_learn.representations.SO3.spherical_harmonics import sh

# 从 se3_transformer_pytorch 库中导入 get_spherical_harmonics_element 和 benchmark 函数
from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics_element
from se3_transformer_pytorch.utils import benchmark

# 定义测试 spherical_harmonics 函数
def test_spherical_harmonics():
    # 设置数据类型为 torch.float64
    dtype = torch.float64

    # 生成随机的 theta 和 phi 数据
    theta = 0.1 * torch.randn(32, 1024, 10, dtype=dtype)
    phi = 0.1 * torch.randn(32, 1024, 10, dtype=dtype)

    # 初始化变量
    s0 = s1 = 0
    max_error = -1.

    # 循环遍历 l 和 m 的取值范围
    for l in range(8):
        for m in range(-l, l + 1):
            # 记录开始时间
            start = time.time()

            # 使用 benchmark 函数计算 get_spherical_harmonics_element 函数的运行时间和输出
            diff, y = benchmark(get_spherical_harmonics_element)(l, m, theta, phi)
            # 将 y 转换为 torch.float32 类型
            y = y.type(torch.float32)
            s0 += diff

            # 使用 benchmark 函数计算 sh 函数的运行时间和输出
            diff, z = benchmark(sh)(l, m, theta, phi)
            s1 += diff

            # 计算误差
            error = np.mean(np.abs((y.cpu().numpy() - z) / z))
            max_error = max(max_error, error)
            print(f"l: {l}, m: {m} ", error)

    # 计算时间差异比率
    time_diff_ratio = s0 / s1

    # 断言最大误差小于 1e-4
    assert max_error < 1e-4, 'maximum error must be less than 1e-3'
    # 断言时间差异比率小于 1
    assert time_diff_ratio < 1., 'spherical harmonics must be faster than the one offered by lie_learn'

    # 打印最大误差和时间差异比率
    print(f"Max error: {max_error}")
    print(f"Time diff: {time_diff_ratio}")

Segformer - Pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch.

Install

$ pip install segformer-pytorch

Usage

For example, MiT-B0

import torch
from segformer_pytorch import Segformer

model = Segformer(
    dims = (32, 64, 160, 256),      # dimensions of each stage
    heads = (1, 2, 5, 8),           # heads of each stage
    ff_expansion = (8, 8, 4, 4),    # feedforward expansion factor of each stage
    reduction_ratio = (8, 4, 2, 1), # reduction ratio of each stage for efficient attention
    num_layers = 2,                 # num layers of each stage
    decoder_dim = 256,              # decoder dimension
    num_classes = 4                 # number of segmentation classes
)

x = torch.randn(1, 3, 256, 256)
pred = model(x) # (1, 4, 64, 64)  # output is (H/4, W/4) map of the number of segmentation classes

Make sure the keywords are at most a tuple of 4, as this repository is hard-coded to give the MiT 4 stages as done in the paper.

Citations

@misc{xie2021segformer,
    title   = {SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers}, 
    author  = {Enze Xie and Wenhai Wang and Zhiding Yu and Anima Anandkumar and Jose M. Alvarez and Ping Luo},
    year    = {2021},
    eprint  = {2105.15203},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\segformer-pytorch\segformer_pytorch\segformer_pytorch.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 functools 模块中导入 partial 函数
from functools import partial
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn、einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 模块中导入 rearrange、reduce 函数
from einops import rearrange, reduce
# 从 einops.layers.torch 模块中导入 Rearrange 类

# 定义一个函数,用于判断变量是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于将变量转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else (val,) * depth

# 定义一个类 DsConv2d,继承自 nn.Module 类
class DsConv2d(nn.Module):
    # 初始化方法
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
        super().__init__()
        # 定义一个神经网络序列
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    # 前向传播方法
    def forward(self, x):
        return self.net(x)

# 定义一个类 LayerNorm,继承自 nn.Module 类
class LayerNorm(nn.Module):
    # 初始化方法
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    # 前向传播方法
    def forward(self, x):
        std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (std + self.eps) * self.g + self.b

# 定义一个类 PreNorm,继承自 nn.Module 类
class PreNorm(nn.Module):
    # 初始化方法
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    # 前向传播方法
    def forward(self, x):
        return self.fn(self.norm(x))

# 定义一个类 EfficientSelfAttention,继承自 nn.Module 类
class EfficientSelfAttention(nn.Module):
    # 初始化方法
    def __init__(
        self,
        *,
        dim,
        heads,
        reduction_ratio
    ):
        super().__init__()
        self.scale = (dim // heads) ** -0.5
        self.heads = heads

        self.to_q = nn.Conv2d(dim, dim, 1, bias = False)
        self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride = reduction_ratio, bias = False)
        self.to_out = nn.Conv2d(dim, dim, 1, bias = False)

    # 前向传播方法
    def forward(self, x):
        h, w = x.shape[-2:]
        heads = self.heads

        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = heads), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h = heads, x = h, y = w)
        return self.to_out(out)

# 定义一个类 MixFeedForward,继承自 nn.Module 类
class MixFeedForward(nn.Module):
    # 初始化方法
    def __init__(
        self,
        *,
        dim,
        expansion_factor
    ):
        super().__init__()
        hidden_dim = dim * expansion_factor
        self.net = nn.Sequential(
            nn.Conv2d(dim, hidden_dim, 1),
            DsConv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, dim, 1)
        )

    # 前向传播方法
    def forward(self, x):
        return self.net(x)

# 定义一个类 MiT,继承自 nn.Module 类
class MiT(nn.Module):
    # 初始化方法
    def __init__(
        self,
        *,
        channels,
        dims,
        heads,
        ff_expansion,
        reduction_ratio,
        num_layers
    # 定义一个继承自 nn.Module 的类
    ):
        # 超类初始化
        super().__init__()
        # 定义每个阶段的卷积核大小、步长和填充
        stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))

        # 将通道数和维度组合成一个元组
        dims = (channels, *dims)
        # 将维度两两配对
        dim_pairs = list(zip(dims[:-1], dims[1:]))

        # 初始化阶段列表
        self.stages = nn.ModuleList([])

        # 遍历每个阶段的参数
        for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
            # 创建获取重叠补丁的对象
            get_overlap_patches = nn.Unfold(kernel, stride = stride, padding = padding)
            # 创建重叠补丁嵌入的卷积层
            overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)

            # 初始化层列表
            layers = nn.ModuleList([])

            # 根据层数循环创建自注意力和前馈网络层
            for _ in range(num_layers):
                layers.append(nn.ModuleList([
                    PreNorm(dim_out, EfficientSelfAttention(dim = dim_out, heads = heads, reduction_ratio = reduction_ratio)),
                    PreNorm(dim_out, MixFeedForward(dim = dim_out, expansion_factor = ff_expansion)),
                ]))

            # 将当前阶段的组件添加到阶段列表中
            self.stages.append(nn.ModuleList([
                get_overlap_patches,
                overlap_patch_embed,
                layers
            ]))

    # 前向传播函数
    def forward(
        self,
        x,
        return_layer_outputs = False
    ):
        # 获取输入张量的高度和宽度
        h, w = x.shape[-2:]

        # 初始化存储每个阶段输出的列表
        layer_outputs = []
        # 遍历每个阶段
        for (get_overlap_patches, overlap_embed, layers) in self.stages:
            # 对输入张量进行重叠补丁提取
            x = get_overlap_patches(x)

            # 计算补丁数量和比例
            num_patches = x.shape[-1]
            ratio = int(sqrt((h * w) / num_patches))
            x = rearrange(x, 'b c (h w) -> b c h w', h = h // ratio)

            # 对补丁进行嵌入
            x = overlap_embed(x)
            # 遍历当前阶段的每一层
            for (attn, ff) in layers:
                x = attn(x) + x
                x = ff(x) + x

            # 将当前阶段的��出添加到列表中
            layer_outputs.append(x)

        # 根据是否返回每个阶段的输出,选择返回值
        ret = x if not return_layer_outputs else layer_outputs
        return ret
class Segformer(nn.Module):
    # 定义 Segformer 类,继承自 nn.Module
    def __init__(
        self,
        *,
        dims = (32, 64, 160, 256),
        heads = (1, 2, 5, 8),
        ff_expansion = (8, 8, 4, 4),
        reduction_ratio = (8, 4, 2, 1),
        num_layers = 2,
        channels = 3,
        decoder_dim = 256,
        num_classes = 4
    ):
        # 初始化函数,接受一系列参数
        super().__init__()
        # 调用父类的初始化函数

        # 将参数转换为长度为4的元组
        dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
        # 断言参数长度为4
        assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'

        # 创建 MiT 模型
        self.mit = MiT(
            channels = channels,
            dims = dims,
            heads = heads,
            ff_expansion = ff_expansion,
            reduction_ratio = reduction_ratio,
            num_layers = num_layers
        )

        # 创建转换到融合层的模块列表
        self.to_fused = nn.ModuleList([nn.Sequential(
            nn.Conv2d(dim, decoder_dim, 1),
            nn.Upsample(scale_factor = 2 ** i)
        ) for i, dim in enumerate(dims)])

        # 创建转换到分割层的模块
        self.to_segmentation = nn.Sequential(
            nn.Conv2d(4 * decoder_dim, decoder_dim, 1),
            nn.Conv2d(decoder_dim, num_classes, 1),
        )

    def forward(self, x):
        # 前向传播函数
        layer_outputs = self.mit(x, return_layer_outputs = True)

        # 对每个输出应用转换到融合层的模块
        fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
        # 在通道维度上拼接融合后的输出
        fused = torch.cat(fused, dim = 1)
        # 返回分割层的输出
        return self.to_segmentation(fused)

.\lucidrains\segformer-pytorch\segformer_pytorch\__init__.py

# 从segformer_pytorch模块中导入Segformer类
from segformer_pytorch.segformer_pytorch import Segformer

.\lucidrains\segformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'segformer-pytorch', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.6', # 版本号
  license='MIT', # 许可证
  description = 'Segformer - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/segformer-pytorch', # 项目链接
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'image segmentation'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Self-Rewarding Language Model

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI

They really took the title of the DPO paper to heart.

This library also contains an implementation of SPIN, which Teknium of Nous Research has expressed optimism for.

Appreciation

Install

$ pip install self-rewarding-lm-pytorch

Usage

import torch
from torch import Tensor

from self_rewarding_lm_pytorch import (
    SelfRewardingTrainer,
    create_mock_dataset
)

from x_transformers import TransformerWrapper, Decoder

transformer = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 1,
        heads = 8
    )
)

sft_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1)))
prompt_dataset = create_mock_dataset(100, lambda: 'mock prompt')

def decode_tokens(tokens: Tensor) -> str:
    decode_token = lambda token: str(chr(max(32, token)))
    return ''.join(list(map(decode_token, tokens)))

def encode_str(seq_str: str) -> Tensor:
    return Tensor(list(map(ord, seq_str)))

trainer = SelfRewardingTrainer(
    transformer,
    finetune_configs = dict(
        train_sft_dataset = sft_dataset,
        self_reward_prompt_dataset = prompt_dataset,
        dpo_num_train_steps = 1000
    ),
    tokenizer_decode = decode_tokens,
    tokenizer_encode = encode_str,
    accelerate_kwargs = dict(
        cpu = True
    )
)

trainer(overwrite_checkpoints = True)

# checkpoints after each finetuning stage will be saved to ./checkpoints

SPIN can be trained as follows - it can also be added to the fine-tuning pipeline as shown in the final example in the readme.

import torch

from self_rewarding_lm_pytorch import (
    SPINTrainer,
    create_mock_dataset
)

from x_transformers import TransformerWrapper, Decoder

transformer = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

sft_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1)))

spin_trainer = SPINTrainer(
    transformer,
    max_seq_len = 16,
    train_sft_dataset = sft_dataset,
    checkpoint_every = 100,
    spin_kwargs = dict(
        λ = 0.1,
    ),
)

spin_trainer()

Say you want to experiment with your own reward prompt (other than LLM-as-Judge). First you need to import the RewardConfig, next pass it into the trainer as reward_prompt_config


# first import

from self_rewarding_lm_pytorch import RewardConfig

# then say you want to try asking the transformer nicely

# reward_regex_template is the string that will be looked for in the LLM response, for parsing out the reward where {{ reward }} is defined as a number

trainer = SelfRewardingTrainer(
    transformer,
    ...,
    self_reward_prompt_config = RewardConfig(
        prompt_template = """
        Pretty please rate the following user prompt and response
        User: {{ prompt }}
        Response: {{ response }}

        Format your score as follows:
        Rating: <rating as integer from 0 - 10>
        """,
        reward_regex_template = """
        Rating: {{ reward }}
        """
    )
)

Finally, if you would like to experiment with arbitrary orders of fine-tuning, you will also have that flexiblity, by passing in FinetuneConfig instances into finetune_configs as a list

ex. say you want to carry out research on interleaving SPIN, External Rewarding, and Self-Rewarding

This idea originated from Teknium from a private discord channel.


# import the configs

from self_rewarding_lm_pytorch import (
    SFTConfig,
    SelfRewardDPOConfig,
    ExternalRewardDPOConfig,
    SelfPlayConfig,
)

trainer = SelfRewardingTrainer(
    model,
    finetune_configs = [
        SFTConfig(...),
        SelfPlayConfig(...),
        ExternalRewardDPOConfig(...),
        SelfRewardDPOConfig(...),
        SelfPlayConfig(...),
        SelfRewardDPOConfig(...)
    ],
    ...
)

trainer()

# checkpoints after each finetuning stage will be saved to ./checkpoints

Todo

Citation

@misc{yuan2024selfrewarding,
    title   = {Self-Rewarding Language Models}, 
    author  = {Weizhe Yuan and Richard Yuanzhe Pang and Kyunghyun Cho and Sainbayar Sukhbaatar and Jing Xu and Jason Weston},
    year    = {2024},
    eprint  = {2401.10020},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{Chen2024SelfPlayFC,
    title   = {Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models},
    author  = {Zixiang Chen and Yihe Deng and Huizhuo Yuan and Kaixuan Ji and Quanquan Gu},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2401.01335},
    url     = {https://api.semanticscholar.org/CorpusID:266725672}
}
@article{Rafailov2023DirectPO,
    title   = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
    author  = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.18290},
    url     = {https://api.semanticscholar.org/CorpusID:258959321}
}
@inproceedings{Guo2024DirectLM,
    title   = {Direct Language Model Alignment from Online AI Feedback},
    author  = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Rame and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:267522951}
}

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\dpo.py

# 导入必要的库
import os
from pathlib import Path
from copy import deepcopy
from functools import lru_cache
from collections import namedtuple
from dataclasses import dataclass

# 导入类型提示相关库
from beartype import beartype
from beartype.typing import Optional, Callable, Union, List
from torchtyping import TensorType

# 导入 PyTorch 相关库
import torch
from torch.nn import Module, Dropout
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist

# 导入加速库
from accelerate import Accelerator

# 导入 einops 和 einx 库
from einops import rearrange
from einx import get_at

# 导入 numpy 相关库
from numpy.lib.format import open_memmap

# 导入自定义工具函数
from pytorch_custom_utils import (
    get_adam_optimizer,
    OptimizerWithWarmupSchedule
)

# 导入加速相关工具函数
from pytorch_custom_utils.accelerate_utils import (
    model_forward_contexts
)

# 导入自定义工具函数
from pytorch_custom_utils.utils import (
    masked_mean,
    maybe_and_mask
)

# 导入进度条库
from tqdm import tqdm

# 导入 EMA 库
from ema_pytorch import EMA

# 定义辅助函数

# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量值,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 生成循环迭代器
def cycle(dl):
    while True:
        for batch in dl:
            yield batch

# 判断是否处于分布式环境
@lru_cache(maxsize=None)
def is_distributed():
    return dist.is_initialized() and dist.get_world_size() > 1

# 从模型和序列中获取对数概率
def log_prob_from_model_and_seq(model, seq):
    logits = model(seq)
    log_probs = logits.log_softmax(dim=-1)
    return get_at('b n [c], b n -> b n', log_probs, seq)

# 根据长度和序列生成掩码
def prompt_mask_from_len(lengths, seq):
    seq_len, device = seq.shape[-1], seq.device
    return torch.arange(seq_len, device=device) < rearrange(lengths, '... -> ... 1')

# 设置模型中的 Dropout 概率
def set_dropout_(model: Module, prob: float):
    for module in model.modules():
        if isinstance(module, Dropout):
            module.p = prob

# 使用线性衰减的 Adam 优化器
def adam_optimizer_with_linear_decay(
    model: Module,
    start_learning_rate: float,
    end_learning_rate: float,
    num_decay_steps: int,
    accelerator: Accelerator,
    weight_decay: float,
    adam_kwargs: dict = dict(),
) -> OptimizerWithWarmupSchedule:

    adam = get_adam_optimizer(
        model.parameters(),
        lr=start_learning_rate,
        wd=weight_decay
    )

    scheduler = None
    if start_learning_rate != end_learning_rate:
        scheduler = LinearLR

    return OptimizerWithWarmupSchedule(
        optimizer=adam,
        accelerator=accelerator,
        scheduler=LinearLR,
        scheduler_kwargs=dict(
            start_factor=1.,
            end_factor=end_learning_rate / start_learning_rate,
            total_iters=num_decay_steps
        )
    )

# 提前停止

# 定义提前停止返回结果的数据类
@dataclass
class EarlyStopperReturn:
    should_stop: bool
    score: float

# 提前停止类
class EarlyStopper(Module):
    @beartype
    def __init__(
        self,
        model: Module,
        evaluator: Module,
        accelerator: Accelerator,
        calculate_should_stop: Callable[..., bool] = lambda scores: len(scores) > 1 and scores[-1] > scores[-2],
        early_stop_checkpoint_folder: str = './early-stop-checkpoint'
    ):
        super().__init__()
        self.model = model
        self.evaluator = evaluator
        self.accelerator = accelerator

        self.scores: List[Union[int, float]] = []
        self.calculate_should_stop = calculate_should_stop

        self.early_stop_checkpoint_folder = Path(early_stop_checkpoint_folder)
        self.early_stop_checkpoint_folder.mkdir(exist_ok=True, parents=True)

        self.register_buffer('break_signal', torch.tensor(0.))

    # 清空提前停止检查点文件夹
    def clear_early_checkpoint_folder(self):
        for file in self.early_stop_checkpoint_folder.glob('*.pt'):
            os.remove(file)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()
    # 保存模型的状态到指定路径
    def save(self, path: str, overwrite: bool = False):
        # 等待所有操作完成
        self.wait()

        # 如果是主进程
        if self.is_main:

            # 设置保存路径为早停检查点文件夹下的指定路径
            path = self.early_stop_checkpoint_folder / path

            # 如果文件已存在且不允许覆盖,则抛出异常
            assert not path.exists() or overwrite, f'file already exists'

            # 构建保存的数据包,包含模型的状态字典
            pkg = dict(
                model = self.model.state_dict()
            )

            # 保存数据包到指定路径
            torch.save(pkg, str(path))

    # 前向传播函数,返回早停器的结果
    @torch.no_grad()
    def forward(self) -> EarlyStopperReturn:
        # 设置模型为评估模式
        self.model.eval()

        score = None

        # 如果是主进程
        if self.is_main:

            # 计算模型的评估分数
            score = self.evaluator(self.model)

            # 如果评估分数是张量
            if torch.is_tensor(score):
                # 确保张量元素个数为1
                assert score.numel() == 1
                # 将张量展平为标量
                score = score.flatten().item()

            # 确保评估分数为整数或浮点数
            assert isinstance(score, (int, float))

            # 将评估分数添加到分数列表中
            self.scores.append(score)

            # 计算是否应该停止训练
            should_stop = self.calculate_should_stop(self.scores)

            # 如果应该停止,则设置中断信号为1
            if should_stop:
                self.break_signal.copy_(1.)

        # 处理分布式环境下的早停中断信号
        if is_distributed():
            dist.all_reduce(self.break_signal)
            should_stop = self.break_signal.item() == 1.

        # 处理在评估分数下降之前恢复到检查点的逻辑
        if should_stop:
            # 获取上一个评估分数对应的检查点文件名
            prev_checkpoint_filename = f'model.ckpt.{len(self.scores) - 1}.pt'
            prev_checkpoint_path = self.early_stop_checkpoint_folder / prev_checkpoint_filename
            # 加载上一个检查点的模型状态
            pkg = torch.load(str(prev_checkpoint_path))

            self.model.load_state_dict(pkg['model'])
        else:
            # 生成当前评估分数对应的检查点文件名,并保存当前模型状态
            checkpoint_filename = f'model.ckpt.{len(self.scores)}.pt'
            self.save(checkpoint_filename)

        # 返回早停器的结果,包括评估分数和是否应该停止训练的标志
        return EarlyStopperReturn(score, self.break_signal.item() == 1)
# 从两个 memmap numpy 文件中读取数据集

# 数据集包含首选和非首选序列的形状 - (<样本数>, <偏好 (2) - 首选后跟非首选>, <序列长度>)
# 提示长度 (<样本数>,)

class DPODataset(Dataset):
    def __init__(
        self,
        data_folder: str = './',
        preference_seq_memmap_file: str = 'preference_seq.memmap.npy',
        prompt_len_memmap_file: str = 'prompt_len.memmap.npy',
    ):
        self.data_folder = Path(data_folder)
        assert self.data_folder.exists() and self.data_folder.is_dir()

        preference_seq_memmap_path = self.data_folder / preference_seq_memmap_file
        prompt_len_memmap_path = self.data_folder / prompt_len_memmap_file

        assert preference_seq_memmap_path.exists()
        assert prompt_len_memmap_path.exists()

        self.paired_sequences = open_memmap(str(preference_seq_memmap_path), dtype = 'int', mode = 'r')
        self.prompt_len = open_memmap(str(prompt_len_memmap_path), dtype = 'int', mode = 'r')

        self.seq_len = self.paired_sequences.shape[1]
        assert self.paired_sequences.shape[0] == self.prompt_len.shape[0]

    def __len__(self):
        return self.paired_sequences.shape[0]

    def __getitem__(self, idx):
        sequences = self.paired_sequences[idx].copy()
        prompt_lens = self.prompt_len[idx].copy()

        preferred_seq, unpreferred_seq = sequences

        return preferred_seq, unpreferred_seq, prompt_lens

# 主类

class DPO(Module):
    def __init__(
        self,
        model: Module,
        *,
        beta = 0.1,
        ref_model_ema_decay = 1.,
        pad_id: Optional[int] = None,
        ema_kwargs: dict = dict()
    ):
        super().__init__()
        self.policy_model = model

        self.ref_model = EMA(
            model,
            beta = ref_model_ema_decay,
            **ema_kwargs
        )

        self.beta = beta
        self.pad_id = pad_id

    def update_reference_model_with_policy(self):
        self.ref_model.copy_params_from_model_to_ema()

    def update_ema(self):
        self.ref_model.update()

    def parameters(self):
        return self.policy_model.parameters()

    @property
    def device(self):
        return next(self.parameters()).device

    @autocast(enabled = False)
    def forward(
        self,
        preferred_seq: TensorType['b', 'n', int],
        unpreferred_seq: TensorType['b', 'n', int],
        prompt_len: TensorType['b', int],
        preferred_seq_mask: Optional[TensorType['b', 'n', bool]] = None,
        unpreferred_seq_mask: Optional[TensorType['b', 'n', bool]] = None
    # 设置策略模型为训练模式
    self.policy_model.train()

    """
    b - batch
    n - sequence length
    """

    # 根据提示长度和首选/非首选序列生成掩码
    preferred_prompt_mask = prompt_mask_from_len(prompt_len, preferred_seq)
    unpreferred_prompt_mask = prompt_mask_from_len(prompt_len, unpreferred_seq)

    """
    Following Appendix B in https://arxiv.org/abs/2305.18290
    """

    # 如果存在填充 ID
    if exists(self.pad_id):
        # 确保首选序列掩码和非首选序列掩码不存在
        assert not exists(preferred_seq_mask)
        assert not exists(unpreferred_seq_mask)
        # 创建首选序列掩码
        preferred_seq_mask = preferred_seq != self.pad_id
        preferred_seq.masked_fill_(~preferred_seq_mask, 0)
        # 创建非首选序列掩码
        unpreferred_seq_mask = unpreferred_seq != self.pad_id
        unpreferred_seq.masked_fill_(~unpreferred_seq_mask, 0)            

    # 在不计算梯度的情况下执行以下操作
    with torch.no_grad():
        # 设置参考模型为评估模式
        self.ref_model.eval()
        # 计算首选序列和非首选序列在参考模型下的对数概率
        ref_preferred_logprob = log_prob_from_model_and_seq(self.ref_model, preferred_seq)
        ref_unpreferred_logprob = log_prob_from_model_and_seq(self.ref_model, unpreferred_seq)

    # 计算策略模型下首选序列和非首选序列的对数概率
    policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
    policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)

    # 计算掩码平均值

    # 对策略模型和参考模型下的首选序列和非首选序列的对数概率进行掩码平均值计算
    policy_preferred_logprob, ref_preferred_logprob = [masked_mean(seq, maybe_and_mask(preferred_seq_mask, ~preferred_prompt_mask)) for seq in (policy_preferred_logprob, ref_preferred_logprob)]
    policy_unpreferred_logprob, ref_unpreferred_logprob = [masked_mean(seq, maybe_and_mask(unpreferred_seq_mask, ~unpreferred_prompt_mask)) for seq in (policy_unpreferred_logprob, ref_unpreferred_logprob)]

    # 计算 DPO 损失

    # 计算策略模型和参考模型下的首选序列和非首选序列的对数概率之差
    policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob
    ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob

    # 计算损失值
    losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))

    # 返回损失值的平均值
    return losses.mean()
# trainer class

class DPOTrainer(Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        dpo: Union[DPO, Module],
        *,
        dataset_generator: Optional[Callable[[], Dataset]] = None,
        accelerator: Optional[Accelerator] = None,
        batch_size: int = 16,
        grad_accum_steps: int = 2,
        num_decay_steps: int = 1000,
        num_train_steps: Optional[int] = None,
        learning_rate: float = 3e-4,
        weight_decay: float = 0.,
        train_dataset: Optional[Dataset] = None,
        valid_dataset: Optional[Dataset] = None,
        start_learning_rate: float = 1e-6,
        end_learning_rate: float = 1e-7,
        early_stopper: Optional[EarlyStopper] = None,
        dropout: float = 0.1,
        check_early_stop_every: int = 200,
        early_stopper_eval_module: Optional[Module] = None,
        adam_kwargs: dict = dict(),
        accelerate_kwargs: dict = dict(),
        dpo_kwargs: dict = dict(
            beta = 0.1,
            ref_model_ema_decay = 1.
        ),
        early_stopper_kwargs: dict = dict()
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 如果 dpo 不是 DPO 类型,则使用 dpo_kwargs 创建 DPO 对象
        if not isinstance(dpo, DPO):
            dpo = DPO(dpo, **dpo_kwargs)

        # 设置 DPO 对象的 dropout
        set_dropout_(dpo, dropout)

        # 如果 accelerator 不存在,则使用 accelerate_kwargs 创建 Accelerator 对象
        if not exists(accelerator):
            accelerator = Accelerator(**accelerate_kwargs)

        # 设置 accelerator
        self.accelerator = accelerator

        # 准备模型
        self.model = accelerator.prepare(dpo)
        self.dropout = dropout

        # 设置数据集生成器
        self.dataset_generator = dataset_generator

        # 设置批量大小和梯度累积步数
        self.batch_size = batch_size
        self.grad_accum_steps = grad_accum_steps

        # 使用 adam_optimizer_with_linear_decay 创建优化器
        self.optimizer = adam_optimizer_with_linear_decay(
            dpo,
            start_learning_rate,
            end_learning_rate,
            num_decay_steps = num_decay_steps,
            accelerator = accelerator,
            weight_decay = weight_decay,
            adam_kwargs = adam_kwargs
        )

        # 如果存在 early_stopper_eval_module,则创建 EarlyStopper 对象
        self.early_stopper = None
        if exists(early_stopper_eval_module):
            self.early_stopper = EarlyStopper(
                dpo.policy_model,
                evaluator = early_stopper_eval_module,
                accelerator = self.accelerator,
                **early_stopper_kwargs
            )

        # 设置检查早停的频率
        self.check_early_stop_every = check_early_stop_every

        # 如果存在 train_dataset,则创建 DataLoader 对象
        self.train_dataloader = None
        if exists(train_dataset):
            self.train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
            self.train_dataloader = accelerator.prepare(self.train_dataloader)

        # 如果存在 valid_dataset,则创建 DataLoader 对象
        self.valid_dataloader = None
        if exists(valid_dataset):
            self.valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size)

        # 初始化步数和训练步数
        self.steps = 0
        self.num_train_steps = num_train_steps

    # 获取未包装的模型
    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 打印信息
    def print(self, *msg):
        self.accelerator.print(*msg)

    # 等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 记录日志
    def log(self, **data):
        self.accelerator.log(data, step = self.steps)

    # 前向传播方法
    def forward(
        self,
        train_self_reward_dataset: Optional[Dataset] = None
        ):
            # 检查是否存在数据集生成器,如果存在则生成训练用的自我奖励数据集
            if exists(self.dataset_generator):
                train_self_reward_dataset = self.dataset_generator()

            # 更新参考模型的策略
            self.model.update_reference_model_with_policy()

            # 如果存在早停器,清除早停检查点文件夹
            if exists(self.early_stopper):
                self.early_stopper.clear_early_checkpoint_folder()

            # 获取训练数据加载器
            train_dataloader = self.train_dataloader

            # 如果训练数据加载器不存在,则创建一个并准备好
            if not exists(train_dataloader):
                assert exists(train_self_reward_dataset)
                train_dataloader = DataLoader(train_self_reward_dataset, batch_size = self.batch_size, drop_last = True, shuffle = True)
                train_dataloader = self.accelerator.prepare(train_dataloader)

            # 创建数据加载器的迭代器
            iter_dl = cycle(train_dataloader)

            # 创建进度条
            pbar = tqdm(desc = 'dpo fine-tuning', total = self.num_train_steps)

            # 设置模型的 dropout
            set_dropout_(self.model, self.dropout)

            # 进入训练循环
            while True:
                self.model.train()

                # 遍历模型前向上下文
                for forward_context in model_forward_contexts(self.accelerator, self.model, self.grad_accum_steps):
                    with forward_context():
                        batch = next(iter_dl)

                        # 计算 DPO 损失
                        dpo_loss = self.model(*batch)
                        self.accelerator.backward(dpo_loss / self.grad_accum_steps)

                # 打印 DPO 损失值
                self.print(f'dpo loss: {dpo_loss.item():.3f}')
                self.log(loss = dpo_loss.item())

                # 执行优化器的步骤
                self.optimizer.step()
                self.optimizer.zero_grad()

                # 等待
                self.wait()

                # 更新指数移动平均模型
                self.unwrapped_model.update_ema()

                # 更新步数并更新进度条
                self.steps += 1
                pbar.update(1)

                # 如果达到训练步数上限,则结束训练
                if exists(self.num_train_steps) and self.steps >= self.num_train_steps:
                    break

                # 检查是否需要早停
                self.wait()

                if not (self.steps % self.check_early_stop_every) and exists(self.early_stopper):

                    # 执行早停逻辑
                    early_stop_return = self.early_stopper()

                    if self.is_main:
                        self.print(f'valid dpo loss: {early_stop_return.score:.3f}')
                        self.log(dpo_valid_score = early_stop_return.score)

                    if early_stop_return.should_stop:
                        self.print('early stopping')
                        break

            # 关闭进度条
            pbar.close()
            self.print('dpo training finished')

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\mocks.py

# 导入 functools 模块中的 wraps 函数
# 导入 typing 模块中的 Type 和 Any 类型
# 导入 torch.utils.data 模块中的 Dataset 类

from functools import wraps
from typing import Type, Any
from torch.utils.data import Dataset

# 创建一个装饰器函数,根据传入的值返回一个装饰器
def always(val):
    # 装饰器函数,接受一个函数作为参数
    def decorator(fn):
        # 内部函数,使用 functools 模块中的 wraps 函数装饰传入的函数
        @wraps(fn)
        # 接受任意参数并根据传入的值返回结果
        def inner(*args, **kwargs):
            # 如果传入的值是可调用的函数,则调用该函数并返回结果
            if callable(val):
                return val()

            # 否则直接返回传入的值
            return val
        return inner
    return decorator

# 创建一个模拟数据集的函数,根据传入的长度和输出值创建一个 Dataset 对象
def create_mock_dataset(
    length: int,
    output: Any
) -> Dataset:

    # 定义一个内部类 MockDataset,继承自 Dataset 类
    class MockDataset(Dataset):
        # 重写 __len__ 方法,返回传入的长度
        def __len__(self):
            return length

        # 重写 __getitem__ 方法,根据索引返回传入的输出值
        def __getitem__(self, idx):
            # 如果传入的输出值是可调用的函数,则调用该函数并返回结果
            if callable(output):
                return output()

            # 否则直接返回传入的输出值
            return output

    # 返回创建的 MockDataset 对象
    return MockDataset()

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\sampling_utils.py

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的函数模块
from torch import Tensor  # 导入 PyTorch 中的张量
from torch.nn import Module  # 导入 PyTorch 中的神经网络模块
from torch.nn.utils.rnn import pad_sequence  # 导入 PyTorch 中的序列填充函数

from beartype import beartype  # 导入 beartype 库中的类型检查装饰器
from beartype.typing import Optional, Callable, List, Tuple  # 导入 beartype 库中的类型注解

from einops import rearrange  # 导入 einops 库中的重排函数

def exists(v):  # 定义函数,判断变量是否存在
    return v is not None

def default(v, d):  # 定义函数,如果变量存在则返回变量,否则返回默认值
    return v if exists(v) else d

# 采样辅助函数

def log(t, eps = 1e-20):  # 定义函数,计算张量的对数
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):  # 定义函数,生成 Gumbel 噪声
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True, eps = 1e-10):  # 定义函数,进行 Gumbel 采样
    return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)

# nucleus

def top_p(logits, thres = 0.9):  # 定义函数,根据 top-p 算法进行筛选
    sorted_logits, sorted_indices = torch.sort(logits, descending = True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)

    sorted_indices_to_remove = cum_probs > thres
    sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk

def top_k(logits, frac_num_tokens = 0.1, k: Optional[int] = None):  # 定义函数,根据 top-k 算法进行筛选
    num_tokens = logits.shape[-1]

    k = default(k, ceil(frac_num_tokens * num_tokens))
    k = min(k, num_tokens)

    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 解码

@torch.no_grad()  # 禁用梯度计算
@beartype  # 使用 beartype 类型检查装饰器
def sample(  # 定义采样函数
    net: Module,  # 神经网络模型
    prompts,  # 输入的提示序列
    seq_len: int,  # 生成序列的长度
    temperature = 1.,  # 温度参数
    filter_fn: Callable = top_p,  # 筛选函数,默认为 top_p
    filter_kwargs: dict = dict(),  # 筛选函数的参数
    pad_id: int = -1,  # 填充标识���
    eos_id: Optional[int] = None,  # 结束标识符
    output_keep_prompt = False  # 是否保留提示序列
):
    device = next(net.parameters()).device  # 获取神经网络模型的设备
    net.eval()  # 设置神经网络模型为评估模式

    if isinstance(prompts, (tuple, list)):  # 如果提示序列是元组或列表
        prompts = pad_sequence(prompts, batch_first = True, padding_value = pad_id)  # 对提示序列进行填充

    batch, prompts_tensor_len = prompts.shape  # 获取提示序列的批次和长度

    batch_arange = torch.arange(batch, device = device)[..., None]  # 生成批次索引

    prompt_lens = (prompts != pad_id).sum(dim = -1)  # 计算提示序列的长度
    curr_seq_indices = prompt_lens[..., None]  # 当前序列索引

    out = prompts.clone()  # 克隆提示序列

    while (curr_seq_indices < seq_len).any():  # 当当前序列索引小于生成序列长度时循环
        out = F.pad(out, (0, 1), value = pad_id)  # 对输出序列进行填充

        net_input = out.masked_fill(out == pad_id, 0)  # 将填充部分替换为零

        logits = net(net_input)  # 输入神经网络模型,获取输出 logits

        logits = logits[batch_arange, curr_seq_indices]  # 获取当前序列的 logits
        logits = rearrange(logits, 'b 1 d -> b d')  # 重排 logits 的维度

        logits = filter_fn(logits, **filter_kwargs)  # 使用筛选函数对 logits 进行筛选
        sampled_tokens = gumbel_sample(logits, temperature = temperature, dim = -1)  # 使用 Gumbel 采样获取生成的 token

        out[batch_arange, curr_seq_indices] = sampled_tokens  # 将生成的 token 放入输出序列

        curr_seq_indices += 1  # 当前序列索引加一
        curr_seq_indices.clamp_(max = seq_len)  # 限制当前序列索引的最大值为生成序列长度

        if not exists(eos_id):  # 如果不存在结束标识符
            continue  # 继续循环

        is_eos_mask = out == eos_id  # 获取结束标识符的掩码
        all_eos = is_eos_mask.any(dim = -1).all()  # 判断是否所有序列都包含结束标识符

        if all_eos:  # 如果所有序列都包含结束标识符
            break  # 跳出循环

    out = out[:, :seq_len]  # 截取生成的序列长度为指定长度

    if exists(eos_id):  # 如果存在结束标识符
        is_eos_mask = out == eos_id  # 获取结束标识符的掩码
        after_eos_mask = F.pad(is_eos_mask.cumsum(dim = -1) > 0, (1, -1), value = False)  # 获取结束标识符后的掩码
        out = out.masked_fill_(after_eos_mask, pad_id)  # 将结束标识符后的部分替换为填充标识符

    if output_keep_prompt:  # 如果需要保留提示序列
        return out  # 返回生成的序列

    prompt_mask = torch.arange(out.shape[-1], device = device) < prompt_lens[..., None]  # 生成提示序列的掩码

    generated_seq_mask = out != pad_id & ~prompt_mask  # 生成序列的掩码
    seq_lens = generated_seq_mask.sum(dim = -1).tolist()  # 计算生成序列的长度

    return out[generated_seq_mask].split(seq_lens)  # 返回拆分后的生成序列

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\self_rewarding_lm_pytorch.py

# 导入所需的库
import re
import sys
from functools import partial
from random import randrange
from copy import deepcopy
from pathlib import Path
from dataclasses import dataclass, field
from functools import wraps
from textwrap import dedent

# 导入类型提示相关的库
from beartype import beartype
from beartype.typing import Optional, Dict, List, Tuple, Union, Callable
from torchtyping import TensorType

# 导入 PyTorch 相关的库
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Dropout
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# 导入 NumPy 相关的库
import numpy as np
from numpy.lib.format import open_memmap

# 导入自定义的模块
from self_rewarding_lm_pytorch.dpo import (
    DPO,
    DPODataset,
    DPOTrainer,
    EarlyStopper,
    set_dropout_,
    adam_optimizer_with_linear_decay
)

from self_rewarding_lm_pytorch.spin import (
    SPIN,
    SPINTrainer
)

from einops import rearrange, repeat

from accelerate import Accelerator

from pytorch_custom_utils.utils import pad_or_slice_to

from pytorch_custom_utils.accelerate_utils import (
    model_forward_contexts
)

from self_rewarding_lm_pytorch.sampling_utils import (
    sample,
    top_p,
    top_k
)

from self_rewarding_lm_pytorch.mocks import always

from tqdm import tqdm

# 如果系统是 32 位,则给出警告
if sys.maxsize <= (2 ** 32):
    print('you need to be on 64 bit system to use memmapped files of > 2GB')

# 基本模板引擎
import jinja2
jinja2_env = jinja2.Environment()

# 从 Jinja 模板中查找变量
def find_variables_from_jinja_template(template: str):
    from jinja2 import meta
    ast = jinja2_env.parse(template)
    return meta.find_undeclared_variables(ast)

# 辅助函数
# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 返回数组的第一个元素
def first(arr):
    return arr[0]

# 无限循环生成器
def cycle(dl):
    while True:
        for batch in dl:
            yield batch

# 返回输入本身
def identity(t, *args, **kwargs):
    return t

# 根据长度生成掩码
def prompt_mask_from_len(length, seq):
    seq_len, device = seq.shape[-1], seq.device
    return torch.arange(seq_len, device=device) < rearrange(length, '... -> ... 1')

# 将输入转换为元组
def cast_tuple(t, length=1, validate=False):
    out = t if isinstance(t, tuple) else ((t,) * length)
    assert not validate or len(out) == length
    return out

# 转换输入数据类型的装饰器
def cast_input(cast_fn):
    def decorator(fn):
        @wraps(fn)
        def inner(t, *args, **kwargs):
            t = cast_fn(t)
            return fn(t, *args, **kwargs)
        return inner

    return decorator

# 转换输出数据类型的装饰器
def cast_output(cast_fn):
    def decorator(fn):
        @wraps(fn)
        def output(*args, **kwargs):
            out = fn(*args, **kwargs)
            out = cast_fn(out)
            return out
        return output

    return decorator

# 常量
# llm-as-judge prompt
# https://openreview.net/forum?id=uccHPGDlao

# 默认的评分模板
DEFAULT_LLM_AS_JUDGE_PROMPT = """
Review the user’s question and the corresponding response using the additive 5-point
scoring system described below. Points are accumulated based on the satisfaction of each
criterion:
- Add 1 point if the response is relevant and provides some information related to
the user’s inquiry, even if it is incomplete or contains some irrelevant content.
- Add another point if the response addresses a substantial portion of the user’s question,
but does not completely resolve the query or provide a direct answer.
- Award a third point if the response answers the basic elements of the user’s question in a
useful way, regardless of whether it seems to have been written by an AI Assistant or if it
has elements typically found in blogs or search results.
- Grant a fourth point if the response is clearly written from an AI Assistant’s perspective,
addressing the user’s question directly and comprehensively, and is well-organized and
helpful, even if there is slight room for improvement in clarity, conciseness or focus.
- Bestow a fifth point for a response that is impeccably tailored to the user’s question
"""
# 定义默认的奖励正则表达式模板
DEFAULT_REWARD_REGEX_TEMPLATE = """
Score: {{ reward }}
"""

# 创建解析奖励函数,根据奖励正则表达式模板
def create_parse_reward_fn(reward_regex_template):
    # 确保奖励模板包含"score"变量
    assert find_variables_from_jinja_template(reward_regex_template) == {'reward'}, 'reward template must include "score" variable'
    # 渲染奖励正则表达式模板
    reward_regex_str = jinja2_env.from_string(reward_regex_template).render(reward = "([0-9\.]+)")

    # 解析奖励函数
    def parse_reward_fn(llm_response: str) -> float:
        # 使用正则表达式匹配奖励
        result = re.search(rf"{reward_regex_str}", llm_response)

        # 如果没有匹配结果或者没有分组
        if not exists(result) or result.groups == 0:
            return None

        # 如果匹配结果不是数字
        if not result.groups(1).isnumeric():
            return None

        # 返回解析后的奖励值
        return float(result.groups(1))

    return parse_reward_fn

# 奖励配置
@dataclass
class RewardConfig:
    prompt_template: str
    reward_regex_template: Optional[str] = None
    parse_reward: Optional[Callable[[str], Optional[float]]] = None
    template_fn: Optional[Callable[..., str]] = None
    auto_dedent: bool = True

    # 初始化函数
    def init(self):

        # 可能需要去除缩进
        if self.auto_dedent:
            self.prompt_template = dedent(self.prompt_template)

            # 如果奖励正则表达式模板存在,也需要去除缩进
            if exists(self.reward_regex_template):
                self.reward_regex_template = dedent(self.reward_regex_template)

        # 初始化用于渲染提示和响应模板的函数
        prompt_template = self.prompt_template
        assert find_variables_from_jinja_template(prompt_template) == {'prompt', 'response'}, 'template must include prompt and response templating variables'
        self.template_fn = jinja2_env.from_string(prompt_template).render

        # 如果没有传入解析奖励函数,则根据奖励正则表达式模板创建解析函数
        if not exists(self.parse_reward):
            assert exists(self.reward_regex_template), 'reward_regex_template must be given if parse_reward is not passed in'
            self.parse_reward = create_parse_reward_fn(self.reward_regex_template)

        return self

# 默认奖励提示配置
SELF_REWARD_PROMPT_CONFIG = dict(
    default = RewardConfig(
        prompt_template = DEFAULT_LLM_AS_JUDGE_PROMPT,
        reward_regex_template = DEFAULT_REWARD_REGEX_TEMPLATE
    )
)

# 默认的有效奖励对选择函数
default_is_valid_reward_pair = lambda preferred_reward, unpreferred_reward: (preferred_reward != unpreferred_reward).all()

# 默认的选择配对奖励函数
@beartype
def default_pick_paired_rewards_fn(rewards: Tensor):
    is_nan_mask = torch.isnan(rewards)
    rewards_max, rewards_min = rewards.clone(), rewards.clone()
    rewards_max[is_nan_mask] = -1e6
    rewards_min[is_nan_mask] = 1e6
    return torch.stack((rewards_max.argmax(dim = -1), rewards_min.argmin(dim = -1)))

# SFT训练器类
class SFTTrainer(Module):
    @beartype
    # 初始化模型训练器,设置各种参数
    def __init__(
        self,
        model: Module,
        *,
        accelerator: Accelerator,
        train_dataset: Union[List[Dataset], Dataset],
        valid_dataset: Optional[Dataset] = None,
        batch_size: int = 16,
        grad_accum_steps: int = 2,
        num_epochs: int = 3,
        start_learning_rate: float = 5.5e-6,
        end_learning_rate: float = 1.1e-6,
        learning_rate_num_decay_steps: Optional[int] = None,
        dropout: float = 0.,
        weight_decay: float = 0.,
        ignore_index: int = -1,
        adam_kwargs: dict = dict(),
        valid_every: int = 1
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置加速器和模型
        self.accelerator = accelerator
        self.model = model
        self.dropout = dropout

        self.num_epochs = num_epochs
        self.ignore_index = ignore_index

        # 如果训练数据集是列表,则将其合并为一个数据集
        if isinstance(train_dataset, list):
            train_dataset = ConcatDataset(train_dataset)

        # 创建训练数据加载器
        self.train_dataloader = DataLoader(train_dataset, batch_size = batch_size, drop_last = True, shuffle = True)

        # 计算总的训练步数
        self.num_train_steps = len(self.train_dataloader) // grad_accum_steps * num_epochs
        self.grad_accum_steps = grad_accum_steps

        # 准备模型和训练数据加载器
        (
            self.model,
            self.train_dataloader
        ) = self.accelerator.prepare(
            self.model,
            self.train_dataloader
        )

        # 如果学习率衰减步数不存在,则设置为训练数据集长度的一半
        if not exists(learning_rate_num_decay_steps):
            learning_rate_num_decay_steps = len(train_dataset) // 2

        # 创建优化器
        self.optimizer = adam_optimizer_with_linear_decay(
            model,
            start_learning_rate,
            end_learning_rate,
            num_decay_steps = learning_rate_num_decay_steps,
            accelerator = accelerator,
            weight_decay = weight_decay,
            adam_kwargs = adam_kwargs
        )

        self.valid_every = valid_every

        self.valid_dataloader = None
        # 如果验证数据集存在,则创建验证数据加载器
        if exists(valid_dataset):
            self.valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size)

        self.steps = 0

    # 记录日志
    def log(self, **data):
        self.accelerator.log(data, step = self.steps)

    # 等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 计算交叉熵损失
    def get_cross_entropy_loss(
        self,
        seq: TensorType['batch', 'seq', int],
        prompt_len_or_mask: Union[
            TensorType['batch', int],
            TensorType['batch', 'seq', bool]
        ]
    ):
        # 根据输入的 prompt_len_or_mask 类型,生成 prompt_mask
        if prompt_len_or_mask.dtype == torch.long:
            prompt_mask = prompt_mask_from_len(prompt_len_or_mask, seq)
        else:
            prompt_mask = prompt_len_or_mask

        # 将输入序列和标签序列分开
        seq, labels = seq[:, :-1], seq[:, 1:]

        # 根据 prompt_mask 填充标签
        labels.masked_fill_(prompt_mask[:, 1:], self.ignore_index)

        # 获取模型的预测结果
        logits = self.model(seq)

        # 计算交叉熵损失
        return F.cross_entropy(
            rearrange(logits, 'b n l -> b l n'),
            labels,
            ignore_index = self.ignore_index
        )
    # 定义 forward 方法,用于模型的前向传播
    def forward(self):
        
        # 从训练数据加载器中创建一个循环迭代器
        train_dl_iter = cycle(self.train_dataloader)
        
        # 设置模型中的 dropout 层
        set_dropout_(self.model, self.dropout)
        
        # 循环执行训练步骤
        for _ in tqdm(range(self.num_train_steps), desc='sft fine-tuning'):
            self.model.train()
            
            # 遍历模型前向传播上下文
            for forward_context in model_forward_contexts(self.accelerator, self.model, self.grad_accum_steps):
                with forward_context():
                    # 从训练数据加载器中获取下一个序列和提示长度或掩码
                    seq, prompt_len_or_mask = next(train_dl_iter)
                    
                    # 计算交叉熵损失
                    loss = self.get_cross_entropy_loss(seq, prompt_len_or_mask)
                    
                    # 反向传播计算梯度
                    self.accelerator.backward(loss / self.grad_accum_steps)
            
            # 更新优化器参数
            self.optimizer.step()
            self.optimizer.zero_grad()
            
            # 记录损失值
            self.log(loss=loss.item())
            
            # 更新步数
            self.steps += 1
            
            # 如果存在验证数据加载器且满足验证频率条件
            if exists(self.valid_dataloader) and not (step % self.valid_every):
                self.wait()
                
                # 如果是主进程
                if self.accelerator.is_main_process:
                    total_valid_loss = 0.
                    total_batches = 0.
                    
                    # 将模型设置为评估模式
                    self.model.eval()
                    
                    # 在无梯度计算的情况下进行验证
                    with torch.no_grad():
                        for seq, prompt_len_or_mask in self.valid_dataloader:
                            batch = seq.shape[0]
                            
                            # 计算验证集的交叉熵损失
                            loss = self.get_cross_entropy_loss(seq, prompt_len_or_mask)
                            
                            total_valid_loss += loss.item() * batch
                            total_batches += batch
                    
                    # 计算验证集的平均损失
                    valid_loss = total_valid_loss / total_batches
                    
                    # 记录验证集损失值
                    self.log(valid_loss=valid_loss)
# 定义一个 DPODatasetGenerator 类,用于生成奖励数据集

class DPODatasetGenerator(Module):
    # 初始化方法,接受以下参数
    @beartype
    def __init__(
        self,
        model: Module,  # 模型对象
        prompt_dataset: Dataset,  # 提示数据集
        num_preference_pairs: int,  # 偏好对数量
        accelerator: Accelerator,  # 加速器对象
        tokenizer_encode: Callable[[str], TensorType['seq', int]],  # 编码器函数
        tokenizer_decode: Callable[[TensorType['seq', int]], str],  # 解码器函数
        self_reward_model: Optional[Module] = None,  # 自我奖励模型,默认为 None
        batch_size: int = 16,  # 批处理大小,默认为 16
        num_candidate_responses: int = 4,  # 候选响应数量,默认为 4
        gen_temperature: float = 0.7,  # 生成温度,默认为 0.7
        gen_filter_fn = top_p,  # 生成过滤函数,默认为 top_p
        gen_filter_kwargs: dict = dict(thres = 0.9),  # 生成过滤函数的参数,默认为 {'thres': 0.9}
        eval_temperature: float = 0.7,  # 评估温度,默认为 0.7
        eval_filter_fn = top_p,  # 评估过滤函数,默认为 top_p
        eval_filter_kwargs: dict = dict(thres = 0.9),  # 评估过滤函数的参数,默认为 {'thres': 0.9}
        num_evals_to_average: int = 3,  # 平均评估次数,默认为 3
        *,
        reward_config: RewardConfig,  # 奖励配置对象
        reward_model: Optional[Module] = None,  # 奖励模型,默认为 None
        data_folder: str = './',  # 数据文件夹,默认为当前目录
        preference_seq_memmap_file: str = 'preference_seq.memmap.npy',  # 偏好序列内存映射文件名,默认为 'preference_seq.memmap.npy'
        prompt_len_memmap_file: str = 'prompt_len.memmap.npy',  # 提示长度内存映射文件名,默认为 'prompt_len.memmap.npy'
        self_reward_memmap_file: str = 'self_reward.memmap.npy',  # 自我奖励内存映射文件名,默认为 'self_reward.memmap.npy'
        preference_max_seq_len: int = 1024,  # 偏好最大序列长度,默认为 1024
        generate_reward_max_seq_len: int = 256,  # 生成奖励最大序列长度,默认为 256
        is_valid_reward: Callable[[float], bool] = lambda *args: True,  # 是否有效奖励的函数,默认为始终返回 True
        is_valid_reward_pair: Optional[Callable[[float, float], bool]] = None,  # 是否有效奖励对的函数,默认为 None
        pick_paired_rewards: Callable[[Tensor], Tensor] = default_pick_paired_rewards_fn,  # 选择配对奖励的函数,默认为 default_pick_paired_rewards_fn
        pad_id: int = -1  # 填充 ID,默认为 -1
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        model,
        num_candidate_responses,
        self_reward_model,
        reward_config,
        batch_size,
        prompt_dataset,
        gen_filter_fn,
        gen_filter_kwargs,
        gen_temperature,
        eval_filter_fn,
        eval_filter_kwargs,
        eval_temperature,
        tokenizer_encode,
        tokenizer_decode,
        num_evals_to_average,
        is_valid_reward,
        is_valid_reward_pair,
        pick_paired_rewards,
        reward_model,
        generate_reward_max_seq_len,
        num_preference_pairs,
        preference_max_seq_len,
        pad_id,
        data_folder,
        preference_seq_memmap_file,
        prompt_len_memmap_file,
        self_reward_memmap_file,
        accelerator
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 初始化属性
        self.model = model
        self.num_candidate_responses = num_candidate_responses

        self.self_reward_model = default(self_reward_model, model)
        self.reward_config = reward_config.init()

        self.batch_size = batch_size
        self.prompt_dataset = prompt_dataset
        self.prompt_dataloader = DataLoader(prompt_dataset, batch_size = batch_size, shuffle = True)

        self.gen_filter_fn = gen_filter_fn
        self.gen_filter_kwargs = gen_filter_kwargs
        self.gen_temperature = gen_temperature

        self.eval_filter_fn = eval_filter_fn
        self.eval_filter_kwargs = eval_filter_kwargs
        self.eval_temperature = eval_temperature

        self.tokenizer_encode = cast_output(lambda t: t.long())(tokenizer_encode)
        self.tokenizer_decode = cast_input(lambda t: t.long() if torch.is_tensor(t) else [*map(int, t)])(tokenizer_decode)

        self.num_evals_to_average = num_evals_to_average

        # 逻辑用于采样奖励对并在将其添加到生成的偏好数据集之前进行验证

        self.is_valid_reward = is_valid_reward
        self.is_valid_reward_pair = default(is_valid_reward_pair, lambda *args: True)
        self.pick_paired_rewards = pick_paired_rewards

        # 准备外部奖励模型,如果传入的话

        self.has_external_reward_model = exists(reward_model)
        self.reward_model = reward_model

        # 形状和填充

        self.generate_reward_max_seq_len = generate_reward_max_seq_len

        self.num_preference_pairs = num_preference_pairs

        self.preference_max_seq_len = preference_max_seq_len

        self.pad_id = pad_id

        memmap_shape = (num_preference_pairs, 2, preference_max_seq_len)

        # 保存以便在最后返回 DPO 数据集的实例

        self.dpo_dataset_kwargs = dict(
            data_folder = data_folder,
            preference_seq_memmap_file = preference_seq_memmap_file,
            prompt_len_memmap_file = prompt_len_memmap_file
        )

        # npy 文件的 memmap

        self.data_folder_path = Path(data_folder)
        self.data_folder_path.mkdir(exist_ok = True, parents = True)

        self.preference_seq_memmap_path = self.data_folder_path / preference_seq_memmap_file
        self.prompt_len_memmap_path = self.data_folder_path / prompt_len_memmap_file
        self.self_reward_mmemap_path = self.data_folder_path / self_reward_memmap_file

        self.preference_seq_memmap = open_memmap(str(self.preference_seq_memmap_path), dtype = 'int', mode = 'w+', shape = memmap_shape)
        self.prompt_len_memmap = open_memmap(str(self.prompt_len_memmap_path), dtype = 'int', mode = 'w+', shape = (num_preference_pairs,))
        self.self_reward_memmap_file = open_memmap(str(self.self_reward_mmemap_path), dtype = 'float32', mode = 'w+', shape = (num_preference_pairs, 2))

        self.accelerator = accelerator

    # 返回加速器设备
    @property
    def device(self):
        return self.accelerator.device

    # 生成奖励
    def generate_reward(
        self,
        prompt: str,
        response: str
        ) -> Optional[float]:

        """
        main contribution of the paper is the logic in this function
        in paper, they sample it 3 times and then average
        """

        # 获取模型参数的设备信息
        device = next(self.model.parameters()).device

        # 获取奖励配置中的模板函数和解析奖励函数
        template_fn = self.reward_config.template_fn
        parse_reward = self.reward_config.parse_reward

        # 根据模板函数生成奖励提示字符串,并使用分词器编码为张量
        reward_prompt_str = template_fn(prompt=prompt, response=response)
        reward_prompt = self.tokenizer_encode(reward_prompt_str).to(device)

        # 复制奖励提示张量,重复次数为 self.num_evals_to_average
        reward_prompt = repeat(reward_prompt, 'n -> b n', b=self.num_evals_to_average)

        reward_prompt = reward_prompt.to(device)
        self_reward_model = self_reward_model.to(device)

        # 使用自我奖励模型生成奖励响应
        reward_responses = sample(
            self_reward_model,
            prompts=reward_prompt,
            seq_len=self.generate_reward_max_seq_len,
            temperature=self.eval_temperature,
            filter_fn=self.eval_filter_fn,
            filter_kwargs=self.eval_filter_kwargs
        )

        # 将奖励响应转换为字符串列表
        reward_responses_as_str: List[str] = [self.tokenizer_decode(resp[resp != self.pad_id].cpu()) for resp in reward_responses]
        
        # 解析奖励字符串列表,得到奖励值列表
        rewards: List[Optional[float]] = [parse_reward(resp_str) for resp_str in reward_responses_as_str]

        # 过滤掉不存在的奖励值
        rewards = [*filter(exists, rewards)] # for now, just filter out any failed responses

        # 如果奖励值列表为空,则返回 None
        if len(rewards) == 0:
            return None

        # 计算奖励值列表的平均值
        avg_reward = Tensor(rewards).mean().item()
        return avg_reward

    @torch.no_grad()
# 定义了一个类 FinetuneConfig,用于存储微调配置信息
class FinetuneConfig:
    pass

# 导入 partial 函数,用于创建带有默认参数的函数
default_dict = partial(field, default_factory = dict)

# 使用 dataclass 装饰器定义了一个类 SFTConfig,继承自 FinetuneConfig,用于存储自训练配置信息
@dataclass
class SFTConfig(FinetuneConfig):
    train_dataset: Union[Dataset, List[Dataset]]  # 训练数据集,可以是单个数据集或数据集列表
    valid_dataset: Optional[Dataset] = None  # 验证数据集,默认为 None
    dropout: float = 0.1  # dropout 概率,默认为 0.1
    trainer_kwargs: dict = default_dict()  # 训练器参数,默认为一个空字典

# 使用 dataclass 装饰器定义了一个类 SelfRewardDPOConfig,继承自 FinetuneConfig,用于存储自奖励 DPO 配置信息
@dataclass
class SelfRewardDPOConfig(FinetuneConfig):
    prompt_dataset: Dataset  # 提示数据集
    num_generated_preference_pairs: int  # 生成的偏好对数量
    dpo_beta: float = 0.1  # DPO beta 参数,默认为 0.1
    max_seq_len: int = 1024  # 最大序列长度,默认为 1024
    rewarding_model: Optional[Module] = None  # 奖励模型,默认为 None
    self_reward_config_keyname: str = 'default'  # 自奖励配置键名,默认为 'default'
    is_valid_reward: Callable[[float], bool] = lambda reward: reward >= 0  # 验证奖励是否有效的函数,默认为 lambda 函数
    is_valid_reward_pair: Callable[[Tensor, Tensor], bool] = default_is_valid_reward_pair  # 验证奖励对是否有效的函数
    pick_paired_rewards_fn: Callable[[Tensor], Tensor] = default_pick_paired_rewards_fn  # 选择配对奖励的函数
    dropout: float = 0.1  # dropout 概率,默认为 0.1
    early_stopper_eval_module: Optional[Module] = None  # 早停评估模块,默认为 None
    num_train_steps: Optional[Module] = None  # 训练步数,默认为 None
    num_candidate_responses: int = 4  # 候选响应数量,默认为 4
    num_sampled_reward_responses: int = 3  # 采样奖励响应数量,默认为 3
    gen_temperature: float = 0.7  # 生成温度,默认为 0.7
    gen_filter_fn: Callable = top_p  # 生成过滤函数,默认为 top_p
    gen_filter_kwargs: dict = default_dict()  # 生成过滤函数参数,默认为一个空字典
    eval_temperature: float = 0.7  # 评估温度,默认为 0.7
    eval_filter_fn: Callable = top_p  # 评估过滤函数,默认为 top_p
    eval_filter_kwargs: dict = default_dict()  # 评估过滤函数参数,默认为一个空字典
    trainer_kwargs: dict = field(default_factory = dict)  # 训练器参数,默认为一个空字典
    reward_generator_kwargs: dict = default_dict()  # 奖��生成器参数,默认为一个空字典

# 使用 dataclass 装饰器定义了一个类 ExternalRewardDPOConfig,继承自 FinetuneConfig,用于存储外部奖励 DPO 配置信息
@dataclass
class ExternalRewardDPOConfig(FinetuneConfig):
    reward_model: Module  # 奖励模型
    dpo_beta: float = 0.1  # DPO beta 参数,默认为 0.1
    max_seq_len: int = 1024  # 最大序列长度,默认为 1024
    gen_temperature: float = 0.7  # 生成温度,默认为 0.7
    gen_filter_fn: Callable = top_p  # 生成过滤函数,默认为 top_p
    gen_filter_kwargs: dict = default_dict()  # 生成过滤函数参数,默认为一个空字典
    dropout: float = 0.1  # dropout 概率,默认为 0.1
    trainer_kwargs: dict = default_dict()  # 训练器参数,默认为一个空字典
    reward_generator_kwargs: dict = default_dict()  # 奖励生成器参数,默认为一个空字典

# 使用 dataclass 装饰器定义了一个类 SelfPlayConfig,继承自 FinetuneConfig,用于存储自对弈配置信息
@dataclass
class SelfPlayConfig(FinetuneConfig):
    train_dataset: Dataset  # 训练数据集
    valid_dataset: Optional[Dataset] = None  # 验证数据集,默认为 None
    max_seq_len: int = 1024  # 最大序列长度,默认为 1024
    spin_λ: float = 0.1  # spin_λ 参数,默认为 0.1
    dropout: float = 0.1  # dropout 概率,默认为 0.1
    temperature: float = 0.7  # 温度,默认为 0.7
    filter_fn: Callable = top_p  # 过滤函数,默认为 top_p
    filter_kwargs: dict = default_dict()  # 过滤函数参数,默认为一个空字典
    trainer_kwargs: dict = default_dict()  # 训练器参数,默认为一个空字典
    spin_kwargs: dict =  default_dict()  # spin 参数,默认为一个空字典

# 定义了一个函数 create_default_paper_config,用于生成默认的论文配置信息
@beartype
def create_default_paper_config(
    *,
    train_sft_dataset: Union[Dataset, List[Dataset],  # 训练 SFT 数据集,可以是单个数据集或数据集列表
    self_reward_prompt_dataset: Union[Dataset, Tuple[Dataset, Dataset]],  # 自奖励提示数据集,可以是单个数据集或数据集元组
    valid_sft_dataset: Optional[Dataset] = None,  # 验证 SFT 数据集,默认为 None
    num_generated_preference_pairs = (3964, 6942),  # 生成的偏好对数量,默认为 (3964, 6942)
    early_stopper_eval_module: Optional[Module] = None,  # 早停评估模块,默认为 None
    dpo_num_train_steps: Optional[int] = None,  # DPO 训练步数,默认为 None
    sft_config: dict = dict(),  # SFT 配置信息,默认为一个空字典
    self_reward_config: dict = dict()  # 自奖励配置信息,默认为一个空字典
) -> List[FinetuneConfig]:  # 返回值为 FinetuneConfig 类型的列表

    prompt_dataset_iter1, prompt_dataset_iter2 = cast_tuple(self_reward_prompt_dataset, 2, validate = True)  # 将自奖励提示数据集转换为元组
    num_generated_iter1, num_generated_iter2 = num_generated_preference_pairs  # 解包生成的偏好对数量

    return [
        SFTConfig(
            train_dataset = train_sft_dataset,  # 训练 SFT 数据集
            valid_dataset = valid_sft_dataset,  # 验证 SFT 数据集
            **sft_config  # 其他 SFT 配置信息
        ),
        SelfRewardDPOConfig(
            num_generated_preference_pairs = num_generated_iter1,  # 生成的偏好对数量
            prompt_dataset = prompt_dataset_iter1,  # 提示数据集
            num_train_steps = dpo_num_train_steps,  # DPO 训练步数
            early_stopper_eval_module = early_stopper_eval_module,  # 早停评估模块
            **self_reward_config  # 其他自奖励配置信息
        ),
        SelfRewardDPOConfig(
            num_generated_preference_pairs = num_generated_iter2,  # 生成的偏好对数量
            prompt_dataset = prompt_dataset_iter2,  # 提示数据集
            num_train_steps = dpo_num_train_steps,  # DPO 训练步数
            early_stopper_eval_module = early_stopper_eval_module,  # 早停评估模块
            **self_reward_config  # 其他自奖励配置信息
        )
    ]

# 定义了一个类 SelfRewardingTrainer,继承自 Module,用于自奖励训练
class SelfRewardingTrainer(Module):
    @beartype  # 类型注解装饰器
    # 初始化方法,接受模型、微调配置、编码和解码函数等参数
    def __init__(
        self,
        model: Module,
        *,
        finetune_configs: Union[Dict, List[FinetuneConfig]],
        tokenizer_encode: Callable[[str], TensorType['seq', int]],
        tokenizer_decode: Callable[[TensorType['seq', int]], str],
        self_reward_prompt_config: Union[RewardConfig, Dict[str, RewardConfig]] = SELF_REWARD_PROMPT_CONFIG,
        pad_id: int = -1,
        checkpoints_folder: str = './checkpoints',
        accelerate_kwargs: dict = dict()
    # 获取未加速的模型
    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)

    # 打印方法,用于输出信息
    def print(self, *msg):
        self.accelerator.print(*msg)

    # 等待方法,等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 保存方法,保存模型参数到指定路径
    def save(self, path: str, overwrite: bool = False):
        self.wait()

        # 如果是主进程
        if self.accelerator.is_main_process:

            # 拼接保存路径
            path = self.checkpoints_folder / path

            # 如果文件已存在且不允许覆盖,则报错
            assert not path.exists() or overwrite, f'file already exists'

            # 封装模型参数并保存
            pkg = dict(
                model = self.unwrapped_model.state_dict()
            )

            torch.save(pkg, str(path))

    # 前向传播方法,用于微调训练
    def forward(
        self,
        overwrite_checkpoints: bool = False
    ):

        # 遍历训练器类型和训练器
        for ind, (trainer_type, trainer) in enumerate(self.trainers):
            finetuning_stage = ind + 1
            trainer()

            # 保存微调阶段的模型参数
            self.save(f'{finetuning_stage}.{trainer_type}.ckpt.pt', overwrite = overwrite_checkpoints)

        # 输出训练完成信息
        self.print(f'self-reward training done')

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\spin.py

from pathlib import Path
# 导入 Path 模块,用于处理文件路径

from beartype import beartype
from beartype.typing import Optional, Callable, Union
# 导入 beartype 模块,用于类型注解

from torchtyping import TensorType
# 导入 TensorType 类型注解

import torch
# 导入 torch 模块

from torch.nn import Module, Dropout
# 从 torch.nn 模块中导入 Module 和 Dropout 类

import torch.nn.functional as F
# 导入 torch.nn.functional 模块,用于神经网络函数

from torch.cuda.amp import autocast
# 导入 autocast 函数,用于混合精度训练

from torch.utils.data import Dataset, DataLoader
# 导入 Dataset 和 DataLoader 类,用于处理数据集和数据加载

from torch.nn.utils.rnn import pad_sequence
# 导入 pad_sequence 函数,用于填充序列

from accelerate import Accelerator
# 导入 Accelerator 类,用于加速训练

from einops import rearrange
# 导入 rearrange 函数,用于重排张量维度

from einx import get_at
# 导入 get_at 函数,用于获取张量的特定位置

from pytorch_custom_utils.utils import (
    masked_mean,
    maybe_and_mask
)
# 从 pytorch_custom_utils.utils 模块中导入 masked_mean 和 maybe_and_mask 函数

from pytorch_custom_utils.accelerate_utils import (
    model_forward_contexts
)
# 从 pytorch_custom_utils.accelerate_utils 模块中导入 model_forward_contexts 函数

from self_rewarding_lm_pytorch.dpo import (
    adam_optimizer_with_linear_decay
)
# 从 self_rewarding_lm_pytorch.dpo 模块中导入 adam_optimizer_with_linear_decay 函数

from self_rewarding_lm_pytorch.sampling_utils import (
    sample,
    top_p,
    top_k
)
# 从 self_rewarding_lm_pytorch.sampling_utils 模块中导入 sample、top_p 和 top_k 函数

from tqdm import tqdm
# 导入 tqdm 模块,用于显示进度条

from ema_pytorch import EMA
# 导入 EMA 类,用于指数移动平均

# helper functions

def exists(v):
    return v is not None
# 定义 exists 函数,判断变量是否为 None

def cycle(dl):
    while True:
        for batch in dl:
            yield batch
# 定义 cycle 函数,用于循环迭代数据加载器中的批次数据

def log_prob_from_model_and_seq(model, seq):
    logits = model(seq)
    log_probs = logits.log_softmax(dim = -1)
    return get_at('b n [c], b n -> b n', log_probs, seq)
# 定义 log_prob_from_model_and_seq 函数,计算模型生成序列的对数概率

def prompt_mask_from_len(lengths, seq):
    seq_len, device = seq.shape[-1], seq.device
    return torch.arange(seq_len, device = device) < rearrange(lengths, '... -> ... 1')
# 定义 prompt_mask_from_len 函数,根据序列长度生成掩码

def set_dropout_(model: Module, prob: float):
    for module in model.modules():
        if isinstance(module, Dropout):
            module.p = prob
# 定义 set_dropout_ 函数,设置模型中的 Dropout 层的概率

# main class

class SPIN(Module):
    def __init__(
        self,
        model: Module,
        *,
        λ = 0.1,
        pad_id: Optional[int] = None,
        ref_model_ema_decay = 1.,
        ema_kwargs: dict = dict()
    ):
        super().__init__()
        self.policy_model = model

        self.ref_model = EMA(
            model,
            beta = ref_model_ema_decay,
            **ema_kwargs
        )
        # 初始化 SPIN 类,包括策略模型、参考模型和参数

        self.λ = λ
        self.pad_id = pad_id
        # 设置 λ 和 pad_id 属性

    def update_reference_model_with_policy(self):
        self.ref_model.copy_params_from_model_to_ema()
    # 更新参考模型参数为策略模型参数

    def update_ema(self):
        self.ref_model.update()
    # 更新指数��动平均

    def parameters(self):
        return self.policy_model.parameters()
    # 返回策略模型的参数

    @property
    def device(self):
        return next(self.parameters()).device
    # 返回模型所在设备

    @autocast(enabled = False)
    def forward(
        self,
        generated_seq: TensorType['b', 'n', int],
        real_seq: TensorType['b', 'n', int],
        prompt_len: TensorType['b', int],
        generated_seq_mask: Optional[TensorType['b', 'n', bool]] = None,
        real_seq_mask: Optional[TensorType['b', 'n', bool]] = None
    # 设置策略模型为训练模式
    self.policy_model.train()

    """
    b - batch
    n - sequence length
    """

    # 根据提示长度和实际序列生成实际提示掩码和生成提示掩码
    real_prompt_mask = prompt_mask_from_len(prompt_len, real_seq)
    generated_prompt_mask = prompt_mask_from_len(prompt_len, generated_seq)

    """
    Equation 4.7 in https://arxiv.org/abs/2401.01335v1
    """

    # 如果存在填充 ID
    if exists(self.pad_id):
        # 确保生成序列掩码和实际序列掩码不存在
        assert not exists(generated_seq_mask)
        assert not exists(real_seq_mask)
        # 生成生成序列掩码并填充
        generated_seq_mask = generated_seq != self.pad_id
        generated_seq.masked_fill_(~generated_seq_mask, 0)

        # 生成实际序列掩码并填充
        real_seq_mask = real_seq != self.pad_id
        real_seq.masked_fill_(~real_seq_mask, 0)

    # 禁用梯度计算
    with torch.no_grad():
        # 设置参考模型为评估模式
        self.ref_model.eval()
        # 计算生成序列和实际序列的参考模型对数概率
        ref_generated_logprob = log_prob_from_model_and_seq(self.ref_model, generated_seq)
        ref_real_logprob = log_prob_from_model_and_seq(self.ref_model, real_seq)

    # 计算策略模型对生成序列和实际序列的对数概率
    policy_generated_logprob = log_prob_from_model_and_seq(self.policy_model, generated_seq)
    policy_real_logprob = log_prob_from_model_and_seq(self.policy_model, real_seq)

    # 对变长序列进行掩码平均值计算

    # 对生成序列和实际序列的策略模型对数概率和参考模型对数概率进行掩码平均值计算
    policy_generated_logprob, ref_generated_logprob = [masked_mean(seq, maybe_and_mask(generated_seq_mask, ~generated_prompt_mask)) for seq in (policy_generated_logprob, ref_generated_logprob)]
    policy_real_logprob, ref_real_logprob = [masked_mean(seq, maybe_and_mask(real_seq_mask, ~real_prompt_mask)) for seq in (policy_real_logprob, ref_real_logprob)]

    # 计算 SPIN 损失

    # 计算损失值
    losses = -F.logsigmoid(self.λ * ((policy_real_logprob - ref_real_logprob) - (policy_generated_logprob - ref_generated_logprob)))

    # 返回损失值的平均值
    return losses.mean()
class SPINTrainer(Module):
    # 定义 SPINTrainer 类,继承自 Module 类
    def __init__(
        self,
        model: Union[Module, SPIN],
        *,
        train_sft_dataset: Dataset,
        max_seq_len: int,
        valid_sft_dataset: Optional[Dataset] = None,
        valid_every = 100,
        accelerator: Optional[Accelerator] = None,
        accelerate_kwargs: dict = dict(),
        batch_size = 16,
        grad_accum_steps = 2,
        epochs = 2,
        start_learning_rate = 1e-6,
        end_learning_rate = 1e-7,
        learning_rate_num_decay_steps = 1000,
        dropout = 0.,
        weight_decay = 0.,
        adam_kwargs: dict = dict(),
        temperature = 0.7,
        filter_fn = top_p,
        filter_kwargs = dict(thres = 0.9),
        pad_id: int = -1,
        ref_model_ema_decay = 1.,
        checkpoint_every = None,
        checkpoint_folder = './spin-checkpoints',
        spin_kwargs: dict = dict(
            λ = 0.1,
        )
    ):
        # 初始化函数,接受多个参数
        super().__init__()

        self.accelerator = accelerator
        # 设置 accelerator 属性为传入的 accelerator 参数
        if not exists(self.accelerator):
            self.accelerator = Accelerator(**accelerate_kwargs)
            # 如果 accelerator 不存在,则根据 accelerate_kwargs 创建一个 Accelerator 对象

        if not isinstance(model, SPIN):
            model = SPIN(
                model,
                pad_id = pad_id,
                ref_model_ema_decay = ref_model_ema_decay,
                **spin_kwargs
            )
            # 如果 model 不是 SPIN 类型,则根据传入参数创建一个 SPIN 对象

        self.model = model
        self.dropout = dropout
        self.train_dataloader = DataLoader(train_sft_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
        # 设置模型、dropout 和训练数据加载器属性

        self.grad_accum_steps = grad_accum_steps
        self.num_train_steps = len(self.train_dataloader) // self.grad_accum_steps * epochs
        # 设置梯度累积步数和训练步数

        self.optimizer = adam_optimizer_with_linear_decay(
            model,
            start_learning_rate,
            end_learning_rate,
            num_decay_steps = learning_rate_num_decay_steps,
            accelerator = self.accelerator,
            weight_decay = weight_decay,
            adam_kwargs = adam_kwargs
        )
        # 使用 adam_optimizer_with_linear_decay 函数创建优化器

        (
            self.model,
            self.train_dataloader
        ) = self.accelerator.prepare(
            self.model,
            self.train_dataloader
        )
        # 准备模型和训练数据加载器

        self.max_seq_len = max_seq_len
        self.pad_id = pad_id
        # 设置最大序列长度和 pad_id

        # sampling

        self.temperature = temperature
        self.filter_fn = filter_fn
        self.filter_kwargs = filter_kwargs
        # 设置采样相关参数

        # validation

        self.valid_dataloader = None
        self.valid_every = valid_every
        # 初始化验证数据加载器和验证频率

        if exists(valid_sft_dataset):
            self.valid_dataloader = DataLoader(valid_sft_dataset, batch_size = batch_size)
            # 如果存在验证数据集,则创建验证数据加载器

        # checkpointing

        self.should_checkpoint = exists(checkpoint_every)
        self.checkpoint_every = checkpoint_every
        # 设置是否需要检查点和检查点频率

        if self.should_checkpoint:
            self.checkpoint_folder = Path(checkpoint_folder)
            self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
            # 如果需要检查点,则创建检查点文件夹

        self.steps = 0
        # 初始化步数为 0

    @property
    def is_main(self):
        return self.accelerator.is_main_process
        # 返回是否为主进程的属性

    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)
        # 返回解封装后的模型属性

    def print(self, *msg):
        self.accelerator.print(*msg)
        # 打印函数

    def log(self, **data):
        self.accelerator.log(data, step = self.steps)
        # 记录日志函数

    def wait(self):
        return self.accelerator.wait_for_everyone()
        # 等待所有进程函数

    def save(self, path: str, overwrite: bool = False):
        self.wait()

        if self.is_main:

            path = self.checkpoint_folder / path

            assert not path.exists() or overwrite, f'file already exists'

            pkg = dict(
                model = self.unwrapped_model.state_dict()
            )

            torch.save(pkg, str(path))
            # 保存模型函数

    def calc_spin_loss(
        self,
        real_seq: TensorType['b', 'n', int],
        prompt_len: TensorType['b', int]
        # 计算 SPIN 损失函数
    ):
        # 根据实际序列长度和掩码生成提示掩码
        prompt_mask = prompt_mask_from_len(prompt_len, real_seq)
        # 根据提示掩码拆分实际序列,得到提示列表
        prompts = real_seq[prompt_mask].split(prompt_len.tolist())

        # 使用策略模型生成序列
        generated_seqs = sample(
            self.unwrapped_model.policy_model,
            prompts = prompts,
            seq_len = self.max_seq_len,
            temperature = self.temperature,
            filter_fn = self.filter_fn,
            filter_kwargs = self.filter_kwargs,
            output_keep_prompt = True
        )

        # 计算 SPIN 损失
        spin_loss = self.model(
            real_seq = real_seq,
            generated_seq = generated_seqs,
            prompt_len = prompt_len
        )

        return spin_loss

    def forward(self, overwrite_checkpoints: bool = True):
        """
        Algorithm 1 - https://arxiv.org/abs/2401.01335v1
        """

        # 更新参考模型
        self.model.update_reference_model_with_policy()

        self.steps = 0

        # 设置模型的 dropout
        set_dropout_(self.model, self.dropout)

        # 创建训练数据加载器的迭代器
        train_dataloader_iter = cycle(self.train_dataloader)

        # 循环进行自我训练
        for _ in tqdm(range(self.num_train_steps), desc = 'spin fine-tuning'):

            self.model.train()
            # 遍历模型前向计算上下文
            for forward_context in model_forward_contexts(self.accelerator, self.model, self.grad_accum_steps):
                with forward_context():
                    # 从训练数据加载器中获取实际序列和提示长度
                    real_seq, prompt_len = next(train_dataloader_iter)

                    # 计算 SPIN 损失
                    train_loss = self.calc_spin_loss(real_seq, prompt_len)

                    # 反向传播
                    self.accelerator.backward(train_loss / self.grad_accum_steps)

            # 打印训练损失
            self.print(f'train spin loss: {train_loss.item():.3f}')
            self.log(loss = train_loss.item())

            # 更新优化器
            self.optimizer.step()
            self.optimizer.zero_grad()

            self.steps += 1

            # 等待
            self.wait()

            # 更新指数移动平均模型
            self.unwrapped_model.update_ema()

            # 如果存在验证数据加载器且满足验证频率条件
            if exists(self.valid_dataloader) and not (self.valid_every % self.steps):
                self.wait()

                if self.is_main:
                    total_loss = 0.
                    total_batches = 0.

                    with torch.no_grad():
                        self.model.eval()

                        # 遍历验证数据加载器
                        for valid_seq, prompt_len in tqdm(self.valid_dataloader, desc = 'valid spin'):
                            batch = valid_seq.shape[0]
                            # 计算验证 SPIN 损失
                            valid_spin_loss = self.calc_spin_loss(valid_seq, prompt_len)

                            total_batches += batch
                            total_loss += valid_spin_loss * batch

                        valid_loss = total_loss / total_batches

                        # 打印验证损失
                        self.print(f'valid spin loss: {valid_loss.item():.3f}')
                        self.log(valid_spin_loss = valid_loss.item())

            # 如果需要保存检查点且满足检查点频率条件
            if self.should_checkpoint and not (self.checkpoint_every % self.steps):
                checkpoint_num = self.steps // self.checkpoint_every
                self.save(f'spin.ckpt.{checkpoint_num}.pt', overwrite = overwrite_checkpoints)

        self.print(f'self-play training complete')

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\__init__.py

# 导入自我奖励语言模型训练器和奖励配置
from self_rewarding_lm_pytorch.self_rewarding_lm_pytorch import (
    SelfRewardingTrainer,
    RewardConfig
)

# 导入 SPIN 模型和 SPIN 训练器
from self_rewarding_lm_pytorch.spin import (
    SPIN,
    SPINTrainer,
)

# 导入 DPO 模型和 DPO 训练器
from self_rewarding_lm_pytorch.dpo import (
    DPO,
    DPOTrainer,
)

# 导入创建模拟数据集的函数
from self_rewarding_lm_pytorch.mocks import create_mock_dataset

# 导入自我奖励语言模型微调配置
from self_rewarding_lm_pytorch.self_rewarding_lm_pytorch import (
    SFTConfig,
    SelfRewardDPOConfig,
    ExternalRewardDPOConfig,
    SelfPlayConfig,
    create_default_paper_config
)

.\lucidrains\self-rewarding-lm-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'self-rewarding-lm-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.2.8',  # 版本号
  license='MIT',  # 许可证
  description = 'Self Rewarding LM - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/self-rewarding-lm-pytorch',  # URL
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'self rewarding',
    'direct preference optimization'
  ],
  install_requires=[  # 安装依赖
    'accelerate',
    'beartype',
    'einops>=0.7.0',
    'einx[torch]>=0.1.3',
    'ema-pytorch>=0.3.3',
    'Jinja2',
    'numpy',
    'pytorch-custom-utils>=0.0.17',
    'torch>=2.0',
    'torchtyping',
    'tqdm'
  ],
  classifiers=[  # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Simple Hierarchical Transformer

Experiments around a simple idea for inducing multiple hierarchical predictive coding models within a GPT. It is so simple, it may not work. But then again, deep learning progress is built on the bedrocks of simple ideas. Worth a shot.

So far, the idea has passed the litmus test from a research friend. Will bring it to completion in the next week or so. If it does not work out, I'll leave the negative experimental results as well as the repository around, and maybe some PhD student can build upon it.

Update: I think it is working 🤞

Appreciation

Install

$ pip install simple-hierarchical-transformer

Usage

Three hierarchies, all servicing predicting the next token

import torch
from simple_hierarchical_transformer import HierarchicalTransformer

model = HierarchicalTransformer(
    num_tokens = 20000,                # number of tokens
    dim = 512,                         # model dimensions
    depth = 6,                         # depth
    dim_head = 64,                     # dimension per attention head
    heads = 8,                         # attention heads
    seq_len = 2048,                    # sequence lengths
    hierarchies = (1, 2, 8),           # hierarchies - here we have 1x (like in a regular transformer), then 2x and 8x compressed hierarchical tokens that undergo their own transformer blocks. information is pooled into one hierarchy at each layer
    window_sizes = (32, 64, None)      # local attention window sizes - the idea is that the higher hierarchies can pass distant information to the local one. None stands for full receptive field. Setting 0 would turn off attention at this hierarchy altogether (while token shift will still be in effect in each layer)
)

ids = torch.randint(0, 20000, (1, 2048))

loss, _ = model(ids, return_loss = True)
loss.backward()

# after much training

logits = model(ids)

By not specifying hierarchies and window_sizes, you basically default to a regular autoregressive transformer with attention across full sequence length.


# non-hierarchical transformer

model = HierarchicalTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 8,
    dim_head = 64,
    heads = 8,
    seq_len = 2048,
    hierarchies = 1,        # implied 1 if not set
    window_sizes = None     # implied None (full sequence length) if not set
)

Now something more complex. Experiments show that as you compress up the hierarchies, you need greater model dimensions for appropriate capacity.

model = HierarchicalTransformer(
    num_tokens = 256,
    dim = (128, 256, 512, 1024),
    depth = 8,
    seq_len = 1024,
    use_flash_attn = True,
    ff_mult = (2, 2, 4, 4),
    dim_head = (16, 32, 64, 64),
    heads = (2, 4, 8, 8),
    hierarchies = (1, 2, 4, 16),
    hierarchical_stride = (1, 1, 1, 8),  # this would determine the stride when compressing, and when concatting the hierarchical tokens to the fine tokens, the past tokens will be repeated this amount of time. causality is not violated as using the trick from hourglass transformers where sequence is shifted by compression factor - 1. recommend sticking with 1 except for highly compressed hierarchies, as it becomes very uncompetitive with baseline and generations look off
    window_sizes = (16, 32, 64, None)
).cuda()

# hierarchies
# 1x - dim 128 - attention (2 heads, 16 dim, receptive field 16)
# 2x - dim 256 - attention (4 heads, 32 dim, receptive field 32)
# 4x - dim 512 - attention (8 heads, 64 dim, receptive field 64)
# 8x - dim 1024 - attention (8 heads, 64 dim, receptive field of all)

Todo

Citations

Closest idea would be hourglass transformers.

And my renewed interest in hierarchical approaches came from reading this.

@article{Nawrot2021HierarchicalTA,
    title   = {Hierarchical Transformers Are More Efficient Language Models},
    author  = {Piotr Nawrot and Szymon Tworkowski and Michal Tyrolski and Lukasz Kaiser and Yuhuai Wu and Christian Szegedy and Henryk Michalewski},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.13711}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Sun2022ALT,
    title     = {A Length-Extrapolatable Transformer},
    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
    year      = {2022}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
@article{Piergiovanni2023Mirasol3BAM,
    title   = {Mirasol3B: A Multimodal Autoregressive model for time-aligned and contextual modalities},
    author  = {A. J. Piergiovanni and Isaac Noble and Dahun Kim and Michael S. Ryoo and Victor Gomes and Anelia Angelova},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2311.05698},
    url     = {https://api.semanticscholar.org/CorpusID:265129010}
}

.\lucidrains\simple-hierarchical-transformer\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'simple-hierarchical-transformer',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找并包含所有包
  version = '0.2.0',  # 版本号
  license='MIT',  # 许可证
  description = 'Simple Hierarchical Transformer',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/simple-hierarchical-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'hierarchical'
  ],
  install_requires=[  # 依赖的包列表
    'accelerate',
    'einops>=0.7.0',
    'local-attention',
    'torch>=2.0'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\simple-hierarchical-transformer\simple_hierarchical_transformer\attention.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F

# 导入 namedtuple 类
from collections import namedtuple
# 导入 wraps 函数
from functools import wraps
# 从 packaging 模块中导入 version 类
from packaging import version
# 从 einops 库中导入 rearrange 函数

from einops import rearrange

# 定义常量 Config,使用 namedtuple 创建一个命名元组
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 仅执行一次的装饰器函数
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 仅打印一次的函数
print_once = once(print)

# 主类

# 定义 Attend 类,继承自 Module 类
class Attend(Module):
    # 初始化函数
    def __init__(
        self,
        causal = False,
        use_flash_attn = False
    ):
        super().__init__()
        # 是否是因果关系
        self.causal = causal
        # 注册缓冲区 mask,初始值为 None
        self.register_buffer("mask", None, persistent=False)

        # 是否使用 flash attention
        self.use_flash_attn = use_flash_attn
        # 断言语句,如果使用 flash attention 且 torch 版本小于 2.0,则抛出异常
        assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 cuda 和 cpu 的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        # 如果没有可用的 cuda 或不使用 flash attention,则直接返回
        if not torch.cuda.is_available() or not use_flash_attn:
            return

        # 获取当前 cuda 设备的属性
        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        # 如果是 A100 GPU,则打印信息并设置 cuda_config
        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            # 如果不是 A100 GPU,则打印信息并设置 cuda_config
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    # 获取 mask 函数
    def get_mask(self, n, device):
        # 如果 mask 存在且形状大于等于 n,则返回 mask
        if exists(self.mask) and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        # 创建 mask,上三角矩阵
        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        # 注册缓冲区 mask
        self.register_buffer("mask", mask, persistent=False)
        return mask

    # flash attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        # 检查 mask 是否存在并扩展到兼容的形状
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 flash attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 torch.backends.cuda.sdp_kernel 运行 pytorch 2.0 flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                is_causal = self.causal
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device = q.shape[-2], q.device

        scale = q.shape[-1] ** -0.5

        if self.use_flash_attn:
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

        # key padding mask
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 因果 mask
        if self.causal:
            causal_mask = self.get_mask(n, device)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力
        attn = sim.softmax(dim=-1)

        # 聚合值
        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\simple-hierarchical-transformer\simple_hierarchical_transformer\simple_hierarchical_transformer.py

# 从 math 模块中导入 log2 和 ceil 函数
# 从 functools 模块中导入 partial 函数
# 从 itertools 模块中导入 zip_longest 函数
# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 torch 模块中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 模块中导入 Module, ModuleList
from torch.nn import Module, ModuleList
# 从 einops 模块中导入 rearrange, repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange
# 从 simple_hierarchical_transformer.attention 模块中导入 Attend
from simple_hierarchical_transformer.attention import Attend
# 从 typing 模块中导入 Tuple
from typing import Tuple
# 从 local_attention 模块中导入 LocalMHA

# 定义常量 Linear,使用 nn.Linear 函数,设置 bias 参数为 False
Linear = partial(nn.Linear, bias = False)
# 定义 LocalMHA,使用 partial 函数,设置 LocalMHA 函数的 causal 和 prenorm 参数为 True

# 定义辅助函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义辅助函数 is_power_of_two,判断一个数是否为2的幂
def is_power_of_two(n):
    return log2(n).is_integer()

# 定义辅助函数 all_unique,判断列表中的元素是否唯一
def all_unique(arr):
    return len(set(arr)) == len(arr

# 定义辅助函数 apply_fns,对输入的函数列表和张量列表进行函数应用
def apply_fns(fns, tensors):
    return [fn(tensor) for fn, tensor in zip(fns, tensors)]

# 定义辅助函数 cast_tuple,将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 定义辅助函数 default,返回第一个非空值
def default(*vals):
    for val in vals:
        if exists(val):
            return val
    return None

# 定义 eval_decorator 装饰器函数,用于在模型评估时切换为 eval 模式
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义张量辅助函数 l2norm,对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 定义余弦相似度损失函数 cosine_sim_loss,计算余弦相似度损失
def cosine_sim_loss(x, y):
    x, y = map(l2norm, (x, y))
    return 1. - einsum('b n d, b n d -> b n', x, y).mean()

# 定义采样辅助函数 log,对张量进行对数运算
def log(t, eps = 1e-20):
    return t.clamp(min = eps).log()

# 定义采样辅助函数 gumbel_noise,生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 定义采样辅助函数 gumbel_sample,使用 Gumbel 噪声进行采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# 定义采样辅助函数 top_k,对 logits 进行 top-k 采样
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
    probs.scatter_(1, ind, val)
    return probs

# 旋转位置嵌入类 RotaryEmbedding
class RotaryEmbedding(Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        scale_base = 512,
        use_xpos = True
    ):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        self.use_xpos = use_xpos
        self.scale_base = scale_base
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        self.register_buffer('scale', scale)

    # 获取设备信息
    @property
    def device(self):
        return next(self.buffers()).device

    # 前向传播函数
    @autocast(enabled = False)
    def forward(self, seq_len):
        device = self.device
        t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)

        if not self.use_xpos:
            return freqs, torch.ones(1, device = device)

        power = (t - (seq_len // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim = -1)

        return freqs, scale

# 旋转半部分函数 rotate_half
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入函数 apply_rotary_pos_emb
def apply_rotary_pos_emb(pos, t, scale = 1.):
    seq_len = t.shape[-2]

    pos = pos[..., -seq_len:, :]
    if not isinstance(scale, (int, float)):
        scale = scale[..., -seq_len:, :]

    return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)

# 应用旋转位置嵌入到查询和键函数 apply_rotary_pos_emb_qk
@autocast(enabled = False)
def apply_rotary_pos_emb_qk(rotary_emb, q, k):
    freqs, scale = rotary_emb
    q = apply_rotary_pos_emb(freqs, q, scale)
    k = apply_rotary_pos_emb(freqs, k, scale ** -1)
    return q, k

# 令牌移位函数 token_shift
def token_shift(t):
    t, t_shift = t.chunk(2, dim = -1)
    t_shift = F.pad(t_shift, (0, 0, 1, -1))
    return torch.cat((t, t_shift), dim = -1)
# hierarchy related classes

# 将序列填充到指定倍数
def pad_seq_to_multiple(t, mult):
    # 获取序列长度
    seq_len = t.shape[-2]
    # 计算下一个序列长度的倍数
    next_seq_len_mult = ceil(seq_len / mult) * mult
    # 计算需要填充的长度
    remainder = next_seq_len_mult - seq_len

    # 如果不需要填充,则直接返回原序列和序列长度
    if remainder == 0:
        return t, seq_len

    # 对序列进行填充
    t = F.pad(t, (0, 0, 0, remainder), value = 0.)
    return t, seq_len

# 将序列截断到指定倍数
def curtail_seq_to_multiple(t, mult):
    # 获取序列长度
    seq_len = t.shape[-2]
    # 计算前一个序列长度的倍数
    prev_seq_len_mult = (seq_len // mult) * mult
    # 计算需要截断的长度
    remainder = seq_len - prev_seq_len_mult

    # 如果不需要截断,则直接返回原序列
    if remainder == 0:
        return t

    # 对序列进行截断
    t = t[..., :prev_seq_len_mult, :]
    return t

# 将多个序列按照指定步长合并
def hierarchical_cat(tokens, strides: Tuple[int, ...]):
    # 断言tokens和strides的长度相等
    assert len(tokens) == len(strides)

    # 如果所有步长都为1,则直接拼接所有序列
    if all([s == 1 for s in strides]):
        return torch.cat(tokens, dim = -1)

    # 对每个序列进行重复以匹配步长
    tokens = [repeat(t, 'b n d -> b (n s) d', s = s) for t, s in zip(tokens, strides)]
    # 获取最小序列长度
    min_seq_len = min([t.shape[-2] for t in tokens])
    # 截取所有序列到最小序列长度
    tokens = [t[..., :min_seq_len, :] for t in tokens]
    return torch.cat(tokens, dim = -1)

# 定义CausalConv类
class CausalConv(Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        stride = 1
    ):
        super().__init__()
        # 设置causal_padding为kernel_size - 1
        self.causal_padding = kernel_size - 1
        # 创建Conv1d层
        self.conv = nn.Conv1d(dim_in, dim_out, kernel_size, stride = stride)

    def forward(self, x):
        # 对输入进行padding
        x = F.pad(x, (self.causal_padding, 0))
        return self.conv(x)

# 定义Compress类
class Compress(Module):
    def __init__(
        self,
        *,
        dim,
        dim_out,
        num_tokens = None,
        stride = 1,
        compress_factor = 1,
        expansion_factor = 4,
        dim_head = 64,
        heads = 8,
        ignore_index = 0,
        should_recon = False
    ):
        super().__init__()
        # 断��compress_factor大于0且为2的幂
        assert compress_factor > 0 and is_power_of_two(compress_factor)

        self.stride = stride
        self.no_compress = compress_factor == 1
        self.compress_factor = compress_factor

        self.should_recon = should_recon

        # 如果不压缩,则使用Linear层或者Identity层
        if self.no_compress:
            self.compress_fn = Linear(dim, dim_out) if dim != dim_out else nn.Identity()
            return

        dim_inner = int(dim * expansion_factor)

        # 使用Sequential定义压缩函数
        self.compress_fn = nn.Sequential(
            Rearrange('b n d -> b d n'),
            CausalConv(dim, dim_inner, compress_factor, stride = stride),
            nn.SiLU(),
            nn.Conv1d(dim_inner, dim_out, 1),
            Rearrange('b d n -> b n d')
        )

        # 如果需要重构,则定义Linear层
        if should_recon:
            assert exists(num_tokens)
            self.to_recon = Linear(dim_out, compress_factor * num_tokens)

        self.ignore_index = ignore_index

    # 重构函数
    def recon(self, h, ids):
        assert self.should_recon

        if self.no_compress:
            return torch.zeros((), device = h.device).requires_grad_()

        c = self.compress_factor
        seq_len = ids.shape[-1]

        recon_logits = self.to_recon(h)
        recon_logits = rearrange(recon_logits, 'b n (c d) -> (b c) d n', c = c)

        recon_ids = F.pad(ids, (c - 1, 0), value = self.ignore_index)
        recon_ids = tuple(recon_ids[:, i:(seq_len + i)] for i in range(c))
        recon_ids = torch.stack(recon_ids, dim = 1)
        recon_ids = rearrange(recon_ids, 'b c n -> (b c) n')

        if self.stride > 1:
            recon_ids = recon_ids[..., ::self.stride]

        recon_loss = F.cross_entropy(recon_logits, recon_ids, ignore_index = self.ignore_index)
        return recon_loss

    def forward(self, x):
        return self.compress_fn(x)

# 定义HierarchicalMerge类
class HierarchicalMerge(Module):
    def __init__(
        self,
        dims: Tuple[int, ...],
        dim_out,
        h_strides = 1
    ):
        super().__init__()
        dim = sum(dims)

        strides = cast_tuple(h_strides, len(dims))
        assert len(strides) == len(dims)

        self.strides = strides

        # 使用Sequential定义网络结构
        self.net = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, dim_out * 2),
            nn.SiLU(),
            nn.Linear(dim_out * 2, dim_out)
        )
    # 定义一个前向传播函数,接收 tokens 作为输入
    def forward(self, tokens):
        # 调用 hierarchical_cat 函数对 tokens 进行处理,得到 x
        x = hierarchical_cat(tokens, self.strides)
        # 将处理后的 x 传入神经网络中进行前向传播,返回结果
        return self.net(x)
# 定义 RMSNorm 类,继承自 Module 类
class RMSNorm(Module):
    # 初始化方法,接受维度参数 dim
    def __init__(self, dim):
        # 调用父类的初始化方法
        super().__init__()
        # 计算缩放因子
        self.scale = dim ** 0.5
        # 初始化可学习参数 gamma
        self.gamma = nn.Parameter(torch.ones(dim))

    # 前向传播方法,接受输入 x
    def forward(self, x):
        # 对输入 x 进行归一化处理,乘以缩放因子和 gamma
        return F.normalize(x, dim=-1) * self.scale * self.gamma

# 定义 FeedForward 类,继承自 Module 类
class FeedForward(Module):
    # 初始化方法,接受维度参数 dim 和倍数参数 mult,默认为 4
    def __init__(self, dim, mult=4):
        # 调用父类的初始化方法
        super().__init__()
        # 计算内部维度
        dim_inner = int(dim * mult)

        # 定义神经网络结构
        self.net = nn.Sequential(
            RMSNorm(dim),
            Linear(dim, dim_inner),
            nn.GELU(),
            Linear(dim_inner, dim)
        )

    # 前向传播方法,接受输入 x
    def forward(self, x):
        # 将输入 x 传入神经网络
        return self.net(x)

# 定义 Attention 类,继承自 Module 类
class Attention(Module):
    # 初始化方法,接受维度参数 dim,头部维度参数 dim_head,默认为 64,头部数量参数 heads,默认为 8,是否使用 Flash Attention 参数 use_flash_attn,默认为 False
    def __init__(
        self,
        dim,
        dim_head=64,
        heads=8,
        use_flash_attn=False
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 计算缩放因子
        self.scale = dim_head ** -0.5
        # 头部数量
        self.heads = heads
        # 内部维度
        dim_inner = dim_head * heads

        # 初始化 RMSNorm 和 RotaryEmbedding
        self.norm = RMSNorm(dim)
        self.rotary_emb = RotaryEmbedding(dim_head)

        # 初始化 Attend 层
        self.attend = Attend(causal=True, use_flash_attn=use_flash_attn)

        # 初始化线性层,用于计算 Q、K、V
        self.to_qkv = Linear(dim, dim_inner * 3)
        # 初始化线性层,用于输出
        self.to_out = Linear(dim_inner, dim)

    # 前向传播方法,接受输入 x
    def forward(self, x):
        # 获取输入 x 的倒数第二维度大小
        n = x.shape[-2]
        # 对输入 x 进行归一化处理
        x = self.norm(x)

        # 将输入 x 经过线性层得到 Q、K、V,并按头部维度拆分
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))

        # 获取旋转位置编码
        rotary_emb = self.rotary_emb(n)
        # 对 Q、K 应用旋转位置编码
        q, k = apply_rotary_pos_emb_qk(rotary_emb, q, k)

        # 进行注意力计算
        out = self.attend(q, k, v)

        # 重排输出维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 经过输出线性层
        return self.to_out(out)

# 定义 HierarchicalBlock 类,继承自 Module 类
class HierarchicalBlock(Module):
    # 初始化方法,接受维度参数 dim,头部维度参数 dim_head,默认为 64,头部数量参数 heads,默认为 8,窗口大小参数 window_size,默认为 None,压缩因子参数 compress_factor,默认为 1,步长参数 stride,默认为 1,FeedForward 倍数参数 ff_mult,默认为 4
    def __init__(
        self,
        dim,
        dim_head=64,
        heads=8,
        window_size=None,
        compress_factor=1,
        stride=1,
        ff_mult=4
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 步长
        self.stride = stride

        # 断言压缩因子为 2 的幂
        assert is_power_of_two(compress_factor)
        self.compress_factor = compress_factor
        self.no_compress = compress_factor == 1

        # 断言窗口大小为非负数
        assert not exists(window_size) or window_size >= 0
        self.has_attn = window_size != 0

        # 初始化注意力层
        self.attn = None

        if self.has_attn:
            attn_klass = Attention
            if exists(window_size):
                attn_klass = partial(LocalMHA, window_size=window_size)

            self.attn = attn_klass(dim=dim, dim_head=dim_head, heads=heads)

        # 初始化 FeedForward 层
        self.ff = FeedForward(dim=dim, mult=ff_mult)

    # 前向传播方法,接受输入 x
    def forward(self, x):
        c = self.compress_factor
        axial_dim = c // self.stride

        # 将输入 x 进行填充,使其长度为压缩因子的整数倍
        x, orig_seq_len = pad_seq_to_multiple(x, axial_dim)

        # 如果不需要压缩,则直接返回
        if not self.no_compress:
            x = rearrange(x, 'b (n c) d -> (b c) n d', c=axial_dim)

        # 如果存在注意力层,则进行注意力计算
        if exists(self.attn):
            x = self.attn(token_shift(x)) + x

        # 经过 FeedForward 层
        x = self.ff(token_shift(x)) + x

        # 如果不需要压缩,则重排维度
        if not self.no_compress:
            x = rearrange(x, '(b c) n d -> b (n c) d', c=axial_dim)

        # 返回结果,截取原始序列长度
        return x[:, :orig_seq_len]

# 定义 HierarchicalTransformer 类
class HierarchicalTransformer(Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        num_tokens,  # 标记数量
        dim,  # 向量维度
        depth,  # 深度
        seq_len = 2048,  # 序列长度,默认为2048
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        ff_mult = 4,  # FeedForward 层的倍数
        hierarchies = 1,  # 分层数量
        window_sizes = None,  # 窗口大小
        hierarchical_stride = 1,  # 分层步长
        hierarchy_merge_all = False,  # 是否将汇总的分层信息传递回所有分层或只传递给一个进行预测
        predict_hierarchy = None,  # 预测分层
        predict_use_all_hierarchy = False,  # 是否使用所有分层进行预测
        recon_loss_weight = 0.1,  # 重构损失权重
        hierarchical_ar_loss_weight = 0.25,  # 分层自回归损失权重
        ignore_index = 0,  # 忽略的索引
        use_flash_attn = False,  # 是否使用 Flash Attention
    @torch.no_grad()  # 禁用梯度计算
    @eval_decorator  # 评估装饰器
    def generate(
        self,
        prompt,  # 提示
        seq_len,  # 序列长度
        temperature = 1.0,  # 温度
        filter_thres = 0.9,  # 过滤阈值
        **kwargs  # 其他参数
    ):
        b, t, device = *prompt.shape, prompt.device

        out = prompt

        # 生成序列
        for _ in range(seq_len):
            logits = self.forward(out[:, -self.seq_len:], **kwargs)[:, -1]
            filtered_logits = top_k(logits, thres = filter_thres)
            sample = gumbel_sample(filtered_logits, temperature = temperature)
            sample = rearrange(sample, 'b -> b 1')
            out = torch.cat((out, sample), dim = -1)

        return out[:, t:]  # 返回生成的序列

    @property
    def device(self):
        return next(self.parameters()).device  # 返回模型参数的设备

    # 前向传播函数
    def forward(
        self,
        ids,  # 标识符
        return_loss = False,  # 是否返回损失
        return_hierarchical_token_embeds = False,  # 是否返回分层标记嵌入
        return_hierarchical_embeds = False,  # 是否返回分层嵌入
        ablate_hierarchical_merge = False  # 是否消融分层合并
        ):
        """
        einops notation:

        b - batch
        n - sequence length
        c - compression factor
        d - dimension
        """

        # 如果是训练阶段,预测序列中的下一个标记

        if return_loss:
            ids, labels = ids[:, :-1], ids[:, 1:]

        # 断言序列长度

        assert ids.shape[-1] <= self.seq_len

        # 获取标记嵌入,并填充到压缩因子的倍数

        x = self.token_emb(ids)

        # 对于每个层次结构,适当地压缩标记嵌入到层次嵌入中

        tokens = []

        for compress in self.compressors:
            tokens.append(compress(x))

        # 后嵌入规范化

        tokens = apply_fns(self.post_token_emb_norms, tokens)

        # 如果想要所有压缩后的标记嵌入
        # 仅用于研究空间

        if return_hierarchical_token_embeds:
            return tokens

        # 层次结构

        for layer, merge in zip_longest(self.layers, self.hierarchical_merges):

            tokens = apply_fns(layer, tokens)

            # 汇总所有层次的信息
            # 然后更新将用于进行最终自回归预测的标记

            if not self.need_hierarchical_merge or ablate_hierarchical_merge:
                continue

            pooled = merge(tokens)

            if self.hierarchy_merge_all:
                tokens = [(t + p[..., ::s, :]) for t, p, s in zip(tokens, pooled.split(self.dims, dim = -1), self.h_strides)]
            else:
                predict_tokens = tokens[self.predict_hierarchy_index]
                predict_tokens = predict_tokens + pooled
                tokens[self.predict_hierarchy_index] = predict_tokens

        # 最终规范化嵌入

        embeds = apply_fns(self.norms, tokens)

        # 如果想要所有规范化的层次嵌入

        if return_hierarchical_embeds:
            return embeds

        # 选择将进行预测的层次嵌入

        if self.predict_use_all_hierarchy:
            predict_embed = hierarchical_cat(embeds, self.h_strides)
        else:
            predict_embed = embeds[self.predict_hierarchy_index]

        # 用于预测下一个标记的对数

        logits = self.to_logits(predict_embed)

        if not return_loss:
            return logits

        # 自回归损失(预测编码)

        logits = rearrange(logits, 'b n c -> b c n')
        ce_loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)

        # 层次标记的重建损失

        recon_losses = self.zeros.requires_grad_()

        if self.should_recon:
            for compress, t in zip(self.compressors, embeds):
                recon_loss = compress.recon(t, ids)
                recon_losses = recon_losses + recon_loss

        # 层次自回归损失

        hierarchical_ar_losses = self.zeros.requires_grad_()

        for h_embed, maybe_h_pred_linear in zip(embeds, self.to_hierarchical_preds):
            if not exists(maybe_h_pred_linear):
                continue

            h_pred = maybe_h_pred_linear(h_embed)
            h_ar_loss = cosine_sim_loss(h_pred[:, :-1], h_embed[:, 1:])

            hierarchical_ar_losses = hierarchical_ar_losses + h_ar_loss

        # 总损失

        total_loss = ce_loss + \
                     recon_losses * self.recon_loss_weight + \
                     hierarchical_ar_losses * self.hierarchical_ar_loss_weight

        return total_loss, (ce_loss, recon_losses, hierarchical_ar_losses)

.\lucidrains\simple-hierarchical-transformer\simple_hierarchical_transformer\__init__.py

# 从 simple_hierarchical_transformer.simple_hierarchical_transformer 模块中导入 HierarchicalTransformer 类
from simple_hierarchical_transformer.simple_hierarchical_transformer import HierarchicalTransformer

.\lucidrains\simple-hierarchical-transformer\train.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义的简单分层Transformer模型
from simple_hierarchical_transformer import HierarchicalTransformer

# 导入加速器库
from accelerate import Accelerator

# 初始化加速器
accelerator = Accelerator()

# 获取设备信息和打印函数
device = accelerator.device
acc_print = accelerator.print

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 2
GRADIENT_ACCUMULATE_EVERY = 8
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
SEQ_LEN = 2048
GENERATE_LENGTH = 1024

# 定义辅助函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 实例化Transformer模型
model = HierarchicalTransformer(
    num_tokens = 256,
    dim = 1024,
    depth = 8,
    seq_len = SEQ_LEN,
    hierarchies = (1, 2),
    window_sizes = (32, 64),
    use_flash_attn = True
).to(device)

# 准备enwik8数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 准备模型、优化器和数据加载器
model, optim, train_loader, val_loader = accelerator.prepare(
    model, optim, train_loader, val_loader
)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss, (ce_loss, recon_loss, prophet_loss) = model(next(train_loader), return_loss = True)
        accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    acc_print(f"training loss: {ce_loss.item()}")
    accelerator.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            _, (ce_loss, *_) = model(next(val_loader), return_loss = True)
            acc_print(f"validation loss: {ce_loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        acc_print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        acc_print(output_str, "\n")

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Sinkhorn Transformer with Deepspeed for Enwik8

Deepspeed is the framework Microsoft used to train the world's largest Attention model (17GB) to date. They have open sourced it, and it works with Sinkhorn Transformers!

  1. First install Deepspeed following instructions from their official repository https://github.com/microsoft/DeepSpeed

  2. Run the following command in this folder

$ deepspeed train.py --deepspeed --deepspeed_config ds_config.json

.\lucidrains\sinkhorn-transformer\examples\enwik8_deepspeed\train.py

import deepspeed  # 导入deepspeed库

from sinkhorn_transformer import SinkhornTransformerLM  # 从sinkhorn_transformer库中导入SinkhornTransformerLM类
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper  # 从sinkhorn_transformer库中导入AutoregressiveWrapper类

import argparse  # 导入argparse库
import random  # 导入random库
import tqdm  # 导入tqdm库
import gzip  # 导入gzip库
import numpy as np  # 导入numpy库,并重命名为np
import torch  # 导入torch库
import torch.optim as optim  # 从torch库中导入optim模块
from torch.nn import functional as F  # 从torch库中导入functional模块,并重命名为F
from torch.utils.data import DataLoader, Dataset  # 从torch.utils.data库中导入DataLoader和Dataset类

def add_argument():  # 定义函数add_argument
    parser=argparse.ArgumentParser(description='enwik8')  # 创建一个ArgumentParser对象,设置描述信息为'enwik8'

    parser.add_argument('--with_cuda', default=False, action='store_true',  # 添加一个参数'--with_cuda',默认值为False,如果存在则设置为True
                        help='use CPU in case there\'s no GPU support')  # 设置参数'--with_cuda'的帮助信息
    parser.add_argument('--use_ema', default=False, action='store_true',  # 添加一个参数'--use_ema',默认值为False,如果存在则设置为True
                        help='whether use exponential moving average')  # 设置参数'--use_ema'的帮助信息
    parser.add_argument('-b', '--batch_size', default=32, type=int,  # 添加一个参数'-b'或'--batch_size',默认值为32,类型为整数
                        help='mini-batch size (default: 32)')  # 设置参数'-b'或'--batch_size'的帮助信息
    parser.add_argument('-e', '--epochs', default=30, type=int,  # 添加一个参数'-e'或'--epochs',默认值为30,类型为整数
                        help='number of total epochs (default: 30)')  # 设置参数'-e'或'--epochs'的帮助信息
    parser.add_argument('--local_rank', type=int, default=-1,  # 添加一个参数'--local_rank',类型为整数,默认值为-1
                       help='local rank passed from distributed launcher')  # 设置参数'--local_rank'的帮助信息

    parser = deepspeed.add_config_arguments(parser)  # 调用deepspeed库中的add_config_arguments函数
    args = parser.parse_args()  # 解析命令行参数
    return args  # 返回参数args

# constants

VALIDATE_EVERY  = 100  # 定义常量VALIDATE_EVERY为100
GENERATE_EVERY  = 500  # 定义常量GENERATE_EVERY为500
GENERATE_LENGTH = 1024  # 定义常量GENERATE_LENGTH为1024
SEQ_LEN = 4096  # 定义常量SEQ_LEN为4096

# helpers

def decode_token(token):  # 定义函数decode_token,接受一个token参数
    return str(chr(max(32, token)))  # 返回ASCII码对应的字符,如果小于32则返回空格

def decode_tokens(tokens):  # 定义函数decode_tokens,接受一个tokens参数
    return ''.join(list(map(decode_token, tokens)))  # 将tokens中的每个token转换为字符,并拼接成字符串

# instantiate model

model = SinkhornTransformerLM(  # 创建SinkhornTransformerLM模型对象
    num_tokens = 256,  # 设置num_tokens参数为256
    emb_dim = 128,  # 设置emb_dim参数为128
    dim = 512,  # 设置dim参数为512
    depth = 8,  # 设置depth参数为8
    max_seq_len = SEQ_LEN,  # 设置max_seq_len参数为SEQ_LEN
    heads = 8,  # 设置heads参数为8
    bucket_size = 128,  # 设置bucket_size参数为128
    ff_chunks = 10,  # 设置ff_chunks参数为10
    causal = True,  # 设置causal参数为True
    reversible = True,  # 设置reversible参数为True
    attn_dropout = 0.1,  # 设置attn_dropout参数为0.1
    n_local_attn_heads = 4  # 设置n_local_attn_heads参数为4
)

model = AutoregressiveWrapper(model)  # 使用AutoregressiveWrapper对模型进行包装
model.cuda()  # 将模型移动到GPU上

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:  # 打开enwik8.gz文件
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)  # 从文件中读取数据,转换为numpy数组
    trX, vaX = np.split(X, [int(90e6)])  # 将数据分割为训练集和验证集
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)  # 将数据转换为PyTorch张量

class TextSamplerDataset(Dataset):  # 定义TextSamplerDataset类,继承自Dataset类
    def __init__(self, data, seq_len):  # 定义初始化方法,接受data和seq_len参数
        super().__init__()  # 调用父类的初始化方法
        self.data = data  # 设置数据属性为传入的data
        self.seq_len = seq_len  # 设置序列长度属性为传入的seq_len

    def __getitem__(self, index):  # 定义获取数据项方法,接受index参数
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))  # 生成随机起始位置
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()  # 获取完整序列
        return full_seq  # 返回完整序列

    def __len__(self):  # 定义长度方法
        return self.data.size(0) // self.seq_len  # 返回数据长度除以序列长度的整数部分

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)  # 创建训练集数据集对象
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)  # 创建验证集数据集对象

# setup deepspeed

cmd_args = add_argument()  # 调用add_argument函数,获取命令行参数
model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=model.parameters(),  training_data=train_dataset)  # 使用deepspeed初始化模型引擎、优化器、训练数据加载器

# training

for i, data in enumerate(trainloader):  # 遍历训练数据加载器
    model_engine.train()  # 设置模型为训练模式
    data = data.to(model_engine.local_rank)  # 将数据移动到指定设备
    loss = model_engine(data, return_loss = True)  # 计算损失
    model_engine.backward(loss)  # 反向传播
    model_engine.step()  # 更新模型参数
    print(loss.item() * 4)  # 打印损失值

    if i % VALIDATE_EVERY == 0:  # 每隔VALIDATE_EVERY次迭代进行验证
        model.eval()  # 设置模型为评估模式
        with torch.no_grad():  # 禁用梯度计算
            inp = random.choice(val_dataset)  # 从验证集中随机选择一个样本
            loss = model(inp[None, :].cuda(), return_loss = True)  # 计算验证集上的损失
            print(f'validation loss: {loss.item()}')  # 打印验证损失值

    if model_engine.local_rank == 0 and i % GENERATE_EVERY == 0:  # 如果是主进程且每隔GENERATE_EVERY次迭代生成样本
        model.eval()  # 设置��型为评估模式
        inp = random.choice(val_dataset)[:-1]  # 从验证集中随机选择一个样本,并去掉最后一个字符
        prime = decode_tokens(inp)  # 解码得到的输入序列
        print(f'%s \n\n %s', (prime, '*' * 100))  # 打印输入序列和分隔符

        sample = model.generate(inp.cuda(), GENERATE_LENGTH)  # 生成样本
        output_str = decode_tokens(sample)  # 解码生成的样本
        print(output_str)  # 打印生成的样本
posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(6)  评论(0编辑  收藏  举报