ERNIE代码解析

©原创作者 |疯狂的Max

ERNIE代码解读

考虑到ERNIE使用BRET作为基础模型,为了让没有基础的NLPer也能够理解代码,笔者将先为大家简略的解读BERT模型的结构,完整代码可以参见[1]。

01 BERT的结构组成

BERT的代码最主要的是由分词模块、训练数据预处理、模型结构模块等几部分组成。

1.1 分词模块

模型在训练之前,需要对输入文本进行切分,并将切分的子词转换为对应的ID。这一功能主要由BertTokenizer来实现,主要在
/models/bert/tokenization_bert.py实现。

BertTokenizer 是基于BasicTokenizer和WordPieceTokenizer 的分词器:

BasicTokenizer负责按标点、空格等分割句子,并处理是否统一小写,以及清理非法字符。

WordPieceTokenizer在词的基础上,进一步将词分解为子词(subword)。

具有以下使用方法:

  • from_pretrained:从包含词表文件(vocab.txt)的目录中初始化一个分词器;
  • tokenize:将文本分解为子词列表;
  • convert_tokens_to_ids:将子词转化为子词对应的下标;
  • convert_ids_to_tokens :将对应下标转化为子词;
  • encode:对于单个句子,分解词并加入特殊词形成“[CLS], x, [SEP]”的结构并转换为词表对应下标的列表;
  • decode:将encode的输出转换为句子。

1.2 训练数据预处理

训练数据的构建主要取决于预训练的任务,由于BERT的预训练任务包括预测上下句和掩码词预测是否为连续句,那么其训练数据就需要随机替换连续的语句和其中的分词,这部分功能由run_pretraining.py中的函数
create_instances_from_document实现。

该部分首先构建上下句,拼接[cls]和[sep]等特殊符号的id,构建长度为512的列表,然后根据论文中所使用的指定概率选择要掩码的子词,这部分由函数
create_masked_lm_predictions实现。

1.3 模型结构

BERT模型主要由BertEmbeddings类、BertEncoder类组成,前者负责将子词、位置和上下句标识(segment)投影成向量,后者实现文本的编码。

编码器BertEncoder又由12层相同的编码块BertLayer组成。每一层都由自注意力层BertSelfAttention和前馈神经网络层BertIntermediate以及输出层BertOutput构成,在
/models/bert/modeling_bert.py中实现。

每一层编码层的结构和功能如下:

  • BertSelfAttention:负责实现子词之间的相互关注。注意,多头自注意力机制的实现是通过将维度为hidden_size 的表示向量切分成n个维度为hidden_size / n的向量,再对切分的向量分别进行编码,最后拼接编码后的向量实现的;
  • BertIntermediate:将批次数据(三维张量)做矩阵相乘和非线性变化;
  • BertOutput :实现归一化和残差连接;

工程小技巧: 如果模型在学习表示向量的过程中需要使用不同的编码方式,以结合图神经网络层和Transformer编码层为例,笔者建议尽量使用相同的参数初始化方式,两者都使用残差连接,这能够避免模型训练时出现梯度爆炸的问题。

此外是否需要对注意力权重进行大小的变化,如Transformer会除以向量维度的开方,则取决于图神经网络的层数,一般而言,仅使用两层或以下的图神经网络层,则无需对注意力权重做变化。

具体可以通过观察图神经网络层生成的表示向量的大小是否和Transformer编码层生成的向量大小在同一个数量级来决定,如果在同一个数量级则无需改变注意力权重,如果出现梯度爆炸的现象,那么则可以缩小注意力的权重。

02 从BERT到ERNIE

由于ERNIE是在BERT的基础上进行改进,在数据层面需要构建与文本对应的实体序列,在预训练层面加入了新的预训练任务,那么在代码上就对应着训练数据预处理和模型结构这两方面的改动。因此笔者也将重点针对这两个方面进行讲解,完整代码参见[2]。

其代码结构主要包含两大模块,训练数据预处理模块和模型构建模块。

2.1 训练数据预处理模块

ERNIE模型的知识注入依赖于找到文本中存在的实体,这些实体是指具有意义的抽象或者具象的单个名词或名词短语,我们可以将其称为文本指称项(mention)。一个实体可以有多个别名,也就意味着一个实体可以对应着文本中的多个指称项。

为了能够找到文本语料中实体,作者使用维基百科作为ERNIE的训练语料,将维基百科中具有超链接的名词或者短语作为实体,利用这一现有资源能够大大的简化检索实体的难度。

2.1.1 训练数据构建

在利用现有抽取工具获得语料和实体名文件后,通过
pretrain_data/create_insts.py构建训练数据。

