Lucidrains-系列项目源码解析-三十-

Lucidrains 系列项目源码解析(三十)

.\lucidrains\nuwa-pytorch\nuwa_pytorch\optimizer.py

# 导入 torch 库
import torch
# 从 torch.optim 中导入 AdamW 和 Adam 优化器

# 分离可进行权重衰减的参数
def separate_weight_decayable_params(params):
    # 找出参数中维度小于 2 的参数,即不需要进行权重衰减的参数
    no_wd_params = set([param for param in params if param.ndim < 2])
    # 计算需要进行权重衰减的参数
    wd_params = set(params) - no_wd_params
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 3e-4,
    wd = 1e-1,
    filter_by_requires_grad = False
):
    # 如果需要根据 requires_grad 过滤参数
    if filter_by_requires_grad:
        # 过滤出 requires_grad 为 True 的参数
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果权重衰减参数为 0,则使用 Adam 优化器
    if wd == 0:
        return Adam(list(params), lr = lr)

    # 将参数转换为集合
    params = set(params)
    # 分离出需要进行权重衰减的参数和不需要进行权重衰减的参数
    wd_params, no_wd_params = separate_weight_decayable_params(params)

    # 构建参数组,其中包含需要进行权重衰减的参数和不需要进行权重衰减的参数
    param_groups = [
        {'params': list(wd_params)},
        {'params': list(no_wd_params), 'weight_decay': 0},
    ]

    # 使用 AdamW 优化器,设置学习率和权重衰减参数
    return AdamW(param_groups, lr = lr, weight_decay = wd)

.\lucidrains\nuwa-pytorch\nuwa_pytorch\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数

# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    # 遍历匹配的键
    for key in matched_keys:
        val = args[key]
        # 遍历路由后的参数列表和路由器中的路由
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            # 根据路由将参数添加到对应的函数参数中
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的保存和设置随机数生成器
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx

class _ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, x, blocks, args):
        ctx.args = args
        for block, kwarg in zip(blocks, args):
            x = block(x, **kwarg)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    @staticmethod
    # 定义反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, dy):
        # 获取上下文中的 y 和 args
        y = ctx.y
        args = ctx.args
        # 反向遍历上下文中的 blocks 和 args
        for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
            # 调用每个 block 的反向传播函数,更新 y 和 dy
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度
        return dy, None, None
# 定义一个可逆序列的神经网络模块
class ReversibleSequence(nn.Module):
    # 初始化函数,接受一组块和参数路由作为输入
    def __init__(self, blocks, args_route = {}):
        super().__init__()
        # 将参数路由保存在对象中
        self.args_route = args_route
        # 创建一个包含多个可逆块的模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 在最后一个维度上将输入张量 x 进行拼接
        x = torch.cat([x, x], dim=-1)

        # 获取模块列表和参数路由
        blocks = self.blocks
        args = route_args(self.args_route, kwargs, len(blocks))
        # 将参数转换为字典形式
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        # 将块和参数组成元组列表
        layers_and_args = list(zip(blocks, args))

        # 调用自定义的可逆函数 _ReversibleFunction 的前向传播方法
        out =  _ReversibleFunction.apply(x, blocks, args)
        # 在最后一个维度上将输出张量拆分成两部分,然后对它们进行求和
        return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)

.\lucidrains\nuwa-pytorch\nuwa_pytorch\reversible_video_audio.py

import torch
import torch.nn as nn
from torch.autograd.function import Function
from contextlib import contextmanager

from nuwa_pytorch.reversible import Deterministic

from einops import reduce

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 上下文管理器,不执行任何操作
@contextmanager
def null_context():
    yield

# 在指定维度上按索引分割张量
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# reversible self attention block

class ReversibleSelfAttnBlock(nn.Module):
    def __init__(self, f, g, j, k):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)
        self.j = Deterministic(j)
        self.k = Deterministic(k)        

    def forward(self, x, m, _reverse = True, **kwargs):
        x1, x2 = torch.chunk(x, 2, dim = 2)
        m1, m2 = torch.chunk(m, 2, dim = 2)
        y1, y2, n1, n2 = None, None, None, None

        fn_context = torch.no_grad if _reverse else null_context
        record_rng = self.training and _reverse

        with fn_context():
            y1 = x1 + self.f(x2, record_rng = record_rng)
            y2 = x2 + self.g(y1, record_rng = record_rng)
            n1 = m1 + self.j(m2, record_rng = record_rng)
            n2 = m2 + self.k(n1, record_rng = record_rng)

        return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)

    def backward_pass(self, y, n, dy, dn, **kwargs):
        y1, y2 = torch.chunk(y, 2, dim = 2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim = 2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng = True)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng = True)
            torch.autograd.backward(fx2, dx1, retain_graph = True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim = 2)
            dx = torch.cat([dx1, dx2], dim = 2)

        n1, n2 = torch.chunk(n, 2, dim = 2)
        del n

        dn1, dn2 = torch.chunk(dn, 2, dim = 2)
        del dn

        with torch.enable_grad():
            n1.requires_grad = True
            gn1 = self.k(n1, set_rng = True)
            torch.autograd.backward(gn1, dn2)

        with torch.no_grad():
            m2 = n2 - gn1
            del n2, gn1

            dm1 = dn1 + n1.grad
            del dn1
            n1.grad = None

        with torch.enable_grad():
            m2.requires_grad = True
            fm2 = self.j(m2, set_rng = True)
            torch.autograd.backward(fm2, dm1, retain_graph=True)

        with torch.no_grad():
            m1 = n1 - fm2
            del n1, fm2

            dm2 = dn2 + m2.grad
            del dn2
            m2.grad = None

            m = torch.cat([m1, m2.detach()], dim = 2)
            dm = torch.cat([dm1, dm2], dim = 2)

        return x, m, dx, dm

class ReversibleCrossAttnBlock(nn.Module):
    def __init__(self, f, g, j, k):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)
        self.j = Deterministic(j)
        self.k = Deterministic(k)        
    # 前向传播函数,接受输入 x 和 m,以及一系列参数,返回处理后的结果
    def forward(self, x, m, *, context, context_mask, video_mask = None, audio_mask = None, _reverse = True, **kwargs):
        # 将输入 x 和 m 按照第二维度分成两部分
        x1, x2 = torch.chunk(x, 2, dim = 2)
        m1, m2 = torch.chunk(m, 2, dim = 2)
        y1, y2, n1, n2 = None, None, None, None

        # 根据 _reverse 参数选择是否启用梯度记录
        fn_context = torch.no_grad if _reverse else null_context
        record_rng = self.training and _reverse

        # 使用 fn_context 上下文管理器,根据 _reverse 参数选择是否启用梯度记录
        with fn_context():
            # 计算 y1 和 y2
            y1 = x1 + self.f(x2, context = context, context_mask = context_mask, mask = video_mask, record_rng = record_rng)
            y2 = x2 + self.g(y1, record_rng = record_rng)
            # 计算 n1 和 n2
            n1 = m1 + self.j(m2, context = context, context_mask = context_mask, mask = audio_mask, record_rng = record_rng)
            n2 = m2 + self.k(n1, record_rng = record_rng)

        # 返回拼接后的结果
        return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)

    # 反向传播函数,接受输入 y, n, dy, dn,以及一系列参数,返回处理后的结果
    def backward_pass(self, y, n, dy, dn, *, context, context_mask, video_mask = None, audio_mask = None, **kwargs):
        # 将输入 y 和 n 按照第二维度分成两部分
        y1, y2 = torch.chunk(y, 2, dim = 2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim = 2)
        del dy

        # 启用梯度记录
        with torch.enable_grad():
            y1.requires_grad = True
            # 计算 gy1
            gy1 = self.g(y1, set_rng = True)
            # 反向传播计算 dy2
            torch.autograd.backward(gy1, dy2)

        # 使用 torch.no_grad 上下文管理器,计算中间结果
        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        # 启用梯度记录
        with torch.enable_grad():
            x2.requires_grad = True
            # 计算 fx2
            fx2 = self.f(x2, set_rng = True, context = context, context_mask = context_mask, mask = video_mask)
            # 反向传播计算 dx1
            torch.autograd.backward(fx2, dx1, retain_graph = True)

        # 使用 torch.no_grad 上下文管理器,计算中间结果
        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim = 2)
            dx = torch.cat([dx1, dx2], dim = 2)

        # 将输入 n 按照第二维度分成两部分
        n1, n2 = torch.chunk(n, 2, dim = 2)
        del n

        dn1, dn2 = torch.chunk(dn, 2, dim = 2)
        del dn

        # 启用梯度记录
        with torch.enable_grad():
            n1.requires_grad = True
            # 计算 gn1
            gn1 = self.k(n1, set_rng = True)
            # 反向传播计算 dn2
            torch.autograd.backward(gn1, dn2)

        # 使用 torch.no_grad 上下文管理器,计算中间结果
        with torch.no_grad():
            m2 = n2 - gn1
            del n2, gn1

            dm1 = dn1 + n1.grad
            del dn1
            n1.grad = None

        # 启用梯度记录
        with torch.enable_grad():
            m2.requires_grad = True
            # 计算 fm2
            fm2 = self.j(m2, set_rng = True, context = context, context_mask = context_mask, mask = audio_mask)
            # 反向传播计算 dm1
            torch.autograd.backward(fm2, dm1, retain_graph=True)

        # 使用 torch.no_grad 上下文管理器,计算中间结果
        with torch.no_grad():
            m1 = n1 - fm2
            del n1, fm2

            dm2 = dn2 + m2.grad
            del dn2
            m2.grad = None

            m = torch.cat([m1, m2.detach()], dim = 2)
            dm = torch.cat([dm1, dm2], dim = 2)

        # 返回结果
        return x, m, dx, dm
# 可逆交叉模态注意力块

class ReversibleCrossModalityAttnBlock(nn.Module):
    def __init__(self, f, g, j, k):
        super().__init__()
        self.f = Deterministic(f)  # 初始化可逆函数 f
        self.g = Deterministic(g)  # 初始化可逆函数 g
        self.j = Deterministic(j)  # 初始化可逆函数 j
        self.k = Deterministic(k)  # 初始化可逆函数 k

    def forward(self, x, m, *, video_mask = None, audio_mask = None, _reverse = True, **kwargs):
        x1, x2 = torch.chunk(x, 2, dim = 2)  # 将输入 x 沿着第二维度分成两部分 x1 和 x2
        m1, m2 = torch.chunk(m, 2, dim = 2)  # 将输入 m 沿着第二维度分成两部分 m1 和 m2
        y1, y2, n1, n2 = None, None, None, None

        fn_context = torch.no_grad if _reverse else null_context  # 根据 _reverse 的值选择上下文管理器
        record_rng = self.training and _reverse

        with fn_context():
            y1 = x1 + self.f(x2, m2, record_rng = record_rng, mask = video_mask, context_mask = audio_mask)  # 计算 y1
            y2 = x2 + self.k(y1, record_rng = record_rng)  # 计算 y2
            n1 = m1 + self.j(m2, y2, record_rng = record_rng, mask = audio_mask, context_mask = video_mask)  # 计算 n1
            n2 = m2 + self.g(n1, record_rng = record_rng)  # 计算 n2

        return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)  # 返回拼接后的结果

    def backward_pass(self, y, n, dy, dn, video_mask = None, audio_mask = None, **kwargs):
        n1, n2 = torch.chunk(n, 2, dim = 2)  # 将输入 n 沿着第二维度分成两部分 n1 和 n2
        del n

        dn1, dn2 = torch.chunk(dn, 2, dim = 2)  # 将输入 dn 沿着第二维度分成两部分 dn1 和 dn2
        del dn

        y1, y2 = torch.chunk(y, 2, dim = 2)  # 将输入 y 沿着第二维度分成两部分 y1 和 y2
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim = 2)  # 将输入 dy 沿着第二维度分成两部分 dy1 和 dy2
        del dy

        with torch.enable_grad():
            n1.requires_grad = True
            gn1 = self.g(n1, set_rng = True)  # 计算 gn1
            torch.autograd.backward(gn1, dn2)  # 反向传播计算梯度

        with torch.no_grad():
            m2 = n2 - gn1  # 计算 m2
            del n2, gn1

            dm1 = dn1 + n1.grad  # 计算 dm1
            del dn1
            n1.grad = None

        with torch.enable_grad():
            m2.requires_grad = True
            y2.requires_grad = True
            fm2 = self.j(m2, y2, set_rng=True, mask = audio_mask, context_mask = video_mask)  # 计算 fm2
            torch.autograd.backward(fm2, dm1)  # 反向传播计算梯度

        with torch.no_grad():
            m1 = n1 - fm2  # 计算 m1
            del n1, fm2

            dm2 = dn2 + m2.grad  # 计算 dm2
            dx2 = dy2 + y2.grad  # 计算 dx2
            del dn2
            del dy2
            m2.grad = None
            y2.grad = None

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.k(y1, set_rng = True)  # 计算 gy1
            torch.autograd.backward(gy1, dx2)  # 反向传播计算梯度

        with torch.no_grad():
            x2 = y2 - gy1  # 计算 x2
            del y2, gy1

            dx1 = dy1 + y1.grad  # 计算 dx1
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            m2.requires_grad = True
            fx2 = self.f(x2, m2, set_rng = True, mask = video_mask, context_mask = audio_mask)  # 计算 fx2
            torch.autograd.backward(fx2, dx1)  # 反向传播计算梯度

        with torch.no_grad():
            x1 = y1 - fx2  # 计算 x1
            del y1, fx2

            dx2 = dx2 + x2.grad  # 计算 dx2
            dm2 = dm2 + m2.grad  # 计算 dm2
            x2.grad = None
            m2.grad = None

        with torch.no_grad():
            m = torch.cat([m1, m2.detach()], dim = 2)  # 拼接 m1 和 m2
            dm = torch.cat([dm1, dm2], dim = 2)  # 拼接 dm1 和 dm2

            x = torch.cat([x1, x2.detach()], dim = 2)  # 拼接 x1 和 x2
            dx = torch.cat([dx1, dx2], dim = 2)  # 拼接 dx1 和 dx2

        return x, m, dx, dm

# 反向和非反向函数

class ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, inp, ind, blocks, kwargs):
        x, m = split_at_index(1, ind, inp)  # 在指定索引处分割输入

        for block in blocks:
            x, m = block(x, m, _reverse = True, **kwargs)  # 对每个块进行前向传播

        ctx.blocks = blocks
        ctx.kwargs = kwargs
        ctx.ind = ind
        ctx.save_for_backward(x.detach(), m.detach())
        return torch.cat((x, m), dim = 1)  # 拼接结果

    @staticmethod
    # 定义一个反向传播函数,接受上下文和梯度作为参数
    def backward(ctx, d):
        # 从上下文中获取索引、块和关键字参数
        ind = ctx.ind
        blocks = ctx.blocks
        kwargs = ctx.kwargs
        # 将梯度按照索引分割成两部分
        dy, dn = split_at_index(1, ind, d)
        # 从上下文中获取保存的张量 y 和 n
        y, n = ctx.saved_tensors

        # 对块列表进行反向遍历
        for block in blocks[::-1]:
            # 调用每个块的反向传播函数,更新 y、n、dy 和 dn
            y, n, dy, dn = block.backward_pass(y, n, dy, dn, **kwargs)

        # 将分割后的梯度拼接在一起
        d = torch.cat((dy, dn), dim=1)
        # 返回更新后的梯度和 None(因为没有额外的返回值)
        return d, None, None, None
# 将 ReversibleFunction.apply 赋值给 reversible_apply
reversible_apply = ReversibleFunction.apply

