Lucidrains-系列项目源码解析-十三-

Lucidrains 系列项目源码解析(十三)

.\lucidrains\diffusion-policy\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'diffusion-policy', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.1', # 版本号
  license='MIT', # 许可证
  description = 'Diffusion Policy', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/diffusion-policy', # URL
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'robotics',
    'denoising diffusion',
    'policy network',
    'transformers'
  ],
  install_requires=[ # 安装依赖
    'accelerate',
    'beartype',
    'einops>=0.7.0',
    'ema-pytorch',
    'torch>=2.1',
    'torchvision'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\discrete-key-value-bottleneck-pytorch\discrete_key_value_bottleneck_pytorch\discrete_key_value_bottleneck.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 einops 库中导入 rearrange, repeat, reduce 函数
from einops import rearrange, repeat, reduce
# 从 vector_quantize_pytorch 模块中导入 VectorQuantize 类

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 主类

class DiscreteKeyValueBottleneck(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_memories,
        dim_embed = None,
        num_memory_codebooks = 1,
        encoder = None,
        dim_memory = None,
        average_pool_memories = True,
        **kwargs
    ):
        super().__init__()
        self.encoder = encoder
        dim_embed = default(dim_embed, dim)
        self.dim_embed = dim_embed

        # 创建 VectorQuantize 对象
        self.vq = VectorQuantize(
            dim = dim * num_memory_codebooks,
            codebook_size = num_memories,
            heads = num_memory_codebooks,
            separate_codebook_per_head = True,
            **kwargs
        )

        dim_memory = default(dim_memory, dim)
        # 创建 nn.Parameter 对象
        self.values = nn.Parameter(torch.randn(num_memory_codebooks, num_memories, dim_memory))

        # 创建随机投影矩阵
        rand_proj = torch.empty(num_memory_codebooks, dim_embed, dim)
        nn.init.xavier_normal_(rand_proj)

        # 将随机投影矩阵注册为 buffer
        self.register_buffer('rand_proj', rand_proj)
        self.average_pool_memories = average_pool_memories

    def forward(
        self,
        x,
        return_intermediates = False,
        average_pool_memories = None,
        **kwargs
    ):
        average_pool_memories = default(average_pool_memories, self.average_pool_memories)

        if exists(self.encoder):
            self.encoder.eval()
            with torch.no_grad():
                x = self.encoder(x, **kwargs)
                x.detach_()

        # 检查输入张量的最后一个维度是否与 dim_embed 相同
        assert x.shape[-1] == self.dim_embed, f'encoding has a dimension of {x.shape[-1]} but dim_embed (defaults to dim) is set to {self.dim_embed} on init'

        # 线性变换
        x = einsum('b n d, c d e -> b n c e', x, self.rand_proj)
        # 重排张量维度
        x = rearrange(x, 'b n c e -> b n (c e)')

        # 对 x 进行向量量化
        vq_out = self.vq(x)

        quantized, memory_indices, commit_loss = vq_out

        if memory_indices.ndim == 2:
            memory_indices = rearrange(memory_indices, '... -> ... 1')

        memory_indices = rearrange(memory_indices, 'b n h -> b h n')

        values = repeat(self.values, 'h n d -> b h n d', b = memory_indices.shape[0])
        memory_indices = repeat(memory_indices, 'b h n -> b h n d', d = values.shape[-1])

        memories = values.gather(2, memory_indices)

        if average_pool_memories:
            memories = reduce(memories, 'b h n d -> b n d', 'mean')

        if return_intermediates:
            return memories, vq_out

        return memories

.\lucidrains\discrete-key-value-bottleneck-pytorch\discrete_key_value_bottleneck_pytorch\__init__.py

# 从 discrete_key_value_bottleneck_pytorch 模块中导入 DiscreteKeyValueBottleneck 类
from discrete_key_value_bottleneck_pytorch.discrete_key_value_bottleneck import DiscreteKeyValueBottleneck

Discrete Key / Value Bottleneck - Pytorch

Implementation of Discrete Key / Value Bottleneck, in Pytorch.

Install

$ pip install discrete-key-value-bottleneck-pytorch

Usage

import torch
from discrete_key_value_bottleneck_pytorch import DiscreteKeyValueBottleneck

key_value_bottleneck = DiscreteKeyValueBottleneck(
    dim = 512,                  # input dimension
    dim_memory = 512,           # output dimension - or dimension of each memories for all heads (defaults to same as input)
    num_memory_codebooks = 2,   # number of memory codebook, embedding is split into 2 pieces of 256, 256, quantized, outputs 256, 256, flattened together to 512
    num_memories = 256,         # number of memories
    decay = 0.9,                # the exponential moving average decay, lower means the keys will change faster
)

embeds = torch.randn(1, 1024, 512)  # from pretrained encoder

memories = key_value_bottleneck(embeds)

memories.shape # (1, 1024, 512)  # (batch, seq, memory / values dimension)

# now you can use the memories for the downstream decoder

You can also pass the pretrained encoder to the bottleneck and it will automatically invoke it. Example with vit-pytorch library

$ pip install vit-pytorch

Then

import torch

# import vision transformer

from vit_pytorch import SimpleViT
from vit_pytorch.extractor import Extractor

vit = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# train vit, or load pretrained

vit = Extractor(vit, return_embeddings_only = True)

# then

from discrete_key_value_bottleneck_pytorch import DiscreteKeyValueBottleneck

enc_with_bottleneck = DiscreteKeyValueBottleneck(
    encoder = vit,         # pass the frozen encoder into the bottleneck
    dim = 512,             # input dimension
    num_memories = 256,    # number of memories
    dim_memory = 2048,     # dimension of the output memories
    decay = 0.9,           # the exponential moving average decay, lower means the keys will change faster
)

images = torch.randn(1, 3, 256, 256)  # input to encoder

memories = enc_with_bottleneck(images) # (1, 64, 2048)   # (64 patches)

Todo

Citations

@inproceedings{Trauble2022DiscreteKB,
    title   = {Discrete Key-Value Bottleneck},
    author  = {Frederik Trauble and Anirudh Goyal and Nasim Rahaman and Michael Curtis Mozer and Kenji Kawaguchi and Yoshua Bengio and Bernhard Scholkopf},
    year    = {2022}
}

.\lucidrains\discrete-key-value-bottleneck-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'discrete-key-value-bottleneck-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.1.1',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Discrete Key / Value Bottleneck - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目的URL
  url = 'https://github.com/lucidrains/discrete-key-value-bottleneck-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'quantization',
    'memory',
    'transfer learning'
  ],
  # 安装依赖项
  install_requires=[
    'einops>=0.6',
    'vector-quantize-pytorch>=1.6.28',
    'torch>=1.6',
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Distilling Knowledge from Reader to Retriever

Implementation of the retriever distillation procedure as outlined in the paper Distilling Knowledge from Reader to Retriever in Pytorch. They propose to train the retriever using the cross attention scores as pseudo-labels. SOTA on QA.

Update: The BM25 gains actually do not look as impressive as the BERT gains. Also, it seems like distilling with BERT as the starting point never gets to the same level as BM25.

I am thinking whether it makes more sense to modify Marge (https://github.com/lucidrains/marge-pytorch) so one minimizes a loss between an extra prediction head on top of the retriever to the cross-attention scores, during training.

Citations

@misc{izacard2020distilling,
    title={Distilling Knowledge from Reader to Retriever for Question Answering}, 
    author={Gautier Izacard and Edouard Grave},
    year={2020},
    eprint={2012.04584},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

Dreamcraft3d - Pytorch (wip)

Implementation of Dreamcraft3D, 3D content generation in Pytorch

Citations

@inproceedings{Sun2023DreamCraft3DH3,
    title   = {DreamCraft3D: Hierarchical 3D Generation with Bootstrapped Diffusion Prior},
    author  = {Jingxiang Sun and Bo Zhang and Ruizhi Shao and Lizhen Wang and Wen Liu and Zhenda Xie and Yebin Liu},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:264452015}
}

.\lucidrains\egnn-pytorch\denoise_sparse.py

# 导入所需的库
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 导入 sidechainnet 和 egnn_pytorch 库
import sidechainnet as scn
from egnn_pytorch.egnn_pytorch import EGNN_Network

# 设置默认的张量数据类型为 float64
torch.set_default_dtype(torch.float64)

# 定义批量大小和梯度累积次数
BATCH_SIZE = 1
GRADIENT_ACCUMULATE_EVERY = 16

# 定义一个循环生成器函数,用于处理数据加载器中的数据
def cycle(loader, len_thres = 200):
    while True:
        for data in loader:
            # 如果序列长度超过指定阈值,则跳过
            if data.seqs.shape[1] > len_thres:
                continue
            yield data

# 创建 EGNN 网络模型
net = EGNN_Network(
    num_tokens = 21,
    num_positions = 200 * 3,   # 最大位置数 - 绝对位置嵌入,因为序列中存在固有顺序
    depth = 5,
    dim = 8,
    num_nearest_neighbors = 16,
    fourier_features = 2,
    norm_coors = True,
    coor_weights_clamp_value = 2.
).cuda()

# 加载数据集
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = BATCH_SIZE,
    dynamic_batching = False
)

# 创建数据加载器循环
dl = cycle(data['train'])
# 初始化优化器
optim = Adam(net.parameters(), lr=1e-3)

# 进行训练循环
for _ in range(10000):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        # 获取一个批次的数据
        batch = next(dl)
        seqs, coords, masks = batch.seqs, batch.crds, batch.msks

        # 将序列转移到 GPU 并获取最大值索引
        seqs = seqs.cuda().argmax(dim = -1)
        coords = coords.cuda().type(torch.float64)
        masks = masks.cuda().bool()

        # 获取序列长度并重新排列坐标
        l = seqs.shape[1]
        coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

        # 保留主干坐标
        coords = coords[:, :, 0:3, :]
        coords = rearrange(coords, 'b l s c -> b (l s) c')

        # 重复序列和掩码
        seq = repeat(seqs, 'b n -> b (n c)', c = 3)
        masks = repeat(masks, 'b n -> b (n c)', c = 3)

        # 创建邻接矩阵
        i = torch.arange(seq.shape[-1], device = seq.device)
        adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

        # 添加噪声到坐标
        noised_coords = coords + torch.randn_like(coords)

        # 使用 EGNN 网络进行前向传播
        feats, denoised_coords = net(seq, noised_coords, adj_mat = adj_mat, mask = masks)

        # 计算损失
        loss = F.mse_loss(denoised_coords[masks], coords[masks])

        # 反向传播并更新梯度
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    # 打印损失值
    print('loss:', loss.item())
    # 更新优化器
    optim.step()
    # 清空梯度
    optim.zero_grad()

.\lucidrains\egnn-pytorch\egnn_pytorch\egnn_pytorch.py

import torch
from torch import nn, einsum, broadcast_tensors
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 安全除法,避免分母为零
def safe_div(num, den, eps = 1e-8):
    res = num.div(den.clamp(min = eps))
    res.masked_fill_(den == 0, 0.)
    return res

# 在给定维度上批量选择索引
def batched_index_select(values, indices, dim = 1):
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    indices = indices[(..., *((None,) * len(value_dims))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# 傅立叶编码距离
def fourier_encode_dist(x, num_encodings = 4, include_self = True):
    x = x.unsqueeze(-1)
    device, dtype, orig_x = x.device, x.dtype, x
    scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype)
    x = x / scales
    x = torch.cat([x.sin(), x.cos()], dim=-1)
    x = torch.cat((x, orig_x), dim = -1) if include_self else x
    return x

# 嵌入标记
def embedd_token(x, dims, layers):
    stop_concat = -len(dims)
    to_embedd = x[:, stop_concat:].long()
    for i,emb_layer in enumerate(layers):
        # 与 `to_embedd` 部分对应的部分被丢弃
        x = torch.cat([ x[:, :stop_concat], 
                        emb_layer( to_embedd[:, i] ) 
                      ], dim=-1)
        stop_concat = x.shape[-1]
    return x

# Swish 激活函数回退
class Swish_(nn.Module):
    def forward(self, x):
        return x * x.sigmoid()

SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_

# 辅助类

# 这遵循与 SE3 Transformers 中规范化相同的策略
class CoorsNorm(nn.Module):
    def __init__(self, eps = 1e-8, scale_init = 1.):
        super().__init__()
        self.eps = eps
        scale = torch.zeros(1).fill_(scale_init)
        self.scale = nn.Parameter(scale)

    def forward(self, coors):
        norm = coors.norm(dim = -1, keepdim = True)
        normed_coors = coors / norm.clamp(min = self.eps)
        return normed_coors * self.scale

# 全局线性注意力
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x, context, mask = None):
        h = self.heads

        q = self.to_q(x)
        kv = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        if exists(mask):
            mask_value = -torch.finfo(dots.dtype).max
            mask = rearrange(mask, 'b n -> b () () n')
            dots.masked_fill_(~mask, mask_value)

        attn = dots.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

class GlobalLinearAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64
    # 定义一个 Transformer 模块,继承自 nn.Module
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化 LayerNorm 模块,对输入进行归一化
        self.norm_seq = nn.LayerNorm(dim)
        self.norm_queries = nn.LayerNorm(dim)
        # 初始化两个 Attention 模块,用于计算注意力
        self.attn1 = Attention(dim, heads, dim_head)
        self.attn2 = Attention(dim, heads, dim_head)

        # 定义前馈神经网络结构
        self.ff = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行归一化
            nn.Linear(dim, dim * 4),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Linear(dim * 4, dim)  # 线性变换
        )

    # 定义前向传播函数
    def forward(self, x, queries, mask = None):
        # 保存输入的原始值
        res_x, res_queries = x, queries
        # 对输入进行归一化
        x, queries = self.norm_seq(x), self.norm_queries(queries)

        # 计算第一个 Attention 模块的输出
        induced = self.attn1(queries, x, mask = mask)
        # 计算第二个 Attention 模块的输出
        out     = self.attn2(x, induced)

        # 将 Attention 模块的输出与原始输入相加
        x =  out + res_x
        queries = induced + res_queries

        # 经过前馈神经网络处理
        x = self.ff(x) + x
        # 返回处理后的结果
        return x, queries
# 定义 EGNN 类
class EGNN(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        edge_dim = 0,
        m_dim = 16,
        fourier_features = 0,
        num_nearest_neighbors = 0,
        dropout = 0.0,
        init_eps = 1e-3,
        norm_feats = False,
        norm_coors = False,
        norm_coors_scale_init = 1e-2,
        update_feats = True,
        update_coors = True,
        only_sparse_neighbors = False,
        valid_radius = float('inf'),
        m_pool_method = 'sum',
        soft_edges = False,
        coor_weights_clamp_value = None
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 检查池化方法是否合法
        assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean'
        # 检查是否需要更新特征或坐标
        assert update_feats or update_coors, 'you must update either features, coordinates, or both'

        # 设置傅立叶特征数量
        self.fourier_features = fourier_features

        # 计算边输入维度
        edge_input_dim = (fourier_features * 2) + (dim * 2) + edge_dim + 1
        # 根据 dropout 值创建 Dropout 层或者恒等映射
        dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        # 定义边 MLP 网络
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_input_dim, edge_input_dim * 2),
            dropout,
            SiLU(),
            nn.Linear(edge_input_dim * 2, m_dim),
            SiLU()
        )

        # 如果使用软边,则定义边门控网络
        self.edge_gate = nn.Sequential(
            nn.Linear(m_dim, 1),
            nn.Sigmoid()
        ) if soft_edges else None

        # 如果需要对节点特征进行归一化,则定义 LayerNorm 层,否则为恒等映射
        self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity()
        # 如果需要对坐标进行归一化,则定义 CoorsNorm 层,否则为恒等映射
        self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity()

        # 设置池化方法
        self.m_pool_method = m_pool_method

        # 如果需要更新特征,则定义节点 MLP 网络
        self.node_mlp = nn.Sequential(
            nn.Linear(dim + m_dim, dim * 2),
            dropout,
            SiLU(),
            nn.Linear(dim * 2, dim),
        ) if update_feats else None

        # 如果需要更新坐标,则定义坐标 MLP 网络
        self.coors_mlp = nn.Sequential(
            nn.Linear(m_dim, m_dim * 4),
            dropout,
            SiLU(),
            nn.Linear(m_dim * 4, 1)
        ) if update_coors else None

        # 设置最近邻节点数量、是否只使用稀疏邻居、有效半径
        self.num_nearest_neighbors = num_nearest_neighbors
        self.only_sparse_neighbors = only_sparse_neighbors
        self.valid_radius = valid_radius

        # 设置坐标权重截断值
        self.coor_weights_clamp_value = coor_weights_clamp_value

        # 设置初始化值
        self.init_eps = init_eps
        # 应用初始化函数
        self.apply(self.init_)

    # 初始化函数
    def init_(self, module):
        # 如果模块类型为线性层
        if type(module) in {nn.Linear}:
            # 初始化权重,防止网络深度增加导致出现 NaN
            nn.init.normal_(module.weight, std = self.init_eps)

