Lucidrains-系列项目源码解析-三十三-

Lucidrains 系列项目源码解析(三十三)

.\lucidrains\perfusion-pytorch\perfusion_pytorch\save_load.py

# 导入所需的模块
from pathlib import Path
import torch
from torch import nn
from torch.nn import Module
from beartype import beartype
from perfusion_pytorch.embedding import EmbeddingWrapper
from perfusion_pytorch.perfusion import Rank1EditModule

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 保存和加载必要的额外微调参数

# 保存函数,将模型的参数保存到指定路径
@beartype
def save(
    text_image_model: Module,
    path: str
):
    # 将路径转换为 Path 对象
    path = Path(path)
    # 创建路径的父目录,如果不存在则创建
    path.parents[0].mkdir(exist_ok=True, parents=True)

    embed_params = None
    key_value_params = []
    C_inv = None

    # 遍历模型的所有模块
    for module in text_image_model.modules():
        # 如果模块是 EmbeddingWrapper 类型
        if isinstance(module, EmbeddingWrapper):
            # 确保只有一个包装的 EmbeddingWrapper
            assert not exists(embed_params), 'there should only be one wrapped EmbeddingWrapper'
            embed_params = module.concepts.data

        # 如果模块是 Rank1EditModule 类型
        elif isinstance(module, Rank1EditModule):
            # 将模块的参数添加到列表中
            key_value_params.append([
                module.ema_concept_text_encs.data,
                module.concept_outputs.data
            ])

            C_inv = module.C_inv.data

    # 确保 C_inv 参数存在
    assert exists(C_inv), 'Rank1EditModule not found. you likely did not wire up the text to image model correctly'

    # 将参数打包成字典
    pkg = dict(
        embed_params=embed_params,
        key_value_params=key_value_params,
        C_inv=C_inv
    )

    # 保存参数到指定路径
    torch.save(pkg, f'{str(path)}')
    print(f'saved to {str(path)}')

# 加载函数,从指定路径加载参数到模型
@beartype
def load(
    text_image_model: Module,
    path: str
):
    # 将路径转换为 Path 对象
    path = Path(path)
    # 检查文件是否存在
    assert path.exists(), f'file not found at {str(path)}'

    # 加载保存的参数
    pkg = torch.load(str(path))

    embed_params = pkg['embed_params']
    key_value_params = pkg['key_value_params']
    C_inv = pkg['C_inv']

    # 遍历模型的所有模块
    for module in text_image_model.modules():
        # 如果模块是 EmbeddingWrapper 类型
        if isinstance(module, EmbeddingWrapper):
            # 将加载的参数复制到模块中
            module.concepts.data.copy_(embed_params)

        # 如果模块是 Rank1EditModule 类型
        elif isinstance(module, Rank1EditModule):
            # 确保保存的参数和加载的参数匹配
            assert len(key_value_params) > 0, 'mismatch between what was saved vs what is being loaded'
            concept_input, concept_output = key_value_params.pop(0)
            module.ema_concept_text_encs.data.copy_(concept_input)
            module.concept_outputs.data.copy_(concept_output)

            module.C_inv.copy_(C_inv)
            module.initted.copy_(torch.tensor([True]))

    print(f'loaded concept params from {str(path)}')

.\lucidrains\perfusion-pytorch\perfusion_pytorch\__init__.py

# 从perfusion_pytorch.perfusion模块中导入Rank1EditModule、calculate_input_covariance、loss_fn_weighted_by_mask、merge_rank1_edit_modules、make_key_value_proj_rank1_edit_modules_函数
from perfusion_pytorch.perfusion import (
    Rank1EditModule,
    calculate_input_covariance,
    loss_fn_weighted_by_mask,
    merge_rank1_edit_modules,
    make_key_value_proj_rank1_edit_modules_
)

# 从perfusion_pytorch.embedding模块中导入EmbeddingWrapper、OpenClipEmbedWrapper、merge_embedding_wrappers函数
from perfusion_pytorch.embedding import (
    EmbeddingWrapper,
    OpenClipEmbedWrapper,
    merge_embedding_wrappers
)

# 从perfusion_pytorch.save_load模块中导入save、load函数
from perfusion_pytorch.save_load import (
    save,
    load
)

# 从perfusion_pytorch.optimizer模块中导入get_finetune_parameters、get_finetune_optimizer函数
from perfusion_pytorch.optimizer import (
    get_finetune_parameters,
    get_finetune_optimizer
)

Perfusion - Pytorch

Implementation of Key-Locked Rank One Editing. Project page

The selling point of this paper is extremely low extra parameters per added concept, down to 100kb.

It seems they successfully applied the Rank-1 editing technique from a memory editing paper for LLM, with a few improvements. They also identified that the keys determine the "where" of the new concept, while the values determine the "what", and propose local / global-key locking to a superclass concept (while learning the values).

For researchers out there, if this paper checks out, the tools in this repository should work for any other text-to-<insert modality> network using cross attention conditioning. Just a thought

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

  • Yoad Tewel for the multiple code reviews and clarifying emails

  • Brad Vidler for precomputing the covariance matrix for the CLIP used in Stable Diffusion 1.5!

  • All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models

Install

$ pip install perfusion-pytorch

Usage

import torch
from torch import nn

from perfusion_pytorch import Rank1EditModule

to_keys = nn.Linear(768, 320, bias = False)
to_values = nn.Linear(768, 320, bias = False)

wrapped_to_keys = Rank1EditModule(
    to_keys,
    is_key_proj = True
)

wrapped_to_values = Rank1EditModule(
    to_values
)

text_enc = torch.randn(4, 77, 768)                  # regular input
text_enc_with_superclass = torch.randn(4, 77, 768)  # init_input in algorithm 1, for key-locking
concept_indices = torch.randint(0, 77, (4,))        # index where the concept or superclass concept token is in the sequence
key_pad_mask = torch.ones(4, 77).bool()

keys = wrapped_to_keys(
    text_enc,
    concept_indices = concept_indices,
    text_enc_with_superclass = text_enc_with_superclass,
)

values = wrapped_to_values(
    text_enc,
    concept_indices = concept_indices,
    text_enc_with_superclass = text_enc_with_superclass,
)

# after much training ...

wrapped_to_keys.eval()
wrapped_to_values.eval()

keys = wrapped_to_keys(text_enc)

values = wrapped_to_values(text_enc)

The repository also contains an EmbeddingWrapper that makes it easy to train on a new concept (and for eventual inference with multiple concepts)

import torch
from torch import nn

from perfusion_pytorch import EmbeddingWrapper

embed = nn.Embedding(49408, 512) # open clip embedding, somewhere in the module tree of stable diffusion

# wrap it, and will automatically create a new concept for learning, based on the superclass embed string

wrapped_embed = EmbeddingWrapper(
    embed,
    superclass_string = 'dog'
)

# now just pass in your prompts with the superclass id

embeds_with_new_concept, embeds_with_superclass, embed_mask, concept_indices = wrapped_embed([
    'a portrait of dog',
    'dog running through a green field',
    'a man walking his dog'
]) # (3, 77, 512), (3, 77, 512), (3, 77), (3,)

# now pass both embeds through clip text transformer
# the embed_mask needs to be passed to the cross attention as key padding mask

If you can identify the CLIP instance within the stable diffusion instance, you can also pass it directly to the OpenClipEmbedWrapper to gain everything you need on forward for the cross attention layers

ex.

from perfusion_pytorch import OpenClipEmbedWrapper

texts = [
    'a portrait of dog',
    'dog running through a green field',
    'a man walking his dog'
]

wrapped_clip_with_new_concept = OpenClipEmbedWrapper(
    stable_diffusion.path.to.clip,
    superclass_string = 'dog'
)

text_enc, superclass_enc, mask, indices = wrapped_clip_with_new_concept(texts)

# (3, 77, 512), (3, 77, 512), (3, 77), (3,)

Todo

Citations

@article{Tewel2023KeyLockedRO,
    title   = {Key-Locked Rank One Editing for Text-to-Image Personalization},
    author  = {Yoad Tewel and Rinon Gal and Gal Chechik and Yuval Atzmon},
    journal = {ACM SIGGRAPH 2023 Conference Proceedings},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:258436985}
}
@inproceedings{Meng2022LocatingAE,
    title   = {Locating and Editing Factual Associations in GPT},
    author  = {Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov},
    booktitle = {Neural Information Processing Systems},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:255825985}
}

.\lucidrains\perfusion-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'perfusion-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.23',  # 版本号
  license='MIT',  # 许可证
  description = 'Perfusion - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/perfusion-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'memory editing',
    'text-to-image'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.6.1',
    'open-clip-torch',
    'opt-einsum',
    'torch>=2.0'
  ],
  include_package_data = True,  # 包含数据文件
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Phasic Policy Gradient - Pytorch

An implementation of Phasic Policy Gradient, a proposed improvement on top of Proximal Policy Optimization (PPO), in Pytorch. It will be my very first project in Reinforcement Learning.

Install

$ pip install -r requirements.txt

Use

$ python train.py --render

Citations

@misc{cobbe2020phasic,
    title={Phasic Policy Gradient},
    author={Karl Cobbe and Jacob Hilton and Oleg Klimov and John Schulman},
    year={2020},
    eprint={2009.04416},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

.\lucidrains\phasic-policy-gradient\train.py

# 导入必要的库
import os
import fire
from collections import deque, namedtuple

from tqdm import tqdm
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.distributions import Categorical
import torch.nn.functional as F

import gym

# 定义常量
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 定义命名元组
Memory = namedtuple('Memory', ['state', 'action', 'action_log_prob', 'reward', 'done', 'value'])
AuxMemory = namedtuple('Memory', ['state', 'target_value', 'old_values'])

# 定义数据集类
class ExperienceDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __len__(self):
        return len(self.data[0])

    def __getitem__(self, ind):
        return tuple(map(lambda t: t[ind], self.data))

# 创建混洗数据加载器
def create_shuffled_dataloader(data, batch_size):
    ds = ExperienceDataset(data)
    return DataLoader(ds, batch_size = batch_size, shuffle = True)

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 归一化函数
def normalize(t, eps = 1e-5):
    return (t - t.mean()) / (t.std() + eps)

# 更新网络参数
def update_network_(loss, optimizer):
    optimizer.zero_grad()
    loss.mean().backward()
    optimizer.step()

# 初始化网络参数
def init_(m):
    if isinstance(m, nn.Linear):
        gain = torch.nn.init.calculate_gain('tanh')
        torch.nn.init.orthogonal_(m.weight, gain)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

# 定义 Actor 神经网络类
class Actor(nn.Module):
    def __init__(self, state_dim, hidden_dim, num_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh()
        )

        self.action_head = nn.Sequential(
            nn.Linear(hidden_dim, num_actions),
            nn.Softmax(dim=-1)
        )

        self.value_head = nn.Linear(hidden_dim, 1)
        self.apply(init_)

    def forward(self, x):
        hidden = self.net(x)
        return self.action_head(hidden), self.value_head(hidden)

# 定义 Critic 神经网络类
class Critic(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
        )
        self.apply(init_)

    def forward(self, x):
        return self.net(x)

# 定义 PPG 代理类
class PPG:
    def __init__(
        self,
        state_dim,
        num_actions,
        actor_hidden_dim,
        critic_hidden_dim,
        epochs,
        epochs_aux,
        minibatch_size,
        lr,
        betas,
        lam,
        gamma,
        beta_s,
        eps_clip,
        value_clip
    ):
        self.actor = Actor(state_dim, actor_hidden_dim, num_actions).to(device)
        self.critic = Critic(state_dim, critic_hidden_dim).to(device)
        self.opt_actor = Adam(self.actor.parameters(), lr=lr, betas=betas)
        self.opt_critic = Adam(self.critic.parameters(), lr=lr, betas=betas)

        self.minibatch_size = minibatch_size

        self.epochs = epochs
        self.epochs_aux = epochs_aux

        self.lam = lam
        self.gamma = gamma
        self.beta_s = beta_s

        self.eps_clip = eps_clip
        self.value_clip = value_clip

    # 保存模型参数
    def save(self):
        torch.save({
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict()
        }, f'./ppg.pt')
    # 加载模型参数
    def load(self):
        # 检查是否存在模型参数文件
        if not os.path.exists('./ppg.pt'):
            return

        # 从文件中加载模型参数
        data = torch.load(f'./ppg.pt')
        # 更新 actor 模型参数
        self.actor.load_state_dict(data['actor'])
        # 更新 critic 模型参数
        self.critic.load_state_dict(data['critic'])

    # 学习函数,用于训练模型
    def learn(self, memories, aux_memories, next_state):
        # 从记忆中提取并准备训练数据
        states = []
        actions = []
        old_log_probs = []
        rewards = []
        masks = []
        values = []

        for mem in memories:
            states.append(mem.state)
            actions.append(torch.tensor(mem.action))
            old_log_probs.append(mem.action_log_prob)
            rewards.append(mem.reward)
            masks.append(1 - float(mem.done))
            values.append(mem.value)

        # 计算广义优势估计值
        next_state = torch.from_numpy(next_state).to(device)
        next_value = self.critic(next_state).detach()
        values = values + [next_value]

        returns = []
        gae = 0
        for i in reversed(range(len(rewards))):
            delta = rewards[i] + self.gamma * values[i + 1] * masks[i] - values[i]
            gae = delta + self.gamma * self.lam * masks[i] * gae
            returns.insert(0, gae + values[i])

        # 将值转换为 torch 张量
        to_torch_tensor = lambda t: torch.stack(t).to(device).detach()

        states = to_torch_tensor(states)
        actions = to_torch_tensor(actions)
        old_values = to_torch_tensor(values[:-1])
        old_log_probs = to_torch_tensor(old_log_probs)

        rewards = torch.tensor(returns).float().to(device)

        # 将状态和目标值存储到辅助内存缓冲区以供后续训练使用
        aux_memory = AuxMemory(states, rewards, old_values)
        aux_memories.append(aux_memory)

        # 为策略阶段训练准备数据加载器
        dl = create_shuffled_dataloader([states, actions, old_log_probs, rewards, old_values], self.minibatch_size)

        # 策略阶段训练,类似于原始的 PPO
        for _ in range(self.epochs):
            for states, actions, old_log_probs, rewards, old_values in dl:
                action_probs, _ = self.actor(states)
                values = self.critic(states)
                dist = Categorical(action_probs)
                action_log_probs = dist.log_prob(actions)
                entropy = dist.entropy()

                # 计算剪切的替代目标,经典的 PPO 损失
                ratios = (action_log_probs - old_log_probs).exp()
                advantages = normalize(rewards - old_values.detach())
                surr1 = ratios * advantages
                surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages
                policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropy

                # 更新策略网络
                update_network_(policy_loss, self.opt_actor)

                # 计算值损失并更新值网络,与策略网络分开
                value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)

                update_network_(value_loss, self.opt_critic)
    # 定义一个辅助学习函数,用于训练辅助记忆
    def learn_aux(self, aux_memories):
        # 将状态和目标值合并成一个张量
        states = []
        rewards = []
        old_values = []
        for state, reward, old_value in aux_memories:
            states.append(state)
            rewards.append(reward)
            old_values.append(old_value)

        # 将状态、奖励和旧值连接成一个张量
        states = torch.cat(states)
        rewards = torch.cat(rewards)
        old_values = torch.cat(old_values)

        # 获取用于最小化 kl 散度和剪切的旧动作预测值
        old_action_probs, _ = self.actor(states)
        old_action_probs.detach_()

        # 为辅助阶段训练准备数据加载器
        dl = create_shuffled_dataloader([states, old_action_probs, rewards, old_values], self.minibatch_size)

        # 提出的辅助阶段训练
        # 在将值蒸馏到策略网络的同时,确保策略网络不改变动作预测值(kl 散度损失)
        for epoch in range(self.epochs_aux):
            for states, old_action_probs, rewards, old_values in tqdm(dl, desc=f'auxiliary epoch {epoch}'):
                action_probs, policy_values = self.actor(states)
                action_logprobs = action_probs.log()

                # 策略网络损失由 kl 散度损失和辅助损失组成
                aux_loss = clipped_value_loss(policy_values, rewards, old_values, self.value_clip)
                loss_kl = F.kl_div(action_logprobs, old_action_probs, reduction='batchmean')
                policy_loss = aux_loss + loss_kl

                # 更新策略网络
                update_network_(policy_loss, self.opt_actor)

                # 论文指出在辅助阶段额外训练值网络非常重要
                values = self.critic(states)
                value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)

                # 更新值网络
                update_network_(value_loss, self.opt_critic)