# 定义不可逆应用函数,接受输入、索引、块和关键字参数
def irreversible_apply(inputs, ind, blocks, kwargs):
    # 在索引处将输入分割为 x 和 m
    x, m = split_at_index(1, ind, inputs)
    # 对每个块应用,更新 x 和 m
    for block in blocks:
        x, m = block(x, m, _reverse = False, **kwargs)
    # 拼接 x 和 m,返回结果
    return torch.cat((x, m), dim = 1)

# 主要的可逆序列类
class DualModalityReversibleSequence(nn.Module):
    # 初始化函数,接受输入块和块类型
    def __init__(self, input_blocks, block_types):
        super().__init__()
        self.block_types = block_types
        blocks = nn.ModuleList([])

        # 遍历输入块和块类型,根据类型选择可逆类别
        for block, block_type in zip(input_blocks, block_types):
            if block_type == 'intra_modality_self_attn':
                reversible_klass = ReversibleSelfAttnBlock
            elif block_type == 'intra_modality_cross_attn':
                reversible_klass = ReversibleCrossAttnBlock
            elif block_type == 'inter_modality_cross_attn':
                reversible_klass = ReversibleCrossModalityAttnBlock
            else:                
                raise ValueError(f'unknown layer type {block_type}')

            blocks.append(reversible_klass(*block))

        self.blocks = blocks

    # 前向传播函数,接受视频、音频、上下文和掩码等参数
    def forward(
        self,
        video,
        audio,
        *,
        context,
        context_mask = None,
        video_mask = None,
        audio_mask = None,
        reverse = True
    ):  
        blocks = self.blocks
        # 将视频和音频拼接起来
        video, audio = list(map(lambda t: torch.cat((t, t), dim = -1), (video, audio)))
        kwargs = {'context': context, 'context_mask': context_mask, 'video_mask': video_mask, 'audio_mask': audio_mask}

        # 根据是否可逆选择应用函数
        fn = reversible_apply if reverse else irreversible_apply
        ind = video.shape[1]
        inp = torch.cat((video, audio), dim = 1)
        out = fn(inp, ind, blocks, kwargs)
        # 将输出拆分为视频和音频
        video, audio  = split_at_index(1, ind, out)
        # 对视频和音频应用 reduce 函数,返回结果
        return list(map(lambda t: reduce(t, 'b n (c d) -> b n d', 'mean', c = 2), (video, audio)))

.\lucidrains\nuwa-pytorch\nuwa_pytorch\train_nuwa.py

# 从 random 模块中导入 randrange 函数
from random import randrange
# 从 pathlib 模块中导入 Path 类
from pathlib import Path

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.utils.data 模块中导入 Dataset 和 DataLoader 类
from torch.utils.data import Dataset, DataLoader
# 从 torch.nn.utils.rnn 模块中导入 pad_sequence 函数
from torch.nn.utils.rnn import pad_sequence
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 从 tqdm 模块中导入 tqdm 函数
from tqdm import tqdm
# 导入 numpy 库
import numpy as np
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree

# 导入 nuwa_pytorch 库中的 tokenizer 模块和 optimizer 模块
from nuwa_pytorch.tokenizer import tokenizer
from nuwa_pytorch.optimizer import get_optimizer
# 导入 nuwa_pytorch 库中的 image_utils 模块
from nuwa_pytorch.image_utils import gif_to_tensor
# 从 nuwa_pytorch 模块中导入 NUWA 类

# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 从 torchvision.utils 模块中导入 make_grid 和 save_image 函数

# 辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 空操作函数
def noop(*args, **kwargs):
    pass

# 生成循环迭代器的函数
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组的函数
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是否为是或否的函数
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积日志的函数
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 数据加载器辅助函数

# 数据填充函数
def pad_collate_fn(batch):
    texts, videos = zip(*batch)
    return pad_sequence(texts, batch_first = True), torch.stack(videos)

# 数据处理流水线函数

# 将视频张量数据集转换为索引的函数
def convert_video_tensor_dataset_to_indices(
    *,
    vae,
    raw_video_dataset,
    num_frames,
    path,
):
    vae_device = next(vae.parameters()).device
    num_videos = len(raw_video_dataset)
    assert num_videos > 0, 'there must be at least 1 video'

    fmap_size = vae.image_size // (vae.num_layers ** 2)
    shape = (num_videos, num_frames * fmap_size * fmap_size)

    video_indices_memmap = np.memmap(path, mode = 'w+', dtype = np.int64, shape = shape)

    for ind in tqdm(range(num_videos)):
        _, video = raw_video_dataset[ind]
        video = rearrange(video, '... -> 1 ...')
        video = video.to(vae_device)
        indices = vae.get_video_indices(video)
        indices = rearrange(indices, '1 f h w -> (f h w)')
        video_indices_memmap[ind] = indices.cpu().numpy()

    print(f'completed conversion of {num_videos} videos to indices at {path}')

# 数据集类

# Mnist 数据集类
class MnistDataset(Dataset):
    def __init__(
        self,
        num_videos,
        videos_memmap_path,
        text_memmap_path,
        num_digits = 2,
        num_frames = 10,
        image_size = 64,
        channels = 1,
        random_rotate = False
    ):
        super().__init__()
        self.num_videos = num_videos
        self.videos_memmap = np.memmap(videos_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_frames, channels, image_size, image_size))
        self.text_memmap = np.memmap(text_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_digits))
        self.random_rotate = random_rotate

    def __len__(self):
        return self.num_videos

    def __getitem__(self, idx):
        video = torch.from_numpy(self.videos_memmap[idx].copy()).float()
        label = torch.from_numpy(self.text_memmap[idx].copy())

        video /= 255
        video = video.to(torch.float32)

        text = tokenizer.encode(' '.join(map(str, label.tolist())))
        text = torch.Tensor(text).long()

        if self.random_rotate:
            video = T.functional.rotate(video, choice([0, 90, 180, 270]))

        return text, video

# 视频索引数据集类
class VideoIndicesDataset(Dataset):
    def __init__(
        self,
        *,
        videos_memmap_path,
        text_memmap_path,
        vae,
        num_videos,
        num_frames,
        num_digits = 2,
    ):
        self.num_videos = num_videos
        fmap_size = vae.image_size // (vae.num_layers ** 2)
        self.videos_memmap = np.memmap(videos_memmap_path, mode = 'r', dtype = np.int64, shape = (num_videos, num_frames * (fmap_size ** 2)))
        self.text_memmap = np.memmap(text_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_digits))

    def __len__(self):
        return self.num_videos
    # 定义一个特殊方法,用于获取数据集中指定索引位置的数据
    def __getitem__(self, idx):
        # 从内存映射中读取视频数据,并转换为PyTorch张量
        video = torch.from_numpy(self.videos_memmap[idx].copy())
        # 从内存映射中读取文本数据,并转换为PyTorch张量
        text = torch.from_numpy(self.text_memmap[idx].copy())

        # 将文本数据转换为字符串,使用空格连接后编码为token,再转换为PyTorch张量
        text = tokenizer.encode(' '.join(map(str, text.tolist())))
        text = torch.Tensor(text).long()

        # 将视频数据转换为长整型张量
        video = video.long()
        # 返回处理后的文本和视频数据
        return text, video
# 从视频文件夹中创建用于训练的数据集类
class GifVideoDataset(Dataset):
    def __init__(
        self,
        *,
        folder,  # 视频文件夹路径
        channels = 1  # 通道数,默认为1
    ):
        # 将文件夹路径转换为 Path 对象
        folder = Path(folder)
        # 获取所有 GIF 文件和对应的文本文件
        gifs = folder.glob('**/*.gif')
        txts = folder.glob('**/*.txt')

        # 获取 GIF 文件和文本文件的路径前缀
        gif_path_stems = set(map(lambda t: str(t.with_suffix('')), gifs))
        txt_path_stems = set(map(lambda t: str(t.with_suffix('')), txts))
        # 获取共同的路径前缀作为数据集的路径
        self.path_stems = list(gif_path_stems.intersection(txt_path_stems))

        self.channels = channels  # 设置通道数
        print(f'{len(self.path_stems)} video / text pairs found')  # 打印找到的视频/文本对数量

    def __len__(self):
        return len(self.path_stems)  # 返回数据集的长度

    def __getitem__(self, idx):
        path_stem = self.path_stems[idx]  # 获取指定索引的路径前缀

        txt_path = Path(f'{path_stem}.txt')  # 构建文本文件路径
        txt_str = txt_path.read_text()  # 读取文本文件内容
        text_tensor = torch.Tensor(tokenizer.encode(txt_str)).long()  # 将文本内容编码为张量

        video_tensor = gif_to_tensor(f'{path_stem}.gif', channels = self.channels)  # 将 GIF 文件转换为张量
        return text_tensor, video_tensor  # 返回文本张量和视频张量的元组

# 训练类
class NUWATrainer(nn.Module):
    def __init__(
        self,
        *,
        nuwa,  # NUWA 模型实例
        dataset,  # 数据集实例
        num_train_steps,  # 训练步数
        lr = 3e-4,  # 学习率,默认为 3e-4
        wd = 0.01,  # 权重衰减,默认为 0.01
        batch_size = 4,  # 批量大小,默认为 4
        grad_accum_every = 8,  # 梯度累积间隔,默认为 8
        max_grad_norm = 0.5,  # 最大梯度范数,默认为 0.5
        save_model_every = 2500,  # 每隔多少步保存模型,默认为 2500
        save_results_every = 1000,  # 每隔多少步保存结果,默认为 1000
        results_folder = './results-nuwa',  # 结果文件夹路径,默认为 './results-nuwa'
        num_sampled_frames = float('inf')  # 抽样帧数,默认为无穷大
    ):
        super().__init__()
        assert isinstance(nuwa, NUWA), 'nuwa must be an instance of NUWA'  # 断言 nuwa 必须是 NUWA 类的实例
        self.nuwa = nuwa  # 设置 NUWA 模型实例

        self.steps = 0  # 训练步数初始化为 0
        self.num_train_steps = num_train_steps  # 设置训练步数
        self.batch_size = batch_size  # 设置批量大小
        self.grad_accum_every = grad_accum_every  # 设置梯度累积间隔
        self.max_grad_norm = max_grad_norm  # 设置最大梯度范数

        self.optim = get_optimizer(nuwa.parameters(), lr = lr, wd = wd)  # 获取优化器

        # 数据集
        self.ds = dataset  # 设置数据集

        # 数据加载器
        self.dl = cycle(DataLoader(
            self.ds,
            batch_size = batch_size,
            collate_fn = pad_collate_fn,
            shuffle = True
        ))  # 创建循环数据加载器

        self.save_model_every = save_model_every  # 设置保存模型间隔
        self.save_results_every = save_results_every  # 设置保存结果间隔
        self.num_sampled_frames = num_sampled_frames  # 设置抽样帧数

        self.results_folder = Path(results_folder)  # 设置结果文件夹路径

        # 如果结果文件夹中有文件且确认清除之前的实验检查点和结果
        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))  # 清除之前的实验检查点和结果

        self.results_folder.mkdir(parents = True, exist_ok = True)  # 创建结果文件夹
    # 定义训练步骤函数
    def train_step(self):
        # 获取模型参数所在设备
        device = next(self.nuwa.parameters()).device
        # 设置模型为训练模式
        self.nuwa.train()

        # 初始化日志字典
        logs = {}

        # 循环执行梯度累积次数
        for _ in range(self.grad_accum_every):
            # 从数据加载器中获取文本和视频数据
            text, video = next(self.dl)
            # 将文本和视频数据移动到指定设备
            text, video = map(lambda t: t.to(device), (text, video))

            # 计算模型损失
            loss = self.nuwa(
                text = text,
                video = video,
                return_loss = True
            )
            # 累积损失到日志中
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

            # 反向传播梯度
            (loss / self.grad_accum_every).backward()

        # 打印当前步骤的损失值
        print(f'{self.steps} loss: {logs["loss"]}')

        # 对模型参数进行梯度裁剪
        torch.nn.utils.clip_grad_norm_(self.nuwa.parameters(), self.max_grad_norm)
        # 更新优化器参数
        self.optim.step()
        # 清空梯度
        self.optim.zero_grad()

        # 每隔一定步骤保存生成结果
        if not (self.steps % self.save_results_every):
            # 设置模型为评估模式
            self.nuwa.eval()
            print(f'{self.steps} sampling')

            # 随机选择一个数据样本
            rand_idx = randrange(0, len(self.ds))

            text, video = self.ds[rand_idx]
            text, video = next(self.dl)
            text = text.to(device)

            # 生成视频序列
            video = self.nuwa.generate(text = text[:1], num_frames = min(video.shape[1], self.num_sampled_frames))
            one_video = video[0].cpu().clamp(0., 1.)

            # 解码文本数据
            text_str = tokenizer.decode(text[0])

            # 保存生成的文本和视频结果
            logs['sampled_text'] = text_str
            logs['sampled_video'] = one_video.numpy()

            # 重新排列视频帧以保存为图像
            image = rearrange(one_video, 'f c h w -> c (f h) w')
            save_image(image, str(self.results_folder / f'{self.steps}.png'))

            print(f'{self.steps}: saving to {str(self.results_folder)}')

        # 每隔一定步骤保存模型
        if not (self.steps % self.save_model_every):
            # 获取模型状态字典
            state_dict = self.nuwa.state_dict()
            model_path = str(self.results_folder / f'nuwa.{self.steps}.pt')
            # ���存模型参数
            torch.save(state_dict, model_path)

            print(f'{self.steps}: saving model to {str(self.results_folder)}')

        # 更新步骤数
        self.steps += 1
        return logs

    # 定义训练函数
    def train(self, log_fn = noop):
        # 循环执行训练步骤直到达到指定训练步数
        while self.steps < self.num_train_steps:
            # 执行训练步骤并记录日志
            logs = self.train_step()
            log_fn(logs)

        # 打印训练完成信息
        print('training complete')

.\lucidrains\nuwa-pytorch\nuwa_pytorch\train_vqgan_vae.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 copy 模块中导入 copy 函数
import copy
# 从 random 模块中导入 choice 函数
from random import choice
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 导入 numpy 模块
import numpy as np

# 从 PIL 模块中导入 Image 类
from PIL import Image
# 从 torchvision.datasets 模块中导入 ImageFolder 类
from torchvision.datasets import ImageFolder
# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 从 torch.utils.data 模块中导入 Dataset, DataLoader, random_split 类
from torch.utils.data import Dataset, DataLoader, random_split
# 从 torchvision.utils 模块中导入 make_grid, save_image 函数

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange
# 从 nuwa_pytorch.vqgan_vae 模块中导入 VQGanVAE 类
from nuwa_pytorch.vqgan_vae import VQGanVAE
# 从 nuwa_pytorch.optimizer 模块中导入 get_optimizer 函数

# helpers

# 定义 exists 函数,判断值是否存在
def exists(val):
    return val is not None

# 定义 noop 函数,空函数
def noop(*args, **kwargs):
    pass

# 定义 cycle 函数,循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 定义 cast_tuple 函数,将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 定义 yes_or_no 函数,询问用户是否为是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 定义 accum_log 函数,累积日志
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# classes