# 定义 EGNN_Network 类
class EGNN_Network(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        depth,
        dim,
        num_tokens = None,
        num_edge_tokens = None,
        num_positions = None,
        edge_dim = 0,
        num_adj_degrees = None,
        adj_dim = 0,
        global_linear_attn_every = 0,
        global_linear_attn_heads = 8,
        global_linear_attn_dim_head = 64,
        num_global_tokens = 4,
        **kwargs
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言邻接度数不小于1
        assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'
        # 初始化位置数量
        self.num_positions = num_positions

        # 如果存在标记数量,则创建标记嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
        # 如果存在位置数量,则创建位置嵌入层
        self.pos_emb = nn.Embedding(num_positions, dim) if exists(num_positions) else None
        # 如果存在边标记数量,则创建边嵌入层
        self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
        # 判断是否存在边
        self.has_edges = edge_dim > 0

        # 初始化邻接度数
        self.num_adj_degrees = num_adj_degrees
        # 如果邻接度数存在且邻接维度大于0,则创建邻接嵌入层
        self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None

        # 如果存在边,则将边维度赋值给edge_dim,否则为0
        edge_dim = edge_dim if self.has_edges else 0
        # 如果邻接度数存在,则将邻接维度赋值给adj_dim,否则为0
        adj_dim = adj_dim if exists(num_adj_degrees) else 0

        # 判断是否存在全局注意力
        has_global_attn = global_linear_attn_every > 0
        # 初始化全局标记
        self.global_tokens = None
        # 如果存在全局注意力,则初始化全局标记
        if has_global_attn:
            self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, dim))

        # 初始化层列表
        self.layers = nn.ModuleList([])
        # 遍历深度
        for ind in range(depth):
            # 判断是否为全局层
            is_global_layer = has_global_attn and (ind % global_linear_attn_every) == 0

            # 添加全局线性注意力和EGNN层到层列表
            self.layers.append(nn.ModuleList([
                GlobalLinearAttention(dim = dim, heads = global_linear_attn_heads, dim_head = global_linear_attn_dim_head) if is_global_layer else None,
                EGNN(dim = dim, edge_dim = (edge_dim + adj_dim), norm_feats = True, **kwargs),
            ]))

    def forward(
        self,
        feats,
        coors,
        adj_mat = None,
        edges = None,
        mask = None,
        return_coor_changes = False
    ):
        # 获取批次大小和设备
        b, device = feats.shape[0], feats.device

        # 如果存在标记嵌入层,则对特征进行标记嵌入
        if exists(self.token_emb):
            feats = self.token_emb(feats)

        # 如果存在位置嵌入层,则对特征进行位置嵌入
        if exists(self.pos_emb):
            n = feats.shape[1]
            # 断言序列长度小于等于初始化时设置的位置数量
            assert n <= self.num_positions, f'given sequence length {n} must be less than the number of positions {self.num_positions} set at init'
            pos_emb = self.pos_emb(torch.arange(n, device = device))
            feats += rearrange(pos_emb, 'n d -> () n d')

        # 如果存在边并且存在边嵌入层,则对边进行边嵌入
        if exists(edges) and exists(self.edge_emb):
            edges = self.edge_emb(edges)

        # 从一阶连接创建N度邻接矩阵
        if exists(self.num_adj_degrees):
            assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'

            if len(adj_mat.shape) == 2:
                adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)

            adj_indices = adj_mat.clone().long()

            for ind in range(self.num_adj_degrees - 1):
                degree = ind + 2

                next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
                adj_indices.masked_fill_(next_degree_mask, degree)
                adj_mat = next_degree_adj_mat.clone()

            if exists(self.adj_emb):
                adj_emb = self.adj_emb(adj_indices)
                edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb

        # 设置全局注意力
        global_tokens = None
        if exists(self.global_tokens):
            global_tokens = repeat(self.global_tokens, 'n d -> b n d', b = b)

        # 遍历层
        coor_changes = [coors]

        for global_attn, egnn in self.layers:
            if exists(global_attn):
                feats, global_tokens = global_attn(feats, global_tokens, mask = mask)

            feats, coors = egnn(feats, coors, adj_mat = adj_mat, edges = edges, mask = mask)
            coor_changes.append(coors)

        # 如果需要返回坐标变化,则返回特征、坐标和坐标变化
        if return_coor_changes:
            return feats, coors, coor_changes

        # 否则只返回特征和坐标
        return feats, coors

.\lucidrains\egnn-pytorch\egnn_pytorch\egnn_pytorch_geometric.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, broadcast_tensors
from torch import nn, einsum, broadcast_tensors
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange, repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange
from einops.layers.torch import Rearrange

# 导入类型相关的模块
from typing import Optional, List, Union

# 尝试导入 torch_geometric 库
try:
    import torch_geometric
    # 从 torch_geometric.nn 中导入 MessagePassing
    from torch_geometric.nn import MessagePassing
    # 从 torch_geometric.typing 中导入 Adj, Size, OptTensor, Tensor
    from torch_geometric.typing import Adj, Size, OptTensor, Tensor
except:
    # 如果导入失败,则将相关类型设为 object 类型
    Tensor = OptTensor = Adj = MessagePassing = Size = object
    # 设置 PYG_AVAILABLE 为 False
    PYG_AVAILABLE = False
    
    # 为了避免类型建议时出现错误,将相关类型设为 object 类型
    Adj = object
    Size = object
    OptTensor = object
    Tensor = object

# 从当前目录下的 egnn_pytorch 文件中导入所有内容
from .egnn_pytorch import *

# 定义全局线性注意力类 GlobalLinearAttention_Sparse
class GlobalLinearAttention_Sparse(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64
    ):
        super().__init__()
        # 初始化序列规范化层 norm_seq 和 queries 规范化层 norm_queries
        self.norm_seq = torch_geomtric.nn.norm.LayerNorm(dim)
        self.norm_queries = torch_geomtric.nn.norm.LayerNorm(dim)
        # 初始化两个稀疏注意力层 attn1 和 attn2
        self.attn1 = Attention_Sparse(dim, heads, dim_head)
        self.attn2 = Attention_Sparse(dim, heads, dim_head)

        # 无法将 pyg norms 与 torch sequentials 连接
        # 初始化前馈神经网络规范化层 ff_norm
        self.ff_norm = torch_geomtric.nn.norm.LayerNorm(dim)
        # 初始化前馈神经网络 ff
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    # 定义前向传播函数
    def forward(self, x, queries, batch=None, batch_uniques=None, mask = None):
        res_x, res_queries = x, queries
        # 对输入 x 和 queries 进行序列规范化
        x, queries = self.norm_seq(x, batch=batch), self.norm_queries(queries, batch=batch)

        # 计算引导向量
        induced = self.attn1.sparse_forward(queries, x, batch=batch, batch_uniques=batch_uniques, mask = mask)
        # 计算输出
        out = self.attn2.sparse_forward(x, induced, batch=batch, batch_uniques=batch_uniques)

        # 更新 x 和 queries
        x =  out + res_x
        queries = induced + res_queries

        # 对 x 进行前馈神经网络规范化
        x_norm = self.ff_norm(x, batch=batch)
        # 前馈神经网络处理 x
        x = self.ff(x_norm) + x_norm
        return x, queries

# 定义 EGNN_Sparse 类,继承自 MessagePassing
class EGNN_Sparse(MessagePassing):
    """ Different from the above since it separates the edge assignment
        from the computation (this allows for great reduction in time and 
        computations when the graph is locally or sparse connected).
        * aggr: one of ["add", "mean", "max"]
    """
    # 初始化函数,设置模型参数
    def __init__(
        self,
        feats_dim,
        pos_dim=3,
        edge_attr_dim = 0,
        m_dim = 16,
        fourier_features = 0,
        soft_edge = 0,
        norm_feats = False,
        norm_coors = False,
        norm_coors_scale_init = 1e-2,
        update_feats = True,
        update_coors = True, 
        dropout = 0.,
        coor_weights_clamp_value = None, 
        aggr = "add",
        **kwargs
    ):
        # 检查聚合方法是否有效
        assert aggr in {'add', 'sum', 'max', 'mean'}, 'pool method must be a valid option'
        # 检查是否需要更新特征或坐标
        assert update_feats or update_coors, 'you must update either features, coordinates, or both'
        # 设置默认聚合方法
        kwargs.setdefault('aggr', aggr)
        # 调用父类的初始化函数
        super(EGNN_Sparse, self).__init__(**kwargs)
        # 设置模型参数
        self.fourier_features = fourier_features
        self.feats_dim = feats_dim
        self.pos_dim = pos_dim
        self.m_dim = m_dim
        self.soft_edge = soft_edge
        self.norm_feats = norm_feats
        self.norm_coors = norm_coors
        self.update_coors = update_coors
        self.update_feats = update_feats
        self.coor_weights_clamp_value = None

        # 计算边的输入维度
        self.edge_input_dim = (fourier_features * 2) + edge_attr_dim + 1 + (feats_dim * 2)
        # 根据 dropout 设置创建 Dropout 层或 Identity 层
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        # 边的 MLP 网络
        self.edge_mlp = nn.Sequential(
            nn.Linear(self.edge_input_dim, self.edge_input_dim * 2),
            self.dropout,
            SiLU(),
            nn.Linear(self.edge_input_dim * 2, m_dim),
            SiLU()
        )

        # 如果 soft_edge 为真,则创建边权重网络
        self.edge_weight = nn.Sequential(nn.Linear(m_dim, 1), 
                                         nn.Sigmoid()
        ) if soft_edge else None

        # 节点的 LayerNorm 或 Identity 层
        self.node_norm = torch_geometric.nn.norm.LayerNorm(feats_dim) if norm_feats else None
        # 坐标的 CoorsNorm 或 Identity 层
        self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity()

        # 节点的 MLP 网络
        self.node_mlp = nn.Sequential(
            nn.Linear(feats_dim + m_dim, feats_dim * 2),
            self.dropout,
            SiLU(),
            nn.Linear(feats_dim * 2, feats_dim),
        ) if update_feats else None

        # 坐标的 MLP 网络
        self.coors_mlp = nn.Sequential(
            nn.Linear(m_dim, m_dim * 4),
            self.dropout,
            SiLU(),
            nn.Linear(self.m_dim * 4, 1)
        ) if update_coors else None

        # 初始化模型参数
        self.apply(self.init_)

    # 初始化函数,设置模型参数的初始化方式
    def init_(self, module):
        # 如果模块类型为 nn.Linear
        if type(module) in {nn.Linear}:
            # 使用 xavier_normal_ 初始化权重,zeros_ 初始化偏置
            nn.init.xavier_normal_(module.weight)
            nn.init.zeros_(module.bias)
    def forward(self, x: Tensor, edge_index: Adj,
                edge_attr: OptTensor = None, batch: Adj = None, 
                angle_data: List = None,  size: Size = None) -> Tensor:
        """ Inputs: 
            * x: (n_points, d) where d is pos_dims + feat_dims
            * edge_index: (2, n_edges)
            * edge_attr: tensor (n_edges, n_feats) excluding basic distance feats.
            * batch: (n_points,) long tensor. specifies xloud belonging for each point
            * angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor.
            * size: None
        """
        # 将输入的 x 分为坐标和特征
        coors, feats = x[:, :self.pos_dim], x[:, self.pos_dim:]
        
        # 计算相对坐标和相对距离
        rel_coors = coors[edge_index[0]] - coors[edge_index[1]]
        rel_dist  = (rel_coors ** 2).sum(dim=-1, keepdim=True)

        # 如果使用傅立叶特征
        if self.fourier_features > 0:
            # 对相对距离进行傅立叶编码
            rel_dist = fourier_encode_dist(rel_dist, num_encodings = self.fourier_features)
            rel_dist = rearrange(rel_dist, 'n () d -> n d')

        # 如果存在边属性,则将边属性和相对距离拼接
        if exists(edge_attr):
            edge_attr_feats = torch.cat([edge_attr, rel_dist], dim=-1)
        else:
            edge_attr_feats = rel_dist

        # 进行消息传递和更新节点信息
        hidden_out, coors_out = self.propagate(edge_index, x=feats, edge_attr=edge_attr_feats,
                                                           coors=coors, rel_coors=rel_coors, 
                                                           batch=batch)
        # 返回节点坐标和隐藏层输出的拼接
        return torch.cat([coors_out, hidden_out], dim=-1)


    def message(self, x_i, x_j, edge_attr) -> Tensor:
        # 通过边属性和节点特征计算消息
        m_ij = self.edge_mlp( torch.cat([x_i, x_j, edge_attr], dim=-1) )
        return m_ij

    def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
        """The initial call to start propagating messages.
            Args:
            `edge_index` holds the indices of a general (sparse)
                assignment matrix of shape :obj:`[N, M]`.
            size (tuple, optional) if none, the size will be inferred
                and assumed to be quadratic.
            **kwargs: Any additional data which is needed to construct and
                aggregate messages, and to update node embeddings.
        """
        # 检查输入并收集数据
        size = self._check_input(edge_index, size)
        coll_dict = self._collect(self._user_args,
                                     edge_index, size, kwargs)
        msg_kwargs = self.inspector.distribute('message', coll_dict)
        aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
        update_kwargs = self.inspector.distribute('update', coll_dict)
        
        # 获取消息
        m_ij = self.message(**msg_kwargs)

        # 如果需要更新坐标
        if self.update_coors:
            coor_wij = self.coors_mlp(m_ij)
            # 如果设置了夹紧值,则夹紧权重
            if self.coor_weights_clamp_value:
                coor_weights_clamp_value = self.coor_weights_clamp_value
                coor_weights.clamp_(min = -clamp_value, max = clamp_value)

            # 如果需要归一化,则对相对坐标进行归一化
            kwargs["rel_coors"] = self.coors_norm(kwargs["rel_coors"])

            mhat_i = self.aggregate(coor_wij * kwargs["rel_coors"], **aggr_kwargs)
            coors_out = kwargs["coors"] + mhat_i
        else:
            coors_out = kwargs["coors"]

        # 如果需要更新特征
        if self.update_feats:
            # 如果传递了软边参数,则加权边
            if self.soft_edge:
                m_ij = m_ij * self.edge_weight(m_ij)
            m_i = self.aggregate(m_ij, **aggr_kwargs)

            hidden_feats = self.node_norm(kwargs["x"], kwargs["batch"]) if self.node_norm else kwargs["x"]
            hidden_out = self.node_mlp( torch.cat([hidden_feats, m_i], dim = -1) )
            hidden_out = kwargs["x"] + hidden_out
        else: 
            hidden_out = kwargs["x"]

        # 返回更新后的节点信息
        return self.update((hidden_out, coors_out), **update_kwargs)
    # 定义对象的字符串表示形式
    def __repr__(self):
        # 创建一个空字典
        dict_print = {}
        # 返回对象的字符串表示形式,包含对象的属性字典
        return "E(n)-GNN Layer for Graphs " + str(self.__dict__) 
class EGNN_Sparse_Network(nn.Module):
    r"""Sample GNN model architecture that uses the EGNN-Sparse
        message passing layer to learn over point clouds. 
        Main MPNN layer introduced in https://arxiv.org/abs/2102.09844v1

        Inputs will be standard GNN: x, edge_index, edge_attr, batch, ...

        Args:
        * n_layers: int. number of MPNN layers
        * ... : same interpretation as the base layer.
        * embedding_nums: list. number of unique keys to embedd. for points
                          1 entry per embedding needed. 
        * embedding_dims: list. point - number of dimensions of
                          the resulting embedding. 1 entry per embedding needed. 
        * edge_embedding_nums: list. number of unique keys to embedd. for edges.
                               1 entry per embedding needed. 
        * edge_embedding_dims: list. point - number of dimensions of
                               the resulting embedding. 1 entry per embedding needed. 
        * recalc: int. Recalculate edge feats every `recalc` MPNN layers. 0 for no recalc
        * verbose: bool. verbosity level.
        -----
        Diff with normal layer: one has to do preprocessing before (radius, global token, ...)
    """
    def forward(self, x, edge_index, batch, edge_attr,
                bsize=None, recalc_edge=None, verbose=0):
        """ Recalculate edge features every `self.recalc_edge` with the
            `recalc_edge` function if self.recalc_edge is set.

            * x: (N, pos_dim+feats_dim) will be unpacked into coors, feats.
        """
        # NODES - Embedd each dim to its target dimensions:
        x = embedd_token(x, self.embedding_dims, self.emb_layers)

        # regulates whether to embed edges each layer
        edges_need_embedding = True  
        for i,layer in enumerate(self.mpnn_layers):
            
            # EDGES - Embedd each dim to its target dimensions:
            if edges_need_embedding:
                edge_attr = embedd_token(edge_attr, self.edge_embedding_dims, self.edge_emb_layers)
                edges_need_embedding = False

            # attn tokens
            global_tokens = None
            if exists(self.global_tokens):
                unique, amounts = torch.unique(batch, return_counts)
                num_idxs = torch.cat([torch.arange(num_idxs_i) for num_idxs_i in amounts], dim=-1)
                global_tokens = self.global_tokens[num_idxs]

            # pass layers
            is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0
            if not is_global_layer:
                x = layer(x, edge_index, edge_attr, batch=batch, size=bsize)
            else: 
                # only pass feats to the attn layer
                x_attn = layer[0](x[:, self.pos_dim:], global_tokens)
                # merge attn-ed feats and coords
                x = torch.cat( (x[:, :self.pos_dim], x_attn), dim=-1)
                x = layer[-1](x, edge_index, edge_attr, batch=batch, size=bsize)

            # recalculate edge info - not needed if last layer
            if self.recalc and ((i%self.recalc == 0) and not (i == len(self.mpnn_layers)-1)) :
                edge_index, edge_attr, _ = recalc_edge(x) # returns attr, idx, any_other_info
                edges_need_embedding = True
            
        return x

    def __repr__(self):
        return 'EGNN_Sparse_Network of: {0} layers'.format(len(self.mpnn_layers))