# 主函数
def main(
    env_name = 'LunarLander-v2',  # 环境名称,默认为'LunarLander-v2'
    num_episodes = 50000,  # 总的训练轮数,默认为50000
    max_timesteps = 500,  # 每轮最大时间步数,默认为500
    actor_hidden_dim = 32,  # Actor神经网络隐藏层维度,默认为32
    critic_hidden_dim = 256,  # Critic神经网络隐藏层维度,默认为256
    minibatch_size = 64,  # 每次训练的样本批量大小,默认为64
    lr = 0.0005,  # 学习率,默认为0.0005
    betas = (0.9, 0.999),  # Adam优化器的beta参数,默认为(0.9, 0.999)
    lam = 0.95,  # GAE的lambda参数,默认为0.95
    gamma = 0.99,  # 折扣因子,默认为0.99
    eps_clip = 0.2,  # PPO算法的epsilon clip参数,默认为0.2
    value_clip = 0.4,  # Critic的值函数clip参数,默认为0.4
    beta_s = .01,  # 熵损失的权重参数,默认为0.01
    update_timesteps = 5000,  # 更新模型的时间步数间隔,默认为5000
    num_policy_updates_per_aux = 32,  # 辅助网络更新次数,默认为32
    epochs = 1,  # 主网络训练轮数,默认为1
    epochs_aux = 6,  # 辅助网络训练轮数,默认为6
    seed = None,  # 随机种子,默认为None
    render = False,  # 是否渲染环境,默认为False
    render_every_eps = 250,  # 每隔多少轮渲染一次,默认为250
    save_every = 1000,  # 每隔多少轮保存模型,默认为1000
    load = False,  # 是否加载已有模型,默认为False
    monitor = False  # 是否监视环境,默认为False
):
    env = gym.make(env_name)  # 创建环境

    if monitor:
        env = gym.wrappers.Monitor(env, './tmp/', force=True)  # 监视环境

    state_dim = env.observation_space.shape[0]  # 状态空间维度
    num_actions = env.action_space.n  # 动作空间维度

    memories = deque([])  # 存储经验的队列
    aux_memories = deque([])  # 存储辅助经验的队列

    agent = PPG(  # 创建PPO算法的代理
        state_dim,
        num_actions,
        actor_hidden_dim,
        critic_hidden_dim,
        epochs,
        epochs_aux,
        minibatch_size,
        lr,
        betas,
        lam,
        gamma,
        beta_s,
        eps_clip,
        value_clip
    )

    if load:
        agent.load()  # 加载模型

    if exists(seed):  # 如果存在随机种子
        torch.manual_seed(seed)  # 设置PyTorch随机种子
        np.random.seed(seed)  # 设置NumPy随机种子

    time = 0  # 时间步数
    updated = False  # 是否更新模型
    num_policy_updates = 0  # 策略更新次数

    for eps in tqdm(range(num_episodes), desc='episodes'):  # 遍历训练轮数
        render_eps = render and eps % render_every_eps == 0  # 是否渲染当前轮次
        state = env.reset()  # 重置环境状态
        for timestep in range(max_timesteps):  # 遍历每个时间步
            time += 1  # 时间步数加1

            if updated and render_eps:  # 如果已更新并需要渲染
                env.render()  # 渲染环境

            state = torch.from_numpy(state).to(device)  # 转换状态为PyTorch张量
            action_probs, _ = agent.actor(state)  # 获取动作概率
            value = agent.critic(state)  # 获取值函数

            dist = Categorical(action_probs)  # 创建分类分布
            action = dist.sample()  # 采样动作
            action_log_prob = dist.log_prob(action)  # 计算动作对数概率
            action = action.item()  # 转换动作为标量

            next_state, reward, done, _ = env.step(action)  # 执行动作

            memory = Memory(state, action, action_log_prob, reward, done, value)  # 创建经验
            memories.append(memory)  # 将经验添加到队列

            state = next_state  # 更新状态

            if time % update_timesteps == 0:  # 如果达到更新时间步
                agent.learn(memories, aux_memories, next_state)  # 更新主网络
                num_policy_updates += 1  # 策略更新次数加1
                memories.clear()  # 清空经验队列

                if num_policy_updates % num_policy_updates_per_aux == 0:  # 达到辅助网络更新次数
                    agent.learn_aux(aux_memories)  # 更新辅助网络
                    aux_memories.clear()  # 清空辅助经验队列

                updated = True  # 设置为已更新

            if done:  # 如果环境结束
                if render_eps:  # 如果需要渲染
                    updated = False  # 设置为未更新
                break  # 跳出循环

        if render_eps:  # 如果需要渲染
            env.close()  # 关闭环境

        if eps % save_every == 0:  # 每隔一定轮次保存模型
            agent.save()  # 保存模型

if __name__ == '__main__':
    fire.Fire(main)  # 使用Fire库执行主函数

.\lucidrains\phenaki-pytorch\phenaki_pytorch\attention.py

        # 初始化注意力机制模块
        def __init__(
            self,
            dim,
            dim_context = None,
            dim_head = 64,
            heads = 8,
            causal = False,
            num_null_kv = 0,
            norm_context = True,
            dropout = 0.,
            scale = 8
        ):
            # 调用父类初始化方法
            super().__init__()
            # 设置注意力头数
            self.heads = heads
            # 是否为因果注意力
            self.causal = causal
            # 缩放因子
            self.scale = scale
            # 内部维度
            inner_dim = dim_head * heads
            # 如果未指定上下文维度,则默认为输入维度
            dim_context = default(dim_context, dim)

            # 如果是因果注意力,则使用AlibiPositionalBias初始化相对位置偏置
            if causal:
                self.rel_pos_bias = AlibiPositionalBias(heads = heads)

            # 注意力机制的dropout层
            self.attn_dropout = nn.Dropout(dropout)

            # 输入的LayerNorm层
            self.norm = LayerNorm(dim)
            # 上下文的LayerNorm层(如果需要规范化上下文)
            self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()

            # 空键值对的数量
            self.num_null_kv = num_null_kv
            # 空键值对参数
            self.null_kv = nn.Parameter(torch.randn(heads, 2 * num_null_kv, dim_head))

            # 查询转换层
            self.to_q = nn.Linear(dim, inner_dim, bias = False)
            # 键值对转换层
            self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)

            # 查询缩放参数
            self.q_scale = nn.Parameter(torch.ones(dim_head))
            # 键缩放参数
            self.k_scale = nn.Parameter(torch.ones(dim_head))

            # 输出转换层
            self.to_out = nn.Linear(inner_dim, dim, bias = False)
    # 获取输入张量 x 的批量大小、设备和数据类型
    batch, device, dtype = x.shape[0], x.device, x.dtype

    # 如果上下文存在,则对上下文进行归一化处理
    if exists(context):
        context = self.context_norm(context)

    # 对输入张量 x 进行归一化处理
    x = self.norm(x)

    # 将输入张量 x 转换为查询(q)、键(k)、值(v)张量
    q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

    # 将查询(q)、键(k)、值(v)张量按照指定维度重新排列
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

    # 重复空键值对(null_kv)以匹配批量大小和维度
    nk, nv = repeat(self.null_kv, 'h (n r) d -> b h n r d', b = batch, r = 2).unbind(dim = -2)

    # 将键(k)和值(v)张量与空键值对(nk、nv)进行拼接
    k = torch.cat((nk, k), dim = -2)
    v = torch.cat((nv, v), dim = -2)

    # 对查询(q)和键(k)进行 L2 归一化处理
    q, k = map(l2norm, (q, k))
    q = q * self.q_scale
    k = k * self.k_scale

    # 计算查询(q)和键(k)之间的相似度
    sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

    i, j = sim.shape[-2:]

    # 如果存在注意力偏置(attn_bias),则对相似度矩阵进行加权
    if exists(attn_bias):
        attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value = 0.)
        sim = sim + attn_bias

    # 如果存在掩码(mask),则对掩码进行处理
    if exists(mask):
        mask = F.pad(mask, (self.num_null_kv, 0), value = True)
        mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

    # 如果启用因果注意力,则对相似度矩阵进行处理
    if self.causal:
        sim = sim + self.rel_pos_bias(sim)

        causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

    # 对相似度矩阵进行 softmax 操作
    attn = sim.softmax(dim = -1)
    attn = self.attn_dropout(attn)

    # 计算输出张量
    out = einsum('b h i j, b h j d -> b h i d', attn, v)

    # 重新排列输出张量的维度
    out = rearrange(out, 'b h n d -> b n (h d)')
    return self.to_out(out)
# 定义一个名为 AlibiPositionalBias 的类,用于处理位置偏差
class AlibiPositionalBias(nn.Module):
    def __init__(self, heads):
        super().__init__()
        self.heads = heads
        # 初始化斜率参数
        slopes = torch.Tensor(self._get_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        # 注册斜率参数和偏差参数
        self.register_buffer('slopes', slopes, persistent = False)
        self.register_buffer('bias', None, persistent = False)

    # 获取偏差值
    def get_bias(self, i, j, device):
        i_arange = torch.arange(j - i, j, device = device)
        j_arange = torch.arange(j, device = device)
        bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
        return bias

    # 获取斜率参数
    @staticmethod
    def _get_slopes(heads):
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]

        if math.log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2 ** math.floor(math.log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

    # 前向传播函数
    def forward(self, sim):
        h, i, j, device = *sim.shape[-3:], sim.device

        if exists(self.bias) and self.bias.shape[-1] >= j:
            return self.bias[..., :i, :j]

        bias = self.get_bias(i, j, device)
        bias = bias * self.slopes

        num_heads_unalibied = h - bias.shape[0]
        bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
        self.register_buffer('bias', bias, persistent = False)

        return self.bias

# 定义一个名为 ContinuousPositionBias 的类,用于处理连续位置偏差
class ContinuousPositionBias(nn.Module):
    """ from https://arxiv.org/abs/2111.09883 """

    def __init__(
        self,
        *,
        dim,
        heads,
        num_dims = 2, # 2 for images, 3 for video
        layers = 2,
        log_dist = True,
        cache_rel_pos = False
    ):
        super().__init__()
        self.num_dims = num_dims
        self.log_dist = log_dist

        self.net = nn.ModuleList([])
        self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), leaky_relu()))

        for _ in range(layers - 1):
            self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))

        self.net.append(nn.Linear(dim, heads)

        self.cache_rel_pos = cache_rel_pos
        self.register_buffer('rel_pos', None, persistent = False)

    # 前向传播函数
    def forward(self, *dimensions, device = torch.device('cpu')):

        if not exists(self.rel_pos) or not self.cache_rel_pos:
            positions = [torch.arange(d, device = device) for d in dimensions]
            grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij'))
            grid = rearrange(grid, 'c ... -> (...) c')
            rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')

            if self.log_dist:
                rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)

            self.register_buffer('rel_pos', rel_pos, persistent = False)

        rel_pos = self.rel_pos.float()

        for layer in self.net:
            rel_pos = layer(rel_pos)

        return rearrange(rel_pos, 'i j h -> h i j')

# 定义一个名为 Transformer 的类,用于实现 Transformer 模型
class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_context = None,
        causal = False,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        peg = False,
        peg_causal = False,
        attn_num_null_kv = 2,
        has_cross_attn = False,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化一个空的神经网络模块列表
        self.layers = nn.ModuleList([])

        # 循环depth次,向神经网络模块列表中添加不同的模块
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                # 如果peg为真,则添加一个PEG模块,否则添加None
                PEG(dim = dim, causal = peg_causal) if peg else None,
                # 添加一个Attention模块
                Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, dropout = attn_dropout),
                # 如果has_cross_attn为真,则添加一个带有跨注意力的Attention模块,否则添加None
                Attention(dim = dim, dim_head = dim_head, dim_context = dim_context, heads = heads, causal = False, num_null_kv = attn_num_null_kv, dropout = attn_dropout) if has_cross_attn else None,
                # 添加一个FeedForward模块
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        # 初始化一个LayerNorm模块
        self.norm_out = LayerNorm(dim)

    @beartype
    def forward(
        self,
        x,
        video_shape: Tuple[int, int, int, int] = None,
        attn_bias = None,
        context = None,
        self_attn_mask = None,
        cross_attn_context_mask = None
    ):

        # 遍历神经网络模块列表中的不同模块
        for peg, self_attn, cross_attn, ff in self.layers:
            # 如果存在PEG模块,则对输入进行处理并与原始输入相加
            if exists(peg):
                x = peg(x, shape = video_shape) + x

            # 对输入进行自注意力处理并与原始输入相加
            x = self_attn(x, attn_bias = attn_bias, mask = self_attn_mask) + x

            # 如果存在跨注意力模块且存在上下文信息,则对输入进行处理并与原始输入相加
            if exists(cross_attn) and exists(context):
                x = cross_attn(x, context = context, mask = cross_attn_context_mask) + x

            # 对输入进行前馈处理并与原始输入相加
            x = ff(x) + x

        # 对处理后的结果进行LayerNorm处理并返回
        return self.norm_out(x)

.\lucidrains\phenaki-pytorch\phenaki_pytorch\cvivit.py

# 导入必要的库
from pathlib import Path
import copy
import math
from functools import wraps

import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.autograd import grad as torch_grad

import torchvision

# 导入 einops 库中的函数
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 导入自定义的模块
from vector_quantize_pytorch import VectorQuantize, LFQ
from phenaki_pytorch.attention import Attention, Transformer, ContinuousPositionBias

# 定义一些辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 定义 leaky_relu 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(p)

# 移除 vgg 属性的装饰器
def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, 'vgg')
        if has_vgg:
            vgg = self.vgg
            delattr(self, 'vgg')

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self.vgg = vgg

        return out
    return inner

# 将单个值转换为元组
def pair(val):
    ret = (val, val) if not isinstance(val, tuple) else val
    assert len(ret) == 2
    return ret

# 将单个值转换为指定长度的元组
def cast_tuple(val, l = 1):
    return val if isinstance(val, tuple) else (val,) * l

# 计算梯度惩罚
def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]

    gradients = torch_grad(
        outputs = output,
        inputs = images,
        grad_outputs = torch.ones(output.size(), device = images.device),
        create_graph = True,
        retain_graph = True,
        only_inputs = True
    )[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 安全除法,避免分母为零
def safe_div(numer, denom, eps = 1e-8):
    return numer / (denom + eps)

# 定义 GAN 损失函数

# Hinge 损失函数(判别器)
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 损失函数(生成器)
def hinge_gen_loss(fake):
    return -fake.mean()

# 二元交叉熵损失函数(判别器)
def bce_discr_loss(fake, real):
    return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()

# 二元交叉熵损失函数(生成器)
def bce_gen_loss(fake):
    return -log(torch.sigmoid(fake)).mean()

# 计算损失函数对某一层的梯度
def grad_layer_wrt_loss(loss, layer):
    return torch_grad(
        outputs = loss,
        inputs = layer,
        grad_outputs = torch.ones_like(loss),
        retain_graph = True
    )[0].detach()

# 定义判别器模块

class DiscriminatorBlock(nn.Module):
    def __init__(
        self,
        input_channels,
        filters,
        downsample = True
    ):
        super().__init__()
        self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))

        self.net = nn.Sequential(
            nn.Conv2d(input_channels, filters, 3, padding=1),
            leaky_relu(),
            nn.Conv2d(filters, filters, 3, padding=1),
            leaky_relu()
        )

        self.downsample = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
            nn.Conv2d(filters * 4, filters, 1)
        ) if downsample else None

    def forward(self, x):
        res = self.conv_res(x)
        x = self.net(x)

        if exists(self.downsample):
            x = self.downsample(x)

        x = (x + res) * (1 / math.sqrt(2))
        return x