我们知道在训练之前,首先需要对语料进行分词(tokenize),获得子词(tokens),然后根据词典得到子词的索引ID,模型在接收索引后将其投影成向量。从BERT的代码中我们可以知道,BERT首先构建用于下一句预测(next sentences prediction)所需要的上下句,并从中随机选择掩码词,生成用于自注意力阶段的掩码列表。

那么为了能够注入语句中对应的实体,ERNIE就需要在这一过程中创建和训练语料等长的实体ID张量,以及对应的掩码列表。

作者仅仅对文本指称项第一个子词所对应的位置标注实体ID,这也就意味模型仅使用第一个子词向量预测实体。这种做法能够直接复用BERT的代码,而无需单独针对实体序列再构建训练数据,减轻了工程实现的工作量。

for i, x in enumerate(vec):
    if x == "#UNK#":
        vec[i] = -1
    elif x[0] == "Q":
        if x in d:
           vec[i] = d[x]
           if i != 0 and vec[i] == vec[i-1]:
           # 以某个实体为例,Q123 Q123 Q123 -> d[Q123] -1 -1,仅在第一个子词中记录实体的ID,其他位置标志为-1
               vec[i] = -1  
           else:
               vec[i] = -1
#函数 create_instances_from_document
    // 获取句子a和b的实体和子词   
    tokens = [101] + tokens_a + [102] + tokens_b + [102]
    entity = [-1] + entity_a + [-1] + entity_b + [-1]
    // 构造用于为数据构建索引的对象ds,并将对应的输入语料id列表及掩码列表,实体id列表和掩码列表等训练数据存入ds。
    ds.add_item(torch.IntTensor(input_ids+input_mask+segment_ids
            +masked_lm_labels+entity+entity_mask+[next_sentence_label]))

2.1.2 实体向量加载

BERT由于具有经过预训练的向量表,子词的ID值可以利用nn.embedding模块获取投影向量。

那么实体的向量是经过TransE表示学习获得的,又应该如何让模型获取其投影向量呢?作者在code/iteration.py中自定义数据迭代器对象,该对象在返回数据时会调用
torch.utils.data.DataLoader,通过在该函数中传入负责投影实体向量的函数collate_fn,能够让模型在加载数据时获取实体的表示向量。

#类 EpochBatchIterator(object):
    return CountingIterator(torch.utils.data.DataLoader(
            self.dataset,
            # collate_fn是传入实体向量的关键
            collate_fn=self.collate_fn,
            batch_sampler=batches,
        ))
#函数collate_fn:
def collate_fn(x):
    x = torch.LongTensor([xx for xx in x])
    entity_idx = x[:, 4*args.max_seq_length:5*args.max_seq_length]
    # embed = torch.nn.Embedding.from_pretrained(embed)
    # embed为加载了经过预训练的二维实体张量
    uniq_idx = np.unique(entity_idx.numpy())
    ent_candidate = embed(torch.LongTensor(uniq_idx+1))

2.2 模型结构模块

在模型方面,作者依旧使用12层Transformer编码层作为模型结构,与BERT所不同的是,在前6层沿用BERT的Transformer编码层,但在第7层自定义知识融合层BertLayerMix,首次对经过对齐的实体向量和指称项向量求和,并将其分别传输给知识编码模块和文本编码模块,在剩下5层自定义知识编码层BertLayer,分别对经过融合了两者信息的实体序列和文本序列使用自注意力机制编码。

模型的前5层就是论文所指的文本编码器,后面的7层编码层则构成了论文中的知识编码器。

对于BERT的Transformer编码层,由于第一部分已经介绍过,就不再赘述。下文主要针对作者自定义的编码层做详细解读。

2.2.1 知识融合层BertLayerMix

具体来说,知识融合层BertLayerMix由自注意力层BertAttention_simple、融合层BertIntermediate以及输出层BertOutput构成。

class BertLayerMix(nn.Module):
    def __init__(self, config):
        super(BertLayerMix, self).__init__()
        self.attention = BertAttention_simple(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
     # 该编码层仅针对文本进行自注意力操作、矩阵相乘和残差连接
    def forward(self, hidden_states, attention_mask, hidden_states_ent, attention_mask_ent, ent_mask):
        attention_output = self.attention(hidden_states, attention_mask)
        attention_output_ent = hidden_states_ent * ent_mask
        # intermediate层负责实体和文本向量求和,并对求和向量非线性变化
        intermediate_output = self.intermediate(attention_output, attention_output_ent)
        # 然后通过输出层output再次归一化和残差连接
        layer_output, layer_output_ent = self.output(intermediate_output, attention_output, attention_output_ent)
        return layer_output, layer_output_ent

自注意力层BertAttention_simple由BertSelfAttention和BertSelfOutput构成,前者负责对文本进行自注意力操作,实现上与BERT的自注意力操作相同,就不再展示代码。后者则用于对向量进行矩阵变化和残差连接,生成attention_output

class BertAttention_simple(nn.Module):
    def __init__(self, config):
        super(BertAttention_simple, self).__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)


    def forward(self, input_tensor, attention_mask):
        self_output = self.self(input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
        return attention_output
class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)


    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

