Lucidrains-系列项目源码解析-五十一-
Lucidrains 系列项目源码解析(五十一)
.\lucidrains\x-transformers\x_transformers\xval.py
"""
定义了一个基于离散标记的常规变换器,但对于数字是连续的
更好地泛化了算术
https://arxiv.org/abs/2310.02989
"""
# 导入所需的库
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Callable
from collections import namedtuple
from einops import rearrange
from einops.layers.torch import Rearrange
from x_transformers.x_transformers import (
AttentionLayers,
TokenEmbedding,
ScaledSinusoidalEmbedding,
AbsolutePositionalEmbedding
)
from x_transformers.autoregressive_wrapper import (
top_k,
top_p
)
# 常量
# 定义一个命名元组,用于表示损失的细分
LossBreakdown = namedtuple('LossBreakdown', ['cross_entropy_loss', 'numerical_mse_loss'])
# 定义一个命名元组,用于表示生成的返回结果
GenerateReturn = namedtuple('GenerateReturn', ['sampled_token_ids', 'sampled_numbers', 'is_number_mask'])
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 主要类
class XValTransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
numerical_token_id,
attn_layers: AttentionLayers,
emb_dim = None,
logits_dim = None,
tie_embedding = False,
max_mem_len = 0,
num_memory_tokens = None,
emb_dropout = 0.,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False
):
super().__init__()
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.emb_dim = emb_dim
self.token_emb = TokenEmbedding(emb_dim, num_tokens)
self.numerical_token_id = numerical_token_id
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
if not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb):
self.pos_emb = always(0) # 如果不使用绝对位置编码或者禁用了绝对位置编码,则将位置编码设置为常数0
elif scaled_sinu_pos_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim) # 如果使用了缩放的正弦位置编码,则使用缩放的正弦位置编码
else:
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) # 否则使用绝对位置编码
self.emb_dropout = nn.Dropout(emb_dropout)
# 内存标记
num_memory_tokens = default(num_memory_tokens, 0)
self.has_memory_tokens = num_memory_tokens > 0
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) # 初始化内存标记
# 注意力层
self.attn_layers = attn_layers
# 转换为logits
logits_dim = default(logits_dim, num_tokens)
self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
self.to_numerical_output = nn.Sequential(
nn.Linear(dim, 1),
Rearrange('... 1 -> ...')
)
def forward(
self,
x: Tensor,
x_num: Tensor,
return_embeddings = False,
return_intermediates = False,
return_mems = False,
mask = None,
return_attn = False,
mems = None,
pos = None,
prepend_embeds = None,
**kwargs
):
# 断言输入张量 x 的形状与 x_num 的形状相同
assert x.shape == x_num.shape
# 获取批次大小
batch = x.shape[0]
# 创建数值标记掩码
is_number_mask = x == self.numerical_token_id
# 对输入进行 token 嵌入
x = self.token_emb(x)
# 根据数值标记掩码调整缩放因子
scale = torch.where(is_number_mask, x_num, 1.)
# 重新排列张量维度,添加一个维度
scale = rearrange(scale, '... -> ... 1')
# 对输入进行缩放
x = x * scale
# 添加位置嵌入
x = x + self.pos_emb(x, pos = pos)
# 存储记忆令牌
if self.has_memory_tokens:
# 复制记忆令牌,扩展为与批次大小相同的维度
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
# 打包输入张量和记忆令牌
x, mem_ps = pack([m, x], 'b * d')
if exists(mask):
num_mems = m.shape[-2]
# 在指定维度上填充掩码
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
# 是否追加嵌入,如 PaLI 中的图像嵌入
if exists(prepend_embeds):
_, prepend_dim = prepend_embeds.shape[1:]
# 断言追加的嵌入维度与模型维度相同
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
# 在指定维度上连接张量
x = torch.cat((prepend_embeds, x), dim = -2)
# 对输入进行嵌入层的 dropout
x = self.emb_dropout(x)
# 注意力层
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
# 分离记忆令牌
if self.has_memory_tokens:
m, x = unpack(x, mem_ps, 'b * d')
intermediates.memory_tokens = m
# 如果不返回嵌入,则生成 logits 和数值预测
if not return_embeddings:
logits = self.to_logits(x)
numerical_pred = self.to_numerical_output(x)
out = (logits, numerical_pred)
else:
out = x
# 如果返回中间结果
if return_intermediates:
return out, intermediates
# 如果返回记忆令牌
if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
return out, new_mems
# 如果返回注意力图
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
return out, attn_maps
return out
class XValAutoregressiveWrapper(nn.Module):
# 定义 XValAutoregressiveWrapper 类,继承自 nn.Module
def __init__(
self,
net: XValTransformerWrapper,
ignore_index = -100,
pad_value = 0,
numerical_loss_weight = 1.
):
# 初始化函数,接受网络 net、ignore_index、pad_value 和 numerical_loss_weight 参数
super().__init__()
# 调用父类的初始化函数
self.net = net
# 将传入的网络赋值给对象属性 net
self.max_seq_len = net.max_seq_len
# 获取网络的最大序列长度
self.numerical_loss_weight = numerical_loss_weight
# 设置数值损失的权重
self.ignore_index = ignore_index
# 设置忽略的索引值
@torch.no_grad()
def generate(
self,
start_tokens: Tensor,
start_numbers: Tensor,
seq_len,
filter_logits_fn: Callable = top_k,
filter_kwargs: dict = dict(),
temperature = 1.,
**kwargs
):
# 生成函数,接受起始标记、起始数字、序列长度等参数
device = start_tokens.device
# 获取起始标记所在设备
was_training = self.net.training
# 保存网络是否处于训练状态
num_dims = len(start_tokens.shape)
# 获取起始标记的维度数
assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'
# 断言起始标记的维度数至少为 2
assert start_tokens.shape == start_numbers.shape
# 断言起始标记和起始数字的形状相同
b, t, device = *start_tokens.shape, start_tokens.device
# 获取起始标记的形状和设备信息
self.net.eval()
# 将网络设置为评估模式
out = start_tokens
num_out = start_numbers
# 初始化输出和数字输出
for _ in range(seq_len):
# 循环生成序列
x = out[:, -self.max_seq_len:]
x_num = num_out[:, -self.max_seq_len:]
# 获取最后 max_seq_len 个标记和数字
logits, numerical_pred = self.net(x, x_num, **kwargs)
# 使用网络生成 logits 和数值预测
last_logits = logits[:, -1]
last_num_pred = numerical_pred[:, -1:]
# 获取最后一个 logits 和数值预测
filtered_logits = filter_logits_fn(last_logits, **filter_kwargs)
# 使用过滤函数过滤 logits
probs = F.softmax(filtered_logits / temperature, dim=-1)
# 计算 softmax 概率
sample = torch.multinomial(probs, 1)
# 从概率分布中采样一个标记
out = torch.cat((out, sample), dim = -1)
num_out = torch.cat((num_out, last_num_pred), dim = -1)
# 将新生成的标记和数值添加到输出中
out = out[:, t:]
num_out = num_out[:, t:]
# 去除起始标记
is_number = out == self.net.numerical_token_id
# 判断是否为数值标记
num_out = torch.where(is_number, num_out, float('nan'))
# 将非数值标记的数值设置为 NaN
self.net.train(was_training)
# 恢复网络的训练状态
return GenerateReturn(out, num_out, is_number)
# 返回生成的序列和数值信息
def forward(
self,
x: Tensor,
x_num: Tensor,
return_loss_breakdown = False,
**kwargs
):
# 前向传播函数,接受输入 x、数值输入 x_num 和其他参数
inp, target = x[:, :-1], x[:, 1:]
# 获取输入和目标序列
x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
# 获取数值输入和数值目标
mask = kwargs.get('mask', None)
# 获取掩码
if exists(mask) and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
kwargs['mask'] = mask
# 处理掩码
logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
# 使用网络进行前向传播
logits = rearrange(logits, 'b n c -> b c n')
# 重新排列 logits 的维度
cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index)
# 计算交叉熵损失
target_mask = target != self.ignore_index
# 创建目标掩码
numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
# 计算数值均方误差损失
numerical_mse_loss = numerical_mse_loss * target_mask
# 根据目标掩码调整数值损失
loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
# 计算总损失
if exists(mask):
loss = loss[mask]
# 根据掩码筛选损失
loss = loss.mean()
# 计算平均损失
if not return_loss_breakdown:
return loss
# 如果不需要详细损失信息,直接返回总损失
return loss, LossBreakdown(cross_entropy_loss, numerical_mse_loss)
# 返回总损失和损失细分信息
.\lucidrains\x-transformers\x_transformers\x_transformers.py
# 导入数学库
import math
# 从 random 模块中导入 random 函数
from random import random
# 从 typing 模块中导入 Dict 类型提示
from typing import Dict
# 从 packaging 模块中导入 version 版本信息
from packaging import version
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 导入 functools 模块中的 partial, wraps 函数
from functools import partial, wraps
# 导入 collections 模块中的 namedtuple 类
from collections import namedtuple
# 导入 dataclasses 模块中的 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块中导入 List, Callable, Optional, Union 类型提示
from typing import List, Callable, Optional, Union
# 从 einops 模块中导入 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange
# 从 x_transformers.attend 模块中导入 Attend, Intermediates 类
from x_transformers.attend import Attend, Intermediates
# 从 x_transformers.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类
# 常量定义
# 默认头部维度
DEFAULT_DIM_HEAD = 64
# 定义 LayerIntermediates 数据类
@dataclass
class LayerIntermediates:
hiddens: Optional[List[Tensor]] = None # 所有隐藏层,在最终规范化之前(在预规范化架构中)
last_hidden: Optional[Tensor] = None # 所有注意力层之后的最后一个隐藏层,在最终规范化之后
attn_intermediates: Optional[List[Intermediates]] = None
layer_hiddens: Optional[List[Tensor]] = None
attn_z_loss: Optional[Tensor] = None
mems: Optional[Tensor] = None
memory_tokens: Optional[Tensor] = None
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 将变量转换为元组
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
return (num % den) == 0
# 如果变量存在则执行函数
def maybe(fn):
@wraps(fn)
def inner(x, *args, **kwargs):
if not exists(x):
return x
return fn(x, *args, **kwargs)
return inner
# 至多一个为真
def at_most_one_of(*bools):
return sum(map(int, bools)) <= 1
# 始终返回相同值
class always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
# 不等于某个值
class not_equals():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x != self.val
# 等于某个值
class equals():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x == self.val
# 创建序列模块
def Sequential(*modules):
return nn.Sequential(*filter(exists, modules))
# 张量辅助函数
# 返回张量的最小负值
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# 对张量进行 L2 归一化
def l2norm(t, groups = 1):
t = rearrange(t, '... (g d) -> ... g d', g = groups)
t = F.normalize(t, p = 2, dim = -1)
return rearrange(t, '... g d -> ... (g d)')
# 在指定维度上填充张量
def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
# 对多个掩码进行逻辑或操作
def or_reduce(masks):
head, *body = masks
for rest in body:
head = head | rest
return head
# 辅助损失函数
# 计算 z 损失
def calc_z_loss(
pre_softmax_attns: List[Tensor],
mask = None,
weight = 1.
):
# 在 https://arxiv.org/abs/2202.08906 中应用于专家混合路由器对数的相同损失
# 在论文中,他们在一个小脚注中提到将其应用于注意力对数,具有稳定效果
# 在 PaLM 中也作为措施之一使用
lse = 0.
for attn in pre_softmax_attns:
lse = lse + attn.logsumexp(dim = -1)
loss = torch.square(lse)
loss = reduce(loss, 'b h n -> b n', 'sum')
if not exists(mask):
return loss.mean() * weight
loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
return loss * weight
# 初始化辅助函数
# 初始化为零
def init_zero_(layer):
nn.init.constant_(layer.weight, 0.)
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
# 关键字参数辅助函数
# 选择并弹出键值对
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
# 根据条件将字典分组
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
return str.startswith(prefix)
# 根据键的前缀对字典进行分组
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
# 根据前缀对字典进行分组并修剪前缀
def groupby_prefix_and_trim(prefix, d):
# 根据前缀对字典进行分组
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
# 剔除前缀,生成新的字典
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
return kwargs_without_prefix, kwargs
# 结构化的 dropout,比传统的注意力 dropout 更有效
def dropout_seq(seq, mask, dropout):
# 获取序列的形状和设备信息
b, n, *_, device = *seq.shape, seq.device
# 生成服从标准正态分布的随机数
logits = torch.randn(b, n, device=device)
# 如果存在掩码
if exists(mask):
# 获取 logits 中的最大负值
mask_value = max_neg_value(logits)
# 使用 mask_value 替换掩码为 False 的位置
logits = logits.masked_fill(~mask, mask_value)
# 计算保留的概率和保留的数量
keep_prob = 1. - dropout
num_keep = max(1, int(keep_prob * n))
keep_indices = logits.topk(num_keep, dim=1).indices
# 生成批次索引
batch_indices = torch.arange(b, device=device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
# 根据保留的索引获取序列的子集
seq = seq[batch_indices, keep_indices]
# 如果存在掩码
if exists(mask):
# 计算序列中每个样本的非零元素数量
seq_counts = mask.sum(dim=-1)
# 计算保留的元素数量
seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
keep_mask = torch.arange(num_keep, device=device) < rearrange(seq_keep_counts, 'b -> b 1')
# 更新掩码
mask = mask[batch_indices, keep_indices] & keep_mask
return seq, mask
# 激活函数
class ReluSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
# 词嵌入
class TokenEmbedding(nn.Module):
def __init__(self, dim, num_tokens, l2norm_embed=False):
super().__init__()
self.l2norm_embed = l2norm_embed
self.emb = nn.Embedding(num_tokens, dim)
def forward(self, x):
token_emb = self.emb(x.long())
return l2norm(token_emb) if self.l2norm_embed else token_emb
# 绝对位置嵌入
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len, l2norm_embed=False):
super().__init__()
self.scale = dim ** -0.5 if not l2norm_embed else 1.
self.max_seq_len = max_seq_len
self.l2norm_embed = l2norm_embed
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x, pos=None, seq_start_pos=None):
seq_len, device = x.shape[1], x.device
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
if not exists(pos):
pos = torch.arange(seq_len, device=device)
if exists(seq_start_pos):
pos = (pos - seq_start_pos[..., None]).clamp(min=0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return l2norm(pos_emb) if self.l2norm_embed else pos_emb
# 缩放的正弦位置嵌入
class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta=10000):
super().__init__()
assert divisible_by(dim, 2)
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
half_dim = dim // 2
freq_seq = torch.arange(half_dim).float() / half_dim
inv_freq = theta ** -freq_seq
self.register_buffer('inv_freq', inv_freq, persistent=False)
def forward(self, x, pos=None, seq_start_pos=None):
seq_len, device = x.shape[1], x.device
if not exists(pos):
pos = torch.arange(seq_len, device=device)
if exists(seq_start_pos):
pos = pos - seq_start_pos[..., None]
emb = einsum('i, j -> i j', pos, self.inv_freq)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb * self.scale
class RelativePositionBias(nn.Module):
# 初始化函数,设置模型参数
def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
# 调用父类的初始化函数
super().__init__()
# 设置模型的缩放比例
self.scale = scale
# 设置是否使用因果关系
self.causal = causal
# 设置桶的数量
self.num_buckets = num_buckets
# 设置最大距离
self.max_distance = max_distance
# 创建相对注意力偏置的嵌入层
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
# 静态方法,用于计算相对位置的桶索引
@staticmethod
def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
# 初始化返回值
ret = 0
# 计算相对位置的负值
n = -relative_position
# 如果不是因果关系,调整桶的数量
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
# 计算最大精确值
max_exact = num_buckets // 2
is_small = n < max_exact
# 计算大值时的桶索引
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
# 根据大小选择桶索引
ret += torch.where(is_small, n, val_if_large)
return ret
# 返回设备信息
@property
def device(self):
return next(self.parameters()).device
# 前向传播函数
def forward(self, i, j):
# 获取设备信息
device = self.device
# 生成查询位置
q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
# 生成键位置
k_pos = torch.arange(j, dtype = torch.long, device = device)
# 计算相对位置
rel_pos = k_pos[None, :] - q_pos[:, None]
# 计算相对位置的桶索引
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
# 获取相对注意力偏置值
values = self.relative_attention_bias(rp_bucket)
# 重排形状
bias = rearrange(values, 'i j h -> h i j')
return bias * self.scale
class DynamicPositionBias(nn.Module):
# 定义动态位置偏置类,继承自 nn.Module
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
# 初始化函数,接受维度、头数、深度、是否对距离取对数、是否进行归一化等参数
super().__init__()
# 调用父类的初始化函数
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
# 断言深度必须大于等于1
self.log_distance = log_distance
# 设置是否对距离取对数的标志
self.mlp = nn.ModuleList([])
# 初始化多层感知机模块列表
self.mlp.append(Sequential(
nn.Linear(1, dim),
LayerNorm(dim) if norm else None,
nn.SiLU()
))
# 向多层感知机模块列表中添加线性层、归一化层和激活函数
for _ in range(depth - 1):
self.mlp.append(Sequential(
nn.Linear(dim, dim),
nn.LayerNorm(dim) if norm else None,
nn.SiLU()
))
# 根据深度循环添加多层感知机模块
self.mlp.append(nn.Linear(dim, heads)
# 向多层感知机模块列表中添加线性层,输出头数
@property
def device(self):
# 定义设备属性,返回参数的设备
return next(self.parameters()).device
def forward(self, i, j):
# 前向传播函数,接受输入i和j
assert i == j
# 断言i等于j
n, device = j, self.device
# 设置n为j,获取设备信息
# get the (n x n) matrix of distances
# 获取距离的(n x n)矩阵
seq_arange = torch.arange(n, device = device)
context_arange = torch.arange(n, device = device)
indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
indices += (n - 1)
# input to continuous positions MLP
# 连续位置多层感知机的输入
pos = torch.arange(-n + 1, n, device = device).float()
pos = rearrange(pos, '... -> ... 1')
if self.log_distance:
pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
# 如果需要对距离取对数,则进行对数操作
for layer in self.mlp:
pos = layer(pos)
# 遍历多层感知机模块列表,对位置进行处理
# get position biases
# 获取位置偏置
bias = pos[indices]
bias = rearrange(bias, 'i j h -> h i j')
return bias
# 返回位置偏置
class AlibiPositionalBias(nn.Module):
# 定义Alibi位置偏置类,继承自 nn.Module
def __init__(self, heads, total_heads, **kwargs):
# 初始化函数,接受头数和总头数等参数
super().__init__()
# 调用父类的初始化函数
self.heads = heads
self.total_heads = total_heads
slopes = Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> h 1 1')
self.register_buffer('slopes', slopes, persistent = False)
self.register_buffer('bias', None, persistent = False)
# 初始化斜率和偏置
def get_bias(self, i, j, device):
# 定义获取偏置的函数,接受i、j和设备参数
i_arange = torch.arange(j - i, j, device = device)
j_arange = torch.arange(j, device = device)
bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
return bias
# 返回偏置
@staticmethod
def _get_slopes(heads):
# 定义获取斜率的静态方法,接受头数参数
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
# 定义获取2的幂次方斜率的函数
if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
# 如果头数是2的幂次方,则返回对应斜率
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
# 否则返回最接近的2的幂次方的斜率和补充的斜率
@property
def device(self):
# 定义设备属性,返回缓冲区的设备
return next(self.buffers()).device
def forward(self, i, j):
# 前向传播函数,接受输入i和j
h, device = self.total_heads, self.device
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
return self.bias[..., -i:, -j:]
# 如果偏置存在且形状符合要求,则返回偏置
bias = self.get_bias(i, j, device)
bias = bias * self.slopes
# 计算偏置并乘以斜率
num_heads_unalibied = h - bias.shape[0]
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
self.register_buffer('bias', bias, persistent = False)
# 对未校准的头数进行填充
return self.bias
# 返回偏置
class RotaryEmbedding(nn.Module):
# 定义旋转嵌入类,继承自 nn.Module
def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.
# 初始化函数,接受维度、是否使用x位置、缩放基数、插值因子、基数和基数重新缩放因子等参数
):
# 调用父类的构造函数
super().__init__()
# 根据 reddit 用户 bloc97 的建议,将旋转嵌入重新缩放到更长的序列长度,而无需微调
# 与 NTK 文献有一定联系
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
# 计算频率的倒数
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
# 将频率的倒数作为缓冲区
self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1.
# 设置插值因子
self.interpolation_factor = interpolation_factor
if not use_xpos:
# 如果不使用 xpos,则将缩放设置为 None
self.register_buffer('scale', None)
return
# 计算缩放
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = scale_base
# 将缩放作为缓冲区
self.register_buffer('scale', scale)
# 根据序列长度进行前向传播
def forward_from_seq_len(self, seq_len):
device = self.inv_freq.device
t = torch.arange(seq_len, device = device)
return self.forward(t)
# 禁用自动混合精度
@autocast(enabled = False)
def forward(self, t):
# 计算最大位置
max_pos = t.max()+1
# 计算频率
freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
freqs = torch.cat((freqs, freqs), dim = -1)
if not exists(self.scale):
return freqs, 1.
# 计算幂次
power = (t - (max_pos // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
# 定义一个函数,将输入张量 x 进行重新排列,将最后两个维度中的第一个维度 j 换到倒数第二个维度
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
# 将 x 拆分为两部分 x1 和 x2,根据倒数第二个维度进行拆分
x1, x2 = x.unbind(dim = -2)
# 将 x2 取负值,然后与 x1 进行拼接,得到旋转后的张量
return torch.cat((-x2, x1), dim = -1)
# 定义一个函数,应用旋转位置嵌入到输入张量 t 上
@autocast(enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
# 获取旋转维度和序列长度
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
# 截取与序列长度相同的频率信息
freqs = freqs[-seq_len:, :]
# 如果 scale 是张量,则截取与序列长度相同的部分
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
# 如果输入张量 t 和频率信息 freqs 的维度分别为 4 和 3
if t.ndim == 4 and freqs.ndim == 3:
# 将频率信息维度扩展为 4 维
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# 部分旋转嵌入,Wang et al. GPT-J
# 将输入张量 t 拆分为旋转部分 t 和未旋转部分 t_unrotated
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
# 计算旋转后的张量 t
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
# 将旋转后的张量 t 与未旋转部分拼接,返回结果
return torch.cat((t, t_unrotated), dim = -1)
# norms
# 定义一个缩放层,用于对输入进行缩放
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
# 对输入进行处理
out = self.fn(x, **kwargs)
# 定义缩放函数
scale_fn = lambda t: t * self.value
# 如果输出不是元组,则对输出进行缩放处理
if not isinstance(out, tuple):
return scale_fn(out)
# 如果输出是元组,则对第一个元素进行缩放处理
return (scale_fn(out[0]), *out[1:])
# 定义一个缩放归一化层
class ScaleNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
def forward(self, x):
# 计算输入张量的范数,并进行归一化处理
norm = torch.norm(x, dim = -1, keepdim = True)
return x / norm.clamp(min = self.eps) * self.g
# 定义一个 LayerNorm ��
class LayerNorm(nn.Module):
def __init__(self, dim):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
"""
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
# 使用 F.layer_norm 进行 LayerNorm 处理
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# 如果 PyTorch 版本大于等于 2.1.0,则将 LayerNorm 替换为 nn.LayerNorm,并设置 bias 为 False
if version.parse(torch.__version__) >= version.parse('2.1.0'):
LayerNorm = partial(nn.LayerNorm, bias = False)
# 定义一个 RMSNorm 层
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
# 对输入进行归一化处理,并乘以缩放因子和参数 g
return F.normalize(x, dim = -1) * self.scale * self.g
# 定义一个简单的 RMSNorm 层
class SimpleRMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
def forward(self, x):
# 对输入进行归一化处理,并乘以缩放因子
return F.normalize(x, dim = -1) * self.scale
# residual and residual gates
# 定义一个残差连接层
class Residual(nn.Module):
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
super().__init__()
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
self.scale_residual_constant = scale_residual_constant
def forward(self, x, residual):
# 如果存在残差缩放参数,则对残差进行缩放处理
if exists(self.residual_scale):
residual = residual * self.residual_scale
# 如果缩放常数不为 1,则对残差进行缩放处理
if self.scale_residual_constant != 1:
residual = residual * self.scale_residual_constant
# 返回残差连接结果
return x + residual
# 定义一个 GRU 门控单元层
class GRUGating(nn.Module):
def __init__(self, dim, scale_residual = False, **kwargs):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
def forward(self, x, residual):
# 如果存在残差缩放参数,则对残差进行缩放处理
if exists(self.residual_scale):
residual = residual * self.residual_scale
# 使用 GRU 单元进行门控处理
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)
# 将门控输出重塑为与输入相同的形状
return gated_output.reshape_as(x)
# token shifting
# 定义一个函数,对输入张量进行平移操作
def shift(t, amount, mask = None):
if amount == 0:
return t
else:
# 如果平移量大于输入张量的长度,则取最大值
amount = min(amount, t.shape[1])
# 如果存在掩码,则对输入张量进行掩码填充
if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
# 在指定维度上对输入张量进行填充操作
return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
# 定义一个 ShiftTokens 类,用于对输入进行平移操作
class ShiftTokens(nn.Module):
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
self.shifts = tuple(shifts)
# 定义一个前向传播函数,接受输入 x 和关键字参数 kwargs
def forward(self, x, **kwargs):
# 从关键字参数 kwargs 中获取名为 'mask' 的值,如果没有则为 None
mask = kwargs.get('mask', None)
# 获取位移列表
shifts = self.shifts
# 计算段数
segments = len(shifts)
# 计算每个段的特征数
feats_per_shift = x.shape[-1] // segments
# 将输入 x 按特征数分割成多个张量
splitted = x.split(feats_per_shift, dim=-1)
# 将分割后的张量分为需要进行位移的段和剩余部分
segments_to_shift, rest = splitted[:segments], splitted[segments:]
# 对需要进行位移的段进行位移操作,使用 map 函数和 lambda 表达式
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
# 将位移后的段和剩余部分拼接在一起
x = torch.cat((*segments_to_shift, *rest), dim=-1)
# 调用 self.fn 函数对拼接后的张量进行处理,返回结果
return self.fn(x, **kwargs)
# 定义 GLU 类,用于实现门控线性单元
class GLU(nn.Module):
def __init__(
self,
dim_in,
dim_out,
activation: Callable,
mult_bias = False
):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2)
self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
# 前向传播函数
def forward(self, x):
# 将输入通过线性变换后分成两部分
x, gate = self.proj(x).chunk(2, dim = -1)
# 返回门控线性单元的输出
return x * self.act(gate) * self.mult_bias
# 定义 FeedForward 类,用于实现前馈神经网络
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out = None,
mult = 4,
glu = False,
glu_mult_bias = False,
swish = False,
relu_squared = False,
post_act_ln = False,
dropout = 0.,
no_bias = False,
zero_init_output = False
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
# 根据参数选择激活函数
if relu_squared:
activation = ReluSquared()
elif swish:
activation = nn.SiLU()
else:
activation = nn.GELU()
# 根据参数选择网络结构
if glu:
project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
else:
project_in = nn.Sequential(
nn.Linear(dim, inner_dim, bias = not no_bias),
activation
)
# 构建前馈神经网络
self.ff = Sequential(
project_in,
LayerNorm(inner_dim) if post_act_ln else None,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out, bias = not no_bias)
)
# 初始化最后一层线性层的权重为0
if zero_init_output:
init_zero_(self.ff[-1])
# 前向传播函数
def forward(self, x):
return self.ff(x)
# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = DEFAULT_DIM_HEAD,
dim_context = None,
heads = 8,
causal = False,
flash = False,
talking_heads = False,
head_scale = False,
sparse_topk = None,
num_mem_kv = 0,
dropout = 0.,
on_attn = False,
gate_value_heads = False,
swiglu_values = False,
gate_values = False,
zero_init_output = False,
max_attend_past = None,
qk_norm = False,
qk_norm_groups = 1,
qk_norm_scale = 10,
qk_norm_dim_scale = False,
one_kv_head = False,
kv_heads = None,
shared_kv = False,
value_dim_head = None,
tensor_product = False, # https://arxiv.org/abs/2208.06061
add_zero_kv = False, # same as add_zero_attn in pytorch
rotary_embed_values = False,
onnxable = False
# 前向传播函数
def forward(
self,
x,
context = None,
mask = None,
context_mask = None,
attn_mask = None,
rel_pos = None,
rotary_pos_emb = None,
prev_attn = None,
mem = None,
mem_mask = None,
return_intermediates = False,
cache: Optional[Intermediates] = None,
class AttentionLayers(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
dim,
depth,
heads = 8,
causal = False,
cross_attend = False,
only_cross = False,
use_scalenorm = False,
use_rmsnorm = False,
use_simple_rmsnorm = False,
alibi_pos_bias = False,
alibi_num_heads = None,
rel_pos_bias = False,
rel_pos_num_buckets = 32,
rel_pos_max_distance = 128,
dynamic_pos_bias = False,
dynamic_pos_bias_log_distance = False,
dynamic_pos_bias_mlp_depth = 2,
dynamic_pos_bias_norm = False,
rotary_pos_emb = False,
rotary_emb_dim = None,
rotary_xpos = False,
rotary_interpolation_factor = 1.,
rotary_xpos_scale_base = 512,
rotary_base_rescale_factor = 1.,
custom_layers = None,
sandwich_coef = None,
par_ratio = None,
weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
residual_attn = False,
cross_residual_attn = False,
macaron = False,
pre_norm = True,
pre_norm_has_final_norm = True,
gate_residual = False,
scale_residual = False,
scale_residual_constant = 1.,
shift_tokens = 0,
sandwich_norm = False,
resi_dual = False,
resi_dual_scale = 1.,
zero_init_branch_output = False,
layer_dropout = 0.,
cross_attn_tokens_dropout = 0.,
disable_abs_pos_emb = None,
**kwargs
# 前向传播函数,接收输入数据并进行模型计算
def forward(
self,
x,
context = None,
mask = None,
context_mask = None,
attn_mask = None,
self_attn_kv_mask = None,
mems = None,
mem_masks = None,
seq_start_pos: Optional[Tensor] = None,
cache: Optional[LayerIntermediates] = None,
cache_age = 1,
return_hiddens = False,
rotary_pos_emb = None
class Encoder(AttentionLayers):
# 定义编码器类,继承自AttentionLayers类
def __init__(self, **kwargs):
# 初始化函数,接受任意关键字参数
assert 'causal' not in kwargs, 'cannot set causality on encoder'
# 断言关键字参数中不包含'causal',否则抛出异常
super().__init__(causal = False, **kwargs)
# 调用父类的初始化函数,设置causal参数为False,并传入其他关键字参数
class Decoder(AttentionLayers):
# 定义解码器类,继承自AttentionLayers类
def __init__(self, **kwargs):
# 初始化函数,接受任意关键字参数
assert 'causal' not in kwargs, 'cannot set causality on decoder'
# 断言关键字参数中不包含'causal',否则抛出异常
super().__init__(causal = True, **kwargs)
# 调用父类的初始化函数,设置causal参数为True,并传入其他关键字参数
class PrefixDecoder(AttentionLayers):
# 定义前缀解码器类,继承自AttentionLayers类
def __init__(self, **kwargs):
# 初始化函数,接受任意关键字参数
assert 'causal' not in kwargs, 'cannot set causality on decoder'
# 断言关键字参数中不包含'causal',否则抛出异常
super().__init__(causal = False, **kwargs)
# 调用父类的初始化函数,设置causal参数为False,并传入其他关键字参数
def forward(
self,
x,
*args,
attn_mask = None,
prefix_attn_len = None,
**kwargs
):
# 前向传播函数,接受输入x和任意位置参数args,注意力掩码attn_mask和前缀注意力长度prefix_attn_len,以及任意关键字参数kwargs
b, n, device = x.shape[0], x.shape[1], x.device
# 获取输入x的批量大小b,序列长度n,设备device
causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
# 创建一个全为1的张量作为因果掩码,上三角部分为True,下三角部分为False
forwarded_mask = ~causal_mask
# 计算非因果掩码,即上三角部分为False,下三角部分为True
if exists(prefix_attn_len):
# 如果前缀注意力长度存在
if isinstance(prefix_attn_len, int):
# 如果前缀注意力长度是整数
prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
# 创建一个形状为(b,)的张量,填充值为前缀注意力长度,设备为device
prefix_mask = torch.arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
# 创建前缀掩码,根据前缀注意���长度生成
forwarded_mask = forwarded_mask | prefix_mask
# 更新前向掩码,将前缀掩码应用到前向掩码中
if exists(attn_mask):
# 如果注意力掩码存在
forwarded_mask = forwarded_mask & attn_mask
# 更新前向掩码,将注意力掩码应用到前向掩码中
return super().forward(x, *args, attn_mask = forwarded_mask, **kwargs)
# 调用父类的前向传播函数,传入更新后的注意力掩码参数
class CrossAttender(AttentionLayers):
# 定义交叉注意力层类,继承自AttentionLayers类
def __init__(self, **kwargs):
# 初始化函数,接受任意关键字参数
super().__init__(cross_attend = True, only_cross = True, **kwargs)
# 调用父类的初始化函数,设置cross_attend和only_cross参数为True,并传入其他关键字参数
class ViTransformerWrapper(nn.Module):
# 定义ViTransformerWrapper类,继承自nn.Module类
def __init__(
self,
*,
image_size,
patch_size,
attn_layers: Encoder,
channels = 3,
num_classes = None,
post_emb_norm = False,
num_register_tokens = 0,
emb_dropout = 0.
):
# 初始化函数,接受命名关键字参数
super().__init__()
# 调用父类的初始化函数
assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
# 断言图像尺寸能被补丁尺寸整除,否则抛出异常
dim = attn_layers.dim
# 获取注意力层的维度
num_patches = (image_size // patch_size) ** 2
# 计算图像中的补丁数量
patch_dim = channels * patch_size ** 2
# 计算补丁的维度
self.patch_size = patch_size
# 设置对象属性patch_size为传入的补丁尺寸
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
# 创建位置嵌入参数,形状为(1, num_patches, dim),初始化为随机值
has_register_tokens = num_register_tokens > 0
# 判断是否存在注册令牌
self.has_register_tokens = has_register_tokens
# 设置对象属性has_register_tokens为判断结果
if has_register_tokens:
# 如果存在注册令牌
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
# 创建注册令牌参数,形状为(num_register_tokens, dim),初始化为随机值
self.patch_to_embedding = nn.Sequential(
LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
LayerNorm(dim)
)
# 创建补丁到嵌入的序列模块
self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
# 根据post_emb_norm参数选择是否进行嵌入后的归一化
self.dropout = nn.Dropout(emb_dropout)
# 创建丢弃层,用于嵌入的丢弃
self.attn_layers = attn_layers
# 设置对象属性attn_layers为传入的注意力层
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
# 创建MLP头部,根据是否存在类别数量选择是否添加线性层
def forward(
self,
img,
return_embeddings = False,
return_logits_and_embeddings = False
):
# 前向传播函数,接受输入图像img,返回嵌入、逻辑和嵌入的标志
b, p = img.shape[0], self.patch_size
# 获取输入图像的批量大小b和补丁大小p
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
# 重排输入图像,将其转换为形状为(b, h*w, p1*p2*c)的张量
x = self.patch_to_embedding(x)
# 将补丁转换为嵌入
n = x.shape[1]
# 获取嵌入的序列长度n
x = x + self.pos_embedding[:, :n]
# 添加位置嵌入到嵌入中
x = self.post_emb_norm(x)
# 对嵌入进行归一化
x = self.dropout(x)
# 对嵌入进行丢弃
if self.has_register_tokens:
# 如果存在注册令牌
r = repeat(self.register_tokens, 'n d -> b n d', b = b)
# 重复注册令牌,形状为(b, num_register_tokens, dim)
x, ps = pack((x, r), 'b * d')
# 打包嵌入和注册令牌
embed = self.attn_layers(x)
# 使用注意力层处理嵌入
if self.has_register_tokens:
# 如果存在注册令牌
embed, _ = unpack(embed, ps, 'b * d')
# 解包嵌入
assert at_most_one_of(return_embeddings, return_logits_and_embeddings)
# 断言返回嵌入和逻辑的标志中最多只有一个为True
if not exists(self.mlp_head) or return_embeddings:
# 如果MLP头部不存在或者需要返回嵌入
return embed
# 返回嵌入
pooled = embed.mean(dim = -2)
# 对嵌入进行平均池化
logits = self.mlp_head(pooled)
# 使用MLP头部生成逻辑
if not return_logits_and_embeddings:
# 如果不需要返回逻辑和嵌入
return logits
# 返回逻辑
return logits, embed
# 返回逻辑和嵌入
# 初始化函数,设置模型参数
def __init__(
self,
*,
num_tokens, # 令牌数量
max_seq_len, # 最大序列长度
attn_layers: AttentionLayers, # 注意力层对象
embed_num_tokens: Dict[str, int] = dict(), # 嵌入令牌数量的字典,默认为空
emb_dim = None, # 嵌入维度,默认为空
max_mem_len = 0, # 最大记忆长度,默认为0
shift_mem_down = 0, # 记忆向下移动的步数,默认为0
emb_dropout = 0., # 嵌入层的dropout率,默认为0
post_emb_norm = False, # 是否对嵌入后进行归一化,默认为False
num_memory_tokens = None, # 记忆令牌数量,默认为空
memory_tokens_interspersed_every = None, # 记忆令牌插入间隔,默认为空
tie_embedding = False, # 是否共享嵌入权重,默认为False
logits_dim = None, # logits维度,默认为空
use_abs_pos_emb = True, # 是否使用绝对位置编码,默认为True
scaled_sinu_pos_emb = False, # 是否使用缩放的正弦位置编码,默认为False
l2norm_embed = False, # 是否对嵌入进行L2归一化,默认为False
emb_frac_gradient = 1., # 梯度分配给嵌入的比例,默认为1
attn_z_loss_weight = 1e-4, # 注意力z损失的权重,默认为1e-4
):
# 调用父类的初始化函数
super().__init__()
# 获取注意力层的维度
dim = attn_layers.dim
# 如果嵌入维度为空,则设置为注意力层的维度
emb_dim = default(emb_dim, dim)
self.emb_dim = emb_dim
self.num_tokens = num_tokens
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down
self.l2norm_embed = l2norm_embed
# 创建令牌嵌入层对象
self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
# 判断是否不需要绝对位置编码
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
# 根据条件选择不同的位置编码方式
if no_abs_pos_emb:
self.pos_emb = always(0)
elif scaled_sinu_pos_emb:
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
else:
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
# 初始化额外的嵌入层
self.embeds = None
# 如果有额外的嵌入令牌数量,则创建对应的嵌入层
if len(embed_num_tokens) > 0:
self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
# 设置梯度分配给嵌入的比例
self.emb_frac_gradient = emb_frac_gradient
# 对嵌入后的结果进行归一化
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
self.emb_dropout = nn.Dropout(emb_dropout)
# 投影嵌入到指定维度
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
# 初始化模型参数
self.init_()
# 设置logits的维度
logits_dim = default(logits_dim, num_tokens)
# 如果不共享嵌入权重,则创建线性层
self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
# 设置记忆令牌
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
# 判断是否可以进行缓存的kv解码
self.can_cache_kv = self.num_memory_tokens == 0
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
# 初始化函数,根据是否进行L2归一化初始化权重
def init_(self):
if self.l2norm_embed:
nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
if not isinstance(self.pos_emb, always):
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
return
nn.init.kaiming_normal_(self.token_emb.emb.weight)
# 前向传播函数
def forward(
self,
x, # 输入数据
return_embeddings = False, # 是否返回嵌入结果
return_logits_and_embeddings = False, # 是否返回logits和嵌入结果
return_intermediates = False, # 是否返回中间结果
mask = None, # 掩码
return_mems = False, # 是否返回记忆
return_attn = False, # 是否返回注意力
mems = None, # 记忆
mem_masks = None, # 记忆掩码
pos = None, # 位置编码
prepend_embeds = None, # 前置嵌入
prepend_mask = None, # 前置掩码
embed_ids: Dict[str, Tensor] = dict(), # 嵌入ID的字典
sum_embeds = None, # 嵌入求和
return_attn_z_loss = False, # 是否返回注意力z损失
attn_z_loss_weight = 1e-4, # 注意力z损失的权重
seq_start_pos = None, # 序列起始位置
cache: Optional[LayerIntermediates] = None, # 缓存
**kwargs # 其他参数
class XTransformer(nn.Module):
# 定义 XTransformer 类,继承自 nn.Module
def __init__(
self,
*,
dim,
tie_token_emb = False,
ignore_index = -100,
pad_value = 0,
cross_attn_tokens_dropout = 0.,
**kwargs
):
# 初始化函数,接受一系列参数
super().__init__()
# 调用父类的初始化函数
# 将参数按照前缀分组并修剪
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
# 断言确保编码器或解码器的维度必须使用 `dim` 关键字设置
assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
# 从参数中选择并弹出 'num_tokens' 和 'max_seq_len',并设置默认值
enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
# 设置交叉注意力的 tokens dropout 参数
self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
# 创建编码器和解码器的 TransformerWrapper 对象
self.encoder = TransformerWrapper(
**enc_transformer_kwargs,
attn_layers = Encoder(dim = dim, **enc_kwargs)
)
self.decoder = TransformerWrapper(
**dec_transformer_kwargs,
attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
)
# 如果 tie_token_emb 为 True,则共享解码器的 token_emb 层和编码器的 token_emb 层
if tie_token_emb:
self.decoder.token_emb = self.encoder.token_emb
# 将解码器包装在 AutoregressiveWrapper 中
self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
@torch.no_grad()
def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
# 生成函数,接受输入序列和输出序列的起始位置、长度等参数
encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
# 使用编码器对输入序列进行编码,返回编码结果
return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
# 使用解码器生成输出序列
def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
# 前向传播函数,接受源序列、目标序列、掩码等参数
# 使用编码器对源序列进行编码
enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
# 如果存在源序列的前置嵌入和掩码,则在掩码上进行填充
if exists(src_prepend_embeds) and exists(mask):
mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
# 如果处于训练状态且交叉注意力 tokens dropout 大于 0,则对编码结果进行 dropout
if self.training and self.cross_attn_tokens_dropout > 0:
enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
# 使用解码器生成输出序列
out = self.decoder(tgt, context = enc, context_mask = mask)
return out
.\lucidrains\x-transformers\x_transformers\__init__.py
# 从 x_transformers.x_transformers 模块中导入以下类
from x_transformers.x_transformers import (
XTransformer, # XTransformer 类,用于定义 Transformer 模型
Encoder, # Encoder 类,用于定义编码器
Decoder, # Decoder 类,用于定义解码器
PrefixDecoder, # PrefixDecoder 类,用于定义前缀解码器
CrossAttender, # CrossAttender 类,用于定义交叉注意力机制
Attention, # Attention 类,用于定义注意力机制
TransformerWrapper, # TransformerWrapper 类,用于包装 Transformer 模型
ViTransformerWrapper # ViTransformerWrapper 类,用于包装 Vision Transformer 模型
)
# 从 x_transformers.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
# 从 x_transformers.nonautoregressive_wrapper 模块中导入 NonAutoregressiveWrapper 类
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
# 从 x_transformers.continuous 模块中导入以下类
from x_transformers.continuous import (
ContinuousTransformerWrapper, # ContinuousTransformerWrapper 类,用于包装连续 Transformer 模型
ContinuousAutoregressiveWrapper # ContinuousAutoregressiveWrapper 类,用于包装连续自回归模型
)
# 从 x_transformers.xval 模块中导入以下类
from x_transformers.xval import (
XValTransformerWrapper, # XValTransformerWrapper 类,用于包装交叉验证 Transformer 模型
XValAutoregressiveWrapper # XValAutoregressiveWrapper 类,用于包装交叉验证自回归模型
)
# 从 x_transformers.xl_autoregressive_wrapper 模块中导入 XLAutoregressiveWrapper 类
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
# 从 x_transformers.dpo 模块中导入 DPO 类
from x_transformers.dpo import (
DPO # DPO 类,用于定义 Discrete-Continuous-Optimization 模型
)
x-unet
Implementation of a U-net complete with efficient attention as well as the latest research findings
Install
$ pip install x-unet
Usage
import torch
from x_unet import XUnet
unet = XUnet(
dim = 64,
channels = 3,
dim_mults = (1, 2, 4, 8),
nested_unet_depths = (7, 4, 2, 1), # nested unet depths, from unet-squared paper
consolidate_upsample_fmaps = True, # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
)
img = torch.randn(1, 3, 256, 256)
out = unet(img) # (1, 3, 256, 256)
For 3d (video or CT / MRI scans)
import torch
from x_unet import XUnet
unet = XUnet(
dim = 64,
frame_kernel_size = 3, # set this to greater than 1
channels = 3,
dim_mults = (1, 2, 4, 8),
nested_unet_depths = (5, 4, 2, 1), # nested unet depths, from unet-squared paper
consolidate_upsample_fmaps = True, # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
weight_standardize = True
)
video = torch.randn(1, 3, 10, 128, 128) # (batch, channels, frames, height, width)
out = unet(video) # (1, 3, 10, 128, 128)
Todo
Citations
@article{Ronneberger2015UNetCN,
title = {U-Net: Convolutional Networks for Biomedical Image Segmentation},
author = {Olaf Ronneberger and Philipp Fischer and Thomas Brox},
journal = {ArXiv},
year = {2015},
volume = {abs/1505.04597}
}
@article{Qin2020U2NetGD,
title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
author = {Xuebin Qin and Zichen Vincent Zhang and Chenyang Huang and Masood Dehghan and Osmar R Zaiane and Martin J{\"a}gersand},
journal = {ArXiv},
year = {2020},
volume = {abs/2005.09007}
}
@inproceedings{Henry2020QueryKeyNF,
title = {Query-Key Normalization for Transformers},
author = {Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen},
booktitle = {FINDINGS},
year = {2020}
}
@article{Qiao2019WeightS,
title = {Weight Standardization},
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
journal = {ArXiv},
year = {2019},
volume = {abs/1903.10520}
}
@article{Shleifer2021NormFormerIT,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456}
}
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
@inproceedings{Woo2023ConvNeXtVC,
title = {ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
author = {Sanghyun Woo and Shoubhik Debnath and Ronghang Hu and Xinlei Chen and Zhuang Liu and In-So Kweon and Saining Xie},
year = {2023}
}
.\lucidrains\x-unet\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'x-unet', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.3.1', # 版本号
license='MIT', # 许可证
description = 'X-Unet', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/x-unet', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'biomedical segmentation',
'medical deep learning',
'unets',
],
install_requires=[ # 安装依赖
'beartype',
'einops>=0.4',
'torch>=1.6',
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\x-unet\x_unet\x_unet.py
# 导入必要的库
from functools import partial
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
# 导入 einops 库中的函数和类
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# 导入 beartype 库中的函数和类型
from beartype import beartype
from beartype.typing import Tuple, Union, Optional
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 返回值或默认值
def default(val, d):
return val if exists(val) else d
# 检查一个数是否为2的幂
def is_power_two(n):
return math.log2(n).is_integer()
# 检查一个数是否可以被另一个数整除
def divisible_by(num, denom):
return (num % denom) == 0
# 将值转换为元组
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(output) == length
return output
# 辅助类
# 上采样函数
def Upsample(dim, dim_out):
return nn.ConvTranspose3d(dim, dim_out, (1, 4, 4), (1, 2, 2), (0, 1, 1))
# 下采样函数
def Downsample(dim, dim_out):
return nn.Sequential(
Rearrange('b c f (h s1) (w s2) -> b (c s1 s2) f h w', s1 = 2, s2 = 2),
nn.Conv3d(dim * 4, dim_out, 1)
)
# 标准化
# 残差连接
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 层归一化
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + eps).sqrt() * self.gamma
# 权重标准化卷积
class WeightStandardizedConv3d(nn.Conv3d):
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
weight = self.weight
mean = reduce(weight, 'o ... -> o 1 1 1 1', 'mean')
var = reduce(weight, 'o ... -> o 1 1 1 1', partial(torch.var, unbiased = False))
weight = (weight - mean) * (var + eps).rsqrt()
return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
# ResNet 块
# 块类
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8,
weight_standardize = False,
frame_kernel_size = 1
):
super().__init__()
kernel_conv_kwargs = partial(kernel_and_same_pad, frame_kernel_size)
conv = nn.Conv3d if not weight_standardize else WeightStandardizedConv3d
self.proj = conv(dim, dim_out, **kernel_conv_kwargs(3, 3))
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return self.act(x)
# ResNet 块类
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8,
frame_kernel_size = 1,
nested_unet_depth = 0,
nested_unet_dim = 32,
weight_standardize = False
):
super().__init__()
self.block1 = Block(dim, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)
if nested_unet_depth > 0:
self.block2 = NestedResidualUnet(dim_out, depth = nested_unet_depth, M = nested_unet_dim, frame_kernel_size = frame_kernel_size, weight_standardize = weight_standardize, add_residual = True)
else:
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x):
h = self.block1(x)
h = self.block2(h)
return h + self.res_conv(x)
# ConvNeXT 2
# 全局响应归一化
class GRN(nn.Module):
""" global response normalization, proposed in updated convnext paper """
# 初始化函数,设置参数维度和容差值
def __init__(self, dim, eps = 1e-5):
# 调用父类的初始化函数
super().__init__()
# 设置容差值
self.eps = eps
# 初始化 gamma 参数为全零张量
self.gamma = nn.Parameter(torch.zeros(dim, 1, 1, 1))
# 初始化 bias 参数为全零张量
self.bias = nn.Parameter(torch.zeros(dim, 1, 1, 1))
# 前向传播函数
def forward(self, x):
# 计算 x 在指定维度上的 L2 范数
spatial_l2_norm = x.norm(p = 2, dim = (2, 3, 4), keepdim = True)
# 计算特征的归一化值
feat_norm = spatial_l2_norm / spatial_l2_norm.mean(dim = -1, keepdim = True).clamp(min = self.eps)
# 返回经过归一化和缩放后的特征值
return x * feat_norm * self.gamma + self.bias + x
# 定义一个卷积块类,用于构建下一个卷积块
class ConvNextBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
mult = 2,
frame_kernel_size = 1,
nested_unet_depth = 0,
nested_unet_dim = 32
):
super().__init__()
kernel_conv_kwargs = partial(kernel_and_same_pad, frame_kernel_size)
# 深度卷积
self.ds_conv = nn.Conv3d(dim, dim, **kernel_conv_kwargs(7, 7), groups = dim)
inner_dim = dim_out * mult
# 构建一个包含多个层的神经网络
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv3d(dim, inner_dim, **kernel_conv_kwargs(3, 3), groups = dim_out),
nn.GELU(),
GRN(inner_dim),
nn.Conv3d(inner_dim, dim_out, **kernel_conv_kwargs(3, 3), groups = dim_out)
)
# 嵌套的残差 UNet
self.nested_unet = NestedResidualUnet(dim_out, depth = nested_unet_depth, M = nested_unet_dim, add_residual = True) if nested_unet_depth > 0 else nn.Identity()
# 残差卷积
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
h = self.ds_conv(x)
h = self.net(h)
h = self.nested_unet(h)
return h + self.res_conv(x)
# 前馈神经网络
def FeedForward(dim, mult = 4.):
inner_dim = int(dim * mult)
return Residual(nn.Sequential(
LayerNorm(dim),
nn.Conv3d(dim, inner_dim, 1, bias = False),
nn.GELU(),
LayerNorm(inner_dim), # properly credit assign normformer
nn.Conv3d(inner_dim, dim, 1, bias = False)
))
# 注意力机制
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 64
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = heads * dim_head
self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv3d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv3d(inner_dim, dim, 1, bias = False)
def forward(self, x):
f, h, w = x.shape[-3:]
residual = x.clone()
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (f x y) d -> b (h d) f x y', f = f, x = h, y = w)
return self.to_out(out) + residual
# Transformer 块
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth,
**kwargs
):
super().__init__()
self.attn = Attention(dim, **kwargs)
self.ff = FeedForward(dim)
def forward(self, x):
x = self.attn(x)
x = self.ff(x)
return x
# 特征图整合器
class FeatureMapConsolidator(nn.Module):
def __init__(
self,
dim,
*,
dim_ins = tuple(),
dim_outs = tuple(),
resize_fmap_before = True,
conv_block_fn = None
):
super().__init__()
assert len(dim_ins) == len(dim_outs)
self.needs_consolidating = len(dim_ins) > 0
block_fn = default(conv_block_fn, Block)
# 特征图卷积层列表
self.fmap_convs = nn.ModuleList([block_fn(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.resize_fmap_before = resize_fmap_before
self.final_dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
# 调整特征图大小
def resize_fmaps(self, fmaps, height, width):
return [F.interpolate(fmap, (fmap.shape[-3], height, width)) for fmap in fmaps]
# 定义一个前向传播函数,接受输入 x 和特征图 fmaps,默认为 None
def forward(self, x, fmaps = None):
# 获取输入 x 的高度和宽度
target_height, target_width = x.shape[-2:]
# 如果未提供特征图 fmaps,则设置为空元组
fmaps = default(fmaps, tuple())
# 如果不需要合并特征图,则直接返回输入 x
if not self.needs_consolidating:
return x
# 如果需要在卷积之前调整特征图大小
if self.resize_fmap_before:
# 调整特征图大小
fmaps = self.resize_fmaps(fmaps, target_height, target_width)
# 初始化一个空列表用于存储输出
outs = []
# 遍历特征图和卷积层,将卷积后的结果添加到输出列表中
for fmap, conv in zip(fmaps, self.fmap_convs):
outs.append(conv(fmap))
# 如果需要在卷积之前调整特征图大小
if self.resize_fmap_before:
# 调整输出列表中的特征图大小
outs = self.resize_fmaps(outs, target_height, target_width)
# 将输入 x 和所有输出特征图连接在一起,沿着通道维度
return torch.cat((x, *outs), dim = 1)
# 定义一个函数,返回一个类型为 type 或者包含 type 的元组
def MaybeTuple(type):
return Union[type, Tuple[type, ...]]
# 根据卷积核大小计算 padding 大小
def kernel_and_same_pad(*kernel_size):
paddings = tuple(map(lambda k: k // 2, kernel_size))
return dict(kernel_size = kernel_size, padding = paddings)
# 定义 XUnet 类
class XUnet(nn.Module):
# 初始化函数
@beartype
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
frame_kernel_size = 1,
dim_mults: MaybeTuple(int) = (1, 2, 4, 8),
num_blocks_per_stage: MaybeTuple(int) = (2, 2, 2, 2),
num_self_attn_per_stage: MaybeTuple(int) = (0, 0, 0, 1),
nested_unet_depths: MaybeTuple(int) = (0, 0, 0, 0),
nested_unet_dim = 32,
channels = 3,
use_convnext = False,
resnet_groups = 8,
consolidate_upsample_fmaps = True,
skip_scale = 2 ** -0.5,
weight_standardize = False,
attn_heads: MaybeTuple(int) = 8,
attn_dim_head: MaybeTuple(int) = 32
def forward(self, x):
is_image = x.ndim == 4
# 验证
assert not (is_image and not self.train_as_images), 'you specified a frame kernel size for the convolutions in this unet, but you are passing in images'
assert not (not is_image and self.train_as_images), 'you specified no frame kernel size dimension, yet you are passing in a video. fold the frame dimension into the batch'
# 将图像转换为帧数为 1 的视频
if is_image:
x = rearrange(x, 'b c h w -> b c 1 h w')
# 初始卷积
x = self.init_conv(x)
# 残差
r = x.clone()
# 下采样和上采样
down_hiddens = []
up_hiddens = []
for init_block, blocks, attn_blocks, downsample in self.downs:
x = init_block(x)
for block in blocks:
x = block(x)
for attn_block in attn_blocks:
x = attn_block(x)
down_hiddens.append(x)
x = downsample(x)
x = self.mid(x)
x = self.mid_attn(x) + x
x = self.mid_after(x)
up_hiddens.append(x)
x = self.mid_upsample(x)
for init_block, blocks, attn_blocks, upsample in self.ups:
x = torch.cat((x, down_hiddens.pop() * self.skip_scale), dim=1)
x = init_block(x)
for block in blocks:
x = block(x)
for attn_block in attn_blocks:
x = attn_block(x)
up_hiddens.insert(0, x)
x = upsample(x)
# 合并特征图
x = self.consolidator(x, up_hiddens)
# 最终残差
x = torch.cat((x, r), dim = 1)
# 最终卷积
out = self.final_conv(x)
if is_image:
out = rearrange(out, 'b c 1 h w -> b c h w')
return out
# 定义 PixelShuffleUpsample 类
class PixelShuffleUpsample(nn.Module):
def __init__(
self,
dim,
dim_out = None,
scale_factor = 2
):
super().__init__()
self.scale_squared = scale_factor ** 2
dim_out = default(dim_out, dim)
conv = nn.Conv3d(dim, dim_out * self.scale_squared, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
Rearrange('b (c r s) f h w -> b c f (h r) (w s)', r = scale_factor, s = scale_factor)
)
self.init_conv_(conv)
# 初始化卷积层
def init_conv_(self, conv):
o, i, *rest_dims = conv.weight.shape
conv_weight = torch.empty(o // self.scale_squared, i, *rest_dims)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.scale_squared)
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
x = self.net(x)
return x
# 定义 NestedResidualUnet 类
class NestedResidualUnet(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
dim,
*,
depth,
M = 32,
frame_kernel_size = 1,
add_residual = False,
groups = 4,
skip_scale = 2 ** -0.5,
weight_standardize = False
):
# 调用父类的初始化函数
super().__init__()
# 设置模型深度和下采样、上采样模块
self.depth = depth
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
# 根据是否需要标准化权重选择卷积层类型
conv = WeightStandardizedConv3d if weight_standardize else nn.Conv3d
# 循环构建下采样模块
for ind in range(depth):
is_first = ind == 0
dim_in = dim if is_first else M
down = nn.Sequential(
conv(dim_in, M, (1, 4, 4), stride = (1, 2, 2), padding = (0, 1, 1)),
nn.GroupNorm(groups, M),
nn.SiLU()
)
# 添加到下采样模块列表
self.downs.append(down)
# 构建上采样模块
up = nn.Sequential(
PixelShuffleUpsample(2 * M, dim_in),
nn.GroupNorm(groups, dim_in),
nn.SiLU()
)
# 添加到上采样模块列表
self.ups.append(up)
# 中间层模块
self.mid = nn.Sequential(
conv(M, M, **kernel_and_same_pad(frame_kernel_size, 3, 3)),
nn.GroupNorm(groups, M),
nn.SiLU()
)
# 设置跳跃连接的缩放因子和是否添加残差连接
self.skip_scale = skip_scale
self.add_residual = add_residual
# 前向传播函数
def forward(self, x, residual = None):
# 判断输入是否为视频
is_video = x.ndim == 5
# 如果需要添加残差连接,则复制输入作为残差
if self.add_residual:
residual = default(residual, x.clone())
# 获取输入张量的高度和宽度
*_, h, w = x.shape
# 计算模型层数
layers = len(self.ups)
# 检查输入张量的高度和宽度是否符合要求
for dim_name, size in (('height', h), ('width', w)):
assert divisible_by(size, 2 ** layers), f'{dim_name} dimension {size} must be divisible by {2 ** layers} ({layers} layers in nested unet)'
assert (size % (2 ** self.depth)) == 0, f'the unet has too much depth for the image {dim_name} ({size}) being passed in'
# hiddens
# 存储中间特征
hiddens = []
# unet
# 下采样过程
for down in self.downs:
x = down(x)
hiddens.append(x.clone().contiguous())
# 中间层处理
x = self.mid(x)
# 上采样过程
for up in reversed(self.ups):
x = torch.cat((x, hiddens.pop() * self.skip_scale), dim = 1)
x = up(x)
# 添加残差连接
if self.add_residual:
x = x + residual
x = F.silu(x)
# 返回处理后的张量
return x
.\lucidrains\x-unet\x_unet\__init__.py
# 从 x_unet 模块中导入 XUnet 和 NestedResidualUnet 类
from x_unet.x_unet import XUnet, NestedResidualUnet
Zorro - Pytorch
Implementation of Zorro, Masked Multimodal Transformer, in Pytorch. This is a Deepmind work that claims a special masking strategy within a transformer help them achieve SOTA on a few multimodal benchmarks.
Appreciation
- Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
Install
$ pip install zorro-pytorch
Usage
import torch
from zorro_pytorch import Zorro, TokenTypes as T
model = Zorro(
dim = 512, # model dimensions
depth = 6, # depth
dim_head = 64, # attention dimension heads
heads = 8, # attention heads
ff_mult = 4, # feedforward multiple
num_fusion_tokens = 16, # number of fusion tokens
audio_patch_size = 16, # audio patch size, can also be Tuple[int, int]
video_patch_size = 16, # video patch size, can also be Tuple[int, int]
video_temporal_patch_size = 2, # video temporal patch size
video_channels = 3, # video channels
return_token_types = (
T.AUDIO,
T.AUDIO,
T.FUSION,
T.GLOBAL,
T.VIDEO,
T.VIDEO,
T.VIDEO,
) # say you want to return 2 tokens for audio, 1 token for fusion, 3 for video - for whatever self-supervised learning, supervised learning, etc etc
)
video = torch.randn(2, 3, 8, 32, 32) # (batch, channels, time, height, width)
audio = torch.randn(2, 1024 * 10) # (batch, time)
return_tokens = model(audio = audio, video = video) # (2, 6, 512) - all 6 tokes as indicated above is returned
# say you only want 1 audio and 1 video token, for contrastive learning
audio_token, video_token = model(audio = audio, video = video, return_token_indices = (0, 3)).unbind(dim = -2) # (2, 512), (2, 512)
Citations
@inproceedings{Recasens2023ZorroTM,
title = {Zorro: the masked multimodal transformer},
author = {Adri{\`a} Recasens and Jason Lin and Jo{\~a}o Carreira and Drew Jaegle and Luyu Wang and Jean-Baptiste Alayrac and Pauline Luc and Antoine Miech and Lucas Smaira and Ross Hemsley and Andrew Zisserman},
year = {2023}
}
.\lucidrains\zorro-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'zorro-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.1.1',
# 许可证类型
license='MIT',
# 描述
description = 'Zorro - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/zorro-pytorch',
# 关键词列表
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'multimodal fusion'
],
# 安装依赖
install_requires=[
'beartype',
'einops>=0.4',
'torch>=1.6',
'torchaudio'
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\zorro-pytorch\zorro_pytorch\zorro_pytorch.py
# 导入所需的模块和类
from enum import Enum
import functools
from functools import wraps
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
from beartype import beartype
from beartype.typing import Tuple, Optional, Union
from torchaudio.transforms import Spectrogram
# 定义枚举类型 TokenTypes,包含音频、视频、融合和全局四种类型
class TokenTypes(Enum):
AUDIO = 0
VIDEO = 1
FUSION = 2
GLOBAL = 3
# 定义一些通用的函数
# 判断变量是否存在
def exists(val):
return val is not None
# 返回参数列表中第一个存在的参数,如果都不存在则返回 None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
# 返回小于等于 n 的最接近的 divisor 的倍数
def round_down_nearest_multiple(n, divisor):
return n // divisor * divisor
# 将输入转换为元组,如果输入不是元组则返回 (t, t)
def pair(t):
return (t, t) if not isinstance(t, tuple) else t
# 对可迭代对象进行累积乘法
def cum_mul(it):
return functools.reduce(lambda x, y: x * y, it, 1)
# 判断 numer 是否能被 denom 整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 装饰器
# 保证函数只调用一次的装饰器
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用 once 装饰的 print 函数,确保只打印一次
print_once = once(print)
# 无偏置的 Layernorm 类
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# GEGLU 激活函数
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return F.gelu(gate) * x
# FeedForward 网络结构
def FeedForward(dim, mult = 4):
inner_dim = int(dim * mult * 2 / 3)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False),
GEGLU(),
nn.Linear(inner_dim, dim, bias = False)
)
# 注意力机制
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(
self,
x,
context = None,
attn_mask = None
):
x = self.norm(x)
kv_x = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(kv_x).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(attn_mask):
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 主类 Zorro
class Zorro(nn.Module):
def __init__(
self,
dim,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
num_fusion_tokens = 16,
audio_patch_size: Union[int, Tuple[int, int]] = 16,
video_patch_size: Union[int, Tuple[int, int]] = 16,
video_temporal_patch_size = 2,
video_channels = 3,
spec_n_fft = 128,
spec_power = 2,
spec_win_length = 24,
spec_hop_length = None,
spec_pad = 0,
spec_center = True,
spec_pad_mode = 'reflect',
spec_aug_stretch_factor = 0.8,
spec_aug_freq_mask = 80,
spec_aug_time_mask = 80,
return_token_types: Tuple[TokenTypes] = (TokenTypes.AUDIO, TokenTypes.VIDEO, TokenTypes.FUSION)
):
# 调用父类的构造函数
super().__init__()
# 设置最大返回标记数为返回标记类型列表的长度
self.max_return_tokens = len(return_token_types)
# 存储返回标记类型列表
self.return_token_types = return_token_types
# 将返回标记类型列表转换为张量
return_token_types_tensor = torch.tensor(list(map(lambda t: t.value, return_token_types)))
# 将返回标记类型张量注册为缓冲区
self.register_buffer('return_token_types_tensor', return_token_types_tensor, persistent=False)
# 初始化返回标记张量
self.return_tokens = nn.Parameter(torch.randn(self.max_return_tokens, dim))
# 初始化注意力池
self.attn_pool = Attention(dim=dim, dim_head=dim_head, heads=heads)
# 音频输入
# 设置音频块大小
self.audio_patch_size = audio_patch_height, audio_patch_width = pair(audio_patch_size)
# 初始化频谱图
self.spec = Spectrogram(
n_fft=spec_n_fft,
power=spec_power,
win_length=spec_win_length,
hop_length=spec_hop_length,
pad=spec_pad,
center=spec_center,
pad_mode=spec_pad_mode
)
# 计算音频输入维度
audio_input_dim = cum_mul(self.audio_patch_size)
# 将音频转换为标记
self.audio_to_tokens = nn.Sequential(
Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1=audio_patch_height, p2=audio_patch_width),
nn.LayerNorm(audio_input_dim),
nn.Linear(audio_input_dim, dim),
nn.LayerNorm(dim)
)
# 视频输入
# 设置视频块大小
self.video_patch_size = (video_temporal_patch_size, *pair(video_patch_size))
# 计算视频输入维度
video_input_dim = cum_mul(self.video_patch_size) * video_channels
video_patch_time, video_patch_height, video_patch_width = self.video_patch_size
# 将视频转换为标记
self.video_to_tokens = nn.Sequential(
Rearrange('b c (t p1) (h p2) (w p3) -> b t h w (c p1 p2 p3)', p1=video_patch_time, p2=video_patch_height, p3=video_patch_width),
nn.LayerNorm(video_input_dim),
nn.Linear(video_input_dim, dim),
nn.LayerNorm(dim)
)
# 融合标记
# 初始化融合标记
self.fusion_tokens = nn.Parameter(torch.randn(num_fusion_tokens, dim))
# transformer
# 初始化层列表
self.layers = nn.ModuleList([])
# 循环创建指定数量的层
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult)
]))
# 初始化层归一化
self.norm = LayerNorm(dim)
def forward(
self,
*,
audio,
video,
return_token_indices: Optional[Tuple[int]] = None
):
# 获取音频的批次大小和设备信息
batch, device = audio.shape[0], audio.device
# 验证视频是否可以被分块
assert all([divisible_by(numer, denom) for denom, numer in zip(self.video_patch_size, tuple(video.shape[-3:]))]), f'video shape {video.shape[-3:]} needs to be divisible by {self.video_patch_size}'
# 如果音频产生的二维频谱图不是patch大小的倍数,则自动裁剪
audio = self.spec(audio)
height, width = audio.shape[-2:]
patch_height, patch_width = self.audio_patch_size
rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))
if (height, width) != (rounded_height, rounded_width): # 只要打印,直到修复为止
print_once(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')
audio = audio[..., :rounded_height, :rounded_width]
# 转换为tokens
audio_tokens = self.audio_to_tokens(audio)
video_tokens = self.video_to_tokens(video)
fusion_tokens = repeat(self.fusion_tokens, 'n d -> b n d', b = batch)
# 构建所有tokens
audio_tokens, fusion_tokens, video_tokens = map(lambda t: rearrange(t, 'b ... d -> b (...) d'), (audio_tokens, fusion_tokens, video_tokens))
tokens, ps = pack((
audio_tokens,
fusion_tokens,
video_tokens
), 'b * d')
# 构建mask(即zorro)
token_types = torch.tensor(list((
*((TokenTypes.AUDIO.value,) * audio_tokens.shape[-2]),
*((TokenTypes.FUSION.value,) * fusion_tokens.shape[-2]),
*((TokenTypes.VIDEO.value,) * video_tokens.shape[-2]),
)), device = device, dtype = torch.long)
token_types_attend_from = rearrange(token_types, 'i -> i 1')
token_types_attend_to = rearrange(token_types, 'j -> 1 j')
# 逻辑是每个模态,包括融合,都可以关注自己
zorro_mask = token_types_attend_from == token_types_attend_to
# 融合可以关注所有
zorro_mask = zorro_mask | (token_types_attend_from == TokenTypes.FUSION.value)
# 注意力和前馈
for attn, ff in self.layers:
tokens = attn(tokens, attn_mask = zorro_mask) + tokens
tokens = ff(tokens) + tokens
tokens = self.norm(tokens)
# 最终注意力池化 - 每个模态池token只能关注自己的tokens
return_tokens = self.return_tokens
return_token_types_tensor = self.return_token_types_tensor
if exists(return_token_indices):
assert len(set(return_token_indices)) == len(return_token_indices), 'all indices must be unique'
assert all([indice < self.max_return_tokens for indice in return_token_indices]), 'indices must range from 0 to max_num_return_tokens - 1'
return_token_indices = torch.tensor(return_token_indices, dtype = torch.long, device = device)
return_token_types_tensor = return_token_types_tensor[return_token_indices]
return_tokens = return_tokens[return_token_indices]
return_tokens = repeat(return_tokens, 'n d -> b n d', b = batch)
pool_mask = rearrange(return_token_types_tensor, 'i -> i 1') == token_types_attend_to
# 全局查询可以关注所有tokens
pool_mask = pool_mask | rearrange(return_token_types_tensor, 'i -> i 1') == torch.ones_like(token_types_attend_to, dtype=torch.long) * TokenTypes.GLOBAL.value
pooled_tokens = self.attn_pool(return_tokens, context = tokens, attn_mask = pool_mask) + return_tokens
return pooled_tokens
.\lucidrains\zorro-pytorch\zorro_pytorch\__init__.py
# 从 zorro_pytorch.zorro_pytorch 模块中导入 Zorro 类和 TokenTypes 常量
from zorro_pytorch.zorro_pytorch import Zorro, TokenTypes