class Discriminator(nn.Module):
    def __init__(
        self,
        *,
        dim,
        image_size,
        channels = 3,
        attn_res_layers = (16,),
        max_dim = 512
    # 初始化函数,继承父类的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 将图像大小转换为元组
        image_size = pair(image_size)
        # 计算图像的最小分辨率
        min_image_resolution = min(image_size)

        # 计算层数,根据最小分辨率
        num_layers = int(math.log2(min_image_resolution) - 2)
        # 将注意力层的分辨率转换为元组
        attn_res_layers = cast_tuple(attn_res_layers, num_layers)

        # 初始化块列表
        blocks = []

        # 计算每一层的维度
        layer_dims = [channels] + [(dim * 4) * (2 ** i) for i in range(num_layers + 1)]
        # 将每一层的维度限制在最大维度内
        layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
        # 将每一层的输入输出维度组成元组
        layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))

        # 初始化块列表和注意力块列表
        blocks = []
        attn_blocks = []

        # 初始化图像分辨率
        image_resolution = min_image_resolution

        # 遍历每一层的输入输出维度
        for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
            # 计算当前层的编号
            num_layer = ind + 1
            # 判断是否为最后一层
            is_not_last = ind != (len(layer_dims_in_out) - 1)

            # 创建鉴别器块
            block = DiscriminatorBlock(in_chan, out_chan, downsample = is_not_last)
            blocks.append(block)

            # 初始化注意力块
            attn_block = None
            if image_resolution in attn_res_layers:
                attn_block = Attention(dim = out_chan)

            attn_blocks.append(attn_block)

            # 更新图像分辨率
            image_resolution //= 2

        # 将块列表和注意力块列表转换为模块列表
        self.blocks = nn.ModuleList(blocks)
        self.attn_blocks = nn.ModuleList(attn_blocks)

        # 计算最后一层的维度
        dim_last = layer_dims[-1]

        # 计算下采样因子
        downsample_factor = 2 ** num_layers
        # 计算最后特征图的大小
        last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))

        # 计算潜在维度
        latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last

        # 定义输出层
        self.to_logits = nn.Sequential(
            nn.Conv2d(dim_last, dim_last, 3, padding = 1),
            leaky_relu(),
            Rearrange('b ... -> b (...)'),
            nn.Linear(latent_dim, 1),
            Rearrange('b 1 -> b')
        )

    # 前向传播函数
    def forward(self, x):

        # 遍历块列表和注意力块列表
        for block, attn_block in zip(self.blocks, self.attn_blocks):
            # 应用块
            x = block(x)

            # 如果存在注意力块
            if exists(attn_block):
                x, ps = pack([x], 'b c *')
                x = rearrange(x, 'b c n -> b n c')
                x = attn_block(x) + x
                x = rearrange(x, 'b n c -> b c n')
                x, = unpack(x, ps, 'b c *')

        # 返回输出结果
        return self.to_logits(x)
# 定义一个函数,用于从视频中选择指定帧的图像
def pick_video_frame(video, frame_indices):
    # 获取视频的批量大小和设备信息
    batch, device = video.shape[0], video.device
    # 重新排列视频张量的维度,将通道维度放在第二个位置
    video = rearrange(video, 'b c f ... -> b f c ...')
    # 创建一个包含批量索引的张量
    batch_indices = torch.arange(batch, device=device)
    batch_indices = rearrange(batch_indices, 'b -> b 1')
    # 从视频中选择指定帧的图像
    images = video[batch_indices, frame_indices]
    # 重新排列图像张量的维度,将通道维度放在第一个位置
    images = rearrange(images, 'b 1 c ... -> b c ...')
    return images

# 定义一个 CViViT 类,实现3D ViT模型,具有分解的空间和时间注意力,并制作成vqgan-vae自动编码器
class CViViT(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 模型维度
        codebook_size,  # 代码簿大小
        image_size,  # 图像大小
        patch_size,  # 图像块大小
        temporal_patch_size,  # 时间块大小
        spatial_depth,  # 空间深度
        temporal_depth,  # 时间深度
        discr_base_dim=16,  # 判别器基础维度
        dim_head=64,  # 头部维度
        heads=8,  # 头部数量
        channels=3,  # 通道数
        use_vgg_and_gan=True,  # 是否使用VGG和GAN
        vgg=None,  # VGG模型
        discr_attn_res_layers=(16,),  # 判别器注意力层分辨率
        use_hinge_loss=True,  # 是否使用hinge损失
        attn_dropout=0.,  # 注意力机制的dropout率
        ff_dropout=0.,  # feed-forward层的dropout率
        lookup_free_quantization=True,  # 是否使用无查找表的量化
        lookup_free_quantization_kwargs: dict = {}  # 无查找表的量化参数
        ):
        """
        einstein notations:

        b - batch
        c - channels
        t - time
        d - feature dimension
        p1, p2, pt - image patch sizes and then temporal patch size
        """

        super().__init__()

        self.image_size = pair(image_size)
        self.patch_size = pair(patch_size)
        patch_height, patch_width = self.patch_size

        self.temporal_patch_size = temporal_patch_size

        self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim, heads = heads)

        image_height, image_width = self.image_size
        assert (image_height % patch_height) == 0 and (image_width % patch_width) == 0

        self.to_patch_emb_first_frame = nn.Sequential(
            Rearrange('b c 1 (h p1) (w p2) -> b 1 h w (c p1 p2)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(channels * patch_width * patch_height),
            nn.Linear(channels * patch_width * patch_height, dim),
            nn.LayerNorm(dim)
        )

        self.to_patch_emb = nn.Sequential(
            Rearrange('b c (t pt) (h p1) (w p2) -> b t h w (c pt p1 p2)', p1 = patch_height, p2 = patch_width, pt = temporal_patch_size),
            nn.LayerNorm(channels * patch_width * patch_height * temporal_patch_size),
            nn.Linear(channels * patch_width * patch_height * temporal_patch_size, dim),
            nn.LayerNorm(dim)
        )

        transformer_kwargs = dict(
            dim = dim,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            peg = True,
            peg_causal = True,
        )

        self.enc_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs)
        self.enc_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs)

        # offer look up free quantization
        # https://arxiv.org/abs/2310.05737

        self.lookup_free_quantization = lookup_free_quantization

        if lookup_free_quantization:
            self.vq = LFQ(dim = dim, codebook_size = codebook_size, **lookup_free_quantization_kwargs)
        else:
            self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True)

        self.dec_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs)
        self.dec_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs)

        self.to_pixels_first_frame = nn.Sequential(
            nn.Linear(dim, channels * patch_width * patch_height),
            Rearrange('b 1 h w (c p1 p2) -> b c 1 (h p1) (w p2)', p1 = patch_height, p2 = patch_width)
        )

        self.to_pixels = nn.Sequential(
            nn.Linear(dim, channels * patch_width * patch_height * temporal_patch_size),
            Rearrange('b t h w (c pt p1 p2) -> b c (t pt) (h p1) (w p2)', p1 = patch_height, p2 = patch_width, pt = temporal_patch_size),
        )

        # turn off GAN and perceptual loss if grayscale

        self.vgg = None
        self.discr = None
        self.use_vgg_and_gan = use_vgg_and_gan

        if not use_vgg_and_gan:
            return

        # preceptual loss

        if exists(vgg):
            self.vgg = vgg
        else:
            self.vgg = torchvision.models.vgg16(pretrained = True)
            self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])

        # gan related losses

        self.discr = Discriminator(
            image_size = self.image_size,
            dim = discr_base_dim,
            channels = channels,
            attn_res_layers = discr_attn_res_layers
        )

        self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
        self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
    # 计算视频的掩码,用于生成视频的 token
    def calculate_video_token_mask(self, videos, video_frame_mask):
        # 解构赋值,获取视频的高度和宽度
        *_, h, w = videos.shape
        # 获取补丁的高度和宽度
        ph, pw = self.patch_size

        # 断言视频帧掩码的总和减去1必须能被时间补丁大小整除
        assert torch.all(((video_frame_mask.sum(dim = -1) - 1) % self.temporal_patch_size) == 0), 'number of frames must be divisible by temporal patch size, subtracting off the first frame'
        # 将第一帧掩码和其余帧掩码分开
        first_frame_mask, rest_frame_mask = video_frame_mask[:, :1], video_frame_mask[:, 1:]
        # 重新排列其余帧掩码,以适应时间补丁大小
        rest_vq_mask = rearrange(rest_frame_mask, 'b (f p) -> b f p', p = self.temporal_patch_size)
        # 合并第一帧掩码和其余帧掩码的逻辑或结果
        video_mask = torch.cat((first_frame_mask, rest_vq_mask.any(dim = -1)), dim = -1)
        # 重复视频掩码,以匹配视频的高度和宽度
        return repeat(video_mask, 'b f -> b (f hw)', hw = (h // ph) * (w // pw))

    # 获取视频补丁的形状
    def get_video_patch_shape(self, num_frames, include_first_frame = True):
        patch_frames = 0

        if include_first_frame:
            num_frames -= 1
            patch_frames += 1

        patch_frames += (num_frames // self.temporal_patch_size)

        return (patch_frames, *self.patch_height_width)

    # 返回图像 token 的数量
    @property
    def image_num_tokens(self):
        return int(self.image_size[0] / self.patch_size[0]) * int(self.image_size[1] / self.patch_size[1])

    # 根据 token 数量返回帧数
    def frames_per_num_tokens(self, num_tokens):
        tokens_per_frame = self.image_num_tokens

        assert (num_tokens % tokens_per_frame) == 0, f'number of tokens must be divisible by number of tokens per frame {tokens_per_frame}'
        assert (num_tokens > 0)

        pseudo_frames = num_tokens // tokens_per_frames
        return (pseudo_frames - 1) * self.temporal_patch_size + 1

    # 根据帧数返回 token 数量
    def num_tokens_per_frames(self, num_frames, include_first_frame = True):
        image_num_tokens = self.image_num_tokens

        total_tokens = 0

        if include_first_frame:
            num_frames -= 1
            total_tokens += image_num_tokens

        assert (num_frames % self.temporal_patch_size) == 0

        return total_tokens + int(num_frames / self.temporal_patch_size) * image_num_tokens

    # 用于评估的模型拷贝
    def copy_for_eval(self):
        device = next(self.parameters()).device
        vae_copy = copy.deepcopy(self.cpu())

        if vae_copy.use_vgg_and_gan:
            del vae_copy.discr
            del vae_copy.vgg

        vae_copy.eval()
        return vae_copy.to(device)

    # 重写 state_dict 方法
    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    # 重写 load_state_dict 方法
    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pt = torch.load(str(path))
        self.load_state_dict(pt)

    # 根据 codebook 索引解码
    def decode_from_codebook_indices(self, indices):
        if self.lookup_free_quantization:
            codes = self.vq.indices_to_codes(indices)
        else:
            codes = self.vq.codebook[indices]

        return self.decode(codes)

    # 返回补丁的高度和宽度
    @property
    def patch_height_width(self):
        return self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]

    # 编码 tokens
    def encode(
        self,
        tokens
    ):
        b = tokens.shape[0]
        h, w = self.patch_height_width

        video_shape = tuple(tokens.shape[:-1])

        tokens = rearrange(tokens, 'b t h w d -> (b t) (h w) d')

        attn_bias = self.spatial_rel_pos_bias(h, w, device = tokens.device)

        tokens = self.enc_spatial_transformer(tokens, attn_bias = attn_bias, video_shape = video_shape)

        tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b = b, h = h , w = w)

        # encode - temporal

        tokens = rearrange(tokens, 'b t h w d -> (b h w) t d')

        tokens = self.enc_temporal_transformer(tokens, video_shape = video_shape)

        tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b = b, h = h, w = w)

        return tokens

    # 解码 tokens
    def decode(
        self,
        tokens
        ):
        # 获取 tokens 的 batch 大小
        b = tokens.shape[0]
        # 获取 patch 的高度和宽度
        h, w = self.patch_height_width

        # 如果 tokens 的维度为 3,则重新排列 tokens 的维度
        if tokens.ndim == 3:
            tokens = rearrange(tokens, 'b (t h w) d -> b t h w d', h = h, w = w)

        # 获取视频形状的元组
        video_shape = tuple(tokens.shape[:-1])

        # 解码 - 时间维度

        # 重新排列 tokens 的维度
        tokens = rearrange(tokens, 'b t h w d -> (b h w) t d')

        # 对 tokens 进行时间维度的解码
        tokens = self.dec_temporal_transformer(tokens, video_shape = video_shape)

        # 重新排列 tokens 的维度
        tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b = b, h = h, w = w)

        # 解码 - 空间维度

        # 重新排列 tokens 的维度
        tokens = rearrange(tokens, 'b t h w d -> (b t) (h w) d')

        # 获取空间相对位置偏置
        attn_bias = self.spatial_rel_pos_bias(h, w, device = tokens.device)

        # 对 tokens 进行空间维度的解码
        tokens = self.dec_spatial_transformer(tokens, attn_bias = attn_bias, video_shape = video_shape)

        # 重新排列 tokens 的维度
        tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b = b, h = h , w = w)

        # 转换为像素

        # 获取第一帧 token 和其余帧 tokens
        first_frame_token, rest_frames_tokens = tokens[:, :1], tokens[:, 1:]

        # 将第一帧转换为像素
        first_frame = self.to_pixels_first_frame(first_frame_token)

        # 将其余帧转换为像素
        rest_frames = self.to_pixels(rest_frames_tokens)

        # 拼接重构视频
        recon_video = torch.cat((first_frame, rest_frames), dim = 2)

        # 返回重构视频
        return recon_video

    def forward(
        self,
        video,
        mask = None,
        return_recons = False,
        return_recons_only = False,
        return_discr_loss = False,
        apply_grad_penalty = True,
        return_only_codebook_ids = False

.\lucidrains\phenaki-pytorch\phenaki_pytorch\cvivit_trainer.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 random 模块中导入 choice 函数
from random import choice
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree

# 从 beartype 模块中导入 beartype 装饰器
from beartype import beartype

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.utils.data 模块中导入 Dataset, DataLoader, random_split 类
from torch.utils.data import Dataset, DataLoader, random_split

# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 从 torchvision.datasets 模块中导入 ImageFolder 类
from torchvision.datasets import ImageFolder
# 从 torchvision.utils 模块中导入 make_grid, save_image 函数
from torchvision.utils import make_grid, save_image

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange

# 从 phenaki_pytorch.optimizer 模块中导入 get_optimizer 函数
from phenaki_pytorch.optimizer import get_optimizer

# 从 ema_pytorch 模块中导入 EMA 类
from ema_pytorch import EMA

# 从 phenaki_pytorch.cvivit 模块中导入 CViViT 类
from phenaki_pytorch.cvivit import CViViT
# 从 phenaki_pytorch.data 模块中导入 ImageDataset, VideoDataset, video_tensor_to_gif 函数
from phenaki_pytorch.data import ImageDataset, VideoDataset, video_tensor_to_gif

# 从 accelerate 模块中导入 Accelerator 类

# helpers

# 定义 exists 函数,判断值是否存在
def exists(val):
    return val is not None

# 定义 noop 函数,空函数
def noop(*args, **kwargs):
    pass

# 定义 cycle 函数,循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 定义 cast_tuple 函数,将参数转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 定义 yes_or_no 函数,询问用户是否为是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 定义 accum_log 函数,累积日志信息
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# main trainer class

# 使用 beartype 装饰器定义 CViViTTrainer 类
@beartype
class CViViTTrainer(nn.Module):
    # 初始化方法
    def __init__(
        self,
        vae: CViViT,
        *,
        num_train_steps,
        batch_size,
        folder,
        train_on_images = False,
        num_frames = 17,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        use_ema = True,
        ema_beta = 0.995,
        ema_update_after_step = 0,
        ema_update_every = 1,
        apply_grad_penalty_every = 4,
        accelerate_kwargs: dict = dict()
    ):
        # 调用父类的构造函数
        super().__init__()
        # 获取 VAE 模型的图像大小
        image_size = vae.image_size

        # 初始化加速器
        self.accelerator = Accelerator(**accelerate_kwargs)

        # 设置 VAE 模型
        self.vae = vae

        # 是否使用指数移动平均
        self.use_ema = use_ema
        # 如果是主进程且使用指数移动平均
        if self.is_main and use_ema:
            # 初始化指数移动平均 VAE 模型
            self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)

        # 注册缓冲区 'steps',用于记录训练步数
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、批量大小和梯度累积步数
        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 获取所有参数、判别器参数和 VAE 参数
        all_parameters = set(vae.parameters())
        discr_parameters = set(vae.discr.parameters())
        vae_parameters = all_parameters - discr_parameters

        self.vae_parameters = vae_parameters

        # 获取优化器
        self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
        self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)

        # 设置梯度裁剪阈值
        self.max_grad_norm = max_grad_norm
        self.discr_max_grad_norm = discr_max_grad_norm

        # 创建数据集
        dataset_klass = ImageDataset if train_on_images else VideoDataset
        if train_on_images:
            self.ds = ImageDataset(folder, image_size)
        else:
            self.ds = VideoDataset(folder, image_size, num_frames = num_frames)

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 创建数据加载器
        self.dl = DataLoader(
            self.ds,
            batch_size = batch_size,
            shuffle = True
        )

        self.valid_dl = DataLoader(
            self.valid_ds,
            batch_size = batch_size,
            shuffle = True
        )

        # 准备加速器
        (
            self.vae,
            self.optim,
            self.discr_optim,
            self.dl
        ) = self.accelerator.prepare(
            self.vae,
            self.optim,
            self.discr_optim,
            self.dl
        )

        # 创建数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置模型保存频率和结果保存频率
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置应用梯度惩罚的频率
        self.apply_grad_penalty_every = apply_grad_penalty_every

        # 设置结果文件夹
        self.results_folder = Path(results_folder)

        # 如果结果文件夹不为空且确认清除之前的实验检查点和结果
        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        # 创建��果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)

    # 保存模型
    def save(self, path):
        if not self.accelerator.is_local_main_process:
            return

        pkg = dict(
            model = self.accelerator.get_state_dict(self.vae),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
        )
        torch.save(pkg, path)

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(path)

        vae = self.accelerator.unwrap_model(self.vae)
        vae.load_state_dict(pkg['model'])

        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

    # 打印信息
    def print(self, msg):
        self.accelerator.print(msg)

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    # 检查当前进程是否为主进程
    def is_main(self):
        return self.accelerator.is_main_process

    # 检查当前进程是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 训练函数,接受一个日志函数作为参数,默认为一个空函数
    def train(self, log_fn = noop):
        # 获取 VAE 模型参数的设备信息
        device = next(self.vae.parameters()).device

        # 在训练步数未达到指定步数之前循环执行训练步骤
        while self.steps < self.num_train_steps:
            # 执行单个训练步骤,返回日志信息
            logs = self.train_step()
            # 调用日志函数记录日志信息
            log_fn(logs)

        # 打印训练完成信息
        self.print('training complete')

