Lucidrains-系列项目源码解析-十四-
Lucidrains 系列项目源码解析(十四)
Electra - Pytorch
A simple working wrapper for fast pretraining of language models as detailed in this paper. It speeds up training (in comparison to normal masked language modeling) by a factor of 4x, and eventually reaches better performance if trained for even longer. Special thanks to Erik Nijkamp for taking the time to replicate the results for GLUE.
Install
$ pip install electra-pytorch
Usage
The following example uses reformer-pytorch
, which is available to be pip installed.
import torch
from torch import nn
from reformer_pytorch import ReformerLM
from electra_pytorch import Electra
# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
generator = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 256, # smaller hidden dimension
heads = 4, # less heads
ff_mult = 2, # smaller feed forward intermediate dimension
dim_head = 64,
depth = 12,
max_seq_len = 1024
)
discriminator = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 1024,
dim_head = 64,
heads = 16,
depth = 12,
ff_mult = 4,
max_seq_len = 1024
)
# (2) weight tie the token and positional embeddings of generator and discriminator
generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb
# weight tie any other embeddings if available, token type embeddings, etc.
# (3) instantiate electra
trainer = Electra(
generator,
discriminator,
discr_dim = 1024, # the embedding dimension of the discriminator
discr_layer = 'reformer', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
mask_token_id = 2, # the token id reserved for masking
pad_token_id = 0, # the token id for padding
mask_prob = 0.15, # masking probability for masked language modeling
mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep)
)
# (4) train
data = torch.randint(0, 20000, (1, 1024))
results = trainer(data)
results.loss.backward()
# after much training, the discriminator should have improved
torch.save(discriminator, f'./pretrained-model.pt')
If you would rather not have the framework auto-magically intercept the hidden output of the discriminator, you can pass in the discriminator (with the extra linear [dim x 1]) by yourself with the following.
import torch
from torch import nn
from reformer_pytorch import ReformerLM
from electra_pytorch import Electra
# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
generator = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 256, # smaller hidden dimension
heads = 4, # less heads
ff_mult = 2, # smaller feed forward intermediate dimension
dim_head = 64,
depth = 12,
max_seq_len = 1024
)
discriminator = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 1024,
dim_head = 64,
heads = 16,
depth = 12,
ff_mult = 4,
max_seq_len = 1024,
return_embeddings = True
)
# (2) weight tie the token and positional embeddings of generator and discriminator
generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb
# weight tie any other embeddings if available, token type embeddings, etc.
# (3) instantiate electra
discriminator_with_adapter = nn.Sequential(discriminator, nn.Linear(1024, 1))
trainer = Electra(
generator,
discriminator_with_adapter,
mask_token_id = 2, # the token id reserved for masking
pad_token_id = 0, # the token id for padding
mask_prob = 0.15, # masking probability for masked language modeling
mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep)
)
# (4) train
data = torch.randint(0, 20000, (1, 1024))
results = trainer(data)
results.loss.backward()
# after much training, the discriminator should have improved
torch.save(discriminator, f'./pretrained-model.pt')
Important details for successful training
The generator should be roughly a quarter to at most one half of the discriminator's size for effective training. Any greater and the generator will be too good and the adversarial game collapses. This was done by reducing the hidden dimension, feed forward hidden dimension, and number of attention heads in the paper.
Testing
$ python setup.py test
Training
- Download the OpenWebText dataset.
$ mkdir data
$ cd data
$ pip3 install gdown
$ gdown --id 1EA5V0oetDCOke7afsktL_JDQ-ETtNOvx
$ tar -xf openwebtext.tar.xz
$ wget https://storage.googleapis.com/electra-data/vocab.txt
$ cd ..
- Tokenize dataset.
$ python pretraining/openwebtext/preprocess.py
- Pre-train.
$ python pretraining/openwebtext/pretrain.py
- Download GLUE dataset.
$ python examples/glue/download.py
- Fine-tune on the MRPC sub-task of the GLUE benchmark.
$ python examples/glue/run.py --model_name_or_path output/yyyy-mm-dd-hh-mm-ss/ckpt/200000
Citations
@misc{clark2020electra,
title={ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators},
author={Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning},
year={2020},
eprint={2003.10555},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
.\lucidrains\electra-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'electra-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.1.2', # 版本号
license='MIT', # 许可证
description = 'Electra - Pytorch', # 描述
author = 'Erik Nijkamp, Phil Wang', # 作者
author_email = 'erik.nijkamp@gmail.com, lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/electra-pytorch', # 项目链接
keywords = [
'transformers', # 关键词
'artificial intelligence', # 关键词
'pretraining' # 关键词
],
install_requires=[
'torch>=1.6.0', # 安装依赖
'transformers==3.0.2', # 安装依赖
'scipy', # 安装依赖
'sklearn' # 安装依赖
],
setup_requires=[
'pytest-runner' # 安装依赖
],
tests_require=[
'pytest', # 测试依赖
'reformer-pytorch' # 测试依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 分类
'Intended Audience :: Developers', # 分类
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类
'License :: OSI Approved :: MIT License', # 分类
'Programming Language :: Python :: 3.7', # 分类
],
)
.\lucidrains\electra-pytorch\tests\test_electra_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 reformer_pytorch 库中导入 ReformerLM 类
from reformer_pytorch import ReformerLM
# 从 electra_pytorch 库中导入 Electra 类
# 定义测试 Electra 模型的函数
def test_electra():
# 创建生成器 ReformerLM 模型
generator = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 1,
max_seq_len = 1024
)
# 创建鉴别器 ReformerLM 模型
discriminator = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 2,
max_seq_len = 1024
)
# 将生成器的 token_emb 属性设置为鉴别器的 token_emb 属性
generator.token_emb = discriminator.token_emb
# 将生成器的 pos_emb 属性设置为鉴别器的 pos_emb 属性
# 创建 Electra 训练器
trainer = Electra(
generator,
discriminator,
num_tokens = 20000,
discr_dim = 512,
discr_layer = 'reformer',
pad_token_id = 1,
mask_ignore_token_ids = [2, 3]
)
# 生成随机数据
data = torch.randint(0, 20000, (1, 1024))
# 使用训练器进行训练
results = trainer(data)
# 计算损失并反向传播
results.loss.backward()
# 定义测试不使用魔法方法的 Electra 模型的函数
def test_electra_without_magic():
# 创建生成器 ReformerLM 模型
generator = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 1,
max_seq_len = 1024
)
# 创建鉴别器 ReformerLM 模型
discriminator = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 2,
max_seq_len = 1024,
return_embeddings = True
)
# 将生成器的 token_emb 属性设置为鉴别器的 token_emb 属性
generator.token_emb = discriminator.token_emb
# 将生成器的 pos_emb 属性设置为鉴别器的 pos_emb 属性
# 创建包含适配器的鉴别器模型
discriminator_with_adapter = nn.Sequential(
discriminator,
nn.Linear(512, 1),
nn.Sigmoid()
)
# 创建 Electra 训练器
trainer = Electra(
generator,
discriminator_with_adapter,
num_tokens = 20000,
pad_token_id = 1,
mask_ignore_token_ids = [2, 3]
)
# 生成随机数据
data = torch.randint(0, 20000, (1, 1024))
# 使用训练器进行训练
results = trainer(data)
# 计算损失并反向传播
results.loss.backward()
.\lucidrains\ema-pytorch\ema_pytorch\ema_pytorch.py
# 导入深拷贝函数 deepcopy 和 partial 函数
from copy import deepcopy
from functools import partial
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, Tensor 模块
from torch import nn, Tensor
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module
# 导入 beartype 库
from beartype import beartype
# 从 beartype.typing 模块中导入 Set, Optional 类型
from beartype.typing import Set, Optional
# 定义函数 exists,用于检查值是否存在
def exists(val):
return val is not None
# 定义函数 get_module_device,用于获取模块的设备信息
def get_module_device(m: Module):
return next(m.parameters()).device
# 定义函数 inplace_copy,用于原地复制张量数据
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
tgt.copy_(src)
# 定义函数 inplace_lerp,用于原地线性插值
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
tgt.lerp_(src, weight)
# 定义 EMA 类,实现模型的指数移动平均阴影
class EMA(Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
# 使用 beartype 装饰器,对初始化函数进行类型检查
@beartype
def __init__(
self,
model: Module,
ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
beta = 0.9999,
update_after_step = 100,
update_every = 10,
inv_gamma = 1.0,
power = 2 / 3,
min_value = 0.0,
param_or_buffer_names_no_ema: Set[str] = set(),
ignore_names: Set[str] = set(),
ignore_startswith_names: Set[str] = set(),
include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
):
# 调用父类的构造函数
super().__init__()
# 初始化 beta 属性
self.beta = beta
# 判断是否冻结模型
self.is_frozen = beta == 1.
# 是否在模块树中包含在线模型,以便 state_dict 也保存它
self.include_online_model = include_online_model
if include_online_model:
self.online_model = model
else:
self.online_model = [model] # hack
# EMA 模型
self.ema_model = ema_model
if not exists(self.ema_model):
try:
self.ema_model = deepcopy(model)
except Exception as e:
print(f'Error: While trying to deepcopy model: {e}')
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
exit()
self.ema_model.requires_grad_(False)
# 参数和缓冲区的名称
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
# 张量更新函数
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
# 更新超参数
self.update_every = update_every
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
assert isinstance(param_or_buffer_names_no_ema, (set, list))
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
self.ignore_names = ignore_names
self.ignore_startswith_names = ignore_startswith_names
# 是否管理 EMA 模型是否保留在不同设备上
self.allow_different_devices = allow_different_devices
# 初始化和步骤状态
self.register_buffer('initted', torch.tensor(False))
self.register_buffer('step', torch.tensor(0))
@property
def model(self):
return self.online_model if self.include_online_model else self.online_model[0]
def eval(self):
return self.ema_model.eval()
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def get_params_iter(self, model):
for name, param in model.named_parameters():
if name not in self.parameter_names:
continue
yield name, param
def get_buffers_iter(self, model):
for name, buffer in model.named_buffers():
if name not in self.buffer_names:
continue
yield name, buffer
def copy_params_from_model_to_ema(self):
copy = self.inplace_copy
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
copy(ma_params.data, current_params.data)
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(ma_buffers.data, current_buffers.data)
def copy_params_from_ema_to_model(self):
copy = self.inplace_copy
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
copy(current_params.data, ma_params.data)
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(current_buffers.data, ma_buffers.data)
# 获取当前的衰减值
def get_current_decay(self):
# 计算当前的 epoch,确保不小于 0
epoch = (self.step - self.update_after_step - 1).clamp(min=0.)
# 根据公式计算衰减值
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
# 如果 epoch 小于等于 0,则返回 0
if epoch.item() <= 0:
return 0.
# 返回计算得到的衰减值,确保在一定范围内
return value.clamp(min=self.min_value, max=self.beta).item()
# 更新操作
def update(self):
# 获取当前步数
step = self.step.item()
# 步数加一
self.step += 1
# 如果步数不是更新频率的倍数,则直接返回
if (step % self.update_every) != 0:
return
# 如果步数小于等于更新之后的步数,则将模型参数拷贝到指数移动平均模型中
if step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
# 如果模型还未初始化,则将模型参数拷贝到指数移动平均模型中,并标记为已初始化
if not self.initted.item():
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.tensor(True))
# 更新指数移动平均模型
self.update_moving_average(self.ema_model, self.model)
# 更新指数移动平均模型
@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
# 如果模型被冻结,则直接返回
if self.is_frozen:
return
# 获取拷贝和线性插值函数
copy, lerp = self.inplace_copy, self.inplace_lerp
# 获取当前的衰减值
current_decay = self.get_current_decay()
# 遍历当前模型和指数移动平均模型的参数
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
# 如果参数名在忽略列表中,则跳过
if name in self.ignore_names:
continue
# 如果参数名以忽略列表中的前缀开头,则跳过
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
# 如果参数名在不进行指数移动平均的列表中,则直接拷贝参数值
if name in self.param_or_buffer_names_no_ema:
copy(ma_params.data, current_params.data)
continue
# 对参数进行线性插值
lerp(ma_params.data, current_params.data, 1. - current_decay)
# 遍历当前模型和指数移动平均模型的缓冲区
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
# 如果缓冲区名在忽略列表中,则跳过
if name in self.ignore_names:
continue
# 如果缓冲区名以忽略列表中的前缀开头,则跳过
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
# 如果缓冲区名在不进行指数移动平均的列表中,则直接拷贝缓冲区值
if name in self.param_or_buffer_names_no_ema:
copy(ma_buffer.data, current_buffer.data)
continue
# 对缓冲区进行线性插值
lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
# 调用函数,返回指数移动平均模型的结果
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
.\lucidrains\ema-pytorch\ema_pytorch\post_hoc_ema.py
# 导入必要的模块
from pathlib import Path
from copy import deepcopy
from functools import partial
import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList
import numpy as np
from beartype import beartype
from beartype.typing import Set, Tuple, Optional
# 检查值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 返回数组的第一个元素
def first(arr):
return arr[0]
# 获取模块的设备
def get_module_device(m: Module):
return next(m.parameters()).device
# 在原地复制张量
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
tgt.copy_(src)
# 在原地执行线性插值
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
tgt.lerp_(src, weight)
# 将相对标准差转换为 gamma
def sigma_rel_to_gamma(sigma_rel):
t = sigma_rel ** -2
return np.roots([1, 7, 16 - t, 12 - t]).real.max().item()
# EMA 模块,使用论文 https://arxiv.org/abs/2312.02696 中的超参数
class KarrasEMA(Module):
"""
exponential moving average module that uses hyperparameters from the paper https://arxiv.org/abs/2312.02696
can either use gamma or sigma_rel from paper
"""
@beartype
def __init__(
self,
model: Module,
sigma_rel: Optional[float] = None,
gamma: Optional[float] = None,
ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
update_every: int = 100,
frozen: bool = False,
param_or_buffer_names_no_ema: Set[str] = set(),
ignore_names: Set[str] = set(),
ignore_startswith_names: Set[str] = set(),
allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
):
super().__init__()
assert exists(sigma_rel) ^ exists(gamma), 'either sigma_rel or gamma is given. gamma is derived from sigma_rel as in the paper, then beta is dervied from gamma'
if exists(sigma_rel):
gamma = sigma_rel_to_gamma(sigma_rel)
self.gamma = gamma
self.frozen = frozen
self.online_model = [model]
# ema model
self.ema_model = ema_model
if not exists(self.ema_model):
try:
self.ema_model = deepcopy(model)
except Exception as e:
print(f'Error: While trying to deepcopy model: {e}')
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
exit()
self.ema_model.requires_grad_(False)
# parameter and buffer names
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
# tensor update functions
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
# updating hyperparameters
self.update_every = update_every
assert isinstance(param_or_buffer_names_no_ema, (set, list))
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
self.ignore_names = ignore_names
self.ignore_startswith_names = ignore_startswith_names
# whether to manage if EMA model is kept on a different device
self.allow_different_devices = allow_different_devices
# init and step states
self.register_buffer('initted', torch.tensor(False))
self.register_buffer('step', torch.tensor(0))
@property
def model(self):
return first(self.online_model)
@property
# 计算 beta 值,用于更新移动平均模型
def beta(self):
return (1 - 1 / (self.step + 1)) ** (1 + self.gamma)
# 调用 EMA 模型的 eval 方法
def eval(self):
return self.ema_model.eval()
# 将 EMA 模型恢复到指定设备上
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
# 获取模型的参数迭代器
def get_params_iter(self, model):
for name, param in model.named_parameters():
if name not in self.parameter_names:
continue
yield name, param
# 获取模型的缓冲区迭代器
def get_buffers_iter(self, model):
for name, buffer in model.named_buffers():
if name not in self.buffer_names:
continue
yield name, buffer
# 从原模型复制参数到 EMA 模型
def copy_params_from_model_to_ema(self):
copy = self.inplace_copy
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
copy(ma_params.data, current_params.data)
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(ma_buffers.data, current_buffers.data)
# 从 EMA 模型复制参数到原模型
def copy_params_from_ema_to_model(self):
copy = self.inplace_copy
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
copy(current_params.data, ma_params.data)
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(current_buffers.data, ma_buffers.data)
# 更新步数并执行移动平均更新
def update(self):
step = self.step.item()
self.step += 1
if (step % self.update_every) != 0:
return
if not self.initted.item():
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.tensor(True))
self.update_moving_average(self.ema_model, self.model)
# 迭代所有 EMA 模型的参数和缓冲区
def iter_all_ema_params_and_buffers(self):
for name, ma_params in self.get_params_iter(self.ema_model):
if name in self.ignore_names:
continue
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
if name in self.param_or_buffer_names_no_ema:
continue
yield ma_params
for name, ma_buffer in self.get_buffers_iter(self.ema_model):
if name in self.ignore_names:
continue
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
if name in self.param_or_buffer_names_no_ema:
continue
yield ma_buffer
# 更新移动平均模型
@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
if self.frozen:
return
copy, lerp = self.inplace_copy, self.inplace_lerp
current_decay = self.beta
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
if name in self.ignore_names:
continue
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
if name in self.param_or_buffer_names_no_ema:
copy(ma_params.data, current_params.data)
continue
lerp(ma_params.data, current_params.data, 1. - current_decay)
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
if name in self.ignore_names:
continue
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
if name in self.param_or_buffer_names_no_ema:
copy(ma_buffer.data, current_buffer.data)
continue
lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
# 定义一个特殊方法 __call__,使得对象可以像函数一样被调用
def __call__(self, *args, **kwargs):
# 调用 ema_model 对象,并传入参数
return self.ema_model(*args, **kwargs)
# 后验EMA包装器
# 解决将所有检查点组合成新合成的EMA的权重,以达到所需的gamma
# 算法3从论文中复制,用torch重新实现
# 计算两个张量的点乘
def p_dot_p(t_a, gamma_a, t_b, gamma_b):
t_ratio = t_a / t_b
t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
t_max = torch.maximum(t_a , t_b)
num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
den = (gamma_a + gamma_b + 1) * t_max
return num / den
# 解决权重
def solve_weights(t_i, gamma_i, t_r, gamma_r):
rv = lambda x: x.double().reshape(-1, 1)
cv = lambda x: x.double().reshape(1, -1)
A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
return torch.linalg.solve(A, b)
# 后验EMA类
class PostHocEMA(Module):
# 初始化函数
@beartype
def __init__(
self,
model: Module,
sigma_rels: Optional[Tuple[float, ...]] = None,
gammas: Optional[Tuple[float, ...]] = None,
checkpoint_every_num_steps: int = 1000,
checkpoint_folder: str = './post-hoc-ema-checkpoints',
**kwargs
):
super().__init__()
assert exists(sigma_rels) ^ exists(gammas)
if exists(sigma_rels):
gammas = tuple(map(sigma_rel_to_gamma, sigma_rels))
assert len(gammas) > 1, 'at least 2 ema models with different gammas in order to synthesize new ema models of a different gamma'
assert len(set(gammas)) == len(gammas), 'calculated gammas must be all unique'
self.gammas = gammas
self.num_ema_models = len(gammas)
self._model = [model]
self.ema_models = ModuleList([KarrasEMA(model, gamma = gamma, **kwargs) for gamma in gammas])
self.checkpoint_folder = Path(checkpoint_folder)
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
assert self.checkpoint_folder.is_dir()
self.checkpoint_every_num_steps = checkpoint_every_num_steps
self.ema_kwargs = kwargs
# 返回模型
@property
def model(self):
return first(self._model)
# 返回步数
@property
def step(self):
return first(self.ema_models).step
# 返回设备
@property
def device(self):
return self.step.device
# 从EMA复制参数到模型
def copy_params_from_ema_to_model(self):
for ema_model in self.ema_models:
ema_model.copy_params_from_model_to_ema()
# 更新EMA模型
def update(self):
for ema_model in self.ema_models:
ema_model.update()
if not (self.step.item() % self.checkpoint_every_num_steps):
self.checkpoint()
# 创建检查点
def checkpoint(self):
step = self.step.item()
for ind, ema_model in enumerate(self.ema_models):
filename = f'{ind}.{step}.pt'
path = self.checkpoint_folder / filename
pkg = deepcopy(ema_model).half().state_dict()
torch.save(pkg, str(path))
# 合成EMA模型
@beartype
def synthesize_ema_model(
self,
gamma: Optional[float] = None,
sigma_rel: Optional[float] = None,
step: Optional[int] = None,
# 定义一个返回 KarrasEMA 对象的函数,参数包括 gamma 和 sigma_rel
def __call__(self, gamma: Optional[float] = None, sigma_rel: Optional[float] = None) -> KarrasEMA:
# 断言 gamma 和 sigma_rel 只能存在一个
assert exists(gamma) ^ exists(sigma_rel)
# 获取设备信息
device = self.device
# 如果存在 sigma_rel,则根据 sigma_rel 转换为 gamma
if exists(sigma_rel):
gamma = sigma_rel_to_gamma(sigma_rel)
# 创建一个合成的 EMA 模型对象
synthesized_ema_model = KarrasEMA(
model = self.model,
gamma = gamma,
**self.ema_kwargs
)
synthesized_ema_model
# 获取所有检查点
gammas = []
timesteps = []
checkpoints = [*self.checkpoint_folder.glob('*.pt')]
# 遍历检查点文件,获取 gamma 和 timestep
for file in checkpoints:
gamma_ind, timestep = map(int, file.stem.split('.'))
gamma = self.gammas[gamma_ind]
gammas.append(gamma)
timesteps.append(timestep)
# 设置步数为最大 timestep
step = default(step, max(timesteps))
# 断言步数小于等于最大 timestep
assert step <= max(timesteps), f'you can only synthesize for a timestep that is less than the max timestep {max(timesteps)}'
# 与算法 3 对齐
gamma_i = Tensor(gammas, device = device)
t_i = Tensor(timesteps, device = device)
gamma_r = Tensor([gamma], device = device)
t_r = Tensor([step], device = device)
# 使用最小二乘法解出将所有检查点组合成合成检查点的权重
weights = solve_weights(t_i, gamma_i, t_r, gamma_r)
weights = weights.squeeze(-1)
# 逐个使用权重将所有检查点相加到合成模型中
tmp_ema_model = KarrasEMA(
model = self.model,
gamma = gamma,
**self.ema_kwargs
)
for ind, (checkpoint, weight) in enumerate(zip(checkpoints, weights.tolist())):
is_first = ind == 0
# 将检查点加载到临时 EMA 模型中
ckpt_state_dict = torch.load(str(checkpoint))
tmp_ema_model.load_state_dict(ckpt_state_dict)
# 将加权检查点添加到合成模型中
for ckpt_tensor, synth_tensor in zip(tmp_ema_model.iter_all_ema_params_and_buffers(), synthesized_ema_model.iter_all_ema_params_and_buffers()):
if is_first:
synth_tensor.zero_()
synth_tensor.add_(ckpt_tensor * weight)
# 返回合成模型
return synthesized_ema_model
# 调用函数,返回所有 EMA 模型的结果
def __call__(self, *args, **kwargs):
return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models)
.\lucidrains\ema-pytorch\ema_pytorch\__init__.py
# 从 ema_pytorch 模块中导入 EMA 类
from ema_pytorch.ema_pytorch import EMA
# 从 ema_pytorch 模块中导入 KarrasEMA 和 PostHocEMA 类
from ema_pytorch.post_hoc_ema import (
KarrasEMA,
PostHocEMA
)
EMA - Pytorch
A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Install
$ pip install ema-pytorch
Usage
import torch
from ema_pytorch import EMA
# your neural network as a pytorch module
net = torch.nn.Linear(512, 512)
# wrap your neural network, specify the decay (beta)
ema = EMA(
net,
beta = 0.9999, # exponential moving average factor
update_after_step = 100, # only after this number of .update() calls will it start updating
update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
)
# mutate your network, with SGD or otherwise
with torch.no_grad():
net.weight.copy_(torch.randn_like(net.weight))
net.bias.copy_(torch.randn_like(net.bias))
# you will call the update function on your moving average wrapper
ema.update()
# then, later on, you can invoke the EMA model the same way as your network
data = torch.randn(1, 512)
output = net(data)
ema_output = ema(data)
# if you want to save your ema model, it is recommended you save the entire wrapper
# as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)
# however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model
In order to use the post-hoc synthesized EMA, proposed by Karras et al. in a recent paper, follow the example below
import torch
from ema_pytorch import PostHocEMA
# your neural network as a pytorch module
net = torch.nn.Linear(512, 512)
# wrap your neural network, specify the sigma_rels or gammas
emas = PostHocEMA(
net,
sigma_rels = (0.05, 0.3), # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one
update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
checkpoint_every_num_steps = 10,
checkpoint_folder = './post-hoc-ema-checkpoints' # the folder of saved checkpoints for each sigma_rel (gamma) across timesteps with the hparam above, used to synthesizing a new EMA model after training
)
net.train()
for _ in range(1000):
# mutate your network, with SGD or otherwise
with torch.no_grad():
net.weight.copy_(torch.randn_like(net.weight))
net.bias.copy_(torch.randn_like(net.bias))
# you will call the update function on your moving average wrapper
emas.update()
# now that you have a few checkpoints
# you can synthesize an EMA model with a different sigma_rel (say 0.15)
synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
# output with synthesized EMA
data = torch.randn(1, 512)
synthesized_ema_output = synthesized_ema(data)
Citations
@article{Karras2023AnalyzingAI,
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2023},
volume = {abs/2312.02696},
url = {https://api.semanticscholar.org/CorpusID:265659032}
}
.\lucidrains\ema-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'ema-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.4.3', # 版本号
license='MIT', # 许可证
description = 'Easy way to keep track of exponential moving average version of your pytorch module', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/ema-pytorch', # URL
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'exponential moving average' # 关键词
],
install_requires=[
'beartype', # 安装依赖
'torch>=1.6', # 安装依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 分类
'Intended Audience :: Developers', # 分类
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类
'License :: OSI Approved :: MIT License', # 分类
'Programming Language :: Python :: 3.6', # 分类
],
)
.\lucidrains\En-transformer\denoise.py
# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.optim 模块中导入 Adam 优化器
from torch.optim import Adam
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 导入 sidechainnet 库并重命名为 scn
import sidechainnet as scn
# 从 en_transformer 模块中导入 EnTransformer 类
from en_transformer.en_transformer import EnTransformer
# 设置默认的张量数据类型为 float64
torch.set_default_dtype(torch.float64)
# 定义批量大小为 1
BATCH_SIZE = 1
# 定义每隔多少次梯度累积
GRADIENT_ACCUMULATE_EVERY = 16
# 定义一个循环函数,用于生成数据批次
def cycle(loader, len_thres = 200):
while True:
for data in loader:
# 如果数据序列长度大于指定阈值,则继续循环
if data.seqs.shape[1] > len_thres:
continue
# 生成数据
yield data
# 创建 EnTransformer 模型实例
transformer = EnTransformer(
num_tokens = 21,
dim = 32,
dim_head = 64,
heads = 4,
depth = 4,
rel_pos_emb = True, # 序列中存在固有的顺序(氨基酸链的主干原子)
neighbors = 16
)
# 加载数据集
data = scn.load(
casp_version = 12,
thinning = 30,
with_pytorch = 'dataloaders',
batch_size = BATCH_SIZE,
dynamic_batching = False
)
# 创建数据加载器
dl = cycle(data['train'])
# 使用 Adam 优化器来优化 EnTransformer 模型的参数
optim = Adam(transformer.parameters(), lr=1e-3)
# 将模型移动到 GPU 上
transformer = transformer.cuda()
# 进行训练循环
for _ in range(10000):
for _ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取一个数据批次
batch = next(dl)
seqs, coords, masks = batch.seqs, batch.crds, batch.msks
# 将序列数据移动到 GPU 上并取最大值
seqs = seqs.cuda().argmax(dim = -1)
# 将坐标数据移动到 GPU 上并转换为 float64 类型
coords = coords.cuda().type(torch.float64)
# 将掩码数据移动到 GPU 上并转换为布尔类型
masks = masks.cuda().bool()
# 获取序列长度
l = seqs.shape[1]
# 重新排列坐标数据的维度
coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)
# 保留主干坐标
coords = coords[:, :, 0:3, :]
coords = rearrange(coords, 'b l s c -> b (l s) c')
# 重复序列数据和掩码数据的维度
seq = repeat(seqs, 'b n -> b (n c)', c = 3)
masks = repeat(masks, 'b n -> b (n c)', c = 3)
# 添加噪声到坐标数据
noised_coords = coords + torch.randn_like(coords)
# 使用 Transformer 模型进行特征提取和去噪
feats, denoised_coords = transformer(seq, noised_coords, mask = masks)
# 计算均方误差损失
loss = F.mse_loss(denoised_coords[masks], coords[masks])
# 反向传播并计算梯度
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
# 打印损失值
print('loss:', loss.item())
# 更新优化器
optim.step()
# 清空梯度
optim.zero_grad()
.\lucidrains\En-transformer\en_transformer\en_transformer.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 torch.utils.checkpoint 中导入 checkpoint_sequential 函数
from torch.utils.checkpoint import checkpoint_sequential
# 从 einx 中导入 get_at 函数
from einx import get_at
# 从 einops 中导入 rearrange、repeat、reduce 函数,从 einops.layers.torch 中导入 Rearrange 类
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# 从 taylor_series_linear_attention 中导入 TaylorSeriesLinearAttn 类
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 返回指定数据类型的最小负值的函数
def max_neg_value(t):
return -torch.finfo(t.dtype).max
# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 对输入张量进行 L2 归一化的函数
def l2norm(t):
return F.normalize(t, dim = -1)
# 对 nn.Linear 类型的权重进行小范围初始化的函数
def small_init_(t: nn.Linear):
nn.init.normal_(t.weight, std = 0.02)
nn.init.zeros_(t.bias)
# 动态位置偏置
class DynamicPositionBias(nn.Module):
def __init__(
self,
dim,
*,
heads,
depth,
dim_head,
input_dim = 1,
norm = True
):
super().__init__()
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
self.mlp = nn.ModuleList([])
self.mlp.append(nn.Sequential(
nn.Linear(input_dim, dim),
nn.LayerNorm(dim) if norm else nn.Identity(),
nn.SiLU()
))
for _ in range(depth - 1):
self.mlp.append(nn.Sequential(
nn.Linear(dim, dim),
nn.LayerNorm(dim) if norm else nn.Identity(),
nn.SiLU()
))
self.heads = heads
self.qk_pos_head = nn.Linear(dim, heads)
self.value_pos_head = nn.Linear(dim, dim_head * heads)
def forward(self, pos):
for layer in self.mlp:
pos = layer(pos)
qk_pos = self.qk_pos_head(pos)
value_pos = self.value_pos_head(pos)
qk_pos = rearrange(qk_pos, 'b 1 i j h -> b h i j')
value_pos = rearrange(value_pos, 'b 1 i j (h d) -> b h i j d', h = self.heads)
return qk_pos, value_pos
# 类
# 此类遵循 SE3 Transformers 中的规范化策略
# https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95
# 层归一化类
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer('beta', torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# 坐标归一化类
class CoorsNorm(nn.Module):
def __init__(self, eps = 1e-8, scale_init = 1.):
super().__init__()
self.eps = eps
scale = torch.zeros(1).fill_(scale_init)
self.scale = nn.Parameter(scale)
def forward(self, coors):
norm = coors.norm(dim = -1, keepdim = True)
normed_coors = coors / norm.clamp(min = self.eps)
return normed_coors * self.scale
# 残差连接类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, feats, coors, **kwargs):
feats_out, coors_delta = self.fn(feats, coors, **kwargs)
return feats + feats_out, coors + coors_delta
# GEGLU 激活函数类
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
# 前馈神经网络类
class FeedForward(nn.Module):
def __init__(
self,
*,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
inner_dim = int(dim * mult * 2 / 3)
self.net = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False),
GEGLU(),
LayerNorm(inner_dim),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
def forward(self, feats, coors):
return self.net(feats), 0
class EquivariantAttention(nn.Module):
# 初始化函数,设置Transformer模型的参数
def __init__(
self,
*,
dim, # 输入特征的维度
dim_head = 64, # 每个头的维度
heads = 4, # 多头注意力机制的头数
edge_dim = 0, # 边的特征维度
coors_hidden_dim = 16, # 坐标隐藏层的维度
neighbors = 0, # 邻居节点的数量
only_sparse_neighbors = False, # 是否只使用稀疏邻居
valid_neighbor_radius = float('inf'), # 有效邻居的半径
init_eps = 1e-3, # 初始化的小量值
rel_pos_emb = None, # 相对位置编码
edge_mlp_mult = 2, # 边的多层感知机的倍数
norm_rel_coors = True, # 是否对相对坐标进行归一化
norm_coors_scale_init = 1., # 归一化坐标的初始值
use_cross_product = False, # 是否使用叉积
talking_heads = False, # 是否使用Talking Heads
dropout = 0., # Dropout概率
num_global_linear_attn_heads = 0, # 全局线性注意力机制的头数
linear_attn_dim_head = 8, # 线性注意力机制的头维度
gate_outputs = True, # 是否使用门控输出
gate_init_bias = 10. # 门控初始化偏置
# 初始化函数,设置模型参数初始化方式
def __init__(
self,
heads,
dim,
dim_head,
num_global_linear_attn_heads,
linear_attn_dim_head,
gate_outputs,
gate_init_bias,
talking_heads,
edge_dim,
edge_mlp_mult,
coors_hidden_dim,
norm_coors,
norm_coors_scale_init,
use_cross_product,
rel_pos_emb,
dropout,
init_eps,
neighbors,
only_sparse_neighbors,
valid_neighbor_radius
):
# 调用父类初始化函数
super().__init__()
# 设置缩放因子
self.scale = dim_head ** -0.5
# 对输入进行归一化
self.norm = LayerNorm(dim)
# 设置邻居节点相关参数
self.neighbors = neighbors
self.only_sparse_neighbors = only_sparse_neighbors
self.valid_neighbor_radius = valid_neighbor_radius
# 计算注意力机制内部维度
attn_inner_dim = heads * dim_head
self.heads = heads
# 判断是否有全局线性注意力机制
self.has_linear_attn = num_global_linear_attn_heads > 0
# 初始化全局线性注意力机制
self.linear_attn = TaylorSeriesLinearAttn(
dim = dim,
dim_head = linear_attn_dim_head,
heads = num_global_linear_attn_heads,
gate_value_heads = True,
combine_heads = False
)
# 线性变换,将输入转换为查询、键、值
self.to_qkv = nn.Linear(dim, attn_inner_dim * 3, bias = False)
# 线性变换,将注意力机制输出转换为模型输出
self.to_out = nn.Linear(attn_inner_dim + self.linear_attn.dim_hidden, dim)
# 是否使用门控输出
self.gate_outputs = gate_outputs
if gate_outputs:
# 初始化门控线性层
gate_linear = nn.Linear(dim, 2 * heads)
nn.init.zeros_(gate_linear.weight)
nn.init.constant_(gate_linear.bias, gate_init_bias)
# 设置输出门控
self.to_output_gates = nn.Sequential(
gate_linear,
nn.Sigmoid(),
Rearrange('b n (l h) -> l b h n 1', h = heads)
)
# 是否使用Talking Heads
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
# 初始化边缘MLP
self.edge_mlp = None
has_edges = edge_dim > 0
if has_edges:
edge_input_dim = heads + edge_dim
edge_hidden = edge_input_dim * edge_mlp_mult
# 设置边缘MLP
self.edge_mlp = nn.Sequential(
nn.Linear(edge_input_dim, edge_hidden, bias = False),
nn.GELU(),
nn.Linear(edge_hidden, heads, bias = False)
)
# 设置坐标MLP
self.coors_mlp = nn.Sequential(
nn.GELU(),
nn.Linear(heads, heads, bias = False)
)
else:
# 设置坐标MLP
self.coors_mlp = nn.Sequential(
nn.Linear(heads, coors_hidden_dim, bias = False),
nn.GELU(),
nn.Linear(coors_hidden_dim, heads, bias = False)
)
# 设置坐标门控
self.coors_gate = nn.Linear(heads, heads)
small_init_(self.coors_gate)
# 是否使用交叉乘积
self.use_cross_product = use_cross_product
if use_cross_product:
# 设置交叉坐标MLP
self.cross_coors_mlp = nn.Sequential(
nn.Linear(heads, coors_hidden_dim, bias = False),
nn.GELU(),
nn.Linear(coors_hidden_dim, heads * 2, bias = False)
)
# 设置交叉坐标门控
self.cross_coors_gate_i = nn.Linear(heads, heads)
self.cross_coors_gate_j = nn.Linear(heads, heads)
small_init_(self.cross_coors_gate_i)
small_init_(self.cross_coors_gate_j)
# 设置坐标归一化
self.norm_rel_coors = CoorsNorm(scale_init = norm_coors_scale_init) if norm_rel_coors else nn.Identity()
# 设置坐标组合参数
num_coors_combine_heads = (2 if use_cross_product else 1) * heads
self.coors_combine = nn.Parameter(torch.randn(num_coors_combine_heads))
# 位置嵌入
# 用于序列和残基/原子之间的相对距离
self.rel_pos_emb = rel_pos_emb
# 动态位置偏置MLP
self.dynamic_pos_bias_mlp = DynamicPositionBias(
dim = dim // 2,
heads = heads,
dim_head = dim_head,
depth = 3,
input_dim = (2 if rel_pos_emb else 1)
)
# 丢弃层
self.node_dropout = nn.Dropout(dropout)
self.coor_dropout = nn.Dropout(dropout)
# 初始化
self.init_eps = init_eps
self.apply(self.init_)
# 初始化函数,设置模型参数初始化方式
def init_(self, module):
if type(module) in {nn.Linear}:
# 初始化线性层参数
nn.init.normal_(module.weight, std = self.init_eps)
# 前向传播函数
def forward(
self,
feats,
coors,
edges = None,
mask = None,
adj_mat = None
# 定义一个 Transformer 模型的 Block 类,包含注意力机制和前馈神经网络
class Block(nn.Module):
def __init__(self, attn, ff):
super().__init__()
self.attn = attn
self.ff = ff
# 前向传播函数,接收输入和坐标变化,返回处理后的特征、坐标、掩码、边缘和邻接矩阵
def forward(self, inp, coor_changes = None):
feats, coors, mask, edges, adj_mat = inp
feats, coors = self.attn(feats, coors, edges = edges, mask = mask, adj_mat = adj_mat)
feats, coors = self.ff(feats, coors)
return (feats, coors, mask, edges, adj_mat)
# 定义一个 Encoder Transformer 模型
class EnTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
num_tokens = None,
rel_pos_emb = False,
dim_head = 64,
heads = 8,
num_edge_tokens = None,
edge_dim = 0,
coors_hidden_dim = 16,
neighbors = 0,
only_sparse_neighbors = False,
num_adj_degrees = None,
adj_dim = 0,
valid_neighbor_radius = float('inf'),
init_eps = 1e-3,
norm_rel_coors = True,
norm_coors_scale_init = 1.,
use_cross_product = False,
talking_heads = False,
checkpoint = False,
attn_dropout = 0.,
ff_dropout = 0.,
num_global_linear_attn_heads = 0,
gate_outputs = True
):
super().__init__()
# 断言维度每个头部应大于等于32,以使旋转嵌入正常工作
assert dim_head >= 32, 'your dimension per head should be greater than 32 for rotary embeddings to work well'
# 断言邻接度数大于等于1
assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'
# 如果只有稀疏邻居,则将邻接度数设置为1
if only_sparse_neighbors:
num_adj_degrees = default(num_adj_degrees, 1)
# 初始化嵌入层
self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
# 初始化邻接矩阵嵌入层
self.num_adj_degrees = num_adj_degrees
self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None
adj_dim = adj_dim if exists(num_adj_degrees) else 0
self.checkpoint = checkpoint
self.layers = nn.ModuleList([])
# 循环创建 Transformer 模型的 Block 层
for ind in range(depth):
self.layers.append(Block(
Residual(EquivariantAttention(
dim = dim,
dim_head = dim_head,
heads = heads,
coors_hidden_dim = coors_hidden_dim,
edge_dim = (edge_dim + adj_dim),
neighbors = neighbors,
only_sparse_neighbors = only_sparse_neighbors,
valid_neighbor_radius = valid_neighbor_radius,
init_eps = init_eps,
rel_pos_emb = rel_pos_emb,
norm_rel_coors = norm_rel_coors,
norm_coors_scale_init = norm_coors_scale_init,
use_cross_product = use_cross_product,
talking_heads = talking_heads,
dropout = attn_dropout,
num_global_linear_attn_heads = num_global_linear_attn_heads,
gate_outputs = gate_outputs
)),
Residual(FeedForward(
dim = dim,
dropout = ff_dropout
))
))
# 前向传播函数,接收特征、坐标、边缘、掩码、邻接矩阵等参数,返回处理后的结果
def forward(
self,
feats,
coors,
edges = None,
mask = None,
adj_mat = None,
return_coor_changes = False,
**kwargs
):
# 获取特征的批次大小
b = feats.shape[0]
# 如果存在 token_emb 属性,则对特征进行处理
if exists(self.token_emb):
feats = self.token_emb(feats)
# 如果存在 edge_emb 属性,则对边进行处理
if exists(self.edge_emb):
assert exists(edges), 'edges must be passed in as (batch x seq x seq) indicating edge type'
edges = self.edge_emb(edges)
# 检查是否存在邻接矩阵,并且 num_adj_degrees 大于 0
assert not (exists(adj_mat) and (not exists(self.num_adj_degrees) or self.num_adj_degrees == 0)), 'num_adj_degrees must be greater than 0 if you are passing in an adjacency matrix'
# 如果存在 num_adj_degrees 属性
if exists(self.num_adj_degrees):
assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'
# 如果邻接矩阵的维度为 2,则进行扩展
if len(adj_mat.shape) == 2:
adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)
# 克隆邻接矩阵并转换为长整型
adj_indices = adj_mat.clone().long()
# 遍历 num_adj_degrees - 1 次
for ind in range(self.num_adj_degrees - 1):
degree = ind + 2
# 计算下一阶邻接矩阵
next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
adj_indices.masked_fill_(next_degree_mask, degree)
adj_mat = next_degree_adj_mat.clone()
# 如果存在 adj_emb 属性,则对邻接矩阵进行处理
if exists(self.adj_emb):
adj_emb = self.adj_emb(adj_indices)
edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb
# 检查是否需要返回坐标变化,并且模型处于训练模式
assert not (return_coor_changes and self.training), 'you must be eval mode in order to return coordinates'
# 遍历层
coor_changes = [coors]
inp = (feats, coors, mask, edges, adj_mat)
# 如果处于训练模式且启用了检查点,则使用检查点跨块进行内存节省
if self.training and self.checkpoint:
inp = checkpoint_sequential(self.layers, len(self.layers), inp)
else:
# 遍历块
for layer in self.layers:
inp = layer(inp)
coor_changes.append(inp[1]) # 为可视化添加坐标
# 返回
feats, coors, *_ = inp
# 如果需要返回坐标变化,则返回特征、坐标和坐标变化
if return_coor_changes:
return feats, coors, coor_changes
# 否则只返回特征和坐标
return feats, coors
.\lucidrains\En-transformer\en_transformer\utils.py
# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos
# 定义绕 z 轴旋转的函数,参数为旋转角度 gamma
def rot_z(gamma):
# 返回一个包含 z 轴旋转矩阵的张量
return torch.tensor([
[cos(gamma), -sin(gamma), 0],
[sin(gamma), cos(gamma), 0],
[0, 0, 1]
], dtype = gamma.dtype)
# 定义绕 y 轴旋转的函数,参数为旋转角度 beta
def rot_y(beta):
# 返回一个包含 y 轴旋转矩阵的张量
return torch.tensor([
[cos(beta), 0, sin(beta)],
[0, 1, 0],
[-sin(beta), 0, cos(beta)]
], dtype = beta.dtype)
# 定义绕任意轴旋转的函数,参数为三个旋转角度 alpha, beta, gamma
def rot(alpha, beta, gamma):
# 返回绕 z 轴、y 轴、z 轴旋转矩阵的乘积
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
.\lucidrains\En-transformer\en_transformer\__init__.py
# 从 en_transformer 模块中导入 EquivariantAttention 和 EnTransformer 类
from en_transformer.en_transformer import EquivariantAttention, EnTransformer
E(n)-Equivariant Transformer
Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention mechanisms and ideas from transformer architecture.
Update: Used for designing of CDR loops in antibodies!
Install
$ pip install En-transformer
Usage
import torch
from en_transformer import EnTransformer
model = EnTransformer(
dim = 512,
depth = 4, # depth
dim_head = 64, # dimension per head
heads = 8, # number of heads
edge_dim = 4, # dimension of edge feature
neighbors = 64, # only do attention between coordinates N nearest neighbors - set to 0 to turn off
talking_heads = True, # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
checkpoint = True, # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
use_cross_product = True, # use cross product vectors (idea by @MattMcPartlon)
num_global_linear_attn_heads = 4 # if your number of neighbors above is low, you can assign a certain number of attention heads to weakly attend globally to all other nodes through linear attention (https://arxiv.org/abs/1812.01243)
)
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)
mask = torch.ones(1, 1024).bool()
feats, coors = model(feats, coors, edges, mask = mask) # (1, 1024, 512), (1, 1024, 3)
Letting the network take care of both atomic and bond type embeddings
import torch
from en_transformer import EnTransformer
model = EnTransformer(
num_tokens = 10, # number of unique nodes, say atoms
rel_pos_emb = True, # set this to true if your sequence is not an unordered set. it will accelerate convergence
num_edge_tokens = 5, # number of unique edges, say bond types
dim = 128,
edge_dim = 16,
depth = 3,
heads = 4,
dim_head = 32,
neighbors = 8
)
atoms = torch.randint(0, 10, (1, 16)) # 10 different types of atoms
bonds = torch.randint(0, 5, (1, 16, 16)) # 5 different types of bonds (n x n)
coors = torch.randn(1, 16, 3) # atomic spatial coordinates
feats_out, coors_out = model(atoms, coors, edges = bonds) # (1, 16, 512), (1, 16, 3)
If you would like to only attend to sparse neighbors, as defined by an adjacency matrix (say for atoms), you have to set one more flag and then pass in the N x N
adjacency matrix.
import torch
from en_transformer import EnTransformer
model = EnTransformer(
num_tokens = 10,
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
neighbors = 0,
only_sparse_neighbors = True, # must be set to true
num_adj_degrees = 2, # the number of degrees to derive from 1st degree neighbors passed in
adj_dim = 8 # whether to pass the adjacency degree information as an edge embedding
)
atoms = torch.randint(0, 10, (1, 16))
coors = torch.randn(1, 16, 3)
# naively assume a single chain of atoms
i = torch.arange(atoms.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
# adjacency matrix must be passed in
feats_out, coors_out = model(atoms, coors, adj_mat = adj_mat) # (1, 16, 512), (1, 16, 3)
Edges
If you need to pass in continuous edges
import torch
from en_transformer import EnTransformer
from en_transformer.utils import rot
model = EnTransformer(
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
edge_dim = 4,
num_nearest_neighbors = 0,
only_sparse_neighbors = True
)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
feats1, coors1 = model(feats, coors, adj_mat = adj_mat, edges = edges)
Example
To run a protein backbone coordinate denoising toy task, first install sidechainnet
$ pip install sidechainnet
Then
$ python denoise.py
Todo
Citations
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{shazeer2020talkingheads,
title = {Talking-Heads Attention},
author = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
year = {2020},
eprint = {2003.02436},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{liu2021swin,
title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year = {2021},
eprint = {2111.09883},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@inproceedings{Kim2020TheLC,
title = {The Lipschitz Constant of Self-Attention},
author = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
booktitle = {International Conference on Machine Learning},
year = {2020},
url = {https://api.semanticscholar.org/CorpusID:219530837}
}
@article {Mahajan2023.07.15.549154,
author = {Sai Pooja Mahajan and Jeffrey A. Ruffolo and Jeffrey J. Gray},
title = {Contextual protein and antibody encodings from equivariant graph transformers},
elocation-id = {2023.07.15.549154},
year = {2023},
doi = {10.1101/2023.07.15.549154},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154},
eprint = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154.full.pdf},
journal = {bioRxiv}
}
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{Arora2023ZoologyMA,
title = {Zoology: Measuring and Improving Recall in Efficient Language Models},
author = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266149332}
}
.\lucidrains\En-transformer\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'En-transformer', # 包的名称
packages = find_packages(), # 查找所有包
version = '1.6.5', # 版本号
license='MIT', # 许可证
description = 'E(n)-Equivariant Transformer', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/En-transformer', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'equivariance',
'transformer'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'einx',
'taylor-series-linear-attention>=0.1.4',
'torch>=1.7'
],
setup_requires=[ # 设置需要的依赖
'pytest-runner',
],
tests_require=[ # 测试需要的依赖
'pytest'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\En-transformer\tests\test_equivariance.py
# 导入 torch 库
import torch
# 从 en_transformer.utils 模块中导入 rot 函数
from en_transformer.utils import rot
# 从 en_transformer 模块中导入 EnTransformer 类
from en_transformer import EnTransformer
# 设置默认张量数据类型为 float64
torch.set_default_dtype(torch.float64)
# 测试函数,用于测试 README 中的示例
def test_readme():
# 创建 EnTransformer 模型对象,设置参数
model = EnTransformer(
dim = 512,
depth = 1,
dim_head = 64,
heads = 8,
edge_dim = 4,
neighbors = 6
)
# 生成随机输入特征、坐标和边
feats = torch.randn(1, 32, 512)
coors = torch.randn(1, 32, 3)
edges = torch.randn(1, 32, 1024, 4)
# 创建掩码张量
mask = torch.ones(1, 32).bool()
# 调用模型进行前向传播
feats, coors = model(feats, coors, edges, mask = mask)
# 断言测试结果为真
assert True, 'it runs'
# 测试函数,用于测试等变性
def test_equivariance():
# 创建 EnTransformer 模型对象,设置参数
model = EnTransformer(
dim = 512,
depth = 1,
edge_dim = 4,
rel_pos_emb = True
)
# 生成随机旋转矩阵 R 和平移向量 T
R = rot(*torch.rand(3))
T = torch.randn(1, 1, 3)
# 生成随机输入特征、坐标和边
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
# 调用模型进行前向传播
feats1, coors1 = model(feats, coors @ R + T, edges)
feats2, coors2 = model(feats, coors, edges)
# 断言特征等变
assert torch.allclose(feats1, feats2, atol = 1e-6), 'type 0 features are invariant'
# 断言坐标等变
assert torch.allclose(coors1, (coors2 @ R + T), atol = 1e-6), 'type 1 features are equivariant'
# 其他测试函数的注释与上述两个测试函数类似,不再重复注释
# 请根据上述示例注释完成以下测试函数
def test_equivariance_with_cross_product():
model = EnTransformer(
dim = 512,
depth = 1,
edge_dim = 4,
rel_pos_emb = True,
use_cross_product = True
)
R = rot(*torch.rand(3))
T = torch.randn(1, 1, 3)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
feats1, coors1 = model(feats, coors @ R + T, edges)
feats2, coors2 = model(feats, coors, edges)
assert torch.allclose(feats1, feats2, atol = 1e-6), 'type 0 features are invariant'
assert torch.allclose(coors1, (coors2 @ R + T), atol = 1e-6), 'type 1 features are equivariant'
def test_equivariance_with_nearest_neighbors():
model = EnTransformer(
dim = 512,
depth = 1,
edge_dim = 4,
neighbors = 5
)
R = rot(*torch.rand(3))
T = torch.randn(1, 1, 3)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
feats1, coors1 = model(feats, coors @ R + T, edges)
feats2, coors2 = model(feats, coors, edges)
assert torch.allclose(feats1, feats2, atol = 1e-6), 'type 0 features are invariant'
assert torch.allclose(coors1, (coors2 @ R + T), atol = 1e-6), 'type 1 features are equivariant'
def test_equivariance_with_sparse_neighbors():
model = EnTransformer(
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
neighbors = 0,
only_sparse_neighbors = True
)
R = rot(*torch.rand(3))
T = torch.randn(1, 1, 3)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
feats1, coors1 = model(feats, coors @ R + T, adj_mat = adj_mat)
feats2, coors2 = model(feats, coors, adj_mat = adj_mat)
assert torch.allclose(feats1, feats2, atol = 1e-6), 'type 0 features are invariant'
assert torch.allclose(coors1, (coors2 @ R + T), atol = 1e-6), 'type 1 features are equivariant'
def test_depth():
model = EnTransformer(
dim = 8,
depth = 12,
edge_dim = 4,
neighbors = 16
)
feats = torch.randn(1, 128, 8)
coors = torch.randn(1, 128, 3)
edges = torch.randn(1, 128, 128, 4)
feats, coors = model(feats, coors, edges)
assert not torch.any(torch.isnan(feats)), 'no NaN in features'
assert not torch.any(torch.isnan(coors)), 'no NaN in coordinates'
.\lucidrains\enformer-pytorch\enformer_pytorch\config_enformer.py
# 导入预训练配置类 PretrainedConfig 从 transformers 模块
from transformers import PretrainedConfig
# 创建 EnformerConfig 类,继承自 PretrainedConfig 类
class EnformerConfig(PretrainedConfig):
# 模型类型为 "enformer"
model_type = "enformer"
# 初始化函数,接受多个参数
def __init__(
self,
dim = 1536, # 维度为 1536
depth = 11, # 深度为 11
heads = 8, # 头数为 8
output_heads = dict(human = 5313, mouse= 1643), # 输出头数为人类 5313,老鼠 1643
target_length = 896, # 目标长度为 896
attn_dim_key = 64, # 注意力维度为 64
dropout_rate = 0.4, # 丢弃率为 0.4
attn_dropout = 0.05, # 注意力丢弃率为 0.05
pos_dropout = 0.01, # 位置丢弃率为 0.01
use_checkpointing = False, # 是否使用检查点为 False
use_convnext = False, # 是否使用卷积为 False
num_downsamples = 7, # 下采样次数为 7,默认 Enformer 下采样 2 ** 7 == 128 倍,可以更改以获得更高分辨率
dim_divisible_by = 128, # 维度可被 128 整除
use_tf_gamma = False, # 是否使用 TensorFlow Gamma 为 False
**kwargs, # 其他关键字参数
):
# 初始化各个参数
self.dim = dim
self.depth = depth
self.heads = heads
self.output_heads = output_heads
self.target_length = target_length
self.attn_dim_key = attn_dim_key
self.dropout_rate = dropout_rate
self.attn_dropout = attn_dropout
self.pos_dropout = pos_dropout
self.use_checkpointing = use_checkpointing
self.num_downsamples = num_downsamples
self.dim_divisible_by = dim_divisible_by
self.use_tf_gamma = use_tf_gamma
# 调用父类的初始化函数
super().__init__(**kwargs)
.\lucidrains\enformer-pytorch\enformer_pytorch\data.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch.utils.data 中导入 Dataset 类
from torch.utils.data import Dataset
# 导入 polars 库并重命名为 pl
import polars as pl
# 导入 numpy 库并重命名为 np
import numpy as np
# 从 random 中导入 randrange 和 random 函数
from random import randrange, random
# 从 pathlib 中导入 Path 类
from pathlib import Path
# 从 pyfaidx 中导入 Fasta 类
import pyfaidx.Fasta
# 辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 返回输入值
def identity(t):
return t
# 将输入值转换为列表
def cast_list(t):
return t if isinstance(t, list) else [t]
# 返回一个随机布尔值
def coin_flip():
return random() > 0.5
# 基因组函数转换
# 创建一个包含 ASCII 码对应索引的张量
seq_indices_embed = torch.zeros(256).long()
seq_indices_embed[ord('a')] = 0
seq_indices_embed[ord('c')] = 1
seq_indices_embed[ord('g')] = 2
seq_indices_embed[ord('t')] = 3
seq_indices_embed[ord('n')] = 4
seq_indices_embed[ord('A')] = 0
seq_indices_embed[ord('C')] = 1
seq_indices_embed[ord('G')] = 2
seq_indices_embed[ord('T')] = 3
seq_indices_embed[ord('N')] = 4
seq_indices_embed[ord('.')] = -1
# 创建一个包含 one-hot 编码的张量
one_hot_embed = torch.zeros(256, 4)
one_hot_embed[ord('a')] = torch.Tensor([1., 0., 0., 0.])
one_hot_embed[ord('c')] = torch.Tensor([0., 1., 0., 0.])
one_hot_embed[ord('g')] = torch.Tensor([0., 0., 1., 0.])
one_hot_embed[ord('t')] = torch.Tensor([0., 0., 0., 1.])
one_hot_embed[ord('n')] = torch.Tensor([0., 0., 0., 0.])
one_hot_embed[ord('A')] = torch.Tensor([1., 0., 0., 0.])
one_hot_embed[ord('C')] = torch.Tensor([0., 1., 0., 0.])
one_hot_embed[ord('G')] = torch.Tensor([0., 0., 1., 0.])
one_hot_embed[ord('T')] = torch.Tensor([0., 0., 0., 1.])
one_hot_embed[ord('N')] = torch.Tensor([0., 0., 0., 0.])
one_hot_embed[ord('.')] = torch.Tensor([0.25, 0.25, 0.25, 0.25])
# 创建一个用于反向互补的映射张量
reverse_complement_map = torch.Tensor([3, 2, 1, 0, 4]).long()
# 将字符串转换为张量
def torch_fromstring(seq_strs):
batched = not isinstance(seq_strs, str)
seq_strs = cast_list(seq_strs)
np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs))
seq_chrs = list(map(torch.from_numpy, np_seq_chrs))
return torch.stack(seq_chrs) if batched else seq_chrs[0]
# 将字符串转换为序列索引
def str_to_seq_indices(seq_strs):
seq_chrs = torch_fromstring(seq_strs)
return seq_indices_embed[seq_chrs.long()]
# 将字符串转换为 one-hot 编码
def str_to_one_hot(seq_strs):
seq_chrs = torch_fromstring(seq_strs)
return one_hot_embed[seq_chrs.long()]
# 将序列索引转换为 one-hot 编码
def seq_indices_to_one_hot(t, padding = -1):
is_padding = t == padding
t = t.clamp(min = 0)
one_hot = F.one_hot(t, num_classes = 5)
out = one_hot[..., :4].float()
out = out.masked_fill(is_padding[..., None], 0.25)
return out
# 数据增强
# 反向互补序列索引
def seq_indices_reverse_complement(seq_indices):
complement = reverse_complement_map[seq_indices.long()]
return torch.flip(complement, dims = (-1,))
# 反向互补 one-hot 编码
def one_hot_reverse_complement(one_hot):
*_, n, d = one_hot.shape
assert d == 4, 'must be one hot encoding with last dimension equal to 4'
return torch.flip(one_hot, (-1, -2))
# 处理 bed 文件
# 定义 FastaInterval 类
class FastaInterval():
def __init__(
self,
*,
fasta_file,
context_length = None,
return_seq_indices = False,
shift_augs = None,
rc_aug = False
):
fasta_file = Path(fasta_file)
assert fasta_file.exists(), 'path to fasta file must exist'
self.seqs = Fasta(str(fasta_file))
self.return_seq_indices = return_seq_indices
self.context_length = context_length
self.shift_augs = shift_augs
self.rc_aug = rc_aug
# 定义一个方法,用于生成指定染色体上指定区间的序列
def __call__(self, chr_name, start, end, return_augs = False):
# 计算区间长度
interval_length = end - start
# 获取染色体序列
chromosome = self.seqs[chr_name]
# 获取染色体序列长度
chromosome_length = len(chromosome)
# 如果存在平移增强参数
if exists(self.shift_augs):
# 获取最小和最大平移值
min_shift, max_shift = self.shift_augs
max_shift += 1
# 计算实际的最小和最大平移值
min_shift = max(start + min_shift, 0) - start
max_shift = min(end + max_shift, chromosome_length) - end
# 随机选择平移值
rand_shift = randrange(min_shift, max_shift)
start += rand_shift
end += rand_shift
# 初始化左右填充值
left_padding = right_padding = 0
# 如果存在上下文长度参数且区间长度小于上下文长度
if exists(self.context_length) and interval_length < self.context_length:
# 计算额外的序列长度
extra_seq = self.context_length - interval_length
# 计算左右额外序列长度
extra_left_seq = extra_seq // 2
extra_right_seq = extra_seq - extra_left_seq
start -= extra_left_seq
end += extra_right_seq
# 处理左边界溢出
if start < 0:
left_padding = -start
start = 0
# 处理右边界溢出
if end > chromosome_length:
right_padding = end - chromosome_length
end = chromosome_length
# 生成序列并进行填充
seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
# 判断是否需要进行反向互补增强
should_rc_aug = self.rc_aug and coin_flip()
# 如果需要返回序列索引
if self.return_seq_indices:
# 将序列转换为索引
seq = str_to_seq_indices(seq)
# 如果需要反向互补增强
if should_rc_aug:
seq = seq_indices_reverse_complement(seq)
return seq
# 将序列转换为独热编码
one_hot = str_to_one_hot(seq)
# 如果需要反向互补增强
if should_rc_aug:
one_hot = one_hot_reverse_complement(one_hot)
# 如果不需要返回增强数据
if not return_augs:
return one_hot
# 返回平移整数以及是否激活反向互补的布尔值
rand_shift_tensor = torch.tensor([rand_shift])
rand_aug_bool_tensor = torch.tensor([should_rc_aug])
return one_hot, rand_shift_tensor, rand_aug_bool_tensor
# 定义一个继承自 Dataset 的 GenomeIntervalDataset 类
class GenomeIntervalDataset(Dataset):
# 初始化函数,接受多个参数
def __init__(
self,
bed_file,
fasta_file,
filter_df_fn = identity,
chr_bed_to_fasta_map = dict(),
context_length = None,
return_seq_indices = False,
shift_augs = None,
rc_aug = False,
return_augs = False
):
# 调用父类的初始化函数
super().__init__()
# 将 bed_file 转换为 Path 对象
bed_path = Path(bed_file)
# 断言 bed 文件路径存在
assert bed_path.exists(), 'path to .bed file must exist'
# 读取 bed 文件内容到 DataFrame
df = pl.read_csv(str(bed_path), separator = '\t', has_header = False)
# 对 DataFrame 应用过滤函数
df = filter_df_fn(df)
# 将过滤后的 DataFrame 赋值给实例变量 df
self.df = df
# 如果 bed 文件中的染色体名称与 fasta 文件中的键名不同,可以在运行时重新映射
self.chr_bed_to_fasta_map = chr_bed_to_fasta_map
# 创建 FastaInterval 对象,传入 fasta 文件路径和其他参数
self.fasta = FastaInterval(
fasta_file = fasta_file,
context_length = context_length,
return_seq_indices = return_seq_indices,
shift_augs = shift_augs,
rc_aug = rc_aug
)
# 设置是否返回增强数据的标志
self.return_augs = return_augs
# 返回数据集的长度
def __len__(self):
return len(self.df)
# 根据索引获取数据
def __getitem__(self, ind):
# 获取指定索引处的区间信息
interval = self.df.row(ind)
# 解析区间信息中的染色体名称、起始位置和结束位置
chr_name, start, end = (interval[0], interval[1], interval[2])
# 如果染色体名称需要重新映射,则进行映射
chr_name = self.chr_bed_to_fasta_map.get(chr_name, chr_name)
# 调用 FastaInterval 对象的方法,返回指定区间的数据
return self.fasta(chr_name, start, end, return_augs = self.return_augs)
.\lucidrains\enformer-pytorch\enformer_pytorch\finetune.py
# 导入 torch 库
import torch
# 导入类型提示 Optional
from typing import Optional
# 从 copy 模块中导入 deepcopy 函数
from copy import deepcopy
# 从 contextlib 模块中导入 contextmanager 装饰器
from contextlib import contextmanager
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 torch 模块中导入 nn、einsum
from torch import nn, einsum
# 从 einops 模块中导入 rearrange、repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange
# 从 enformer_pytorch.modeling_enformer 模块中导入 Enformer、poisson_loss 函数
from enformer_pytorch.modeling_enformer import Enformer, poisson_loss
# 从 discrete_key_value_bottleneck_pytorch 模块中导入 DiscreteKeyValueBottleneck 类
# 定义 exists 函数,判断变量是否存在
def exists(val):
return val is not None
# 定义 default 函数,如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 定义 null_context 上下文管理器
@contextmanager
def null_context():
yield
# 定义 better sequential 函数,返回过滤掉不存在的模块的 nn.Sequential 对象
def Sequential(*modules):
return nn.Sequential(*filter(exists, modules))
# 控制层的冻结
# 设置模块的 requires_grad 属性
def set_module_requires_grad_(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad
# 冻结所有层
def freeze_all_layers_(module):
set_module_requires_grad_(module, False)
# 解冻所有层
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
# 冻结批归一化层
def freeze_batchnorms_(model):
bns = [m for m in model.modules() if isinstance(m, nn.BatchNorm1d)]
for bn in bns:
bn.eval()
bn.track_running_stats = False
set_module_requires_grad_(bn, False)
# 冻结除了层归一化层之外的所有层
def freeze_all_but_layernorms_(model):
for m in model.modules():
set_module_requires_grad_(m, isinstance(m, nn.LayerNorm))
# 冻结除了最后 N 层之外的所有层
def freeze_all_but_last_n_layers_(enformer, n):
assert isinstance(enformer, Enformer)
freeze_all_layers_(enformer)
transformer_blocks = enformer.transformer
for module in transformer_blocks[-n:]:
set_module_requires_grad_(module, True)
# 获取 Enformer 的嵌入
def get_enformer_embeddings(
model,
seq,
freeze = False,
train_layernorms_only = False,
train_last_n_layers_only = None,
enformer_kwargs: dict = {}
):
freeze_batchnorms_(model)
if train_layernorms_only:
assert not freeze, 'you set the intent to train the layernorms of the enformer, yet also indicated you wanted to freeze the entire model'
freeze_all_but_layernorms_(model)
if exists(train_last_n_layers_only):
assert not freeze, 'you set the intent to train last N layers of enformer, but also indicated you wanted to freeze the entire network'
freeze_all_but_last_n_layers_(model, train_last_n_layers_only)
enformer_context = null_context() if not freeze else torch.no_grad()
with enformer_context:
embeddings = model(seq, return_only_embeddings = True, **enformer_kwargs)
if freeze:
embeddings.detach_()
return embeddings
# 微调包装类
# 额外头部投影,类似于人类和老鼠轨迹的训练方式
class HeadAdapterWrapper(nn.Module):
def __init__(
self,
*,
enformer,
num_tracks,
post_transformer_embed = False, # 是否从变换器后面的嵌入中获取嵌入,而不是在最终的逐点卷积之后获取 - 这将添加另一个层归一化
discrete_key_value_bottleneck = False,
bottleneck_num_memories = 256,
bottleneck_num_codebooks = 4,
bottleneck_decay = 0.9,
transformer_embed_fn: nn.Module = nn.Identity(),
output_activation: Optional[nn.Module] = nn.Softplus(),
auto_set_target_length = True
):
# 调用父类的构造函数
super().__init__()
# 断言 enformer 是 Enformer 类的实例
assert isinstance(enformer, Enformer)
# 计算 enformer_hidden_dim,如果 post_transformer_embed 为 False,则乘以 2
enformer_hidden_dim = enformer.dim * (2 if not post_transformer_embed else 1)
# 设置离散键值瓶颈的标志
self.discrete_key_value_bottleneck = discrete_key_value_bottleneck
# 如果启用了离散键值瓶颈
if discrete_key_value_bottleneck:
# 创建 DiscreteKeyValueBottleneck 对象
enformer = DiscreteKeyValueBottleneck(
encoder = enformer,
dim = enformer_hidden_dim,
num_memory_codebooks = bottleneck_num_codebooks,
num_memories = bottleneck_num_memories,
dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
decay = bottleneck_decay,
)
# 设置 post_transformer_embed 标志
self.post_transformer_embed = post_transformer_embed
# 设置 enformer 属性
self.enformer = enformer
# 设置 auto_set_target_length 标志
self.auto_set_target_length = auto_set_target_length
# 如果启用了 post_transformer_embed
if post_transformer_embed:
# 深拷贝 enformer 对象
self.enformer = deepcopy(enformer)
# 将 enformer 的最后一层设置为 nn.Identity()
self.enformer._trunk[-1] = nn.Identity()
# 将 enformer 的 final_pointwise 层设置为 nn.Identity()
self.enformer.final_pointwise = nn.Identity()
# 设置 post_embed_transform 属性
self.post_embed_transform = Sequential(
transformer_embed_fn,
nn.LayerNorm(enformer_hidden_dim) if post_transformer_embed else None
)
# 设置 to_tracks 属性
self.to_tracks = Sequential(
nn.Linear(enformer_hidden_dim, num_tracks),
output_activation
)
# 定义前向传播函数
def forward(
self,
seq,
*,
target = None,
freeze_enformer = False,
finetune_enformer_ln_only = False,
finetune_last_n_layers_only = None
):
# 初始化 enformer_kwargs 字典
enformer_kwargs = dict()
# 如果存在目标数据并且 auto_set_target_length 为 True
if exists(target) and self.auto_set_target_length:
# 设置 enformer_kwargs 中的 target_length 键值对
enformer_kwargs = dict(target_length = target.shape[-2])
# 如果启用了离散键值瓶颈
if self.discrete_key_value_bottleneck:
# 获取 enformer 的 embeddings
embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
else:
# 获取 enformer 的 embeddings
embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)
# 将 embeddings 转换为预测结果
preds = self.to_tracks(embeddings)
# 如果不存在目标数据,则返回预测结果
if not exists(target):
return preds
# 计算 Poisson 损失并返回结果
return poisson_loss(preds, target)
# 定义一个包装器,允许为每个轨道提供上下文维度
# 上下文嵌入将投影到头线性投影(超网络)的权重和偏置中
class ContextAdapterWrapper(nn.Module):
def __init__(
self,
*,
enformer, # Enformer 模型
context_dim, # 上下文维度
discrete_key_value_bottleneck = False, # 是否使用离散键值瓶颈
bottleneck_num_memories = 256, # 瓶颈内存数量
bottleneck_num_codebooks = 4, # 瓶颈码书数量
bottleneck_decay = 0.9, # 瓶颈衰减率
auto_set_target_length = True, # 是否自动设置目标长度
output_activation: Optional[nn.Module] = nn.Softplus() # 输出激活函数,默认为 Softplus
):
super().__init__()
assert isinstance(enformer, Enformer)
enformer_hidden_dim = enformer.dim * 2
self.discrete_key_value_bottleneck = discrete_key_value_bottleneck
if discrete_key_value_bottleneck:
enformer = DiscreteKeyValueBottleneck(
encoder = enformer,
dim = enformer_hidden_dim,
num_memory_codebooks = bottleneck_num_codebooks,
num_memories = bottleneck_num_memories,
dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
decay = bottleneck_decay,
)
self.enformer = enformer
self.auto_set_target_length = auto_set_target_length
self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer_hidden_dim)) # 上下文权重参数
self.to_context_bias = nn.Parameter(torch.randn(context_dim)) # 上下文偏置参数
self.activation = default(output_activation, nn.Identity()) # 激活函数
def forward(
self,
seq, # 输入序列
*,
context, # 上下文
target = None, # 目标
freeze_enformer = False, # 是否冻结 Enformer
finetune_enformer_ln_only = False, # 是否仅微调 Enformer 层归一化
finetune_last_n_layers_only = None # 仅微调最后 n 层
):
enformer_kwargs = dict()
if exists(target) and self.auto_set_target_length:
enformer_kwargs = dict(target_length = target.shape[-2])
if self.discrete_key_value_bottleneck:
embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
else:
embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)
weights = einsum('t d, d e -> t e', context, self.to_context_weights) # 计算权重
bias = einsum('t d, d -> t', context, self.to_context_bias) # 计算偏置
pred = einsum('b n d, t d -> b n t', embeddings, weights) + bias # 预测结果
pred = self.activation(pred) # 应用激活函数
if not exists(target):
return pred
return poisson_loss(pred, target) # 返回 Poisson 损失
# 包装器,执行上下文的注意力聚合,上下文可以是一个标记列表(批次 x 序列 x 维度)
class ContextAttentionAdapterWrapper(nn.Module):
def __init__(
self,
*,
enformer, # Enformer 模型
context_dim, # 上下文维度
heads = 8, # 头数
dim_head = 64, # 每个头的维度
discrete_key_value_bottleneck = False, # 是否使用离散键值瓶颈
bottleneck_num_memories = 256, # 瓶颈内存数量
bottleneck_num_codebooks = 4, # 瓶颈码书数量
bottleneck_decay = 0.9, # 瓶颈衰减率
auto_set_target_length = True, # 是否自动设置目标长度
output_activation: Optional[nn.Module] = nn.Softplus() # 输出激活函数,默认为 Softplus
):
# 调用父类的构造函数
super().__init__()
# 断言 enformer 是 Enformer 类的实例
assert isinstance(enformer, Enformer)
# 计算 enformer 隐藏维度
enformer_hidden_dim = enformer.dim * 2
# 设置离散键值瓶颈
self.discrete_key_value_bottleneck = discrete_key_value_bottleneck
# 如果启用了离散键值瓶颈
if discrete_key_value_bottleneck:
# 创建 DiscreteKeyValueBottleneck 对象
enformer = DiscreteKeyValueBottleneck(
encoder = enformer,
dim = enformer_hidden_dim,
num_memory_codebooks = bottleneck_num_codebooks,
num_memories = bottleneck_num_memories,
dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
decay = bottleneck_decay,
)
# 设置 enformer
self.enformer = enformer
# 设置是否自动设置目标长度
self.auto_set_target_length = auto_set_target_length
# 对查询进行归一化
self.query_norm = nn.LayerNorm(enformer_hidden_dim)
# 对键值进行归一化
self.key_values_norm = nn.LayerNorm(context_dim)
# 设置缩放因子和头数
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = heads * dim_head
# 线性变换生成查询
self.to_queries = nn.Linear(enformer_hidden_dim, inner_dim, bias = False)
# 初始化空键和空值
self.null_key = nn.Parameter(torch.randn(inner_dim))
self.null_value = nn.Parameter(torch.randn(inner_dim))
# 线性变换生成键值
self.to_key_values = nn.Linear(context_dim, inner_dim * 2, bias = False)
# 线性变换生成输出
self.to_out = nn.Linear(inner_dim, enformer_hidden_dim)
# 线性变换生成预测结果
self.to_pred = Sequential(
nn.Linear(enformer_hidden_dim, 1),
Rearrange('b c ... 1 -> b ... c'),
output_activation
)
# 前向传播函数
def forward(
self,
seq,
*,
context,
context_mask = None,
target = None,
freeze_enformer = False,
finetune_enformer_ln_only = False,
finetune_last_n_layers_only = None
):
"""
b - batch
n - sequence length
c - number of contexts (tracks)
d - dimension
i - sequence length (query embeddings)
j - sequence length (keys / values contexts)
h - attention heads
"""
# 设置变量 h 为 self.heads
enformer_kwargs = dict()
# 如果 target 存在且 self.auto_set_target_length 为真,则设置 enformer_kwargs 的 target_length 为 target 的倒数第二维度长度
if exists(target) and self.auto_set_target_length:
enformer_kwargs = dict(target_length = target.shape[-2])
# 如果 self.discrete_key_value_bottleneck 为真,则调用 self.enformer 方法获取 embeddings
# 否则调用 get_enformer_embeddings 方法获取 embeddings
if self.discrete_key_value_bottleneck:
embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
else:
embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)
# 从 genetic 到 context 执行交叉注意力
# 如果 context 的维度为 2,则将其重排为 'b d -> b 1 d'
if context.ndim == 2:
context = rearrange(context, 'b d -> b 1 d')
# 获取查询 q,键 k 和值 v
q = self.to_queries(self.query_norm(embeddings))
k, v = self.to_key_values(self.key_values_norm(context)).chunk(2, dim = -1)
# 创建 null_k 和 null_v,并将其重复到与 k 和 v 相同的维度
null_k, null_v = map(lambda t: repeat(t, 'd -> b 1 d', b = context.shape[0]), (self.null_key, self.null_value))
# 将 null_k 和 k 连接在一起,将 null_v 和 v 连接在一起
k = torch.cat((null_k, k), dim = 1)
v = torch.cat((null_v, v), dim = 1)
# 分离头部
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
sim = einsum('b h i d, c h j d -> b c h i j', q, k) * self.scale
# 掩码
if exists(context_mask):
context_mask = F.pad(context_mask, (1, 0), value = True)
context_mask = rearrange(context_mask, 'b j -> b 1 1 1 j')
sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)
# 注意力
attn = sim.softmax(dim = -1)
# 聚合
out = einsum('b c h i j, c h j d -> b c h i d', attn, v)
out = rearrange(out, 'b c h n d -> b c n (h d)', h = h)
# 合并头部
branch_out = self.to_out(out)
# 残差连接
embeddings = embeddings + branch_out
# 转换为预测
pred = self.to_pred(embeddings)
# 如果 target 不存在,则返回 pred,否则返回 poisson_loss(pred, target)
if not exists(target):
return pred
return poisson_loss(pred, target)
.\lucidrains\enformer-pytorch\enformer_pytorch\metrics.py
from torchmetrics import Metric
from typing import Optional
import torch
# 定义一个自定义的 Metric 类,用于计算每个通道的平均皮尔逊相关系数
class MeanPearsonCorrCoefPerChannel(Metric):
# 是否可微分,默认为不可微分
is_differentiable: Optional[bool] = False
# 较高值是否更好,默认为是
higher_is_better: Optional[bool] = True
def __init__(self, n_channels:int, dist_sync_on_step=False):
"""Calculates the mean pearson correlation across channels aggregated over regions"""
# 调用父类的初始化方法
super().__init__(dist_sync_on_step=dist_sync_on_step)
# 设置要减少的维度
self.reduce_dims=(0, 1)
# 添加状态变量,用于存储乘积、真实值、真实值平方、预测值、预测值平方、计数
self.add_state("product", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
self.add_state("true", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
self.add_state("true_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
self.add_state("pred", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
self.add_state("pred_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
self.add_state("count", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
# 断言预测值和目标值的形状相同
assert preds.shape == target.shape
# 更新状态变量
self.product += torch.sum(preds * target, dim=self.reduce_dims)
self.true += torch.sum(target, dim=self.reduce_dims)
self.true_squared += torch.sum(torch.square(target), dim=self.reduce_dims)
self.pred += torch.sum(preds, dim=self.reduce_dims)
self.pred_squared += torch.sum(torch.square(preds), dim=self.reduce_dims)
self.count += torch.sum(torch.ones_like(target), dim=self.reduce_dims)
def compute(self):
# 计算真实值和预测值的均值
true_mean = self.true / self.count
pred_mean = self.pred / self.count
# 计算协方差、真实值方差、预测值方差、真实值和预测值的平方根乘积、相关系数
covariance = (self.product
- true_mean * self.pred
- pred_mean * self.true
+ self.count * true_mean * pred_mean)
true_var = self.true_squared - self.count * torch.square(true_mean)
pred_var = self.pred_squared - self.count * torch.square(pred_mean)
tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var)
correlation = covariance / tp_var
return correlation
.\lucidrains\enformer-pytorch\enformer_pytorch\modeling_enformer.py
# 导入所需的库
import math
from pathlib import Path
import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.checkpoint import checkpoint_sequential
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from enformer_pytorch.data import str_to_one_hot, seq_indices_to_one_hot
from enformer_pytorch.config_enformer import EnformerConfig
from transformers import PreTrainedModel
# 定义常量
SEQUENCE_LENGTH = 196_608
TARGET_LENGTH = 896
# 从 TensorFlow 中加载 gamma 位置
# 解决 TensorFlow 和 PyTorch 之间 xlogy 结果的差异
# 解决方案来自 @johahi
DIR = Path(__file__).parents[0]
TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt")
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 返回始终为指定值的函数
def always(val):
def inner(*args, **kwargs):
return val
return inner
# 对字典中的值应用函数
def map_values(fn, d):
return {key: fn(values) for key, values in d.items()}
# 在指数范围内生成整数序列
def exponential_linspace_int(start, end, num, divisible_by = 1):
def _round(x):
return int(round(x / divisible_by) * divisible_by)
base = math.exp(math.log(end / start) / (num - 1))
return [_round(start * base**i) for i in range(num)]
# 计算对数,避免值过小
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# 可能用于同步批归一化,在分布式训练中
def MaybeSyncBatchnorm(is_distributed = None):
is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d
# 损失函数和指标
# Poisson 损失函数
def poisson_loss(pred, target):
return (pred - target * log(pred)).mean()
# 计算 Pearson 相关系数
def pearson_corr_coef(x, y, dim = 1, reduce_dims = (-1,)):
x_centered = x - x.mean(dim = dim, keepdim = True)
y_centered = y - y.mean(dim = dim, keepdim = True)
return F.cosine_similarity(x_centered, y_centered, dim = dim).mean(dim = reduce_dims)
# 相对位置编码函数
# 获取指数衰减的位置特征
def get_positional_features_exponential(positions, features, seq_len, min_half_life = 3., dtype = torch.float):
max_range = math.log(seq_len) / math.log(2.)
half_life = 2 ** torch.linspace(min_half_life, max_range, features, device = positions.device)
half_life = half_life[None, ...]
positions = positions.abs()[..., None]
return torch.exp(-math.log(2.) / half_life * positions)
# 获取中心掩码位置特征
def get_positional_features_central_mask(positions, features, seq_len, dtype = torch.float):
center_widths = 2 ** torch.arange(1, features + 1, device = positions.device).to(dtype)
center_widths = center_widths - 1
return (center_widths[None, ...] > positions.abs()[..., None]).to(dtype)
# Gamma 分布概率密度函数
def gamma_pdf(x, concentration, rate):
log_unnormalized_prob = torch.xlogy(concentration - 1., x) - rate * x
log_normalization = (torch.lgamma(concentration) - concentration * torch.log(rate))
return torch.exp(log_unnormalized_prob - log_normalization)
# 获取 Gamma 分布位置特征
def get_positional_features_gamma(positions, features, seq_len, stddev = None, start_mean = None, eps = 1e-8, dtype = torch.float):
if not exists(stddev):
stddev = seq_len / (2 * features)
if not exists(start_mean):
start_mean = seq_len / features
mean = torch.linspace(start_mean, seq_len, features, device = positions.device)
mean = mean[None, ...]
concentration = (mean / stddev) ** 2
rate = mean / stddev ** 2
probabilities = gamma_pdf(positions.to(dtype).abs()[..., None], concentration, rate)
probabilities = probabilities + eps
outputs = probabilities / torch.amax(probabilities, dim = -1, keepdim = True)
return outputs
# 获取位置嵌入
def get_positional_embed(seq_len, feature_size, device, use_tf_gamma, dtype = torch.float):
distances = torch.arange(-seq_len + 1, seq_len, device = device)
assert not use_tf_gamma or seq_len == 1536, 'if using tf gamma, only sequence length of 1536 allowed for now'
# 定义特征函数列表,包括指数特征、中心掩码特征和伽马特征(如果不使用 TensorFlow 伽马则使用 TF_GAMMAS)
feature_functions = [
get_positional_features_exponential,
get_positional_features_central_mask,
get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
]
# 计算特征组件的数量
num_components = len(feature_functions) * 2
# 检查特征大小是否能被组件数量整除
if (feature_size % num_components) != 0:
raise ValueError(f'feature size is not divisible by number of components ({num_components})')
# 计算每个类别的基础数量
num_basis_per_class = feature_size // num_components
# 初始化嵌入列表
embeddings = []
# 遍历特征函数列表,生成嵌入特征并添加到嵌入列表中
for fn in feature_functions:
embeddings.append(fn(distances, num_basis_per_class, seq_len, dtype = dtype))
# 在最后一个维度上连接所有嵌入特征
embeddings = torch.cat(embeddings, dim = -1)
# 在最后一个维度上连接嵌入特征和距离的符号乘积
embeddings = torch.cat((embeddings, torch.sign(distances)[..., None] * embeddings), dim = -1)
# 将嵌入特征转换为指定数据类型并返回
return embeddings.to(dtype)
def relative_shift(x):
# 创建一个与 x 的最后一个维度大小相同的全零张量
to_pad = torch.zeros_like(x[..., :1])
# 在 x 的最后一个维度上连接全零张量,实现相对位移
x = torch.cat((to_pad, x), dim=-1)
# 获取 x 的形状信息
_, h, t1, t2 = x.shape
# 重新调整 x 的形状
x = x.reshape(-1, h, t2, t1)
# 从 x 中删除第一个元素
x = x[:, :, 1:, :]
# 重新调整 x 的形状
x = x.reshape(-1, h, t1, t2 - 1)
# 返回 x 的前一半元素
return x[..., :((t2 + 1) // 2)]
# classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
# 返回残差连接结果
return self.fn(x, **kwargs) + x
class GELU(nn.Module):
def forward(self, x):
# GELU 激活函数
return torch.sigmoid(1.702 * x) * x
class AttentionPool(nn.Module):
def __init__(self, dim, pool_size=2):
super().__init__()
self.pool_size = pool_size
# 定义池化函数
self.pool_fn = Rearrange('b d (n p) -> b d n p', p=pool_size)
# 定义注意力机制中的卷积层
self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias=False)
# 初始化卷积层的权重
nn.init.dirac_(self.to_attn_logits.weight)
# 对卷积层的权重进行缩放
with torch.no_grad():
self.to_attn_logits.weight.mul_(2)
def forward(self, x):
b, _, n = x.shape
remainder = n % self.pool_size
needs_padding = remainder > 0
if needs_padding:
# 对输入进行填充
x = F.pad(x, (0, remainder), value=0)
mask = torch.zeros((b, 1, n), dtype=torch.bool, device=x.device)
mask = F.pad(mask, (0, remainder), value=True)
# 对输入进行池化操作
x = self.pool_fn(x)
# 计算注意力权重
logits = self.to_attn_logits(x)
if needs_padding:
mask_value = -torch.finfo(logits.dtype).max
logits = logits.masked_fill(self.pool_fn(mask), mask_value)
# 计算加权和
attn = logits.softmax(dim=-1)
return (x * attn).sum(dim=-1)
class TargetLengthCrop(nn.Module):
def __init__(self, target_length):
super().__init__()
self.target_length = target_length
def forward(self, x):
seq_len, target_len = x.shape[-2], self.target_length
if target_len == -1:
return x
if seq_len < target_len:
raise ValueError(f'sequence length {seq_len} is less than target length {target_len}')
trim = (target_len - seq_len) // 2
if trim == 0:
return x
return x[:, -trim:trim]
def ConvBlock(dim, dim_out=None, kernel_size=1, is_distributed=None):
batchnorm_klass = MaybeSyncBatchnorm(is_distributed=is_distributed)
return nn.Sequential(
batchnorm_klass(dim),
GELU(),
nn.Conv1d(dim, default(dim_out, dim), kernel_size, padding=kernel_size // 2)
)
# attention classes
class Attention(nn.Module):
def __init__(
self,
dim,
*,
num_rel_pos_features,
heads=8,
dim_key=64,
dim_value=64,
dropout=0.,
pos_dropout=0.,
use_tf_gamma=False
):
super().__init__()
self.scale = dim_key ** -0.5
self.heads = heads
# 线性变换得到查询、键、值
self.to_q = nn.Linear(dim, dim_key * heads, bias=False)
self.to_k = nn.Linear(dim, dim_key * heads, bias=False)
self.to_v = nn.Linear(dim, dim_value * heads, bias=False)
# 输��层的线性变换
self.to_out = nn.Linear(dim_value * heads, dim)
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
# 相对位置编码
self.num_rel_pos_features = num_rel_pos_features
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
# dropout
self.pos_dropout = nn.Dropout(pos_dropout)
self.attn_dropout = nn.Dropout(dropout)
# 是否使用 tf gamma
self.use_tf_gamma = use_tf_gamma
# 定义前向传播函数,接受输入张量 x
def forward(self, x):
# 获取输入张量 x 的维度信息
n, h, device = x.shape[-2], self.heads, x.device
# 将输入张量 x 分别转换为查询(q)、键(k)、值(v)张量
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
# 将查询(q)、键(k)、值(v)张量重排维度,以适应多头注意力机制
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 对查询张量(q)进行缩放
q = q * self.scale
# 计算内容注意力得分
content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k)
# 获取位置嵌入向量
positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma, dtype = self.to_rel_k.weight.dtype)
positions = self.pos_dropout(positions)
rel_k = self.to_rel_k(positions)
# 重排位置嵌入向量的维度,以适应多头注意力机制
rel_k = rearrange(rel_k, 'n (h d) -> h n d', h = h)
# 计算相对位置注意力得分
rel_logits = einsum('b h i d, h j d -> b h i j', q + self.rel_pos_bias, rel_k)
# 对相对位置注意力得分进行相对偏移
rel_logits = relative_shift(rel_logits)
# 组合内容注意力得分和相对位置注意力得分
logits = content_logits + rel_logits
# 对注意力得分进行 softmax 操作
attn = logits.softmax(dim = -1)
attn = self.attn_dropout(attn)
# 根据注意力权重计算输出张量
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
# 返回输出张量
return self.to_out(out)
# 主类 Enformer 继承自 PreTrainedModel
class Enformer(PreTrainedModel):
# 设置配置类和基础模型前缀
config_class = EnformerConfig
base_model_prefix = "enformer"
# 从超参数创建 Enformer 实例的静态方法
@staticmethod
def from_hparams(**kwargs):
return Enformer(EnformerConfig(**kwargs))
# 初始化方法,接受配置参数
def __init__(self, config):
super().__init__(config)
self.dim = config.dim
half_dim = config.dim // 2
twice_dim = config.dim * 2
# 创建 stem 模块
self.stem = nn.Sequential(
nn.Conv1d(4, half_dim, 15, padding=7),
Residual(ConvBlock(half_dim)),
AttentionPool(half_dim, pool_size=2)
)
# 创建卷积 tower
filter_list = exponential_linspace_int(half_dim, config.dim, num=(config.num_downsamples - 1), divisible_by=config.dim_divisible_by)
filter_list = [half_dim, *filter_list]
conv_layers = []
for dim_in, dim_out in zip(filter_list[:-1], filter_list[1:]):
conv_layers.append(nn.Sequential(
ConvBlock(dim_in, dim_out, kernel_size=5),
Residual(ConvBlock(dim_out, dim_out, 1)),
AttentionPool(dim_out, pool_size=2)
))
self.conv_tower = nn.Sequential(*conv_layers)
# 是否使用 tensorflow gamma 位置
use_tf_gamma = config.use_tf_gamma
self.use_tf_gamma = use_tf_gamma
# transformer 模块
transformer = []
for _ in range(config.depth):
transformer.append(nn.Sequential(
Residual(nn.Sequential(
nn.LayerNorm(config.dim),
Attention(
config.dim,
heads=config.heads,
dim_key=config.attn_dim_key,
dim_value=config.dim // config.heads,
dropout=config.attn_dropout,
pos_dropout=config.pos_dropout,
num_rel_pos_features=config.dim // config.heads,
use_tf_gamma=use_tf_gamma
),
nn.Dropout(config.dropout_rate)
)),
Residual(nn.Sequential(
nn.LayerNorm(config.dim),
nn.Linear(config.dim, config.dim * 2),
nn.Dropout(config.dropout_rate),
nn.ReLU(),
nn.Linear(config.dim * 2, config.dim),
nn.Dropout(config.dropout_rate)
))
))
self.transformer = nn.Sequential(*transformer)
# 目标裁剪
self.target_length = config.target_length
self.crop_final = TargetLengthCrop(config.target_length)
# 最终的 pointwise 模块
self.final_pointwise = nn.Sequential(
Rearrange('b n d -> b d n'),
ConvBlock(filter_list[-1], twice_dim, 1),
Rearrange('b d n -> b n d'),
nn.Dropout(config.dropout_rate / 8),
GELU()
)
# 创建 trunk 顺序模块
self._trunk = nn.Sequential(
Rearrange('b n d -> b d n'),
self.stem,
self.conv_tower,
Rearrange('b d n -> b n d'),
self.transformer,
self.crop_final,
self.final_pointwise
)
# 为人类和老鼠创建最终头部
self.add_heads(**config.output_heads)
# 在 transformer trunk 上使用检查点
self.use_checkpointing = config.use_checkpointing
# 添加头部方法
def add_heads(self, **kwargs):
self.output_heads = kwargs
self._heads = nn.ModuleDict(map_values(lambda features: nn.Sequential(
nn.Linear(self.dim * 2, features),
nn.Softplus()
), kwargs))
# 设置目标长度的方法
def set_target_length(self, target_length):
crop_module = self._trunk[-2]
crop_module.target_length = target_length
# trunk 属性
@property
def trunk(self):
return self._trunk
@property
# 返回当前对象的头部属性
def heads(self):
return self._heads
# 对输入进行处理,返回经过处理后的结果
def trunk_checkpointed(self, x):
# 重新排列输入的数据维度
x = rearrange(x, 'b n d -> b d n')
# 对输入数据进行处理
x = self.stem(x)
x = self.conv_tower(x)
x = rearrange(x, 'b d n -> b n d')
# 使用序列化函数对输入数据进行处理
x = checkpoint_sequential(self.transformer, len(self.transformer), x)
x = self.crop_final(x)
x = self.final_pointwise(x)
return x
# 对输入数据进行前向传播处理
def forward(
self,
x,
target = None,
return_corr_coef = False,
return_embeddings = False,
return_only_embeddings = False,
head = None,
target_length = None
):
# 如果输入数据是列表,则将其转换为独热编码
if isinstance(x, list):
x = str_to_one_hot(x)
# 如果输入数据是 torch.Tensor 类型且数据类型为 long,则将其转换为独热编码
elif type(x) == torch.Tensor and x.dtype == torch.long:
x = seq_indices_to_one_hot(x)
# 将数据移动到指定设备上
x.to(self.device)
# 判断是否存在批次维度
no_batch = x.ndim == 2
# 如果没有批次维度,则重新排列数据维度
if no_batch:
x = rearrange(x, '... -> () ...')
# 如果存在目标长度,则设置目标长度
if exists(target_length):
self.set_target_length(target_length)
# 根据是否使用检查点技术选择相应的处理函数
trunk_fn = self.trunk_checkpointed if self.use_checkpointing else self._trunk
x = trunk_fn(x)
# 如果没有批次维度,则重新排列数据维度
if no_batch:
x = rearrange(x, '() ... -> ...')
# 如果只返回嵌入向量,则直接返回处理后的结果
if return_only_embeddings:
return x
# 对处理后的结果进行映射处理
out = map_values(lambda fn: fn(x), self._heads)
# 如果指定了头部,则返回指定头部的结果
if exists(head):
assert head in self._heads, f'head {head} not found'
out = out[head]
# 如果存在目标数据,则计算损失
if exists(target):
assert exists(head), 'head must be passed in if one were to calculate loss directly with targets'
# 如果需要返回相关系数,则返回相关系数
if return_corr_coef:
return pearson_corr_coef(out, target)
# 返回泊松损失
return poisson_loss(out, target)
# 如果需要返回嵌入向量,则返回嵌入向量和处理后的结果
if return_embeddings:
return out, x
# 返回处理后的结果
return out
# 从预训练模型加载模型
def from_pretrained(name, use_tf_gamma = None, **kwargs):
# 从预训练模型名称加载 Enformer 模型
enformer = Enformer.from_pretrained(name, **kwargs)
# 如果模型名称为 'EleutherAI/enformer-official-rough'
if name == 'EleutherAI/enformer-official-rough':
# 如果 use_tf_gamma 为 None,则设置为 True
use_tf_gamma = default(use_tf_gamma, True)
# 遍历 Enformer 模型的所有模块
for module in enformer.modules():
# 如果模块是 Attention 类型
if isinstance(module, Attention):
# 设置模块的 use_tf_gamma 属性为 use_tf_gamma
module.use_tf_gamma = use_tf_gamma
# 返回加载的 Enformer 模型
return enformer
.\lucidrains\enformer-pytorch\enformer_pytorch\__init__.py
# 从enformer_pytorch包中导入EnformerConfig类
from enformer_pytorch.config_enformer import EnformerConfig
# 从enformer_pytorch包中导入Enformer、from_pretrained、SEQUENCE_LENGTH、AttentionPool类
from enformer_pytorch.modeling_enformer import Enformer, from_pretrained, SEQUENCE_LENGTH, AttentionPool
# 从enformer_pytorch包中导入seq_indices_to_one_hot、str_to_one_hot、GenomeIntervalDataset、FastaInterval类
from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval
Enformer - Pytorch
Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. This repository also contains the means to fine tune pretrained models for your downstream tasks. The original tensorflow sonnet code can be found here.
Update: finetuned for predicting pseudobulk chromatin accessibility here
Install
$ pip install enformer-pytorch
Usage
import torch
from enformer_pytorch import Enformer
model = Enformer.from_hparams(
dim = 1536,
depth = 11,
heads = 8,
output_heads = dict(human = 5313, mouse = 1643),
target_length = 896,
)
seq = torch.randint(0, 5, (1, 196_608)) # for ACGTN, in that order (-1 for padding)
output = model(seq)
output['human'] # (1, 896, 5313)
output['mouse'] # (1, 896, 1643)
You can also directly pass in the sequence as one-hot encodings, which must be float values
import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot
model = Enformer.from_hparams(
dim = 1536,
depth = 11,
heads = 8,
output_heads = dict(human = 5313, mouse = 1643),
target_length = 896,
)
seq = torch.randint(0, 5, (1, 196_608))
one_hot = seq_indices_to_one_hot(seq)
output = model(one_hot)
output['human'] # (1, 896, 5313)
output['mouse'] # (1, 896, 1643)
Finally, one can fetch the embeddings, for fine-tuning and otherwise, by setting the return_embeddings
flag to be True
on forward
import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot
model = Enformer.from_hparams(
dim = 1536,
depth = 11,
heads = 8,
output_heads = dict(human = 5313, mouse = 1643),
target_length = 896,
)
seq = torch.randint(0, 5, (1, 196_608))
one_hot = seq_indices_to_one_hot(seq)
output, embeddings = model(one_hot, return_embeddings = True)
embeddings # (1, 896, 3072)
For training, you can directly pass the head and target in to get the poisson loss
import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot
model = Enformer.from_hparams(
dim = 1536,
depth = 11,
heads = 8,
output_heads = dict(human = 5313, mouse = 1643),
target_length = 200,
).cuda()
seq = torch.randint(0, 5, (196_608 // 2,)).cuda()
target = torch.randn(200, 5313).cuda()
loss = model(
seq,
head = 'human',
target = target
)
loss.backward()
# after much training
corr_coef = model(
seq,
head = 'human',
target = target,
return_corr_coef = True
)
corr_coef # pearson R, used as a metric in the paper
Pretrained Model
Deepmind has released the weights for their tensorflow sonnet Enformer model! I have ported it over to Pytorch and uploaded it to 🤗 Huggingface (~1GB). There are still some rounding errors that seem to be accruing across the layers, resulting in an absolute error as high as 0.5
. However, correlation coefficient look good so I am releasing the 'rough'ly working version. Will keep working on figuring out where the numerical errors are happening (it may be the attention pooling module, as I noticed the attention logits are pretty high).
Update: John St. John did some work and found that the enformer-official-rough
model hits the reported marks in the paper - human pearson R of 0.625
for validation, and 0.65
for test.
Update: As of version 0.8.0, if one were to use the from_pretrained
function to load the pretrained model, it should automatically use precomputed gamma positions to address a difference between tensorflow and pytorch xlogy
. This should resolve the numerical discrepancy above. If you were to further finetune and not be using the from_pretrained
function, please make sure to set use_tf_gamma = True
when using .from_hparams
to instantiate the Enformer
$ pip install enformer-pytorch>=0.5
Loading the model
from enformer_pytorch import from_pretrained
enformer = from_pretrained('EleutherAI/enformer-official-rough')
Quick sanity check on a single human validation point
$ python test_pretrained.py
# 0.5963 correlation coefficient on a validation sample
This is all made possible thanks to HuggingFace's custom model feature.
You can also load, with overriding of the target_length
parameter, if you are working with shorter sequence lengths
from enformer_pytorch import from_pretrained
model = from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)
# do your fine-tuning
To save on memory during fine-tuning a large Enformer model
from enformer_pytorch import from_pretrained
enformer = from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)
# finetune enformer on a limited budget
Fine-tuning
This repository will also allow for easy fine-tuning of Enformer.
Fine-tuning on new tracks
import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import HeadAdapterWrapper
enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = HeadAdapterWrapper(
enformer = enformer,
num_tracks = 128,
post_transformer_embed = False # by default, embeddings are taken from after the final pointwise block w/ conv -> gelu - but if you'd like the embeddings right after the transformer block with a learned layernorm, set this to True
).cuda()
seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
target = torch.randn(1, 200, 128).cuda() # 128 tracks
loss = model(seq, target = target)
loss.backward()
Finetuning on contextual data (cell type, transcription factor, etc)
import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAdapterWrapper
enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = ContextAdapterWrapper(
enformer = enformer,
context_dim = 1024
).cuda()
seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
target = torch.randn(1, 200, 4).cuda() # 4 tracks
context = torch.randn(4, 1024).cuda() # 4 contexts for the different 'tracks'
loss = model(
seq,
context = context,
target = target
)
loss.backward()
Finally, there is also a way to use attention aggregation from a set of context embeddings (or a single context embedding). Simply use the ContextAttentionAdapterWrapper
import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAttentionAdapterWrapper
enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = ContextAttentionAdapterWrapper(
enformer = enformer,
context_dim = 1024,
heads = 8, # number of heads in the cross attention
dim_head = 64 # dimension per head
).cuda()
seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
target = torch.randn(1, 200, 4).cuda() # 4 tracks
context = torch.randn(4, 16, 1024).cuda() # 4 contexts for the different 'tracks', each with 16 tokens
context_mask = torch.ones(4, 16).bool().cuda() # optional context mask, in example, include all context tokens
loss = model(
seq,
context = context,
context_mask = context_mask,
target = target
)
loss.backward()
Data
You can use the GenomicIntervalDataset
to easily fetch sequences of any length from a .bed
file, with greater context length dynamically computed if specified
import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset
filter_train = lambda df: df.filter(pl.col('column_4') == 'train')
ds = GenomeIntervalDataset(
bed_file = './sequences.bed', # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
fasta_file = './hg38.ml.fa', # path to fasta file
filter_df_fn = filter_train, # filter dataframe function
return_seq_indices = True, # return nucleotide indices (ACGTN) or one hot encodings
shift_augs = (-2, 2), # random shift augmentations from -2 to +2 basepairs
context_length = 196_608,
# this can be longer than the interval designated in the .bed file,
# in which case it will take care of lengthening the interval on either sides
# as well as proper padding if at the end of the chromosomes
chr_bed_to_fasta_map = {
'chr1': 'chromosome1', # if the chromosome name in the .bed file is different than the key name in the fasta file, you can rename them on the fly
'chr2': 'chromosome2',
'chr3': 'chromosome3',
# etc etc
}
)
model = Enformer.from_hparams(
dim = 1536,
depth = 11,
heads = 8,
output_heads = dict(human = 5313, mouse = 1643),
target_length = 896,
)
seq = ds[0] # (196608,)
pred = model(seq, head = 'human') # (896, 5313)
To return the random shift value, as well as whether reverse complement was activated (in the case you need to reverse the corresponding chip-seq target data), just set return_augs = True
when initializing the GenomicIntervalDataset
import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset
filter_train = lambda df: df.filter(pl.col('column_4') == 'train')
ds = GenomeIntervalDataset(
bed_file = './sequences.bed', # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
fasta_file = './hg38.ml.fa', # path to fasta file
filter_df_fn = filter_train, # filter dataframe function
return_seq_indices = True, # return nucleotide indices (ACGTN) or one hot encodings
shift_augs = (-2, 2), # random shift augmentations from -2 to +2 basepairs
rc_aug = True, # use reverse complement augmentation with 50% probability
context_length = 196_608,
return_augs = True # return the augmentation meta data
)
seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,)
Appreciation
Special thanks goes out to EleutherAI for providing the resources to retrain the model, during a time when the official model from Deepmind had not been released yet.
Thanks also goes out to @johahi for finding out that there are numerical differences between the torch and tensorflow implementations of xlogy
. He provided a fix for this difference, which is adopted in this repository in v0.8.0
Todo
Citations
@article {Avsec2021.04.07.438649,
author = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
title = {Effective gene expression prediction from sequence by integrating long-range interactions},
elocation-id = {2021.04.07.438649},
year = {2021},
doi = {10.1101/2021.04.07.438649},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
eprint = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
journal = {bioRxiv}
}
@misc{liu2022convnet,
title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022},
eprint = {2201.03545},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
.\lucidrains\enformer-pytorch\scripts\tf_to_torch.py
# 导入 einops 模块中的 rearrange 函数
from einops import rearrange
# 复制 BatchNorm 层的参数到 PyTorch 模型中
def copy_bn(mod, vars, path):
# 获取 BatchNorm 层的 offset 和 scale 参数
bn_offset = vars[f'{path}offset:0']
bn_scale = vars[f'{path}scale:0']
# 获取 BatchNorm 层的移动平均值参数
ema_path = '/'.join(path.split('/')[:-1]) + '/'
bn_running_mean = vars[f'{ema_path}moving_mean/average:0']
bn_running_var = vars[f'{ema_path}moving_variance/average:0']
# 将 scale 参数复制到权重数据中
mod.weight.data.copy_(bn_scale)
# 将 offset 参数复制到偏置数据中
mod.bias.data.copy_(bn_offset)
# 将移动方差参数复制到 running_var 数据中
mod.running_var.data.copy_(rearrange(bn_running_var, '1 1 d -> d'))
# 将移动平均值参数复制到 running_mean 数据中
mod.running_mean.data.copy_(rearrange(bn_running_mean, '1 1 d -> d'))
# 复制卷积层的参数到 PyTorch 模型中
def copy_conv(mod, vars, path):
# 获取卷积层的偏置和权重参数
bias = vars[f'{path}b:0']
weight = vars[f'{path}w:0']
# 将权重参数复制到权重数据中
mod.weight.data.copy_(rearrange(weight, 'k i o -> o i k'))
# 将偏置参数复制到偏置数据中
mod.bias.data.copy_(bias)
# 复制注意力池化层的参数到 PyTorch 模型中
def copy_attn_pool(mod, vars, path):
# 获取注意力池化层的参数
attn_pool_proj = vars[path]
# 将参数复制到权重数据中
mod.to_attn_logits.weight.data.copy_(rearrange(attn_pool_proj, 'i o -> o i 1 1'))
# 复制全连接层的参数到 PyTorch 模型中
def copy_linear(mod, vars, path, has_bias = True):
# 获取全连接层的权重参数
weight = vars[f'{path}w:0']
# 将权重参数复制到权重数据中
mod.weight.data.copy_(rearrange(weight, 'i o -> o i'))
# 如果没有偏置参数,则直接返回
if not has_bias:
return
# 获取全连接层的偏置参数
bias = vars[f'{path}b:0']
# 将偏置参数复制到偏置数据中
mod.bias.data.copy_(bias)
# 复制 LayerNorm 层的参数到 PyTorch 模型中
def copy_ln(mod, vars, path):
# 获取 LayerNorm 层的 scale 和 offset 参数
weight = vars[f'{path}scale:0']
bias = vars[f'{path}offset:0']
# 将 scale 参数复制到权重数据中
mod.weight.data.copy_(weight)
# 将 offset 参数复制到偏置数据中
mod.bias.data.copy_(bias)
# 获取 TensorFlow 模型的变量
def get_tf_vars(tf_model):
return {v.name: (torch.from_numpy(v.numpy()) if isinstance(v.numpy(), np.ndarray) else None) for v in tf_model.variables}
# 将 TensorFlow 模型的参数复制到 PyTorch 模型中
def copy_tf_to_pytorch(tf_model, pytorch_model):
# 获取 TensorFlow 模型的变量
tf_vars = get_tf_vars(tf_model)
# 获取 PyTorch 模型的 stem 部分
stem_conv = pytorch_model.stem[0]
stem_point_bn = pytorch_model.stem[1].fn[0]
stem_point_conv = pytorch_model.stem[1].fn[2]
stem_attn_pool = pytorch_model.stem[2]
# 复制 stem 部分的参数
copy_conv(stem_conv, tf_vars, 'enformer/trunk/stem/conv1_d/')
copy_bn(stem_point_bn, tf_vars, 'enformer/trunk/stem/pointwise_conv_block/cross_replica_batch_norm/')
copy_conv(stem_point_conv, tf_vars, 'enformer/trunk/stem/pointwise_conv_block/conv1_d/')
copy_attn_pool(stem_attn_pool, tf_vars, 'enformer/trunk/stem/softmax_pooling/linear/w:0')
# 遍历 conv_tower 部分的参数
for ind, tower_block in enumerate(pytorch_model.conv_tower):
tower_bn = tower_block[0][0]
tower_conv = tower_block[0][2]
tower_point_bn = tower_block[1].fn[0]
tower_point_conv = tower_block[1].fn[2]
tower_attn_pool = tower_block[2]
# 构建路径
conv_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/conv_block/conv1_d/'
bn_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/conv_block/cross_replica_batch_norm/'
point_conv_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/pointwise_conv_block/conv1_d/'
point_bn_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/pointwise_conv_block/cross_replica_batch_norm/'
attn_pool_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/softmax_pooling/linear/w:0'
# 复制 conv_tower 部分的参数
copy_bn(tower_bn, tf_vars, bn_path)
copy_conv(tower_conv, tf_vars, conv_path)
copy_bn(tower_point_bn, tf_vars, point_bn_path)
copy_conv(tower_point_conv, tf_vars, point_conv_path)
copy_attn_pool(tower_attn_pool, tf_vars, attn_pool_path)
# 遍历 PyTorch 模型中的 transformer 层
for ind, transformer_block in enumerate(pytorch_model.transformer):
# 构建注意力层的路径
attn_ln_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/layer_norm/'
attn_q_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/q_layer/'
attn_k_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/k_layer/'
attn_r_k_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_k_layer/'
attn_v_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/v_layer/'
attn_out_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/embedding_layer/'
attn_content_bias_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_w_bias:0'
attn_rel_bias_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_r_bias:0'
ff_ln_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/layer_norm/'
# 需要编辑的链接,确保变量可访问
ff_linear1_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/project_in/'
ff_linear2_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/project_out/'
# 获取注意力层和多头注意力机制
attn = transformer_block[0]
attn_ln = attn.fn[0]
mha = attn.fn[1]
# 复制线性层参数
copy_linear(mha.to_q, tf_vars, attn_q_path, has_bias = False)
copy_linear(mha.to_k, tf_vars, attn_k_path, has_bias = False)
copy_linear(mha.to_rel_k, tf_vars, attn_r_k_path, has_bias = False)
copy_linear(mha.to_v, tf_vars, attn_v_path, has_bias = False)
copy_linear(mha.to_out, tf_vars, attn_out_path)
# 复制注意力层的偏置参数
mha.rel_content_bias.data.copy_(tf_vars[attn_content_bias_path])
mha.rel_pos_bias.data.copy_(tf_vars[attn_rel_bias_path])
# 获取前馈层和线性层
ff = transformer_block[-1]
ff_ln = ff.fn[0]
ff_linear1 = ff.fn[1]
ff_linear2 = ff.fn[4]
# 复制层归一化参数
copy_ln(attn_ln, tf_vars, attn_ln_path)
copy_ln(ff_ln, tf_vars, ff_ln_path)
copy_linear(ff_linear1, tf_vars, ff_linear1_path)
copy_linear(ff_linear2, tf_vars, ff_linear2_path)
# 获取最终的批归一化层和卷积层
final_bn = pytorch_model.final_pointwise[1][0]
final_conv = pytorch_model.final_pointwise[1][2]
# 复制批归一化层和卷积层参数
copy_bn(final_bn, tf_vars, 'enformer/trunk/final_pointwise/conv_block/cross_replica_batch_norm/')
copy_conv(final_conv, tf_vars, 'enformer/trunk/final_pointwise/conv_block/conv1_d/')
# 获取头部线性层
human_linear = pytorch_model._heads['human'][0]
mouse_linear = pytorch_model._heads['mouse'][0]
# 复制头部线性层参数
copy_linear(human_linear, tf_vars, 'enformer/heads/head_human/linear/')
copy_linear(mouse_linear, tf_vars, 'enformer/heads/head_mouse/linear/')
# 打印成功信息
print('success')
.\lucidrains\enformer-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'enformer-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
include_package_data = True, # 包含所有数据文件
version = '0.8.8', # 版本号
license='MIT', # 许可证
description = 'Enformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/enformer-pytorch', # URL
keywords = [ # 关键词
'artificial intelligence',
'transformer',
'gene-expression'
],
install_requires=[ # 安装依赖
'discrete-key-value-bottleneck-pytorch>=0.0.8',
'einops>=0.3',
'numpy',
'torch>=1.6',
'torchmetrics',
'polars',
'pyfaidx',
'pyyaml',
'transformers[torch]',
],
classifiers=[ # 分类
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\enformer-pytorch\test_pretrained.py
# 导入 torch 库
import torch
# 从 enformer_pytorch 库中导入 from_pretrained 函数
from enformer_pytorch import from_pretrained
# 从预训练模型 'EleutherAI/enformer-official-rough' 中加载模型,不使用 TF Gamma 参数,将模型放在 GPU 上
enformer = from_pretrained('EleutherAI/enformer-official-rough', use_tf_gamma = False).cuda()
# 将模型设置为评估模式
enformer.eval()
# 从文件 './data/test-sample.pt' 中加载数据
data = torch.load('./data/test-sample.pt')
# 将数据中的 'sequence' 和 'target' 转移到 GPU 上
seq, target = data['sequence'].cuda(), data['target'].cuda()
# 禁用梯度计算
with torch.no_grad():
# 使用 enformer 模型进行推理,计算相关系数
corr_coef = enformer(
seq,
target = target,
return_corr_coef = True,
head = 'human'
)
# 打印相关系数
print(corr_coef)
# 断言相关系数大于 0.1
assert corr_coef > 0.1
.\lucidrains\enformer-tensorflow-sonnet-training-script\create_tfrecords.py
# 导入所需的模块
from itertools import islice
from functools import partial
import tensorflow as tf
# 旧的 get_dataset 函数,但只返回标签以便在新的更长序列中进行压缩
def organism_path(organism):
return os.path.join(f'gs://basenji_barnyard/data', organism)
# 获取数据集
def get_dataset(organism, subset, num_threads=8):
# 获取元数据
metadata = get_metadata(organism)
# 获取 TFRecord 文件
files = tfrecord_files(organism, subset)
# 创建 TFRecord 数据集
dataset = tf.data.TFRecordDataset(files, compression_type='ZLIB', num_parallel_reads=None)
# 映射数据集
dataset = dataset.map(functools.partial(deserialize, metadata=metadata), num_parallel_calls=num_threads)
return dataset
# 获取元数据
def get_metadata(organism):
path = os.path.join(organism_path(organism), 'statistics.json')
with tf.io.gfile.GFile(path, 'r') as f:
return json.load(f)
# 获取 TFRecord 文件
def tfrecord_files(organism, subset):
return sorted(tf.io.gfile.glob(os.path.join(organism_path(organism), 'tfrecords', f'{subset}-*.tfr')), key=lambda x: int(x.split('-')[-1].split('.')[0]))
# 反序列化
def deserialize(serialized_example, metadata):
feature_map = {
'sequence': tf.io.FixedLenFeature([], tf.string),
'target': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_example(serialized_example, feature_map)
target = tf.io.decode_raw(example['target'], tf.float16)
target = tf.reshape(target, (metadata['target_length'], metadata['num_targets']))
target = tf.cast(target, tf.float32)
return target
# 分块函数
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
# 创建 float 特征
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
# 解析单个示例
def parse_single_example(seq, target):
seq = seq.numpy()
target = target.numpy()
data = {
'seq' : _float_feature(seq.flatten()),
'target' : _float_feature(target.flatten()),
}
out = tf.train.Example(features=tf.train.Features(feature=data))
return out
# 物种配置
NUM_TRACKS_CONFIG = dict(human = 5313, mouse = 1643)
# 映射序列和目标
def map_seq_target(
element,
seq_len,
species, # 'human' or 'mouse'
shifts = None
):
assert species in NUM_TRACKS_CONFIG, f'{species} not found in config'
num_tracks = NUM_TRACKS_CONFIG[species]
num_shifts = 0 if shifts is None else len(list(range(shifts[0], shifts[1] + 1)))
data = {
'seq':tf.io.FixedLenFeature([(seq_len + num_shifts) * 4], tf.float32),
'target':tf.io.FixedLenFeature([896 * num_tracks], tf.float32),
}
content = tf.io.parse_single_example(element, data)
return content
# 创建 TFRecord
def create_tfrecords(ds, path = './', chunk_size = 256):
for ind, batch in enumerate(chunk(iter(ds), chunk_size)):
writer = tf.io.TFRecordWriter(f'{path}{ind}.tfrecord', 'ZLIB')
for seq, target in batch:
features = parse_single_example(seq, target)
writer.write(features.SerializeToString())
writer.close()
if __name__ == '__main__':
# 写入示例
generator_fn = get_dna_sample(
bed_file = './human-sequences.bed',
fasta_file = './hg38.ml.fa',
filter_type = 'train',
context_length = 196_608
)
seq_ds = tf.data.Dataset.from_generator(generator_fn, tf.float32)
label_ds = get_dataset('human', 'train')
zipped_ds = tf.data.Dataset.zip((seq_ds, label_ds))
create_tfrecords(zipped_ds, 'gs://enformer-new-data-path/')
# 读取
dataset = tf.data.TFRecordDataset(['./0.tfrecord', './1.tfrecord'], compression_type = 'ZLIB')
map_element_fn = partial(map_seq_target, seq_len = 196608, species = 'human', shifts = (-2, 2))
dataset = dataset.map(map_element_fn)
Enformer TPU training script (wip)
The full training script for Enformer (Tensorflow Sonnet) on TPU clusters, in an effort to migrate the model to pytorch.
This was pieced together from the Deepmind Enformer repository, the colab training notebook, as well as Basenji sequence augmentation code
It accounts for:
- distributed TPU training
- distributed datasets
- distributed validation
- gradient clipping
- cross replica batchnorms
- dataset augmentation
Training takes about 3 days on v3-64
Downloading sequence data for extending context length to 196,608
$ gsutil cp gs://basenji_barnyard/hg38.ml.fa.gz ./ && gunzip hg38.ml.fa.gz
$ gsutil cp gs://basenji_barnyard/mm10.ml.fa.gz ./ && gunzip mm10.ml.fa.gz
$ gsutil cp gs://basenji_barnyard/data/human/sequences.bed ./human-sequences.bed
$ gsutil cp gs://basenji_barnyard/data/mouse/sequences.bed ./mouse-sequences.bed
Todo
Citations
@article {Avsec2021.04.07.438649,
author = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
title = {Effective gene expression prediction from sequence by integrating long-range interactions},
elocation-id = {2021.04.07.438649},
year = {2021},
doi = {10.1101/2021.04.07.438649},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
eprint = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
journal = {bioRxiv}
}
.\lucidrains\enformer-tensorflow-sonnet-training-script\sequence.py
# 导入所需的库
import tensorflow as tf
import numpy as np
import pandas as pd
from pyfaidx import Fasta
from functools import partial
from random import randrange
# 创建一个用于存储 DNA 序列的独热编码的嵌入表
# 基于 https://gist.github.com/hannes-brt/54ca5d4094b3d96237fa2e820c0945dd 进行修改
embed = np.zeros([89, 4], np.float32)
embed[ord('A')] = np.array([1, 0, 0, 0])
embed[ord('C')] = np.array([0, 1, 0, 0])
embed[ord('G')] = np.array([0, 0, 1, 0])
embed[ord('T')] = np.array([0, 0, 0, 1])
embed[ord('a')] = np.array([1, 0, 0, 0])
embed[ord('c')] = np.array([0, 1, 0, 0])
embed[ord('g')] = np.array([0, 0, 1, 0])
embed[ord('t')] = np.array([0, 0, 0, 1])
embed[ord('.')] = np.array([.25, .25, .25, .25])
# 将嵌入表转换为 TensorFlow 张量
embedding_table = tf.convert_to_tensor(embed)
# 定义一个函数,将 DNA 序列进行独热编码
def one_hot_encode_seq(dna_input, embed, name = "encode_seq"):
with tf.name_scope(name):
# 将 DNA 序列转换为字节流
b = bytearray()
b.extend(map(ord, str(dna_input)))
t = tf.convert_to_tensor(b)
t = tf.cast(t, tf.int32)
# 使用嵌入表进行独热编码
encoded_dna = tf.nn.embedding_lookup(embedding_table, t)
return encoded_dna
# 根据 fasta 文件和 pyfaidx 获取更长的上下文
def get_datum(
ind,
fasta_ref,
bed_df,
context_length = None,
rand_shift_range = None
):
# 从 bed 数据框中获取行信息
row = bed_df.iloc[ind]
chrname, start, end, t = bed_df.iloc[ind].tolist()
interval_length = end - start
chromosome = fasta_ref[chrname]
chromosome_length = len(chromosome)
if rand_shift_range is not None:
min_shift, max_shift = rand_shift_range
adj_min_shift = max(start + min_shift, 0) - start
adj_max_shift = min(end + max_shift, chromosome_length) - end
left_padding = adj_min_shift - min_shift
right_padding = max_shift - adj_max_shift
start += adj_min_shift
end += adj_max_shift
if context_length is None or context_length <= interval_length:
seq = chromosome[start:end]
return one_hot_encode_seq(seq, embed)
left_padding = right_padding = 0
extra_seq = context_length - interval_length
extra_left_seq = extra_seq // 2
extra_right_seq = extra_seq - extra_left_seq
start -= extra_left_seq
end += extra_right_seq
if start < 0:
left_padding = -start
start = 0
if end > chromosome_length:
right_padding = end - chromosome_length
end = chromosome_length
seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
return one_hot_encode_seq(seq, embed)
# 获取 DNA 样本数据
def get_dna_sample(
bed_file,
fasta_file,
filter_type = None,
context_length = None,
rand_shift_range = (-2, 2)
):
# 从 bed 文件中读取数据
df = pd.read_csv(bed_file, sep = '\t', header = None)
if filter_type is not None:
df = df[df[3] == filter_type]
# 读取 fasta 文件
fasta = Fasta(fasta_file, sequence_always_upper = True)
yield_data_fn = partial(get_datum, fasta_ref = fasta, bed_df = df, context_length = context_length, rand_shift_range = rand_shift_range)
def inner():
for ind in range(len(df)):
yield yield_data_fn(ind)
return inner
# 主函数
if __name__ == '__main__':
# 获取 DNA 样本数据生成器
generator_fn = get_dna_sample(
bed_file = './human-sequences.bed',
fasta_file = './hg38.ml.fa',
filter_type = 'valid',
context_length = 196_608
)
# 创建 TensorFlow 数据集
dataset = tf.data.Dataset.from_generator(generator_fn, tf.float32)
# 打印数据集中第一个元素的形状
print(next(iter(dataset)).shape)
.\lucidrains\enformer-tensorflow-sonnet-training-script\train.py
# 版权声明,指明代码的版权归属
# 导入所需的库和模块
import time
import os
import glob
import json
import functools
import inspect
from pathlib import Path
import tensorflow as tf
from tqdm import tqdm
import numpy as np
import pandas as pd
from typing import Any, Callable, Dict, Optional, Text, Union, Iterable, List, Sequence
import sonnet as snt
from sonnet.src import base, once, types, utils
from sonnet.src.optimizers import optimizer_utils
import tensorflow as tf
import wandb
# attribute
# 引用 Enformer tensorflow 代码并进行修改以用于分布式训练
# https://github.com/deepmind/deepmind-research/tree/master/enformer
# 引用 Genetic augmentation 代码
# https://github.com/calico/basenji/blob/84c681a4b02f592a3de90799cee7f17d96f81ef8/basenji/archive/augmentation.py
# constants
NUM_CORES_ENFORCE = 64 # 使用 v3-64
SEQUENCE_LENGTH = 196_608
TARGET_LENGTH = 896
BIN_SIZE = 128
# assert TPUs
# 配置 TPU 环境
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='enformer')
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = snt.distribute.TpuReplicator(tpu)
num_cores = tpu_strategy.num_replicas_in_sync
# 断言核心数与预期值相等
assert num_cores == NUM_CORES_ENFORCE, f'must betraining on {num_cores} cores'
# optimizer
# 实现 Adam 优化器的更新函数
def adam_update(g, alpha, beta_1, beta_2, epsilon, t, m, v):
"""Implements 'Algorithm 1' from :cite:`kingma2014adam`."""
m = beta_1 * m + (1. - beta_1) * g # Biased first moment estimate.
v = beta_2 * v + (1. - beta_2) * g * g # Biased second raw moment estimate.
m_hat = m / (1. - tf.pow(beta_1, t)) # Bias corrected 1st moment estimate.
v_hat = v / (1. - tf.pow(beta_2, t)) # Bias corrected 2nd moment estimate.
update = alpha * m_hat / (tf.sqrt(v_hat) + epsilon)
return update, m, v
# 自定义 Adam 优化器类
class Adam(base.Optimizer):
def __init__(self,
learning_rate: Union[types.FloatLike, tf.Variable] = 0.001,
beta1: Union[types.FloatLike, tf.Variable] = 0.9,
beta2: Union[types.FloatLike, tf.Variable] = 0.999,
epsilon: Union[types.FloatLike, tf.Variable] = 1e-8,
weight_decay: Union[types.FloatLike, tf.Variable] = 1e-4,
name: Optional[str] = None):
super().__init__(name=name)
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.weight_decay = weight_decay
# 初始化步数
self.step = tf.Variable(0, trainable=False, name="t", dtype=tf.int64)
self.m = []
self.v = []
@once.once
def _initialize(self, parameters: Sequence[tf.Variable]):
"""First and second order moments are initialized to zero."""
zero_var = lambda p: utils.variable_like(p, trainable=False)
with tf.name_scope("m"):
self.m.extend(zero_var(p) for p in parameters)
with tf.name_scope("v"):
self.v.extend(zero_var(p) for p in parameters)
def apply(self, updates: Sequence[types.ParameterUpdate],
parameters: Sequence[tf.Variable]):
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
self._initialize(parameters)
self.step.assign_add(1)
# 使用 zip 函数同时遍历 updates, parameters, self.m, self.v 四个列表中的元素
for update, param, m_var, v_var in zip(updates, parameters, self.m, self.v):
# 如果 update 为 None,则跳过当前循环
if update is None:
continue
# 检查 update 和 param 的数据类型是否一致
optimizer_utils.check_same_dtype(update, param)
# 将学习率转换为 update 的数据类型
learning_rate = tf.cast(self.learning_rate, update.dtype)
# 将 beta1 转换为 update 的数据类型
beta_1 = tf.cast(self.beta1, update.dtype)
# 将 beta2 转换为 update 的数据类型
beta_2 = tf.cast(self.beta2, update.dtype)
# 将 epsilon 转换为 update 的数据类型
epsilon = tf.cast(self.epsilon, update.dtype)
# 将 step 转换为 update 的数据类型
step = tf.cast(self.step, update.dtype)
# 使用 adam_update 函数计算更新后的 update, m, v
update, m, v = adam_update(
g=update, alpha=learning_rate, beta_1=beta_1, beta_2=beta_2,
epsilon=epsilon, t=step, m=m_var, v=v_var)
# 计算权重衰减更新值,排除偏置项
weight_decay_update = (param * self.weight_decay * learning_rate) if 'w:0' in param.name else tf.zeros_like(param)
# 更新参数 param
param.assign_sub(update)
# 更新参数 param,加入权重衰减项
param.assign_sub(weight_decay_update)
# 更新 m_var
m_var.assign(m)
# 更新 v_var
v_var.assign(v)
# 定义一个名为MultiheadAttention的类,用于实现多头注意力机制
class MultiheadAttention(snt.Module):
"""Multi-head attention."""
def __init__(self,
value_size: int,
key_size: int,
num_heads: int,
scaling: bool = True,
attention_dropout_rate: float = 0.1,
relative_positions: bool = False,
relative_position_symmetric: bool = False,
relative_position_functions: Optional[List[str]] = None,
num_relative_position_features: Optional[int] = None,
positional_dropout_rate: float = 0.1,
zero_initialize: bool = True,
initializer: Optional[snt.initializers.Initializer] = None,
name: str = None):
"""Creates a MultiheadAttention module.
Args:
value_size: 每个头部的值嵌入大小。
key_size: 每个头部的键和查询嵌入大小。
num_heads: 每个时间步的独立查询数量。
scaling: 是否对注意力logits进行缩放。
attention_dropout_rate: 注意力logits的dropout率。
relative_positions: 是否使用TransformerXL风格的相对注意力。
relative_position_symmetric: 如果为True,则使用对称版本的基础函数。
如果为False,则使用对称和非对称版本。
relative_position_functions: 用于相对位置偏差的函数名称列表。
num_relative_position_features: 要计算的相对位置特征数量。
如果为None,则使用`value_size * num_heads`。
positional_dropout_rate: 如果使用相对位置,则位置编码的dropout率。
zero_initialize: 如果为True,则最终的线性层将被初始化为0。
initializer: 用于投影层的初始化器。如果未指定,则使用VarianceScaling,scale = 2.0。
name: 模块的名称。
"""
super().__init__(name=name)
self._value_size = value_size
self._key_size = key_size
self._num_heads = num_heads
self._attention_dropout_rate = attention_dropout_rate
self._scaling = scaling
self._relative_positions = relative_positions
self._relative_position_symmetric = relative_position_symmetric
self._relative_position_functions = relative_position_functions
if num_relative_position_features is None:
# num_relative_position_features需要能够被相对位置函数数量*2整除(用于对称和非对称版本)。
divisible_by = 2 * len(self._relative_position_functions)
self._num_relative_position_features = (
(self._value_size // divisible_by) * divisible_by)
else:
self._num_relative_position_features = num_relative_position_features
self._positional_dropout_rate = positional_dropout_rate
self._initializer = initializer
if self._initializer is None:
self._initializer = snt.initializers.VarianceScaling(scale=2.0)
key_proj_size = self._key_size * self._num_heads
embedding_size = self._value_size * self._num_heads
# 创建线性层用于查询、键和值的投影
self._q_layer = snt.Linear(
key_proj_size,
name='q_layer',
with_bias=False,
w_init=self._initializer)
self._k_layer = snt.Linear(
key_proj_size,
name='k_layer',
with_bias=False,
w_init=self._initializer)
self._v_layer = snt.Linear(
embedding_size,
name='v_layer',
with_bias=False,
w_init=self._initializer)
w_init = snt.initializers.Constant(1e-8) if zero_initialize else self._initializer
# 创建线性层用于嵌入
self._embedding_layer = snt.Linear(
embedding_size,
name='embedding_layer',
w_init=w_init,
b_init= snt.initializers.Constant(1e-8))
# 如果使用相对位置,则创建额外的层
# 如果存在相对位置信息
if self._relative_positions:
# 创建线性层用于处理相对位置信息
self._r_k_layer = snt.Linear(
key_proj_size,
name='r_k_layer',
with_bias=False,
w_init=self._initializer)
# 创建相对位置信息的偏置项
self._r_w_bias = tf.Variable(
self._initializer([1, self._num_heads, 1, self._key_size],
dtype=tf.float32),
name='r_w_bias')
self._r_r_bias = tf.Variable(
self._initializer([1, self._num_heads, 1, self._key_size],
dtype=tf.float32),
name='r_r_bias')
def _multihead_output(self, linear, inputs):
"""Applies a standard linear to inputs and returns multihead output."""
# 对输入应用标准线性变换
output = snt.BatchApply(linear)(inputs) # [B, T, H * KV]
num_kv_channels = output.shape[-1] // self._num_heads
# 将 H * Channels 分割成不同的轴
output = snt.reshape(output,
output_shape=[-1, self._num_heads, num_kv_channels])
# [B, T, H, KV] -> [B, H, T, KV]
return tf.transpose(output, [0, 2, 1, 3])
def __call__(self,
inputs,
is_training=False):
# 初始化投影层
embedding_size = self._value_size * self._num_heads
seq_len = inputs.shape[1]
# 计算 q, k 和 v 作为输入的多头投影
q = self._multihead_output(self._q_layer, inputs) # [B, H, T, K]
k = self._multihead_output(self._k_layer, inputs) # [B, H, T, K]
v = self._multihead_output(self._v_layer, inputs) # [B, H, T, V]
# 将查询按照键大小的平方根进行缩放
if self._scaling:
q *= self._key_size**-0.5
if self._relative_positions:
# 对于相对位置,我们将位置投影以形成相对键
distances = tf.range(-seq_len + 1, seq_len, dtype=tf.float32)[tf.newaxis]
positional_encodings = positional_features_all(
positions=distances,
feature_size=self._num_relative_position_features,
seq_length=seq_len,
feature_functions=self._relative_position_functions,
symmetric=self._relative_position_symmetric)
# [1, 2T-1, Cr]
if is_training:
positional_encodings = tf.nn.dropout(
positional_encodings, rate=self._positional_dropout_rate)
# [1, H, 2T-1, K]
r_k = self._multihead_output(self._r_k_layer, positional_encodings)
# 将相对位置的偏移 logits 添加到内容 logits 中
# [B, H, T', T]
content_logits = tf.matmul(q + self._r_w_bias, k, transpose_b=True)
# [B, H, T', 2T-1]
relative_logits = tf.matmul(
q + self._r_r_bias, r_k, transpose_b=True)
# [B, H, T', T]
relative_logits = relative_shift(relative_logits)
logits = content_logits + relative_logits
else:
# [B, H, T', T]
logits = tf.matmul(q, k, transpose_b=True)
weights = tf.nn.softmax(logits)
# 在注意力权重上进行 dropout
if is_training:
weights = tf.nn.dropout(weights, rate=self._attention_dropout_rate)
# 转置和重塑输出
output = tf.matmul(weights, v) # [B, H, T', V]
output_transpose = tf.transpose(output, [0, 2, 1, 3]) # [B, T', H, V]
# 最终线性层
attended_inputs = snt.reshape(
output_transpose, output_shape=[embedding_size], preserve_dims=2)
output = self._embedding_layer(attended_inputs)
return output
def relative_shift(x):
"""Shift the relative logits like in TransformerXL."""
# 在最后一个时间尺度维度上添加零
to_pad = tf.zeros_like(x[..., :1])
x = tf.concat([to_pad, x], -1)
_, num_heads, t1, t2 = x.shape
x = tf.reshape(x, [-1, num_heads, t2, t1])
x = tf.slice(x, [0, 0, 1, 0], [-1, -1, -1, -1])
x = tf.reshape(x, [-1, num_heads, t1, t2 - 1])
x = tf.slice(x, [0, 0, 0, 0], [-1, -1, -1, (t2 + 1) // 2])
return x
# 可用的特征函数:
def get_positional_feature_function(name):
"""返回位置特征函数。"""
available = {
'positional_features_exponential': positional_features_exponential,
'positional_features_central_mask': positional_features_central_mask,
'positional_features_gamma': positional_features_gamma
}
if name not in available:
raise ValueError(f'Function {name} not available in {available.keys()}')
return available[name]
def positional_features_all(positions: tf.Tensor,
feature_size: int,
seq_length: Optional[int] = None,
bin_size: Optional[int] = None,
feature_functions: Optional[List[str]] = None,
symmetric=False):
"""计算相对位置编码/特征。每个位置特征函数将计算/提供相同比例的特征,组成总特征数为 feature_size。
Args:
positions: 任意形状的相对位置张量。
feature_size: 基函数的总数。
seq_length: 表示个体位置特征可以使用的特征长度的序列长度。这是必需的,因为输入特征的参数化应该独立于 `positions`,但仍然可能需要使用总特征数。
bin_size: 用于对序列进行分区的 bin 大小。这可用于计算相对于基因组的绝对尺度上的特征。
feature_functions: 要使用的不同特征函数的列表。每个函数将以参数形式接受:positions、序列长度和要计算的特征数。
symmetric: 如果为 True,则生成的特征将在相对位置为 0 时对称(即只有位置的绝对值会影响)。如果为 False,则将使用特征的对称和非对称版本(对称乘以位置的符号)。
Returns:
形状为 `positions.shape + (feature_size,)` 的张量。
"""
if feature_functions is None:
feature_functions = ['positional_features_exponential',
'positional_features_central_mask',
'positional_features_gamma']
num_components = len(feature_functions) # 每个基函数一个
if not symmetric:
num_components = 2 * num_components
# 目前,我们不允许奇数大小的嵌入。
if feature_size % num_components != 0:
raise ValueError(
f'feature_size 必须能被 {num_components} 整除')
feature_functions = [get_positional_feature_function(f)
for f in feature_functions]
num_basis_per_class = feature_size // num_components
embeddings = tf.concat([f(tf.abs(positions), num_basis_per_class,
seq_length, bin_size)
for f in feature_functions],
axis=-1)
if not symmetric:
embeddings = tf.concat([embeddings,
tf.sign(positions)[..., tf.newaxis] * embeddings],
axis=-1)
tf.TensorShape(embeddings.shape).assert_is_compatible_with(
positions.shape + [feature_size])
return embeddings
def _prepend_dims(x, num_dims):
return tf.reshape(x, shape=[1] * num_dims + x.shape)
def positional_features_exponential(positions: tf.Tensor,
feature_size: int,
seq_length: Optional[int] = None,
bin_size: Optional[int] = None,
min_half_life: Optional[float] = 3.0):
"""Create exponentially decaying positional weights.
Args:
positions: Position tensor (arbitrary shape).
feature_size: Number of basis functions to use.
seq_length: Sequence length.
bin_size: (unused). See `positional_features_all`.
min_half_life: Smallest exponential half life in the grid of half lives.
Returns:
A Tensor with shape [2 * seq_length - 1, feature_size].
"""
# 删除未使用的变量
del bin_size # Unused.
# 如果未提供序列长度,则计算最大位置的绝对值加1作为序列长度
if seq_length is None:
seq_length = tf.reduce_max(tf.abs(positions)) + 1
# 计算最大范围和半衰期
seq_length = tf.cast(seq_length, dtype=tf.float32)
max_range = tf.math.log(seq_length) / tf.math.log(2.0)
half_life = tf.pow(2.0, tf.linspace(min_half_life, max_range, feature_size))
half_life = _prepend_dims(half_life, positions.shape.rank)
positions = tf.abs(positions)
# 计算指数衰减权重
outputs = tf.exp(-tf.math.log(2.0) / half_life * positions[..., tf.newaxis])
# 确保输出形状与预期一致
tf.TensorShape(outputs.shape).assert_is_compatible_with(
positions.shape + [feature_size])
return outputs
def positional_features_central_mask(positions: tf.Tensor,
feature_size: int,
seq_length: Optional[int] = None,
bin_size: Optional[int] = None):
"""Positional features using a central mask (allow only central features)."""
# 删除未使用的变量
del seq_length # Unused.
del bin_size # Unused.
# 计算中心掩码的宽度
center_widths = tf.pow(2.0, tf.range(1, feature_size + 1, dtype=tf.float32))
center_widths = center_widths - 1
center_widths = _prepend_dims(center_widths, positions.shape.rank)
# 创建中心掩码
outputs = tf.cast(center_widths > tf.abs(positions)[..., tf.newaxis],
tf.float32)
# 确保输出形状与预期一致
tf.TensorShape(outputs.shape).assert_is_compatible_with(
positions.shape + [feature_size])
return outputs
def gamma_pdf(x, concentration, rate):
"""Gamma probability distribution function: p(x|concentration, rate)."""
# 计算 Gamma 概率分布函数
log_unnormalized_prob = tf.math.xlogy(concentration - 1., x) - rate * x
log_normalization = (tf.math.lgamma(concentration) -
concentration * tf.math.log(rate))
return tf.exp(log_unnormalized_prob - log_normalization)
def positional_features_gamma(positions: tf.Tensor,
feature_size: int,
seq_length: Optional[int] = None,
bin_size: Optional[int] = None,
stddev=None,
start_mean=None):
"""Positional features computed using the gamma distributions."""
# 删除未使用的变量
del bin_size # Unused.
# 如果未提供序列长度,则计算最大位置的绝对值加1作为序列长度
if seq_length is None:
seq_length = tf.reduce_max(tf.abs(positions)) + 1
# 如果未提供标准差,则使用默认值
if stddev is None:
stddev = seq_length / (2 * feature_size)
# 如果未提供起始均值,则使用默认值
if start_mean is None:
start_mean = seq_length / feature_size
# 计算均值、浓度和速率
mean = tf.linspace(start_mean, seq_length, num=feature_size)
mean = _prepend_dims(mean, positions.shape.rank)
concentration = (mean / stddev)**2
rate = mean / stddev**2
# 计算 Gamma 分布概率
probabilities = gamma_pdf(
tf.abs(tf.cast(positions, dtype=tf.float32))[..., tf.newaxis],
concentration, rate)
probabilities += 1e-8 # 为了确保数值稳定性
outputs = probabilities / tf.reduce_max(probabilities)
# 确保输出形状与预期一致
tf.TensorShape(outputs.shape).assert_is_compatible_with(
positions.shape + [feature_size])
return outputs
class Enformer(snt.Module):
"""Main model."""
def __init__(self,
channels: int = 1536,
num_transformer_layers: int = 11,
num_heads: int = 8,
pooling_type: str = 'attention',
use_convnext: bool = False,
name: str = 'enformer'):
"""Enformer model.
Args:
channels: Number of convolutional filters and the overall 'width' of the
model.
num_transformer_layers: Number of transformer layers.
num_heads: Number of attention heads.
pooling_type: Which pooling function to use. Options: 'attention' or max'.
name: Name of sonnet module.
"""
# 初始化 Enformer 模型
super().__init__(name=name)
# 定义头部通道数
heads_channels = {'human': 5313, 'mouse': 1643}
# 定义丢弃率
dropout_rate = 0.4
# 检查通道数是否可以被头部数整除
assert channels % num_heads == 0, ('channels needs to be divisible '
f'by {num_heads}')
# 定义整体注意力参数
whole_attention_kwargs = {
'attention_dropout_rate': 0.05,
'initializer': None,
'key_size': 64,
'num_heads': num_heads,
'num_relative_position_features': channels // num_heads,
'positional_dropout_rate': 0.01,
'relative_position_functions': [
'positional_features_exponential',
'positional_features_central_mask',
'positional_features_gamma'
],
'relative_positions': True,
'scaling': True,
'value_size': channels // num_heads,
'zero_initialize': True
}
# 定义名称作用域
trunk_name_scope = tf.name_scope('trunk')
trunk_name_scope.__enter__()
# 导入 moving_averages 模块
# 定义卷积块函数
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
with tf.name_scope(name or "batch_norm"):
moving_mean = moving_averages.ExponentialMovingAverage(
0.9, name="moving_mean")
moving_variance = moving_averages.ExponentialMovingAverage(
0.9, name="moving_variance")
return Sequential(lambda: [
snt.distribute.CrossReplicaBatchNorm(create_scale=True,
create_offset=True,
moving_mean = moving_mean,
moving_variance = moving_variance,
scale_init=snt.initializers.Ones()),
gelu,
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
], name=name)
# 定义 ConvNext 卷积块函数
def convnext_block(filters, width=1, mult = 4, ds_conv_kernel_size = 7, w_init=None, name='convnext_block', **kwargs):
return Sequential(lambda: [
ExpandDims(2),
snt.DepthwiseConv2D((ds_conv_kernel_size, 1), name ='convnext_ds_conv'),
Squeeze(2),
snt.LayerNorm(axis=-1, create_scale=True, create_offset=True),
snt.Linear(filters * mult, name='convnext_project_in'),
tf.nn.relu,
snt.Linear(filters, name='convnext_project_out')
], name=name)
# 根据是否使用 ConvNext 选择不同的卷积块函数
conv_block_fn = convnext_block if use_convnext else conv_block
# 定义干部模块
stem = Sequential(lambda: [
snt.Conv1D(channels // 2, 15),
Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')),
pooling_module(pooling_type, pool_size=2),
], name='stem')
# 定义滤波器列表
filter_list = exponential_linspace_int(start=channels // 2, end=channels,
num=6, divisible_by=128)
# 定义卷积塔模块
conv_tower = Sequential(lambda: [
Sequential(lambda: [
conv_block(num_filters, 5),
Residual(conv_block(num_filters, 1, name='pointwise_conv_block')),
pooling_module(pooling_type, pool_size=2),
],
name=f'conv_tower_block_{i}')
for i, num_filters in enumerate(filter_list)], name='conv_tower')
# Transformer.
# 定义一个多层感知机模型
def transformer_mlp():
return Sequential(lambda: [
# 对输入进行 LayerNorm 处理
snt.LayerNorm(axis=-1, create_scale=True, create_offset=True),
# 线性变换,将输入维度扩展为 channels * 2
snt.Linear(channels * 2, name = 'project_in'),
# 随机失活,防止过拟合
snt.Dropout(dropout_rate),
# 激活函数,使用 ReLU
tf.nn.relu,
# 线性变换,将输入维度缩减为 channels
snt.Linear(channels, name = 'project_out'),
# 随机失活,防止过拟合
snt.Dropout(dropout_rate)], name='mlp')
# 定义一个 Transformer 模型
transformer = Sequential(lambda: [
Sequential(lambda: [
# 残差连接,包含 LayerNorm、多头注意力、随机失活
Residual(Sequential(lambda: [
snt.LayerNorm(axis=-1,
create_scale=True, create_offset=True,
scale_init=snt.initializers.Ones()),
MultiheadAttention(**whole_attention_kwargs,
name=f'attention_{i}'),
snt.Dropout(dropout_rate),
], name='mha')),
# 残差连接,包含 MLP 模块
Residual(transformer_mlp())], name=f'transformer_block_{i}')
for i in range(num_transformer_layers)], name='transformer')
# 定义一个目标长度裁剪层
crop_final = TargetLengthCrop1D(TARGET_LENGTH, name='target_input')
# 定义一个最终的一维卷积块
final_pointwise = Sequential(lambda: [
# 一维卷积块,将输入维度扩展为 channels * 2
conv_block(channels * 2, 1),
# 随机失活,防止过拟合
snt.Dropout(dropout_rate / 8),
# 激活函数,使用 GELU
gelu], name='final_pointwise')
# 构建整个模型的主干部分
self._trunk = Sequential([stem,
conv_tower,
transformer,
crop_final,
final_pointwise],
name='trunk')
trunk_name_scope.__exit__(None, None, None)
# 构建模型的头部部分
with tf.name_scope('heads'):
self._heads = {
head: Sequential(
lambda: [snt.Linear(num_channels), tf.nn.softplus],
name=f'head_{head}')
for head, num_channels in heads_channels.items()
}
# pylint: enable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
@property
def trunk(self):
return self._trunk
@property
def heads(self):
return self._heads
# 模型的前向传播方法
def __call__(self, inputs: tf.Tensor,
is_training: bool) -> Dict[str, tf.Tensor]:
# 获取主干部分的嵌入表示
trunk_embedding = self.trunk(inputs, is_training=is_training)
# 返回各个头部的输出
return {
head: head_module(trunk_embedding, is_training=is_training)
for head, head_module in self.heads.items()
}
# 针对输入数据进行预测的方法,用于 SavedModel
@tf.function(input_signature=[
tf.TensorSpec([None, SEQUENCE_LENGTH, 4], tf.float32)])
def predict_on_batch(self, x):
"""Method for SavedModel."""
return self(x, is_training=False)
class TargetLengthCrop1D(snt.Module):
"""Crop sequence to match the desired target length."""
def __init__(self, target_length: int, name='target_length_crop'):
super().__init__(name=name)
self._target_length = target_length
def __call__(self, inputs):
# Calculate the amount to trim from the sequence to match the target length
trim = (inputs.shape[-2] - self._target_length) // 2
if trim < 0:
raise ValueError('inputs longer than target length')
# Crop the sequence to match the target length
return inputs[..., trim:-trim, :]
class ExpandDims(snt.Module):
def __init__(self, dim: int, name='expand_dims'):
super().__init__(name=name)
self._dim = dim
def __call__(self, inputs):
# Expand the dimensions of the input tensor at the specified dimension
return tf.expand_dims(inputs, self._dim)
class Squeeze(snt.Module):
def __init__(self, dim: int, name='squeeze'):
super().__init__(name=name)
self._dim = dim
def __call__(self, inputs):
# Remove dimensions of size 1 from the input tensor at the specified dimension
return tf.squeeze(inputs, self._dim)
class Sequential(snt.Module):
"""snt.Sequential automatically passing is_training where it exists."""
def __init__(self,
layers: Optional[Union[Callable[[], Iterable[snt.Module]],
Iterable[Callable[..., Any]]]] = None,
name: Optional[Text] = None):
super().__init__(name=name)
if layers is None:
self._layers = []
else:
# layers wrapped in a lambda function to have a common namespace.
if hasattr(layers, '__call__'):
with tf.name_scope(name):
layers = layers()
self._layers = [layer for layer in layers if layer is not None]
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):
outputs = inputs
for _, mod in enumerate(self._layers):
if accepts_is_training(mod):
outputs = mod(outputs, is_training=is_training, **kwargs)
else:
outputs = mod(outputs, **kwargs)
return outputs
def pooling_module(kind, pool_size):
"""Pooling module wrapper."""
if kind == 'attention':
return SoftmaxPooling1D(pool_size=pool_size, per_channel=True,
w_init_scale=2.0)
elif kind == 'max':
return tf.keras.layers.MaxPool1D(pool_size=pool_size, padding='same')
else:
raise ValueError(f'Invalid pooling kind: {kind}.')
class SoftmaxPooling1D(snt.Module):
"""Pooling operation with optional weights."""
def __init__(self,
pool_size: int = 2,
per_channel: bool = False,
w_init_scale: float = 0.0,
name: str = 'softmax_pooling'):
"""Softmax pooling.
Args:
pool_size: Pooling size, same as in Max/AvgPooling.
per_channel: If True, the logits/softmax weights will be computed for
each channel separately. If False, same weights will be used across all
channels.
w_init_scale: When 0.0 is equivalent to avg pooling, and when
~2.0 and `per_channel=False` it's equivalent to max pooling.
name: Module name.
"""
super().__init__(name=name)
self._pool_size = pool_size
self._per_channel = per_channel
self._w_init_scale = w_init_scale
self._logit_linear = None
@snt.once
def _initialize(self, num_features):
# Initialize the linear layer for computing logits
self._logit_linear = snt.Linear(
output_size=num_features if self._per_channel else 1,
with_bias=False, # Softmax is agnostic to shifts.
w_init=snt.initializers.Identity(self._w_init_scale))
def __call__(self, inputs):
_, length, num_features = inputs.shape
self._initialize(num_features)
# Reshape the inputs for pooling operation
inputs = tf.reshape(
inputs,
(-1, length // self._pool_size, self._pool_size, num_features))
# Perform softmax pooling operation
return tf.reduce_sum(
inputs * tf.nn.softmax(self._logit_linear(inputs), axis=-2),
axis=-2)
class Residual(snt.Module):
"""Residual block."""
def __init__(self, module: snt.Module, name='residual'):
super().__init__(name=name)
self._module = module
def __call__(self, inputs: tf.Tensor, is_training: bool, *args,
**kwargs) -> tf.Tensor:
# 返回输入数据与模块处理后的结果的和
return inputs + self._module(inputs, is_training, *args, **kwargs)
# 定义 GELU 激活函数,应用高斯误差线性单元激活函数
def gelu(x: tf.Tensor) -> tf.Tensor:
"""Applies the Gaussian error linear unit (GELU) activation function.
Using approximiation in section 2 of the original paper:
https://arxiv.org/abs/1606.08415
Args:
x: Input tensor to apply gelu activation.
Returns:
Tensor with gelu activation applied to it.
"""
return tf.nn.sigmoid(1.702 * x) * x
# 对序列进行 one-hot 编码
def one_hot_encode(sequence: str,
alphabet: str = 'ACGT',
neutral_alphabet: str = 'N',
neutral_value: Any = 0,
dtype=np.float32) -> np.ndarray:
"""One-hot encode sequence."""
# 将字符串转换为 uint8 类型
def to_uint8(string):
return np.frombuffer(string.encode('ascii'), dtype=np.uint8)
# 创建一个零矩阵,用于存储 one-hot 编码结果
hash_table = np.zeros((np.iinfo(np.uint8).max, len(alphabet)), dtype=dtype)
# 对字母表进行 one-hot 编码
hash_table[to_uint8(alphabet)] = np.eye(len(alphabet), dtype=dtype)
hash_table[to_uint8(neutral_alphabet)] = neutral_value
hash_table = hash_table.astype(dtype)
return hash_table[to_uint8(sequence)]
# 生成指数增长的整数序列
def exponential_linspace_int(start, end, num, divisible_by=1):
"""Exponentially increasing values of integers."""
def _round(x):
return int(np.round(x / divisible_by) * divisible_by)
base = np.exp(np.log(end / start) / (num - 1))
return [_round(start * base**i) for i in range(num)]
# 检查模块是否接受 is_training 参数
def accepts_is_training(module):
return 'is_training' in list(inspect.signature(module.__call__).parameters)
# 获取给定生物体的目标数据
def get_targets(organism):
targets_txt = f'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_{organism}.txt'
return pd.read_csv(targets_txt, sep='\t')
# 对批量 one-hot 编码的序列及其标签进行反向互补
def reverse_complement_transform(seq):
"""Reverse complement of batched onehot seq and corresponding label and na."""
# 反向互补序列
seq_rc = tf.gather(seq, [3, 2, 1, 0], axis=-1)
seq_rc = tf.reverse(seq_rc, axis=[0])
return seq_rc
# 将序列左移或右移指定数量的位置
def shift_sequence(seq, shift_amount, pad_value=0.25):
"""Shift a sequence left or right by shift_amount.
Args:
seq: a [batch_size, sequence_length, sequence_depth] sequence to shift
shift_amount: the signed amount to shift (tf.int32 or int)
pad_value: value to fill the padding (primitive or scalar tf.Tensor)
"""
input_shape = seq.shape
pad = pad_value * tf.ones_like(seq[0:tf.abs(shift_amount), :])
def _shift_right(_seq):
sliced_seq = _seq[:-shift_amount:, :]
return tf.concat([pad, sliced_seq], axis=0)
def _shift_left(_seq):
sliced_seq = _seq[-shift_amount:, :]
return tf.concat([sliced_seq, pad], axis=0)
output = tf.cond(
tf.greater(shift_amount, 0), lambda: _shift_right(seq),
lambda: _shift_left(seq))
output.set_shape(input_shape)
return output
# 应用随机移位增强
def augment_stochastic_shifts(seq, augment_shifts):
"""Apply a stochastic shift augmentation.
Args:
seq: input sequence of size [batch_size, length, depth]
augment_shifts: list of int offsets to sample from
Returns:
shifted and padded sequence of size [batch_size, length, depth]
"""
shift_index = tf.random.uniform(shape=[], minval=0,
maxval=len(augment_shifts), dtype=tf.int64)
shift_value = tf.gather(tf.constant(augment_shifts), shift_index)
seq = tf.cond(tf.not_equal(shift_value, 0),
lambda: shift_sequence(seq, shift_value),
lambda: seq)
return seq
# 应用随机移位增强到映射函数
def augment_stochastic_shifts_map_fn(datum):
augment_shifts = [-2, -1, 0, 1, 2]
return dict(
sequence = augment_stochastic_shifts(datum['sequence'], augment_shifts),
target = datum['target']
)
# 应用随机反向互补增强到映射函数
def augment_stochastic_rc_map_fn(datum):
sequence, target = (datum['sequence'], datum['target'])
augment = tf.random.uniform(shape=[]) > 0.5
sequence, target = tf.cond(augment, lambda: (sequence[::-1, ::-1], target[::-1, :]),
lambda: (sequence, target))
return dict(sequence = sequence, target = target)
# 获取生物体路径
def organism_path(organism):
# 返回拼接后的 Google Cloud 存储路径,包含基因组信息
return os.path.join(f'gs://basenji_barnyard/data', organism)
def get_dataset(organism, subset, num_threads=8, shuffle=True, rotate = 0, augment = False):
# 获取指定生物的元数据
metadata = get_metadata(organism)
# 获取指定生物和数据集子集的 TFRecord 文件列表
files = tfrecord_files(organism, subset)
# 将文件列表按照指定的旋转值重新排序
files = files[rotate:] + files[:rotate]
# 创建 TFRecord 数据集对象
dataset = tf.data.TFRecordDataset(files,
compression_type='ZLIB',
num_parallel_reads=num_threads)
if shuffle:
# 如果需要打乱数据集,则重复数据集
dataset = dataset.repeat()
# 对数据集进行随机打乱
dataset = dataset.shuffle(5000, seed = 42)
# 对数据集中的每个元素进行反序列化操作
dataset = dataset.map(functools.partial(deserialize, metadata=metadata),
num_parallel_calls=num_threads)
if augment:
# 如果需要数据增强,则对数据集进行增强操作
dataset = dataset.map(augment_stochastic_shifts_map_fn, num_parallel_calls=num_threads)
dataset = dataset.map(augment_stochastic_rc_map_fn, num_parallel_calls=num_threads)
return dataset
def get_metadata(organism):
# 获取指定生物的元数据
path = os.path.join(organism_path(organism), 'statistics.json')
with tf.io.gfile.GFile(path, 'r') as f:
return json.load(f)
def tfrecord_files(organism, subset):
# 获取指定生物和数据集子集的 TFRecord 文件列表,并按照文件名中的数字排序
return sorted(tf.io.gfile.glob(os.path.join(
organism_path(organism), 'tfrecords', f'{subset}-*.tfr'
)), key=lambda x: int(x.split('-')[-1].split('.')[0]))
def deserialize(serialized_example, metadata):
"""Deserialize bytes stored in TFRecordFile."""
# 定义 TFRecord 文件中的特征映射
feature_map = {
'sequence': tf.io.FixedLenFeature([], tf.string),
'target': tf.io.FixedLenFeature([], tf.string),
}
# 解析 TFRecord 文件中的序列和目标特征
example = tf.io.parse_example(serialized_example, feature_map)
# 解码序列特征并转换为指定形状和数据类型
sequence = tf.io.decode_raw(example['sequence'], tf.bool)
sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
sequence = tf.cast(sequence, tf.float32)
# 解码目标特征并转换为指定形状和数据类型
target = tf.io.decode_raw(example['target'], tf.float16)
target = tf.reshape(target,
(metadata['target_length'], metadata['num_targets']))
target = tf.cast(target, tf.float32)
return {'sequence': sequence,
'target': target}
# 新的 get_dataset 函数���用于实际为 196_608 的序列
NEW_TFRECORD_LOCATIONS = dict(
human = dict(
train = 'gs://enformer-human-train/',
valid = 'gs://enformer-human-valid/'
),
mouse = dict(
train = 'gs://enformer-mouse-train/',
valid = 'gs://enformer-mouse-valid/'
)
)
NUM_TRACKS_CONFIG = dict(human = 5313, mouse = 1643)
def new_dataset_map_seq_target(
element,
seq_len,
species, # 'human' or 'mouse'
target_length = 896,
shifts = None,
augment_rc = False
):
assert species in NUM_TRACKS_CONFIG, f'{species} not found in config'
num_tracks = NUM_TRACKS_CONFIG[species]
num_shifts = 0 if shifts is None else len(list(range(shifts[0], shifts[1] + 1)))
data = {
'seq': tf.io.FixedLenFeature([(seq_len + num_shifts) * 4], tf.float32),
'target': tf.io.FixedLenFeature([target_length * num_tracks], tf.float32),
}
content = tf.io.parse_single_example(element, data)
content['sequence'] = content.pop('seq')
content['sequence'] = tf.reshape(content['sequence'], (-1, 4))
content['target'] = tf.reshape(content['target'], (target_length, -1))
# 处理位移增强
shifts = tf.pad(tf.random.uniform(shape = [1], minval = 0, maxval = num_shifts, dtype = tf.int64), [[0, 1]])
content['sequence'] = tf.slice(content['sequence'], shifts, (seq_len, -1))
if augment_rc:
content = augment_stochastic_rc_map_fn(content)
content['sequence'].set_shape(tf.TensorShape([seq_len, 4]))
content['target'].set_shape(tf.TensorShape([target_length, num_tracks]))
return content
def get_dataset_new(
organism,
datatype,
shifts = (-2, 2),
augment_rc = False,
num_threads = 8
# 获取指定生物和数据类型的 TFRecord 文件路径
gcs_path = NEW_TFRECORD_LOCATIONS[organism][datatype]
# 获取指定路径下所有以 .tfrecord 结尾的文件,并按文件名排序
files = sorted(tf.io.gfile.glob(f'{gcs_path}*.tfrecord'))
# 创建 TFRecord 数据集对象,指定压缩类型为 ZLIB,并行读取线程数为 num_threads
dataset = tf.data.TFRecordDataset(files, compression_type='ZLIB', num_parallel_reads=num_threads)
# 部分应用函数,对数据集中的每个元素进行处理
map_element_fn = partial(new_dataset_map_seq_target, seq_len=SEQUENCE_LENGTH, species=organism, shifts=shifts, augment_rc=augment_rc)
dataset = dataset.map(map_element_fn)
# 返回处理后的数据集
return dataset
# 计算相关系数
def corr_coef(x, y, eps=0):
# 计算 x 的平方
x2 = tf.math.square(x)
# 计算 y 的平方
y2 = tf.math.square(y)
# 计算 x 和 y 的乘积
xy = x * y
# 计算 x 的均值
ex = tf.reduce_mean(x, axis=1)
# 计算 y 的均值
ey = tf.reduce_mean(y, axis=1)
# 计算 x 和 y 的乘积的均值
exy = tf.reduce_mean(xy, axis=1)
# 计算 x 的平方的均值
ex2 = tf.reduce_mean(x2, axis=1)
# 计算 y 的平方的均值
ey2 = tf.reduce_mean(y2, axis=1)
# 计算相关系数
r = (exy - ex * ey) / ((tf.math.sqrt(ex2 - tf.math.square(ex) + eps) * tf.math.sqrt(ey2 - tf.math.square(ey) + eps)) + eps)
# 返回相关系数的均值
return tf.reduce_mean(r, axis=-1)
# 创建评估步骤函数
def create_eval_step(model, head):
@tf.function
def predict(seq, target):
# 使用模型进行预测
pred = model(seq, is_training=False)[head]
# 返回预测结果与目标值的相关系数
return corr_coef(pred, target)
return predict
# 创建训练步骤函数
def create_step_function(model, optimizer, head, clip_grad_norm=1.0, weight_decay=0.0001):
@tf.function
def train_step(batch_seq, batch_target):
with tf.GradientTape() as tape:
with snt.mixed_precision.scope(tf.float16):
outputs = model(batch_seq, is_training=True)[head]
# 计算相关系数损失
corr_coef_loss = 1 - corr_coef(outputs, batch_target, eps=1e-8)
# 计算 Poisson 损失
poisson = tf.reduce_mean(tf.keras.losses.poisson(batch_target, outputs))
# 总损失为 Poisson 损失
loss = poisson
# 计算梯度
gradients = tape.gradient(loss, model.trainable_variables, unconnected_gradients=tf.UnconnectedGradients.ZERO)
gradients = [tf.clip_by_norm(grad, clip_grad_norm) for grad in gradients]
ctx = tf.distribute.get_replica_context()
gradients = ctx.all_reduce("mean", gradients)
optimizer.apply(gradients, model.trainable_variables)
return loss
return train_step
# 实例化模型和训练/评估函数
with tpu_strategy.scope():
# 创建 Enformer 模型
model = Enformer(channels=1536, num_heads=8, num_transformer_layers=11)
# 创建学习率变量
learning_rate = tf.Variable(0., trainable=False, name='learning_rate')
# 创建 Adam 优化器
optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
# 创建人类数据集训练步骤函数
train_step_human = create_step_function(model, optimizer, 'human')
# 创建小鼠数据集训练步骤函数
train_step_mouse = create_step_function(model, optimizer, 'mouse')
# 创建人类数据集评估步骤函数
eval_step_human = create_eval_step(model, 'human')
# 创建小鼠数据集评估步骤函数
eval_step_mouse = create_eval_step(model, 'mouse')
# 实验追踪
wandb.init(project='enformer')
wandb.run.save()
# 训练模型
num_steps = int(2e6)
num_warmup_steps = 5000
target_learning_rate = 5e-4
checkpoint_every = 2500
max_eval_steps = 25
eval_every = 500
# 全局步骤变量
global_step = tf.Variable(0, name='global_step', trainable=False)
# 检查点
checkpoint_root = "gs://enformer/"
checkpoint_name = "enformer"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)
checkpoint = tf.train.Checkpoint(module=model, step=global_step, optimizer=optimizer)
# 如果有最新的检查点,则加载
latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
checkpoint.restore(latest)
@tf.function
def step():
global_step.assign(global_step + 1)
batch_human, batch_mouse = next(data_it)
loss_human = tpu_strategy.run(train_step_human, args=(batch_human['sequence'], batch_human['target']))
loss_mouse = tpu_strategy.run(train_step_mouse, args=(batch_mouse['sequence'], batch_mouse['target']))
loss_human = tpu_strategy.reduce('mean', loss_human, axis=None)
loss_mouse = tpu_strategy.reduce('mean', loss_mouse, axis=None)
learning_rate_frac = tf.math.minimum(1.0, tf.cast(global_step, tf.float32) / tf.math.maximum(1.0, float(num_warmup_steps)))
learning_rate.assign(target_learning_rate * learning_rate_frac)
return loss_human, loss_mouse
@tf.function
# 定义一个函数,用于执行评估步骤
def eval_step():
# 从验证数据集中获取下一个人类数据批次
batch_human = next(valid_human_data_it)
# 从验证数据集中获取下一个老鼠数据批次
batch_mouse = next(valid_mouse_data_it)
# 在 TPU 策略下运行人类数据评估步骤
human_r = tpu_strategy.run(eval_step_human, args = (batch_human['sequence'], batch_human['target']))
# 在 TPU 策略下运行老鼠数据评估步骤
mouse_r = tpu_strategy.run(eval_step_mouse, args = (batch_mouse['sequence'], batch_mouse['target']))
# 对人类数据结果进行均值归约
human_r = tpu_strategy.reduce('mean', human_r, axis = 0)
# 对老鼠数据结果进行均值归约
mouse_r = tpu_strategy.reduce('mean', mouse_r, axis = 0)
# 返回人类和老鼠数据的评估结果
return human_r, mouse_r
# 获取全局步数
i = global_step.numpy()
# 计算总老鼠数据量和总人类数据量
total_mice = 114 * 256 + 111
total_human = 132 * 256 + 229
bucket_size = 256
num_seen = i * num_cores
# 计算在人类和老鼠数据中的文件跳过量
human_file_skip = (num_seen % total_human) // bucket_size
mouse_file_skip = (num_seen % total_mice) // bucket_size
# 获取人类和老鼠数据集,并按照指定方式处理
human_dataset = get_dataset('human', 'train', rotate = human_file_skip).batch(num_cores, drop_remainder = True)
mouse_dataset = get_dataset('mouse', 'train', rotate = mouse_file_skip).batch(num_cores, drop_remainder = True)
# 将人类和老鼠数据集进行配对,并预取数据
human_mouse_dataset = tf.data.Dataset.zip((human_dataset, mouse_dataset)).prefetch(2)
# 获取人类和老鼠验证数据集
human_valid_dataset = get_dataset('human', 'valid', shuffle = False).repeat().batch(num_cores)
mouse_valid_dataset = get_dataset('mouse', 'valid', shuffle = False).repeat().batch(num_cores)
# 创建数据集迭代器
data_it = iter(tpu_strategy.experimental_distribute_dataset(human_mouse_dataset))
valid_human_data_it = iter(tpu_strategy.experimental_distribute_dataset(human_valid_dataset))
valid_mouse_data_it = iter(tpu_strategy.experimental_distribute_dataset(mouse_valid_dataset))
# 打印起始步数
print(f'starting from {i}')
# 循环执行训练步骤
while i < num_steps:
print(f'processing step {i}')
# 执行训练步骤,获取人类和老鼠数据的损失值
loss_human, loss_mouse = step()
loss_human = loss_human.numpy()
loss_mouse = loss_mouse.numpy()
learning_rate_numpy = learning_rate.numpy()
print(f'completed step {i}')
# 记录损失值和学习率
log = {
'loss_human': loss_human,
'loss_mouse': loss_mouse,
'learning_rate': learning_rate_numpy
}
# 每隔一定步数进行评估
if i and not i % eval_every:
print('evaluating')
# 执行评估步骤,获取人类和老鼠数据的皮尔逊相关系数
human_pearson_r, mouse_pearson_r = eval_step()
human_pearson_r = human_pearson_r.numpy()
mouse_pearson_r = mouse_pearson_r.numpy()
# 更新记录
log = {
**log,
'human_pearson_r': human_pearson_r,
'mouse_pearson_r': mouse_pearson_r
}
# 将记录写入日志
wandb.log(log, step = i)
# 每隔一定步数进行保存模型
if not i % checkpoint_every:
print('checkpointing')
checkpoint.save(save_prefix)
# 更新步数
i += 1