# 定义 MemmappedImageDataset 类,继承自 Dataset 类
class MemmappedImageDataset(Dataset):
    def __init__(
        self,
        *,
        path,
        shape,
        random_rotate = True
    ):
        super().__init__()
        path = Path(path)
        assert path.exists(), f'path {path} must exist'
        self.memmap = np.memmap(str(path), mode = 'r', dtype = np.uint8, shape = shape)
        self.random_rotate = random_rotate

        image_size = shape[-1]
        self.transform = T.Compose([
            T.Resize(image_size),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return self.memmap.shape[0]

    def __getitem__(self, index):
        arr = self.memmap[index]

        if arr.shape[0] == 1:
            arr = rearrange(arr, '1 ... -> ...')

        img = Image.fromarray(arr)
        img = self.transform(img)

        if self.random_rotate:
            img = T.functional.rotate(img, choice([0, 90, 180, 270]))
        return img

# 定义 ImageDataset 类,继承自 Dataset 类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# exponential moving average wrapper

# 定义 EMA 类,继承自 nn.Module 类
class EMA(nn.Module):
    def __init__(
        self,
        model,
        beta = 0.99,
        ema_update_after_step = 1000,
        ema_update_every = 10,
    ):
        super().__init__()
        self.beta = beta
        self.online_model = model
        self.ema_model = copy.deepcopy(model)

        self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
        self.ema_update_every = ema_update_every

        self.register_buffer('initted', torch.Tensor([False]))
        self.register_buffer('step', torch.tensor([0.]))

    def update(self):
        self.step += 1

        if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
            return

        if not self.initted:
            self.ema_model.state_dict(self.online_model.state_dict())
            self.initted.data.copy_(torch.Tensor([True]))

        self.update_moving_average(self.ema_model, self.online_model)
    # 更新移动平均模型的参数
    def update_moving_average(self, ma_model, current_model):
        # 定义计算指数移动平均的函数
        def calculate_ema(beta, old, new):
            # 如果旧值不存在,则直接返回新值
            if not exists(old):
                return new
            # 计算指数移动平均值
            return old * beta + (1 - beta) * new

        # 遍历当前模型和移动平均模型的参数,更新移动平均值
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = calculate_ema(self.beta, old_weight, up_weight)

        # 遍历当前模型和移动平均模型的缓冲区,更新移动平均值
        for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
            new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
            ma_buffer.copy_(new_buffer_value)

    # 调用函数,返回移动平均模型的结果
    def __call__(self, *args, **kwargs):
        return self.ema_model(*args, **kwargs)
# 主要的训练器类

class VQGanVAETrainer(nn.Module):
    def __init__(
        self,
        vae,
        *,
        num_train_steps,
        lr,
        batch_size,
        grad_accum_every,
        wd = 0.,
        images_memmap_path = None,
        images_memmap_shape = None,
        folder = None,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        ema_beta = 0.995,
        ema_update_after_step = 2000,
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
    ):
        super().__init__()
        assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
        image_size = vae.image_size

        self.vae = vae
        self.ema_vae = EMA(vae, ema_update_after_step = ema_update_after_step, ema_update_every = ema_update_every)

        self.register_buffer('steps', torch.Tensor([0]))

        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        all_parameters = set(vae.parameters())
        discr_parameters = set(vae.discr.parameters())
        vae_parameters = all_parameters - discr_parameters

        self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
        self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)

        # 创建数据集

        assert exists(folder) ^ exists(images_memmap_path), 'either folder or memmap path to images must be supplied'

        if exists(images_memmap_path):
            assert exists(images_memmap_shape), 'shape of memmapped images must be supplied'

        if exists(folder):
            self.ds = ImageDataset(folder, image_size = image_size)
        elif exists(images_memmap_path):
            self.ds = MemmappedImageDataset(path = images_memmap_path, shape = images_memmap_shape)

        # 划分验证集

        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 数据加载器

        self.dl = cycle(DataLoader(
            self.ds,
            batch_size = batch_size,
            shuffle = True
        ))

        self.valid_dl = cycle(DataLoader(
            self.valid_ds,
            batch_size = batch_size,
            shuffle = True
        ))

        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        self.apply_grad_penalty_every = apply_grad_penalty_every

        self.results_folder = Path(results_folder)

        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)
    # 定义训练步骤函数
    def train_step(self):
        # 获取模型参数所在设备
        device = next(self.vae.parameters()).device
        # 获取当前步数
        steps = int(self.steps.item())
        # 是否应用梯度惩罚
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

        # 设置 VAE 模型为训练模式
        self.vae.train()

        # 初始化日志字典
        logs = {}

        # 更新 VAE(生成器)

        # 多次执行梯度累积
        for _ in range(self.grad_accum_every):
            # 获取下一个数据批次
            img = next(self.dl)
            img = img.to(device)

            # 计算损失
            loss = self.vae(
                img,
                return_loss = True,
                apply_grad_penalty = apply_grad_penalty
            )

            # 累积损失到日志中
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

            # 反向传播
            (loss / self.grad_accum_every).backward()

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 更新鉴别器

        if exists(self.vae.discr):
            self.discr_optim.zero_grad()
            discr_loss = 0

            for _ in range(self.grad_accum_every):
                img = next(self.dl)
                img = img.to(device)

                loss = self.vae(img, return_discr_loss = True)
                accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})

                (loss / self.grad_accum_every).backward()

            self.discr_optim.step()

            # 打印日志
            print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")

        # 更新指数移动平均生成器
        self.ema_vae.update()

        # 定期采样结果

        if not (steps % self.save_results_every):
            for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
                model.eval()

                imgs = next(self.dl)
                imgs = imgs.to(device)

                recons = model(imgs)
                nrows = int(sqrt(self.batch_size))

                imgs_and_recons = torch.stack((imgs, recons), dim = 0)
                imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')

                imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
                grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))

                logs['reconstructions'] = grid

                save_image(grid, str(self.results_folder / f'{filename}.png'))

            print(f'{steps}: saving to {str(self.results_folder)}')

        # 定期保存模型

        if not (steps % self.save_model_every):
            state_dict = self.vae.state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.pt')
            torch.save(state_dict, model_path)

            ema_state_dict = self.ema_vae.state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
            torch.save(ema_state_dict, model_path)

            print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数
        self.steps += 1
        return logs

    # 训练函数
    def train(self, log_fn = noop):
        # 获取模型参数所在设备
        device = next(self.vae.parameters()).device

        # 在训练步数未达到总训练步数前循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 训练完成
        print('training complete')

.\lucidrains\nuwa-pytorch\nuwa_pytorch\vqgan_vae.py

# 导入必要的库
import copy
import math
from functools import partial, wraps
from math import sqrt

# 导入自定义模块
from vector_quantize_pytorch import VectorQuantize as VQ

# 导入 PyTorch 相关库
import torchvision
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad

# 导入 einops 库
from einops import rearrange, reduce, repeat

# 定义常量
MList = nn.ModuleList

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 装饰器

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 移除 VGG 模型装饰器
def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, 'vgg')
        if has_vgg:
            vgg = self.vgg
            delattr(self, 'vgg')

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self.vgg = vgg

        return out
    return inner

# 关键字参数辅助函数

# 从字典中选择指定键的值并弹出这些键
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return str.startswith(prefix)

# 根据前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀将字典分组并去除前缀
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 张量辅助函数

# 计算梯度惩罚
def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs = output, inputs = images,
                           grad_outputs = torch.ones(output.size(), device = images.device),
                           create_graph = True, retain_graph = True, only_inputs = True)[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

# 计算 L2 范数
def l2norm(t):
    return F.normalize(t, dim = -1)

# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

# 稳定的 Softmax 函数
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
    t = t / alpha
    t = t - torch.amax(t, dim = dim, keepdim = True).detach()
    return (t * alpha).softmax(dim = dim)

# 安全除法
def safe_div(numer, denom, eps = 1e-6):
    return numer / (denom + eps)

# GAN 损失函数

# Hinge 判别器损失
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 生成器损失
def hinge_gen_loss(fake):
    return -fake.mean()

# 二元交叉熵判别器损失
def bce_discr_loss(fake, real):
    return (-log(1 - sigmoid(fake)) - log(sigmoid(real))).mean()

# 二元交叉熵生成器损失
def bce_gen_loss(fake):
    return -log(sigmoid(fake)).mean()

# 计算损失对层的梯度
def grad_layer_wrt_loss(loss, layer):
    return torch_grad(
        outputs = loss,
        inputs = layer,
        grad_outputs = torch.ones_like(loss),
        retain_graph = True
    )[0].detach()

# VQGAN VAE

# 通道层归一化
class LayerNormChan(nn.Module):
    def __init__(
        self,
        dim,
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 判别器模型
class Discriminator(nn.Module):
    def __init__(
        self,
        dims,
        channels = 3,
        groups = 16,
        init_kernel_size = 5
    # 定义一个继承自 nn.Module 的类,用于构建一个简单的卷积神经网络
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 将输入维度按照前后两两配对,形成一个维度对的列表
        dim_pairs = zip(dims[:-1], dims[1:])

        # 初始化网络的第一层,包括一个卷积层和激活函数
        self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])

        # 遍历维度对列表,构建网络的中间层,每层包括卷积层、归一化层和激活函数
        for dim_in, dim_out in dim_pairs:
            self.layers.append(nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
                nn.GroupNorm(groups, dim_out),
                leaky_relu()
            ))

        # 获取最后一个维度
        dim = dims[-1]
        # 构建输出层,包括两个卷积层和激活函数,用于生成输出结果
        self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
            nn.Conv2d(dim, dim, 1),
            leaky_relu(),
            nn.Conv2d(dim, 1, 4)
        )

    # 定义前向传播方法,将输入数据通过网络层进行处理,得到输出结果
    def forward(self, x):
        # 遍历网络的每一层,将输入数据依次传递给每一层
        for net in self.layers:
            x = net(x)

        # 返回经过所有网络层处理后的输出结果
        return self.to_logits(x)
class ContinuousPositionBias(nn.Module):
    """ 定义一个连续位置偏置的类,参考 https://arxiv.org/abs/2111.09883 """

    def __init__(self, *, dim, heads, layers = 2):
        super().__init__()
        self.net = MList([])
        self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))

        for _ in range(layers - 1):
            self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))

        self.net.append(nn.Linear(dim, heads)
        self.register_buffer('rel_pos', None, persistent = False)

    def forward(self, x):
        n, device = x.shape[-1], x.device
        fmap_size = int(sqrt(n))

        if not exists(self.rel_pos):
            pos = torch.arange(fmap_size, device = device)
            grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
            grid = rearrange(grid, 'c i j -> (i j) c')
            rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
            rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
            self.register_buffer('rel_pos', rel_pos, persistent = False)

        rel_pos = self.rel_pos.float()

        for layer in self.net:
            rel_pos = layer(rel_pos)

        bias = rearrange(rel_pos, 'i j h -> h i j')
        return x + bias

class GLUResBlock(nn.Module):
    """ 定义一个 GLUResBlock 类 """

    def __init__(self, chan, groups = 16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan * 2, 3, padding = 1),
            nn.GLU(dim = 1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan * 2, 3, padding = 1),
            nn.GLU(dim = 1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan, 1)
        )

    def forward(self, x):
        return self.net(x) + x

class ResBlock(nn.Module):
    """ 定义一个 ResBlock 类 """

    def __init__(self, chan, groups = 16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 1)
        )

    def forward(self, x):
        return self.net(x) + x