.\lucidrains\phenaki-pytorch\phenaki_pytorch\data.py

# 导入所需的库
from pathlib import Path
import cv2
from PIL import Image
from functools import partial
from typing import Tuple, List
from beartype.door import is_bearable
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader as PytorchDataLoader
from torchvision import transforms as T, utils
from einops import rearrange

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 将输入值转换为元组
def pair(val):
    return val if isinstance(val, tuple) else (val, val)

# 调整帧数
def cast_num_frames(t, *, frames):
    f = t.shape[1]
    if f == frames:
        return t
    if f > frames:
        return t[:, :frames]
    return F.pad(t, (0, 0, 0, 0, 0, frames - f))

# 将图像转换为指定格式
def convert_image_to_fn(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 图像相关的辅助函数和数据集

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 处理读取和写入 GIF

# 通道数对应的图像模式
CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

# 读取 GIF 中的所有图像
def seek_all_images(img, channels = 3):
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    mode = CHANNELS_TO_MODE[channels]

    i = 0
    while True:
        try:
            img.seek(i)
            yield img.convert(mode)
        except EOFError:
            break
        i += 1

# 将视频张量转换为 GIF
def video_tensor_to_gif(
    tensor,
    path,
    duration = 120,
    loop = 0,
    optimize = True
):
    images = map(T.ToPILImage(), tensor.unbind(dim = 1))
    first_img, *rest_imgs = images
    first_img.save(path, save_all = True, append_images = rest_imgs, duration = duration, loop = loop, optimize = optimize)
    return images

# GIF 转换为张量
def gif_to_tensor(
    path,
    channels = 3,
    transform = T.ToTensor()
):
    img = Image.open(path)
    tensors = tuple(map(transform, seek_all_images(img, channels = channels)))
    return torch.stack(tensors, dim = 1)

# 处理读取和写入 MP4

# 将视频转换为张量
def video_to_tensor(
    path: str,              # 要导入的视频路径
    num_frames = -1,        # 要存储在输出张量中的帧数
    crop_size = None
) -> torch.Tensor:          # 形状为 (1, 通道数, 帧数, 高度, 宽度)

    video = cv2.VideoCapture(path)

    frames = []
    check = True

    while check:
        check, frame = video.read()

        if not check:
            continue

        if exists(crop_size):
            frame = crop_center(frame, *pair(crop_size))

        frames.append(rearrange(frame, '... -> 1 ...'))

    frames = np.array(np.concatenate(frames[:-1], axis = 0))  # 将帧列表转换为 numpy 数组
    frames = rearrange(frames, 'f h w c -> c f h w')

    frames_torch = torch.tensor(frames).float()

    return frames_torch[:, :num_frames, :, :]

# 将张量转换为视频
def tensor_to_video(
    tensor,                # Pytorch 视频张量
    path: str,             # 要保存的视频路径
    fps = 25,              # 保存视频的帧率
    # 定义视频格式为 MP4V
    video_format = 'MP4V'
# Import the video and cut it into frames.
def read_zip(fname):
    # 将张量移回 CPU
    tensor = tensor.cpu()

    # 获取张量的帧数、高度和宽度
    num_frames, height, width = tensor.shape[-3:]

    # 使用指定的视频格式创建 VideoWriter 对象
    fourcc = cv2.VideoWriter_fourcc(*video_format) # Changes in this line can allow for different video formats.
    video = cv2.VideoWriter(path, fourcc, fps, (width, height))

    frames = []

    # 遍历每一帧,将张量转换为 numpy 数组并写入视频
    for idx in range(num_frames):
        numpy_frame = tensor[:, idx, :, :].numpy()
        numpy_frame = np.uint8(rearrange(numpy_frame, 'c h w -> h w c'))
        video.write(numpy_frame)

    # 释放视频对象
    video.release()

    # 关闭所有 OpenCV 窗口
    cv2.destroyAllWindows()

    # 返回视频对象
    return video

# 将图像中心裁剪为指定大小
def crop_center(
    img,        # tensor
    cropx,      # Length of the final image in the x direction.
    cropy       # Length of the final image in the y direction.
) -> torch.Tensor:
    y, x, c = img.shape
    startx = x // 2 - cropx // 2
    starty = y // 2 - cropy // 2
    return img[starty:(starty + cropy), startx:(startx + cropx), :]

# 视频数据集类
class VideoDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        channels = 3,
        num_frames = 17,
        horizontal_flip = False,
        force_num_frames = True,
        exts = ['gif', 'mp4']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.channels = channels
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 定义数据转换流程
        self.transform = T.Compose([
            T.Resize(image_size),
            T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

        # 定义将视频路径转换为张量的函数
        self.gif_to_tensor = partial(gif_to_tensor, channels = self.channels, transform = self.transform)
        self.mp4_to_tensor = partial(video_to_tensor, crop_size = self.image_size)

        # 定义将帧数转换为指定数量的函数
        self.cast_num_frames_fn = partial(cast_num_frames, frames = num_frames) if force_num_frames else identity

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        ext = path.suffix

        # 根据文���扩展名选择相应的处理方式
        if ext == '.gif':
            tensor = self.gif_to_tensor(path)
        elif ext == '.mp4':
            tensor = self.mp4_to_tensor(str(path))
        else:
            raise ValueError(f'unknown extension {ext}')

        # 转换帧数并返回张量
        return self.cast_num_frames_fn(tensor)

# 重写数据加载器以能够整理字符串
def collate_tensors_and_strings(data):
    if is_bearable(data, List[torch.Tensor]):
        return (torch.stack(data, dim = 0),)

    data = zip(*data)
    output = []

    for datum in data:
        if is_bearable(datum, Tuple[torch.Tensor, ...]):
            datum = torch.stack(datum, dim = 0)
        elif is_bearable(datum, Tuple[str, ...]):
            datum = list(datum)
        else:
            raise ValueError('detected invalid type being passed from dataset')

        output.append(datum)

    return tuple(output)

# 创建数据加载器
def DataLoader(*args, **kwargs):
    return PytorchDataLoader(*args, collate_fn = collate_tensors_and_strings, **kwargs)

.\lucidrains\phenaki-pytorch\phenaki_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    **kwargs
):
    # 根据是否需要梯度过滤参数
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果权重衰减为0,则使用 Adam 优化器
    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要对参数进行分组权重衰减
    if group_wd_params:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 将参数分为需要权重衰减和不需要权重衰减的两组
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 使用 AdamW 优化器,设置学习率、权重衰减、动量参数和 epsilon
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\phenaki-pytorch\phenaki_pytorch\phenaki_pytorch.py

# 导入数学库
import math
# 导入 functools 库
import functools
# 从 contextlib 库中导入 nullcontext
from contextlib import nullcontext
# 从 functools 库中导入 partial 和 wraps
from functools import partial, wraps

# 从 typing 模块中导入 Optional, List, Union
from typing import Optional, List, Union
# 从 beartype 库中导入 beartype
from beartype import beartype

# 导入 torch 库
import torch
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch 中导入 nn, einsum
from torch import nn, einsum

# 从 einops 库中导入 rearrange, repeat, pack, unpack
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 中导入 Rearrange
from einops.layers.torch import Rearrange

# 从 phenaki_pytorch.t5 中导入 t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
from phenaki_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 从 phenaki_pytorch.cvivit 中导入 CViViT
from phenaki_pytorch.cvivit import CViViT
# 从 phenaki_pytorch.attention 中导入 Attention, Transformer, ContinuousPositionBias

# helpers

# 定义函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义函数 default,返回值或默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数 cast_tuple,将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else (val,) * length

# 定义函数 reduce_mult,对数组中的元素进行累乘
def reduce_mult(arr):
    return functools.reduce(lambda x, y: x * y, arr)

# 定义函数 divisible_by,判断两个数是否整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# tensor helpers

# 定义函数 get_mask_subset_with_prob,根据概率获取掩码子集
def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device

    num_tokens = mask.sum(dim = -1)
    num_pads = seq_len - num_tokens
    num_masked = (prob * num_tokens).round().clamp(min = 1)

    randperm_indices = torch.rand((batch, seq_len), device = device).argsort(dim = -1)
    randperm_indices -= rearrange(num_pads, 'b -> b 1')
    randperm_indices.masked_fill_(randperm_indices < 0, seq_len) # set to max out of bounds, so never chosen

    mask_subset = randperm_indices < rearrange(num_masked, 'b -> b 1')
    return mask_subset

# decorators

# 定义装饰器 eval_decorator,用于在评估模型时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# classifier free guidance functions

# 定义函数 uniform,生成指定形状的均匀分布张量
def uniform(shape, device):
    return torch.zeros(shape, device = device).float().uniform_(0, 1)

# 定义函数 prob_mask_like,生成概率掩码张量
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# tensor helper functions

# 定义函数 log,计算张量的对数
def log(t, eps = 1e-10):
    return torch.log(t + eps)

# sampling helpers

# 定义函数 gumbel_noise,生成古贝尔噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 定义函数 gumbel_sample,使用古贝尔噪声进行采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# 定义函数 top_k,根据阈值获取前 k 个概率最大的位置
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# mask git

# 定义 MaskGit 类
class MaskGit(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        max_seq_len,
        gradient_shrink_alpha = 0.1,
        heads = 8,
        dim_head = 64,
        unconditional = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        **kwargs
    # 初始化函数,设置模型的维度、mask_id、是否无条件生成等参数
    ):
        super().__init__()
        self.dim = dim

        self.mask_id = num_tokens
        self.unconditional = unconditional

        # 创建 token embedding 层,num_tokens + 1 个 token,最后一个用作 mask_id
        self.token_emb = nn.Embedding(num_tokens + 1, dim)

        self.max_seq_len = max_seq_len
        # 创建位置编码 embedding 层
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 设置梯度缩放参数
        self.gradient_shrink_alpha = gradient_shrink_alpha

        # 创建连续位置偏置
        self.continuous_pos_bias = ContinuousPositionBias(dim = dim_head, heads = heads, num_dims = 3)

        # 创建 Transformer 模型
        self.transformer = Transformer(
            dim = dim,
            attn_num_null_kv = 2,
            has_cross_attn = not self.unconditional,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            peg = True,
            **kwargs
        )

        # 创建输出层,将 dim 维度映射到 num_tokens
        self.to_logits = nn.Linear(dim, num_tokens)

    # 带条件缩放的前向传播函数
    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,
        **kwargs
    ):
        # 调用前向传播函数,cond_drop_prob 为 0
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        if cond_scale == 1:
            return logits

        # 调用前向传播函数,cond_drop_prob 为 1
        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    # 前向传播函数
    def forward(
        self,
        x,
        cond_drop_prob = 0.,
        text_mask = None,
        video_mask = None,
        video_patch_shape = None,
        return_embeds = False,
        **kwargs
    ):
        assert x.ndim in {2, 4}, 'video token ids must be of shape (batch, seq) or (batch, frame, height, width)'

        if x.ndim == 4:
            video_patch_shape = x.shape[1:]
            x = rearrange(x, 'b ... -> b (...)')

        b, n, device = *x.shape, x.device

        # 如果 text_mask 不存在,则创建全为 True 的 mask
        if not exists(text_mask):
            text_mask = torch.ones((b, n), device = device, dtype = torch.bool)

        assert exists(video_patch_shape), 'video patch shape must be given'

        # 计算相对位置偏置
        rel_pos_bias = self.continuous_pos_bias(*video_patch_shape, device = device)

        # 如果 cond_drop_prob 大于 0,则生成保留 mask
        if cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

        video_shape = (b, *video_patch_shape)

        # 对输入进行 token embedding
        x = self.token_emb(x)

        # 断言视频 token 序列长度不超过 max_seq_len
        assert n <= self.max_seq_len, f'the video token sequence length you are passing in ({n}) is greater than the `max_seq_len` ({self.max_seq_len}) set on your `MaskGit`'
        x = self.pos_emb(torch.arange(n, device = device)) + x

        # 梯度缩放
        x = x * self.gradient_shrink_alpha + x.detach() * (1 - self.gradient_shrink_alpha)

        # Transformer 模型的前向传播
        x = self.transformer(
            x,
            video_shape = video_shape,
            attn_bias = rel_pos_bias,
            self_attn_mask = video_mask,
            cross_attn_context_mask = text_mask,
            **kwargs
        )

        # 如果需要返回嵌入向量,则直接返回
        if return_embeds:
            return x

        return self.to_logits(x)
# 定义 TokenCritic 类,继承自 nn.Module
class TokenCritic(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 维度
        num_tokens,  # token 数量
        max_seq_len,  # 最大序列长度
        has_cross_attn = False,  # 是否有跨注意力
        attn_dropout = 0.,  # 注意力丢弃率
        ff_dropout = 0.,  # FeedForward 层丢弃率
        **kwargs
    ):
        super().__init__()
        self.has_cross_attn = has_cross_attn

        self.mask_id = num_tokens  # 定义 mask_id 为 num_tokens

        self.token_emb = nn.Embedding(num_tokens + 1, dim)  # 创建 token 的嵌入层,最后一个 token 用作 mask_id
        self.pos_emb = nn.Embedding(max_seq_len, dim)  # 创建位置嵌入层

        self.transformer = Transformer(
            dim = dim,
            peg = True,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            has_cross_attn = has_cross_attn,
            **kwargs
        )  # 创建 Transformer 模型

        self.to_logits = nn.Sequential(
            nn.Linear(dim, 1),  # 线性层
            Rearrange('... 1 -> ...')  # 重排维度
        )  # 创建输出 logits 的序列

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,  # 条件缩放
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)  # 调用 forward 方法获取 logits

        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)  # 调用 forward 方法获取 null_logits
        return null_logits + (logits - null_logits) * cond_scale  # 返回根据条件缩放计算后的结果

    def forward(
        self,
        x,
        text_mask = None,
        cond_drop_prob = None,
        context = None,
        video_mask = None,
        video_patch_shape = None,
        **kwargs
    ):
        if exists(video_patch_shape):
            video_shape = (x.shape[0], *video_patch_shape)
        else:
            video_shape = x.shape

        x = rearrange(x, 'b ... -> b (...)')  # 重排输入数据的维度
        b, n, device = *x.shape, x.device

        if not exists(text_mask):
            text_mask = torch.ones((b, n), device = device, dtype = torch.bool)  # 如果不存在文本 mask,则创建全为 True 的 mask

        if exists(context) and cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)  # 根据条件概率创建 mask
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask  # ���新文本 mask

        x = self.token_emb(x)  # 对输入数据进行 token 嵌入
        x = self.pos_emb(torch.arange(n, device = device)) + x  # 添加位置嵌入

        x = self.transformer(
            x,
            video_shape = video_shape,
            context = context,
            self_attn_mask = video_mask,
            cross_attn_context_mask = text_mask,
            **kwargs
        )  # 调用 Transformer 模型进行计算

        return self.to_logits(x)  # 返回 logits