.\lucidrains\egnn-pytorch\egnn_pytorch\utils.py

# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos

# 定义绕 z 轴旋转的函数,参数为旋转角度 gamma
def rot_z(gamma):
    # 返回绕 z 轴旋转的旋转矩阵
    return torch.tensor([
        [cos(gamma), -sin(gamma), 0],
        [sin(gamma), cos(gamma), 0],
        [0, 0, 1]
    ], dtype=gamma.dtype)

# 定义绕 y 轴旋转的函数,参数为旋转角度 beta
def rot_y(beta):
    # 返回绕 y 轴旋转的旋转矩阵
    return torch.tensor([
        [cos(beta), 0, sin(beta)],
        [0, 1, 0],
        [-sin(beta), 0, cos(beta)]
    ], dtype=beta.dtype)

# 定义绕任意轴旋转的函数,参数为三个旋转角度 alpha, beta, gamma
def rot(alpha, beta, gamma):
    # 返回绕任意轴旋转的旋转矩阵,先绕 z 轴旋转 alpha,再绕 y 轴旋转 beta,最后绕 z 轴旋转 gamma
    return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

.\lucidrains\egnn-pytorch\egnn_pytorch\__init__.py

# 从 egnn_pytorch 模块中导入 EGNN 和 EGNN_Network 类
from egnn_pytorch.egnn_pytorch import EGNN, EGNN_Network
# 从 egnn_pytorch 模块中导入 EGNN_Sparse 和 EGNN_Sparse_Network 类
from egnn_pytorch.egnn_pytorch_geometric import EGNN_Sparse, EGNN_Sparse_Network

** A bug has been discovered with the neighbor selection in the presence of masking. If you ran any experiments prior to 0.1.12 that had masking, please rerun them. 🙏 **

EGNN - Pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.

Install

$ pip install egnn-pytorch

Usage

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)

feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)

With edges

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)

A full EGNN network

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    num_positions = 1024,           # unless what you are passing in is an unordered set, set this to the maximum sequence length
    dim = 32,
    depth = 3,
    num_nearest_neighbors = 8,
    coor_weights_clamp_value = 2.   # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3)         # (1, 1024, 3)
mask = torch.ones_like(feats).bool()    # (1, 1024)

feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)

Only attend to sparse neighbors, given to the network as an adjacency matrix.

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    only_sparse_neighbors = True
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

You can also have the network automatically determine the Nth-order neighbors, and pass in an adjacency embedding (depending on the order) to be used as an edge, with two extra keyword arguments

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    num_adj_degrees = 3,           # fetch up to 3rd degree neighbors
    adj_dim = 8,                   # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
    only_sparse_neighbors = True
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

Edges

If you need to pass in continuous edges

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    edge_dim = 4,
    num_nearest_neighbors = 3
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

continuous_edges = torch.randn(1, 1024, 1024, 4)

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

Stability

The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this.

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    num_nearest_neighbors = 32,
    norm_coors = True,              # normalize the relative coordinates
    coor_weights_clamp_value = 2.   # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3)         # (1, 1024, 3)
mask = torch.ones_like(feats).bool()    # (1, 1024)

feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)

All parameters

import torch
from egnn_pytorch import EGNN

model = EGNN(
    dim = dim,                         # input dimension
    edge_dim = 0,                      # dimension of the edges, if exists, should be > 0
    m_dim = 16,                        # hidden model dimension
    fourier_features = 0,              # number of fourier features for encoding of relative distance - defaults to none as in paper
    num_nearest_neighbors = 0,         # cap the number of neighbors doing message passing by relative distance
    dropout = 0.0,                     # dropout
    norm_feats = False,                # whether to layernorm the features
    norm_coors = False,                # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper    
    update_feats = True,               # whether to update features - you can build a layer that only updates one or the other
    update_coors = True,               # whether ot update coordinates
    only_sparse_neighbors = False,     # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in 
    valid_radius = float('inf'),       # the valid radius each node considers for message passing
    m_pool_method = 'sum',             # whether to mean or sum pool for output node representation
    soft_edges = False,                # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper
    coor_weights_clamp_value = None    # clamping of the coordinate updates, again, for stabilization purposes
)

Examples

To run the protein backbone denoising example, first install sidechainnet

$ pip install sidechainnet

Then

$ python denoise_sparse.py

Tests

Make sure you have pytorch geometric installed locally

$ python setup.py test

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\egnn-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'egnn-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.2.7',  # 版本号
  license='MIT',  # 许可证
  description = 'E(n)-Equivariant Graph Neural Network - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang, Eric Alcaide',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/egnn-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'equivariance',
    'graph neural network'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'numba',
    'numpy',
    'torch>=1.6'
  ],
  setup_requires=[  # 设置依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试依赖
    'pytest'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\egnn-pytorch\tests\test_equivariance.py

import torch  # 导入PyTorch库

from egnn_pytorch import EGNN, EGNN_Sparse  # 导入EGNN和EGNN_Sparse类
from egnn_pytorch.utils import rot  # 导入rot函数

torch.set_default_dtype(torch.float64)  # 设置PyTorch默认数据类型为float64

def test_egnn_equivariance():  # 定义测试函数test_egnn_equivariance
    layer = EGNN(dim=512, edge_dim=4)  # 创建EGNN层对象,设置维度和边维度

    R = rot(*torch.rand(3))  # 生成随机旋转矩阵R
    T = torch.randn(1, 1, 3)  # 生成随机平移向量T

    feats = torch.randn(1, 16, 512)  # 生成随机特征张量
    coors = torch.randn(1, 16, 3)  # 生成随机坐标张量
    edges = torch.randn(1, 16, 16, 4)  # 生成随机边张量
    mask = torch.ones(1, 16).bool()  # 生成全为True的掩码张量

    # 缓存前两个节点的特征
    node1 = feats[:, 0, :]  # 获取第一个节点的特征
    node2 = feats[:, 1, :]  # 获取第二个节点的特征

    # 交换第一个和第二个节点的位置
    feats_permuted_row_wise = feats.clone().detach()  # 克隆特征张量
    feats_permuted_row_wise[:, 0, :] = node2  # 将第一个节点的特征替换为第二个节点的特征
    feats_permuted_row_wise[:, 1, :] = node1  # 将第二个节点的特征替换为第一个节点的特征

    feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask)  # 使用EGNN层进行前向传播
    feats2, coors2 = layer(feats, coors, edges, mask=mask)  # 使用EGNN层进行前向传播
    feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask)  # 使用EGNN层进行前向传播

    assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant'  # 断言特征1和特征2在误差范围内相等
    assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant'  # 断言坐标1和坐标2在误差范围内相等
    assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'  # 断言特征1和特征3不在误差范围内相等

def test_higher_dimension():  # 定义测试函数test_higher_dimension
    layer = EGNN(dim=512, edge_dim=4)  # 创建EGNN层对象,设置维度和边维度

    feats = torch.randn(1, 16, 512)  # 生成随机特征张量
    coors = torch.randn(1, 16, 5)  # 生成随机坐标张量
    edges = torch.randn(1, 16, 16, 4)  # 生成随机边张量
    mask = torch.ones(1, 16).bool()  # 生成全为True的掩码张量

    feats, coors = layer(feats, coors, edges, mask=mask)  # 使用EGNN层进行前向传播
    assert True  # 断言为True

def test_egnn_equivariance_with_nearest_neighbors():  # 定义测试函数test_egnn_equivariance_with_nearest_neighbors
    layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8)  # 创建EGNN层对象,设置维度、边维度和最近邻节点数

    R = rot(*torch.rand(3))  # 生成随机旋转矩阵R
    T = torch.randn(1, 1, 3)  # 生成随机平移向量T

    feats = torch.randn(1, 256, 512)  # 生成随机特征张量
    coors = torch.randn(1, 256, 3)  # 生成随机坐标张量
    edges = torch.randn(1, 256, 256, 1)  # 生成随机边张量
    mask = torch.ones(1, 256).bool()  # 生成全为True的掩码张量

    # 缓存前两个节点的特征
    node1 = feats[:, 0, :]  # 获取第一个节点的特征
    node2 = feats[:, 1, :]  # 获取第二个节点的特征

    # 交换第一个和第二个节点的位置
    feats_permuted_row_wise = feats.clone().detach()  # 克隆特征张量
    feats_permuted_row_wise[:, 0, :] = node2  # 将第一个节点的特征替换为第二个节点的特征
    feats_permuted_row_wise[:, 1, :] = node1  # 将第二个节点的特征替换为第一个节点的特征

    feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask)  # 使用EGNN层进行前向传播
    feats2, coors2 = layer(feats, coors, edges, mask=mask)  # 使用EGNN层进行前向传播
    feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask)  # 使用EGNN层进行前向传播

    assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant'  # 断言特征1和特征2在误差范围内相等
    assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant'  # 断言坐标1和坐标2在误差范围内相等
    assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'  # 断言特征1和特征3不在误差范围内相等

def test_egnn_equivariance_with_coord_norm():  # 定义测试函数test_egnn_equivariance_with_coord_norm
    layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8, norm_coors=True)  # 创建EGNN层对象,设置维度、边维度、最近邻节点数和是否对坐标进行归一化

    R = rot(*torch.rand(3))  # 生成随机旋转矩阵R
    T = torch.randn(1, 1, 3)  # 生成随机平移向量T

    feats = torch.randn(1, 256, 512)  # 生成随机特征张量
    coors = torch.randn(1, 256, 3)  # 生成随机坐标张量
    edges = torch.randn(1, 256, 256, 1)  # 生成随机边张量
    mask = torch.ones(1, 256).bool()  # 生成全为True的掩码张量

    # 缓存前两个节点的特征
    node1 = feats[:, 0, :]  # 获取第一个节点的特征
    node2 = feats[:, 1, :]  # 获取第二个节点的特征

    # 交换第一个和第二个节点的位置
    feats_permuted_row_wise = feats.clone().detach()  # 克隆特征张量
    feats_permuted_row_wise[:, 0, :] = node2  # 将第一个节点的特征替换为第二个节点的特征
    feats_permuted_row_wise[:, 1, :] = node1  # 将第二个节点的特征替换为第一个节点的特征

    feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask)  # 使用EGNN层进行前向传播
    feats2, coors2 = layer(feats, coors, edges, mask=mask)  # 使用EGNN层进行前向传播
    feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask)  # 使用EGNN层进行前向传播

    assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant'  # 断言特征1和特征2在误差范围内相等
    assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant'  # 断言坐标1和坐标2在误差范围内相等
    assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'  # 断言特征1和特征3不在误差范围内相等

def test_egnn_sparse_equivariance():  # 定义测试函数test_egnn_sparse_equivariance
    layer = EGNN_Sparse(feats_dim=1, m_dim=16, fourier_features=4)  # 创建稀疏EGNN层对象,设置特征维度、消息维度和傅立叶特征数

    R = rot(*torch.rand(3))  # 生成随机旋转矩阵R
    T = torch.randn(1, 1, 3)  # 生成随机平移向量T
    apply_action = lambda t: (t @ R + T).squeeze()  # 定义应用旋转和平移的操作函数
    # 生成一个大小为16x1的随机张量,表示节点的特征
    feats = torch.randn(16, 1)
    # 生成一个大小为16x3的随机张量,表示节点的坐标
    coors = torch.randn(16, 3)
    # 生成一个大小为2x20的随机整数张量,表示边的索引
    edge_idxs = (torch.rand(2, 20) * 16).long()

    # 缓存第一个和第二个节点的特征
    node1 = feats[0, :]
    node2 = feats[1, :]

    # 交换第一个和第二个节点的位置,生成一个新的特征张量
    feats_permuted_row_wise = feats.clone().detach()
    feats_permuted_row_wise[0, :] = node2
    feats_permuted_row_wise[1, :] = node1

    # 将节点的坐标和特征拼接在一起,形成输入张量x1
    x1 = torch.cat([coors, feats], dim=-1)
    # 将节点的坐标和经过apply_action函数处理后的特征拼接在一起,形成输入张量x2
    x2 = torch.cat([apply_action(coors), feats], dim=-1)
    # 将节点的坐标和交换节点顺序后的特征拼接在一起,形成输入张量x3
    x3 = torch.cat([apply_action(coors), feats_permuted_row_wise], dim=-1)

    # 使用layer函数对输入张量x1进行处理,得到输出out1
    out1 = layer(x=x1, edge_index=edge_idxs)
    # 使用layer函数对输入张量x2进行处理,得到输出out2
    out2 = layer(x=x2, edge_index=edge_idxs)
    # 使用layer函数对输入张量x3进行处理,得到输出out3
    out3 = layer(x=x3, edge_index=edge_idxs)

    # 从out1中分离出特征和坐标
    feats1, coors1 = out1[:, 3:], out1[:, :3]
    # 从out2中分离出特征和坐标
    feats2, coors2 = out2[:, 3:], out2[:, :3]
    # 从out3中分离出特征和坐标
    feats3, coors3 = out3[:, 3:], out3[:, :3]

    # 打印feats1和feats2之间的差异
    print(feats1 - feats2)
    # 打印apply_action(coors1)和coors2之间的差异
    print(apply_action(coors1) - coors2)
    # 断言feats1和feats2必须非常接近,否则抛出异常
    assert torch.allclose(feats1, feats2), 'features must be invariant'
    # 断言apply_action(coors1)和coors2必须非常接近,否则抛出异常
    assert torch.allclose(apply_action(coors1), coors2), 'coordinates must be equivariant'
    # 断言feats1和feats3不能非常接近,否则抛出异常
    assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'
# 定义一个测试函数,用于测试几何等效性
def test_geom_equivalence():
    # 创建一个 EGNN_Sparse 层对象,设置特征维度为128,边属性维度为4,m维度为16,傅立叶特征为4
    layer = EGNN_Sparse(feats_dim=128,
                        edge_attr_dim=4,
                        m_dim=16,
                        fourier_features=4)

    # 生成一个大小为16x128的随机特征张量
    feats = torch.randn(16, 128)
    # 生成一个大小为16x3的随机坐标张量
    coors = torch.randn(16, 3)
    # 将坐标和特征张量在最后一个维度上拼接起来
    x = torch.cat([coors, feats], dim=-1)
    # 生成一个2x20的随机整数张量,用于表示边的索引
    edge_idxs = (torch.rand(2, 20) * 16).long()
    # 生成一个大小为16x16x4的随机边属性张量
    edges_attrs = torch.randn(16, 16, 4)
    # 根据边索引从边属性张量中取出对应的边属性
    edges_attrs = edges_attrs[edge_idxs[0], edge_idxs[1]]

    # 断言通过 EGNN_Sparse 层的前向传播后输出的形状与输入张量 x 的形状相同
    assert layer.forward(x, edge_idxs, edge_attr=edges_attrs).shape == x.shape

.\lucidrains\einops-exts\einops_exts\einops_exts.py

# 导入所需的模块
import re
from torch import nn
from functools import wraps, partial
# 从 einops 模块中导入 rearrange、reduce、repeat 函数

from einops import rearrange, reduce, repeat

# checking shape
# @nils-werner
# https://github.com/arogozhnikov/einops/issues/168#issuecomment-1042933838

# 定义函数 check_shape,用于检查张量的形状是否符合指定的模式
def check_shape(tensor, pattern, **kwargs):
    return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)

# do same einops operations on a list of tensors

# 定义装饰器函数 _many,用于对一组张量执行相同的 einops 操作
def _many(fn):
    @wraps(fn)
    def inner(tensors, pattern, **kwargs):
        return (fn(tensor, pattern, **kwargs) for tensor in tensors)
    return inner