class VQGanAttention(nn.Module):
    """ 定义一个 VQGanAttention 类 """

    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * math.log(0.01))
        inner_dim = heads * dim_head

        self.dropout = nn.Dropout(dropout)
        self.post_norm = LayerNormChan(dim)

        self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1)

    def forward(self, x):
        h = self.heads
        height, width, residual = *x.shape[-2:], x.clone()

        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))

        q, k = map(l2norm, (q, k))

        sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale.exp()

        sim = self.cpb(sim)

        attn = stable_softmax(sim, dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h c j -> b h c i', attn, v)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
        out = self.to_out(out)

        return self.post_norm(out) + residual

class VQGanVAE(nn.Module):
    """ 定义一个 VQGanVAE 类 """
    # 初始化函数,设置模型的参数
    def __init__(
        self,
        *,
        dim,  # 模型的维度
        image_size,  # 图像的尺寸
        channels = 3,  # 图像的通道数,默认为3
        num_layers = 4,  # 模型的层数,默认为4
        layer_mults = None,  # 每一层的倍增因子
        l2_recon_loss = False,  # 是否使用L2重建损失,默认为False
        use_hinge_loss = True,  # 是否使用hinge损失,默认为True
        num_resnet_blocks = 1,  # ResNet块的数量,默认为1
        vgg = None,  # VGG模型
        vq_codebook_dim = 256,  # VQ编码簇的维度
        vq_codebook_size = 512,  # VQ编码簇的大小
        vq_decay = 0.8,  # VQ衰减率
        vq_commitment_weight = 1.,  # VQ损失的权重
        vq_kmeans_init = True,  # 是否使用K均值初始化VQ编码簇,默认为True
        vq_use_cosine_sim = True,  # 是否使用余弦相似度计算VQ损失,默认为True
        use_attn = True,  # 是否使用注意力机制,默认为True
        attn_dim_head = 64,  # 注意力机制的头维度
        attn_heads = 8,  # 注意力机制的头数量
        resnet_groups = 16,  # ResNet块的组数
        attn_dropout = 0.,  # 注意力机制的dropout率
        first_conv_kernel_size = 5,  # 第一个卷积层的卷积核大小
        use_vgg_and_gan = True,  # 是否同时使用VGG和GAN,默认为True
        **kwargs  # 其他参数
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言维度必须能够被 resnet_groups 整除
        assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'

        # 将参数中以 'vq_' 开头的参数提取出来
        vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)

        # 初始化一些属性
        self.image_size = image_size
        self.channels = channels
        self.num_layers = num_layers
        self.fmap_size = image_size // (num_layers ** 2)
        self.codebook_size = vq_codebook_size

        self.encoders = MList([])
        self.decoders = MList([])

        # 计算每一层的维度
        layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
        assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'

        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)
        codebook_dim = layer_dims[-1]

        dim_pairs = zip(dims[:-1], dims[1:])

        append = lambda arr, t: arr.append(t)
        prepend = lambda arr, t: arr.insert(0, t)

        # 如果 num_resnet_blocks 不是元组,则转换为元组
        if not isinstance(num_resnet_blocks, tuple):
            num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)

        # 如果 use_attn 不是元组,则转换为元组
        if not isinstance(use_attn, tuple):
            use_attn = (*((False,) * (num_layers - 1)), use_attn)

        assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
        assert len(use_attn) == num_layers

        # 遍历每一层,构建编码器和解码器
        for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
            append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
            prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))

            if layer_use_attn:
                prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))

            for _ in range(layer_num_resnet_blocks):
                append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
                prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))

            if layer_use_attn:
                append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))

        prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
        append(self.decoders, nn.Conv2d(dim, channels, 1))

        # 初始化 VQ 模块
        self.vq = VQ(
            dim = layer_dims[-1],
            codebook_dim = vq_codebook_dim,
            codebook_size = vq_codebook_size,
            decay = vq_decay,
            commitment_weight = vq_commitment_weight,
            accept_image_fmap = True,
            kmeans_init = vq_kmeans_init,
            use_cosine_sim = vq_use_cosine_sim,
            **vq_kwargs
        )

        # 重构损失函数
        self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss

        # 如果是灰度图像,则关闭 GAN 和感知损失
        self.vgg = None
        self.discr = None
        self.use_vgg_and_gan = use_vgg_and_gan

        if not use_vgg_and_gan:
            return

        # 初始化感知损失
        if exists(vgg):
            self.vgg = vgg
        else:
            self.vgg = torchvision.models.vgg16(pretrained = True)
            self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])

        # 初始化GAN相关损失
        self.discr = Discriminator(dims = dims, channels = channels)

        self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
        self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
    # 创建一个模型的副本用于评估,确保在同一设备上
    def copy_for_eval(self):
        # 获取模型参数的设备信息
        device = next(self.parameters()).device
        # 深度复制模型并将其移动到 CPU
        vae_copy = copy.deepcopy(self.cpu())

        # 如果模型使用 VGG 和 GAN,则删除相关部分
        if vae_copy.use_vgg_and_gan:
            del vae_copy.discr
            del vae_copy.vgg

        # 将模型设置为评估模式
        vae_copy.eval()
        # 将模型移动回原设备
        return vae_copy.to(device)

    # 重写父类的 state_dict 方法,移除 VGG 相关部分
    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    # 重写父类的 load_state_dict 方法,移除 VGG 相关部分
    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    # 返回模型的 codebook 属性,即 VQ 模块的 codebook
    @property
    def codebook(self):
        return self.vq.codebook

    # 对输入进行编码操作,通过多个编码器层
    def encode(self, fmap):
        for enc in self.encoders:
            fmap = enc(fmap)

        return self.vq(fmap)

    # 对输入进行解码操作,通过多个解码器层
    def decode(self, fmap):
        for dec in self.decoders:
            fmap = dec(fmap)

        return fmap

    # 将 codebook 索引转换为视频数据
    @torch.no_grad()
    @eval_decorator
    def codebook_indices_to_video(self, indices):
        b = indices.shape[0]
        codes = self.codebook[indices]
        codes = rearrange(codes, 'b (f h w) d -> (b f) d h w', h = self.fmap_size, w = self.fmap_size)
        video = self.decode(codes)
        return rearrange(video, '(b f) ... -> b f ...', b = b)

    # 从视频数据中获取 codebook 索引
    @torch.no_grad()
    @eval_decorator
    def get_video_indices(self, video):
        b, f, _, h, w = video.shape
        images = rearrange(video, 'b f ... -> (b f) ...')
        _, indices, _ = self.encode(images)
        return rearrange(indices, '(b f) ... -> b f ...', b = b)

    # 模型的前向传播方法,包括返回损失、重构、梯度惩罚等选项
    def forward(
        self,
        img,
        return_loss = False,
        return_discr_loss = False,
        return_recons = False,
        apply_grad_penalty = False
        ):
        # 解构赋值,获取图像的批次、通道数、高度、宽度和设备信息
        batch, channels, height, width, device = *img.shape, img.device
        # 断言输入图像的高度和宽度与设定的self.image_size相等
        assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
        # 断言输入图像的通道数与VQGanVAE中设定的通道数相等
        assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'

        # 编码输入图像,获取特征图、索引和commit_loss
        fmap, indices, commit_loss = self.encode(img)

        # 解码特征图
        fmap = self.decode(fmap)

        # 如果不需要返回损失和鉴别器损失,则直接返回解码后的特征图
        if not return_loss and not return_discr_loss:
            return fmap

        # 断言只能返回自编码器损失或鉴别器损失,不能同时返回
        assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'

        # 是否返回鉴别器损失
        if return_discr_loss:
            # 断言鉴别器存在
            assert exists(self.discr), 'discriminator must exist to train it'

            # 分离特征图,设置输入图像为可求导
            fmap.detach_()
            img.requires_grad_()

            # 获取特征图和输入图像的鉴别器logits
            fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))

            # 计算鉴别器损失
            loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

            # 如果需要应用梯度惩罚
            if apply_grad_penalty:
                gp = gradient_penalty(img, img_discr_logits)
                loss = loss + gp

            # 如果需要返回重构图像
            if return_recons:
                return loss, fmap

            return loss

        # 重构损失
        recon_loss = self.recon_loss_fn(fmap, img)

        # 如果不使用VGG和GAN,则直接返回重构损失
        if not self.use_vgg_and_gan:
            if return_recons:
                return recon_loss, fmap

            return recon_loss

        # 感知损失
        img_vgg_input = img
        fmap_vgg_input = fmap

        # 处理灰度图像用于VGG
        if img.shape[1] == 1:
            img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))

        # 获取输入图像和重构图像的VGG特征
        img_vgg_feats = self.vgg(img_vgg_input)
        recon_vgg_feats = self.vgg(fmap_vgg_input)
        perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)

        # 生成器损失
        gen_loss = self.gen_loss(self.discr(fmap))

        # 计算自适应权重
        last_dec_layer = self.decoders[-1].weight

        norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
        norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

        adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
        adaptive_weight.clamp_(max = 1e4)

        # 组合损失
        loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss

        # 如果需要返回重构图像
        if return_recons:
            return loss, fmap

        return loss

.\lucidrains\nuwa-pytorch\nuwa_pytorch\__init__.py

# 从 nuwa_pytorch.nuwa_pytorch 模块中导入 NUWA、NUWASketch、NUWAVideoAudio、Sparse3DNA、CrossModalityCrossAttention 类
# 以及从 nuwa_pytorch.vqgan_vae 模块中导入 VQGanVAE 类
from nuwa_pytorch.nuwa_pytorch import NUWA, NUWASketch, NUWAVideoAudio, Sparse3DNA, CrossModalityCrossAttention
from nuwa_pytorch.vqgan_vae import VQGanVAE

# 从 nuwa_pytorch.train_vqgan_vae 模块中导入 VQGanVAETrainer 类
# 以及从 nuwa_pytorch.train_nuwa 模块中导入 NUWATrainer 类
from nuwa_pytorch.train_vqgan_vae import VQGanVAETrainer
from nuwa_pytorch.train_nuwa import NUWATrainer

NÜWA - Pytorch

Join us on Discord

Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. It also contain an extension into video and audio generation, using a dual decoder approach.

Yannic Kilcher

DeepReader

Status

  • March 2022 - seeing signs of life with a difficult version of moving mnist

  • April 2022 - It seems as though a diffusion based method has taken the new throne for SOTA. However, I will continue on with NUWA, extending it to use multi-headed codes + hierarchical causal transformer. I think that direction is untapped for improving on this line of work.

Install

$ pip install nuwa-pytorch

Usage

First train the VAE

import torch
from nuwa_pytorch import VQGanVAE

vae = VQGanVAE(
    dim = 512,
    channels = 3,               # default is 3, but can be changed to any value for the training of the segmentation masks (sketches)
    image_size = 256,           # image size
    num_layers = 4,             # number of downsampling layers
    num_resnet_blocks = 2,      # number of resnet blocks
    vq_codebook_size = 8192,    # codebook size
    vq_decay = 0.8              # codebook exponential decay
)

imgs = torch.randn(10, 3, 256, 256)

# alternate learning for autoencoder ...

loss = vae(imgs, return_loss = True)
loss.backward()

# and the discriminator ...

discr_loss = vae(imgs, return_discr_loss = True)
discr_loss.backward()

# do above for many steps

# return reconstructed images and make sure they look ok

recon_imgs = vae(imgs)

Then, with your learned VAE

import torch
from nuwa_pytorch import NUWA, VQGanVAE

# autoencoder

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

# NUWA transformer