# 定义 SelfCritic 类,继承自 nn.Module,受 Nijkamp 等人启发
@beartype
class SelfCritic(nn.Module):
    def __init__(
        self,
        maskgit: MaskGit  # 接收 MaskGit 类型参数
    ):
        super().__init__()
        self.maskgit = maskgit

        self.to_pred = nn.Sequential(
            nn.Linear(maskgit.dim, 1),  # 线性层
            Rearrange('... 1 -> ...')  # 重排维度
        )  # 创建输出预测的序列

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,  # 条件缩放
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)  # 调用 forward 方法获取 logits

        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)  # 调用 forward 方法获取 null_logits
        return null_logits + (logits - null_logits) * cond_scale  # 返回根据条件缩放计算后的结果

    def forward(self, x, *args, **kwargs):
        embeds = self.maskgit(x, *args, return_embeds = True, **kwargs)  # 调用 maskgit 方法获取嵌入
        return self.to_pred(embeds)  # 返回预测结果

# 定义 Phenaki 类,继承自 nn.Module
@beartype
class Phenaki(nn.Module):
    def __init__(
        self,
        *,
        maskgit: MaskGit,  # MaskGit 类型参数
        cvivit: CViViT,  # CViViT 类型参数
        critic: Optional[Union[TokenCritic, SelfCritic]] = None,  # 可选的 TokenCritic 或 SelfCritic 类型参数
        steps = 18,  # 步数
        t5_name = DEFAULT_T5_NAME,  # T5 模型名称
        sample_temperature = 0.,  # 采样温度
        text_embed_dim = None,  # 文本嵌入维度
        cond_drop_prob = 0.25,  # 条件丢弃概率
        max_text_len = 128,  # 最大文本长度
        self_token_critic = False,  # 是否使用自身 TokenCritic
        critic_loss_weight = 1.,  # TokenCritic 权重
        critic_noise_anneal_schedule = 'decay',  # TokenCritic 噪声退火计划
        critic_train_sample_temperature = 1.  # TokenCritic 训练采样温度
    # 初始化函数,继承父类的初始化方法
    def __init__(self):
        super().__init__()

        # 复制cvivit用于评估
        self.cvivit = cvivit.copy_for_eval()

        # 设置maskgit属性
        self.maskgit = maskgit
        self.unconditional = maskgit.unconditional

        # 设置mask_id属性
        self.mask_id = maskgit.mask_id

        # 断言条件,确保self_token_critic和critic不存在,或者critic存在
        assert not (self_token_critic and exists(critic))

        # 如果self_token_critic为真,则创建SelfCritic对象
        if self_token_critic:
            critic = SelfCritic(maskgit)

        # 如果critic存在,则将其设置为评估模式
        if exists(critic):
            critic = critic.eval()

        # 断言条件,确保critic不存在或者self_token_critic为真,或者maskgit.unconditional为假且critic具有交叉注意力
        assert not exists(critic) or self_token_critic or (not maskgit.unconditional) == critic.has_cross_attn

        # 设置critic相关属性
        self.critic = critic
        self.critic_noise_anneal_schedule = critic_noise_anneal_schedule
        self.critic_loss_weight = critic_loss_weight
        self.critic_train_sample_temperature = critic_train_sample_temperature

        # 设置步数和采样温度
        self.steps = steps
        self.sample_temperature = sample_temperature

        # 文本条件
        text_embed_dim = default(text_embed_dim, get_encoded_dim(t5_name))
        self.encode_texts = partial(t5_encode_text, name = t5_name)
        self.text_embed_dim = text_embed_dim
        self.max_text_len = max_text_len

        # 断言条件,确保cond_drop_prob大于0
        assert cond_drop_prob > 0.
        # 设置cond_drop_prob属性,用于transformers的分类器自由引导
        self.cond_drop_prob = cond_drop_prob # classifier free guidance for transformers - @crowsonkb

    # 采样图像函数
    def sample_images(
        self,
        *,
        texts: Union[List[str], str] = None,
        batch_size = 1,
        cond_scale = 3.,
        starting_temperature = 0.9,
        noise_K = 1.
    ):
        # 生成单帧视频
        single_framed_video = self.sample(
            texts = texts,
            num_frames = 1,
            cond_scale = cond_scale,
            starting_temperature = starting_temperature,
            noise_K = noise_K
        )

        # 重新排列视频维度
        return rearrange(single_framed_video, '... c 1 h w')

    # 采样函数
    @eval_decorator
    @torch.no_grad()
    def sample(
        self,
        *,
        num_frames,
        texts: Union[List[str], str] = None,
        prime_frames = None,
        batch_size = 1,
        cond_scale = 3.,
        starting_temperature = 0.9,
        noise_K = 1. # 用于token-critic论文第3.2节中critic分数的噪声超参数,需要找到正确的值
    def forward(
        self,
        videos = None,
        *,
        texts: Optional[List[str]] = None,
        video_codebook_ids = None,
        video_frame_mask = None,
        text_embeds = None,
        cond_drop_prob = None,
        only_train_generator = False,
        only_train_critic = False
# 定义一个名为 make_video 的函数,用于生成视频

@beartype
# 使用 beartype 装饰器对函数参数进行类型检查
def make_video(
    phenaki: Phenaki,  # 接受 Phenaki 对象作为参数
    texts: List[str],  # 接受一个字符串列表作为参数
    num_frames,  # 接受一个整数作为参数,表示帧数
    prime_lengths  # 接受一个整数或整数元组作为参数,表示前置长度
):
    num_scenes = len(texts)  # 获取文本列表的长度,即场景数
    num_frames = cast_tuple(num_frames, num_scenes)  # 将 num_frames 转换为元组,长度与场景数相同

    prime_lengths = cast_tuple(prime_lengths, num_scenes - 1)  # 将 prime_lengths 转换为元组,长度为场景数减一
    prime_lengths = (*prime_lengths, 0)  # 在 prime_lengths 元组末尾添加一个 0,表示最后一个场景无需前置长度

    entire_video = []  # 初始化整个视频列表
    video_prime = None  # 初始化视频前置
    scenes = []  # 初始化场景列表

    # 遍历文本、帧数、前置长度三个参数的元素,生成视频
    for text, scene_num_frames, next_scene_prime_length in zip(texts, num_frames, prime_lengths):
        # 从 Phenaki 对象中生成视频,传入文本、视频前置、场景帧数
        video = phenaki.sample(texts=text, prime_frames=video_prime, num_frames=scene_num_frames)
        scenes.append(video)  # 将生成的视频添加到场景列表中

        video_prime = video[:, :, -next_scene_prime_length:]  # 更新视频前置为当前视频的最后 next_scene_prime_length 帧

    # 将所有场景的视频拼接在一起,沿着第二维度拼接,返回拼接后的视频和场景列表
    return torch.cat(scenes, dim=2), scenes

.\lucidrains\phenaki-pytorch\phenaki_pytorch\phenaki_trainer.py

# 导入数学库
import math
# 导入复制库
import copy
# 导入路径库
from pathlib import Path
# 导入随机库
from random import random, choices
# 导入偏函数库
from functools import partial
# 导入命名元组库
from collections import namedtuple
# 导入 CPU 核心数库
from multiprocessing import cpu_count

# 导入 beartype 库
from beartype import beartype
# 导入 beartype.door 库
from beartype.door import is_bearable
# 导入 beartype.vale 库
from beartype.vale import Is
# 导入类型提示库
from typing import Optional, List, Iterable, Tuple
# 导入类型扩展库
from typing_extensions import Annotated

# 导入 PyTorch 库
import torch
# 从 PyTorch 中导入神经网络库和张量乘法库
from torch import nn, einsum
# 从 PyTorch 中导入函数库
import torch.nn.functional as F
# 从 PyTorch 中导入数据集库
from torch.utils.data import Dataset
# 从 PyTorch 中导入优化器库
from torch.optim import Adam

# 从 torchvision 中导入变换库
from torchvision import transforms as T
# 从 torchvision 中导入图像处理库
from torchvision.utils import make_grid, save_image

# 从 einops 中导入重排库和减少库
from einops import rearrange, reduce
# 从 einops.layers.torch 中导入重排层
from einops.layers.torch import Rearrange

# 从 PIL 中导入图像库
from PIL import Image
# 从 tqdm.auto 中导入进度条库
from tqdm.auto import tqdm

# 从 phenaki_pytorch.optimizer 中导入获取优化器函数
from phenaki_pytorch.optimizer import get_optimizer
# 从 accelerate 中导入加速器库
from accelerate import Accelerator

# 从 phenaki_pytorch.phenaki_pytorch 中导入 Phenaki 类
from phenaki_pytorch.phenaki_pytorch import Phenaki

# 从 phenaki_pytorch.data 中导入图像数据集、视频数据集、视频张量转 GIF、数据加载器
from phenaki_pytorch.data import ImageDataset, VideoDataset, video_tensor_to_gif, DataLoader

# 常量

# 数据集字段类型配置
DATASET_FIELD_TYPE_CONFIG = dict(
    videos = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim in {4, 5}]
    ],
    texts = List[str],
    video_codebook_ids = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.long]
    ],
    video_frame_mask = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.bool]
    ],
    text_embeds = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim == 3]
    ],
)

# 辅助函数

# 检查变量是否存在
def exists(x):
    return x is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 无限循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 检查整数是否有平方根
def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

# 将数字分组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# 将元素转移到指定设备
def elements_to_device_if_tensor(arr, device):
    output = []
    for el in arr:
        if isinstance(el, torch.Tensor):
            el = el.to(device)
        output.append(el)
    return output

# 分割可迭代对象
def split_iterable(it, split_size):
    accum = []
    for ind in range(math.ceil(len(it) / split_size)):
        start_index = ind * split_size
        accum.append(it[start_index: (start_index + split_size)])
    return accum

# 分割数据
def split(t, split_size = None):
    if not exists(split_size):
        return t

    if isinstance(t, torch.Tensor):
        return t.split(split_size, dim = 0)

    if isinstance(t, Iterable):
        return split_iterable(t, split_size)

    return TypeError

# 查找第一个符合条件的元素
def find_first(cond, arr):
    for el in arr:
        if cond(el):
            return el
    return None

# 分割参数和关键字参数
def split_args_and_kwargs(*args, batch_size = None, split_size = None, **kwargs):
    all_args = (*args, *kwargs.values())
    len_all_args = len(all_args)

    if not exists(batch_size):
        first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
        assert exists(first_tensor)
        batch_size = len(first_tensor)

    split_size = default(split_size, batch_size)
    num_chunks = math.ceil(batch_size / split_size)

    dict_len = len(kwargs)
    dict_keys = kwargs.keys()
    split_kwargs_index = len_all_args - dict_len

    split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
    chunk_sizes = tuple(map(len, split_all_args[0]))
    # 遍历元组中的每个元素,元素包含一个 chunk_size 和对应的参数列表
    for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
        # 将参数列表拆分为位置参数和关键字参数值
        chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
        # 将关键字参数的键和值组成字典
        chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
        # 计算当前 chunk 的大小占总 batch 大小的比例
        chunk_size_frac = chunk_size / batch_size
        # 生成当前 chunk 的比例和参数元组
        yield chunk_size_frac, (chunked_args, chunked_kwargs)
# 简单的文本转换函数,将特定字符替换为指定字符,去除空格和特殊字符,并截取指定长度
def simple_slugify(text, max_length = 255):
    return text.replace('-', '_').replace(',', '').replace(' ', '_').replace('|', '--').strip('-_')[:max_length]

# 检查元组中是否存在重复元素
def has_duplicates(tup):
    counts = dict()
    for el in tup:
        if el not in counts:
            counts[el] = 0
        counts[el] += 1
    return any(filter(lambda count: count > 1, counts.values()))

# 根据配置确定数据的类型
def determine_types(data, config):
    output = []
    for el in data:
        for name, data_type in config.items():
            if is_bearable(el, data_type):
                output.append(name)
                break
        else:
            raise TypeError(f'unable to determine type of {data}')

    return tuple(output)