# do einops with unflattening of anonymously named dimensions
# (...flattened) ->  ...flattened

# 定义装饰器函数 _with_anon_dims,用于在匿名命名维度上执行 einops 操作
def _with_anon_dims(fn):
    @wraps(fn)
    def inner(tensor, pattern, **kwargs):
        regex = r'(\.\.\.[a-zA-Z]+)'
        matches = re.findall(regex, pattern)
        get_anon_dim_name = lambda t: t.lstrip('...')
        dim_prefixes = tuple(map(get_anon_dim_name, set(matches)))

        update_kwargs_dict = dict()

        for prefix in dim_prefixes:
            assert prefix in kwargs, f'dimension list "{prefix}" was not passed in'
            dim_list = kwargs[prefix]
            assert isinstance(dim_list, (list, tuple)), f'dimension list "{prefix}" needs to be a tuple of list of dimensions'
            dim_names = list(map(lambda ind: f'{prefix}{ind}', range(len(dim_list)))
            update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list))

        def sub_with_anonymous_dims(t):
            dim_name_prefix = get_anon_dim_name(t.groups()[0])
            return ' '.join(update_kwargs_dict[dim_name_prefix].keys())

        pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern)

        for prefix, update_dict in update_kwargs_dict.items():
            del kwargs[prefix]
            kwargs.update(update_dict)

        return fn(tensor, pattern_new, **kwargs)
    return inner

# generate all helper functions

# 生成对多个张量执行 rearrange 操作的函数 rearrange_many
rearrange_many = _many(rearrange)
# 生成对多个张量执行 repeat 操作的函数 repeat_many
repeat_many = _many(repeat)
# 生成对多个张量执行 reduce 操作的函数 reduce_many

rearrange_with_anon_dims = _with_anon_dims(rearrange)
repeat_with_anon_dims = _with_anon_dims(repeat)
reduce_with_anon_dims = _with_anon_dims(reduce)

.\lucidrains\einops-exts\einops_exts\torch.py

# 导入 torch 中的 nn 模块
# 导入 einops 中的 rearrange 函数
from torch import nn
from einops import rearrange

# 定义一个用于转换和重组数据的类 EinopsToAndFrom
class EinopsToAndFrom(nn.Module):
    def __init__(self, from_einops, to_einops, fn):
        super().__init__()
        # 初始化类的属性
        self.from_einops = from_einops
        self.to_einops = to_einops
        self.fn = fn

        # 检查 from_einops 中是否包含 '...'
        if '...' in from_einops:
            # 如果包含 '...',则将其分割成 before 和 after 两部分
            before, after = [part.strip().split() for part in from_einops.split('...')]
            # 生成重组键值对,包括 before 和 after 部分
            self.reconstitute_keys = tuple(zip(before, range(len(before)))) + tuple(zip(after, range(-len(after), 0)))
        else:
            # 如果不包含 '...',则直接按空格分割成键值对
            split = from_einops.strip().split()
            self.reconstitute_keys = tuple(zip(split, range(len(split)))

    # 定义前向传播函数
    def forward(self, x, **kwargs):
        # 获取输入 x 的形状
        shape = x.shape
        # 根据 reconstitute_keys 生成重组参数字典
        reconstitute_kwargs = {key: shape[position] for key, position in self.reconstitute_keys}
        # 对输入 x 进行从 from_einops 到 to_einops 的重组
        x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
        # 对重组后的 x 进行处理
        x = self.fn(x, **kwargs)
        # 将处理后的 x 重新从 to_einops 重组回 from_einops
        x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
        # 返回处理后的 x
        return x

.\lucidrains\einops-exts\einops_exts\__init__.py

# 从 einops_exts.einops_exts 模块中导入 check_shape 函数
from einops_exts.einops_exts import check_shape
# 从 einops_exts.einops_exts 模块中导入 rearrange_many, repeat_many, reduce_many 函数
from einops_exts.einops_exts import rearrange_many, repeat_many, reduce_many
# 从 einops_exts.einops_exts 模块中导入 rearrange_with_anon_dims, repeat_with_anon_dims, reduce_with_anon_dims 函数
from einops_exts.einops_exts import rearrange_with_anon_dims, repeat_with_anon_dims, reduce_with_anon_dims

Einops Extensions

Implementation of some personal helper functions for Einops, my most favorite tensor manipulation library ❤️

Citations

@inproceedings{rogozhnikov2022einops,
  title     = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
  author    = {Alex Rogozhnikov},
  booktitle = {International Conference on Learning Representations},
  year      = {2022},
  url       = {https://openreview.net/forum?id=oapKSVM2bcj}
}

.\lucidrains\einops-exts\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'einops-exts',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.4',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Einops Extensions',
  # 长描述内容类型为 Markdown
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/einops-exts',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'tensor manipulation'
  ],
  # 安装依赖项
  install_requires=[
    'einops>=0.4',
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\electra-pytorch\electra_pytorch\electra_pytorch.py

# 导入数学库
import math
# 导入 reduce 函数
from functools import reduce
# 导入 namedtuple 类
from collections import namedtuple

# 导入 torch 库
import torch
# 导入 torch 中的 nn 模块
from torch import nn
# 导入 torch 中的 functional 模块
import torch.nn.functional as F

# 定义一个命名元组 Results,包含多个字段
Results = namedtuple('Results', [
    'loss',
    'mlm_loss',
    'disc_loss',
    'gen_acc',
    'disc_acc',
    'disc_labels',
    'disc_predictions'
])

# 定义一些辅助函数

# 计算输入张量的自然对数
def log(t, eps=1e-9):
    return torch.log(t + eps)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从 Gumbel 分布中采样
def gumbel_sample(t, temperature = 1.):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)

# 根据概率生成掩码
def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

# 使用特定的标记生成掩码
def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

# 根据概率获取子集掩码
def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)

    num_tokens = mask.sum(dim=-1, keepdim=True)
    mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-1)
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, seq_len + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    return new_mask[:, 1:].bool()

# 隐藏层提取器类,用于为语言模型添加适配器以进行预训练

class HiddenLayerExtractor(nn.Module):
    def __init__(self, net, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = output

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    def forward(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        _ = self.net(x)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

# Electra 主类

class Electra(nn.Module):
    # 初始化函数,接受生成器、判别器等参数
    def __init__(
        self,
        generator,
        discriminator,
        *,
        num_tokens = None,  # 可选参数:标记数量,默认为 None
        discr_dim = -1,  # 判别器维度,默认为 -1
        discr_layer = -1,  # 判别器层,默认为 -1
        mask_prob = 0.15,  # 掩码概率,默认为 0.15
        replace_prob = 0.85,  # 替换概率,默认为 0.85
        random_token_prob = 0.,  # 随机标记概率,默认为 0
        mask_token_id = 2,  # 掩码标记 ID,默认为 2
        pad_token_id = 0,  # 填充标记 ID,默认为 0
        mask_ignore_token_ids = [],  # 忽略的掩码标记 ID 列表,默认为空
        disc_weight = 50.,  # 判别器权重,默认为 50
        gen_weight = 1.,  # 生成器权重,默认为 1
        temperature = 1.):  # 温度参数,默认为 1
        super().__init__()  # 调用父类的初始化函数

        self.generator = generator  # 初始化生成器
        self.discriminator = discriminator  # 初始化判别器

        if discr_dim > 0:  # 如果判别器维度大于 0
            self.discriminator = nn.Sequential(  # 使用判别器的特定层
                HiddenLayerExtractor(discriminator, layer = discr_layer),  # 提取特定层的隐藏层
                nn.Linear(discr_dim, 1)  # 添加线性层
            )

        # mlm 相关概率
        self.mask_prob = mask_prob  # 掩码概率
        self.replace_prob = replace_prob  # 替换概率

        self.num_tokens = num_tokens  # 标记数量
        self.random_token_prob = random_token_prob  # 随机标记概率

        # 标记 ID
        self.pad_token_id = pad_token_id  # 填充标记 ID
        self.mask_token_id = mask_token_id  # 掩码标记 ID
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])  # 忽略的掩码标记 ID 集合

        # 采样温度
        self.temperature = temperature  # 温度参数

        # 损失权重
        self.disc_weight = disc_weight  # 判别器权重
        self.gen_weight = gen_weight  # 生成器权重
    # 定义前向传播函数,接受输入和其他参数
    def forward(self, input, **kwargs):
        # 获取输入张量的形状
        b, t = input.shape

        # 根据输入张量生成一个与其形状相同的概率掩码,用于替换概率
        replace_prob = prob_mask_like(input, self.replace_prob)

        # 创建一个不需要掩码的标记列表,包括 [pad] 标记和其他指定排除的标记(如 [cls], [sep])
        no_mask = mask_with_tokens(input, self.mask_ignore_token_ids)
        # 根据概率获取需要掩码的子集
        mask = get_mask_subset_with_prob(~no_mask, self.mask_prob)

        # 获取需要掩码的索引
        mask_indices = torch.nonzero(mask, as_tuple=True)

        # 使用掩码标记的标记替换为 [mask] 标记,保留标记不变
        masked_input = input.clone().detach()

        # 将掩码的标记替换为填充标记,用于生成标签
        gen_labels = input.masked_fill(~mask, self.pad_token_id)

        # 克隆掩码,用于可能的随机标记修改
        masking_mask = mask.clone()

        # 如果随机标记概率大于0,用于 MLM
        if self.random_token_prob > 0:
            assert self.num_tokens is not None, 'Number of tokens (num_tokens) must be passed to Electra for randomizing tokens during masked language modeling'

            # 根据概率生成随机标记
            random_token_prob = prob_mask_like(input, self.random_token_prob)
            random_tokens = torch.randint(0, self.num_tokens, input.shape, device=input.device)
            random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids)
            random_token_prob &= ~random_no_mask
            masked_input = torch.where(random_token_prob, random_tokens, masked_input)

            # 从掩码中移除随机标记概率掩码
            masking_mask = masking_mask & ~random_token_prob

        # 将掩码的标记替换为 [mask] 标记
        masked_input = masked_input.masked_fill(masking_mask * replace_prob, self.mask_token_id)

        # 获取生成器输出和 MLM 损失
        logits = self.generator(masked_input, **kwargs)

        mlm_loss = F.cross_entropy(
            logits.transpose(1, 2),
            gen_labels,
            ignore_index = self.pad_token_id
        )

        # 使用之前的掩码选择需要采样的 logits
        sample_logits = logits[mask_indices]

        # 采样
        sampled = gumbel_sample(sample_logits, temperature = self.temperature)

        # 将采样值散布回输入
        disc_input = input.clone()
        disc_input[mask_indices] = sampled.detach()

        # 生成鉴别器标签,替换为 True,原始为 False
        disc_labels = (input != disc_input).float().detach()

        # 获取替换/原始的鉴别器预测
        non_padded_indices = torch.nonzero(input != self.pad_token_id, as_tuple=True)

        # 获取鉴别器输出和二元交叉熵损失
        disc_logits = self.discriminator(disc_input, **kwargs)
        disc_logits = disc_logits.reshape_as(disc_labels)

        disc_loss = F.binary_cross_entropy_with_logits(
            disc_logits[non_padded_indices],
            disc_labels[non_padded_indices]
        )

        # 收集指标
        with torch.no_grad():
            gen_predictions = torch.argmax(logits, dim=-1)
            disc_predictions = torch.round((torch.sign(disc_logits) + 1.0) * 0.5)
            gen_acc = (gen_labels[mask] == gen_predictions[mask]).float().mean()
            disc_acc = 0.5 * (disc_labels[mask] == disc_predictions[mask]).float().mean() + 0.5 * (disc_labels[~mask] == disc_predictions[~mask]).float().mean()

        # 返回加权损失的结果
        return Results(self.gen_weight * mlm_loss + self.disc_weight * disc_loss, mlm_loss, disc_loss, gen_acc, disc_acc, disc_labels, disc_predictions)

.\lucidrains\electra-pytorch\electra_pytorch\__init__.py

# 从 electra_pytorch 模块中导入 Electra 类
from electra_pytorch.electra_pytorch import Electra

.\lucidrains\electra-pytorch\examples\glue\download.py

# 下载和提取数据集的函数
def download_and_extract(task, data_dir):
    # 打印提示信息,指示正在下载和解压缩特定任务的数据
    print("Downloading and extracting %s..." % task)
    # 构建数据文件名,将任务名称与.zip拼接起来
    data_file = "%s.zip" % task
    # 使用 urllib 库下载指定任务的数据文件到本地
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    # 使用 zipfile 库打开下载的数据文件
    with zipfile.ZipFile(data_file) as zip_ref:
        # 解压缩数据文件中的所有内容到指定的数据目录
        zip_ref.extractall(data_dir)
    # 删除已解压缩的数据文件
    os.remove(data_file)
    # 打印提示信息,指示任务数据下载和解压缩完成
    print("\tCompleted!")
# 格式化 MRPC 数据集
def format_mrpc(data_dir, path_to_data):
    # 打印处理 MRPC 数据集的信息
    print("Processing MRPC...")
    # 创建 MRPC 数据集目录
    mrpc_dir = os.path.join(data_dir, "MRPC")
    if not os.path.isdir(mrpc_dir):
        os.mkdir(mrpc_dir)
    # 检查是否提供了数据路径
    if path_to_data:
        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
    else:
        # 如果未提供本地 MRPC 数据路径,则从指定 URL 下载数据
        print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
        mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
        urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
        urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
    # 确保训练和测试数据文件存在
    assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
    assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
    # 下载 MRPC 数据集的 dev_ids.tsv 文件
    urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))

    # 读取 dev_ids.tsv 文件中的内容
    dev_ids = []
    with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
        for row in ids_fh:
            dev_ids.append(row.strip().split('\t'))

    # 处理训练数据和开发数据
    with open(mrpc_train_file, encoding="utf8") as data_fh, \
         open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
         open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
        header = data_fh.readline()
        train_fh.write(header)
        dev_fh.write(header)
        for row in data_fh:
            label, id1, id2, s1, s2 = row.strip().split('\t')
            if [id1, id2] in dev_ids:
                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
            else:
                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))

    # 处理测试数据
    with open(mrpc_test_file, encoding="utf8") as data_fh, \
            open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
        header = data_fh.readline()
        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
        for idx, row in enumerate(data_fh):
            label, id1, id2, s1, s2 = row.strip().split('\t')
            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
    # 打印处理完成信息
    print("\tCompleted!")

# 下载和提取诊断数据集
def download_diagnostic(data_dir):
    print("Downloading and extracting diagnostic...")
    # 创建诊断数据集目录
    if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
        os.mkdir(os.path.join(data_dir, "diagnostic"))
    data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
    # 下载诊断数据集文件
    urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
    # 打印下载和提取完成信息
    print("\tCompleted!")
    return

# 获取指定任务的数据集
def get_tasks(task_names):
    task_names = task_names.split(',')
    if "all" in task_names:
        tasks = TASKS
    else:
        tasks = []
        for task_name in task_names:
            assert task_name in TASKS, "Task %s not found!" % task_name
            tasks.append(task_name)
    return tasks

# 主函数
def main(arguments):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', help='directory to save data to', type=str, default='./data/glue_data')
    parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
                        type=str, default='all')
    parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
                        type=str, default='')
    args = parser.parse_args(arguments)

    # 如果数据保存目录不存在,则创建
    if not os.path.exists(args.data_dir):
        os.makedirs(args.data_dir)
    # 获取需要下载数据的任务列表
    tasks = get_tasks(args.tasks)

    # 遍历任务列表,处理每个任务的数据集
    for task in tasks:
        if task == 'MRPC':
            format_mrpc(args.data_dir, args.path_to_mrpc)
        elif task == 'diagnostic':
            download_diagnostic(args.data_dir)
        else:
            download_and_extract(task, args.data_dir)


if __name__ == '__main__':
    # 解析命令行参数并执行主函数
    sys.exit(main(sys.argv[1:]))

.\lucidrains\electra-pytorch\examples\glue\metrics.py

# 设置文件编码为 UTF-8
# 版权声明,版权归 Google AI Language Team 作者和 HuggingFace Inc. 团队所有,以及 NVIDIA 公司所有
# 根据 Apache 许可证 2.0 版本,除非符合许可证,否则不得使用此文件
# 可以在以下网址获取许可证副本:http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则按“原样”分发软件,不提供任何明示或暗示的担保或条件
# 请查看许可证以获取有关特定语言的权限和限制

# 尝试导入所需的库,如果导入失败则将 _has_sklearn 设置为 False
try:
    from scipy.stats import pearsonr, spearmanr
    from sklearn.metrics import matthews_corrcoef, f1_score

    _has_sklearn = True
except (AttributeError, ImportError):
    _has_sklearn = False

# 检查是否有 sklearn 库可用
def is_sklearn_available():
    return _has_sklearn