前馈神经网络层BertIntermediate负责将两者进行线性变化转换为同一维度,求和并做非线性变化。

class BertIntermediate(nn.Module):
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.dense_ent = nn.Linear(100, config.intermediate_size)
        self.intermediate_act_fn = ACT2FN[config.hidden_act] \
            if isinstance(config.hidden_act, str) else config.hidden_act
    def forward(self, hidden_states, hidden_states_ent):
        # 线性变化转换为同一维度
        hidden_states_ = self.dense(hidden_states)
        hidden_states_ent_ = self.dense_ent(hidden_states_ent)
        # 求和并使用intermediate_act_fn做非线性变化
        hidden_states = self.intermediate_act_fn(hidden_states_+hidden_states_ent_)
        return hidden_states

最终使用BertOutput分别对文本向量和实体向量做矩阵相乘,将经过融合的向量和两者残差连接,并做归一化操作。

class BertOutput(nn.Module):
    def __init__(self, config):
        super(BertOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dense_ent = nn.Linear(config.intermediate_size, 100)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.LayerNorm_ent = BertLayerNorm(100, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    def forward(self, hidden_states_, input_tensor, input_tensor_ent):
        # 针对文本向量矩阵相乘
        hidden_states = self.dense(hidden_states_)
        hidden_states = self.dropout(hidden_states)
        # 针对文本向量残差连接和归一化
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 针对实体向量的矩阵相乘、残差连接和归一化
        hidden_states_ent = self.dense_ent(hidden_states_)
        hidden_states_ent = self.dropout(hidden_states_ent)
        hidden_states_ent = self.LayerNorm_ent(hidden_states_ent + input_tensor_ent)
        return hidden_states, hidden_states_ent

2.2.2 知识编码层BertLayer

该编码层针对融合后的实体向量和文本向量分别进行自注意力编码,从而使实体序列中的所有实体也能够实现相互关注。

再次基础上实体向量将和对应位置的文本向量求和,将实体信息传递给文本向量,从而使整个文本序列在下一个编码层中实现对实体序列的关注。

class BertLayer(nn.Module):
    def __init__(self, config):
        super(BertLayer, self).__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)


    def forward(self, hidden_states, attention_mask, hidden_states_ent, attention_mask_ent, ent_mask):
        attention_output, attention_output_ent = self.attention(hidden_states, attention_mask, hidden_states_ent, attention_mask_ent)
        attention_output_ent = attention_output_ent * ent_mask
        intermediate_output = self.intermediate(attention_output, attention_output_ent)
        layer_output, layer_output_ent = self.output(intermediate_output, attention_output, attention_output_ent)
        # layer_output_ent = layer_output_ent * ent_mask
        return layer_output, layer_output_ent

这一编码层自定义了自注意力层,其中针对实体的自注意力层仅使用4个注意力头。

class BertAttention(nn.Module):
    def __init__(self, config):
        super(BertAttention, self).__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)
        config_ent = copy.deepcopy(config)
        config_ent.hidden_size = 100
        config_ent.num_attention_heads = 4
        self.self_ent = BertSelfAttention(config_ent)
        self.output_ent = BertSelfOutput(config_ent)
    def forward(self, input_tensor, attention_mask, input_tensor_ent, attention_mask_ent):
        # BertSelfAttention对文本向量进行自注意力操作
        self_output = self.self(input_tensor, attention_mask)
        self_output_ent = self.self_ent(input_tensor_ent, attention_mask_ent)
        # BertSelfAttention对实体向量进行自注意力操作
        attention_output = self.output(self_output, input_tensor)
        attention_output_ent = self.output_ent(self_output_ent, input_tensor_ent)
        return attention_output, attention_output_ent

输出层同知识融合层一样,都是使用BERToutput实现归一化和残差连接。

03 源代码参考

[1] https://github.com/google-research/bert

[2] https://github.com/thunlp/ERNIE

 

私信我领取目标检测与R-CNN/数据分析的应用/电商数据分析/数据分析在医疗领域的应用/NLP学员项目展示/中文NLP的介绍与实际应用/NLP系列直播课/NLP前沿模型训练营等干货学习资源。

posted @ 2022-01-28 13:54  NLP论文解读  阅读(718)  评论(0编辑  收藏  举报