# 训练器类
@beartype
class PhenakiTrainer(object):
    def __init__(
        self,
        phenaki: Phenaki,
        *,
        folder = None,
        train_on_images = False,
        batch_size = 16,
        grad_accum_every = 1,
        num_frames = 17,
        sample_num_frames = None,
        train_lr = 1e-4,
        train_num_steps = 100000,
        max_grad_norm = None,
        ema_update_every = 10,
        ema_decay = 0.995,
        adam_betas = (0.9, 0.99),
        wd = 0,
        save_and_sample_every = 1000,
        num_samples = 25,
        results_folder = './results',
        amp = False,
        fp16 = False,
        split_batches = True,
        convert_image_to = None,
        sample_texts_file_path = None,  # path to a text file with video captions, delimited by newline
        sample_texts: Optional[List[str]] = None,
        dataset: Optional[Dataset] = None,
        dataset_fields: Optional[Tuple[str, ...]] = None
    ):
        # 调用父类的构造函数
        super().__init__()
        # 导入 phenaki 模块中的 maskgit 和 cvivit
        maskgit = phenaki.maskgit
        cvivit = phenaki.cvivit

        # 确保 cvivit 在 phenaki 中存在
        assert exists(cvivit), 'cvivit must be present on phenaki'

        # 定义加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = 'fp16' if fp16 else 'no'
        )

        # 设置加速器的本地自动混合精度
        self.accelerator.native_amp = amp

        # 设置模型为 phenaki
        self.model = phenaki

        # 确保样本数量具有整数平方根
        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        # 设置是否无条件生成
        self.unconditional = maskgit.unconditional

        # 训练相关变量
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every
        self.max_grad_norm = max_grad_norm
        self.train_num_steps = train_num_steps
        self.image_size = cvivit.image_size

        # 采样相关变量
        self.num_samples = num_samples
        self.sample_texts = None

        # 如果存在采样文本文件路径,则读取文本内容
        if exists(sample_texts_file_path):
            sample_texts_file_path = Path(sample_texts_file_path)
            assert sample_texts_file_path.exists()
            captions = sample_texts_file_path.read_text().split('\n')
            self.sample_texts = list(filter(len, captions))

        # 如果存在采样文本,则设置为采样文本
        elif exists(self.sample_texts):
            self.sample_texts = sample_texts

        # 如果是无条件生成或存在采样文本,则继续,否则报错
        assert maskgit.unconditional or exists(self.sample_texts), 'if maskgit is to be trained text conditioned, `sample_texts` List[str] or `sample_texts_file_path` must be given'

        # 设置保存和采样频率
        self.save_and_sample_every = save_and_sample_every

        # 数据集和数据加载器
        dataset_klass = ImageDataset if train_on_images else VideoDataset
        self.sample_num_frames = default(sample_num_frames, num_frames)
        self.train_on_images = train_on_images

        # 如果存在数据集,则使用该数据集,否则根据训练类型选择数据集
        if dataset:
            self.ds = dataset
        elif train_on_images:
            assert exists(folder)
            self.ds = ImageDataset(folder, self.image_size)
        else:
            assert exists(folder)
            self.ds = VideoDataset(folder, self.image_size, num_frames = num_frames)

        # 创建数据加载器
        dl = DataLoader(self.ds, batch_size = batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # 如果存在数据集字段,则检查字段是否合法
        if exists(dataset_fields):
            assert not has_duplicates(dataset_fields), 'dataset fields must not have duplicate field names'
            valid_dataset_fields = set(DATASET_FIELD_TYPE_CONFIG.keys())
            assert len(set(dataset_fields) - valid_dataset_fields) == 0, f'dataset fields must be one of {valid_dataset_fields}'

        self.dataset_fields = dataset_fields

        # 优化器
        self.opt = get_optimizer(maskgit.parameters(), lr = train_lr, wd = wd, betas = adam_betas)

        # 步数计数器
        self.step = 0

        # 准备模型、数据加载器和优化器
        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

        # 设置结果文件���
        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(parents = True, exist_ok = True)

    # 将数据元组转换为关键字参数
    def data_tuple_to_kwargs(self, data):
        if not exists(self.dataset_fields):
            self.dataset_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
            assert not has_duplicates(self.dataset_fields), 'dataset fields must not have duplicate field names'

        return dict(zip(self.dataset_fields, data))

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 设备属性
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式属性
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 是否主进程属性
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 是否本地主进程属性
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process
    # 保存模型的当前状态
    def save(self, milestone):
        # 如果不是本地主进程,则直接返回
        if not self.accelerator.is_local_main_process:
            return

        # 构建保存的数据字典
        data = {
            'step': self.step,  # 保存当前步数
            'model': self.accelerator.get_state_dict(self.model),  # 保存模型的状态字典
            'opt': self.opt.state_dict(),  # 保存优化器的状态字典
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None  # 保存混合精度训练器的状态字典
        }

        # 将数据保存到文件中
        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

    # 加载指定里程碑的模型状态
    def load(self, milestone):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = accelerator.device

        # 从文件中加载数据
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)

        # 获取模型并加载状态
        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(data['model'])

        # 加载步数和优化器状态
        self.step = data['step']
        self.opt.load_state_dict(data['opt'])

        # 如果混合精度训练器存在且数据中也存在,则加载混合精度训练器状态
        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    # 训练步骤函数
    def train_step(
        self,
        only_train_generator=False,  # 是否只训练生成器
        only_train_critic=False  # 是否只训练评论家
    # 定义 train 方法,用于训练模型
    def train(
        self,
        only_train_generator = False,
        only_train_critic = False
        ):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = self.device

        # 初始化总损失
        total_loss = 0.

        # 循环执行梯度累积
        for _ in range(self.grad_accum_every):
            # 从数据加载器中获取数据
            data = next(self.dl)
            # 将数据转移到指定设备
            data = elements_to_device_if_tensor(data, device)
            # 将数据转换为关键字参数
            data_kwargs = self.data_tuple_to_kwargs(data)

            # 检查是否训练图像,数据维度是否正确
            assert not (self.train_on_images and data_kwargs['videos'].ndim != 4), 'you have it set to train on images, but the dataset is not returning tensors of 4 dimensions (batch, channels, height, width)'

            # 使用混合精度进行训练
            with self.accelerator.autocast():
                # 模型前向传播计算损失
                loss = self.model(**{
                    **data_kwargs,
                    'only_train_generator': only_train_generator,
                    'only_train_critic': only_train_critic
                })

                # 将损失除以梯度累积次数
                loss = loss / self.grad_accum_every
                # 累加总损失
                total_loss += loss.item()

            # 反向传播
            self.accelerator.backward(loss)

        # 如果存在最大梯度范数,则进行梯度裁剪
        if exists(self.max_grad_norm):
            accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        # 等待所有进程完成
        accelerator.wait_for_everyone()

        # 更新优化器参数
        self.opt.step()
        self.opt.zero_grad()

        # 等待所有进程完成
        accelerator.wait_for_everyone()

        # 如果是主进程且满足保存和采样间隔条件
        if self.is_main and self.step % self.save_and_sample_every == 0:
            # 模型转为评估模式
            self.model.eval()
            milestone = self.step // self.save_and_sample_every

            # 是否传入文本
            sample_kwargs = dict()

            if not self.unconditional:
                texts = choices(self.sample_texts, k = self.num_samples)
            else:
                texts = (None,) * self.num_samples

            sample_kwargs = {'texts': texts}

            # 选择采样方法
            if self.train_on_images:
                sample_method = self.model.sample_images
            else:
                sample_method = partial(self.model.sample, num_frames = self.sample_num_frames)

            # 分组评估,适当拆分参数
            with torch.no_grad():
                groups = num_to_groups(self.num_samples, self.batch_size)
                args_kwargs_iter = split_args_and_kwargs(batch_size = self.num_samples, split_size = self.batch_size, **sample_kwargs)

                all_sampled = []
                for group_batch_size, (_, (_, kwargs)) in zip(groups, args_kwargs_iter):
                    _kwargs = kwargs if not self.unconditional else dict()
                    sampled = sample_method(num_frames = self.sample_num_frames, batch_size = group_batch_size, **_kwargs)
                    all_sampled.append(sampled)

            # 保存视频和图像
            if not self.train_on_images:
                sampled_videos = torch.cat(all_sampled, dim = 0)
                milestone_folder = self.results_folder / f'videos.{milestone}'
                milestone_folder.mkdir(parents = True, exist_ok = True)

                for ind, (video_tensor, video_caption) in enumerate(zip(sampled_videos.unbind(dim = 0), texts)):
                    slugged_video_caption = simple_slugify(video_caption) if exists(video_caption) else str(ind)
                    video_tensor_to_gif(video_tensor, str(milestone_folder / f'{slugged_video_caption}.gif'))
            else:
                nrows = int(math.sqrt(self.num_samples))

                sampled_images = sampled_videos.detach().cpu().float().clamp(0., 1.)
                grid = make_grid(sampled_images, nrow = nrows, normalize = True, value_range = (0, 1))

                save_image(grid, str(self.results_folder / f'{milestone}.png'))

            # 保存检查点
            self.save(milestone)

        # 更新步数
        self.step += 1
        return total_loss
    ):  
        # 使用 tqdm 创建一个进度条,设置初始值为 self.step,总步数为 self.train_num_steps,如果不是主进程则禁用
        with tqdm(
            initial = self.step,
            total = self.train_num_steps,
            disable = not self.is_main
        ) as pbar:
            # 当 self.step 小于 self.train_num_steps 时循环
            while self.step < self.train_num_steps:
                # 调用 train_step 方法进行训练,传入参数 only_train_generator 和 only_train_critic
                loss = self.train_step(
                    only_train_generator = only_train_generator,
                    only_train_critic = only_train_critic
                )
                # 设置进度条的描述为当前 loss 值,保留四位小数
                pbar.set_description(f'loss: {loss:.4f}')
                # 更新进度条
                pbar.update(1)
        # 训练完成后打印信息
        self.print('training complete')

.\lucidrains\phenaki-pytorch\phenaki_pytorch\t5.py

# 导入 torch 库
import torch
# 导入 transformers 库
import transformers
# 从 transformers 库中导入 T5Tokenizer, T5EncoderModel, T5Config

# 减少警告信息,只使用编码器
transformers.logging.set_verbosity_error()

# 辅助函数
def exists(val):
    return val is not None

# 配置
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}

# 全局单例
# 获取指定名称的 tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

# 获取指定名称的模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定名称的模型和 tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()

    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)

    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config = config)

    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]

    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config

    else:
        raise ValueError(f'unknown t5 name {name}')

    return config.d_model

# 编码文本
def t5_encode_text(
    texts,
    name = DEFAULT_T5_NAME,
    output_device = None
):
    # 获取模型和 tokenizer
    t5, tokenizer = get_model_and_tokenizer(name)

    # 如果 CUDA 可用,则将模型移至 CUDA
    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device

    # 对文本进行编码
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = 'pt',
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = attn_mask[..., None].bool()

    # 如果输出设备不存在,则返回编码文本
    if not exists(output_device):
        encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
        return encoded_text

    encoded_text = encoded_text.to(output_device)
    attn_mask = attn_mask.to(output_device)

    encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
    return encoded_text

.\lucidrains\phenaki-pytorch\phenaki_pytorch\__init__.py

# 从 phenaki_pytorch 模块中导入 Phenaki, CViViT, MaskGit, TokenCritic, make_video 函数
from phenaki_pytorch.phenaki_pytorch import Phenaki, CViViT, MaskGit, TokenCritic, make_video

# 从 phenaki_pytorch 模块中导入 CViViTTrainer 类
from phenaki_pytorch.cvivit_trainer import CViViTTrainer

# 从 phenaki_pytorch 模块中导入 PhenakiTrainer 类
from phenaki_pytorch.phenaki_trainer import PhenakiTrainer

Phenaki - Pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch. It will also combine another technique involving a token critic for potentially even better generations

Please join Join us on Discord if you are interested in replicating this work in the open

AI Coffeebreak explanation

Appreciation

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

  • 🤗 Huggingface for their amazing transformers and accelerate library

  • Guillem for his ongoing contributions

  • You? If you are a great machine learning engineer and / or researcher, feel free to contribute to the frontier of open source generative AI

Install

$ pip install phenaki-pytorch

Usage

C-ViViT

import torch
from phenaki_pytorch import CViViT, CViViTTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
).cuda()

trainer = CViViTTrainer(
    cvivit,
    folder = '/path/to/images/or/videos',
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False,  # you can train on images first, before fine tuning on video, for sample efficiency
    use_ema = False,          # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
    num_train_steps = 10000
)

trainer.train()               # reconstructions and checkpoints will be saved periodically to ./results

Phenaki

import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 128),  # video with rectangular screen allowed
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

cvivit.load('/path/to/trained/cvivit.pt')

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

videos = torch.randn(3, 3, 17, 256, 128).cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 17)).bool().cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
]

loss = phenaki(videos, texts = texts, video_frame_mask = mask)
loss.backward()

# do the above for many steps, then ...

video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# so in the paper, they do not really achieve 2 minutes of coherent video
# at each new scene with new text conditioning, they condition on the previous K frames
# you can easily achieve this with this framework as so

video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3

video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# the total video

entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)

# and so on...

Or just import the make_video function

# ... above code

from phenaki_pytorch import make_video

entire_video, scenes = make_video(phenaki, texts = [
    'a squirrel examines an acorn buried in the snow',
    'a cat watches the squirrel from a frosted window sill',
    'zoom out to show the entire living room, with the cat residing by the window sill'
], num_frames = (17, 14, 14), prime_lengths = (5, 5))

entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)

# scenes - List[Tensor[3]] - video segment of each scene

That's it!

Token Critic

A new paper suggests that instead of relying on the predicted probabilities of each token as a measure of confidence, one can train an extra critic to decide what to iteratively mask during sampling. You can optionally train this critic for potentially better generations as shown below

import torch
from phenaki_pytorch import CViViT, MaskGit, TokenCritic, Phenaki

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 128),
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
)

# (1) define the critic

critic = TokenCritic(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
    has_cross_attn = True
)

trainer = Phenaki(
    maskgit = maskgit,
    cvivit = cvivit,
    critic = critic    # and then (2) pass it into Phenaki
).cuda()

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
]

videos = torch.randn(3, 3, 3, 256, 128).cuda() # (batch, channels, frames, height, width)

loss = trainer(videos = videos, texts = texts)
loss.backward()

Or even simpler, just reuse MaskGit itself as a Self Critic (Nijkamp et al), by setting self_token_critic = True on the initialization of Phenaki

phenaki = Phenaki(
    ...,
    self_token_critic= True  # set this to True
)

Now your generations should be greatly improved!

Phenaki Trainer

This repository will also endeavor to allow the researcher to train on text-to-image and then text-to-video. Similarly, for unconditional training, the researcher should be able to first train on images and then fine tune on video. Below is an example for text-to-video

import torch
from torch.utils.data import Dataset
from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

cvivit.load('/path/to/trained/cvivit.pt')

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
    unconditional = False
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

# mock text video dataset
# you will have to extend your own, and return the (<video tensor>, <caption>) tuple

class MockTextVideoDataset(Dataset):
    def __init__(
        self,
        length = 100,
        image_size = 256,
        num_frames = 17
    ):
        super().__init__()
        self.num_frames = num_frames
        self.image_size = image_size
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        video = torch.randn(3, self.num_frames, self.image_size, self.image_size)
        caption = 'video caption'
        return video, caption

dataset = MockTextVideoDataset()

# pass in the dataset

trainer = PhenakiTrainer(
    phenaki = phenaki,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False, # if your mock dataset above return (images, caption) pairs, set this to True
    dataset = dataset,       # pass in your dataset here
    sample_texts_file_path = '/path/to/captions.txt' # each caption should be on a new line, during sampling, will be randomly drawn
)

trainer.train()

Unconditional is as follows

ex. unconditional images and video training

import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

cvivit.load('/path/to/trained/cvivit.pt')

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
    unconditional = False
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

# pass in the folder to images or video

trainer = PhenakiTrainer(
    phenaki = phenaki,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = True,                # for sake of example, bottom is folder of images
    dataset = '/path/to/images/or/video'
)

trainer.train()

Todo

Citations

@article{Villegas2022PhenakiVL,
    title   = {Phenaki: Variable Length Video Generation From Open Domain Textual Description},
    author  = {Ruben Villegas and Mohammad Babaeizadeh and Pieter-Jan Kindermans and Hernan Moraldo and Han Zhang and Mohammad Taghi Saffar and Santiago Castro and Julius Kunze and D. Erhan},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.02399}
}
@article{Chang2022MaskGITMG,
    title   = {MaskGIT: Masked Generative Image Transformer},
    author  = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11305-11315}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{press2021ALiBi,
    title   = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
    author  = {Ofir Press and Noah A. Smith and Mike Lewis},
    year    = {2021},
    url     = {https://ofir.io/train_short_test_long.pdf}
}
@article{Liu2022SwinTV,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11999-12009}
}
@inproceedings{Nijkamp2021SCRIPTSP,
    title   = {SCRIPT: Self-Critic PreTraining of Transformers},
    author  = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
    booktitle = {North American Chapter of the Association for Computational Linguistics},
    year    = {2021}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@misc{mentzer2023finite,
    title   = {Finite Scalar Quantization: VQ-VAE Made Simple},
    author  = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
    year    = {2023},
    eprint  = {2309.15505},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\phenaki-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'phenaki-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.4.2',  # 版本号
  license='MIT',  # 许可证
  description = 'Phenaki - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/phenaki-pytorch',  # URL
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanisms',
    'text-to-video'
  ],
  install_requires = [  # 安装依赖列表
    'accelerate',
    'beartype',
    'einops>=0.7',
    'ema-pytorch>=0.2.2',
    'opencv-python',
    'pillow',
    'numpy',
    'sentencepiece',
    'torch>=1.6',
    'torchtyping',
    'torchvision',
    'transformers>=4.20.1',
    'tqdm',
    'vector-quantize-pytorch>=1.11.8'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\pi-GAN-pytorch\pi_gan_pytorch\coordconv.py

# 从给定链接中导入所需的库
# https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py
import torch
import torch.nn as nn

# 定义一个名为AddCoords的类,继承自nn.Module
class AddCoords(nn.Module):

    # 初始化函数,接受一个布尔值参数with_r,默认为False
    def __init__(self, with_r=False):
        super().__init__()
        self.with_r = with_r

    # 前向传播函数,接受一个输入张量input_tensor
    def forward(self, input_tensor):
        """
        Args:
            input_tensor: shape(batch, channel, x_dim, y_dim)
        """
        # 获取输入张量的维度信息
        batch_size, _, x_dim, y_dim = input_tensor.size()

        # 创建xx_channel和yy_channel张量,用于表示坐标信息
        xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
        yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)

        # 对坐标信息进行归一化处理
        xx_channel = xx_channel.float() / (x_dim - 1)
        yy_channel = yy_channel.float() / (y_dim - 1)

        # 将坐标信息映射到[-1, 1]范围内
        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

        # 将坐标信息扩展到batch维度,并转置维度
        xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
        yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)

        # 将坐标信息与输入张量拼接在一起
        ret = torch.cat([
            input_tensor,
            xx_channel.type_as(input_tensor),
            yy_channel.type_as(input_tensor)], dim=1)

        # 如果with_r为True,则计算距离信息并拼接到结果中
        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
            ret = torch.cat([ret, rr], dim=1)

        return ret