nuwa = NUWA(
    vae = vae,
    dim = 512,
    text_num_tokens = 20000,                # number of text tokens
    text_enc_depth = 12,                    # text encoder depth
    text_enc_heads = 8,                     # number of attention heads for encoder
    text_max_seq_len = 256,                 # max sequence length of text conditioning tokens (keep at 256 as in paper, or shorter, if your text is not that long)
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 64,                         # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    dec_reversible = True,                  # reversible networks - from reformer, decoupling memory usage from depth
    enc_reversible = True,                  # reversible encoders, if you need it
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

text = torch.randint(0, 20000, (1, 256)).cuda()
video = torch.randn(1, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    text = text,
    video = video,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from text

video = nuwa.generate(text = text, num_frames = 5) # (1, 5, 3, 256, 256)

Conditioning on Sketches

In the paper, they also present a way to condition the video generation based on segmentation mask(s). You can easily do this as well, given you train a VQGanVAE on the sketches before hand.

Then, you will use NUWASketch instead of NUWA, which can accept the sketch VAE as a reference

ex.

import torch
from nuwa_pytorch import NUWASketch, VQGanVAE

# autoencoder, one for main video, the other for the sketch

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

sketch_vae = VQGanVAE(
    dim = 512,
    channels = 5,                # say the sketch has 5 classes
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

# NUWA transformer for conditioning with sketches

nuwa = NUWASketch(
    vae = vae,
    sketch_vae = sketch_vae,
    dim = 512,                              # model dimensions
    sketch_enc_depth = 12,                  # sketch encoder depth
    sketch_enc_heads = 8,                   # number of attention heads for sketch encoder
    sketch_max_video_frames = 3,            # max number of frames for sketches
    sketch_enc_use_sparse_3dna = True,      # whether to use 3d-nearby attention (of full attention if False) for sketch encoding transformer
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 64,                         # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    dec_reversible = True,                  # reversible networks - from reformer, decoupling memory usage from depth
    enc_reversible = True,                  # reversible encoders, if you need it
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    cross_2dna_kernel_size = 5,             # 2d kernel size of spatial grouping of attention from video frames to sketches
    cross_2dna_dilation = 1,                # 2d dilation of spatial attention from video frames to sketches
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

sketch = torch.randn(2, 2, 5, 256, 256).cuda() # (batch, frames, segmentation classes, height, width)
sketch_mask = torch.ones(2, 2).bool().cuda()   # (batch, frames) [Optional]
video = torch.randn(2, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    sketch = sketch,
    sketch_mask =sketch_mask,
    video = video,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from sketch(es)

video = nuwa.generate(sketch = sketch, num_frames = 5) # (1, 5, 3, 256, 256)

Text to Video and Audio

This repository will also offer a variant of NUWA that can produce both video and audio. For now, the audio will need to be encoded manually.

import torch
from nuwa_pytorch import NUWAVideoAudio, VQGanVAE

# autoencoder

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 100
)

# NUWA transformer

nuwa = NUWAVideoAudio(
    vae = vae,
    dim = 512,
    num_audio_tokens = 2048,                # codebook size for audio tokens
    num_audio_tokens_per_video_frame = 32,  # number of audio tokens per video frame
    cross_modality_attn_every = 3,          # cross modality attention every N layers
    text_num_tokens = 20000,                # number of text tokens
    text_enc_depth = 1,                     # text encoder depth
    text_enc_heads = 8,                     # number of attention heads for encoder
    text_max_seq_len = 256,                 # max sequence length of text conditioning tokens (keep at 256 as in paper, or shorter, if your text is not that long)
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 4,                          # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    enc_reversible = True,                  # reversible encoders, if you need it
    dec_reversible = True,                  # quad-branched reversible network, for making depth of twin video / audio decoder independent of network depth. recommended to be turned on unless you have a ton of memory at your disposal
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

text = torch.randint(0, 20000, (1, 256)).cuda()
audio = torch.randint(0, 2048, (1, 32 * 10)).cuda() # (batch, audio tokens per frame * max video frames)
video = torch.randn(1, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    text = text,
    video = video,
    audio = audio,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from text

video, audio = nuwa.generate(text = text, num_frames = 5) # (1, 5, 3, 256, 256), (1, 32 * 5 == 160)

Trainers

This library will offer some utilities to make training easier. For starters, you can use the VQGanVAETrainer class to take care of training the VQGanVAE. Simply wrap the model and also pass in the image folder path as well as the various training hyperparameters.

import torch
from nuwa_pytorch import VQGanVAE, VQGanVAETrainer

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 5,
    vq_codebook_size = 1024,
    vq_use_cosine_sim = True,
    vq_codebook_dim = 32,
    vq_orthogonal_reg_weight = 10,
    vq_orthogonal_reg_max_codes = 128,
).cuda()

trainer = VQGanVAETrainer(
    vae,                           # VAE defined above
    folder ='/path/to/images',     # path to images
    lr = 3e-4,                     # learning rate
    num_train_steps = 100000,      # number of training steps
    batch_size = 8,                # batch size
    grad_accum_every = 4           # gradient accumulation (effective batch size is (batch_size x grad_accum_every))
)

trainer.train()

# results and model checkpoints will be saved periodically to ./results

To train NUWA, first you need to organize a folder of .gif files with corresponding .txt files containing its caption. It should be organized as such.

ex.

📂video-and-text-data
 ┣ 📜cat.gif
 ┣ 📜cat.txt
 ┣ 📜dog.gif
 ┣ 📜dog.txt
 ┣ 📜turtle.gif
 ┗ 📜turtle.txt
```py

Then you will load your previously trained VQGan-VAE and train NUWA with the `GifVideoDataset` and `NUWATrainer` classes.

```py
import torch
from nuwa_pytorch import NUWA, VQGanVAE
from nuwa_pytorch.train_nuwa import GifVideoDataset, NUWATrainer

# dataset

ds = GifVideoDataset(
    folder = './path/to/videos/',
    channels = 1
)

# autoencoder

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 5,
    num_resnet_blocks = 2,
    vq_codebook_size = 512,
    attn_dropout = 0.1
)

vae.load_state_dict(torch.load('./path/to/trained/vae.pt'))

# NUWA transformer

nuwa = NUWA(
    vae = vae,
    dim = 512,
    text_enc_depth = 6,
    text_max_seq_len = 256,
    max_video_frames = 10,
    dec_depth = 12,
    dec_reversible = True,
    enc_reversible = True,
    attn_dropout = 0.05,
    ff_dropout = 0.05,
    sparse_3dna_kernel_size = (5, 3, 3),
    sparse_3dna_dilation = (1, 2, 4),
    shift_video_tokens = True
).cuda()

# data

trainer = NUWATrainer(
    nuwa = nuwa,                 # NUWA transformer
    dataset = dataset,           # video dataset class
    num_train_steps = 1000000,   # number of training steps
    lr = 3e-4,                   # learning rate
    wd = 0.01,                   # weight decay
    batch_size = 8,              # batch size
    grad_accum_every = 4,        # gradient accumulation
    max_grad_norm = 0.5,         # gradient clipping
    num_sampled_frames = 10,     # number of frames to sample
    results_folder = './results' # folder to store checkpoints and samples
)

trainer.train()
```py

## VQ improvements

This library depends on this <a href="https://github.com/lucidrains/vector-quantize-pytorch">vector quantization</a> library, which comes with a number of improvements (improved vqgan, orthogonal codebook regularization, etc). To use any of these improvements, you can configure the vector quantizer keyword params by prepending `vq_` on `VQGanVAE` initialization.

ex. cosine sim proposed in <a href="https://arxiv.org/abs/2110.04627">improved vqgan</a>

```py
from nuwa_pytorch import VQGanVAE

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 4,
    vq_use_cosine_sim = True
    # VectorQuantize will be initialized with use_cosine_sim = True
    # https://github.com/lucidrains/vector-quantize-pytorch#cosine-similarity
).cuda()
```py

## Todo

- [x] complete 3dna causal attention in decoder
- [x] write up easy generation functions
- [x] make sure GAN portion of VQGan is correct, reread paper
- [x] make sure adaptive weight in vqgan is correctly built
- [x] offer new vqvae improvements (orthogonal reg and smaller codebook dimensions)
- [x] batch video tokens -> vae during video generation, to prevent oom
- [x] query chunking in 3dna attention, to put a cap on peak memory
- [x] flesh out VAE resnet blocks, offer some choices
- [x] add all stability tricks from cogview paper by default
- [x] make VQGan able to accept custom VGG for LPAPs loss (audio)
- [x] add feedforward chunking
- [x] add shift token in decoder for cheap powerful RPE
- [x] add reversible networks, to save on memory on depth
- [x] support kernel sizes different along each dimension for sparse 3dna
- [x] add some autotrainer that takes care of the alternating updates of discriminator and VQVAE generator
- [x] segmentation mask encoder, make sure embeddings can undergo 3dna attention with decoder during cross attention
- [x] finish 2d-nearby cross attention for sketches
- [x] able to add convnext blocks to other layers in vqgan vae
- [x] offer vqvae training script
- [x] handle variable lengthed sketches, accept a mask on the sketch frames dimension
- [x] take care of audio transformer and cross modality attention
- [x] add audio transformer, and build audio / video nearby cross attention
- [x] make dual decoder reversible
- [x] rotary embeddings for encoder
- [x] add cycle dilation to audio
- [x] omit vgg from VAE state dict
- [x] add cosine sim attention from swinv2 as an option
- [x] add axial positional embedding to audio
- [ ] Triton kernel for 3dna attention
- [ ] offer a colab with moving mnist example, conditioned on present digits
- [ ] build NUWA controller class that can accept text or sketch
- [ ] key masking for 3dna attention - for variable sketch length masking
- [ ] figure out spec vqgan and fit it into the framework, take care of audio encoding / decoding automatically
- [ ] turn into CLI tool, like stylegan2-pytorch
- [ ] look into integrating https://github.com/lucidrains/RQ-Transformer for both video and audio
- [ ] inference caching

## Citations

```py
@misc{wu2021nuwa,
    title   = {N\"UWA: Visual Synthesis Pre-training for Neural visUal World creAtion}, 
    author  = {Chenfei Wu and Jian Liang and Lei Ji and Fan Yang and Yuejian Fang and Daxin Jiang and Nan Duan},
    year    = {2021},
    eprint  = {2111.12417},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{esser2021taming,
    title   = {Taming Transformers for High-Resolution Image Synthesis},
    author  = {Patrick Esser and Robin Rombach and Björn Ommer},
    year    = {2021},
    eprint  = {2012.09841},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{iashin2021taming,
    title   = {Taming Visually Guided Sound Generation},
    author  = {Vladimir Iashin and Esa Rahtu},
    year    = {2021},
    eprint  = {2110.08791},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{kitaev2020reformer,
    title   = {Reformer: The Efficient Transformer},
    author  = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya},
    year    = {2020},
    eprint  = {2001.04451},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
```py

```py
@misc{shazeer2020talkingheads,
    title   = {Talking-Heads Attention}, 
    author  = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    year    = {2020},
    eprint  = {2003.02436},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
```py

```py
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}    
}
```py

```py
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
```py

```py
@inproceedings{ho2021classifierfree,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho and Tim Salimans},
    booktitle = {NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications},
    year    = {2021},
    url     = {https://openreview.net/forum?id=qw8AKxfYbI}
}
```py

```py
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/RiversHaveWings/status/1478093658716966912}
}

Attention is the rarest and purest form of generosity. - Simone Weil

.\lucidrains\nuwa-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'nuwa-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 包含所有数据文件
  include_package_data = True,
  # 版本号
  version = '0.7.8',
  # 许可证类型
  license='MIT',
  # 包的描述
  description = 'NÜWA - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/nuwa-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'transformers'
  ],
  # 安装依赖包
  install_requires=[
    'einops>=0.4.1',
    'ftfy',
    'pillow',
    'regex',
    'torch>=1.6',
    'torchvision',
    'tqdm',
    'unfoldNd',
    'vector-quantize-pytorch>=0.4.10'
  ],
  # 分类标签列表
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\NWT-pytorch\nwt_pytorch\nwt_pytorch.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import EinMix as Mix

# 定义一个名为Memcodes的神经网络模型,继承自nn.Module类
class Memcodes(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入数据的维度
        num_codes,  # 编码的数量
        heads = 8,  # 多头注意力机制中的头数,默认为8
        temperature = 1.,  # 温度参数,默认为1
    ):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
        self.heads = heads
        self.dim = dim
        self.scale = (dim // heads) ** -0.5  # 缩放因子
        self.temperature = temperature
        self.num_codes = num_codes

        num_codebooks = heads
        codebook_dim = dim // heads

        # 初始化编码参数
        self.codes = nn.Parameter(torch.randn(num_codebooks, num_codes, codebook_dim))
        # 初始化转换矩阵,用于将编码转换为key
        self.to_k = Mix('h n d -> h n c', weight_shape = 'h d c', h = heads, d = codebook_dim, c = codebook_dim)
        # 初始化转换矩阵,用于将编码转换为value
        self.to_v = Mix('h n d -> h n c', weight_shape = 'h d c', h = heads, d = codebook_dim, c = codebook_dim)

    # 根据编码的索引获取编码
    def get_codes_from_indices(self, codebook_indices, *, merge_output_heads = True):
        batch = codebook_indices.shape[0]

        values = self.to_v(self.codes)
        values = repeat(values, 'h n d -> b h n d', b = batch)

        codebook_indices = repeat(codebook_indices, '... -> ... d', d = values.shape[-1])
        out = values.gather(2, codebook_indices)

        if not merge_output_heads:
            return out

        return rearrange(out, 'b h n d -> b n (h d)')

    # 前向传播函数
    def forward(self, x, *, merge_output_heads = True):
        assert x.shape[-1] == self.dim

        # 将输入数据分成多个头
        q = rearrange(x, 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale

        # 获取编码的key和value
        k, v = self.to_k(self.codes), self.to_v(self.codes)

        # 使用直通Gumbel Softmax
        logits = einsum('b h i d, h j d -> b h i j', q, k)

        if self.training:
            attn = F.gumbel_softmax(logits, tau = self.temperature, dim = -1, hard = True)
            codebook_indices = attn.argmax(dim = -1)
        else:
            codebook_indices = logits.argmax(dim = -1)
            attn = F.one_hot(codebook_indices, num_classes = self.num_codes).float()

        out = einsum('b h i j, h j d -> b h i d', attn, v)

        if not merge_output_heads:
            return out, codebook_indices

        # 如果指定了合并头部,则合并头部
        out = rearrange(out, 'b h n d -> b n (h d)')
        return out, codebook_indices

.\lucidrains\NWT-pytorch\nwt_pytorch\__init__.py

# 从 nwt_pytorch.nwt_pytorch 模块中导入 Memcodes 类
from nwt_pytorch.nwt_pytorch import Memcodes

NWT - Pytorch (wip)

Implementation of NWT, audio-to-video generation, in Pytorch.

Generated samples

Install

$ pip install nwt-pytorch

Usage

The paper proposes a new discrete latent representation named Memcodes, which can be succinctly described as a type of multi-head hard-attention to learned memory (codebook) key / values. They claim the need for less codes and smaller codebook dimension in order to achieve better reconstructions.

import torch
from nwt_pytorch import Memcodes

codebook = Memcodes(
    dim = 512,            # dimension of incoming features (codebook dimension will be dim / heads)
    heads = 8,            # head dimension, which is equivalent ot number of codebooks
    num_codes = 1024,     # number of codes per codebook
    temperature = 1.      # gumbel softmax temperature
)

x = torch.randn(1, 1024, 512)
out, codebook_indices = codebook(x) # (1, 1024, 512), (1, 1024, 8)
# (batch, seq, dimension), (batch, seq, heads)

# reconstruct output from codebook indices (codebook indices are autoregressed out from an attention net in paper)

assert torch.allclose(codebook.get_codes_from_indices(codebook_indices), out)

Citations

@misc{mama2021nwt,
    title   = {NWT: Towards natural audio-to-video generation with representation learning}, 
    author  = {Rayhane Mama and Marc S. Tyndel and Hashiam Kadhim and Cole Clifford and Ragavan Thurairatnam},
    year    = {2021},
    eprint  = {2106.04283},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}

.\lucidrains\NWT-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'nwt-pytorch',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.4',
  # 许可证
  license='MIT',
  # 描述
  description = 'NWT - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/NWT-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'pytorch',
    'audio to video synthesis'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'torch'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\nystrom-attention\nystrom_attention\nystrom_attention.py

# 从 math 模块中导入 ceil 函数
from math import ceil
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块

import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 reduce 函数

from einops import rearrange, reduce

# 定义一个辅助函数 exists,用于判断变量是否存在
def exists(val):
    return val is not None

# 定义 Moore-Penrose 伪逆的迭代计算函数
def moore_penrose_iter_pinv(x, iters = 6):
    # 获取输入张量 x 的设备信息
    device = x.device

    # 计算 x 的绝对值
    abs_x = torch.abs(x)
    # 沿着最后一个维度求和,得到列和
    col = abs_x.sum(dim = -1)
    # 沿着倒数第二个维度求和,得到行和
    row = abs_x.sum(dim = -2)
    # 对 x 进行重排,转置操作
    z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row))

    # 创建单位矩阵
    I = torch.eye(x.shape[-1], device = device)
    I = rearrange(I, 'i j -> () i j')

    # 迭代计算 Moore-Penrose 伪逆
    for _ in range(iters):
        xz = x @ z
        z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))

    return z

# 主要的注意力类 NystromAttention
class NystromAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        num_landmarks = 256,
        pinv_iterations = 6,
        residual = True,
        residual_conv_kernel = 33,
        eps = 1e-8,
        dropout = 0.
    ):
        super().__init__()
        self.eps = eps
        inner_dim = heads * dim_head

        self.num_landmarks = num_landmarks
        self.pinv_iterations = pinv_iterations

        self.heads = heads
        self.scale = dim_head ** -0.5
        # 定义一个线性层,用于将输入维度转换为内部维度的三倍
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 定义输出层,包含一个线性层和一个 dropout 层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

        self.residual = residual
        # 如果启用残差连接
        if residual:
            kernel_size = residual_conv_kernel
            padding = residual_conv_kernel // 2
            # 定义一个卷积层,用于残差连接
            self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False)
    # 定义前向传播函数,接受输入 x,mask 和 return_attn 参数
    def forward(self, x, mask = None, return_attn = False):
        # 解包 x 的形状信息,包括 batch size (b), 序列长度 (n), 头数 (h), 地标数 (m), 伪逆迭代次数 (iters), 以及 epsilon (eps)
        b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps

        # 将序列填充,使其可以被均匀地分成 m 个地标
        remainder = n % m
        if remainder > 0:
            padding = m - (n % m)
            x = F.pad(x, (0, 0, padding, 0), value = 0)

            if exists(mask):
                mask = F.pad(mask, (padding, 0), value = False)

        # 派生查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 将查询、键、值中的掩码位置设为 0
        if exists(mask):
            mask = rearrange(mask, 'b n -> b () n')
            q, k, v = map(lambda t: t * mask[..., None], (q, k, v))

        q = q * self.scale

        # 通过求和缩减生成地标,然后使用掩码计算均值
        l = ceil(n / m)
        landmark_einops_eq = '... (n l) d -> ... n d'
        q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
        k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)

        # 计算地标掩码,并准备计算掩码均值时的非掩码元素总和
        divisor = l
        if exists(mask):
            mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l)
            divisor = mask_landmarks_sum[..., None] + eps
            mask_landmarks = mask_landmarks_sum > 0

        # 如果存在掩码,则进行掩码均值计算
        q_landmarks = q_landmarks / divisor
        k_landmarks = k_landmarks / divisor

        # 相似度计算
        einops_eq = '... i d, ... j d -> ... i j'
        sim1 = einsum(einops_eq, q, k_landmarks)
        sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
        sim3 = einsum(einops_eq, q_landmarks, k)

        # 掩码处理
        if exists(mask):
            mask_value = -torch.finfo(q.dtype).max
            sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
            sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
            sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)

        # 计算公式 (15) 中的等式,并聚合值
        attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
        attn2_inv = moore_penrose_iter_pinv(attn2, iters)

        out = (attn1 @ attn2_inv) @ (attn3 @ v)

        # 添加值的深度卷积残差
        if self.residual:
            out = out + self.res_conv(v)

        # 合并和组合头
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        out = self.to_out(out)
        out = out[:, -n:]

        # 如果需要返回注意力权重,则返回输出和注意力权重
        if return_attn:
            attn = attn1 @ attn2_inv @ attn3
            return out, attn

        return out
# transformer

# 定义一个预标准化层,包含一个 LayerNorm 层和一个传入的函数
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # 初始化 LayerNorm 层
        self.fn = fn  # 保存传入的函数

    def forward(self, x, **kwargs):
        x = self.norm(x)  # 对输入数据进行标准化
        return self.fn(x, **kwargs)  # 调用传入的函数处理标准化后的数据

# 定义一个前馈神经网络层,包含线性层、GELU 激活函数、Dropout 和另一个线性层
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),  # 第一个线性层
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 层
            nn.Linear(dim * mult, dim)  # 第二个线性层
        )

    def forward(self, x):
        return self.net(x)  # 前馈神经网络的前向传播

# 定义一个 Nystromformer 模型,包含多个层
class Nystromformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_landmarks = 256,
        pinv_iterations = 6,
        attn_values_residual = True,
        attn_values_residual_conv_kernel = 33,
        attn_dropout = 0.,
        ff_dropout = 0.   
    ):
        super().__init__()

        self.layers = nn.ModuleList([])  # 初始化一个空的 ModuleList
        for _ in range(depth):
            # 每一层包含一个 NystromAttention 层和一个 FeedForward 层,都经过预标准化
            self.layers.append(nn.ModuleList([
                PreNorm(dim, NystromAttention(dim = dim, dim_head = dim_head, heads = heads, num_landmarks = num_landmarks, pinv_iterations = pinv_iterations, residual = attn_values_residual, residual_conv_kernel = attn_values_residual_conv_kernel, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout))
            ]))

    def forward(self, x, mask = None):
        # 遍历每一层,依次进行注意力计算和前馈神经网络处理
        for attn, ff in self.layers:
            x = attn(x, mask = mask) + x  # 注意力计算后加上残差连接
            x = ff(x) + x  # 前馈神经网络处理后加上残差连接
        return x  # 返回处理后的数据

.\lucidrains\nystrom-attention\nystrom_attention\__init__.py

# 从 nystrom_attention 模块中导入 NystromAttention 和 Nystromformer 类
from nystrom_attention.nystrom_attention import NystromAttention, Nystromformer
# 将 Nystromformer 类赋值给 Nystromer 变量
Nystromer = Nystromformer

Nyström Attention

Implementation of Nyström Self-attention, from the paper Nyströmformer.

Yannic Kilcher video

Install

$ pip install nystrom-attention

Usage

import torch
from nystrom_attention import NystromAttention

attn = NystromAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    num_landmarks = 256,    # number of landmarks
    pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
    residual = True         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
)

x = torch.randn(1, 16384, 512)
mask = torch.ones(1, 16384).bool()

attn(x, mask = mask) # (1, 16384, 512)

Nyströmformer, layers of Nyström attention

import torch
from nystrom_attention import Nystromformer

model = Nystromformer(
    dim = 512,
    dim_head = 64,
    heads = 8,
    depth = 6,
    num_landmarks = 256,
    pinv_iterations = 6
)

x = torch.randn(1, 16384, 512)
mask = torch.ones(1, 16384).bool()

model(x, mask = mask) # (1, 16384, 512)

You can also import it as Nyströmer if you wish

from nystrom_attention import Nystromer

Citations

