Bert源码解读(三)之预训练部分
一、Masked LM
get_masked_lm_output
函数用于计算「任务#1」的训练 loss。输入为 BertModel 的最后一层 sequence_output 输出([batch_size, seq_length, hidden_size]),先找出输出结果中masked掉的词,然后构建一层全连接网络,接着构建一层节点数为vocab_size的softmax输出,从而与真实label计算损失。
def get_masked_lm_output(bert_config, input_tensor, #BertModel的最后一层sequence_output输出model.get_sequence_output()[batch_size, seq_length, hidden_size] output_weights,#输入是model.get_embedding_table(),[vocab_size,hidden_size] positions, #mask词的位置 label_ids, #label,真实值结果 label_weights): """Get loss and log probs for the masked LM.""" # 根据positions位置获取masked词在Transformer的输出结果,即要预测的那些位置的encoder input_tensor = gather_indexes(input_tensor, positions)#[batch_size*max_pred_pre_seq,hidden_size] with tf.variable_scope("cls/predictions"): # 在输出之前添加一个带激活函数的全连接神经网络,只在预训练阶段起作用 with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=modeling.get_activation(bert_config.hidden_act), kernel_initializer=modeling.create_initializer( bert_config.initializer_range)) input_tensor = modeling.layer_norm(input_tensor) # output_weights是和传入的word embedding一样的,这里再添加一个bias output_bias = tf.get_variable( "output_bias", shape=[bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size*max_pred_pre_seq,vocab_size] logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1)#得出masked词的softmax结果,[batch_size*max_pred_pre_seq,vocab_size] # label_ids表示mask掉的Token的id,下面这部分就是根据真实值计算loss了。 label_ids = tf.reshape(label_ids, [-1])#[batch_size*max_pred_per_seq] label_weights = tf.reshape(label_weights, [-1]) one_hot_labels = tf.one_hot( label_ids, depth=bert_config.vocab_size, dtype=tf.float32)#[batch_size*max_pred_per_seq,vocab_size] # 但是由于实际MASK的可能不到20,比如只MASK18,那么label_ids有2个0(padding),而label_weights=[1, 1, ...., 0, 0],说明后面两个label_id是padding的,计算loss要去掉,label_weights就是起一个标记作用 per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])#[batch_size*max_pred_per_seq] numerator = tf.reduce_sum(label_weights * per_example_loss) #一个batch的loss denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator #平均loss return (loss, per_example_loss, log_probs)
重要补充:预训练中的随机MASK函数
核心思想:每个输入序列,只有最多15%的token被mask,而其中80%的机会被替换成[MASK],10%的机会保持原词不变,10%的机会随机替换为字典中的任意词。代码如何实现呢?先获取每个token的索引位置,然后随机打乱索引位置,接着取前15%的token进行替换即可。在替换中,再次利用随机函数,实现80%替换为[MASK]等,代码层面利用random函数还是比较巧妙的。
def create_masked_lm_predictions(tokens, #list存放的sequence,例如[CLS,今, 天, 举, 行, 的, 国, 家, 发, 展, 改, 革, 委, 新, 闻, 发, 布, 会, SEP] masked_lm_prob, #代码中是0.15 max_predictions_per_seq, #代码中20 vocab_words, rng): #rng=random.Random() cand_indexes = [] # [CLS]和[SEP]不能用于MASK for (i, token) in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue cand_indexes.append(i) #随机打乱索引顺序 rng.shuffle(cand_indexes) output_tokens = list(tokens) #masked token数量,从最大mask配置数和seq长度*mask比例中取一个最小数,作为这个seq最终的mask数量 num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) masked_lms = [] #covered_indexes存放被mask token的索引位置 covered_indexes = set() for index in cand_indexes: #达到mask的数量,就停止 if len(masked_lms) >= num_to_predict: break if index in covered_indexes: continue covered_indexes.add(index) masked_token = None # 80% of the time, replace with [MASK],替换为[MASK] if rng.random() < 0.8: masked_token = "[MASK]" else: # 10% of the time, keep original,保持原词 if rng.random() < 0.5: masked_token = tokens[index] # 10% of the time, replace with random word,随机替换 else: masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] #将masked_token替换覆盖原token output_tokens[index] = masked_token #保存masked token的原索引位置,及真实的label token masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) # 按照下标重排,保证是原来句子中出现的顺序 masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_positions = [] masked_lm_labels = [] for p in masked_lms: masked_lm_positions.append(p.index) masked_lm_labels.append(p.label) #返回带mask的sequence tokens,被masked token的原索引位置,及原来的真实label token ,以便计算loss return (output_tokens, masked_lm_positions, masked_lm_labels)
举例实现随机替换的思想:
import random,collections MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) #返回的是一个Random对象,每次再调用rng.random()都返回一个0~1的随机数,这里与bert原代码保持一致,种子都是12345 rng=random.Random(12345)#这里rng一定要放在函数外面,这样相当于在外部完成初始化,每次调用函数才会随机生成不断变化的结果 def create_mask_sample(sequence="",mask_prob=0.15,vocab_words=[],rng=None): tokens=[] cand_indexes = [] for i,w in enumerate(sequence): cand_indexes.append(i) tokens.append(w) #随机打乱索引顺序 rng.shuffle(cand_indexes) #mask后输出tokens output_tokens = list(tokens) #一个输入序列中需要mask的数量 num_to_predict = int(len(tokens)*mask_prob) masked_lms = [] #covered_indexes存放被mask token的索引位置 covered_indexes = set() for index in cand_indexes: #达到mask的数量,就停止 if len(masked_lms) >= num_to_predict: break if index in covered_indexes: continue covered_indexes.add(index) masked_token = None # 80% of the time, replace with [MASK],替换为[MASK] if rng.random() < 0.8: #这里有80%的概率是满足<0.8 masked_token = "[MASK]" else: #如果是>=0.8情况呢,这里有20%的概率 # 剩下的概率一半保持原词,也就是10% of the time, keep original,保持原词 if rng.random() < 0.5: masked_token = tokens[index] # 10% of the time, replace with random word,随机替换 else: masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] #将masked_token替换覆盖原token output_tokens[index] = masked_token #保存masked token的原索引位置,及真实的label token masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) # 按照下标重排,保证是原来句子中出现的顺序 masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_positions = [] masked_lm_labels = [] for p in masked_lms: masked_lm_positions.append(p.index) masked_lm_labels.append(p.label) #返回带mask的sequence tokens,被masked token的原索引位置,及原来的真实label token ,以便计算loss return (output_tokens, masked_lm_positions, masked_lm_labels)
#举例子测试
seq='今天下午举行的市新冠肺炎疫情防控工作领导小组新闻发布会透露:近期,多个国家和地区出现新冠肺炎确诊病例,数量持续攀升。鉴于当前境外疫情防控形势,结合上海实际,市防控工作领导小组及相关部门综合研判,进一步明确了涉外疫情防控和入境人员健康管理措施。' v_words=['自', '今', '年', '3', '月', '1', '日', '起', ',', '重', '新', '调', '整', '存', '量', '房', '贷', '利', '率', ',', '存', '量', '浮', '动', '利', '率', '贷', '款', '客', '户', '可', '以', '有', '两', '个', '选', '择', ',', '原', '则', '上', '转', '换', '工', '作', '应', '于', '今', '年', '8', '月', '底', '前', '完', '成', '。', '目', '前', ',', '已', '有', '至', '少', '2', '4', '家', '主', '要', '银', '行', '发', '布', '了', '相', '关', '公', '告', ',', '多', '家', '银', '行', '称', '还', '将', '陆', '续', '发', '送', '一', '对', '一', '短', '信'] output_tokens,masked_lm_positions,masked_lm_labels=create_mask_sample(sequence=seq,mask_prob=0.1,vocab_words=v_words,rng=rng) print(len(output_tokens)) print(''.join(output_tokens)) print(masked_lm_positions) print(masked_lm_labels)
out:
121 今天下午[MASK]行的市新冠肺炎疫情防控工[MASK]领导小组新闻发布会透露:[MASK]期,多个国家和地区出现新冠[MASK]炎确诊病例[MASK]数量持[MASK]攀升[MASK]鉴于当[MASK]境外疫情防控形势,结合上海实际,市防控工[MASK]领导小组及相关部门[MASK]合研判,进一步明[MASK]了涉外疫情防控和入境人员[MASK]康管理措施。 [4, 17, 30, 44, 50, 54, 57, 61, 82, 92, 101, 114] ['举', '作', '近', '肺', ',', '续', '。', '前', '作', '综', '确', '健']
注意:同一段话,每调用一次都会随机生成不同的mask结果,达到随机mask目的。
二、 Next Sentence Prediction
get_next_sentence_output函数用于计算「任务#2」的训练 loss,这部分比较简单,只需要再额外加一层softmax输出即可。输入为 BertModel 的最后一层 pooled_output 输出([batch_size, hidden_size]),因为该任务属于二分类问题,所以只需要每个序列的第一个 token【CLS】即可。
def get_next_sentence_output(bert_config, input_tensor,#pooled_output 输出,shape=[batch_size, hidden_size] labels): """Get loss and log probs for the next sentence prediction.""" # 标签0表示 下一个句子关系成立;标签1表示 下一个句子关系不成立。这个分类器的参数在实际Fine-tuning阶段会丢弃掉 with tf.variable_scope("cls/seq_relationship"): #初始化权重参数,最终的分类结果是只有2个,所以shape=[2,hidden_size] output_weights = tf.get_variable( "output_weights", shape=[2, bert_config.hidden_size], initializer=modeling.create_initializer(bert_config.initializer_range)) output_bias = tf.get_variable( "output_bias", shape=[2], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True)#输入与权重相乘,shape=[batch_size,2] logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1)#softmax输出:shape=[batch_size,2] #下面这部分就是根据真实值计算损失loss了 labels = tf.reshape(labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) return (loss, per_example_loss, log_probs)