# 定义一个名为CoordConv的类,继承自nn.Module
class CoordConv(nn.Module):

    # 初始化函数,接受输入通道数in_channels、输出通道数out_channels和其他关键字参数
    def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
        super().__init__()
        # 创建AddCoords对象,传入with_r参数
        self.addcoords = AddCoords(with_r=with_r)
        # 计算输入尺寸大小
        in_size = in_channels+2
        if with_r:
            in_size += 1
        # 创建卷积层对象
        self.conv = nn.Conv2d(in_size, out_channels, **kwargs)

    # 前向传播函数,接受输入张量x
    def forward(self, x):
        # 将输入张量经过AddCoords处理后再经过卷积层处理
        ret = self.addcoords(x)
        ret = self.conv(ret)
        return ret

.\lucidrains\pi-GAN-pytorch\pi_gan_pytorch\nerf.py

# 从给定链接中获取的代码,需要从3D输入重构为5D输入(包含光线方向)

import torch
import torch.nn.functional as F
from einops import repeat, rearrange

# 创建二维网格
def meshgrid_xy(tensor1, tensor2):
    ii, jj = torch.meshgrid(tensor1, tensor2)
    return ii.transpose(-1, -2), jj.transpose(-1, -2)

# 计算累积乘积(不包括当前元素)
def cumprod_exclusive(tensor):
    cumprod = torch.cumprod(tensor, dim = -1)
    cumprod = torch.roll(cumprod, 1, -1)
    cumprod[..., 0] = 1.
    return cumprod

# 获取光线束
def get_ray_bundle(height, width, focal_length, tform_cam2world):
    ii, jj = meshgrid_xy(
      torch.arange(width).to(tform_cam2world),
      torch.arange(height).to(tform_cam2world)
    )

    directions = torch.stack([(ii - width * .5) / focal_length,
                            -(jj - height * .5) / focal_length,
                            -torch.ones_like(ii)
                           ], dim=-1)
    ray_directions = torch.sum(directions[..., None, :] * tform_cam2world[:3, :3], dim=-1)
    ray_origins = tform_cam2world[:3, -1].expand(ray_directions.shape)
    return ray_origins, ray_directions

# 从光线计算查询点
def compute_query_points_from_rays(
    ray_origins,
    ray_directions,
    near_thresh,
    far_thresh,
    num_samples,
    randomize = True
):
    depth_values = torch.linspace(near_thresh, far_thresh, num_samples).to(ray_origins)
    if randomize is True:
        noise_shape = list(ray_origins.shape[:-1]) + [num_samples]
        depth_values = depth_values \
            + torch.rand(noise_shape).to(ray_origins) * (far_thresh
                - near_thresh) / num_samples
    query_points = ray_origins[..., None, :] + ray_directions[..., None, :] * depth_values[..., :, None]
    return query_points, depth_values

# 渲染体密度
def render_volume_density(
    radiance_field,
    ray_origins,
    depth_values
):
    sigma_a = F.relu(radiance_field[..., 3])
    rgb = torch.sigmoid(radiance_field[..., :3])
    one_e_10 = torch.tensor([1e10], dtype=ray_origins.dtype, device=ray_origins.device)
    dists = torch.cat((depth_values[..., 1:] - depth_values[..., :-1],
                  one_e_10.expand(depth_values[..., :1].shape)), dim=-1)
    alpha = 1. - torch.exp(-sigma_a * dists)
    weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)

    rgb_map = (weights[..., None] * rgb).sum(dim=-2)
    depth_map = (weights * depth_values).sum(dim=-1)
    acc_map = weights.sum(-1)

    return rgb_map, depth_map, acc_map

# 从NERF模型获取图像
def get_image_from_nerf_model(
    model,
    latents,
    height,
    width,
    focal_length = 140,
    tform_cam2world = torch.eye(4),
    near_thresh = 2.,
    far_thresh = 6.,
    depth_samples_per_ray = 32
):
    tform_cam2world = tform_cam2world.to(latents)

    ray_origins, ray_directions = get_ray_bundle(height, width, focal_length,
                                               tform_cam2world)

    query_points, depth_values = compute_query_points_from_rays(
      ray_origins, ray_directions, near_thresh, far_thresh, depth_samples_per_ray
    )

    flattened_query_points = query_points.reshape((-1, 3))

    images = []
    for latent in latents.unbind(0):
        predictions = []
        predictions.append(model(latent, flattened_query_points))

        radiance_field_flattened = torch.cat(predictions, dim=0)

        unflattened_shape = list(query_points.shape[:-1]) + [4]
        radiance_field = torch.reshape(radiance_field_flattened, unflattened_shape)

        rgb_predicted, _, _ = render_volume_density(radiance_field, ray_origins, depth_values)
        image = rearrange(rgb_predicted, 'h w c -> c h w')
        images.append(image)

    return torch.stack(images)

.\lucidrains\pi-GAN-pytorch\pi_gan_pytorch\pi_gan_pytorch.py

# 导入所需的库
import math
from pathlib import Path
from functools import partial

import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad

from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR

from tqdm import trange
from PIL import Image
import torchvision
from torchvision.utils import save_image
import torchvision.transforms as T

# 导入自定义模块
from pi_gan_pytorch.coordconv import CoordConv
from pi_gan_pytorch.nerf import get_image_from_nerf_model
from einops import rearrange, repeat

# 检查是否有可用的 CUDA 设备
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'

# 定义一些辅助函数

def exists(val):
    return val is not None

def leaky_relu(p = 0.2):
    return nn.LeakyReLU(p)

def to_value(t):
    return t.clone().detach().item()

def get_module_device(module):
    return next(module.parameters()).device

# 定义损失函数

def gradient_penalty(images, output, weight = 10):
    batch_size, device = images.shape[0], images.device
    gradients = torch_grad(outputs=output, inputs=images,
                           grad_outputs=torch.ones(output.size(), device=device),
                           create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.reshape(batch_size, -1)
    l2 = ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
    return weight * l2

# 定义正弦激活函数

class Sine(nn.Module):
    def __init__(self, w0 = 1.):
        super().__init__()
        self.w0 = w0
    def forward(self, x):
        return torch.sin(self.w0 * x)

# 定义 Siren 层

class Siren(nn.Module):
    def __init__(self, dim_in, dim_out, w0 = 1., c = 6., is_first = False, use_bias = True, activation = None):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c = c, w0 = w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.activation = Sine(w0) if activation is None else activation

    def init_(self, weight, bias, c, w0):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if bias is not None:
            bias.uniform_(-w_std, w_std)

    def forward(self, x, gamma = None, beta = None):
        out =  F.linear(x, self.weight, self.bias)

        # FiLM modulation

        if exists(gamma):
            out = out * gamma

        if exists(beta):
            out = out + beta

        out = self.activation(out)
        return out

# 定义映射网络

class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim, lr_mul = 0.1, bias = True):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim))

        self.lr_mul = lr_mul

    def forward(self, input):
        return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)

class MappingNetwork(nn.Module):
    def __init__(self, *, dim, dim_out, depth = 3, lr_mul = 0.1):
        super().__init__()

        layers = []
        for i in range(depth):
            layers.extend([EqualLinear(dim, dim, lr_mul), leaky_relu()])

        self.net = nn.Sequential(*layers)

        self.to_gamma = nn.Linear(dim, dim_out)
        self.to_beta = nn.Linear(dim, dim_out)

    def forward(self, x):
        x = F.normalize(x, dim = -1)
        x = self.net(x)
        return self.to_gamma(x), self.to_beta(x)

# 定义 Siren 网络

class SirenNet(nn.Module):
    # 初始化神经网络模型
    def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial = 30., use_bias = True, final_activation = None):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个空的神经网络层列表
        self.layers = nn.ModuleList([])

        # 循环创建指定数量的 Siren 层
        for ind in range(num_layers):
            # 判断是否是第一层
            is_first = ind == 0
            # 根据是否是第一层选择不同的参数
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden

            # 将创建的 Siren 层添加到神经网络层列表中
            self.layers.append(Siren(
                dim_in = layer_dim_in,
                dim_out = dim_hidden,
                w0 = layer_w0,
                use_bias = use_bias,
                is_first = is_first
            ))

        # 创建最后一层 Siren 层
        self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)

    # 前向传播函数
    def forward(self, x, gamma, beta):
        # 遍历神经网络层列表,依次进行前向传播
        for layer in self.layers:
            x = layer(x, gamma, beta)
        # 返回最后一层的前向传播结果
        return self.last_layer(x)
# 定义 Siren 生成器类
class SirenGenerator(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,
        dim_hidden,
        siren_num_layers = 8
    ):
        super().__init__()

        # 创建映射网络对象
        self.mapping = MappingNetwork(
            dim = dim,
            dim_out = dim_hidden
        )

        # 创建 Siren 网络对象
        self.siren = SirenNet(
            dim_in = 3,
            dim_hidden = dim_hidden,
            dim_out = dim_hidden,
            num_layers = siren_num_layers
        )

        # 创建输出 alpha 的线性层
        self.to_alpha = nn.Linear(dim_hidden, 1)

        # 创建 Siren 网络对象用于生成 RGB
        self.to_rgb_siren = Siren(
            dim_in = dim_hidden,
            dim_out = dim_hidden
        )

        # 创建输出 RGB 的线性层
        self.to_rgb = nn.Linear(dim_hidden, 3)

    # 前向传播函数
    def forward(self, latent, coors, batch_size = 8192):
        # 获取 gamma 和 beta
        gamma, beta = self.mapping(latent)

        outs = []
        # 分批处理坐标
        for coor in coors.split(batch_size):
            # 重排 gamma 和 beta 的维度
            gamma_, beta_ = map(lambda t: rearrange(t, 'n -> () n'), (gamma, beta))
            # 使用 Siren 网络生成 x
            x = self.siren(coor, gamma_, beta_)
            # 生成 alpha
            alpha = self.to_alpha(x)

            # 使用 Siren 网络生成 RGB
            x = self.to_rgb_siren(x, gamma, beta)
            rgb = self.to_rgb(x)
            # 拼接 RGB 和 alpha
            out = torch.cat((rgb, alpha), dim = -1)
            outs.append(out)

        return torch.cat(outs)

# 定义生成器类
class Generator(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        image_size,
        dim,
        dim_hidden,
        siren_num_layers
    ):
        super().__init__()
        self.dim = dim
        self.image_size = image_size

        # 创建 Siren 生成器对象
        self.nerf_model = SirenGenerator(
            dim = dim,
            dim_hidden = dim_hidden,
            siren_num_layers = siren_num_layers
        )

    # 设置图像尺寸
    def set_image_size(self, image_size):
        self.image_size = image_size

    # 前向传播函数
    def forward(self, latents):
        image_size = self.image_size
        device, b = latents.device, latents.shape[0]

        # 从 Siren 生成器模型获取生成的图像
        generated_images = get_image_from_nerf_model(
            self.nerf_model,
            latents,
            image_size,
            image_size
        )

        return generated_images

# 定义判别器块类
class DiscriminatorBlock(nn.Module):
    # 初始化函数
    def __init__(self, dim, dim_out):
        super().__init__()
        # 创建 CoordConv 层
        self.res = CoordConv(dim, dim_out, kernel_size = 1, stride = 2)

        # 创建网络序列
        self.net = nn.Sequential(
            CoordConv(dim, dim_out, kernel_size = 3, padding = 1),
            leaky_relu(),
            CoordConv(dim_out, dim_out, kernel_size = 3, padding = 1),
            leaky_relu()
        )

        # 下采样层
        self.down = nn.AvgPool2d(2)

    # 前向传播函数
    def forward(self, x):
        res = self.res(x)
        x = self.net(x)
        x = self.down(x)
        x = x + res
        return x