@misc{xiong2021nystromformer,
    title   = {Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention},
    author  = {Yunyang Xiong and Zhanpeng Zeng and Rudrasis Chakraborty and Mingxing Tan and Glenn Fung and Yin Li and Vikas Singh},
    year    = {2021},
    eprint  = {2102.03902},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\nystrom-attention\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'nystrom-attention',
  # 查找所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.12',
  # 许可证
  license='MIT',
  # 描述
  description = 'Nystrom Attention - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/nystrom-attention',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.7.0',
    'torch>=2.0'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\omninet-pytorch\omninet_pytorch\omninet_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 使用 PerformerAttention 作为自注意力机制,因为它有最好的报告数字
from performer_pytorch import SelfAttention as PerformerAttention

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 获取模块所在设备的函数
def get_module_device(module):
    return next(module.parameters()).device

# 查找指定类型模块的函数
def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

# 类定义

# 预层归一化类
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        # 初始化 LayerNorm 归一化层
        self.norm = nn.LayerNorm(dim)
        # 初始化传入的函数
        self.fn = fn

    def forward(self, x, **kwargs):
        # 对输入进行归一化后,再传入函数进行处理
        return self.fn(self.norm(x), **kwargs)

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        # 定义前馈神经网络结构
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        # 前馈神经网络前向传播
        return self.net(x)

# 自注意力机制类
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads =  heads
        self.scale = dim_head ** -0.5
        self.causal = causal

        # 定义 Q、K、V 的线性变换层
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 定义输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
        # 获取输入 x 的形状信息
        b, n, d, h, device = *x.shape, self.heads, x.device
        # 将输入 x 进行 Q、K、V 的线性���换,并分割为 Q、K、V
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        # 计算注意力分数
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        # 定义最大负值
        max_neg_value = -torch.finfo(sim.dtype).max

        # 如果存在 mask,则进行 mask 操作
        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
            sim.masked_fill_(~mask, max_neg_value)

        # 如果是因果注意力机制,则进行 mask 操作
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
            causal_mask = rearrange(causal_mask, 'i j -> () i j')
            sim.masked_fill_(causal_mask, max_neg_value)

        # 计算注意力权重
        attn = sim.softmax(dim = -1)

        # 计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

# 主类

class Omninet(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        pool_layer_tokens_every = 2,
        attn_dropout = 0.,
        ff_dropout = 0.,
        feature_redraw_interval = 1000
    ):
        super().__init__()

        layers = nn.ModuleList([])
        for ind in range(depth):
            num_layers = ind + 1
            should_pool = num_layers % pool_layer_tokens_every

            layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)),
                PerformerAttention(dim = dim, heads= heads, dim_head = dim_head) if should_pool else None
            ]))

        self.layers = layers
        self.pool_num_layers = pool_layer_tokens_every

        # 跟踪重新绘制 Performer 投影矩阵的次数
        self.feature_redraw_interval = feature_redraw_interval
        self.register_buffer('calls_since_last_redraw', torch.tensor(0))

    # 修复投影矩阵的函数
    def fix_projection_matrices_(self):
        self.feature_redraw_interval = None
    # 检查是否需要重新绘制投影矩阵
    def check_redraw_projections(self):
        # 如果不处于训练状态,则直接返回
        if not self.training:
            return

        # 如果存在特征重新绘制间隔,并且自上次重新绘制以来的调用次数超过间隔
        if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
            # 获取模块所在设备
            device = get_module_device(self)

            # 查找所有 FastAttention 模块
            fast_attentions = find_modules(self, FastAttention)
            # 对每个 FastAttention 模块重新绘制投影矩阵
            for fast_attention in fast_attentions:
                fast_attention.redraw_projection_matrix(device)

            # 重置自上次重新绘制以来的调用次数
            self.calls_since_last_redraw.zero_()
            return

        # 自上次重新绘制以来的调用次数加一
        self.calls_since_last_redraw += 1

    # 前向传播函数
    def forward(self, x, mask = None):
        # 检查是否需要重新绘制投影矩阵
        self.check_redraw_projections()
        # 获取池化层数
        pool_num_layers = self.pool_num_layers

        # 初始化隐藏层列表
        hiddens = [x]

        # 遍历每个注意力层、前馈层和高效注意力层
        for attn, ff, efficient_attn in self.layers:
            # 注意力层的输出加上输入,得到新的输出
            x = attn(x, mask = mask) + x
            # 前馈层的输出加上输入,得到新的输出
            x = ff(x) + x

            # 将新的输出添加到隐藏层列表中
            hiddens.append(x)
            # 如果存在高效注意力层
            if exists(efficient_attn):
                # 选择最近的池化层数量的隐藏层
                layers_to_pool = hiddens[-pool_num_layers:]
                num_layers = len(layers_to_pool)

                # 将所有隐藏层的 token 合并成一个张量
                all_tokens = torch.stack(layers_to_pool)
                all_tokens = rearrange(all_tokens, 'l b n d -> b (n l) d')

                # 初始化池化注意力层的掩码
                pool_attn_mask = None
                if exists(mask):
                    pool_attn_mask = repeat(mask, 'b n -> b (n l)', l = num_layers)

                # 对合并的 token 应用高效注意力层
                attended_tokens = efficient_attn(all_tokens, mask = pool_attn_mask)

                # 重新排列输出张量的维度
                attended_tokens = rearrange(attended_tokens, 'b n c -> b c n')
                # 对注意力输出进行最大池化
                pooled_tokens = F.max_pool1d(attended_tokens, kernel_size = num_layers, stride = num_layers)
                # 将池化后的 token 添加到输出中
                x += rearrange(pooled_tokens, 'b c n -> b n c')

        # 返回最终输出
        return x