# 如果有 sklearn 库可用,则定义以下函数
if _has_sklearn:

    # 计算简单准确率
    def simple_accuracy(preds, labels):
        return (preds == labels).mean()

    # 计算准确率和 F1 分数
    def acc_and_f1(preds, labels):
        acc = simple_accuracy(preds, labels)
        f1 = f1_score(y_true=labels, y_pred=preds)
        return {
            "acc": acc,
            "f1": f1,
            "acc_and_f1": (acc + f1) / 2,
        }

    # 计算 Pearson 相关系数和 Spearman 秩相关系数
    def pearson_and_spearman(preds, labels):
        pearson_corr = pearsonr(preds, labels)[0]
        spearman_corr = spearmanr(preds, labels)[0]
        return {
            "pearson": pearson_corr,
            "spearmanr": spearman_corr,
            "corr": (pearson_corr + spearman_corr) / 2,
        }

    # 计算 GLUE 任务的评估指标
    def glue_compute_metrics(task_name, preds, labels):
        assert len(preds) == len(labels)
        if task_name == "cola":
            return {"mcc": matthews_corrcoef(labels, preds)}
        elif task_name == "sst-2":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "mrpc":
            return acc_and_f1(preds, labels)
        elif task_name == "sts-b":
            return pearson_and_spearman(preds, labels)
        elif task_name == "qqp":
            return acc_and_f1(preds, labels)
        elif task_name == "mnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "mnli-mm":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "qnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "rte":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "wnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "hans":
            return {"acc": simple_accuracy(preds, labels)}
        else:
            raise KeyError(task_name)

    # 计算 XNLI 任务的评估指标
    def xnli_compute_metrics(task_name, preds, labels):
        assert len(preds) == len(labels)
        if task_name == "xnli":
            return {"acc": simple_accuracy(preds, labels)}
        else:
            raise KeyError(task_name)

.\lucidrains\electra-pytorch\examples\glue\processors.py

# 设置文件编码为 UTF-8
# 版权声明,包括作者和团队信息
# 版权声明,版权所有,保留所有权利
# 根据 Apache 许可证 2.0 版本,除非符合许可证,否则不得使用此文件
# 可以在以下网址获取许可证副本
# http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是基于“原样”分发的,没有任何明示或暗示的保证或条件
# 请查看许可证以获取特定语言的权限和限制
""" GLUE processors and helpers """

# 导入日志记录模块
import logging
# 导入操作系统模块
import os

# 导入自定义模块
# from ...file_utils import is_tf_available
from utils import DataProcessor, InputExample, InputFeatures

# 定义一个 lambda 函数,用于检查 TensorFlow 是否可用
is_tf_available = lambda: False

# 如果 TensorFlow 可用,则导入 TensorFlow 模块
if is_tf_available():
    import tensorflow as tf

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)