# 定义判别器类
class Discriminator(nn.Module):
    # 初始化函数
    def __init__(
        self,
        image_size,
        init_chan = 64,
        max_chan = 400,
        init_resolution = 32,
        add_layer_iters = 10000
    ):
        # 调用父类的构造函数
        super().__init__()
        # 计算图像大小的对数值
        resolutions = math.log2(image_size)
        # 断言图像大小必须是2的幂
        assert resolutions.is_integer(), 'image size must be a power of 2'
        # 断言初始分辨率必须是2的幂
        assert math.log2(init_resolution).is_integer(), 'initial resolution must be power of 2'

        # 将对数值转换为整数
        resolutions = int(resolutions)
        # 计算层数
        layers = resolutions - 1

        # 计算通道数列表
        chans = list(reversed(list(map(lambda t: 2 ** (11 - t), range(layers))))
        # 将通道数限制在最大通道数以内
        chans = list(map(lambda n: min(max_chan, n), chans))
        # 添加初始通道数到通道数列表
        chans = [init_chan, *chans]
        # 获取最终通道数
        final_chan = chans[-1]

        # 初始化 from_rgb_layers 和 layers
        self.from_rgb_layers = nn.ModuleList([])
        self.layers = nn.ModuleList([])
        self.image_size = image_size
        self.resolutions = list(map(lambda t: 2 ** (7 - t), range(layers)))

        # 遍历分辨率、输入通道数、输出通道数,创建 from_rgb_layer 和 DiscriminatorBlock
        for resolution, in_chan, out_chan in zip(self.resolutions, chans[:-1], chans[1:]):

            from_rgb_layer = nn.Sequential(
                CoordConv(3, in_chan, kernel_size = 1),
                leaky_relu()
            ) if resolution >= init_resolution else None

            self.from_rgb_layers.append(from_rgb_layer)

            self.layers.append(DiscriminatorBlock(
                dim = in_chan,
                dim_out = out_chan
            ))

        # 创建最终卷积层
        self.final_conv = CoordConv(final_chan, 1, kernel_size = 2)

        # 初始化 alpha、resolution 和 iterations
        self.add_layer_iters = add_layer_iters
        self.register_buffer('alpha', torch.tensor(0.))
        self.register_buffer('resolution', torch.tensor(init_resolution))
        self.register_buffer('iterations', torch.tensor(0.))

    # 增加分辨率
    def increase_resolution_(self):
        if self.resolution >= self.image_size:
            return

        self.alpha += self.alpha + (1 - self.alpha)
        self.iterations.fill_(0.)
        self.resolution *= 2

    # 更新迭代次数
    def update_iter_(self):
        self.iterations += 1
        self.alpha -= (1 / self.add_layer_iters)
        self.alpha.clamp_(min = 0.)

    # 前向传播函数
    def forward(self, img):
        x = img

        for resolution, from_rgb, layer in zip(self.resolutions, self.from_rgb_layers, self.layers):
            if self.resolution < resolution:
                continue

            if self.resolution == resolution:
                x = from_rgb(x)

            if bool(resolution == (self.resolution // 2)) and bool(self.alpha > 0):
                x_down = F.interpolate(img, scale_factor = 0.5)
                x = x * (1 - self.alpha) + from_rgb(x_down) * self.alpha

            x = layer(x)

        out = self.final_conv(x)
        return out
# 定义 piGAN 类
class piGAN(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        dim,
        init_resolution = 32,
        generator_dim_hidden = 256,
        siren_num_layers = 6,
        add_layer_iters = 10000
    ):
        super().__init__()
        self.dim = dim

        # 初始化生成器 G
        self.G = Generator(
            image_size = image_size,
            dim = dim,
            dim_hidden = generator_dim_hidden,
            siren_num_layers = siren_num_layers
        )

        # 初始化判别器 D
        self.D = Discriminator(
            image_size = image_size,
            add_layer_iters = add_layer_iters,
            init_resolution = init_resolution
        )

# 定义数据集相关函数

# 无限循环迭代器
def cycle(iterable):
    while True:
        for i in iterable:
            yield i

# 调整图像大小至最小尺寸
def resize_to_minimum_size(min_size, image):
    if max(*image.size) < min_size:
        return torchvision.transforms.functional.resize(image, min_size)
    return image

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        transparent = False,
        aug_prob = 0.,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
        assert len(self.paths) > 0, f'No images were found in {folder} for training'
        self.create_transform(image_size)

    # 创建图像转换函数
    def create_transform(self, image_size):
        self.transform = T.Compose([
            T.Lambda(partial(resize_to_minimum_size, image_size)),
            T.Resize(image_size),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 训练器类

# 生成器采样函数
def sample_generator(G, batch_size):
    dim = G.dim
    rand_latents = torch.randn(batch_size, dim).cuda()
    return G(rand_latents)

class Trainer(nn.Module):
    def __init__(
        self,
        *,
        gan,
        folder,
        add_layers_iters = 10000,
        batch_size = 8,
        gradient_accumulate_every = 4,
        sample_every = 100,
        log_every = 10,
        num_train_steps = 50000,
        lr_gen = 5e-5,
        lr_discr = 4e-4,
        target_lr_gen = 1e-5,
        target_lr_discr = 1e-4,
        lr_decay_span = 10000
    ):
        super().__init__()
        gan.D.add_layer_iters = add_layers_iters
        self.add_layers_iters = add_layers_iters

        # 将 gan 移至 GPU
        self.gan = gan.cuda()

        # 初始化判别器和生成器的优化器
        self.optim_D = Adam(self.gan.D.parameters(), betas=(0, 0.9), lr = lr_discr)
        self.optim_G = Adam(self.gan.G.parameters(), betas=(0, 0.9), lr = lr_gen)

        # 定义判别器和生成器的学习率衰减函数
        D_decay_fn = lambda i: max(1 - i / lr_decay_span, 0) + (target_lr_discr / lr_discr) * min(i / lr_decay_span, 1)
        G_decay_fn = lambda i: max(1 - i / lr_decay_span, 0) + (target_lr_gen / lr_gen) * min(i / lr_decay_span, 1)

        # 初始化判别器和生成器的学习率调度器
        self.sched_D = LambdaLR(self.optim_D, D_decay_fn)
        self.sched_G = LambdaLR(self.optim_G, G_decay_fn)

        self.iterations = 0
        self.batch_size = batch_size
        self.num_train_steps = num_train_steps

        self.log_every = log_every
        self.sample_every = sample_every
        self.gradient_accumulate_every = gradient_accumulate_every

        # 初始化数据集和数据加载器
        self.dataset = ImageDataset(folder = folder, image_size = gan.D.resolution.item())
        self.dataloader = cycle(DataLoader(self.dataset, batch_size = batch_size, shuffle = True, drop_last = True))

        self.last_loss_D = 0
        self.last_loss_G = 0
    # 定义每一步训练的操作
    def step(self):
        # 获取GAN模型的判别器D、生成器G、批量大小batch_size、维度dim、梯度累积次数accumulate_every
        D, G, batch_size, dim, accumulate_every = self.gan.D, self.gan.G, self.batch_size, self.gan.dim, self.gradient_accumulate_every

        # 设置适当的图像大小
        if self.iterations % self.add_layers_iters == 0:
            if self.iterations != 0:
                D.increase_resolution_()

            # 获取图像大小
            image_size = D.resolution.item()
            G.set_image_size(image_size)
            self.dataset.create_transform(image_size)

        # 是否应用梯度惩罚
        apply_gp = self.iterations % 4 == 0

        # 训练判别器
        D.train()
        loss_D = 0

        for _ in range(accumulate_every):
            # 获取下一个批量图像数据
            images = next(self.dataloader)
            images = images.cuda().requires_grad_()
            real_out = D(images)

            # 生成假图像
            fake_imgs = sample_generator(G, batch_size)
            fake_out = D(fake_imgs.clone().detach())

            # 计算梯度惩罚
            divergence = (F.relu(1 + real_out) + F.relu(1 - fake_out)).mean()
            loss = divergence

            if apply_gp:
                gp = gradient_penalty(images, real_out)
                self.last_loss_gp = to_value(gp)
                loss = loss + gp

            (loss / accumulate_every).backward()
            loss_D += to_value(divergence) / accumulate_every

        self.last_loss_D = loss_D

        self.optim_D.step()
        self.optim_D.zero_grad()

        # 训练生成器
        G.train()
        loss_G = 0

        for _ in range(accumulate_every):
            fake_out = sample_generator(G, batch_size)
            loss = D(fake_out).mean()
            (loss / accumulate_every).backward()
            loss_G += to_value(loss) / accumulate_every

        self.last_loss_G = loss_G

        self.optim_G.step()
        self.optim_G.zero_grad()

        # 更新调度器
        self.sched_D.step()
        self.sched_G.step()

        self.iterations += 1
        D.update_iter_()

    # 前向传播函数
    def forward(self):
        for _ in trange(self.num_train_steps):
            self.step()

            # 每隔一定步数打印损失信息
            if self.iterations % self.log_every == 0:
                print(f'I: {self.gan.D.resolution.item()} | D: {self.last_loss_D:.2f} | G: {self.last_loss_G:.2f} | GP: {self.last_loss_gp:.2f}')

            # 每隔一定步数保存生成的图像
            if self.iterations % self.sample_every == 0:
                i = self.iterations // self.sample_every
                imgs = sample_generator(self.gan.G, 4)
                imgs.clamp_(0., 1.)
                save_image(imgs, f'./{i}.png', nrow=2)

.\lucidrains\pi-GAN-pytorch\pi_gan_pytorch\__init__.py

# 从 pi_gan_pytorch.pi_gan_pytorch 模块中导入 Generator, Discriminator, piGAN, Trainer 类
from pi_gan_pytorch.pi_gan_pytorch import Generator, Discriminator, piGAN, Trainer

π-GAN - Pytorch (wip)

Implementation of π-GAN, for 3d-aware image synthesis, in Pytorch.

Project video from authors

Install

$ pip install pi-gan-pytorch

Usage

from pi_gan_pytorch import piGAN, Trainer

gan = piGAN(
    image_size = 128,
    dim = 512
).cuda()

trainer = Trainer(
    gan = gan,
    folder = '/path/to/images'
)

trainer()

Citations

@misc{chan2020pigan,
    title={pi-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis}, 
    author={Eric R. Chan and Marco Monteiro and Petr Kellnhofer and Jiajun Wu and Gordon Wetzstein},
    year={2020},
    eprint={2012.00926},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

.\lucidrains\pi-GAN-pytorch\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'pi-gan-pytorch',
  # 查找所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.11',
  # 许可证
  license='MIT',
  # 描述
  description = 'π-GAN - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/pi-gan-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'generative adversarial network'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.3',
    'pillow',
    'torch>=1.6',
    'torchvision',
    'tqdm'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\pixel-level-contrastive-learning\pixel_level_contrastive_learning\pixel_level_contrastive_learning.py

# 导入数学库
import math
# 导入复制库
import copy
# 导入随机库
import random
# 导入wraps和partial函数
from functools import wraps, partial
# 从数学库中导入floor函数
from math import floor

# 导入torch库
import torch
# 从torch中导入nn和einsum模块
from torch import nn, einsum
# 从torch.nn中导入functional模块
import torch.nn.functional as F

# 从kornia库中导入augmentation、filters和color模块
from kornia import augmentation as augs
from kornia import filters, color

# 从einops库中导入rearrange函数
from einops import rearrange

# 辅助函数

# 返回输入的张量
def identity(t):
    return t

# 如果输入值为None,则返回默认值
def default(val, def_val):
    return def_val if val is None else val

# 根据概率返回True或False
def rand_true(prob):
    return random.random() < prob

# 缓存装饰器,用于缓存计算结果
def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# 获取模块所在设备
def get_module_device(module):
    return next(module.parameters()).device

# 设置模型参数是否需要梯度
def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

# 随机生成cutout的坐标和比例
def cutout_coordinates(image, ratio_range = (0.6, 0.8)):
    _, _, orig_h, orig_w = image.shape

    ratio_lo, ratio_hi = ratio_range
    random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo)
    w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h)
    coor_x = floor((orig_w - w) * random.random())
    coor_y = floor((orig_h - h) * random.random())
    return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio

# 对cutout后的图像进行插值缩放
def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
    shape = image.shape
    output_size = default(output_size, shape[2:])
    (y0, y1), (x0, x1) = coordinates
    cutout_image = image[:, :, y0:y1, x0:x1]
    return F.interpolate(cutout_image, size = output_size, mode = mode)

# 数据增强工具

# 随机应用函数
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# 指数移动平均

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# 损失函数

# 计算损失函数
def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

# 类

# 多层感知器
class MLP(nn.Module):
    def __init__(self, chan, chan_out = 256, inner_dim = 2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(chan, inner_dim),
            nn.BatchNorm1d(inner_dim),
            nn.ReLU(),
            nn.Linear(inner_dim, chan_out)
        )

    def forward(self, x):
        return self.net(x)

# 卷积多层感知器
class ConvMLP(nn.Module):
    def __init__(self, chan, chan_out = 256, inner_dim = 2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, inner_dim, 1),
            nn.BatchNorm2d(inner_dim),
            nn.ReLU(),
            nn.Conv2d(inner_dim, chan_out, 1)
        )

    def forward(self, x):
        return self.net(x)

# 空间金字塔池化
class PPM(nn.Module):
    # 初始化函数,设置网络的参数
    def __init__(
        self,
        *,
        chan,
        num_layers = 1,
        gamma = 2):
        # 调用父类的初始化函数
        super().__init__()
        # 设置网络的 gamma 参数
        self.gamma = gamma

        # 根据 num_layers 的值选择不同的转换网络
        if num_layers == 0:
            # 如果 num_layers 为 0,则使用恒等映射
            self.transform_net = nn.Identity()
        elif num_layers == 1:
            # 如果 num_layers 为 1,则使用一个卷积层
            self.transform_net = nn.Conv2d(chan, chan, 1)
        elif num_layers == 2:
            # 如果 num_layers 为 2,则使用两个卷积层和批归一化层
            self.transform_net = nn.Sequential(
                nn.Conv2d(chan, chan, 1),
                nn.BatchNorm2d(chan),
                nn.ReLU(),
                nn.Conv2d(chan, chan, 1)
            )
        else:
            # 如果 num_layers 不是 0、1 或 2,则抛出数值错误
            raise ValueError('num_layers must be one of 0, 1, or 2')

    # 前向传播函数,定义网络的计算流程
    def forward(self, x):
        # 对输入张量 x 进行维度扩展
        xi = x[:, :, :, :, None, None]
        xj = x[:, :, None, None, :, :]
        # 计算相似度矩阵,使用余弦相似度并进行非负化和幂运算
        similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma

        # 对输入张量 x 进行变换
        transform_out = self.transform_net(x)
        # 使用 einsum 函数将相似度矩阵和变换后的张量进行乘积和重组
        out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out)
        # 返回计算结果
        return out
# 一个用于基础神经网络的包装类
# 将管理隐藏层输出的拦截并将其传递到投影器和预测器网络中

class NetWrapper(nn.Module):
    def __init__(
        self,
        *,
        net,
        projection_size,
        projection_hidden_size,
        layer_pixel = -2,
        layer_instance = -2
    ):
        super().__init__()
        self.net = net
        self.layer_pixel = layer_pixel
        self.layer_instance = layer_instance

        self.pixel_projector = None
        self.instance_projector = None

        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size

        self.hidden_pixel = None
        self.hidden_instance = None
        self.hook_registered = False

    # 查找指定层
    def _find_layer(self, layer_id):
        if type(layer_id) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(layer_id, None)
        elif type(layer_id) == int:
            children = [*self.net.children()]
            return children[layer_id]
        return None

    # 钩子函数,用于拦截像素层输出
    def _hook_pixel(self, _, __, output):
        setattr(self, 'hidden_pixel', output)

    # 钩子函数,用于拦截实例层输出
    def _hook_instance(self, _, __, output):
        setattr(self, 'hidden_instance', output)

    # 注册钩子函数
    def _register_hook(self):
        pixel_layer = self._find_layer(self.layer_pixel)
        instance_layer = self._find_layer(self.layer_instance)

        assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found'
        assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found'

        pixel_layer.register_forward_hook(self._hook_pixel)
        instance_layer.register_forward_hook(self._hook_instance)
        self.hook_registered = True

    # 获取像素投影器
    @singleton('pixel_projector')
    def _get_pixel_projector(self, hidden):
        _, dim, *_ = hidden.shape
        projector = ConvMLP(dim, self.projection_size, self.projection_hidden_size)
        return projector.to(hidden)

    # 获取实例投影器
    @singleton('instance_projector')
    def _get_instance_projector(self, hidden):
        _, dim = hidden.shape
        projector = MLP(dim, self.projection_size, self.projection_hidden_size)
        return projector.to(hidden)

    # 获取表示
    def get_representation(self, x):
        if not self.hook_registered:
            self._register_hook()

        _ = self.net(x)
        hidden_pixel = self.hidden_pixel
        hidden_instance = self.hidden_instance
        self.hidden_pixel = None
        self.hidden_instance = None
        assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output'
        assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output'
        return hidden_pixel, hidden_instance

    # 前向传播
    def forward(self, x):
        pixel_representation, instance_representation = self.get_representation(x)
        instance_representation = instance_representation.flatten(1)

        pixel_projector = self._get_pixel_projector(pixel_representation)
        instance_projector = self._get_instance_projector(instance_representation)

        pixel_projection = pixel_projector(pixel_representation)
        instance_projection = instance_projector(instance_representation)
        return pixel_projection, instance_projection

# 主类

class PixelCL(nn.Module):
    # 初始化函数,设置模型参数和数据增强方式等
    def __init__(
        self,
        net,
        image_size,
        hidden_layer_pixel = -2,
        hidden_layer_instance = -2,
        projection_size = 256,
        projection_hidden_size = 2048,
        augment_fn = None,
        augment_fn2 = None,
        prob_rand_hflip = 0.25,
        moving_average_decay = 0.99,
        ppm_num_layers = 1,
        ppm_gamma = 2,
        distance_thres = 0.7,
        similarity_temperature = 0.3,
        alpha = 1.,
        use_pixpro = True,
        cutout_ratio_range = (0.6, 0.8),
        cutout_interpolate_mode = 'nearest',
        coord_cutout_interpolate_mode = 'bilinear'
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 默认的数据增强方式
        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomSolarize(p=0.5),
            augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        # 设置数据增强方式
        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)
        self.prob_rand_hflip = prob_rand_hflip

        # 在线编码器
        self.online_encoder = NetWrapper(
            net = net,
            projection_size = projection_size,
            projection_hidden_size = projection_hidden_size,
            layer_pixel = hidden_layer_pixel,
            layer_instance = hidden_layer_instance
        )

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.distance_thres = distance_thres
        self.similarity_temperature = similarity_temperature
        self.alpha = alpha

        self.use_pixpro = use_pixpro

        # 如果使用像素级处理
        if use_pixpro:
            self.propagate_pixels = PPM(
                chan = projection_size,
                num_layers = ppm_num_layers,
                gamma = ppm_gamma
            )

        self.cutout_ratio_range = cutout_ratio_range
        self.cutout_interpolate_mode = cutout_interpolate_mode
        self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode

        # 实例级别预测器
        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

        # 获取网络设备并将 wrapper 设置为相同设备
        device = get_module_device(net)
        self.to(device)

        # 发送一个模拟图像张量以实例化单例参数
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))

    # 获取目标编码器的单例函数
    @singleton('target_encoder')
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        set_requires_grad(target_encoder, False)
        return target_encoder

    # 重置移动平均值
    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    # 更新移动平均值
    def update_moving_average(self):
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
posted @ 2024-06-28 14:03  绝不原创的飞龙  阅读(6)  评论(0编辑  收藏  举报