# 定义一个名为 OmninetCausal 的类,用于处理因果关系的情况,采用轴向注意力层,直到重写线性注意力的 CUDA 内核
class OmninetCausal(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        pool_layer_tokens_every = 2,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()

        # 初始化层位置嵌入参数
        self.layer_pos_emb = nn.Parameter(torch.randn(depth + 1, dim))

        # 初始化层列表
        layers = nn.ModuleList([])
        for ind in range(depth):
            num_layers = ind + 1
            should_pool = num_layers % pool_layer_tokens_every

            # 添加每一层的注意力、前馈和轴向注意力(如果需要池化)到层列表中
            layers.append(nn.ModuleList([
                PreNorm(dim, Attention(causal = True, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)),
                Attention(dim = dim, heads= heads, dim_head = dim_head) if should_pool else None
            ]))

        self.layers = layers
        self.pool_num_layers = pool_layer_tokens_every

    # 前向传播函数
    def forward(self, x, mask = None):
        pool_num_layers = self.pool_num_layers

        b = x.shape[0]
        pos_embs = rearrange(self.layer_pos_emb, 'n d -> () n d')

        x += pos_embs[:, 0]
        hiddens = [x]

        for ind, (attn, ff, layer_axial_attn) in enumerate(self.layers):

            # 执行注意力层操作
            x = attn(x, mask = mask) + x
            # 执行前馈层操作
            x = ff(x) + x

            x += pos_embs[:, ind + 1]
            hiddens.append(x)

            if exists(layer_axial_attn):
                layers_to_pool = hiddens[-pool_num_layers:]
                num_layers = len(layers_to_pool)

                # 重排层的 tokens,并进行轴向注意力操作
                layer_tokens = rearrange(torch.stack(layers_to_pool), 'l b n d -> (b n) l d')

                attended_tokens = layer_axial_attn(layer_tokens)
                attended_tokens = rearrange(attended_tokens, '(b n) l d -> b n l d', b = b)
                pooled_attended_tokens = attended_tokens.max(dim = -2).values
                x += pooled_attended_tokens

        return x

.\lucidrains\omninet-pytorch\omninet_pytorch\__init__.py

# 从 omninet_pytorch 模块中导入 Omninet 和 OmninetCausal 类
from omninet_pytorch.omninet_pytorch import Omninet, OmninetCausal

Omninet - Pytorch

Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch. The authors propose that we should be attending to all the tokens of the previous layers, leveraging recent efficient attention advances to achieve this goal.

Install

$ pip install omninet-pytorch

Usage

import torch
from omninet_pytorch import Omninet

omninet = Omninet(
    dim = 512,                     # model dimension
    depth = 6,                     # depth
    dim_head = 64,                 # dimension per head
    heads = 8,                     # number of heads
    pool_layer_tokens_every = 3,   # key to this paper - every N layers, omni attend to all tokens of all layers
    attn_dropout = 0.1,            # attention dropout
    ff_dropout = 0.1,              # feedforward dropout
    feature_redraw_interval = 1000 # how often to redraw the projection matrix for omni attention net - Performer
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

omninet(x, mask = mask) # (1, 1024, 512)

Causal case, just use the class OmninetCausal. At the moment, it isn't faithful to the paper (I am using layer axial attention with layer positional embeddings to draw up information), but will fix this once I rework the linear attention CUDA kernel.

import torch
from omninet_pytorch import OmninetCausal

omninet = OmninetCausal(
    dim = 512,                     # model dimension
    depth = 6,                     # depth
    dim_head = 64,                 # dimension per head
    heads = 8,                     # number of heads
    pool_layer_tokens_every = 3,   # key to this paper - every N layers, omni attend to all tokens of all layers
    attn_dropout = 0.1,            # attention dropout
    ff_dropout = 0.1               # feedforward dropout
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

omninet(x, mask = mask) # (1, 1024, 512)

Citations

@misc{tay2021omninet,
    title   = {OmniNet: Omnidirectional Representations from Transformers}, 
    author  = {Yi Tay and Mostafa Dehghani and Vamsi Aribandi and Jai Gupta and Philip Pham and Zhen Qin and Dara Bahri and Da-Cheng Juan and Donald Metzler},
    year    = {2021},
    eprint  = {2103.01075},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\omninet-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'omninet-pytorch', # 包的名称
  packages = find_packages(), # 查找并包含所有包
  version = '0.0.6', # 版本号
  license='MIT', # 许可证
  description = 'Omninet - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/omninet-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformer',
    'attention mechanism'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6',
    'performer-pytorch'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\PaLM-jax\palm_jax\palm.py

# 导入所需的模块和库
from typing import List, Tuple

import numpy as onp
from jax import random, nn, lax, jit, numpy as np
from jax.numpy import einsum

from equinox import Module, static_field
from einops import rearrange, repeat

# bias-less layernorm

class LayerNorm(Module):
    gamma: np.ndarray
    eps: float = static_field()

    def __init__(self, dim, eps = 1e-5):
        # 初始化 LayerNorm 类,设置 gamma 和 eps 属性
        self.gamma = np.ones((dim,))
        self.eps = eps

    def __call__(self, x):
        # 计算均值和均方差
        mean = np.mean(x, axis = -1, keepdims = True)
        mean_of_squares = np.mean(np.square(x), axis = -1, keepdims = True)
        variance = mean_of_squares - np.square(mean)
        inv = lax.rsqrt(variance + self.eps)
        # 返回 LayerNorm 结果
        return inv * (x - mean) * self.gamma

# Rotary embedding

def fixed_pos_embedding(inv_freq, seq):
    # 生成固定位置嵌入的正弦和余弦值
    sinusoid_inp = einsum('i , j -> i j', np.arange(seq), inv_freq)
    sinusoid_inp = repeat(sinusoid_inp, '... d -> ... (d r)', r = 2)
    return np.sin(sinusoid_inp), np.cos(sinusoid_inp)

def rotate_every_two(x):
    # 将输入张量中的每两个元素进行旋转
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x[..., 0], x[..., 1]
    x = np.stack((-x2, x1), axis = -1)
    return rearrange(x, '... d r -> ... (d r)')

def apply_rotary_pos_emb(x, sincos):
    sin, cos = sincos
    # 应用旋转位置嵌入
    return (x * cos) + (rotate_every_two(x) * sin)

# attention - multi-query, one-headed key / values variant
# feedforward - Shazeer's SwiGLU variant

class ParallelTransformerBlock(Module):
    norm: Module
    wi: np.ndarray
    attn_wo: np.ndarray
    ff_wo: np.ndarray

    heads: int = static_field()
    fused_dims: Tuple[int] = static_field()
    scale: float = static_field()
    mask_value: float = static_field()

    def __init__(
        self,
        dim,
        dim_head,
        heads,
        key,
        ff_mult = 4,
        mask_value = -1e10
    ):
        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.norm = LayerNorm(dim)
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, ff_inner_dim, ff_inner_dim)

        self.wi = random.normal(key, (dim, sum(self.fused_dims)))
        self.attn_wo = random.normal(key, (attn_inner_dim, dim))
        self.ff_wo = random.normal(key, (ff_inner_dim, dim))

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.mask_value = mask_value

    def __call__(self, x, *, pos_emb, causal_mask):
        n, split_indices = x.shape[-2], onp.cumsum(self.fused_dims[:-1])

        x = self.norm(x)

        # fused attention and feedforward projections

        q, k, v, ff, ff_gate = np.split(x @ self.wi, split_indices, axis = -1)

        # split out heads

        q = rearrange(q, '... n (h d) -> ... h n d', h = self.heads)

        # scale

        q *= self.scale

        # apply rotary embeddings

        q, k = map(lambda t: apply_rotary_pos_emb(t, pos_emb), (q, k))

        # sim

        sim = einsum('... h i d, ... j d -> ... h i j', q, k)

        # causal mask

        sim = np.where(causal_mask, sim, self.mask_value)

        # attention

        attn = nn.softmax(sim, axis = -1)

        # aggregate values

        out = einsum('... h i j, ... j d -> ... h i d', attn, v)

        # merge heads

        out = rearrange(out, '... h n d -> ... n (h d)')

        # feedforward out

        attn_out = out @ self.attn_wo

        ff_out = (ff * nn.swish(ff_gate)) @ self.ff_wo

        # combine heads out

        return attn_out + ff_out

# main class

class PaLM(Module):
    embedding: np.ndarray
    norm: Module
    layers: List[List[Module]]
    inv_freq: onp.ndarray = static_field()

    def __init__(
        self,
        *,
        num_tokens,
        dim,
        dim_head,
        depth,
        heads,
        key,
        ff_mult = 4
    # 初始化 Transformer 模型的参数
    ):
        # 使用正态分布随机初始化嵌入矩阵,乘以0.02缩放
        self.embedding = random.normal(key, (num_tokens, dim)) * 0.02
        # 计算位置编码的倒数频率
        self.inv_freq = 1.0 / (10000 ** (np.arange(0, dim_head, 2) / dim_head))

        # 创建 Transformer 模型的多个层
        self.layers = [ParallelTransformerBlock(dim = dim, dim_head = dim_head, heads = heads, ff_mult = ff_mult, key = key) for _ in range(depth)]
        # 初始化 LayerNorm 层
        self.norm = LayerNorm(dim)

    # 定义 JIT 编译的调用函数
    @jit
    def __call__(self, x):
        # 获取输入张量 x 的最后一个维度大小
        n = x.shape[-1]
        # 使用嵌入矩阵将输入 x 映射到嵌入空间
        x = self.embedding[x]

        # 生成固定的位置编码
        rotary_emb = fixed_pos_embedding(self.inv_freq, n)
        # 生成因果掩码,下三角矩阵
        causal_mask = np.tril(np.ones((n, n)))

        # 遍历 Transformer 模型的每个层进行前向传播
        for block in self.layers:
            # 调用每个层的前向传播函数,更新输入 x
            x = block(x, pos_emb = rotary_emb, causal_mask = causal_mask) + x

        # 对输出 x 进行 LayerNorm 处理
        x = self.norm(x)
        # 返回最终输出,执行嵌入矩阵的转置乘积
        return x @ self.embedding.transpose()

.\lucidrains\PaLM-jax\palm_jax\palm_lite.py

# 从 math 模块中导入 log2 和 floor 函数
# 从 typing 模块中导入 List 和 Tuple 类型
import numpy as onp
# 从 jax 模块中导入 random, jit, nn, lax, numpy 模块,并将 numpy 模块重命名为 np
from jax import random, jit, nn, lax, numpy as np
# 从 jax.numpy 模块中导入 einsum 函数
from jax.numpy import einsum
# 从 equinox 模块中导入 Module, static_field 类
from equinox import Module, static_field
# 从 einops 模块中导入 rearrange, repeat 函数

# 定义 RMSNorm 类,继承自 Module 类
class RMSNorm(Module):
    # 定义类属性 gamma, scale, eps
    gamma: np.ndarray
    scale: float = static_field()
    eps: float = static_field()

    # 初始化方法,接受 dim 和 eps 两个参数
    def __init__(self, dim, eps = 1e-5):
        # 初始化 gamma 为全为 1 的数组
        self.gamma = np.ones((dim,))
        self.eps = eps
        self.scale = dim ** 0.5

    # 定义 __call__ 方法,接受参数 x
    def __call__(self, x):
        # 计算 x 的平方和,并在最后一个维度上保持维度
        sum_of_squares = np.sum(np.square(x), axis = -1, keepdims = True)
        # 计算 sum_of_squares 加上 eps 的平方根的倒数
        inv_norm = lax.rsqrt(sum_of_squares + self.eps)
        # 返回 inv_norm 乘以 x 乘以 gamma 乘以 scale 的结果
        return inv_norm * x * self.gamma * self.scale

# 定义 get_alibi_slopes 函数,接受 heads 参数
def get_alibi_slopes(heads):
    # 定义内部函数 get_slopes_power_of_2,接受 n 参数
    def get_slopes_power_of_2(n):
        # 计算起始值 start
        start = (2 ** (-2 ** -(log2(n) - 3)))
        ratio = start
        # 返回等比数列
        return [start*ratio**i for i in range(n)]

    # 如果 heads 的对数是整数
    if log2(heads).is_integer():
        # 返回 get_slopes_power_of_2(heads) 的结果
        return get_slopes_power_of_2(heads)

    # 计算最接近 heads 的 2 的幂次方
    closest_power_of_2 = 2 ** floor(log2(heads))
    # 返回 get_slopes_power_of_2(closest_power_of_2) 和 get_slopes_power_of_2(2 * closest_power_of_2) 的结果
    return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

# 定义 calc_alibi_bias 函数,接受 seq_len 和 heads 两个参数
def calc_alibi_bias(seq_len, heads):
    # 获取斜率
    slopes = get_alibi_slopes(heads)
    # 重排 slopes 数组的维度
    slopes = rearrange(onp.array(slopes), 'h -> h 1 1')
    # 生成偏置
    bias = rearrange(onp.arange(seq_len), 'j -> 1 1 j')
    return slopes * bias

# 定义 ParallelTransformerBlock 类,继承自 Module 类
class ParallelTransformerBlock(Module):
    # 定义类属性 norm, wi, attn_wo, ff_wo, heads, fused_dims, scale, mask_value
    norm: Module
    wi: np.ndarray
    attn_wo: np.ndarray
    ff_wo: np.ndarray
    heads: int = static_field()
    fused_dims: Tuple[int] = static_field()
    scale: float = static_field()
    mask_value: float = static_field()

    # 初始化方法,接受 dim, dim_head, heads, key, ff_mult, mask_value 参数
    def __init__(
        self,
        dim,
        dim_head,
        heads,
        key,
        ff_mult = 4,
        mask_value = -1e10
    ):
        # 计算注意力内部维度和前馈内部维度
        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        # 初始化 norm 为 RMSNorm 类的实例
        self.norm = RMSNorm(dim)
        self.fused_dims = (attn_inner_dim, dim_head, ff_inner_dim, ff_inner_dim)

        # 初始化 wi, attn_wo, ff_wo 为随机正态分布的数组
        self.wi = random.normal(key, (dim, sum(self.fused_dims)))
        self.attn_wo = random.normal(key, (attn_inner_dim, dim))
        self.ff_wo = random.normal(key, (ff_inner_dim, dim))

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.mask_value = mask_value

    # 定义 __call__ 方法,接受 x 和 attn_bias 两个参数
    def __call__(self, x, *, attn_bias):
        # 获取 x 的倒数第二个维度的大小和分割索引
        n, split_indices = x.shape[-2], onp.cumsum(self.fused_dims[:-1])

        # 对 x 进行归一化
        x = self.norm(x)

        # 融合注意力和前馈的投影

        q, kv, ff, ff_gate = np.split(x @ self.wi, split_indices, axis = -1)

        # 分割出头部

        q = rearrange(q, '... n (h d) -> ... h n d', h = self.heads)

        # 缩放

        q *= self.scale

        # 相似度

        sim = einsum('... h i d, ... j d -> ... h i j', q, kv)

        # 因果掩码

        sim = sim + attn_bias

        # 注意力

        attn = nn.softmax(sim, axis = -1)

        # 聚合值

        out = einsum('... h i j, ... j d -> ... h i d', attn, kv)

        # 合并头部

        out = rearrange(out, '... h n d -> ... n (h d)')

        # 前馈输出

        attn_out = out @ self.attn_wo

        ff_out = (ff * nn.swish(ff_gate)) @ self.ff_wo

        # 合并头部输出

        return attn_out + ff_out

# 主类

class PaLM(Module):
    # 定义类属性 embedding, norm, layers, attn_bias
    embedding: np.ndarray
    norm: Module
    layers: List[List[Module]]
    attn_bias: onp.ndarray = static_field()

    # 初始化方法,接受 num_tokens, dim, dim_head, depth, heads, key, ff_mult, max_seq_len, mask_value 参数
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        dim_head,
        depth,
        heads,
        key,
        ff_mult = 4,
        max_seq_len = 2048,
        mask_value = -1e10
        self.embedding = random.normal(key, (num_tokens, dim)) * 0.02
        # 初始化嵌入矩阵,使用正态分布生成随机值,并乘以0.02

        causal_mask = onp.tril(onp.ones((max_seq_len, max_seq_len)))
        # 创建一个下三角矩阵作为因果掩码
        alibi_bias = calc_alibi_bias(max_seq_len, heads = heads)
        # 计算alibi偏置
        self.attn_bias = np.where(causal_mask, repeat(alibi_bias, 'h 1 j -> h i j', i = max_seq_len), mask_value)
        # 根据因果掩码和alibi偏置生成注意力偏置矩阵

        self.layers = [ParallelTransformerBlock(dim = dim, dim_head = dim_head, heads = heads, key = key, ff_mult = ff_mult) for _ in range(depth)]
        # 创建多个并行Transformer块
        self.norm = RMSNorm(dim)
        # 初始化RMS归一化层

    @jit
    def __call__(self, x):
        # 定义类的调用方法,输入x
        n = x.shape[-1]
        # 获取输入x的最后一个维度大小
        x = self.embedding[x]
        # 使用嵌入矩阵将输入x转换为嵌入向量

        attn_bias = self.attn_bias[..., :n, :n]
        # 获取与输入长度相关的注意力偏置

        for block in self.layers:
            # 遍历每个Transformer块
            x = block(x, attn_bias = attn_bias) + x
            # 对输入x进行Transformer块的处理,并将结果与原始输入相加

        x = self.norm(x)
        # 对处理后的结果进行RMS归一化
        return x @ self.embedding.transpose()
        # 返回结果与嵌入矩阵的转置矩阵的乘积

.\lucidrains\PaLM-jax\palm_jax\utils.py

# 导入所需的库
from jax import random
from jax.lax import top_k
import jax.numpy as np

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 计算对数,加上一个很小的值以避免出现对数零的情况
def log(t, eps = 1e-20):
    return np.log(t + eps)

# 采样函数

# 选择概率最高的前 k 个元素
def select_top_k(tensor, k):
    values, _ = top_k(tensor, k)
    mask = tensor > values.min()
    return mask, np.where(mask, tensor, 0.)

# 生成 Gumbel 噪声
def gumbel_noise(key, shape):
    noise = random.uniform(key, shape = shape, minval = 0., maxval = 1.)
    return -log(-log(noise))

# 生成样本序列
def sample(key, model, prime, length, top_k = None):
    start_pos = prime.shape[-1]
    seq = np.pad(prime, (0, length - prime.shape[-1]))
    one_hots = np.eye(length, dtype = int)

    for curr_pos in range(start_pos, length):
        logits = model(seq)
        logits = logits[curr_pos - 1]

        _, key = random.split(key)
        noise = gumbel_noise(key, logits.shape)

        if exists(top_k):
            mask, logits = select_top_k(logits, top_k)
            noise *= mask

        logits += noise
        sampled_ind = np.argmax(logits, axis = -1)

        one_hot = one_hots[curr_pos]
        seq += one_hot * sampled_ind

    return seq

.\lucidrains\PaLM-jax\palm_jax\__init__.py

# 从 palm_jax.palm 模块中导入 PaLM 类
from palm_jax.palm import PaLM

PaLM - Jax

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax using Equinox

May as well start doing more Jax work, given Facebook (Meta's) uncertain future

Pytorch version

Flax version from Enrico!

Install

$ pip install PaLM-jax

Usage

The way the model is built doesn't require vmap at all. It can have any number of leading dimensions

import jax
from palm_jax import PaLM

key = jax.random.PRNGKey(0)

model = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    heads = 8,
    dim_head = 64,
    key = key
)

seq = jax.random.randint(key, (1, 1024), 0, 20000)

logits = model(seq) # (1, 1024, 20000)

The 540B PaLM in the paper would be


model = PaLM(
    num_tokens = 256000,
    dim = 18432,
    depth = 118,
    heads = 48,
    dim_head = 256,
    key = key
)

That's all it is. Attention (and scale) is all we need.

Todos

Citations

@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
@misc{press2021ALiBi,
    title   = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
    author  = {Ofir Press and Noah A. Smith and Mike Lewis},
    year    = {2021},
    url     = {https://ofir.io/train_short_test_long.pdf}
}
@article{Rae2021ScalingLM,
    title   = {Scaling Language Models: Methods, Analysis \& Insights from Training Gopher},
    author  = {Jack W. Rae and Sebastian Borgeaud and Trevor Cai and Katie Millican and Jordan Hoffmann and Francis Song and John Aslanides and Sarah Henderson and Roman Ring and Susannah Young and Eliza Rutherford and Tom Hennigan and Jacob Menick and Albin Cassirer and Richard Powell and George van den Driessche and Lisa Anne Hendricks and Maribeth Rauh and Po-Sen Huang and Amelia Glaese and Johannes Welbl and Sumanth Dathathri and Saffron Huang and Jonathan Uesato and John F. J. Mellor and Irina Higgins and Antonia Creswell and Nathan McAleese and Amy Wu and Erich Elsen and Siddhant M. Jayakumar and Elena Buchatskaya and David Budden and Esme Sutherland and Karen Simonyan and Michela Paganini and L. Sifre and Lena Martens and Xiang Lorraine Li and Adhiguna Kuncoro and Aida Nematzadeh and Elena Gribovskaya and Domenic Donato and Angeliki Lazaridou and Arthur Mensch and Jean-Baptiste Lespiau and Maria Tsimpoukelli and N. K. Grigorev and Doug Fritz and Thibault Sottiaux and Mantas Pajarskas and Tobias Pohlen and Zhitao Gong and Daniel Toyama and Cyprien de Masson d'Autume and Yujia Li and Tayfun Terzi and Vladimir Mikulik and Igor Babuschkin and Aidan Clark and Diego de Las Casas and Aurelia Guy and Chris Jones and James Bradbury and Matthew G. Johnson and Blake A. Hechtman and Laura Weidinger and Iason Gabriel and William S. Isaac and Edward Lockhart and Simon Osindero and Laura Rimell and Chris Dyer and Oriol Vinyals and Kareem W. Ayoub and Jeff Stanway and L. L. Bennett and Demis Hassabis and Koray Kavukcuoglu and Geoffrey Irving},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2112.11446}
}
@inproceedings{Zhang2019RootMS,
    title   = {Root Mean Square Layer Normalization},
    author  = {Biao Zhang and Rico Sennrich},
    booktitle = {NeurIPS},
    year    = {2019}
}

.\lucidrains\PaLM-jax\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'PaLM-jax',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.1.2',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'PaLM: Scaling Language Modeling with Pathways - Jax',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/PaLM-jax',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism'
  ],
  # 安装依赖
  install_requires=[
    'einops==0.4',
    'equinox>=0.5',
    'jax>=0.3.4',
    'jaxlib>=0.1',
    'optax',
    'numpy'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\PaLM-jax\train.py

# 导入必要的库
import os
from random import randrange
from functools import partial
import tqdm
import gzip
import numpy as np

import jax
import jax.numpy as jnp
from jax import nn

# 导入自定义库
import equinox as eqx
from optax import adam, clip_by_global_norm, chain, apply_every

# 导入自定义模块
from palm_jax.palm_lite import PaLM
from palm_jax.utils import sample

# 设置环境变量
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
MAX_GRAD_NORM = 0.5
VALIDATE_EVERY  = 100
SAMPLE_EVERY  = 500
SEQ_LEN = 1024

# 定义循环生成器函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 解码单个 token 函数
def decode_token(token):
    return str(chr(max(32, token)))

# 解码一组 tokens 函数
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 读取 enwik8 数据集
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    data_train, data_val = np.split(X, [int(90e6)])

# 从数据集中采样序列函数
def sample_seq_from_data(data, *, seq_len, batch_size):
    total_seq_len = data.shape[0]
    base_arange = np.arange(seq_len)
    start_indices = np.random.randint(0, total_seq_len - seq_len, (batch_size,))
    token_indices = start_indices[:, None] + base_arange
    return data[token_indices]

# 部分应用采样序列函数
sample_seq_fn = partial(sample_seq_from_data, seq_len = SEQ_LEN, batch_size = BATCH_SIZE)

# 初始化 PRNGKey
key = jax.random.PRNGKey(0)

# 初始化 PaLM 模型
model = PaLM(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    heads = 8,
    dim_head = 64,
    key = key
)

# 交叉熵损失函数
def cross_entropy(logits, targets, axis = -1):
    logprobs = nn.log_softmax(logits, axis = axis)
    nll = jnp.take_along_axis(logprobs, jnp.expand_dims(targets, axis = axis), axis = axis)
    cross_entropy = -jnp.mean(nll)
    return cross_entropy

# 定义损失函数
@eqx.filter_value_and_grad
def loss_fn(model, data):
    inp, labels = data[:, :-1], data[:, 1:]
    logits = model(inp)
    return cross_entropy(logits, labels, axis = -1)

# 初始化优化器
optim = chain(
    clip_by_global_norm(MAX_GRAD_NORM),
    adam(LEARNING_RATE),
    apply_every(GRADIENT_ACCUMULATE_EVERY)
)

optim_state = optim.init(model)

# 训练步骤
@eqx.filter_jit(kwargs=dict(data=True))
def train_step(model, data, optim_state):
    loss, grads = loss_fn(model, data)
    updates, optim_state = optim.update(grads, optim_state)
    model = eqx.apply_updates(model, updates)
    return model, optim_state, loss

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        data = sample_seq_fn(data_train)
        model, optim_state, loss = train_step(model, data, optim_state)

    print(f'loss: {loss.item()}')

    if i % SAMPLE_EVERY == 0:
        valid_data = sample_seq_fn(data_val)
        prime = valid_data[0][:100]
        prime_str = decode_tokens(prime)
        print(prime_str, "\n", "*" * 40)

        sampled = sample(key, model, prime, SEQ_LEN, top_k = 25)
        sampled_str = decode_tokens(sampled[100:])
        print(sampled_str)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

PaLM-pytorch with Deepspeed for Enwik8

Deepspeed is the framework Microsoft used to train the world's largest Attention model (17GB) to date. They have open sourced it, and it works with PaLM Pytorch!

  1. First install Deepspeed following instructions from their official repository https://github.com/microsoft/DeepSpeed

  2. Run the following command in this folder

$ deepspeed train.py --deepspeed --deepspeed_config ds_config.json

.\lucidrains\PaLM-pytorch\examples\enwik8_deepspeed\train.py

import deepspeed
# 导入 deepspeed 库

from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# 从 palm_pytorch 库中导入 PaLM 类和 AutoregressiveWrapper 类

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from einops import rearrange
from torch import einsum, nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
# 导入所需的库

def add_argument():
    parser=argparse.ArgumentParser(description='enwik8')
    # 创建参数解析器对象

    parser.add_argument('--with_cuda', default=False, action='store_true',
                        help='use CPU in case there\'s no GPU support')
    parser.add_argument('--use_ema', default=False, action='store_true',
                        help='whether use exponential moving average')
    parser.add_argument('-b', '--batch_size', default=32, type=int,
                        help='mini-batch size (default: 32)')
    parser.add_argument('-e', '--epochs', default=30, type=int,
                        help='number of total epochs (default: 30)')
    parser.add_argument('--local_rank', type=int, default=-1,
                       help='local rank passed from distributed launcher')
    # 添加命令行参数

    parser = deepspeed.add_config_arguments(parser)
    # 添加 deepspeed 配置参数
    args=parser.parse_args()
    return args
# 定义函数用于添加参数

# constants

EPOCHS = 20
GRADIENT_ACCUMULATE_EVERY = 4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024
# 定义常量

# helpers

def decode_token(token):
    return str(chr(max(32, token)))
# 定义函数用于解码单个 token

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))
# 定义函数用于解码多个 tokens

# instantiate GPT-like decoder model

model = PaLM(num_tokens = 256, dim = 512, depth = 8)
# 实例化 PaLM 模型对象,设置参数

model = AutoregressiveWrapper(model, max_seq_len=2048)
# 使用 AutoregressiveWrapper 对象包装模型,设置最大序列长度

model.cuda()
# 将模型移动到 GPU 上

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 读取并准备数据集

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq

    def __len__(self):
        return self.data.size(0) // self.seq_len
# 定义数据集类

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建训练集和验证集对象

# setup deepspeed

cmd_args = add_argument()
# 调用添加参数函数

model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=model.parameters(), training_data=train_dataset)
# 使用 deepspeed 初始化模型引擎、优化器、训练数据加载器

# training

for _ in range(EPOCHS):
    for i, data in enumerate(trainloader):
        model_engine.train()
        # 设置模型为训练模式
        data = data.to(model_engine.local_rank)
        # 将数据移动到指定设备
        loss = model_engine(data)
        # 计算损失
        model_engine.backward(loss)
        # 反向传播
        torch.nn.utils.clip_grad_norm_(model_engine.parameters(), 0.5)
        # 对梯度进行裁剪
        model_engine.step()
        # 更新模型参数
        print(loss.item() * GRADIENT_ACCUMULATE_EVERY)
        # 打印损失值

        if i % VALIDATE_EVERY == 0:
            model.eval()
            # 设置模型为评估模式
            with torch.no_grad():
                inp = random.choice(val_dataset)[:-1]
                loss = model(inp[None, :].cuda())
                # 计算验证集损失
                print(f'validation loss: {loss.item()}')

        if i % GENERATE_EVERY == 0:
            model.eval()
            # 设置模型为评估模式
            inp = random.choice(val_dataset)[:-1]
            prime = decode_tokens(inp)
            print(f'%s \n\n %s', (prime, '*' * 100))
            # 打印生成的文本

            sample = model.generate(inp[None, ...].cuda(), GENERATE_LENGTH)
            output_str = decode_tokens(sample[0])
            print(output_str)
            # 生成文本并打印

.\lucidrains\PaLM-pytorch\palm_pytorch\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 模块中导入 nn 模块
from torch import nn

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数用于进行 top k 过滤
def top_k(logits, thres=0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float("-inf"))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个自回归封装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, max_seq_len=2048, pad_value=0):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.pad_value = pad_value
        self.net = net

    # 生成函数,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        start_tokens,
        seq_len,
        eos_token=None,
        temperature=1.0,
        filter_thres=0.9,
        **kwargs
    ):
        b, t, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            logits = self.net(out, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres=filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                is_eos_token = out == eos_token

                if is_eos_token.any(dim=-1).all():
                    # mask out everything after the eos tokens
                    shifted_is_eos_tokens = F.pad(is_eos_token, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
                    out = out.masked_fill(mask, self.pad_value)
                    break

        out = out[:, t:]
        return out

    # 前向传播函数,用于计算损失
    def forward(self, x, **kwargs):
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        logits = self.net(x_inp, **kwargs)
        return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)

.\lucidrains\PaLM-pytorch\palm_pytorch\palm_lite.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 torch 库中导入 einsum 和 nn 模块
from torch import einsum, nn
# 从 math 库中导入 log2 和 floor 函数
from math import log2, floor

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# normalization

# 定义 RMSNorm 类,继承自 nn.Module
class RMSNorm(nn.Module):
    # 初始化函数
    def __init__(self, dim, eps = 1e-8):
        super().__init__()
        # 初始化缩放因子
        self.scale = dim ** -0.5
        # 初始化 eps
        self.eps = eps
        # 创建可学习参数 g
        self.g = nn.Parameter(torch.ones(dim))

    # 前向传播函数
    def forward(self, x):
        # 计算输入张量 x 的 L2 范数
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        # 返回归一化后的结果
        return x / norm.clamp(min = self.eps) * self.g

# AliBi

# 定义 AlibiPositionalBias 类,继承自 nn.Module
class AlibiPositionalBias(nn.Module):
    # 初始化函数
    def __init__(self, heads, **kwargs):
        super().__init__()
        # 初始化头数
        self.heads = heads
        # 计算斜率
        slopes = torch.Tensor(self._get_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        # 注册缓冲区 slopes 和 bias
        self.register_buffer('slopes', slopes, persistent = False)
        self.register_buffer('bias', None, persistent = False)
    
    # 获取偏置
    def get_bias(self, i, j, device):
        i_arange = torch.arange(i, device = device)
        j_arange = torch.arange(j, device = device)
        bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
        return bias

    # 静态方法,获取斜率
    @staticmethod
    def _get_slopes(heads):
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]

        if log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2 ** floor(log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

    # 前向传播函数
    def forward(self, qk_sim):
        h, i, j, device = *qk_sim.shape[-3:], qk_sim.device

        if exists(self.bias) and self.bias.shape[-1] >= j:
            return self.bias[..., :i, :j]

        bias = self.get_bias(i, j, device)
        bias = bias * self.slopes

        num_heads_unalibied = h - bias.shape[0]
        bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
        self.register_buffer('bias', bias, persistent=False)

        return bias

# residual

# 定义 Residual 类,继承自 nn.Module
class Residual(nn.Module):
    # 初始化函数
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    # 前向传播函数
    def forward(self, x):
        return self.fn(x) + x

# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202

# 定义 SwiGLU 类,继承自 nn.Module
class SwiGLU(nn.Module):
    # 前向传播函数
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame

# 定义 ParallelTransformerBlock 类,继承自 nn.Module
class ParallelTransformerBlock(nn.Module):
    # 初始化函数
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        # 初始化 RMSNorm 层
        self.norm = RMSNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, (ff_inner_dim * 2))

        self.heads = heads
        self.scale = dim_head**-0.5

        # 初始化 AlibiPositionalBias 层
        self.alibi_pos_biases = AlibiPositionalBias(heads = self.heads)

        # 初始化线性变换层
        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # for caching causal mask

        self.register_buffer("mask", None, persistent=False)

    # 获取掩码
    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), 1)
        self.register_buffer("mask", mask, persistent=False)
        return mask
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):

        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """
        # 获取输入张量 x 的形状信息
        n, device, h = x.shape[1], x.device, self.heads

        # 对输入张量 x 进行预层归一化处理
        x = self.norm(x)

        # 获取注意力查询、键或值(共享键/值是我个人的发现)和前馈内部
        q, kv, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # 分割头部
        # 他们使用多查询单键值注意力,又一篇 Noam Shazeer 的论文
        # 他们发现在一定规模之后没有性能损失,而且解码更有效
        # https://arxiv.org/abs/1911.02150

        # 重新排列查询张量 q 的形状
        q = rearrange(q, "b n (h d) -> b h n d", h = h)

        # 缩放
        q = q * self.scale

        # 相似度计算
        sim = einsum("b h i d, b j d -> b h i j", q, kv)

        # 添加 alibi 偏置
        sim = sim + self.alibi_pos_biases(sim)

        # 因果掩码
        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力计算
        attn = sim.softmax(dim=-1)
        out = einsum("b h i j, b j d -> b h i d", attn, kv)

        # 合并头部
        out = rearrange(out, "b h n d -> b n (h d)")

        # 合并头部并通过注意力输出和前馈输出层
        merge_heads = self.attn_out(out) + self.ff_out(ff)
        return merge_heads
# 定义一个函数PaLM,使用关键字参数,接受模型的维度dim、标记数量num_tokens、层数depth、头部维度dim_head、头部数量heads、前馈网络倍增ff_mult作为参数
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):

    # 创建一个神经网络模型,包括嵌入层、多个平行Transformer块、RMSNorm层和线性层
    net = nn.Sequential(
        nn.Embedding(num_tokens, dim), # 嵌入层,将标记映射到指定维度的向量
        *[Residual(ParallelTransformerBlock(dim, dim_head, heads, ff_mult)) for _ in range(depth)], # 多个平行Transformer块
        RMSNorm(dim), # RMSNorm层
        nn.Linear(dim, num_tokens, bias=False) # 线性层,将维度映射回标记数量
    )

    # 将最后一层的权重设置为与第一层嵌入层的权重相同,实现权重共享
    net[-1].weight = net[0].weight

    # 对第一层嵌入层的权重进行正态分布初始化
    nn.init.normal_(net[0].weight, std=0.02)
    
    # 返回神经网络模型
    return net

# 主函数,用于测试模型的功能
if __name__ == "__main__":

    # 创建一个PaLM模型实例
    palm = PaLM(
        num_tokens = 20000,
        dim = 512,
        depth = 1,
        heads = 8,
        dim_head = 64,
    )

    # 生成随机标记序列
    tokens = torch.randint(0, 20000, (1, 2048))
    # 输入标记序列到模型,得到预测结果logits
    logits = palm(tokens) # (1, 2048, 20000)

    # 统计模型中可训练参数的数量
    n_params_torch = sum(
        p.numel() for p in palm.parameters() if p.requires_grad
    )

    # 打印模型中可训练参数的数量
    print(f"Number of parameters in torch model: {n_params_torch}")

.\lucidrains\PaLM-pytorch\palm_pytorch\palm_pytorch.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 einsum 和 nn 模块
from torch import einsum, nn

# normalization
# they use layernorm without bias, something that pytorch does not offer

# 定义 LayerNorm 类,继承自 nn.Module
class LayerNorm(nn.Module):
    # 初始化函数
    def __init__(self, dim):
        super().__init__()
        # 创建可学习参数 gamma
        self.gamma = nn.Parameter(torch.ones(dim))
        # 创建 buffer beta
        self.register_buffer("beta", torch.zeros(dim))

    # 前向传播函数
    def forward(self, x):
        # 使用 F.layer_norm 进行层归一化
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# residual

# 定义 Residual 类,继承自 nn.Module
class Residual(nn.Module):
    # 初始化函数
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    # 前向传播函数
    def forward(self, x):
        # 返回残差连接结果
        return self.fn(x) + x

# rotary positional embedding
# https://arxiv.org/abs/2104.09864

# 定义 RotaryEmbedding 类,继承自 nn.Module
class RotaryEmbedding(nn.Module):
    # 初始化函数
    def __init__(self, dim):
        super().__init__()
        # 计算频率
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 创建 buffer inv_freq
        self.register_buffer("inv_freq", inv_freq)

    # 前向传播函数
    def forward(self, max_seq_len, *, device):
        # 生成序列
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        # 计算频率
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        # 拼接频率
        return torch.cat((freqs, freqs), dim=-1)

# 旋转位置嵌入
def rotate_half(x):
    # 重新排列张量维度
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    # 拆分张量
    x1, x2 = x.unbind(dim=-2)
    # 拼接张量
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    # 计算旋转位置嵌入
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())

# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202

# 定义 SwiGLU 类,继承自 nn.Module
class SwiGLU(nn.Module):
    # 前向传播函数
    def forward(self, x):
        # 拆分张量
        x, gate = x.chunk(2, dim=-1)
        # 使用 SiLU 激活函数
        return F.silu(gate) * x

# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame

# 定义 ParallelTransformerBlock 类,继承自 nn.Module
class ParallelTransformerBlock(nn.Module):
    # 初始化函数
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        # 归一化层
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        self.heads = heads
        self.scale = dim_head**-0.5
        self.rotary_emb = RotaryEmbedding(dim_head)

        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # for caching causal mask and rotary embeddings

        self.register_buffer("mask", None, persistent=False)
        self.register_buffer("pos_emb", None, persistent=False)

    # 获取掩码
    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    # 获取旋转嵌入
    def get_rotary_embedding(self, n, device):
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n]

        pos_emb = self.rotary_emb(n, device=device)
        self.register_buffer("pos_emb", pos_emb, persistent=False)
        return pos_emb
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取输入张量 x 的形状信息
        n, device, h = x.shape[1], x.device, self.heads

        # 对输入张量 x 进行 LayerNorm 处理
        x = self.norm(x)

        # 使用融合的注意力和前馈神经网络投影层对输入张量 x 进行投影
        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # 将投影后的张量按照指定维度进行分割,用于多头注意力
        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # 获取旋转位置嵌入
        positions = self.get_rotary_embedding(n, device)
        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

        # 缩放
        q = q * self.scale

        # 计算相似度
        sim = einsum("b h i d, b j d -> b h i j", q, k)

        # 获取因果掩码
        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力权重计算
        attn = sim.softmax(dim=-1)

        # 聚合值
        out = einsum("b h i j, b j d -> b h i d", attn, v)

        # 合并多头
        out = rearrange(out, "b h n d -> b n (h d)")
        # 返回注意力输出和前馈网络输出的和
        return self.attn_out(out) + self.ff_out(ff)
# 定义一个函数PaLM,用于创建一个Parallel Transformer模型
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
    # 创建一个神经网络模型,包括嵌入层、多个ParallelTransformerBlock、LayerNorm层和线性层
    net = nn.Sequential(
        nn.Embedding(num_tokens, dim),  # 创建一个嵌入层,将输入的token映射到指定维度的向量
        *[
            Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
            for _ in range(depth)  # 创建指定数量的ParallelTransformerBlock,并将其作为Residual块添加到模型中
        ],
        LayerNorm(dim),  # 添加LayerNorm层,用于归一化模型输出
        nn.Linear(dim, num_tokens, bias=False)  # 添加线性层,将模型输出映射到指定数量的token
    )

    # 将嵌入层的权重赋值给线性层的权重,实现权重共享
    net[-1].weight = net[0].weight

    # 对嵌入层的权重进行正态分布初始化
    nn.init.normal_(net[0].weight, std=0.02)
    
    # 返回创建的神经网络模型
    return net
posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(5)  评论(0编辑  收藏  举报