# 定义函数,将示例转换为特征
def glue_convert_examples_to_features(
    examples,
    tokenizer,
    max_length=512,
    task=None,
    label_list=None,
    output_mode=None,
    pad_on_left=False,
    pad_token=0,
    pad_token_segment_id=0,
    mask_padding_with_zero=True,
):
    """
    Loads a data file into a list of ``InputFeatures``

    Args:
        examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
        tokenizer: Instance of a tokenizer that will tokenize the examples
        max_length: Maximum example length
        task: GLUE task
        label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
        output_mode: String indicating the output mode. Either ``regression`` or ``classification``
        pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
        pad_token: Padding token
        pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4)
        mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
            and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
            actual values)

    Returns:
        If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
        containing the task-specific features. If the input is a list of ``InputExamples``, will return
        a list of task-specific ``InputFeatures`` which can be fed to the model.

    """
    # ���始化变量,用于检查是否为 TensorFlow 数据集
    is_tf_dataset = False
    # 如果 TensorFlow 可用且 examples 是 tf.data.Dataset 类型,则设置为 True
    if is_tf_available() and isinstance(examples, tf.data.Dataset):
        is_tf_dataset = True

    # 如果指定了任务,则创建对应的处理器
    if task is not None:
        processor = glue_processors[task]()
        # 如果标签列表为空,则从处理器中获取标签列表
        if label_list is None:
            label_list = processor.get_labels()
            logger.info("Using label list %s for task %s" % (label_list, task))
        # 如果输出模式为空,则从 GLUE 输出模式中获取
        if output_mode is None:
            output_mode = glue_output_modes[task]
            logger.info("Using output mode %s for task %s" % (output_mode, task))

    # 创建标签映射字典
    label_map = {label: i for i, label in enumerate(label_list)}

    # 初始化特征列表
    features = []
    # 遍历所有的例子,并获取索引和例子内容
    for (ex_index, example) in enumerate(examples):
        # 初始化例子的数量
        len_examples = 0
        # 如果是 TensorFlow 数据集
        if is_tf_dataset:
            # 从张量字典中获取例子
            example = processor.get_example_from_tensor_dict(example)
            # 对例子进行 TFDS 映射
            example = processor.tfds_map(example)
            # 获取例子的数量
            len_examples = tf.data.experimental.cardinality(examples)
        else:
            # 获取例子的数量
            len_examples = len(examples)
        # 每处理 10000 个例子输出日志信息
        if ex_index % 10000 == 0:
            logger.info("Writing example %d/%d" % (ex_index, len_examples))

        # 使用分词器对文本进行编码
        inputs = tokenizer.encode_plus(
            example.text_a, example.text_b, add_special_tokens=True, max_length=max_length, return_token_type_ids=True,
        )
        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]

        # 生成注意力掩码,用于指示哪些是真实标记,哪些是填充标记
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # 对序列进行零填充
        padding_length = max_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
            token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
        else:
            input_ids = input_ids + ([pad_token] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        # 断言输入长度与最大长度相等
        assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
        assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
            len(attention_mask), max_length
        )
        assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
            len(token_type_ids), max_length
        )

        # 根据输出模式处理标签
        if output_mode == "classification":
            label = label_map[example.label]
        elif output_mode == "regression":
            label = float(example.label)
        else:
            raise KeyError(output_mode)

        # 输出前5个例子的信息
        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
            logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
            logger.info("label: %s (id = %d)" % (example.label, label))

        # 将特征添加到列表中
        features.append(
            InputFeatures(
                input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
            )
        )

    # 如果 TensorFlow 可用且是 TensorFlow 数据集
    if is_tf_available() and is_tf_dataset:

        # 生成器函数,用于生成数据集
        def gen():
            for ex in features:
                yield (
                    {
                        "input_ids": ex.input_ids,
                        "attention_mask": ex.attention_mask,
                        "token_type_ids": ex.token_type_ids,
                    },
                    ex.label,
                )

        # 从生成器创建 TensorFlow 数据集
        return tf.data.Dataset.from_generator(
            gen,
            ({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
            (
                {
                    "input_ids": tf.TensorShape([None]),
                    "attention_mask": tf.TensorShape([None]),
                    "token_type_ids": tf.TensorShape([None]),
                },
                tf.TensorShape([]),
            ),
        )

    # 返回特征列表
    return features
class MrpcProcessor(DataProcessor):
    """Processor for the MRPC data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """从张量字典中获取示例。"""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """获取训练集示例。"""
        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """获取开发集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_labels(self):
        """获取标签列表。"""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """为训练集和开发集创建示例。"""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[3]
            text_b = line[4]
            label = line[0]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class MnliProcessor(DataProcessor):
    """Processor for the MultiNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """从张量字典中获取示例。"""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["premise"].numpy().decode("utf-8"),
            tensor_dict["hypothesis"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """获取训练集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """获取开发集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")

    def get_labels(self):
        """获取标签列表。"""
        return ["contradiction", "entailment", "neutral"]

    def _create_examples(self, lines, set_type):
        """为训练集和开发集创建示例。"""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[8]
            text_b = line[9]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class MnliMismatchedProcessor(MnliProcessor):
    """Processor for the MultiNLI Mismatched data set (GLUE version)."""

    def get_dev_examples(self, data_dir):
        """获取开发集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")


class ColaProcessor(DataProcessor):
    """Processor for the CoLA data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """从张量字典中获取示例。"""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """获取训练集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """获取开发集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_labels(self):
        """获取标签列表。"""
        return ["0", "1"]
    # 创建训练集和开发集的示例
    def _create_examples(self, lines, set_type):
        # 初始化示例列表
        examples = []
        # 遍历每一行数据
        for (i, line) in enumerate(lines):
            # 生成示例的唯一标识符
            guid = "%s-%s" % (set_type, i)
            # 获取文本 A 的内容
            text_a = line[3]
            # 获取标签
            label = line[1]
            # 将示例添加到示例列表中
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        # 返回示例列表
        return examples
# 定义处理 SST-2 数据集的 Processor 类
class Sst2Processor(DataProcessor):
    """Processor for the SST-2 data set (GLUE version)."""

    # 从张量字典中获取示例
    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    # 获取训练集示例
    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    # 获取验证集示例
    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    # 获取标签
    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    # 创建训练集和验证集示例
    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[0]
            label = line[1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


# 定义处理 STS-B 数据集的 Processor 类
class StsbProcessor(DataProcessor):
    """Processor for the STS-B data set (GLUE version)."""

    # 从张量字典中获取示例
    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    # 获取训练集示例
    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    # 获取验证集示例
    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    # 获取标签
    def get_labels(self):
        """See base class."""
        return [None]

    # 创建训练集和验证集示例
    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[7]
            text_b = line[8]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


# 定义处理 QQP 数据集的 Processor 类
class QqpProcessor(DataProcessor):
    """Processor for the QQP data set (GLUE version)."""

    # 从张量字典中获取示例
    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["question1"].numpy().decode("utf-8"),
            tensor_dict["question2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    # 获取训练集示例
    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    # 获取验证集示例
    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    # 获取标签
    def get_labels(self):
        """See base class."""
        return ["0", "1"]
    # 创建训练集和开发集的示例
    def _create_examples(self, lines, set_type):
        # 初始化示例列表
        examples = []
        # 遍历每一行数据
        for (i, line) in enumerate(lines):
            # 跳过第一行数据
            if i == 0:
                continue
            # 生成示例的唯一标识符
            guid = "%s-%s" % (set_type, line[0])
            # 尝试获取文本A、文本B和标签信息
            try:
                text_a = line[3]
                text_b = line[4]
                label = line[5]
            # 如果索引超出范围,则跳过该行数据
            except IndexError:
                continue
            # 将示例添加到示例列表中
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        # 返回示例列表
        return examples
class QnliProcessor(DataProcessor):
    """Processor for the QNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """从张量字典中获取示例。"""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["question"].numpy().decode("utf-8"),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """获取训练集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """获取开发集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")

    def get_labels(self):
        """获取标签列表。"""
        return ["entailment", "not_entailment"]

    def _create_examples(self, lines, set_type):
        """为训练集和开发集创建示例。"""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[1]
            text_b = line[2]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class RteProcessor(DataProcessor):
    """Processor for the RTE data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """从张量字典中获取示例。"""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """获取训练集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """获取开发集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_labels(self):
        """获取标签列表。"""
        return ["entailment", "not_entailment"]

    def _create_examples(self, lines, set_type):
        """为训练集和开发集创建示例。"""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[1]
            text_b = line[2]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class WnliProcessor(DataProcessor):
    """Processor for the WNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """从张量字典中获取示例。"""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """获取训练集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """获取开发集示例。"""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_labels(self):
        """获取标签列表。"""
        return ["0", "1"]
    # 创建训练集和验证集的示例
    def _create_examples(self, lines, set_type):
        # 初始化示例列表
        examples = []
        # 遍历每一行数据
        for (i, line) in enumerate(lines):
            # 跳过第一行数据
            if i == 0:
                continue
            # 生成示例的唯一标识符
            guid = "%s-%s" % (set_type, line[0])
            # 获取文本 A
            text_a = line[1]
            # 获取文本 B
            text_b = line[2]
            # 获取标签
            label = line[-1]
            # 创建输入示例对象并添加到示例列表中
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        # 返回示例列表
        return examples
# 定义每个 GLUE 任务对应的标签数量
glue_tasks_num_labels = {
    "cola": 2,  # CoLA 任务有2个标签
    "mnli": 3,  # MNLI 任务有3个标签
    "mrpc": 2,  # MRPC 任务有2个标签
    "sst-2": 2,  # SST-2 任务有2个标签
    "sts-b": 1,  # STS-B 任务有1个标签
    "qqp": 2,  # QQP 任务有2个标签
    "qnli": 2,  # QNLI 任务有2个标签
    "rte": 2,  # RTE 任务有2个标签
    "wnli": 2,  # WNLI 任务有2个标签
}

# 定义每个 GLUE 任务对应的处理器类
glue_processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "mrpc": MrpcProcessor,
    "sst-2": Sst2Processor,
    "sts-b": StsbProcessor,
    "qqp": QqpProcessor,
    "qnli": QnliProcessor,
    "rte": RteProcessor,
    "wnli": WnliProcessor,
}

# 定义每个 GLUE 任务对应的输出模式
glue_output_modes = {
    "cola": "classification",  # CoLA 任务的输出模式为分类
    "mnli": "classification",  # MNLI 任务的输出模式为分类
    "mnli-mm": "classification",  # MNLI-MM 任务的输出模式为分类
    "mrpc": "classification",  # MRPC 任务的输出模式为分类
    "sst-2": "classification",  # SST-2 任务的输出模式为分类
    "sts-b": "regression",  # STS-B 任务的输出模式为回归
    "qqp": "classification",  # QQP 任务的输出模式为分类
    "qnli": "classification",  # QNLI 任务的输出模式为分类
    "rte": "classification",  # RTE 任务的输出模式为分类
    "wnli": "classification",  # WNLI 任务的输出模式为分类
}

.\lucidrains\electra-pytorch\examples\glue\run.py

# 设置文件编码为 UTF-8
# 版权声明,版权归 Google AI Language Team Authors 和 HuggingFace Inc. 团队所有,以及 NVIDIA 公司所有
# 根据 Apache 许可证 2.0 版本使用此文件,详细信息请访问 http://www.apache.org/licenses/LICENSE-2.0
# 除非符合许可证规定或书面同意,否则不得使用此文件
# 根据许可证规定,软件按"原样"分发,不提供任何明示或暗示的担保或条件
# 请查看许可证以获取有关特定语言的权限和限制

""" 在 GLUE 上对库模型进行序列分类微调(Bert、XLM、XLNet、RoBERTa、Albert、XLM-RoBERTa)。"""

# 导入所需的库
import argparse
import glob
import json
import logging
import os
import random

import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

# 导入自定义的计算指标函数
from metrics import glue_compute_metrics as compute_metrics
# 导入数据处理函数
from processors import glue_convert_examples_to_features as convert_examples_to_features
# 导入输出模式
from processors import glue_output_modes as output_modes
# 导入处理器
from processors import glue_processors as processors
# 导入任务标签数量
from processors import glue_tasks_num_labels as task_num_labels

# 设置日志记录器
logger = logging.getLogger(__name__)

##################################################
# 适配 Google 风格的 GLUE 代码

# Tokenizer 适配器类
class TokenizerAdapter:
    def __init__(self, tokenizer, pad_token, cls_token="[CLS]", sep_token="[SEP]"):
        self.tokenizer = tokenizer
        self.pad_token = pad_token
        self.cls_token = cls_token
        self.sep_token = sep_token

    # 将 tokens 转换为 ids
    def convert_tokens_to_ids(self, tokens):
        return self.tokenizer.convert_tokens_to_ids(tokens)

    # 截断序列
    def truncate_sequences(
        self,
        ids,
        pair_ids,
        num_tokens_to_remove,
        truncation_strategy,
        stride,
    ):
        # 确保 ids 的长度大于要移除的 tokens 数量
        assert len(ids) > num_tokens_to_remove
        # 计算窗口长度
        window_len = min(len(ids), stride + num_tokens_to_remove)
        # 获取溢出的 tokens
        overflowing_tokens = ids[-window_len:]
        # 截断 ids
        ids = ids[:-num_tokens_to_remove]

        return (ids, pair_ids, overflowing_tokens)
    # 对输入文本进行编码,生成输入的 token ids 和 token type ids
    def encode_plus(self, text, text_pair, add_special_tokens, max_length, return_token_type_ids):

        # 对第一个文本进行 tokenization,转换成 token ids
        token_ids_0 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
        len_ids = len(token_ids_0)
        # 如果有第二个文本,则对其进行 tokenization,转换成 token ids
        if text_pair:
            token_ids_1 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text_pair))
            len_pair_ids = len(token_ids_1)
        else:
            token_ids_1 = None
            len_pair_ids = 0

 
        # 截断文本
        assert add_special_tokens
        num_special_tokens_to_add = (2 if not text_pair else 3)
        total_len = len_ids + len_pair_ids + num_special_tokens_to_add
        # 如果总长度超过最大长度,则进行截断
        if max_length and total_len > max_length:
            token_ids_0, token_ids_1, overflowing_tokens = self.truncate_sequences(
                token_ids_0,
                pair_ids=token_ids_1,
                num_tokens_to_remove=total_len - max_length,
                truncation_strategy='only_first', # TODO(nijkamp): is this the correct truncation strategy for all GLUE tasks?
                stride=0,
            )


        # 添加特殊 token
        cls = [self.tokenizer.vocab[self.cls_token]]
        sep = [self.tokenizer.vocab[self.sep_token]]

        if not text_pair:

            input_ids = cls + token_ids_0 + sep
            token_type_ids = len(cls + token_ids_0 + sep) * [0]

        else:

            input_ids = cls + token_ids_0 + sep + token_ids_1 + sep
            token_type_ids = len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

        assert len(input_ids) <= max_length

        # 返回编码结果
        return {"input_ids": input_ids, "token_type_ids": token_type_ids}

    # 返回 tokenizer 的词汇表长度
    def __len__(self):
        return len(self.tokenizer.vocab)

    # 保存预训练模型到指定目录
    def save_pretrained(self, outputdir):
        pass
# 将给定的 tokenizer 和 pad_token 封装成 TokenizerAdapter 对象并返回
def wrap_tokenizer(tokenizer, pad_token):
    return TokenizerAdapter(tokenizer, pad_token)


##################################################
# distilled Google-like/HF glue code

# 设置随机种子,确保实验的可重复性
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

# 创建一个学习率调度器,包括线性增加和线性减少学习率
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

# 训练模型
def train(args, train_dataset, model, tokenizer):
    """ Train the model """

    # 设置训练批次大小
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # 创建训练数据采样器
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # 创建训练数据加载器
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    # 计算总的训练步数
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # 准备优化器和调度器(线性增加和减少学习率)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # 检查是否存在保存的优化器或调度器状态
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
        # 加载优化器和调度器状态
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))

    # 如果启用混合精度训练
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # 多 GPU 训练(应在 apex fp16 初始化之后)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # 分布式训练(应在 apex fp16 初始化之后)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
        )

    # 开始训练
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    # 打印训练批次的总大小(包括并行、分布式和累积),根据参数计算得出
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    # 打印梯度累积步数
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    # 打印总优化步数
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # 检查是否从检查点继续训练
    if os.path.exists(args.model_name_or_path):
        # 将 global_step 设置为模型路径中最后一个保存检查点的 global_step
        try:
            global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    # 将模型梯度置零
    model.zero_grad()
    # 创建训练迭代器
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # 为了可重现性而添加在这里
    # 返回 global_step 和 tr_loss/global_step
    return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
    # 循环处理 MNLI 双重评估(匹配,不匹配)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
    eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # 注意 DistributedSampler 会随机采样
        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

        # 多 GPU 评估
        if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
            model = torch.nn.DataParallel(model)

        # 评估!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
                if args.model_type != "distilbert":
                    inputs["token_type_ids"] = (
                        batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
                    )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
                outputs = model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
            print(preds)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)

        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key]))

    return results


def load_and_cache_examples(args, task, tokenizer, evaluate=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # 确保在分布式训练中只有第一个进程处理数据集,其他进程将使用缓存

    processor = processors[task]()
    output_mode = output_modes[task]
    # 从缓存或数据集文件加载数据特征
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
            str(task),
        ),
    )
    # 检查缓存文件是否存在且不覆盖缓存时
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        # 输出日志信息,加载缓存文件中的特征
        logger.info("Loading features from cached file %s", cached_features_file)
        # 从缓存文件中加载特征数据
        features = torch.load(cached_features_file)
    else:
        # 输出日志信息,从数据集文件中创建特征
        logger.info("Creating features from dataset file at %s", args.data_dir)
        # 获取标签列表
        label_list = processor.get_labels()
        # 如果任务是 mnli 或 mnli-mm 且模型类型是 roberta 或 xlmroberta
        if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
            # HACK(在 RoBERTa 预训练模型中交换标签索引)
            label_list[1], label_list[2] = label_list[2], label_list[1]
        # 获取示例数据
        examples = (
            processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
        )
        # 将示例转换为特征
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=args.max_seq_length,
            output_mode=output_mode,
            pad_on_left=False,  # 在 xlnet 中左侧填充
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0,
        )
        # 如果本地进程的索引为 -1 或 0
        if args.local_rank in [-1, 0]:
            # 输出日志信息,将特征保存到缓存文件中
            logger.info("Saving features into cached file %s", cached_features_file)
            # 将特征保存到缓存文件中
            torch.save(features, cached_features_file)

    # 如果本地进程的索引为 0 且不是评估模式
    if args.local_rank == 0 and not evaluate:
        # 确保只有分布式训练中的第一个进程处理数据集,其他进程将使用缓存
        torch.distributed.barrier()

    # 转换为张量并构建数据集
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    # 如果输出模式是分类
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    # 如果输出模��是回归
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)

    # 构建张量数据集
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    return dataset
# 定义主函数,设置默认参数task='MRPC', seed=42, ckpt='output/pretrain/2020-08-28-02-41-37/ckpt/60000'
def main(task='MRPC', seed=42, ckpt='output/pretrain/2020-08-28-02-41-37/ckpt/60000'):
    # 创建参数解析器
    parser = argparse.ArgumentParser()

    # 必需参数
    # 指定输入数据目录,应包含任务的.tsv文件(或其他数据文件)
    parser.add_argument(
        "--data_dir",
        default=f'data/glue_data/{task}',
        type=str,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    # 模型类型,默认为"bert"
    parser.add_argument(
        "--model_type",
        default="bert",
        type=str,
    )
    # 模型名称或路径,默认为ckpt
    parser.add_argument(
        "--model_name_or_path",
        default=ckpt,
        type=str,
    )
    # 词汇表路径,默认为'data/vocab.txt'
    parser.add_argument(
        "--vocab_path",
        default='data/vocab.txt',
        type=str,
    )
    # 任务名称,默认为task
    parser.add_argument(
        "--task_name",
        default=task,
        type=str,
        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
    )
    # 输出目录,默认为'output/glue'
    parser.add_argument(
        "--output_dir",
        default='output/glue',
        type=str,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    # 其他参数
    # 缓存目录,默认为空字符串
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    # 最大序列长度,默认为128
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    # 是否进行训练,默认为True
    parser.add_argument("--do_train", default=True, help="Whether to run training.")
    # 是否在开发集上进行评估,默认为True
    parser.add_argument("--do_eval", default=True, help="Whether to run eval on the dev set.")
    # 训练期间是否进行评估,默认为True
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
    )
    # 是否使用小写模型,默认为True
    parser.add_argument(
        "--do_lower_case", default=True, help="Set this flag if you are using an uncased model.",
    )

    # 每个GPU/CPU的训练批次大小,默认为32
    parser.add_argument(
        "--per_gpu_train_batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.",
    )
    # 每个GPU/CPU的评估批次大小,默认为8
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
    )
    # 累积梯度更新的步数,默认为1
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    # Adam优化器的初始学习率,默认为2e-5
    parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.")
    # 权重衰减,默认为0.0
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    # Adam优化器��epsilon值,默认为1e-8
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    # 最大梯度范数,默认为1.0
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    # 总训练周期数,默认为3.0
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
    )
    # 最大步数,默认为-1
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    # 线性预热步数,默认为0
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    # 每X次更新步骤记录一次日志,默认为500
    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    # 每X次更新步骤保存一次检查点,默认为500
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
    # 添加一个参数,用于评估所有具有与 model_name 相同前缀和以步数结尾的检查点
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    # 添加一个参数,用于避免在可用时使用 CUDA
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    # 添加一个参数,用于覆盖输出目录的内容
    parser.add_argument(
        "--overwrite_output_dir", default=True, help="Overwrite the content of the output directory",
    )
    # 添加一个参数,用于覆盖缓存的训练和评估集
    parser.add_argument(
        "--overwrite_cache", default=True, help="Overwrite the cached training and evaluation sets",
    )
    # 添加一个参数,用于初始化随机种子
    parser.add_argument("--seed", type=int, default=seed, help="random seed for initialization")

    # 添加一个参数,用于指定是否使用 16 位(混合)精度
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    # 添加一个参数,用于指定 fp16 的优化级别
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    # 添加一个参数,用于分布式训练中的本地排名
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    # 添加一个参数,用于远程调试的服务器 IP 地址
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    # 添加一个参数,用于远程调试的服务器端口
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
    # 解析参数
    args = parser.parse_args()

    # 如果输出目录已存在且不为空,并且需要训练且不覆盖输出目录,则引发 ValueError
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    # 如果需要远程调试,则设置远程调试
    if args.server_ip and args.server_port:
        # 远程调试 - 参考 https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # 设置 CUDA、GPU 和分布式训练
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = 1
    args.device = device

    # 设置日志记录
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # 设置随机种子
    set_seed(args)

    # 准备 GLUE 任务
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # 加载预训练模型和分词器
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # 确保只有分布式训练中的第一个进程会下载模型和词汇表

    from transformers import AutoConfig, AutoModelForSequenceClassification
    args.model_type = args.model_type.lower()
    config = AutoConfig.from_pretrained(
        args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    # 从预训练模型中加载自动序列分类模型
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    # 导入自定义的新标记器
    from pretraining.openwebtext.dataset import new_tokenizer
    # 使用新标记器包装标记器,并设置填充标记
    tokenizer = wrap_tokenizer(new_tokenizer(args.vocab_path), pad_token='[PAD]')

    # 如果本地进程的排名为0,则执行分布式训练中的同步操作
    if args.local_rank == 0:
        torch.distributed.barrier()  # 确保只有分布式训练中的第一个进程会下载模型和词汇表

    # 将模型移动到指定设备
    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # 训练
    if args.do_train:
        # 加载并缓存训练数据集示例
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        # 训练模型并获取全局步数和训练损失
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # 保存最佳实践:如果使用默认名称为模型,则可以使用from_pretrained()重新加载它
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # 如果需要,创建输出目录
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # 保存训练后的模型、配置和标记器使用`save_pretrained()`方法
        # 可以使用`from_pretrained()`重新加载它们
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # 处理分布式/并行训练
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # 良好的实践:将训练参数与训练后的模型一起保存
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        # 加载已经微调的训练模型和词汇表
        model = model_to_save
        # TODO(nijkamp): 我们忽略模型序列化
        # model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
        # tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
        model.to(args.device)

    # 评估
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        # TODO(nijkamp): 我们忽略模型序列化
        # tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # 减少日志记录
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            # TODO(nijkamp): 我们忽略模型序列化
            # model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
            model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=prefix)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

    return results
# 如果当前脚本被直接执行,则调用主函数
if __name__ == "__main__":
    main()

.\lucidrains\electra-pytorch\examples\glue\utils.py

# 设置文件编码为 utf-8
# 版权声明,包括作者和团队信息
# 版权声明,版权所有,保留所有权利
# 根据 Apache 许可证 2.0 版本,只有在遵守许可证的情况下才能使用此文件
# 可以在以下网址获取许可证的副本
# http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则按原样分发软件
# 分发的软件基于“原样”基础,没有任何明示或暗示的保证或条件
# 请查看许可证以获取特定语言的权限和限制

# 导入必要的库
import copy
import csv
import dataclasses
import json
import logging
from dataclasses import dataclass
from typing import Optional

# 定义函数 is_torch_available 和 is_tf_available
is_torch_available = lambda: True
is_tf_available = lambda: False

# 获取 logger 对象
logger = logging.getLogger(__name__)

# 定义一个数据类 InputExample,用于表示单个训练/测试示例
@dataclass(frozen=True)
class InputExample:
    """
    A single training/test example for simple sequence classification.

    Args:
        guid: Unique id for the example.
        text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
        text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
        label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
    """

    guid: str
    text_a: str
    text_b: Optional[str] = None
    label: Optional[str] = None

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"

# 定义一个类 InputFeatures,表示单个数据特征集
class InputFeatures(object):
    """
    A single set of features of data.

    Args:
        input_ids: Indices of input sequence tokens in the vocabulary.
        attention_mask: Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            Usually  ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
        token_type_ids: Segment token indices to indicate first and second portions of the inputs.
        label: Label corresponding to the input
    """

    def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.label = label

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

# 定义一个数据处理类 DataProcessor,用于序列分类数据集的数据转换器
class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """Gets an example from a dict with tensorflow tensors
        Args:
            tensor_dict: Keys and values should match the corresponding Glue
                tensorflow_dataset examples.
        """
        raise NotImplementedError()

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()
    # 将给定的示例转换为正确的格式,以适应 GLUE 数据集的要求
    def tfds_map(self, example):
        """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
        This method converts examples to the correct format."""
        # 如果标签数量大于1,则将示例的标签转换为对应的标签值
        if len(self.get_labels()) > 1:
            example.label = self.get_labels()[int(example.label)]
        # 返回转换后的示例
        return example

    # 读取一个以制表符分隔的值文件
    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        # 打开文件并以 UTF-8 编码读取
        with open(input_file, "r", encoding="utf-8-sig") as f:
            # 使用 csv 模块读取文件内容,以制表符为分隔符
            return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
class SingleSentenceClassificationProcessor(DataProcessor):
    """ Generic processor for a single sentence classification data set."""

    def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
        # 初始化函数,设置标签、示例、模式和详细信息
        self.labels = [] if labels is None else labels
        self.examples = [] if examples is None else examples
        self.mode = mode
        self.verbose = verbose

    def __len__(self):
        # 返回示例的数量
        return len(self.examples)

    def __getitem__(self, idx):
        # 获取指定索引的示例
        if isinstance(idx, slice):
            return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
        return self.examples[idx]

    @classmethod
    def create_from_csv(
        cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
    ):
        # 从CSV文件创建处理器
        processor = cls(**kwargs)
        processor.add_examples_from_csv(
            file_name,
            split_name=split_name,
            column_label=column_label,
            column_text=column_text,
            column_id=column_id,
            skip_first_row=skip_first_row,
            overwrite_labels=True,
            overwrite_examples=True,
        )
        return processor

    @classmethod
    def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
        # 从示例创建处理器
        processor = cls(**kwargs)
        processor.add_examples(texts_or_text_and_labels, labels=labels)
        return processor

    def add_examples_from_csv(
        self,
        file_name,
        split_name="",
        column_label=0,
        column_text=1,
        column_id=None,
        skip_first_row=False,
        overwrite_labels=False,
        overwrite_examples=False,
    ):
        # 从CSV文件中添加示例
        lines = self._read_tsv(file_name)
        if skip_first_row:
            lines = lines[1:]
        texts = []
        labels = []
        ids = []
        for (i, line) in enumerate(lines):
            texts.append(line[column_text])
            labels.append(line[column_label])
            if column_id is not None:
                ids.append(line[column_id])
            else:
                guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
                ids.append(guid)

        return self.add_examples(
            texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
        )

    def add_examples(
        self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
    ):
        # 添加示例
        assert labels is None or len(texts_or_text_and_labels) == len(labels)
        assert ids is None or len(texts_or_text_and_labels) == len(ids)
        if ids is None:
            ids = [None] * len(texts_or_text_and_labels)
        if labels is None:
            labels = [None] * len(texts_or_text_and_labels)
        examples = []
        added_labels = set()
        for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids):
            if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
                text, label = text_or_text_and_label
            else:
                text = text_or_text_and_label
            added_labels.add(label)
            examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))

        # Update examples
        if overwrite_examples:
            self.examples = examples
        else:
            self.examples.extend(examples)

        # Update labels
        if overwrite_labels:
            self.labels = list(added_labels)
        else:
            self.labels = list(set(self.labels).union(added_labels))

        return self.examples

    def get_features(
        self,
        tokenizer,
        max_length=None,
        pad_on_left=False,
        pad_token=0,
        mask_padding_with_zero=True,
        return_tensors=None,

.\lucidrains\electra-pytorch\pretraining\openwebtext\arg.py

# 导入必要的模块
import argparse
import dataclasses

# 定义公开的类
__all__ = ('Arg', 'Int', 'Float', 'Bool', 'Str', 'Choice', 'parse_to')

# 定义参数类
class Arg:
    def __init__(self, **kwargs):
        super().__init__()
        self.kwargs = kwargs

# 定义整数参数类
class Int(Arg):
    def __init__(self, **kwargs):
        super().__init__(type=int, **kwargs)

# 定义浮点数参数类
class Float(Arg):
    def __init__(self, **kwargs):
        super().__init__(type=float, **kwargs)

# 定义布尔参数类
class Bool(Arg):
    def __init__(self, **kwargs):
        super().__init__(type=bool, **kwargs)

# 定义字符串参数类
class Str(Arg):
    def __init__(self, **kwargs):
        super().__init__(type=str, **kwargs)

# 定义选择参数类
class _MetaChoice(type):
    def __getitem__(self, item):
        return self(choices=list(item), type=item)

# 定义选择参数类
class Choice(Arg, metaclass=_MetaChoice):
    def __init__(self, choices, **kwargs):
        super().__init__(choices=choices, **kwargs)

# 解析参数并填充到指定的容器类中
def parse_to(container_class, **kwargs):
    # 将字段名转换为命令行参数格式
    def mangle_name(name):
        return '--' + name.replace('_', '-')

    # 创建参数解析器
    parser = argparse.ArgumentParser(description=container_class.__doc__)
    # 遍历容器类的字段
    for field in dataclasses.fields(container_class):
        name = field.name
        default = field.default
        value_or_class = field.type
        # 如果字段类型是类,则使用默认值创建实例
        if isinstance(value_or_class, type):
            value = value_or_class(default=default)
        else:
            value = value_or_class
            value.kwargs['default'] = default
        # 添加参数到参数解析器
        parser.add_argument(
            mangle_name(name), **value.kwargs)

    # 解析参数并返回填充后的容器类实例
    arg_dict = parser.parse_args(**kwargs)
    return container_class(**vars(arg_dict))

.\lucidrains\electra-pytorch\pretraining\openwebtext\dataset.py

import math
import os
import random
from dataclasses import dataclass
from itertools import chain
from functools import partial
from pathlib import Path

import numpy as np

import torch
import torch.utils.data

from openwebtext import tokenization


class ExampleBuilder:
    """Given a stream of input text, creates pretraining examples."""

    def __init__(self, vocab, max_length):
        # 初始化 ExampleBuilder 类,传入词汇表和最大长度参数
        self._vocab = vocab
        self._current_sentences = []  # 当前正在构建的例子的句子列表
        self._current_length = 0  # 当前正在构建的例子的长度
        self._max_length = max_length  # 最大长度
        self._target_length = max_length  # 目标长度

    def add_line(self, bert_tokids):
        """Adds a line of text to the current example being built."""
        # 将一行文本添加到当前正在构建的例子中
        self._current_sentences.append(bert_tokids)  # 将 BERT token ids 添加到当前句子列表中
        self._current_length += len(bert_tokids)  # 更新当前例子的长度
        if self._current_length >= self._target_length:
            return self._create_example()  # 如果当前长度达到目标长度,则创建一个例子
        return None

    def _create_example(self):
        """Creates a pre-training example from the current list of sentences."""
        # 有很小的概率只有一个段落,类似分类任务
        if random.random() < 0.1:
            first_segment_target_length = 100000
        else:
            # -3 是因为输入文本中尚未有 [CLS]/[SEP] 标记
            first_segment_target_length = (self._target_length - 3) // 2

        first_segment = []  # 第一个段落
        second_segment = []  # 第二个段落
        for sentence in self._current_sentences:
            # 如果第一个段落为空,或者加入当前句子不会超过目标长度,或者50%的概率加入当前句子会超过目标长度但第二个段落为空
            if (len(first_segment) == 0 or
                len(first_segment) + len(sentence) < first_segment_target_length or
                (len(second_segment) == 0 and
                len(first_segment) < first_segment_target_length and
                random.random() < 0.5)):
                first_segment += sentence  # 将当前句子加入第一个段落
            else:
                second_segment += sentence  # 将当前句子加入第二个段落

        # 裁剪到最大长度,考虑尚未添加的 [CLS]/[SEP] 标记
        first_segment = first_segment[:self._max_length - 2]
        second_segment = second_segment[:max(0, self._max_length - len(first_segment) - 3)]

        # 准备开始构建下一个例子
        self._current_sentences = []  # 清空当前句子列表
        self._current_length = 0  # 重置当前长度
        # 有很小的概率选择随机长度而不是最大长度
        if random.random() < 0.05:
            self._target_length = random.randint(5, self._max_length)
        else:
            self._target_length = self._max_length

        return self._make_tf_example(first_segment, second_segment)  # 创建 TF 格式的例子
    def _make_tf_example(self, first_segment, second_segment):
        """将两个文本“段”转换为tf.train.Example。"""
        # 获取词汇表
        vocab = self._vocab
        # 构建输入文本的token id序列,包括[CLS]和[SEP]标记
        input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]
        # 初始化段落标识符,全部为0
        segment_ids = [0] * len(input_ids)
        # 如果存在第二个文本段
        if second_segment:
            # 添加第二个文本段的token id序列和段落标识符
            input_ids += second_segment + [vocab["[SEP]"]]
            segment_ids += [1] * (len(second_segment) + 1)
        # 初始化输入掩码,全部为1
        input_mask = [1] * len(input_ids)
        # 将输入文本的token id序列、输入掩码和段落标识符填充至最大长度
        input_ids += [0] * (self._max_length - len(input_ids))
        input_mask += [0] * (self._max_length - len(input_mask))
        segment_ids += [0] * (self._max_length - len(segment_ids)

        # 定义创建整数特征的函数
        def create_int_feature(tensors):
            return torch.tensor(tensors)

        # 构建tf.train.Example对象
        tf_example = {
            "input_ids": create_int_feature(input_ids),
            "input_mask": create_int_feature(input_mask),
            "segment_ids": create_int_feature(segment_ids)
        }
        return tf_example
# 定义一个继承自torch.utils.data.IterableDataset的OpenWebTextDataset类
class OpenWebTextDataset(torch.utils.data.IterableDataset):
    # 初始化方法,接收feature_set_paths和n_tensors_per_file两个参数
    def __init__(self, feature_set_paths, n_tensors_per_file):
        # 将feature_set_paths赋值给实例变量feature_set_paths
        self.feature_set_paths = feature_set_paths
        # 将n_tensors_per_file赋值给实例变量n_tensors_per_file

    # 静态方法,用于解析文件,接收file_index作为参数
    @staticmethod
    def parse_file(file_index):
        # 尝试加载文件内容为features
        try:
            features = torch.load(str(file_index))
            # 生成器,逐个返回features中的元素
            yield from features
        # 捕获RuntimeError异常
        except RuntimeError:
            # 抛出带有文件索引信息的RuntimeError异常
            raise RuntimeError(f'Corrupted file {file_index}')

    # 返回数据集的长度
    def __len__(self):
        return len(self.feature_set_paths) * self.n_tensors_per_file

    # 迭代器方法,返回一个可迭代对象
    def __iter__(self):
        # 使用map函数将parse_file应用于feature_set_paths中的每个元素,然后使用chain.from_iterable将结果展平
        return chain.from_iterable(map(self.parse_file, self.feature_set_paths))


# 定义一个继承自torch.utils.data.IterableDataset的ExampleBuilderDataset类
class ExampleBuilderDataset(torch.utils.data.IterableDataset):
    # 初始化方法,接收dataset和builder两个参数
    def __init__(self, dataset, builder):
        # 将dataset赋值给实例变量dataset
        self.dataset = dataset
        # 将builder赋值给实例变量builder

    # 返回数据集的长度
    def __len__(self):
        return len(self.dataset)

    # 迭代器方法,返回一个可迭代对象
    def __iter__(self):
        # 定义一个内部函数create_example
        def create_example():
            # 无限循环
            while True:
                # 获取下一个dataset元素,转换为CPU上的numpy数组,然后转换为列表
                token_ids = list(next(self.dataset).cpu().numpy())
                # 使用builder的add_line方法添加token_ids,如果返回了example,则返回该example
                example = self.builder.add_line(token_ids)
                if example:
                    return example

        # 无限循环
        while True:
            # 生成器,逐个返回create_example函数的结果
            yield create_example()


# 定义一个循环生成器函数cycle
def cycle(iterable):
    # 无限循环
    while True:
        # 遍历可迭代对象iterable,逐个返回元素
        for x in iterable:
            yield x


# 定义一个函数new_tokenizer,接收vocab_file和do_lower_case两个参数
def new_tokenizer(vocab_file, do_lower_case=True):
    # 返回一个FullTokenizer对象,传入vocab_file和do_lower_case参数
    return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)


# 定义一个函数parse_tokenizer,接收tokenizer和text两个参数
def parse_tokenizer(tokenizer, text):
    # 将text转换为token ids并返回
    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))


# 定义一个函数create_tokenizer,接收vocab_file和do_lower_case两个参数
def create_tokenizer(vocab_file, do_lower_case=True):
    # 创建一个FullTokenizer对象
    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
    # 返回一个partial对象,传入parse_tokenizer函数和tokenizer参数
    return partial(parse_tokenizer, tokenizer)


# 定义一个函数load_owt,接收owt_dir和n_tensors_per_file两个参数
def load_owt(owt_dir, n_tensors_per_file):
    # 将owt_dir转换为Path对象
    owt_dir_path = Path(owt_dir)
    # 获取owt_dir_path目录下的所有文件路径,随机打乱顺序
    feature_set_paths = [owt_dir_path / feature_set_path for feature_set_path in os.listdir(owt_dir_path)]
    np.random.shuffle(feature_set_paths)
    # 断言feature_set_paths长度大于0
    assert len(feature_set_paths) > 0
    # 返回一个OpenWebTextDataset对象,传入feature_set_paths和n_tensors_per_file参数
    return OpenWebTextDataset(feature_set_paths, n_tensors_per_file=n_tensors_per_file)


# 定义一个函数wrap_example_builder,接收dataset、vocab和max_length三个参数
def wrap_example_builder(dataset, vocab, max_length):
    # 返回一个ExampleBuilderDataset对象,传入循环生成器cycle(iter(dataset))和ExampleBuilder(vocab, max_length)参数
    return ExampleBuilderDataset(cycle(iter(dataset)), ExampleBuilder(vocab, max_length))

.\lucidrains\electra-pytorch\pretraining\openwebtext\preprocess.py

import logging
import logging
import math
import multiprocessing
import os
import random
import tarfile
from dataclasses import dataclass
from itertools import chain
from functools import partial
from pathlib import Path

import numpy as np

import torch
import torch.utils.data

from pretraining.openwebtext import arg
from pretraining.openwebtext import tokenization


logger = logging.getLogger(__name__)


def parse_tokenizer(tokenizer, text):
    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))


def create_tokenizer(vocab_file, do_lower_case=True):
    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
    return partial(parse_tokenizer, tokenizer)


def preprocess_owt(tokenizer, src_dir, tmp_dir, trg_dir, n_dataset_building_processes, n_tensors_per_file, max_seq_length=None):
    # Preamble
    logger.info(f'Writing features to {trg_dir}.')
    os.makedirs(trg_dir, exist_ok=False)

    # Crunch files
    trg_dir = Path(trg_dir)
    src_dir = Path(src_dir)
    tmp_dir = Path(tmp_dir)
    archives = os.listdir(src_dir)
    n_archives_per_job = math.ceil(len(archives) / n_dataset_building_processes)
    job_archives = [
        archives[i * n_archives_per_job : (i + 1) * n_archives_per_job]
        for i in range(n_dataset_building_processes)
    ]

    logger.info(f'Processing {len(archives)} archives.')
    assert len(archives) > 0

    if n_dataset_building_processes == 1:
        feature_set_paths = preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0)
    else:
        pool = multiprocessing.Pool(processes=n_dataset_building_processes)
        preprocess_owt_job_partial = partial(preprocess_owt_job, tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length)
        feature_sets = pool.map(preprocess_owt_job_partial, range(n_dataset_building_processes))
        feature_set_paths = [file_path for feature_set in feature_sets for file_path in feature_set]

    return feature_set_paths


def preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0):
    '''
    OpenWebText is saved under the following format:
    openwebtext.zip
        |-> archive_xxx.zip
            |-> file_xxx.txt
            |-> file_xxz.txt
            ...
        |-> archive_xxz.zip
            |-> file_xxy.txt
            ...
        ...
    '''

    # Preamble
    os.makedirs(tmp_dir, exist_ok=True)

    # Process
    feature_index = 0
    feature_set_paths = []
    features = []
    for archive_id, archive in enumerate(job_archives[job_id]):
        if os.path.isdir(src_dir / archive):
            logger.info(f'Ignoring rogue directory {src_dir / archive}.')
            continue

        logger.info(f'Job {job_id}: Processing {archive_id}/{len(job_archives[job_id])} {src_dir / archive}.')

        with tarfile.open(src_dir / archive) as t:
            extracted_archive = tmp_dir / f'{archive}-extracted'
            t.extractall(extracted_archive)

        for file in os.listdir(extracted_archive):
            file_path = extracted_archive / file

            with open(file_path, 'r') as f:
                for line in f.readlines():
                    line = line.strip()
                    if len(line) > 2:
                        encoding = tokenizer(line)
                        features.append(torch.tensor(encoding))

        while len(features) > n_tensors_per_file:
            feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt'
            torch.save(features[:n_tensors_per_file], feature_set_path)
            features = features[n_tensors_per_file:]
            feature_index += 1
            feature_set_paths.append(feature_set_path)

    # Serialize
    # 如果特征列表不为空
    if len(features) > 0:
        # 构建特征集路径,包含作业ID和特征索引
        feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt'
        # 使用torch保存特征到指定路径
        torch.save(features, feature_set_path)
        # 将特征集路径添加到列表中
        feature_set_paths.append(feature_set_path)

    # 返回特征集路径列表
    return feature_set_paths
# 使用 dataclass 装饰器创建一个不可变的参数类 Args,包含默认参数值
@dataclass(frozen=True)
class Args:
    src_dir: arg.Str = 'data/openwebtext'  # 源目录路径参数,默认值为'data/openwebtext'
    trg_dir: arg.Str = 'data/openwebtext_features'  # 目标目录路径参数,默认值为'data/openwebtext_features'
    tmp_dir: arg.Str = '/tmp/owt'  # 临时目录路径参数,默认值为'/tmp/owt'
    vocab_file: arg.Str = 'data/vocab.txt'  # 词汇表文件路径参数,默认值为'data/vocab.txt'
    n_dataset_building_processes: arg.Int = 32  # 数据集构建进程数参数,默认值为32
    n_tensors_per_file: arg.Int = 2048  # 每个文件的张量数参数,默认值为2048
    max_seq_length: arg.Int = 128  # 最大序列长度参数,默认值为128

# 主函数
def main():
    # 解析参数并赋值给 args
    args = arg.parse_to(Args)

    # 配置日志记录器,设置日志格式和级别
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO
    )

    # 创建分词器对象
    tokenizer = create_tokenizer(args.vocab_file)
    # 预处理 openwebtext 数据集
    preprocess_owt(tokenizer=tokenizer, src_dir=args.src_dir, tmp_dir=args.tmp_dir, trg_dir=args.trg_dir, n_dataset_building_processes=args.n_dataset_building_processes, n_tensors_per_file=args.n_tensors_per_file, max_seq_length=args.max_seq_length)

# 如果当前脚本作为主程序运行,则调用主函数
if __name__ == '__main__':
    main()

.\lucidrains\electra-pytorch\pretraining\openwebtext\pretrain.py

# 导入必要的库
import os
import sys

# 获取当前文件所在目录的绝对路径
dir_path = os.path.dirname(os.path.realpath(__file__))
# 获取当前文件所在目录的父目录的绝对路径
parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir))
# 将父目录的路径插入到系统路径中
sys.path.insert(0, parent_dir_path)

# 导入其他必要的库
import random
import logging
from time import time
from dataclasses import dataclass

import numpy as np

import torch
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader

from electra_pytorch import Electra

from openwebtext import arg
from openwebtext.dataset import load_owt, new_tokenizer, wrap_example_builder

logger = logging.getLogger(__name__)

########################################################################################################
## args

# 定义参数类
@dataclass
class Args:
    data_dir: arg.Str = 'data/openwebtext_features'
    data_vocab_file: arg.Str = 'data/vocab.txt'
    data_n_tensors_per_file: arg.Int = 2048
    data_max_seq_length: arg.Int = 128

    gpu: arg.Int = 0
    gpu_enabled: arg.Bool = True
    gpu_deterministic: arg.Bool = False
    gpu_mixed_precision: arg.Bool = False
    distributed_port: arg.Int = 8888
    distributed_enabled: arg.Bool = True
    distributed_world_size: arg.Int = 4

    model_generator: arg.Str = 'pretraining/openwebtext/small_generator.json'
    model_discriminator: arg.Str = 'pretraining/openwebtext/small_discriminator.json'
    model_mask_prob: arg.Float = 0.15

    opt_lr: arg.Float = 5e-4
    opt_batch_size: arg.Int = 128 // (distributed_world_size if distributed_enabled else 1)
    opt_warmup_steps: arg.Int = 10_000
    opt_num_training_steps: arg.Int = 200_000

    step_log: arg.Int = 10
    step_ckpt: arg.Int = 10_000


########################################################################################################
## train

# 定义训练函数
def train(rank, args):

    #######################
    ## distributed

    # 如果启用分布式训练,则初始化进程组
    if args.distributed_enabled:
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='env://',
            world_size=args.distributed_world_size,
            rank=rank)
    # 如果启用 GPU,则选择对应的设备
    if args.gpu_enabled:
        device = torch.device('cuda:{}'.format(rank))
    else:
        device = torch.device('cpu')

    # 判断当前进程是否为主进程
    is_master = True if not args.distributed_enabled else args.distributed_enabled and rank == 0


    #######################
    ## preamble

    # 设置 GPU
    set_gpus(rank)
    # 设置随机种子
    set_seed(rank)
    # 设置 CUDA
    set_cuda(deterministic=args.gpu_deterministic)

    # 创建输出目录
    output_dir = f'{args.output_dir}/{rank}'
    os.makedirs(output_dir, exist_ok=False)

    # 设置日志记录
    setup_logging(filename=f'{output_dir}/output.log', console=is_master)


    #######################
    ## dataset

    # 创建分词器
    tokenizer = new_tokenizer(vocab_file=args.data_vocab_file)
    vocab_size = len(tokenizer.vocab)
    # 加载数据集
    ds_train = wrap_example_builder(dataset=load_owt(owt_dir=args.data_dir, n_tensors_per_file=args.data_n_tensors_per_file), vocab=tokenizer.vocab, max_length=args.data_max_seq_length)

    # 获取特殊标记的 ID
    pad_token_id = tokenizer.vocab['[PAD]']
    mask_token_id = tokenizer.vocab['[MASK]']
    cls_token_id = tokenizer.vocab['[CLS]']
    sep_token_id = tokenizer.vocab['[SEP]']

    # 断言特殊标记的 ID 符合预期
    assert pad_token_id == 0
    assert cls_token_id == 101
    assert sep_token_id == 102
    assert mask_token_id == 103

    # 定义数据加载函数
    def collate_batch(examples):
        input_ids = torch.nn.utils.rnn.pad_sequence([example['input_ids'] for example in examples], batch_first=True, padding_value=pad_token_id)
        input_mask = torch.nn.utils.rnn.pad_sequence([example['input_mask'] for example in examples], batch_first=True, padding_value=pad_token_id)
        segment_ids = torch.nn.utils.rnn.pad_sequence([example['segment_ids'] for example in examples], batch_first=True, padding_value=pad_token_id)
        return input_ids, input_mask, segment_ids

    # 定义数据集加载器
    def cycle(iterable):
        while True:
            for x in iterable:
                yield x

    ds_train_loader = iter(cycle(DataLoader(ds_train, batch_size=args.opt_batch_size, collate_fn=collate_batch)))


    #######################
    ## model
    # 如果分布式模式未启用,则返回原始模型;否则返回使用分布式数据并行的模型
    def to_distributed_model(model):
        return model if not args.distributed_enabled else torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)

    # 将生成器和鉴别器的权重绑定在一起
    def tie_weights(generator, discriminator):
        generator.electra.embeddings.word_embeddings = discriminator.electra.embeddings.word_embeddings
        generator.electra.embeddings.position_embeddings = discriminator.electra.embeddings.position_embeddings
        generator.electra.embeddings.token_type_embeddings = discriminator.electra.embeddings.token_type_embeddings

    # 定义一个适配器类,用于调整模型输出的格式
    class LogitsAdapter(torch.nn.Module):
        def __init__(self, adaptee):
            super().__init__()
            self.adaptee = adaptee

        def forward(self, *args, **kwargs):
            return self.adaptee(*args, **kwargs)[0]

    # 导入所需的库和模型配置
    from transformers import AutoConfig, ElectraForMaskedLM, ElectraForPreTraining

    # 创建生成器和鉴别器模型
    generator = ElectraForMaskedLM(AutoConfig.from_pretrained(args.model_generator))
    discriminator = ElectraForPreTraining(AutoConfig.from_pretrained(args.model_discriminator))

    # 将生成器和鉴别器的权重绑定在一起
    tie_weights(generator, discriminator)

    # 创建分布式模型,并设置相关参数
    model = to_distributed_model(Electra(
        LogitsAdapter(generator),
        LogitsAdapter(discriminator),
        num_tokens = vocab_size,
        mask_token_id = mask_token_id,
        pad_token_id = pad_token_id,
        mask_prob = args.model_mask_prob,
        mask_ignore_token_ids = [tokenizer.vocab['[CLS]'], tokenizer.vocab['[SEP]'],
        random_token_prob = 0.0).to(device))

    #######################
    ## optimizer

    # 定义一个带有热身阶段的线性学习率调度器
    def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
        def lr_lambda(current_step):
            learning_rate = max(0.0, 1. - (float(current_step) / float(num_training_steps)))
            learning_rate *= min(1.0, float(current_step) / float(num_warmup_steps))
            return learning_rate
        return LambdaLR(optimizer, lr_lambda, last_epoch)

    # 获取不需要权重衰减的参数
    def get_params_without_weight_decay_ln(named_params, weight_decay):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)],
                'weight_decay': weight_decay,
            },
            {
                'params': [p for n, p in named_params if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
            },
        ]
        return optimizer_grouped_parameters

    # 创建优化器和学习率调度器
    optimizer = torch.optim.AdamW(get_params_without_weight_decay_ln(model.named_parameters(), weight_decay=0.1), lr=args.opt_lr, betas=(0.9, 0.999), eps=1e-08)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.opt_warmup_steps, num_training_steps=args.opt_num_training_steps)
    scaler = torch.cuda.amp.GradScaler(enabled=args.gpu_mixed_precision)

    #######################
    ## train

    # 记录训练开始时间,步长速度和预计完成时间
    t, steps_s, eta_m = time(), 0., 0
    # 循环执行训练步骤,包括优化器更新、梯度裁剪、学习率调整等
    for step in range(args.opt_num_training_steps+1):
        # 从训练数据加载下一个批次的输入数据
        input_ids, input_mask, segment_ids = next(ds_train_loader)

        # 将输入数据移动到指定设备上
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)

        # 断言输入数据的序列长度不超过设定的最大长度
        assert input_ids.shape[1] <= args.data_max_seq_length

        # 梯度清零
        optimizer.zero_grad()

        # 使用混合精度训练,计算损失和准确率
        with torch.cuda.amp.autocast(enabled=args.gpu_mixed_precision):
            loss, loss_mlm, loss_disc, acc_gen, acc_disc, disc_labels, disc_pred = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids)

        # 反向传播并调整优化器参数
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        # 记录训练指标
        metrics = {
            'step': (step, '{:8d}'),
            'loss': (loss.item(), '{:8.5f}'),
            'loss_mlm': (loss_mlm.item(), '{:8.5f}'),
            'loss_disc': (loss_disc.item(), '{:8.5f}'),
            'acc_gen': (acc_gen.item(), '{:5.3f}'),
            'acc_disc': (acc_disc.item(), '{:5.3f}'),
            'lr': (scheduler.get_last_lr()[0], '{:8.7f}'),
            'steps': (steps_s, '{:4.1f}/s'),
            'eta': (eta_m, '{:4d}m'),
        }

        # 每隔一定步数打印训练指标信息
        if step % args.step_log == 0:
            sep = ' ' * 2
            logger.info(sep.join([f'{k}: {v[1].format(v[0])}' for (k, v) in metrics.items()])

        # 每隔一定步数计算训练速度和预计剩余时间
        if step > 0 and step % 100 == 0:
            t2 = time()
            steps_s = 100. / (t2 - t)
            eta_m = int(((args.opt_num_training_steps - step) / steps_s) // 60)
            t = t2

        # 每隔一定步数打印部分标签和预测结果
        if step % 200 == 0:
            logger.info(np.array2string(disc_labels[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize))
            logger.info(np.array2string(disc_pred[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize))

        # 每隔一定步数保存模型检查点
        if step > 0 and step % args.step_ckpt == 0 and is_master:
            discriminator.electra.save_pretrained(f'{args.output_dir}/ckpt/{step}')
# 设置程序在哪块 GPU 上运行
def set_gpus(gpu):
    torch.cuda.set_device(gpu)

# 设置随机种子
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    如果 CUDA 可用,设置 CUDA 随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

# 设置 CUDA 是否确定性
def set_cuda(deterministic=True):
    如果 CUDA 可用,设置 CUDA 是否确定性
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = deterministic
        torch.backends.cudnn.benchmark = not deterministic

# 获取实验 ID
def get_exp_id(file):
    返回文件名的基本名称(不包含扩展名)
    return os.path.splitext(os.path.basename(file))[0]

# 获取输出目录
def get_output_dir(exp_id):
    导入 datetime 模块
    t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    创建输出目录路径
    output_dir = os.path.join('output/' + exp_id, t)
    如果输出目录不存在,则创建
    os.makedirs(output_dir, exist_ok=True)
    返回输出目录路径
    return output_dir

# 设置日志记录
def setup_logging(filename, console=True):
    设置日志格式
    log_format = logging.Formatter("%(asctime)s : %(message)s")
    获取日志记录器
    logger = logging.getLogger()
    清空日志记录器的处理器
    logger.handlers = []
    创建文件处理器
    file_handler = logging.FileHandler(filename)
    设置文件处理器的格式
    file_handler.setFormatter(log_format)
    添加文件处理器到日志记录器
    logger.addHandler(file_handler)
    如果需要在控制台输出日志
    if console:
        创建控制台处理器
        console_handler = logging.StreamHandler(sys.stdout)
        设置控制台处理器的格式
        console_handler.setFormatter(log_format)
        添加控制台处理器到日志记录器
        logger.addHandler(console_handler)
        设置日志记录器的日志级别为 INFO
        logger.setLevel(logging.INFO)
    返回日志记录器
    return logger

# 复制源文件到输出目录
def copy_source(file, output_dir):
    导入 shutil 模块
    复制源文件到输出目录
    shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))

# 主函数
def main():

    # preamble
    获取实验 ID
    exp_id = get_exp_id(__file__)
    获取输出目录
    output_dir = get_output_dir(exp_id)
    如果输出目录不存在,则创建
    os.makedirs(output_dir, exist_ok=True)
    创建检查点目录
    os.makedirs(f'{output_dir}/ckpt', exist_ok=False)
    复制源文件到输出目录
    copy_source(__file__, output_dir)

    # args
    解析命令行参数
    args = arg.parse_to(Args)
    设置输出目录和实验 ID
    args.output_dir = output_dir
    args.exp_id = exp_id

    # distributed
    如果启用分布式训练
    if args.distributed_enabled:
        设置主地址和端口
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = str(args.distributed_port)
        使用多进程方式启动训练
        torch.multiprocessing.spawn(train, nprocs=args.distributed_world_size, args=(args,))
    否则
    else:
        单机训练
        train(rank=args.gpu, args=args)

# 如果当前脚本作为主程序运行,则调用主函数
if __name__ == '__main__':
    main()

.\lucidrains\electra-pytorch\pretraining\openwebtext\tokenization.py

# 设置文件编码为 utf-8
# 版权声明
# 根据 Apache 许可证 2.0 版本授权
# 获取许可证的链接
# 根据适用法律或书面同意,软件按"原样"分发,不提供任何明示或暗示的担保或条件
# 请查看许可证以获取特定语言的权限和限制

"""Tokenization classes, the same as used for BERT."""

# 导入必要的库
import collections
import unicodedata

# 将输入文本转换为 Unicode 编码(如果尚未转换),假定输入为 utf-8
def convert_to_unicode(text):
    if isinstance(text, str):
        return text
    elif isinstance(text, bytes):
        return text.decode("utf-8", "ignore")
    else:
        raise ValueError("Unsupported string type: %s" % (type(text)))

# 返回适合打印的文本编码方式
def printable_text(text):
    if isinstance(text, str):
        return text
    elif isinstance(text, bytes):
        return text.decode("utf-8", "ignore")
    else:
        raise ValueError("Unsupported string type: %s" % (type(text)))

# 加载词汇表文件到字典中
def load_vocab(vocab_file):
    vocab = collections.OrderedDict()
    index = 0
    with open(vocab_file, "r") as reader:
        while True:
            token = convert_to_unicode(reader.readline())
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab

# 使用词汇表将序列 [tokens|ids] 转换
def convert_by_vocab(vocab, items):
    output = []
    for item in items:
        output.append(vocab[item])
    return output

# 将 tokens 转换为 ids
def convert_tokens_to_ids(vocab, tokens):
    return convert_by_vocab(vocab, tokens)

# 将 ids 转换为 tokens
def convert_ids_to_tokens(inv_vocab, ids):
    return convert_by_vocab(inv_vocab, ids)

# 基本的空格分词函数
def whitespace_tokenize(text):
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens

# 完整的分词器类
class FullTokenizer(object):
    def __init__(self, vocab_file, do_lower_case=True):
        self.vocab = load_vocab(vocab_file)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

    def tokenize(self, text):
        split_tokens = []
        for token in self.basic_tokenizer.tokenize(text):
            for sub_token in self.wordpiece_tokenizer.tokenize(token):
                split_tokens.append(sub_token)
        return split_tokens

    def convert_tokens_to_ids(self, tokens):
        return convert_by_vocab(self.vocab, tokens)

    def convert_ids_to_tokens(self, ids):
        return convert_by_vocab(self.inv_vocab, ids)

# 基本的分词器类
class BasicTokenizer(object):
    def __init__(self, do_lower_case=True):
        self.do_lower_case = do_lower_case
    def tokenize(self, text):
        """Tokenizes a piece of text."""
        # 将文本转换为 Unicode 格式
        text = convert_to_unicode(text)
        # 清理文本数据
        text = self._clean_text(text)

        # 为多语言和中文模型添加的功能,对中文字符进行处理
        text = self._tokenize_chinese_chars(text)

        # 使用空格分隔文本,得到原始 token 列表
        orig_tokens = whitespace_tokenize(text)
        split_tokens = []
        for token in orig_tokens:
            # 如果需要转换为小写,则将 token 转换为小写
            if self.do_lower_case:
                token = token.lower()
                # 去除 token 中的重音符号
                token = self._run_strip_accents(token)
            # 根据标点符号分割 token
            split_tokens.extend(self._run_split_on_punc(token))

        # 使用空格分隔 token 列表,得到最终的输出 token 列表
        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        # 将文本中的重音符号去除
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            cat = unicodedata.category(char)
            if cat == "Mn":
                continue
            output.append(char)
        return "".join(output)

    def _run_split_on_punc(self, text):
        """Splits punctuation on a piece of text."""
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        return ["".join(x) for x in output]

    def _tokenize_chinese_chars(self, text):
        """Adds whitespace around any CJK character."""
        output = []
        for char in text:
            cp = ord(char)
            if self._is_chinese_char(cp):
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # 判断字符是否为中文字符
        if ((cp >= 0x4E00 and cp <= 0x9FFF) or
                (cp >= 0x3400 and cp <= 0x4DBF) or
                (cp >= 0x20000 and cp <= 0x2A6DF) or
                (cp >= 0x2A700 and cp <= 0x2B73F) or
                (cp >= 0x2B740 and cp <= 0x2B81F) or
                (cp >= 0x2B820 and cp <= 0x2CEAF) or
                (cp >= 0xF900 and cp <= 0xFAFF) or
                (cp >= 0x2F800 and cp <= 0x2FA1F)):
            return True

        return False

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        for char in text:
            cp = ord(char)
            # 移除无效字符和空白字符
            if cp == 0 or cp == 0xfffd or _is_control(char):
                continue
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)
class WordpieceTokenizer(object):
    """Runs WordPiece tokenziation."""

    def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
        # 初始化 WordpieceTokenizer 类,设置词汇表、未知标记和单词最大字符数
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """Tokenizes a piece of text into its word pieces.

        This uses a greedy longest-match-first algorithm to perform tokenization
        using the given vocabulary.

        For example:
            input = "unaffable"
            output = ["un", "##aff", "##able"]

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through `BasicTokenizer.

        Returns:
            A list of wordpiece tokens.
        """

        text = convert_to_unicode(text)

        output_tokens = []
        for token in whitespace_tokenize(text):
            chars = list(token)
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        return output_tokens


def _is_whitespace(char):
    """Checks whether `chars` is a whitespace character."""
    # \t, \n, and \r are technically contorl characters but we treat them
    # as whitespace since they are generally considered as such.
    if char == " " or char == "\t" or char == "\n" or char == "\r":
        return True
    cat = unicodedata.category(char)
    if cat == "Zs":
        return True
    return False


def _is_control(char):
    """Checks whether `chars` is a control character."""
    # These are technically control characters but we count them as whitespace
    # characters.
    if char == "\t" or char == "\n" or char == "\r":
        return False
    cat = unicodedata.category(char)
    if cat.startswith("C"):
        return True
    return False


def _is_punctuation(char):
    """Checks whether `chars` is a punctuation character."""
    cp = ord(char)
    # We treat all non-letter/number ASCII as punctuation.
    # Characters such as "^", "$", and "`" are not in the Unicode
    # Punctuation class but we treat them as punctuation anyways, for
    # consistency.
    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
            (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False
posted @ 2024-06-28 14:05  绝不原创的飞龙  阅读(13)  评论(0编辑  收藏  举报