Lucidrains-系列项目源码解析-十九-
Lucidrains 系列项目源码解析(十九)
.\lucidrains\gigagan-pytorch\gigagan_pytorch\version.py
# 定义变量 __version__,赋值为字符串 '0.2.20'
__version__ = '0.2.20'
.\lucidrains\gigagan-pytorch\gigagan_pytorch\__init__.py
# 从 gigagan_pytorch 模块中导入 GigaGAN 相关类
from gigagan_pytorch.gigagan_pytorch import (
GigaGAN,
Generator,
Discriminator,
VisionAidedDiscriminator,
AdaptiveConv2DMod,
StyleNetwork,
TextEncoder
)
# 从 gigagan_pytorch 模块中导入 UnetUpsampler 类
from gigagan_pytorch.unet_upsampler import UnetUpsampler
# 从 gigagan_pytorch 模块中导入数据相关类
from gigagan_pytorch.data import (
ImageDataset,
TextImageDataset,
MockTextImageDataset
)
# 定义 __all__ 列表,包含需要导出的类
__all__ = [
GigaGAN,
Generator,
Discriminator,
VisionAidedDiscriminator,
AdaptiveConv2DMod,
StyleNetwork,
UnetUpsampler,
TextEncoder,
ImageDataset,
TextImageDataset,
MockTextImageDataset
]
GigaGAN - Pytorch
Implementation of GigaGAN (project page), new SOTA GAN out of Adobe.
I will also add a few findings from lightweight gan, for faster convergence (skip layer excitation) and better stability (reconstruction auxiliary loss in discriminator)
It will also contain the code for the 1k - 4k upsamplers, which I find to be the highlight of this paper.
Please join if you are interested in helping out with the replication with the LAION community
Appreciation
-
StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
-
🤗 Huggingface for their accelerate library
-
All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models
-
Xavier for the very helpful code review, and for discussions on how the scale invariance in the discriminator should be built!
-
@CerebralSeed for pull requesting the initial sampling code for both the generator and upsampler!
-
Keerth for the code review and pointing out some discrepancies with the paper!
Install
$ pip install gigagan-pytorch
Usage
Simple unconditional GAN, for starters
import torch
from gigagan_pytorch import (
GigaGAN,
ImageDataset
)
gan = GigaGAN(
generator = dict(
dim_capacity = 8,
style_network = dict(
dim = 64,
depth = 4
),
image_size = 256,
dim_max = 512,
num_skip_layers_excite = 4,
unconditional = True
),
discriminator = dict(
dim_capacity = 16,
dim_max = 512,
image_size = 256,
num_skip_layers_excite = 4,
unconditional = True
),
amp = True
).cuda()
# dataset
dataset = ImageDataset(
folder = '/path/to/your/data',
image_size = 256
)
dataloader = dataset.get_dataloader(batch_size = 1)
# you must then set the dataloader for the GAN before training
gan.set_dataloader(dataloader)
# training the discriminator and generator alternating
# for 100 steps in this example, batch size 1, gradient accumulated 8 times
gan(
steps = 100,
grad_accum_every = 8
)
# after much training
images = gan.generate(batch_size = 4) # (4, 3, 256, 256)
For unconditional Unet Upsampler
import torch
from gigagan_pytorch import (
GigaGAN,
ImageDataset
)
gan = GigaGAN(
train_upsampler = True, # set this to True
generator = dict(
style_network = dict(
dim = 64,
depth = 4
),
dim = 32,
image_size = 256,
input_image_size = 64,
unconditional = True
),
discriminator = dict(
dim_capacity = 16,
dim_max = 512,
image_size = 256,
num_skip_layers_excite = 4,
multiscale_input_resolutions = (128,),
unconditional = True
),
amp = True
).cuda()
dataset = ImageDataset(
folder = '/path/to/your/data',
image_size = 256
)
dataloader = dataset.get_dataloader(batch_size = 1)
gan.set_dataloader(dataloader)
# training the discriminator and generator alternating
# for 100 steps in this example, batch size 1, gradient accumulated 8 times
gan(
steps = 100,
grad_accum_every = 8
)
# after much training
lowres = torch.randn(1, 3, 64, 64).cuda()
images = gan.generate(lowres) # (1, 3, 256, 256)
Losses
G
- GeneratorMSG
- Multiscale GeneratorD
- DiscriminatorMSD
- Multiscale DiscriminatorGP
- Gradient PenaltySSL
- Auxiliary Reconstruction in Discriminator (from Lightweight GAN)VD
- Vision-aided DiscriminatorVG
- Vision-aided GeneratorCL
- Generator Constrastive LossMAL
- Matching Aware Loss
A healthy run would have G
, MSG
, D
, MSD
with values hovering between 0
to 10
, and usually staying pretty constant. If at any time after 1k training steps these values persist at triple digits, that would mean something is wrong. It is ok for generator and discriminator values to occasionally dip negative, but it should swing back up to the range above.
GP
and SSL
should be pushed towards 0
. GP
can occasionally spike; I like to imagine it as the networks undergoing some epiphany
Multi-GPU Training
The GigaGAN
class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their accelerate
CLI
At the project root directory, where the training script is, run
$ accelerate config
Then, in the same directory
$ accelerate launch train.py
Todo
Citations
@misc{https://doi.org/10.48550/arxiv.2303.05511,
url = {https://arxiv.org/abs/2303.05511},
author = {Kang, Minguk and Zhu, Jun-Yan and Zhang, Richard and Park, Jaesik and Shechtman, Eli and Paris, Sylvain and Park, Taesung},
title = {Scaling up GANs for Text-to-Image Synthesis},
publisher = {arXiv},
year = {2023},
copyright = {arXiv.org perpetual, non-exclusive license}
}
@article{Liu2021TowardsFA,
title = {Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis},
author = {Bingchen Liu and Yizhe Zhu and Kunpeng Song and A. Elgammal},
journal = {ArXiv},
year = {2021},
volume = {abs/2101.04775}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@inproceedings{Karras2020ada,
title = {Training Generative Adversarial Networks with Limited Data},
author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
booktitle = {Proc. NeurIPS},
year = {2020}
}
.\lucidrains\gigagan-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('gigagan_pytorch/version.py').read())
# 设置包的元数据
setup(
name = 'gigagan-pytorch', # 包名
packages = find_packages(exclude=[]), # 查找包
version = __version__, # 版本号
license='MIT', # 许可证
description = 'GigaGAN - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/ETSformer-pytorch', # URL
keywords = [ # 关键词
'artificial intelligence',
'deep learning',
'generative adversarial networks'
],
install_requires=[ # 安装依赖
'accelerate',
'beartype',
'einops>=0.6',
'ema-pytorch',
'kornia',
'numerize',
'open-clip-torch>=2.0.0,<3.0.0',
'pillow',
'torch>=1.6',
'torchvision',
'tqdm'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\global-self-attention-network\gsa_pytorch\gsa_pytorch.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange 函数
from einops import rearrange
# 从 inspect 中导入 isfunction 函数
# 辅助函数
# 如果 val 存在则返回 val,否则返回 d()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# 判断 val 是否存在
def exists(val):
return val is not None
# 计算重新索引张量
def calc_reindexing_tensor(l, L, device):
"""
Appendix B - (5)
"""
# 创建 x 张量
x = torch.arange(l, device = device)[:, None, None]
# 创建 i 张量
i = torch.arange(l, device = device)[None, :, None]
# 创建 r 张量
r = torch.arange(-(L - 1), L, device = device)[None, None, :]
# 创建 mask 张量
mask = ((i - x) == r) & ((i - x).abs() <= L)
return mask.float()
# 类
# GSA 类
class GSA(nn.Module):
# 初始化函数
def __init__(self, dim, *, rel_pos_length = None, dim_out = None, heads = 8, dim_key = 64, norm_queries = False, batch_norm = True):
super().__init__()
dim_out = default(dim_out, dim)
dim_hidden = dim_key * heads
self.heads = heads
self.dim_out = dim_out
self.rel_pos_length = rel_pos_length
self.norm_queries = norm_queries
# 创建卷积层,用于将输入转换为查询、键和值
self.to_qkv = nn.Conv2d(dim, dim_hidden * 3, 1, bias = False)
# 创建卷积层,用于将隐藏层转换为输出维度
self.to_out = nn.Conv2d(dim_hidden, dim_out, 1)
self.rel_pos_length = rel_pos_length
if exists(rel_pos_length):
num_rel_shifts = 2 * rel_pos_length - 1
self.norm = nn.BatchNorm2d(dim_key) if batch_norm else None
self.rel_rows = nn.Parameter(torch.randn(num_rel_shifts, dim_key))
self.rel_columns = nn.Parameter(torch.randn(num_rel_shifts, dim_key))
# 前向传播函数
def forward(self, img):
# 获取输入张量的形状信息
b, c, x, y, h, c_out, L, device = *img.shape, self.heads, self.dim_out, self.rel_pos_length, img.device
# 将输入张量通过 to_qkv 卷积层得到查询、键和值
qkv = self.to_qkv(img).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c (x y)', h = h), qkv)
# 对键进行 softmax 操作
k = k.softmax(dim = -1)
# 计算上下文信息
context = einsum('ndm,nem->nde', k, v)
# 如果需要对查询进行归一化,则进行 softmax 操作
content_q = q if not self.norm_queries else q.softmax(dim=-2)
# 计算内容输出
content_out = einsum('nde,ndm->nem', context, content_q)
content_out = rearrange(content_out, 'n d (x y) -> n d x y', x = x, y = y)
# 根据附录 B (6) - (8) 中的数学实现细节进行处理
if exists(self.rel_pos_length):
q, v = map(lambda t: rearrange(t, 'n c (x y) -> n c x y', x = x, y = y), (q, v))
Ix = calc_reindexing_tensor(x, L, device)
Px = einsum('xir,rd->xid', Ix, self.rel_rows)
Sx = einsum('ndxy,xid->nixy', q, Px)
Yh = einsum('nixy,neiy->nexy', Sx, v)
if exists(self.norm):
Yh = self.norm(Yh)
Iy = calc_reindexing_tensor(y, L, device)
Py = einsum('yir,rd->yid', Iy, self.rel_columns)
Sy = einsum('ndxy,yid->nixy', q, Py)
rel_pos_out = einsum('nixy,nexi->nexy', Sy, Yh)
content_out = content_out + rel_pos_out.contiguous()
content_out = rearrange(content_out, '(b h) c x y -> b (h c) x y', h = h)
return self.to_out(content_out)
.\lucidrains\global-self-attention-network\gsa_pytorch\__init__.py
# 从 gsa_pytorch 模块中导入 GSA 类
from gsa_pytorch.gsa_pytorch import GSA
Global Self-attention Network
An implementation of Global Self-Attention Network, which proposes an all-attention vision backbone that achieves better results than convolutions with less parameters and compute.
They use a previously discovered linear attention variant with a small modification for further gains (no normalization of the queries), paired with relative positional attention, computed axially for efficiency.
The result is an extremely simple circuit composed of 8 einsums, 1 softmax, and normalization.
Install
$ pip install gsa-pytorch
Usage
import torch
from gsa_pytorch import GSA
gsa = GSA(
dim = 3,
dim_out = 64,
dim_key = 32,
heads = 8,
rel_pos_length = 256 # in paper, set to max(height, width). you can also turn this off by omitting this line
)
x = torch.randn(1, 3, 256, 256)
gsa(x) # (1, 64, 256, 256)
Citations
@inproceedings{
anonymous2021global,
title={Global Self-Attention Networks},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=KiFeuZu24k},
note={under review}
}
.\lucidrains\global-self-attention-network\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'gsa-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.2.2', # 版本号
license='MIT', # 许可证
description = 'Global Self-attention Network (GSA) - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/global-self-attention-network', # 项目链接
keywords = [
'artificial intelligence', # 关键词:人工智能
'attention mechanism', # 关键词:注意力机制
'image recognition' # 关键词:图像识别
],
install_requires=[
'torch>=1.6', # 安装所需的依赖项:torch 版本大于等于 1.6
'einops>=0.3' # 安装所需的依赖项:einops 版本大于等于 0.3
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器:开发状态为 Beta
'Intended Audience :: Developers', # 分类器:面向的受众为开发者
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器:主题为科学/工程 - 人工智能
'License :: OSI Approved :: MIT License', # 分类器:许可证为 MIT
'Programming Language :: Python :: 3.6', # 分类器:编程语言为 Python 3.6
],
)
.\lucidrains\glom-pytorch\glom_pytorch\glom_pytorch.py
# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 和 functional 模块
import torch.nn.functional as F
# 从 torch 模块中导入 einsum 函数
from torch import nn, einsum
# 从 einops 模块中导入 rearrange 和 repeat 函数,以及 torch 模块中的 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# 常量定义
# 定义 TOKEN_ATTEND_SELF_VALUE 常量为 -5e-4
TOKEN_ATTEND_SELF_VALUE = -5e-4
# 辅助函数
# 定义 exists 函数,判断值是否存在
def exists(val):
return val is not None
# 定义 default 函数,如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 类定义
# 定义 GroupedFeedForward 类
class GroupedFeedForward(nn.Module):
def __init__(self, *, dim, groups, mult = 4):
super().__init__()
total_dim = dim * groups # 计算总维度
# 定义神经网络结构
self.net = nn.Sequential(
Rearrange('b n l d -> b (l d) n'),
nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups),
nn.GELU(),
nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups),
Rearrange('b (l d) n -> b n l d', l = groups)
)
# 前向传播函数
def forward(self, levels):
return self.net(levels)
# 定义 ConsensusAttention 类
class ConsensusAttention(nn.Module):
def __init__(self, num_patches_side, attend_self = True, local_consensus_radius = 0):
super().__init__()
self.attend_self = attend_self
self.local_consensus_radius = local_consensus_radius
# 如果存在局部一致性半径
if self.local_consensus_radius > 0:
# 生成坐标网格
coors = torch.stack(torch.meshgrid(
torch.arange(num_patches_side),
torch.arange(num_patches_side)
)).float()
coors = rearrange(coors, 'c h w -> (h w) c')
dist = torch.cdist(coors, coors)
mask_non_local = dist > self.local_consensus_radius
mask_non_local = rearrange(mask_non_local, 'i j -> () i j')
self.register_buffer('non_local_mask', mask_non_local)
# 前向传播函数
def forward(self, levels):
_, n, _, d, device = *levels.shape, levels.device
q, k, v = levels, F.normalize(levels, dim = -1), levels
sim = einsum('b i l d, b j l d -> b l i j', q, k) * (d ** -0.5)
if not self.attend_self:
self_mask = torch.eye(n, device = device, dtype = torch.bool)
self_mask = rearrange(self_mask, 'i j -> () () i j')
sim.masked_fill_(self_mask, TOKEN_ATTEND_SELF_VALUE)
if self.local_consensus_radius > 0:
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(self.non_local_mask, max_neg_value)
attn = sim.softmax(dim = -1)
out = einsum('b l i j, b j l d -> b i l d', attn, levels)
return out
# 主类定义
# 定义 Glom 类
class Glom(nn.Module):
def __init__(
self,
*,
dim = 512,
levels = 6,
image_size = 224,
patch_size = 14,
consensus_self = False,
local_consensus_radius = 0
):
super().__init__()
# 计算每个边上的补丁数量
num_patches_side = (image_size // patch_size)
num_patches = num_patches_side ** 2
self.levels = levels
# 图像转换为标记的神经网络结构
self.image_to_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_size ** 2 * 3, dim)
)
self.pos_emb = nn.Embedding(num_patches, dim)
# 列的所有级别的初始嵌入
self.init_levels = nn.Parameter(torch.randn(levels, dim))
# 自下而上和自上而下
self.bottom_up = GroupedFeedForward(dim = dim, groups = levels)
self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1)
# 一致性注意力
self.attention = ConsensusAttention(num_patches_side, attend_self = consensus_self, local_consensus_radius = local_consensus_radius)
# 定义前向传播函数,接受输入图像和可选参数,返回处理后的结果
def forward(self, img, iters = None, levels = None, return_all = False):
# 获取输入图像的形状和设备信息
b, device = img.shape[0], img.device
# 如果未提供迭代次数,则设置为默认值(层级数的两倍),以便信息在上下传播时能够传播
iters = default(iters, self.levels * 2)
# 将图像转换为 tokens
tokens = self.image_to_tokens(img)
n = tokens.shape[1]
# 生成位置编码
pos_embs = self.pos_emb(torch.arange(n, device = device))
pos_embs = rearrange(pos_embs, 'n d -> () n () d')
# 初始化底层 tokens
bottom_level = tokens
bottom_level = rearrange(bottom_level, 'b n d -> b n () d')
# 如果未提供层级信息,则使用初始层级信息
if not exists(levels):
levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)
# 存储每次迭代后的隐藏层信息
hiddens = [levels]
# 初始化每个层级的贡献次数
num_contributions = torch.empty(self.levels, device = device).fill_(4)
num_contributions[-1] = 3 # 顶层不会得到来自顶部的贡献,因此需要考虑这一点在计算加权平均时
# 迭代处理
for _ in range(iters):
# 将原始输入附加到最底层,用于自底向上
levels_with_input = torch.cat((bottom_level, levels), dim = -2)
# 底部向上处理
bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :])
# 顶部向下处理,加上位置编码
top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs)
top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)
# 计算共识信息
consensus = self.attention(levels)
# 计算加权平均值
levels_sum = torch.stack((levels, bottom_up_out, top_down_out, consensus)).sum(dim = 0)
levels_mean = levels_sum / rearrange(num_contributions, 'l -> () () l ()')
# 更新层级信息,用于下一次迭代
levels = levels_mean
hiddens.append(levels)
# 如果需要返回所有隐藏层信息,则返回整个列表
if return_all:
return torch.stack(hiddens)
# 否则,只返回最终的层级信息
return levels
.\lucidrains\glom-pytorch\glom_pytorch\__init__.py
# 从 glom_pytorch 模块中导入 Glom 类
from glom_pytorch.glom_pytorch import Glom
GLOM - Pytorch
An implementation of Glom, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for learning emergent part-whole heirarchies from data.
Yannic Kilcher's video was instrumental in helping me to understand this paper
Install
$ pip install glom-pytorch
Usage
import torch
from glom_pytorch import Glom
model = Glom(
dim = 512, # dimension
levels = 6, # number of levels
image_size = 224, # image size
patch_size = 14 # patch size
)
img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)
Pass the return_all = True
keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.
It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.
import torch
from glom_pytorch import Glom
model = Glom(
dim = 512, # dimension
levels = 6, # number of levels
image_size = 224, # image size
patch_size = 14 # patch size
)
img = torch.randn(1, 3, 224, 224)
all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)
# get the top level outputs after iteration 6
top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)
Denoising self-supervised learning for encouraging emergence, as described by Hinton
import torch
import torch.nn.functional as F
from torch import nn
from einops.layers.torch import Rearrange
from glom_pytorch import Glom
model = Glom(
dim = 512, # dimension
levels = 6, # number of levels
image_size = 224, # image size
patch_size = 14 # patch size
)
img = torch.randn(1, 3, 224, 224)
noised_img = img + torch.randn_like(img)
all_levels = model(noised_img, return_all = True)
patches_to_images = nn.Sequential(
nn.Linear(512, 14 * 14 * 3),
Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)
top_level = all_levels[7, :, :, -1] # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)
# do self-supervised learning by denoising
loss = F.mse_loss(img, recon_img)
loss.backward()
You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)
import torch
from glom_pytorch import Glom
model = Glom(
dim = 512,
levels = 6,
image_size = 224,
patch_size = 14
)
img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)
levels1 = model(img1, iters = 12) # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6) # image 3 for 6 iterations
Appreciation
Thanks goes out to Cfoster0 for reviewing the code
Todo
Citations
@misc{hinton2021represent,
title = {How to represent part-whole hierarchies in a neural network},
author = {Geoffrey Hinton},
year = {2021},
eprint = {2102.12627},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
.\lucidrains\glom-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'glom-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.14', # 版本号
license='MIT', # 许可证
description = 'Glom - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/glom-pytorch', # 项目链接
keywords = [
'artificial intelligence', # 关键词
'deep learning'
],
install_requires=[
'einops>=0.3', # 安装所需的依赖包
'torch>=1.6'
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\gradnorm-pytorch\gradnorm_pytorch\gradnorm_pytorch.py
# 导入必要的库
from functools import cache, partial
import torch
import torch.distributed as dist
from torch.autograd import grad
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList, Parameter
from einops import rearrange, repeat
from accelerate import Accelerator
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Optional, Union, List, Dict, Tuple, NamedTuple
# 辅助函数
# 检查变量是否存在
def exists(v):
return v is not None
# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
return v if exists(v) else d
# 张量辅助函数
# 计算张量的 L1 范数
def l1norm(t, dim = -1):
return F.normalize(t, p = 1, dim = dim)
# 分布式计算辅助函数
# 判断是否处于分布式环境
@cache
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1
# 如果处于分布式环境,则计算张量的均值
def maybe_distributed_mean(t):
if not is_distributed():
return t
dist.all_reduce(t)
t = t / dist.get_world_size()
return t
# 主类
class GradNormLossWeighter(Module):
@beartype
def __init__(
self,
*,
num_losses: Optional[int] = None,
loss_weights: Optional[Union[
List[float],
Tensor
]] = None,
loss_names: Optional[Tuple[str, ...]] = None,
learning_rate = 1e-4,
restoring_force_alpha = 0.,
grad_norm_parameters: Optional[Parameter] = None,
accelerator: Optional[Accelerator] = None,
frozen = False,
initial_losses_decay = 1.,
update_after_step = 0.,
update_every = 1.
):
super().__init__()
assert exists(num_losses) or exists(loss_weights)
if exists(loss_weights):
if isinstance(loss_weights, list):
loss_weights = torch.tensor(loss_weights)
num_losses = default(num_losses, loss_weights.numel())
else:
loss_weights = torch.ones((num_losses,), dtype = torch.float32)
assert len(loss_weights) == num_losses
assert num_losses > 1, 'only makes sense if you have multiple losses'
assert loss_weights.ndim == 1, 'loss weights must be 1 dimensional'
self.accelerator = accelerator
self.num_losses = num_losses
self.frozen = frozen
self.loss_names = loss_names
assert not exists(loss_names) or len(loss_names) == num_losses
assert restoring_force_alpha >= 0.
self.alpha = restoring_force_alpha
self.has_restoring_force = self.alpha > 0
self._grad_norm_parameters = [grad_norm_parameters] # hack
# 损失权重,可以是学习得到的或静态的
self.register_buffer('loss_weights', loss_weights)
self.learning_rate = learning_rate
# 初始损失
# 如果初始损失衰减设置为小于1,则会对初始损失进行 EMA 平滑处理
assert 0 <= initial_losses_decay <= 1.
self.initial_losses_decay = initial_losses_decay
self.register_buffer('initial_losses', torch.zeros(num_losses))
# 用于在最后重新归一化损失权重
self.register_buffer('loss_weights_sum', self.loss_weights.sum())
# 用于梯度累积
self.register_buffer('loss_weights_grad', torch.zeros_like(loss_weights), persistent = False)
# 步数,用于可能的调度等
self.register_buffer('step', torch.tensor(0.))
# 可以较少频繁更新,以节省计算资源
self.update_after_step = update_after_step
self.update_every = update_every
self.register_buffer('initted', torch.tensor(False))
@property
def grad_norm_parameters(self):
return self._grad_norm_parameters[0]
def backward(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@beartype
# 定义一个 forward 方法,用于前向传播
def forward(
self,
losses: Union[
Dict[str, Tensor], # 损失值可以是字典类型,键为字符串,值为张量
List[Tensor], # 损失值可以是张量列表
Tuple[Tensor], # 损失值可以是元组中的张量
Tensor # 损失值可以是单个张量
],
activations: Optional[Tensor] = None, # 激活值,默认为 None,在论文中,他们使用了从骨干层次的倒数第二个参数的梯度范数。但这也可以是激活值(例如,共享的图像被馈送到多个鉴别器)
freeze = False, # 可以选择在前向传播时冻结可学习的损失权重
scale = 1., # 缩放因子,默认为 1
grad_step = True, # 是否进行梯度步骤,默认为 True
**backward_kwargs # 其他后向传播参数
.\lucidrains\gradnorm-pytorch\gradnorm_pytorch\mocks.py
# 导入 torch 中的 nn 模块
from torch import nn
# 定义一个带有多个损失函数的模拟网络类
class MockNetworkWithMultipleLosses(nn.Module):
# 初始化函数,接受维度和损失函数数量作为参数
def __init__(
self,
dim,
num_losses = 2
):
# 调用父类的初始化函数
super().__init__()
# 定义网络的主干部分,包括线性层、SiLU 激活函数和另一个线性层
self.backbone = nn.Sequential(
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
# 定义多个判别器,每个判别器都是一个线性层,数量由参数 num_losses 决定
self.discriminators = nn.ModuleList([
nn.Linear(dim, 1) for _ in range(num_losses)
])
# 前向传播函数,接受输入 x
def forward(self, x):
# 将输入 x 通过主干部分得到输出
backbone_output = self.backbone(x)
# 初始化损失列表
losses = []
# 遍历每个判别器
for discr in self.discriminators:
# 计算判别器的输出作为损失
loss = discr(backbone_output)
# 将损失的均值添加到损失列表中
losses.append(loss.mean())
# 返回损失列表和主干部分的输出
return losses, backbone_output
.\lucidrains\gradnorm-pytorch\gradnorm_pytorch\__init__.py
# 从 gradnorm_pytorch.gradnorm_pytorch 模块中导入 GradNormLossWeighter 类
# 从 gradnorm_pytorch.mocks 模块中导入 MockNetworkWithMultipleLosses 类
from gradnorm_pytorch.gradnorm_pytorch import GradNormLossWeighter
from gradnorm_pytorch.mocks import MockNetworkWithMultipleLosses
GradNorm - Pytorch
A practical implementation of GradNorm, Gradient Normalization for Adaptive Loss Balancing, in Pytorch
Increasingly starting to come across neural network architectures that require more than 3 auxiliary losses, so will build out an installable package that easily handles loss balancing in distributed setting, gradient accumulation, etc. Also open to incorporating any follow up research; just let me know in the issues.
Will be dog-fooded for SoundStream, MagViT2 as well as MetNet3
Appreciation
- StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install gradnorm-pytorch
Usage
import torch
from gradnorm_pytorch import (
GradNormLossWeighter,
MockNetworkWithMultipleLosses
)
# a mock network with multiple discriminator losses
network = MockNetworkWithMultipleLosses(
dim = 512,
num_losses = 4
)
# backbone shared parameter
backbone_parameter = network.backbone[-1].weight
# grad norm based loss weighter
loss_weighter = GradNormLossWeighter(
num_losses = 4,
learning_rate = 1e-4,
restoring_force_alpha = 0., # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3.
grad_norm_parameters = backbone_parameter
)
# mock input
mock_input = torch.randn(2, 512)
losses, backbone_output_activations = network(mock_input)
# backwards with the loss weights
# will update on each backward based on gradnorm algorithm
loss_weighter.backward(losses, retain_graph = True)
# if you would like to update the loss weights wrt activations just do the following instead
loss_weighter.backward(losses, backbone_output_activations)
You can also switch it to basic static loss weighting, in case you want to run experiments against fixed weighting.
loss_weighter = GradNormLossWeighter(
loss_weights = [1., 10., 5., 2.],
...,
frozen = True
)
# or you can also freeze it on invoking the instance
loss_weighter.backward(..., freeze = True)
For use with 🤗 Huggingface Accelerate, just pass in the Accelerator
instance into the keyword accelerator
on initialization
ex.
accelerator = Accelerator()
network = accelerator.prepare(network)
loss_weighter = GradNormLossWeighter(
...,
accelerator = accelerator
)
# backwards will now use accelerator
Todo
Citations
@article{Chen2017GradNormGN,
title = {GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks},
author = {Zhao Chen and Vijay Badrinarayanan and Chen-Yu Lee and Andrew Rabinovich},
journal = {ArXiv},
year = {2017},
volume = {abs/1711.02257},
url = {https://api.semanticscholar.org/CorpusID:4703661}
}
.\lucidrains\gradnorm-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包名
name = 'gradnorm-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.0.26',
# 许可证
license='MIT',
# 描述
description = 'GradNorm - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/gradnorm-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'loss balancing',
'gradient normalization'
],
# 安装依赖
install_requires=[
'accelerate',
'beartype',
'einops>=0.7.0',
'torch>=2.0'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\graph-transformer-pytorch\graph_transformer_pytorch\graph_transformer_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum
from torch import nn, einsum
# 从 einops 库中导入 rearrange, repeat
from einops import rearrange, repeat
# 从 rotary_embedding_torch 库中导入 RotaryEmbedding, apply_rotary_emb
# helpers
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 定义 nn.ModuleList 类别为 List
List = nn.ModuleList
# normalizations
# 预处理层,包含 LayerNorm 和传入的函数
class PreNorm(nn.Module):
def __init__(
self,
dim,
fn
):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, *args, **kwargs):
x = self.norm(x)
return self.fn(x, *args,**kwargs)
# gated residual
# 残差连接层
class Residual(nn.Module):
def forward(self, x, res):
return x + res
# 带门控的残差连接层
class GatedResidual(nn.Module):
def __init__(self, dim):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(dim * 3, 1, bias = False),
nn.Sigmoid()
)
def forward(self, x, res):
gate_input = torch.cat((x, res, x - res), dim = -1)
gate = self.proj(gate_input)
return x * gate + res * (1 - gate)
# attention
# 注意力机制层
class Attention(nn.Module):
def __init__(
self,
dim,
pos_emb = None,
dim_head = 64,
heads = 8,
edge_dim = None
):
super().__init__()
edge_dim = default(edge_dim, dim)
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.pos_emb = pos_emb
self.to_q = nn.Linear(dim, inner_dim)
self.to_kv = nn.Linear(dim, inner_dim * 2)
self.edges_to_kv = nn.Linear(edge_dim, inner_dim)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, nodes, edges, mask = None):
h = self.heads
q = self.to_q(nodes)
k, v = self.to_kv(nodes).chunk(2, dim = -1)
e_kv = self.edges_to_kv(edges)
q, k, v, e_kv = map(lambda t: rearrange(t, 'b ... (h d) -> (b h) ... d', h = h), (q, k, v, e_kv))
if exists(self.pos_emb):
freqs = self.pos_emb(torch.arange(nodes.shape[1], device = nodes.device))
freqs = rearrange(freqs, 'n d -> () n d')
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)
ek, ev = e_kv, e_kv
k, v = map(lambda t: rearrange(t, 'b j d -> b () j d '), (k, v))
k = k + ek
v = v + ev
sim = einsum('b i d, b i j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b i -> b i ()') & rearrange(mask, 'b j -> b () j')
mask = repeat(mask, 'b i j -> (b h) i j', h = h)
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim = -1)
out = einsum('b i j, b i j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.to_out(out)
# optional feedforward
# 可选的前馈神经网络层
def FeedForward(dim, ff_mult = 4):
return nn.Sequential(
nn.Linear(dim, dim * ff_mult),
nn.GELU(),
nn.Linear(dim * ff_mult, dim)
)
# classes
# 图形变换器模型
class GraphTransformer(nn.Module):
def __init__(
self,
dim,
depth,
dim_head = 64,
edge_dim = None,
heads = 8,
gated_residual = True,
with_feedforwards = False,
norm_edges = False,
rel_pos_emb = False,
accept_adjacency_matrix = False
# 初始化函数,继承父类的初始化方法
):
# 调用父类的初始化方法
super().__init__()
# 初始化图神经网络的层列表
self.layers = List([])
# 设置边的维度,默认为节点的维度
edge_dim = default(edge_dim, dim)
# 如果需要对边进行归一化,则使用 LayerNorm 进行归一化,否则使用恒等映射
self.norm_edges = nn.LayerNorm(edge_dim) if norm_edges else nn.Identity()
# 如果需要接受邻接矩阵,则使用 Embedding 层进行嵌入,否则设为 None
self.adj_emb = nn.Embedding(2, edge_dim) if accept_adjacency_matrix else None
# 如果需要相对位置编码,则使用 RotaryEmbedding 进行编码,否则设为 None
pos_emb = RotaryEmbedding(dim_head) if rel_pos_emb else None
# 循环创建指定深度的图神经网络层
for _ in range(depth):
# 添加每一层的注意力机制和前馈网络
self.layers.append(List([
List([
# 使用预归一化和注意力机制
PreNorm(dim, Attention(dim, pos_emb = pos_emb, edge_dim = edge_dim, dim_head = dim_head, heads = heads)),
GatedResidual(dim)
]),
List([
# 使用预归一化和前馈网络
PreNorm(dim, FeedForward(dim)),
GatedResidual(dim)
]) if with_feedforwards else None
]))
# 前向传播函数
def forward(
self,
nodes,
edges = None,
adj_mat = None,
mask = None
):
# 获取节点的批次大小、序列长度和维度
batch, seq, _ = nodes.shape
# 如果存在边信息,则对边进行归一化处理
if exists(edges):
edges = self.norm_edges(edges)
# 如果存在邻接矩阵,则进行相应处理
if exists(adj_mat):
assert adj_mat.shape == (batch, seq, seq)
assert exists(self.adj_emb), 'accept_adjacency_matrix must be set to True'
adj_mat = self.adj_emb(adj_mat.long())
# 组合所有边信息
all_edges = default(edges, 0) + default(adj_mat, 0)
# 遍历每一层的注意力机制和前馈网络
for attn_block, ff_block in self.layers:
attn, attn_residual = attn_block
# 使用注意力机制和门控残差连接更新节点信息
nodes = attn_residual(attn(nodes, all_edges, mask = mask), nodes)
# 如果存在前馈网络,则使用前馈网络和门控残差连接更新节点信息
if exists(ff_block):
ff, ff_residual = ff_block
nodes = ff_residual(ff(nodes), nodes)
# 返回更新后的节点信息和边信息
return nodes, edges
.\lucidrains\graph-transformer-pytorch\graph_transformer_pytorch\__init__.py
# 从 graph_transformer_pytorch 包中导入 GraphTransformer 类
from graph_transformer_pytorch.graph_transformer_pytorch import GraphTransformer
Graph Transformer - Pytorch
Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2. This was recently used by both Costa et al and Bakers lab for transforming MSA and pair-wise embedding into 3d coordinates.
Install
$ pip install graph-transformer-pytorch
Usage
import torch
from graph_transformer_pytorch import GraphTransformer
model = GraphTransformer(
dim = 256,
depth = 6,
edge_dim = 512, # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
with_feedforwards = True, # whether to add a feedforward after each attention layer, suggested by literature to be needed
gated_residual = True, # to use the gated residual to prevent over-smoothing
rel_pos_emb = True # set to True if the nodes are ordered, default to False
)
nodes = torch.randn(1, 128, 256)
edges = torch.randn(1, 128, 128, 512)
mask = torch.ones(1, 128).bool()
nodes, edges = model(nodes, edges, mask = mask)
nodes.shape # (1, 128, 256) - project to R^3 for coordinates
If you want it to handle an adjacency matrix
import torch
from graph_transformer_pytorch import GraphTransformer
model = GraphTransformer(
dim = 256,
depth = 6,
edge_dim = 512,
with_feedforwards = True,
gated_residual = True,
rel_pos_emb = True,
accept_adjacency_matrix = True # set this to True
)
nodes = torch.randn(2, 128, 256)
adj_mat = torch.randint(0, 2, (2, 128, 128))
mask = torch.ones(2, 128).bool()
nodes, edges = model(nodes, adj_mat = adj_mat, mask = mask)
nodes.shape # (1, 128, 256) - project to R^3 for coordinates
Citations
@article {Costa2021.06.02.446809,
author = {Costa, Allan and Ponnapati, Manvitha and Jacobson, Joseph M. and Chatterjee, Pranam},
title = {Distillation of MSA Embeddings to Folded Protein Structures with Graph Transformers},
year = {2021},
doi = {10.1101/2021.06.02.446809},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809},
eprint = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809.full.pdf},
journal = {bioRxiv}
}
@article {Baek2021.06.14.448402,
author = {Baek, Minkyung and DiMaio, Frank and Anishchenko, Ivan and Dauparas, Justas and Ovchinnikov, Sergey and Lee, Gyu Rie and Wang, Jue and Cong, Qian and Kinch, Lisa N. and Schaeffer, R. Dustin and Mill{\'a}n, Claudia and Park, Hahnbeom and Adams, Carson and Glassman, Caleb R. and DeGiovanni, Andy and Pereira, Jose H. and Rodrigues, Andria V. and van Dijk, Alberdina A. and Ebrecht, Ana C. and Opperman, Diederik J. and Sagmeister, Theo and Buhlheller, Christoph and Pavkov-Keller, Tea and Rathinaswamy, Manoj K and Dalwadi, Udit and Yip, Calvin K and Burke, John E and Garcia, K. Christopher and Grishin, Nick V. and Adams, Paul D. and Read, Randy J. and Baker, David},
title = {Accurate prediction of protein structures and interactions using a 3-track network},
year = {2021},
doi = {10.1101/2021.06.14.448402},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402},
eprint = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402.full.pdf},
journal = {bioRxiv}
}
@misc{shi2021masked,
title = {Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification},
author = {Yunsheng Shi and Zhengjie Huang and Shikun Feng and Hui Zhong and Wenjin Wang and Yu Sun},
year = {2021},
eprint = {2009.03509},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\graph-transformer-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'graph-transformer-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.1.1', # 版本号
license='MIT', # 许可证
description = 'Graph Transformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/graph-transformer-pytorch', # 项目链接
long_description_content_type = 'text/markdown', # 长描述内容类型
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'transformers', # 关键词
'graphs' # 关键词
],
install_requires=[
'einops>=0.3', # 安装所需的依赖包
'rotary-embedding-torch', # 安装所需的依赖包
'torch>=1.6' # 安装所需的依赖包
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers', # 分类器
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
'License :: OSI Approved :: MIT License', # 分类器
'Programming Language :: Python :: 3.6', # 分类器
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
.\lucidrains\h-transformer-1d\h_transformer_1d\autoregressive_wrapper.py
import torch
from torch import nn
import torch.nn.functional as F
# helper function
# 检查值是否存在
def exists(val):
return val is not None
# 装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# top k filtering
# 根据阈值过滤 logits,保留前 k 个值
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = -100, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列的方法,支持自定义起始标记、序列长度、结束标记、温度等参数
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
device = start_tokens.device
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
out = start_tokens
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
logits = self.net(x, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = (out == eos_token)
if is_eos_token.any(dim = -1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
# 前向传播方法,计算损失值
def forward(self, x, **kwargs):
xi = x[:, :-1]
xo = x[:, 1:]
out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
.\lucidrains\h-transformer-1d\h_transformer_1d\h_transformer_1d.py
# 从 math 模块中导入 log2 和 ceil 函数
# 从 functools 模块中导入 wraps 函数
import torch
# 从 torch 模块中导入 nn, einsum, diagonal 和 nn.functional 模块
from torch import nn, einsum, diagonal
import torch.nn.functional as F
# 从 h_transformer_1d.reversible 模块中导入 ReversibleSequence 和 SequentialSequence 类
from h_transformer_1d.reversible import ReversibleSequence, SequentialSequence
# 从 rotary_embedding_torch 模块中导入 apply_rotary_emb 和 RotaryEmbedding 类
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding
# 从 einops 模块中导入 rearrange, reduce, repeat 函数
# helpers
# 定义函数 exists,判断变量是否存在
def exists(val):
return val is not None
# 定义函数 masked_aggregate,对张量进行聚合操作
def masked_aggregate(tensor, mask = None, dim = -1, average = True):
if not exists(mask):
fn = torch.sum if not average else torch.mean
return fn(tensor, dim = dim)
diff_len = len(tensor.shape) - len(mask.shape)
mask = mask[(..., *((None,) * diff_len))]
tensor = tensor.masked_fill(~mask, 0.)
total_el = mask.sum(dim = dim)
agg = tensor.sum(dim = dim)
if average:
agg = agg / total_el.clamp(min = 1.)
agg.masked_fill_(total_el == 0, 0.)
return agg
# 定义函数 shift,对张量进行平移操作
def shift(t, amount, mask = None):
if amount == 0:
return t
if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
return F.pad(t, (0, 0, amount, -amount), value = 0.)
# helper classes
# 定义类 PreNorm,实现预层归一化
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 定义类 FeedForward,实现前馈神经网络
class FeedForward(nn.Module):
def __init__(
self,
dim,
*,
mult = 4
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# token shifting
# 定义类 PreShiftTokens,实现令牌平移
class PreShiftTokens(nn.Module):
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
self.shifts = tuple(shifts)
def forward(self, x, **kwargs):
mask = kwargs.get('mask', None)
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim = -1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
x = torch.cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)
# hierarchical attention helper functions
# 定义函数 cast_for_op,将张量转换为指定类型进行操作
def cast_for_op(cast_type, fn):
@wraps(fn)
def inner(t, *args, **kwargs):
orig_type = t.dtype
t = t.type(cast_type)
out = fn(t, *args, **kwargs)
out = out.type(orig_type)
return out
return inner
# 定义函数 flip_every_two,交换张量中每两个元素的位置
def flip_every_two(t):
t = rearrange(t, 'b (n r) ... -> b n r ...', r = 2)
t = torch.flip(t, dims = (2,)) # so we pay attention to the off-diagonal blocks in the attention matrix
t = rearrange(t, 'b n r ... -> b (n r) ...')
return t
# attention
# 定义类 HAttention1D,实现一维注意力机制
class HAttention1D(nn.Module):
def __init__(
self,
dim,
*,
heads = 8,
dim_head = 64,
block_size = 16,
pos_emb = None,
eps = 1e-8,
**kwargs
):
super().__init__()
self.eps = eps
self.heads = heads
self.scale = dim_head ** -0.5
self.block_size = block_size
inner_dim = heads * dim_head
self.pos_emb = pos_emb
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
# causal attention
# 定义类 CausalHAttention1D,实现一维因果注意力机制
class CausalHAttention1D(nn.Module):
def __init__(
self,
dim,
*,
max_seq_len,
heads = 8,
dim_head = 64,
block_size = 16,
eps = 1e-8,
pos_emb = None
):
# 调用父类的初始化方法
super().__init__()
# 初始化注意力机制的参数
self.eps = eps
self.heads = heads
self.scale = dim_head ** -0.5
self.block_size = block_size
inner_dim = heads * dim_head
# 设置位置编码
self.pos_emb = pos_emb
# 线性变换,将输入维度转换为内部维度的三倍,用于计算查询、键、值
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
# 线性变换,将内部维度转换为输出维度
self.to_out = nn.Linear(inner_dim, dim)
# 推导出掩码
# 计算级别数量
num_levels = int(log2(max_seq_len // block_size)) - 1
root_seq = torch.arange(max_seq_len)
seqs = [root_seq]
seq = root_seq
# 生成掩码序列
for ind in range(num_levels):
seq = rearrange(seq, '(n r) -> n r', r = 2)
seq = seq.max(dim = -1).values
expanded_mask_seq = repeat(seq, 'n -> (n r)', r = (2 ** (ind + 1)))
seqs.append(expanded_mask_seq)
# 将生成的掩码序列堆叠起来
seq_keys = torch.stack(seqs, dim = 0)
# 创建掩码,用于屏蔽无效位置
mask = seq_keys > rearrange(root_seq, 'n -> () n')
# 将掩码作为缓冲区注册到模型中
self.register_buffer('mask', mask)
# 主类定义
class HTransformer1D(nn.Module):
def __init__(
self,
*,
num_tokens, # 标记的数量
dim, # 向量维度
depth, # 深度
max_seq_len, # 最大序列长度
causal = False, # 是否因果
heads = 8, # 多头注意力的头数
dim_head = 64, # 每个头的维度
ff_mult = 4, # FeedForward 层的倍数
block_size = 128, # 块的大小,即 Nr
pos_emb = None, # 位置编码
reversible = False, # 是否可逆
shift_tokens = False # 是否移动标记
):
super().__init__()
assert (max_seq_len % block_size) == 0, 'maximum sequence length must be divisible by the block size'
num_blocks = max_seq_len // block_size
assert log2(max_seq_len // block_size).is_integer(), f'number of blocks {num_blocks} must be a power of 2'
self.token_emb = nn.Embedding(num_tokens, dim) # 标记嵌入层
self.pos_emb = RotaryEmbedding(dim = dim_head) # 位置编码
self.max_seq_len = max_seq_len
layers = nn.ModuleList([]) # 模块列表
attn_class = CausalHAttention1D if causal else HAttention1D # 根据是否因果选择不同的注意力类
attn_kwargs = dict(max_seq_len = max_seq_len) if causal else dict() # 如果是因果,传入最大序列长度参数
shift_token_ranges = (0, 1) if shift_tokens else (-1, 0, 1) # 如果移动标记,设置移动范围
for ind in range(depth):
attn = attn_class(dim, dim_head = dim_head, heads = heads, block_size = block_size, pos_emb = self.pos_emb, **attn_kwargs) # 创建注意力层
ff = FeedForward(dim, mult = ff_mult) # 创建 FeedForward 层
if shift_tokens:
attn, ff = map(lambda t: PreShiftTokens(shift_token_ranges, t), (attn, ff)) # 如果移动标记,对注意力和 FeedForward 层进行预移动标记处理
attn, ff = map(lambda t: PreNorm(dim, t), (attn, ff)) # 对注意力和 FeedForward 层进行预归一化处理
layers.append(nn.ModuleList([attn ,ff])) # 将注意力和 FeedForward 层添加到模块列表中
execute_type = ReversibleSequence if reversible else SequentialSequence # 根据是否可逆选择不同的执行类型
route_attn = ((True, False),) * depth # 设置注意力路由
attn_route_map = {'mask': route_attn} # 设置注意力路由映射
self.layers = execute_type(layers, args_route = {**attn_route_map}) # 创建执行类型的层
self.to_logits = nn.Sequential(
nn.LayerNorm(dim), # 归��化层
nn.Linear(dim, num_tokens) # 线性层,输出标记数量
)
def forward(self, x, mask = None):
b, n, device = *x.shape, x.device # 获取输入张量的形状和设备信息
assert n <= self.max_seq_len, 'sequence length must be less than the maximum sequence length' # 断言序列长度必须小于等于最大序列长度
x = self.token_emb(x) # 标记嵌入
x = self.layers(x, mask = mask) # 执行层
return self.to_logits(x) # 输出预测结果
.\lucidrains\h-transformer-1d\h_transformer_1d\reversible.py
# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数
# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
# 初始化路由后的参数列表
routed_args = [(dict(), dict()) for _ in range(depth)]
# 获取参数中与路由器匹配的键
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
# 根据概率丢弃层的函数
def layer_drop(layers, prob):
to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
blocks = [block for block, drop in zip(layers, to_drop) if not drop]
blocks = layers[:1] if len(blocks) == 0 else blocks
return blocks
# 保存和设置随机数种子的类
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 可逆块类,受启发于 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args = {}, g_args = {}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=2)
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
y1, y2 = torch.chunk(y, 2, dim=2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
return x, dx
# 可逆函数类
class _ReversibleFunction(Function):
@staticmethod
# 前向传播函数,接收上下文对象 ctx,输入数据 x,模块列表 blocks 和参数列表 args
def forward(ctx, x, blocks, args):
# 将参数列表 args 存储到上下文对象 ctx 中
ctx.args = args
# 遍历模块列表 blocks 和参数列表 args,对输入数据 x 进行处理
for block, kwarg in zip(blocks, args):
x = block(x, **kwarg)
# 将处理后的数据 x 分离出来,并存储到上下文对象 ctx 中
ctx.y = x.detach()
# 将模块列表 blocks 存储到上下文对象 ctx 中
ctx.blocks = blocks
# 返回处理后的数据 x
return x
# 反向传播函数,接收上下文对象 ctx 和梯度 dy
@staticmethod
def backward(ctx, dy):
# 获取上下文对象 ctx 中存储的处理后的数据 y 和参数列表 args
y = ctx.y
args = ctx.args
# 反向遍历模块列表 blocks 和参数列表 args,对梯度 dy 进行处理
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
# 调用模块的反向传播函数,更新梯度 dy 和数据 y
y, dy = block.backward_pass(y, dy, **kwargs)
# 返回更新后的梯度 dy
return dy, None, None
class SequentialSequence(nn.Module):
# 定义一个顺序执行的神经网络模块
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
super().__init__()
# 断言每个参数路由映射的深度与顺序层的数量相同
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
def forward(self, x, **kwargs):
# 根据参数路由和关键字参数获取参数
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
if self.training and self.layer_dropout > 0:
# 如果处于训练状态且存在层丢弃率,则执行层丢弃
layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
for (f, g), (f_args, g_args) in layers_and_args:
# 依次执行每个顺序层的前向传播
x = x + f(x, **f_args)
x = x + g(x, **g_args)
return x
class ReversibleSequence(nn.Module):
# 定义一个可逆的序列神经网络模块
def __init__(self, blocks, args_route = {}, layer_dropout = 0.):
super().__init__()
self.args_route = args_route
self.layer_dropout = layer_dropout
# 创建包含可逆块的模块列表
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
def forward(self, x, **kwargs):
# 在最后一个维度上连接输入张量的副本
x = torch.cat([x, x], dim=-1)
blocks = self.blocks
# 根据参数路由和关键字参数获取参数
args = route_args(self.args_route, kwargs, len(blocks))
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
layers_and_args = list(zip(blocks, args))
if self.training and self.layer_dropout > 0:
# 如果处于训练状态且存在层丢弃率,则执行层丢弃
layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1))
# 调用自定义的可逆函数进行前向传播
out = _ReversibleFunction.apply(x, blocks, args)
# 在最后一个维度上分割输出并求和
return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)
.\lucidrains\h-transformer-1d\h_transformer_1d\__init__.py
# 从 h_transformer_1d.h_transformer_1d 模块中导入 HTransformer1D 类
from h_transformer_1d.h_transformer_1d import HTransformer1D
H-Transformer-1D
Implementation of H-Transformer-1D, Transformer using hierarchical Attention for sequence learning with subquadratic costs. The encoder (non-autoregressive) flavor of this architecture currently holds the throne for Long Range Arena, a benchmark for efficient transformers.
Install
$ pip install h-transformer-1d
Usage
import torch
from h_transformer_1d import HTransformer1D
model = HTransformer1D(
num_tokens = 256, # number of tokens
dim = 512, # dimension
depth = 12, # depth
causal = False, # autoregressive or not
max_seq_len = 8192, # maximum sequence length
heads = 8, # heads
dim_head = 64, # dimension per head
block_size = 128, # block size
reversible = True, # use reversibility, to save on memory with increased depth
shift_tokens = True # whether to shift half the feature space by one along the sequence dimension, for faster convergence (experimental feature)
)
x = torch.randint(0, 256, (1, 8000)) # variable sequence length
mask = torch.ones((1, 8000)).bool() # variable mask length
# network will automatically pad to power of 2, do hierarchical attention, etc
logits = model(x, mask = mask) # (1, 8000, 256)
Citations
@misc{zhu2021htransformer1d,
title = {H-Transformer-1D: Fast One-Dimensional Hierarchical Attention for Sequences},
author = {Zhenhai Zhu and Radu Soricut},
year = {2021},
eprint = {2107.11906},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = {aug},
year = {2021},
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578}
}
.\lucidrains\h-transformer-1d\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'h-transformer-1d', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.1.9', # 版本号
license='MIT', # 许可证
description = 'H-Transformer 1D - Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/h-transformer-1d', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'efficient attention'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'rotary-embedding-torch>=0.5.3',
'torch>=1.6'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\h-transformer-1d\train.py
# 导入所需的模块和类
from h_transformer_1d import HTransformer1D
from h_transformer_1d.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = HTransformer1D(
num_tokens = 256,
dim = 512,
max_seq_len = SEQ_LEN,
depth = 8,
heads = 8,
causal = True,
reversible = True
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)
.\lucidrains\halonet-pytorch\halonet_pytorch\halonet_pytorch.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
# 导入所需的库
# 相对位置编码
def to(x):
return {'device': x.device, 'dtype': x.dtype}
# 返回包含设备和数据类型信息的字典
def pair(x):
return (x, x) if not isinstance(x, tuple) else x
# 如果输入不是元组,则返回包含两个相同元素的元组,否则返回原元组
def expand_dim(t, dim, k):
t = t.unsqueeze(dim = dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
# 在指定维度上扩展张量的大小
def rel_to_abs(x):
b, l, m = x.shape
r = (m + 1) // 2
col_pad = torch.zeros((b, l, 1), **to(x))
x = torch.cat((x, col_pad), dim = 2)
flat_x = rearrange(x, 'b l c -> b (l c)')
flat_pad = torch.zeros((b, m - l), **to(x))
flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1)
final_x = flat_x_padded.reshape(b, l + 1, m)
final_x = final_x[:, :l, -r:]
return final_x
# 将相对位置编码转换为绝对位置编码
def relative_logits_1d(q, rel_k):
b, h, w, _ = q.shape
r = (rel_k.shape[0] + 1) // 2
logits = einsum('b x y d, r d -> b x y r', q, rel_k)
logits = rearrange(logits, 'b x y r -> (b x) y r')
logits = rel_to_abs(logits)
logits = logits.reshape(b, h, w, r)
logits = expand_dim(logits, dim = 2, k = r)
return logits
# 计算相对位置的一维逻辑值
class RelPosEmb(nn.Module):
def __init__(
self,
block_size,
rel_size,
dim_head
):
super().__init__()
height = width = rel_size
scale = dim_head ** -0.5
self.block_size = block_size
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
def forward(self, q):
block = self.block_size
q = rearrange(q, 'b (x y) c -> b x y c', x = block)
rel_logits_w = relative_logits_1d(q, self.rel_width)
rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')
q = rearrange(q, 'b x y d -> b y x d')
rel_logits_h = relative_logits_1d(q, self.rel_height)
rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')
return rel_logits_w + rel_logits_h
# 相对位置编码类
# classes
class HaloAttention(nn.Module):
def __init__(
self,
*,
dim,
block_size,
halo_size,
dim_head = 64,
heads = 8
):
super().__init__()
assert halo_size > 0, 'halo size must be greater than 0'
self.dim = dim
self.heads = heads
self.scale = dim_head ** -0.5
self.block_size = block_size
self.halo_size = halo_size
inner_dim = dim_head * heads
self.rel_pos_emb = RelPosEmb(
block_size = block_size,
rel_size = block_size + (halo_size * 2),
dim_head = dim_head
)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
# HaloAttention 类,实现了自注意力机制
# 定义前向传播函数,接受输入张量 x
def forward(self, x):
# 解包输入张量 x 的形状信息,包括批大小 b,通道数 c,高度 h,宽度 w,块大小 block,边界大小 halo,头数 heads,设备信息 device
b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.device
# 断言高度和宽度能够被块大小整除,确保 fmap 的维度必须是块大小的整数倍
assert h % block == 0 and w % block == 0, 'fmap dimensions must be divisible by the block size'
# 断言通道数等于指定的维度
assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'
# 获取块的邻域,并为推导键值准备一个带有边界的版本(带有填充的块)
# 重排输入张量 x,将其形状变为 '(b c (h p1) (w p2) -> (b h w) (p1 p2) c',其中 p1 和 p2 为块大小
q_inp = rearrange(x, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = block, p2 = block)
# 使用 F.unfold 函数对 x 进行展开,设置卷积核大小为 block + halo * 2,步长为 block,填充为 halo
kv_inp = F.unfold(x, kernel_size = block + halo * 2, stride = block, padding = halo)
# 重排展开后的张量 kv_inp,将其形状变为 '(b (c j) i -> (b i) j c',其中 j 为块大小
kv_inp = rearrange(kv_inp, 'b (c j) i -> (b i) j c', c = c)
# 推导查询、键、值
# 将 q_inp 输入到 self.to_q 函数中得到查询 q
q = self.to_q(q_inp)
# 将 kv_inp 输入到 self.to_kv 函数中得到键 k 和值 v,并按最后一个维度分割成两部分
k, v = self.to_kv(kv_inp).chunk(2, dim = -1)
# 分割头部
# 对查询 q、键 k、值 v 进行重排,将其形状变为 '(b n (h d) -> (b h) n d',其中 h 为头部数
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = heads), (q, k, v))
# 缩放
q *= self.scale
# 注意力计算
sim = einsum('b i d, b j d -> b i j', q, k)
# 添加相对位置偏置
sim += self.rel_pos_emb(q)
# 掩码填充(在论文中,他们声称不需要掩码,但是对于填充怎么处理?)
# 创建全为 1 的掩码张量 mask,形状为 (1, 1, h, w),设备为 device
mask = torch.ones(1, 1, h, w, device = device)
# 使用 F.unfold 函数对 mask 进行展开,设置卷积核大小为 block + (halo * 2),步长为 block,填充为 halo
mask = F.unfold(mask, kernel_size = block + (halo * 2), stride = block, padding = halo)
# 重复 mask 张量,形状变为 '(() j i -> (b i h) () j',其中 b 为批大小,h 为头部数
mask = repeat(mask, '() j i -> (b i h) () j', b = b, h = heads)
# 将 mask 转换为布尔类型张量
mask = mask.bool()
# 计算最大负值
max_neg_value = -torch.finfo(sim.dtype).max
# 使用 mask 对 sim 进行掩码填充,将 mask 为 True 的位置���充为最大负值
sim.masked_fill_(mask, max_neg_value)
# 注意力计算
attn = sim.softmax(dim = -1)
# 聚合
out = einsum('b i j, b j d -> b i d', attn, v)
# 合并和组合头部
out = rearrange(out, '(b h) n d -> b n (h d)', h = heads)
out = self.to_out(out)
# 将块合并回原始特征图
out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b = b, h = (h // block), w = (w // block), p1 = block, p2 = block)
return out
.\lucidrains\halonet-pytorch\halonet_pytorch\__init__.py
# 从 halonet_pytorch.halonet_pytorch 模块中导入 HaloAttention 类
from halonet_pytorch.halonet_pytorch import HaloAttention
HaloNet - Pytorch
Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This repository will only house the attention layer and not much more.
Install
$ pip install halonet-pytorch
Usage
import torch
from halonet_pytorch import HaloAttention
attn = HaloAttention(
dim = 512, # dimension of feature map
block_size = 8, # neighborhood block size (feature map must be divisible by this)
halo_size = 4, # halo size (block receptive field)
dim_head = 64, # dimension of each head
heads = 4 # number of attention heads
).cuda()
fmap = torch.randn(1, 512, 32, 32).cuda()
attn(fmap) # (1, 512, 32, 32)
Citations
@misc{vaswani2021scaling,
title = {Scaling Local Self-Attention For Parameter Efficient Visual Backbones},
author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and Jonathon Shlens},
year = {2021},
eprint = {2103.12731},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
.\lucidrains\halonet-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'halonet-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.4', # 版本号
license='MIT', # 许可证
description = 'HaloNet - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/halonet-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'attention mechanism'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\hamburger-pytorch\hamburger_pytorch\hamburger_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 contextlib 模块中导入 contextmanager 上下文管理器
from contextlib import contextmanager
# 从 einops 模块中导入 repeat, rearrange 函数
from einops import repeat, rearrange
# 辅助函数
# 定义一个空上下文管理器
@contextmanager
def null_context():
yield
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 类
# 定义 NMF 类,继承自 nn.Module
class NMF(nn.Module):
def __init__(
self,
dim,
n,
ratio = 8,
K = 6,
eps = 2e-8
):
super().__init__()
r = dim // ratio
# 初始化 D 和 C 为随机数
D = torch.zeros(dim, r).uniform_(0, 1)
C = torch.zeros(r, n).uniform_(0, 1)
self.K = K
self.D = nn.Parameter(D)
self.C = nn.Parameter(C)
self.eps = eps
def forward(self, x):
b, D, C, eps = x.shape[0], self.D, self.C, self.eps
# 将输入 x 转为非负数
x = F.relu(x)
# 将 D 和 C 扩展为与输入 x 相同的 batch 维度
D = repeat(D, 'd r -> b d r', b = b)
C = repeat(C, 'r n -> b r n', b = b)
# 转置函数
t = lambda tensor: rearrange(tensor, 'b i j -> b j i')
for k in reversed(range(self.K)):
# 只在最后一步计算梯度,根据 'One-step Gradient' 提议
context = null_context if k == 0 else torch.no_grad
with context():
C_new = C * ((t(D) @ x) / ((t(D) @ D @ C) + eps))
D_new = D * ((x @ t(C)) / ((D @ C @ t(C)) + eps))
C, D = C_new, D_new
return D @ C
# 定义 Hamburger 类,继承自 nn.Module
class Hamburger(nn.Module):
def __init__(
self,
*,
dim,
n,
inner_dim = None,
ratio = 8,
K = 6
):
super().__init__()
inner_dim = default(inner_dim, dim)
# 定义 lower_bread 为一维卷积层
self.lower_bread = nn.Conv1d(dim, inner_dim, 1, bias = False)
# 定义 ham 为 NMF 类的实例
self.ham = NMF(inner_dim, n, ratio = ratio, K = K)
# 定义 upper_bread 为一维卷积层
self.upper_bread = nn.Conv1d(inner_dim, dim, 1, bias = False)
def forward(self, x):
shape = x.shape
# 将输入 x 展平为二维
x = x.flatten(2)
x = self.lower_bread(x)
x = self.ham(x)
x = self.upper_bread(x)
# 将 x 重新 reshape 成原始形状
return x.reshape(shape)
.\lucidrains\hamburger-pytorch\hamburger_pytorch\__init__.py
# 从hamburger_pytorch包中导入Hamburger类
from hamburger_pytorch.hamburger_pytorch import Hamburger
🍔 - Pytorch
Pytorch implementation of the hamburger module from the ICLR 2021 paper Is Attention Better Than Matrix Decomposition?. Following Betteridge's law, the answer according to the paper is "No" for segmentation and GANs.
This repository will contain the NMF-MU (nonnegative matrix factorization w/ multiplicative update) module sandwiched by linear projections.
Update: I tried this, but did not get better results than just using linear attention
Install
$ pip install hamburger-pytorch
Usage
import torch
from hamburger_pytorch import Hamburger
hamburger = Hamburger(
dim = 512, # input dimension
n = 32 * 32, # n will be size of the sequence, in this case, height times width of the images
ratio = 8, # matrix factorization ratio, recommended to be at 8
K = 6 # number of iterations, optimal at 6 as shown in paper
)
x = torch.randn(1, 512, 32, 32)
hamburger(x) + x # (1, 512, 32, 32)
Citations
@inproceedings{
anonymous2021is,
title={Is Attention Better Than Matrix Decomposition?},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=1FvkSpWosOl},
note={under review}
}
.\lucidrains\hamburger-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'hamburger-pytorch',
# 查找并包含所有包
packages = find_packages(),
# 版本号
version = '0.0.3',
# 许可证
license='MIT',
# 描述
description = 'Hamburger - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/hamburger-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'attention mechanism',
'matrix factorization'
],
# 安装依赖
install_requires=[
'torch',
'einops>=0.3'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\holodeck-pytorch\holodeck_pytorch\holodeck_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 F 函数
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 注意力机制类
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
dim_context = None,
heads = 8,
norm_context = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
dim_context = default(dim_context, dim)
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(
self,
x,
context = None,
mask = None,
attn_bias = None
):
if exists(context):
context = self.context_norm(context)
kv_input = default(context, x)
x = self.norm(x)
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j', q, k)
mask_value = -torch.finfo(sim.dtype).max
if exists(attn_bias):
sim = sim + attn_bias
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, mask_value)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 主类
class Holodeck(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
.\lucidrains\holodeck-pytorch\holodeck_pytorch\__init__.py
# 从 holodeck_pytorch 模块中导入 Holodeck 类
from holodeck_pytorch.holodeck_pytorch import Holodeck
Holodeck - Pytorch (wip)
Implementation of a holodeck, written in Pytorch.
Citations
@article{Wu20234DGS,
title = {4D Gaussian Splatting for Real-Time Dynamic Scene Rendering},
author = {Guanjun Wu and Taoran Yi and Jiemin Fang and Lingxi Xie and Xiaopeng Zhang and Wei Wei and Wenyu Liu and Qi Tian and Xinggang Wang},
journal = {ArXiv},
year = {2023},
volume = {abs/2310.08528},
url = {https://api.semanticscholar.org/CorpusID:263908793}
}
@inproceedings{Singer2023TextTo4DDS,
title = {Text-To-4D Dynamic Scene Generation},
author = {Uriel Singer and Shelly Sheynin and Adam Polyak and Oron Ashual and Iurii Makarov and Filippos Kokkinos and Naman Goyal and Andrea Vedaldi and Devi Parikh and Justin Johnson and Yaniv Taigman},
year = {2023}
}
@inproceedings{Bauer2023SpatialFS,
title = {Spatial Functa: Scaling Functa to ImageNet Classification and Generation},
author = {M. Bauer and Emilien Dupont and Andy Brock and Dan Rosenbaum and Jonathan Schwarz and Hyunjik Kim},
year = {2023}
}
.\lucidrains\holodeck-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'holodeck-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.1', # 版本号
license='MIT', # 许可证
description = 'Holodeck - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/holodeck-pytorch', # URL
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'denoising diffusion',
'temporal scene representations',
'hypernetworks'
],
install_requires=[ # 安装依赖
'einops>=0.4',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
.\lucidrains\hourglass-transformer-pytorch\hourglass_transformer_pytorch\autoregressive_wrapper.py
import torch
from torch import nn
import torch.nn.functional as F
# helper function
# 检查值是否存在
def exists(val):
return val is not None
# 装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# top k filtering
# 根据阈值过滤 logits,保留前 k 个值
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value = 0):
super().__init__()
assert hasattr(net, 'max_seq_len'), 'your transformer class must have max_seq_len set to the maximum sequence length'
self.pad_value = pad_value
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列的方法
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
b, t, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
logits = self.net(x, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = (out == eos_token)
if is_eos_token.any(dim = -1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
out = out[:, t:]
return out
# 前向传播方法
def forward(self, x, **kwargs):
x_inp, x_labels = x[:, :-1], x[:, 1:]
logits = self.net(x_inp, **kwargs)
return F.cross_entropy(logits.transpose(1, 2), x_labels, ignore_index = self.pad_value)
.\lucidrains\hourglass-transformer-pytorch\hourglass_transformer_pytorch\hourglass_transformer_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块、einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 库中导入 F 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange、reduce、repeat 函数
from einops import rearrange, reduce, repeat
# helpers
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 将张量填充到指定的倍数的函数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
seq_len = tensor.shape[dim]
m = seq_len / multiple
if m.is_integer():
return tensor
remainder = math.ceil(m) * multiple - seq_len
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(tensor, (*pad_offset, 0, remainder), value = value)
# 将输入值转换为元组的函数
def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else ((val,) * depth)
# factory
# 获取 hourglass transformer 的工厂函数
def get_hourglass_transformer(
dim,
*,
depth,
shorten_factor,
attn_resampling,
updown_sample_type,
**kwargs
):
assert isinstance(depth, int) or (isinstance(depth, tuple) and len(depth) == 3), 'depth must be either an integer or a tuple of 3, indicating (pre_transformer_depth, <nested-hour-glass-config>, post_transformer_depth)'
assert not (isinstance(depth, int) and shorten_factor), 'there does not need to be a shortening factor when only a single transformer block is indicated (depth of one integer value)'
if isinstance(depth, int):
return Transformer(dim = dim, depth = depth, **kwargs)
return HourglassTransformer(dim = dim, depth = depth, shorten_factor = shorten_factor, attn_resampling = attn_resampling, updown_sample_type = updown_sample_type, **kwargs)
# up and down sample classes
# 下采样类
class NaiveDownsample(nn.Module):
def __init__(self, shorten_factor):
super().__init__()
self.shorten_factor = shorten_factor
def forward(self, x):
return reduce(x, 'b (n s) d -> b n d', 'mean', s = self.shorten_factor)
# 上采样类
class NaiveUpsample(nn.Module):
def __init__(self, shorten_factor):
super().__init__()
self.shorten_factor = shorten_factor
def forward(self, x):
return repeat(x, 'b n d -> b (n s) d', s = self.shorten_factor)
# 线性下采样类
class LinearDownsample(nn.Module):
def __init__(self, dim, shorten_factor):
super().__init__()
self.proj = nn.Linear(dim * shorten_factor, dim)
self.shorten_factor = shorten_factor
def forward(self, x):
x = rearrange(x, 'b (n s) d -> b n (s d)', s = self.shorten_factor)
return self.proj(x)
# 线性上采样类
class LinearUpsample(nn.Module):
def __init__(self, dim, shorten_factor):
super().__init__()
self.proj = nn.Linear(dim, dim * shorten_factor)
self.shorten_factor = shorten_factor
def forward(self, x):
x = self.proj(x)
return rearrange(x, 'b n (s d) -> b (n s) d', s = self.shorten_factor)
# classes
# 预归一化残差类
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
# 注意力机制类
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
causal = False
):
super().__init__()
self.heads = heads
self.causal = causal
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)
# 定义一个前向传播函数,接受输入 x,上下文 context 和掩码 mask
def forward(self, x, context = None, mask = None):
# 获取头数和设备信息
h, device = self.heads, x.device
# 如果没有指定上下文,则使用输入 x 作为键值对输入
kv_input = default(context, x)
# 将输入 x 分别转换为查询 q,键 k 和值 v
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
# 将查询 q,键 k 和值 v 重排维度,以适应多头注意力机制
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 对查询 q 进行缩放
q = q * self.scale
# 计算查询和键之间的相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 定义掩码值为负无穷
mask_value = -torch.finfo(sim.dtype).max
# 如果存在掩码,则将相似度矩阵进行掩码处理
if exists(mask):
mask = rearrange(mask, 'b j -> b () () j')
sim = sim.masked_fill(~mask, mask_value)
# 如果启用因果性,生成一个上三角掩码矩阵
if self.causal:
i, j = sim.shape[-2:]
mask = torch.ones(i, j, device = device, dtype = torch.bool).triu_(j - i + 1)
mask = rearrange(mask, 'i j -> () () i j')
sim = sim.masked_fill(mask, mask_value)
# 对相似度矩阵进行 softmax 操作
attn = sim.softmax(dim = -1)
# 对注意力矩阵进行 dropout 操作
attn = self.dropout(attn)
# 根据注意力矩阵计算输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 重排输出维度,以适应后续处理
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
# 返回输出结果
return self.to_out(out)
def FeedForward(dim, mult = 4, dropout = 0.):
# 返回一个包含线性层、GELU激活函数、Dropout层和另一个线性层的序列模块
return nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
# transformer classes
class Transformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
causal = False,
heads = 8,
dim_head = 64,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
norm_out = False
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
# 为每个深度创建一个包含注意力和前馈网络的预层归一化残差模块
self.layers.append(nn.ModuleList([
PreNormResidual(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout, causal = causal)),
PreNormResidual(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
]))
# 如果需要输出归一化,则使用LayerNorm,否则使用Identity
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity()
def forward(self, x, context = None, mask = None):
for attn, ff in self.layers:
# 依次对每个层进行前向传播:注意力层 -> 前馈网络
x = attn(x, context = context, mask = mask)
x = ff(x)
return self.norm(x)
class HourglassTransformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
shorten_factor = 2,
attn_resampling = True,
updown_sample_type = 'naive',
heads = 8,
dim_head = 64,
causal = False,
norm_out = False
):
super().__init__()
assert len(depth) == 3, 'depth should be a tuple of length 3'
assert updown_sample_type in {'naive', 'linear'}, 'downsample / upsample type must be either naive (average pool and repeat) or linear (linear projection and reshape)'
pre_layers_depth, valley_depth, post_layers_depth = depth
if isinstance(shorten_factor, (tuple, list)):
shorten_factor, *rest_shorten_factor = shorten_factor
elif isinstance(valley_depth, int):
shorten_factor, rest_shorten_factor = shorten_factor, None
else:
shorten_factor, rest_shorten_factor = shorten_factor, shorten_factor
transformer_kwargs = dict(
dim = dim,
heads = heads,
dim_head = dim_head
)
self.causal = causal
self.shorten_factor = shorten_factor
if updown_sample_type == 'naive':
# 使用NaiveDownsample和NaiveUpsample进行下采样和上采样
self.downsample = NaiveDownsample(shorten_factor)
self.upsample = NaiveUpsample(shorten_factor)
elif updown_sample_type == 'linear':
# 使用LinearDownsample和LinearUpsample进行下采样和上采样
self.downsample = LinearDownsample(dim, shorten_factor)
self.upsample = LinearUpsample(dim, shorten_factor)
else:
raise ValueError(f'unknown updown_sample_type keyword value - must be either naive or linear for now')
# 获取中间层的Transformer
self.valley_transformer = get_hourglass_transformer(
shorten_factor = rest_shorten_factor,
depth = valley_depth,
attn_resampling = attn_resampling,
updown_sample_type = updown_sample_type,
causal = causal,
**transformer_kwargs
)
# 如果需要注意力重采样,则创建前后的Transformer
self.attn_resampling_pre_valley = Transformer(depth = 1, **transformer_kwargs) if attn_resampling else None
self.attn_resampling_post_valley = Transformer(depth = 1, **transformer_kwargs) if attn_resampling else None
# 创建前向Transformer和后向Transformer
self.pre_transformer = Transformer(depth = pre_layers_depth, causal = causal, **transformer_kwargs)
self.post_transformer = Transformer(depth = post_layers_depth, causal = causal, **transformer_kwargs)
# 如果需要输出归一化,则使用LayerNorm,否则使用Identity
self.norm_out = nn.LayerNorm(dim) if norm_out else nn.Identity()
def forward(self, x, mask = None):
# 定义变量含义:b 为 batch 大小,n 为序列长度,d 为特征维度,s 为缩短因子
s, b, n = self.shorten_factor, *x.shape[:2]
# hourglass 的上半部分,前置 transformer 层
x = self.pre_transformer(x, mask = mask)
# 填充到缩短因子的倍数,为池化做准备
x = pad_to_multiple(x, s, dim = -2)
if exists(mask):
padded_mask = pad_to_multiple(mask, s, dim = -1, value = False)
# 保存残差,并用于“注意力重采样”在下采样和上采样时
x_residual = x.clone()
# 如果是自回归的,进行移位操作,移位量为缩短因子减一
if self.causal:
shift = s - 1
x = F.pad(x, (0, 0, shift, -shift), value = 0.)
if exists(mask):
padded_mask = F.pad(padded_mask, (shift, -shift), value = False)
# 简单的平均池化
downsampled = self.downsample(x)
if exists(mask):
downsampled_mask = reduce(padded_mask, 'b (n s) -> b n', 'sum', s = s) > 0
else:
downsampled_mask = None
# 前谷“注意力重采样” - 每个桶中的池化令牌与预池化的令牌进行关注
if exists(self.attn_resampling_pre_valley):
if exists(mask):
attn_resampling_mask = rearrange(padded_mask, 'b (n s) -> (b n) s', s = s)
else:
attn_resampling_mask = None
downsampled = self.attn_resampling_pre_valley(
rearrange(downsampled, 'b n d -> (b n) () d'),
rearrange(x, 'b (n s) d -> (b n) s d', s = s),
mask = attn_resampling_mask
)
downsampled = rearrange(downsampled, '(b n) () d -> b n d', b = b)
# “谷” - 可能是一个常规 transformer 或另一个 hourglass
x = self.valley_transformer(downsampled, mask = downsampled_mask)
valley_out = x.clone()
# 简单的重复上采样
x = self.upsample(x)
# 加上残差
x = x + x_residual
# 后谷“注意力重采样”
if exists(self.attn_resampling_post_valley):
x = self.attn_resampling_post_valley(
rearrange(x, 'b (n s) d -> (b n) s d', s = s),
rearrange(valley_out, 'b n d -> (b n) () d')
)
x = rearrange(x, '(b n) s d -> b (n s) d', b = b)
# 将序列恢复到原始长度,如果为了池化而填充
x = x[:, :n]
# 后置 transformer 层
x = self.post_transformer(x, mask = mask)
return self.norm_out(x)
# 主要类定义
class HourglassTransformerLM(nn.Module):
def __init__(
self,
*,
num_tokens, # 标记的数量
dim, # 向量维度
max_seq_len, # 最大序列长度
depth, # 深度
shorten_factor = None, # 缩短因子,默认为None
heads = 8, # 头数,默认为8
dim_head = 64, # 头的维度,默认为64
attn_resampling = True, # 注意力重采样,默认为True
updown_sample_type = 'naive', # 上下采样类型,默认为'naive'
causal = True # 因果关系,默认为True
):
super().__init__()
self.max_seq_len = max_seq_len
# 标记嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 位置嵌入层
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 获取 HourglassTransformer 模型
self.transformer = get_hourglass_transformer(
dim = dim,
depth = depth,
shorten_factor = shorten_factor,
attn_resampling = attn_resampling,
updown_sample_type = updown_sample_type,
dim_head = dim_head,
heads = heads,
causal = causal,
norm_out = True
)
# 线性层,用于输出logits
self.to_logits = nn.Linear(dim, num_tokens)
def forward(self, x, mask = None):
device = x.device
x = self.token_emb(x)
pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device))
x = x + rearrange(pos_emb, 'n d -> () n d')
# 使用 Transformer 处理输入数据
x = self.transformer(x, mask = mask)
return self.to_logits(x)
.\lucidrains\hourglass-transformer-pytorch\hourglass_transformer_pytorch\__init__.py
# 从 hourglass_transformer_pytorch.hourglass_transformer_pytorch 模块中导入 HourglassTransformerLM 和 HourglassTransformer 类
from hourglass_transformer_pytorch.hourglass_transformer_pytorch import HourglassTransformerLM, HourglassTransformer
Hourglass Transformer - Pytorch
Implementation of Hourglass Transformer, in Pytorch.
Install
$ pip install hourglass-transformer-pytorch
Usage
import torch
from hourglass_transformer_pytorch import HourglassTransformerLM
model = HourglassTransformerLM(
num_tokens = 256, # number of tokens
dim = 512, # feature dimension
max_seq_len = 1024, # maximum sequence length
heads = 8, # attention heads
dim_head = 64, # dimension per attention head
shorten_factor = 2, # shortening factor
depth = (4, 2, 4), # tuple of 3, standing for pre-transformer-layers, valley-transformer-layers (after downsample), post-transformer-layers (after upsample) - the valley transformer layers can be yet another nested tuple, in which case it will shorten again recursively
)
x = torch.randint(0, 256, (1, 1024))
logits = model(x) # (1, 1024, 256)
For something more sophisticated, two hourglasses, with one nested within the other
import torch
from hourglass_transformer_pytorch import HourglassTransformerLM
model = HourglassTransformerLM(
num_tokens = 256,
dim = 512,
max_seq_len = 1024,
shorten_factor = (2, 4), # 2x for first hour glass, 4x for second
depth = (4, (2, 1, 2), 3), # 4@1 -> 2@2 -> 1@4 -> 2@2 -> 3@1
)
x = torch.randint(0, 256, (1, 1024))
logits = model(x)
Funnel Transformer would be approximately
import torch
from hourglass_transformer_pytorch import HourglassTransformerLM
model = HourglassTransformerLM(
num_tokens = 20000,
dim = 512,
max_seq_len = 1024,
causal = False,
attn_resampling = False,
shorten_factor = 2,
depth = (2, (2, (2, 2, 2), 2), 2)
)
x = torch.randint(0, 20000, (1, 1024))
logits = model(x)
For images, instead of average pool and repeat for the down and upsampling functions, they found that linear projections worked a lot better. You can use this by setting updown_sample_type = 'linear'
import torch
from hourglass_transformer_pytorch import HourglassTransformer
model = HourglassTransformer(
dim = 512,
shorten_factor = 2,
depth = (4, 2, 4),
updown_sample_type = 'linear'
)
img_tokens = torch.randn(1, 1024, 512)
model(img_tokens) # (1, 1024, 512)
Although results were not presented in the paper, you can also use the Hourglass Transformer in this repository non-autoregressively.
import torch
from hourglass_transformer_pytorch import HourglassTransformerLM
model = HourglassTransformerLM(
num_tokens = 20000,
dim = 512,
max_seq_len = 1024,
shorten_factor = 2,
depth = (4, 2, 4),
causal = False # set this to False
)
x = torch.randint(0, 256, (1, 1024))
mask = torch.ones((1, 1024)).bool()
logits = model(x, mask = mask) # (1, 1024, 20000)
Enwik8 autoregressive example
$ python train.py
Todo
Citations
@misc{nawrot2021hierarchical,
title = {Hierarchical Transformers Are More Efficient Language Models},
author = {Piotr Nawrot and Szymon Tworkowski and Michał Tyrolski and Łukasz Kaiser and Yuhuai Wu and Christian Szegedy and Henryk Michalewski},
year = {2021},
eprint = {2110.13711},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\hourglass-transformer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'hourglass-transformer-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.6', # 版本号
license='MIT', # 许可证
description = 'Hourglass Transformer', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/hourglass-transformer-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'transformers'
],
install_requires=[ # 安装依赖
'einops',
'torch>=1.6'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\hourglass-transformer-pytorch\train.py
# 导入所需的模块和类
from hourglass_transformer_pytorch import HourglassTransformerLM
from hourglass_transformer_pytorch.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类 GPT-like decoder model
model = HourglassTransformerLM(
num_tokens = 256,
dim = 512,
max_seq_len = SEQ_LEN,
depth = (4, 2, 4),
shorten_factor = 2,
heads = 8
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str)
.\lucidrains\HTM-pytorch\htm_pytorch\htm_pytorch.py
# 从 math 模块中导入 ceil 函数
from math import ceil
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 和 einsum
from torch import nn, einsum
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 einops 模块中导入 rearrange 和 repeat
from einops import rearrange, repeat
# helpers
# 定义函数 exists,判断值是否存在
def exists(val):
return val is not None
# 定义函数 default,如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 定义函数 pad_to_multiple,将输入张量在指定维度上填充到指定的倍数长度
def pad_to_multiple(t, multiple, dim = -2, value = 0.):
seq_len = t.shape[dim]
pad_to_len = ceil(seq_len / multiple) * multiple
remainder = pad_to_len - seq_len
if remainder == 0:
return t
zeroes = (0, 0) * (-dim - 1)
padded_t = F.pad(t, (*zeroes, remainder, 0), value = value)
return padded_t
# positional encoding
# 定义 SinusoidalPosition 类,用于生成位置编码
class SinusoidalPosition(nn.Module):
def __init__(
self,
dim,
min_timescale = 2.,
max_timescale = 1e4
):
super().__init__()
freqs = torch.arange(0, dim, min_timescale)
inv_freqs = max_timescale ** (-freqs / dim)
self.register_buffer('inv_freqs', inv_freqs)
def forward(self, x):
seq_len = x.shape[-2]
seq = torch.arange(seq_len - 1, -1, -1.)
sinusoidal_inp = rearrange(seq, 'n -> n ()') * rearrange(self.inv_freqs, 'd -> () d')
pos_emb = torch.cat((sinusoidal_inp.sin(), sinusoidal_inp.cos()), dim = -1)
return pos_emb
# multi-head attention
# 定义 Attention 类,实现多头注意力机制
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(
self,
x,
mems,
mask = None
):
h = self.heads
q, k, v = self.to_q(x), *self.to_kv(mems).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b ... (h d) -> (b h) ... d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b m i d, b m i j d -> b m i j', q, k)
if exists(mask):
mask = repeat(mask, 'b ... -> (b h) ...', h = h)
mask_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(~mask, mask_value)
attn = sim.softmax(dim = -1)
out = einsum('... i j, ... i j d -> ... i d', attn, v)
out = rearrange(out, '(b h) ... d -> b ... (h d)', h = h)
return self.to_out(out)
# main class
# 定义 HTMAttention 类,实现 HTMAttention 模型
class HTMAttention(nn.Module):
def __init__(
self,
dim,
heads,
topk_mems = 2,
mem_chunk_size = 32,
dim_head = 64,
add_pos_enc = True,
eps = 1e-5
):
super().__init__()
self.dim = dim
self.eps = eps
self.scale = dim ** -0.5
self.to_summary_queries = nn.Linear(dim, dim)
self.to_summary_keys = nn.Linear(dim, dim)
self.attn = Attention(dim = dim, heads = heads, dim_head = dim_head)
self.topk_mems = topk_mems
self.mem_chunk_size = mem_chunk_size
self.pos_emb = SinusoidalPosition(dim = dim) if add_pos_enc else None
def forward(
self,
queries,
memories,
mask = None,
chunk_attn_mask = None
):
# 解包参数
dim, query_len, mem_chunk_size, topk_mems, scale, eps = self.dim, queries.shape[1], self.mem_chunk_size, self.topk_mems, self.scale, self.eps
# 填充记忆,以及如果需要的话,填充记忆掩码,然后分成块
memories = pad_to_multiple(memories, mem_chunk_size, dim = -2, value = 0.)
memories = rearrange(memories, 'b (n c) d -> b n c d', c = mem_chunk_size)
if exists(mask):
mask = pad_to_multiple(mask, mem_chunk_size, dim = -1, value = False)
mask = rearrange(mask, 'b (n c) -> b n c', c = mem_chunk_size)
# 通过均值池化总结记忆,考虑掩码
if exists(mask):
mean_mask = rearrange(mask, '... -> ... ()')
memories = memories.masked_fill(~mean_mask, 0.)
numer = memories.sum(dim = 2)
denom = mean_mask.sum(dim = 2)
summarized_memories = numer / (denom + eps)
else:
summarized_memories = memories.mean(dim = 2)
# 推导查询和总结的记忆键
summary_queries = self.to_summary_queries(queries)
summary_keys = self.to_summary_keys(summarized_memories.detach())
# 对总结的键进行单头注意力
sim = einsum('b i d, b j d -> b i j', summary_queries, summary_keys) * scale
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
chunk_mask = mask.any(dim = 2)
chunk_mask = rearrange(chunk_mask, 'b j -> b () j')
sim = sim.masked_fill(~chunk_mask, mask_value)
if exists(chunk_attn_mask):
sim = sim.masked_fill(~chunk_attn_mask, mask_value)
topk_logits, topk_indices = sim.topk(k = topk_mems, dim = -1)
weights = topk_logits.softmax(dim = -1)
# 为内存注意力准备查询
queries = repeat(queries, 'b n d -> b k n d', k = topk_mems)
# 选择前k个记忆
memories = repeat(memories, 'b m j d -> b m i j d', i = query_len)
mem_topk_indices = repeat(topk_indices, 'b i m -> b m i j d', j = mem_chunk_size, d = dim)
selected_memories = memories.gather(1, mem_topk_indices)
# 位置编码
if exists(self.pos_emb):
pos_emb = self.pos_emb(memories)
selected_memories = selected_memories + rearrange(pos_emb, 'n d -> () () () n d')
# 选择掩码
selected_mask = None
if exists(mask):
mask = repeat(mask, 'b m j -> b m i j', i = query_len)
mask_topk_indices = repeat(topk_indices, 'b i m -> b m i j', j = mem_chunk_size)
selected_mask = mask.gather(1, mask_topk_indices)
# 现在进行内存注意力
within_mem_output = self.attn(
queries,
selected_memories.detach(),
mask = selected_mask
)
# 对内存注意力输出进行加权
weighted_output = within_mem_output * rearrange(weights, 'b i m -> b m i ()')
output = weighted_output.sum(dim = 1)
return output
# 定义一个 HTMBlock 类,继承自 nn.Module
class HTMBlock(nn.Module):
# 初始化方法,接受维度参数和其他关键字参数
def __init__(self, dim, **kwargs):
super().__init__()
# 初始化 LayerNorm 层,对输入进行归一化处理
self.norm = nn.LayerNorm(dim)
# 初始化 HTMAttention 层,处理注意力机制
self.attn = HTMAttention(dim=dim, **kwargs)
# 前向传播方法,接受查询 queries 和记忆 memories,以及其他关键字参数
def forward(
self,
queries,
memories,
**kwargs
):
# 对查询 queries 进行归一化处理
queries = self.norm(queries)
# 使用 HTMAttention 层处理查询 queries 和记忆 memories,再加上原始查询 queries
out = self.attn(queries, memories, **kwargs) + queries
# 返回处理后的结果
return out
.\lucidrains\HTM-pytorch\htm_pytorch\__init__.py
# 从 htm_pytorch 包中导入 HTMAttention 和 HTMBlock 类
from htm_pytorch.htm_pytorch import HTMAttention, HTMBlock
Hierarchical Transformer Memory (HTM) - Pytorch
Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a simple method to allow transformers to attend to memories of the past efficiently. Original Jax repository
Install
$ pip install htm-pytorch
Usage
import torch
from htm_pytorch import HTMAttention
attn = HTMAttention(
dim = 512,
heads = 8, # number of heads for within-memory attention
dim_head = 64, # dimension per head for within-memory attention
topk_mems = 8, # how many memory chunks to select for
mem_chunk_size = 32, # number of tokens in each memory chunk
add_pos_enc = True # whether to add positional encoding to the memories
)
queries = torch.randn(1, 128, 512) # queries
memories = torch.randn(1, 20000, 512) # memories, of any size
mask = torch.ones(1, 20000).bool() # memory mask
attended = attn(queries, memories, mask = mask) # (1, 128, 512)
If you want the entire HTM Block (which contains the layernorm for the input followed by a skip connection), just import HTMBlock
instead
import torch
from htm_pytorch import HTMBlock
block = HTMBlock(
dim = 512,
topk_mems = 8,
mem_chunk_size = 32
)
queries = torch.randn(1, 128, 512)
memories = torch.randn(1, 20000, 512)
mask = torch.ones(1, 20000).bool()
out = block(queries, memories, mask = mask) # (1, 128, 512)
Citations
@misc{lampinen2021mental,
title = {Towards mental time travel: a hierarchical memory for reinforcement learning agents},
author = {Andrew Kyle Lampinen and Stephanie C. Y. Chan and Andrea Banino and Felix Hill},
year = {2021},
eprint = {2105.14039},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\HTM-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'htm-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.4', # 版本号
license='MIT', # 许可证
description = 'Hierarchical Transformer Memory - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/htm-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'attention-mechanism',
'memory'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\imagen-pytorch\imagen_pytorch\cli.py
import click
import torch
from pathlib import Path
import pkgutil
from imagen_pytorch import load_imagen_from_checkpoint
from imagen_pytorch.version import __version__
from imagen_pytorch.data import Collator
from imagen_pytorch.utils import safeget
from imagen_pytorch import ImagenTrainer, ElucidatedImagenConfig, ImagenConfig
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
import json
# 定义一个函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个简单的字符串处理函数,将特殊字符替换为下划线,并截取指定长度
def simple_slugify(text: str, max_length = 255):
return text.replace('-', '_').replace(',', '').replace(' ', '_').replace('|', '--').strip('-_./\\')[:max_length]
# 主函数
def main():
pass
# 创建一个命令组
@click.group()
def imagen():
pass
# 创建一个命令,用于从 Imagen 模型检查点中进行采样
@imagen.command(help = 'Sample from the Imagen model checkpoint')
@click.option('--model', default = './imagen.pt', help = 'path to trained Imagen model')
@click.option('--cond_scale', default = 5, help = 'conditioning scale (classifier free guidance) in decoder')
@click.option('--load_ema', default = True, help = 'load EMA version of unets if available')
@click.argument('text')
def sample(
model,
cond_scale,
load_ema,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
# 获取版本信息
version = safeget(loaded, 'version')
print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
# 获取 Imagen 参数和类型
imagen = load_imagen_from_checkpoint(str(model_path), load_ema_if_available = load_ema)
imagen.cuda()
# 生成图像
pil_image = imagen.sample([text], cond_scale = cond_scale, return_pil_images = True)
image_path = f'./{simple_slugify(text)}.png'
pil_image[0].save(image_path)
print(f'image saved to {str(image_path)}')
return
# 创建一个命令,用于生成 Imagen 模型的配置
@imagen.command(help = 'Generate a config for the Imagen model')
@click.option('--path', default = './imagen_config.json', help = 'Path to the Imagen model config')
def config(
path
):
data = pkgutil.get_data(__name__, 'default_config.json').decode("utf-8")
with open(path, 'w') as f:
f.write(data)
# 创建一个命令,用于训练 Imagen 模型
@imagen.command(help = 'Train the Imagen model')
@click.option('--config', default = './imagen_config.json', help = 'Path to the Imagen model config')
@click.option('--unet', default = 1, help = 'Unet to train', type = click.IntRange(1, 3, False, True, True))
@click.option('--epoches', default = 50, help = 'Amount of epoches to train for')
def train(
config,
unet,
epoches,
):
# 检查配置文件路径
config_path = Path(config)
full_config_path = str(config_path.resolve())
assert config_path.exists(), f'config not found at {full_config_path}'
with open(config_path, 'r') as f:
config_data = json.loads(f.read())
assert 'checkpoint_path' in config_data, 'checkpoint path not found in config'
model_path = Path(config_data['checkpoint_path'])
full_model_path = str(model_path.resolve())
# 设置 Imagen 配置
imagen_config_klass = ElucidatedImagenConfig if config_data['type'] == 'elucidated' else ImagenConfig
imagen = imagen_config_klass(**config_data['imagen']).create()
trainer = ImagenTrainer(
imagen = imagen,
**config_data['trainer']
)
# 加载模型
if model_path.exists():
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
trainer.load(model_path)
if torch.cuda.is_available():
trainer = trainer.cuda()
size = config_data['imagen']['image_sizes'][unet-1]
max_batch_size = config_data['max_batch_size'] if 'max_batch_size' in config_data else 1
channels = 'RGB'
# 检查配置数据中是否包含 'channels' 键
if 'channels' in config_data['imagen']:
# 断言通道数在 1 到 4 之间,否则抛出异常
assert config_data['imagen']['channels'] > 0 and config_data['imagen']['channels'] < 5, 'Imagen only support 1 to 4 channels L, LA, RGB, RGBA'
# 根据通道数设置 channels 变量
if config_data['imagen']['channels'] == 4:
channels = 'RGBA' # Color with alpha
elif config_data['imagen']['channels'] == 2:
channels == 'LA' # Luminance (Greyscale) with alpha
elif config_data['imagen']['channels'] == 1:
channels = 'L' # Luminance (Greyscale)
# 断言配置数据中包含 'batch_size' 键
assert 'batch_size' in config_data['dataset'], 'A batch_size is required in the config file'
# 加载并添加训练数据集和验证数据集
ds = load_dataset(config_data['dataset_name'])
train_ds = None
# 如果有训练和验证数据集,则将它们合并成一个数据集,以便训练器处理拆分
if 'train' in ds and 'valid' in ds:
train_ds = concatenate_datasets([ds['train'], ds['valid']])
elif 'train' in ds:
train_ds = ds['train']
elif 'valid' in ds:
train_ds = ds['valid']
else:
train_ds = ds
# 断言训练数据集不为空
assert train_ds is not None, 'No train dataset could be fetched from the dataset name provided'
# 添加训练数据集到训练器
trainer.add_train_dataset(
ds = train_ds,
collate_fn = Collator(
image_size = size,
image_label = config_data['image_label'],
text_label = config_data['text_label'],
url_label = config_data['url_label'],
name = imagen.text_encoder_name,
channels = channels
),
**config_data['dataset']
)
# 检查是否需要验证、采样和保存
should_validate = trainer.split_valid_from_train and 'validate_at_every' in config_data
should_sample = 'sample_texts' in config_data and 'sample_at_every' in config_data
should_save = 'save_at_every' in config_data
# 根据配置设置验证、采样和保存的频率
valid_at_every = config_data['validate_at_every'] if should_validate else 0
assert isinstance(valid_at_every, int), 'validate_at_every must be an integer'
sample_at_every = config_data['sample_at_every'] if should_sample else 0
assert isinstance(sample_at_every, int), 'sample_at_every must be an integer'
save_at_every = config_data['save_at_every'] if should_save else 0
assert isinstance(save_at_every, int), 'save_at_every must be an integer'
sample_texts = config_data['sample_texts'] if should_sample else []
assert isinstance(sample_texts, list), 'sample_texts must be a list'
# 当 should_sample 为真时,检查 sample_texts 不为空
assert not should_sample or len(sample_texts) > 0, 'sample_texts must not be empty when sample_at_every is set'
# 循环训练模型
for i in range(epoches):
for _ in tqdm(range(len(trainer.train_dl)):
# 训练模型并获取损失
loss = trainer.train_step(unet_number = unet, max_batch_size = max_batch_size)
print(f'loss: {loss}')
# 在指定的验证频率进行验证
if not (i % valid_at_every) and i > 0 and trainer.is_main and should_validate:
valid_loss = trainer.valid_step(unet_number = unet, max_batch_size = max_batch_size)
print(f'valid loss: {valid_loss}')
# 在指定的采样频率进行采样并保存图片
if not (i % save_at_every) and i > 0 and trainer.is_main and should_sample:
images = trainer.sample(texts = [sample_texts], batch_size = 1, return_pil_images = True, stop_at_unet_number = unet)
images[0].save(f'./sample-{i // 100}.png')
# 在指定的保存频率保存模型
if not (i % save_at_every) and i > 0 and trainer.is_main and should_save:
trainer.save(model_path)
# 最终保存模型
trainer.save(model_path)
.\lucidrains\imagen-pytorch\imagen_pytorch\configs.py
# 导入必要的模块和类
from pydantic import BaseModel, model_validator
from typing import List, Optional, Union, Tuple
from enum import Enum
# 导入自定义模块中的类和函数
from imagen_pytorch.imagen_pytorch import Imagen, Unet, Unet3D, NullUnet
from imagen_pytorch.trainer import ImagenTrainer
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim
# 定义一些辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 定义一个接受内部类型的列表或元组
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
# 定义一个接受内部类型的单个值或列表
def SingleOrList(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# 噪声调度
# 定义一个枚举类,表示噪声调度的类型
class NoiseSchedule(Enum):
cosine = 'cosine'
linear = 'linear'
# 允许额外字段的基础模型类
class AllowExtraBaseModel(BaseModel):
class Config:
extra = "allow"
use_enum_values = True
# imagen pydantic 类
# 空 Unet 配置类
class NullUnetConfig(BaseModel):
is_null: bool
def create(self):
return NullUnet()
# Unet 配置类
class UnetConfig(AllowExtraBaseModel):
dim: int
dim_mults: ListOrTuple(int)
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
cond_dim: Optional[int] = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
def create(self):
return Unet(**self.dict())
# Unet3D 配置类
class Unet3DConfig(AllowExtraBaseModel):
dim: int
dim_mults: ListOrTuple(int)
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
cond_dim: Optional[int] = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
def create(self):
return Unet3D(**self.dict())
# Imagen 配置类
class ImagenConfig(AllowExtraBaseModel):
unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
image_sizes: ListOrTuple(int)
video: bool = False
timesteps: SingleOrList(int) = 1000
noise_schedules: SingleOrList(NoiseSchedule) = 'cosine'
text_encoder_name: str = DEFAULT_T5_NAME
channels: int = 3
loss_type: str = 'l2'
cond_drop_prob: float = 0.5
@model_validator(mode="after")
def check_image_sizes(self):
if len(self.image_sizes) != len(self.unets):
raise ValueError(f'image sizes length {len(self.image_sizes)} must be equivalent to the number of unets {len(self.unets)}')
return self
def create(self):
decoder_kwargs = self.dict()
unets_kwargs = decoder_kwargs.pop('unets')
is_video = decoder_kwargs.pop('video', False)
unets = []
for unet, unet_kwargs in zip(self.unets, unets_kwargs):
if isinstance(unet, NullUnetConfig):
unet_klass = NullUnet
elif is_video:
unet_klass = Unet3D
else:
unet_klass = Unet
unets.append(unet_klass(**unet_kwargs))
imagen = Imagen(unets, **decoder_kwargs)
imagen._config = self.dict().copy()
return imagen
# ElucidatedImagen 配置类
class ElucidatedImagenConfig(AllowExtraBaseModel):
unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
image_sizes: ListOrTuple(int)
video: bool = False
text_encoder_name: str = DEFAULT_T5_NAME
channels: int = 3
cond_drop_prob: float = 0.5
num_sample_steps: SingleOrList(int) = 32
sigma_min: SingleOrList(float) = 0.002
sigma_max: SingleOrList(int) = 80
sigma_data: SingleOrList(float) = 0.5
rho: SingleOrList(int) = 7
P_mean: SingleOrList(float) = -1.2
P_std: SingleOrList(float) = 1.2
S_churn: SingleOrList(int) = 80
S_tmin: SingleOrList(float) = 0.05
S_tmax: SingleOrList(int) = 50
# 定义 S_tmax 变量,类型为 int 或 int 列表,默认值为 50
S_noise: SingleOrList(float) = 1.003
# 定义 S_noise 变量,类型为 float 或 float 列表,默认值为 1.003
@model_validator(mode="after")
# 使用 model_validator 装饰器,指定 mode 参数为 "after"
def check_image_sizes(self):
# 检查图像大小是否与 unets 数量相等
if len(self.image_sizes) != len(self.unets):
raise ValueError(f'image sizes length {len(self.image_sizes)} must be equivalent to the number of unets {len(self.unets)}')
return self
# 返回当前对象
def create(self):
# 创建方法 create
decoder_kwargs = self.dict()
# 获取当前对象的字典形式
unets_kwargs = decoder_kwargs.pop('unets')
# 从字典中弹出键为 'unets' 的值,并赋给 unets_kwargs
is_video = decoder_kwargs.pop('video', False)
# 从字典中弹出键为 'video' 的值,如果不存在则默认为 False
unet_klass = Unet3D if is_video else Unet
# 根据 is_video 的值选择 Unet3D 或 Unet 类
unets = []
for unet, unet_kwargs in zip(self.unets, unets_kwargs):
# 遍历 self.unets 和 unets_kwargs
if isinstance(unet, NullUnetConfig):
unet_klass = NullUnet
elif is_video:
unet_klass = Unet3D
else:
unet_klass = Unet
unets.append(unet_klass(**unet_kwargs))
# 根据条件选择 Unet 类型,并将实例添加到 unets 列表中
imagen = ElucidatedImagen(unets, **decoder_kwargs)
# 创建 ElucidatedImagen 实例,传入 unets 和 decoder_kwargs
imagen._config = self.dict().copy()
# 将当前对象的字典形式复制给 imagen 的 _config 属性
return imagen
# 返回 imagen 实例
# 定义一个配置类 ImagenTrainerConfig,继承自 AllowExtraBaseModel
class ImagenTrainerConfig(AllowExtraBaseModel):
# 定义属性 imagen,类型为字典
imagen: dict
# 定义属性 elucidated,默认值为 False
elucidated: bool = False
# 定义属性 video,默认值为 False
video: bool = False
# 定义属性 use_ema,默认值为 True
use_ema: bool = True
# 定义属性 lr,默认值为 1e-4
lr: SingleOrList(float) = 1e-4
# 定义属性 eps,默认值为 1e-8
eps: SingleOrList(float) = 1e-8
# 定义属性 beta1,默认值为 0.9
beta1: float = 0.9
# 定义属性 beta2,默认值为 0.99
beta2: float = 0.99
# 定义属性 max_grad_norm,默认值为 None
max_grad_norm: Optional[float] = None
# 定义属性 group_wd_params,默认值为 True
group_wd_params: bool = True
# 定义属性 warmup_steps,默认值为 None
warmup_steps: SingleOrList(Optional[int]) = None
# 定义属性 cosine_decay_max_steps,默认值为 None
cosine_decay_max_steps: SingleOrList(Optional[int]) = None
# 定义一个方法 create,用于创建 ImagenTrainer 对象
def create(self):
# 将配置参数转换为字典
trainer_kwargs = self.dict()
# 弹出并获取 imagen 属性的值
imagen_config = trainer_kwargs.pop('imagen')
# 弹出并获取 elucidated 属性的值
elucidated = trainer_kwargs.pop('elucidated')
# 根据 elucidated 属性的值选择不同的配置类
imagen_config_klass = ElucidatedImagenConfig if elucidated else ImagenConfig
# 创建 imagen 对象,根据 video 属性的值选择不同的配置
imagen = imagen_config_klass(**{**imagen_config, 'video': video}).create()
# 返回创建的 ImagenTrainer 对象
return ImagenTrainer(imagen, **trainer_kwargs)
.\lucidrains\imagen-pytorch\imagen_pytorch\data.py
# 导入所需的库
from pathlib import Path
from functools import partial
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from imagen_pytorch import t5
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
# 导入自定义的文件工具函数
from datasets.utils.file_utils import get_datasets_user_agent
import io
import urllib
# 设置用户代理
USER_AGENT = get_datasets_user_agent()
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 无限循环生成数据集
def cycle(dl):
while True:
for data in dl:
yield data
# 将图像转换为指定类型
def convert_image_to(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
# 数据集、数据加载器、数据整理器
# 数据整理器类
class Collator:
def __init__(self, image_size, url_label, text_label, image_label, name, channels):
self.url_label = url_label
self.text_label = text_label
self.image_label = image_label
self.download = url_label is not None
self.name = name
self.channels = channels
self.transform = T.Compose([
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor(),
])
def __call__(self, batch):
texts = []
images = []
for item in batch:
try:
if self.download:
image = self.fetch_single_image(item[self.url_label])
else:
image = item[self.image_label]
image = self.transform(image.convert(self.channels))
except:
continue
text = t5.t5_encode_text([item[self.text_label]], name=self.name)
texts.append(torch.squeeze(text))
images.append(image)
if len(texts) == 0:
return None
texts = pad_sequence(texts, True)
newbatch = []
for i in range(len(texts)):
newbatch.append((images[i], texts[i]))
return torch.utils.data.dataloader.default_collate(newbatch)
def fetch_single_image(self, image_url, timeout=1):
try:
request = urllib.request.Request(
image_url,
data=None,
headers={"user-agent": USER_AGENT},
)
with urllib.request.urlopen(request, timeout=timeout) as req:
image = Image.open(io.BytesIO(req.read())).convert('RGB')
except Exception:
image = None
return image
# 数据集���
class Dataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
convert_image_to_type = None
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()
self.transform = T.Compose([
T.Lambda(convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# 获取图像数据加载器
def get_images_dataloader(
folder,
*,
batch_size,
image_size,
shuffle = True,
cycle_dl = False,
pin_memory = True
):
ds = Dataset(folder, image_size)
dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
if cycle_dl:
dl = cycle(dl)
return dl
.\lucidrains\imagen-pytorch\imagen_pytorch\elucidated_imagen.py
# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 random 模块中导入 random 函数
from random import random
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 contextlib 模块中导入 contextmanager 和 nullcontext
from contextlib import contextmanager, nullcontext
# 从 typing 模块中导入 List 和 Union
from typing import List, Union
# 从 collections 模块中导入 namedtuple
from collections import namedtuple
# 从 tqdm.auto 模块中导入 tqdm 函数
from tqdm.auto import tqdm
# 导入 torch 库
import torch
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 torch.nn.parallel 模块中导入 DistributedDataParallel 类
from torch.nn.parallel import DistributedDataParallel
# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 导入 kornia.augmentation 模块
import kornia.augmentation as K
# 从 einops 模块中导入 rearrange、repeat 和 reduce 函数
from einops import rearrange, repeat, reduce
# 从 imagen_pytorch.imagen_pytorch 模块中导入各种函数和类
from imagen_pytorch.imagen_pytorch import (
GaussianDiffusionContinuousTimes,
Unet,
NullUnet,
first,
exists,
identity,
maybe,
default,
cast_tuple,
cast_uint8_images_to_float,
eval_decorator,
pad_tuple_to_length,
resize_image_to,
calc_all_frame_dims,
safe_get_tuple_index,
right_pad_dims_to,
module_device,
normalize_neg_one_to_one,
unnormalize_zero_to_one,
compact,
maybe_transform_dict_key
)
# 从 imagen_pytorch.imagen_video 模块中导入 Unet3D、resize_video_to 和 scale_video_time 函数
from imagen_pytorch.imagen_video import (
Unet3D,
resize_video_to,
scale_video_time
)
# 从 imagen_pytorch.t5 模块中导入 t5_encode_text、get_encoded_dim 和 DEFAULT_T5_NAME 常量
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
# 定义常量 Hparams_fields
Hparams_fields = [
'num_sample_steps',
'sigma_min',
'sigma_max',
'sigma_data',
'rho',
'P_mean',
'P_std',
'S_churn',
'S_tmin',
'S_tmax',
'S_noise'
]
# 创建命名元组 Hparams
Hparams = namedtuple('Hparams', Hparams_fields)
# 定义辅助函数 log
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# 主类 ElucidatedImagen
class ElucidatedImagen(nn.Module):
# 初始化方法
def __init__(
self,
unets,
*,
image_sizes, # 用于级联 ddpm 的图像大小
text_encoder_name = DEFAULT_T5_NAME,
text_embed_dim = None,
channels = 3,
cond_drop_prob = 0.1,
random_crop_sizes = None,
resize_mode = 'nearest',
temporal_downsample_factor = 1,
resize_cond_video_frames = True,
lowres_sample_noise_level = 0.2, # 低分辨率采样噪声级别
per_sample_random_aug_noise_level = False, # 是否在每个批次元素上接收随机增强噪声值
condition_on_text = True,
auto_normalize_img = True, # 是否自动归一化图像
dynamic_thresholding = True,
dynamic_thresholding_percentile = 0.95, # 动态阈值百分位数
only_train_unet_number = None,
lowres_noise_schedule = 'linear',
num_sample_steps = 32, # 采样步数
sigma_min = 0.002, # 最小噪声水平
sigma_max = 80, # 最大噪声水平
sigma_data = 0.5, # 数据分布的标准差
rho = 7, # 控制采样计划
P_mean = -1.2, # 训练时噪声抽取的对数正态分布均值
P_std = 1.2, # 训练时噪声抽取的对数正态分布标准差
S_churn = 80, # 随机采样参数
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
# 强制取消条件性
def force_unconditional_(self):
self.condition_on_text = False
self.unconditional = True
for unet in self.unets:
unet.cond_on_text = False
# 返回属性 device 的值
@property
def device(self):
return self._temp.device
# 获取指定编号的 UNet 模型
def get_unet(self, unet_number):
# 确保 unet_number 在有效范围内
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
# 如果 self.unets 是 nn.ModuleList 类型,则转换为列表
if isinstance(self.unets, nn.ModuleList):
unets_list = [unet for unet in self.unets]
# 删除属性 'unets'
delattr(self, 'unets')
self.unets = unets_list
# 如果 index 不等于正在训练的 UNet 索引,则将 UNet 移动到指定设备
if index != self.unet_being_trained_index:
for unet_index, unet in enumerate(self.unets):
unet.to(self.device if unet_index == index else 'cpu')
self.unet_being_trained_index = index
return self.unets[index]
# 将所有 UNet 模型重置到同一设备上
def reset_unets_all_one_device(self, device = None):
device = default(device, self.device)
self.unets = nn.ModuleList([*self.unets])
self.unets.to(device)
self.unet_being_trained_index = -1
# 使用上下文管理器将指定 UNet 移动到 GPU 上
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.unets[unet_number - 1]
cpu = torch.device('cpu')
devices = [module_device(unet) for unet in self.unets]
self.unets.to(cpu)
unet.to(self.device)
yield
for unet, device in zip(self.unets, devices):
unet.to(device)
# 重写 state_dict 函数
def state_dict(self, *args, **kwargs):
self.reset_unets_all_one_device()
return super().state_dict(*args, **kwargs)
# 重写 load_state_dict 函数
def load_state_dict(self, *args, **kwargs):
self.reset_unets_all_one_device()
return super().load_state_dict(*args, **kwargs)
# 动态阈值
def threshold_x_start(self, x_start, dynamic_threshold = True):
if not dynamic_threshold:
return x_start.clamp(-1., 1.)
s = torch.quantile(
rearrange(x_start, 'b ... -> b (...)').abs(),
self.dynamic_thresholding_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = right_pad_dims_to(x_start, s)
return x_start.clamp(-s, s) / s
# 衍生的预处理参数 - 表 1
def c_skip(self, sigma_data, sigma):
return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2)
def c_out(self, sigma_data, sigma):
return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5
def c_in(self, sigma_data, sigma):
return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5
def c_noise(self, sigma):
return log(sigma) * 0.25
# 预处理网络输出
def preconditioned_network_forward(
self,
unet_forward,
noised_images,
sigma,
*,
sigma_data,
clamp = False,
dynamic_threshold = True,
**kwargs
):
batch, device = noised_images.shape[0], noised_images.device
if isinstance(sigma, float):
sigma = torch.full((batch,), sigma, device = device)
padded_sigma = self.right_pad_dims_to_datatype(sigma)
net_out = unet_forward(
self.c_in(sigma_data, padded_sigma) * noised_images,
self.c_noise(sigma),
**kwargs
)
out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out
if not clamp:
return out
return self.threshold_x_start(out, dynamic_threshold)
# 采样
# 采样计划
def sample_schedule(
self,
num_sample_steps,
rho,
sigma_min,
sigma_max
):
N = num_sample_steps
inv_rho = 1 / rho
# 生成一个包含 num_sample_steps 个元素的张量,设备为 self.device,数据类型为 torch.float32
steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
# 计算每个步骤的 sigma 值
sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
# 在 sigmas 张量的末尾填充一个值为 0 的元素,用于表示最后一个步骤的 sigma 值为 0
sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
return sigmas
@torch.no_grad()
def one_unet_sample(
self,
unet,
shape,
*,
unet_number,
clamp = True,
dynamic_threshold = True,
cond_scale = 1.,
use_tqdm = True,
inpaint_videos = None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
skip_steps = None,
sigma_min = None,
sigma_max = None,
**kwargs
@torch.no_grad()
@eval_decorator
def sample(
self,
texts: List[str] = None,
text_masks = None,
text_embeds = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
inpaint_videos = None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
skip_steps = None,
sigma_min = None,
sigma_max = None,
video_frames = None,
batch_size = 1,
cond_scale = 1.,
lowres_sample_noise_level = None,
start_at_unet_number = 1,
start_image_or_video = None,
stop_at_unet_number = None,
return_all_unet_outputs = False,
return_pil_images = False,
use_tqdm = True,
use_one_unet_in_gpu = True,
device = None,
# training
# 计算损失权重
def loss_weight(self, sigma_data, sigma):
return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2
# 生成服从指定均值和标准差的噪声分布
def noise_distribution(self, P_mean, P_std, batch_size):
return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp()
def forward(
self,
images, # 重命名为 images 或 video
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
texts: List[str] = None,
text_embeds = None,
text_masks = None,
unet_number = None,
cond_images = None,
**kwargs