QA系统Match-LSTM代码研读
QA系统Match-LSTM代码研读
背景
在QA模型中,Match-LSTM是较早提出的,使用Prt-Net边界模型。本文是对阅读其实现代码的总结。主要思路是对照着论文和代码,对论文中模型的关键结构,查看代码中的具体实现。参考代码是MurtyShikhar实现的。
模型简介
模型的输入是(Passage, Question),模型的输出是(start_idx, end_idx)。对于输入,Passage是QA任务中的正文,输入给模型时已经转化为经过Padding的id-list;Question是QA任务中的问题,输入给模型时已经转化为经过Padding的id-list。对于输出,start_idx是答案在正文的起始位置,end_idx是答案在正文的结束位置。
用于QA的Match-LSTM模型主要由三层构成:
- LSTM预处理层。
分别将Passage和Question通过LSTM进行处理,使每个位置的表示都带有一些上下文信息。- Match-LSTM层。
Match-LSTM最早用于文本蕴含,输入一个前提,一个猜测,判断前提是否能蕴含猜测。在用于QA任务时,Question被当做前提,Passage被当做猜测。依次处理Passage的每个位置,计算Passage每个位置对Question的Attention,进而求出对Question的Attend Vector。该Attend Vector与第一层的输出拼接起来,输入给一个LSTM进行处理,这整个流程被称作Match-LSTM。
其中Attention选择BahdanauAttention,Attention的输入(Query)由上一时刻Match-LSTM的输出及Passage在当前位置的表示拼接,Attention的key是Question每个位置的表示,Attention的value也是Question每个位置的表示。根据Attention的alignment对Attention Value加权求和计算出Attend Vector。
所以,Match-LSTM本质上由一个LSTM单元和一个Attention单元组成。LSTM单元的输出作为Match-LSTM层的输出,LSTM单元的状态和下一个位置的输入拼接起来作为Attention单元的输入(Query),Attention单元的输出(Attend Vector)与当前位置的输入拼接起来作为LSTM单元的输入。也可以理解为在LSTM的基础上增加Attention,改变LSTM的输入,在LSTM的原始输入上增加当前位置对于Question的Attention。- Pointer-Net层。
Pointer-Net层在代码实现上,与Match-LSTM十分接近。只在涉及输入、输出的地方有几处不同。从原理上看,Pointer-Net层也是一个序列化迭代的Attention过程,首先用zero_state作为query对Match-LSTM层的所有输出计算attention,作为回答第一个符号的logit。然后以AttentionWrapper的输出作为下一时刻的query,对Match-LSTM层的所有输出计算attention,如此迭代进行。对于边界模型,秩序计算start_index和end_index,这个迭代过程秩序进行两次。
接下来的几部分对照论文及代码中模型关键结构实现。
模型
模型图构建的入口在qa_model.py文件中class QASystem
类的def setup_system(self)
方法内。这一节主要就是对该方法的细节展开解读。
LSTM预处理层
所有逻辑都包含在qa_model.py文件中,入口位于class QASystem
类的def setup_system(self)
方法内,具体逻辑位于class Encoder
的def encode(self, inputs, masks, encoder_state_input = None)
方法内。
在def setup_system(self)
方法内,通过以下语句调用class Encoder
的def encode(self, inputs, masks, encoder_state_input = None)
方法。
encoder = self.encoder
decoder = self.decoder
encoded_question, encoded_passage, q_rep, p_rep = encoder.encode([self.question, self.passage],
[self.question_lengths, self.passage_lengths], encoder_state_input = None)
再看一下encode
方法的实现。
def encode(self, inputs, masks, encoder_state_input = None):
"""
:param inputs: vector representations of question and passage (a tuple)
:param masks: masking sequences for both question and passage (a tuple)
:param encoder_state_input: (Optional) pass this as initial hidden state to tf.nn.dynamic_rnn to build conditional representations
:return: an encoded representation of the question and passage.
"""
question, passage = inputs
masks_question, masks_passage = masks
# read passage conditioned upon the question
with tf.variable_scope("encoded_question"):
lstm_cell_question = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
encoded_question, (q_rep, _) = tf.nn.dynamic_rnn(lstm_cell_question, question, masks_question, dtype=tf.float32) # (-1,
Q, H)
with tf.variable_scope("encoded_passage"):
lstm_cell_passage = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
encoded_passage, (p_rep, _) = tf.nn.dynamic_rnn(lstm_cell_passage, passage, masks_passage, dtype=tf.float32) # (-1, P,
H)
# outputs beyond sequence lengths are masked with 0s
return encoded_question, encoded_passage , q_rep, p_rep
从代码可以看出,对Passage和Question的预处理就是分别经过两个单向LSTM层(不共享参数),LSTM每个位置的输出作为预处理后的表示。
Match-LSTM层
Match-LSTM的逻辑主要在qa_model.py和attention_wrapper.py两个文件中。虽然tensorflow的contrib库中现在也有attention_wrapper这个模块,但是两者在具体实现上不太相同。入口位于qa_model.py文件class Decoder
类中decode
方法内。
首先,看一下最外层的入口,与LSTM预处理层一样,位于class QASystem
类的def setup_system(self)
方法内。
if self.config.use_match:
self.logger.info("\n========Using Match LSTM=========\n")
logits= decoder.decode([encoded_question, encoded_passage], q_rep, [self.question_lengths, self.passage_lengths], self.
labels)
接下来,进入class Decoder
类中decode
方法。函数逻辑非常清晰,先通过Match-LSTM层,再通过Ptr-Net层。
def decode(self, encoded_rep, q_rep, masks, labels):
output_attender = self.run_match_lstm(encoded_rep, masks)
logits = self.run_answer_ptr(output_attender, masks, labels)
return logits
然后进入run_match_lstm
方法。
def run_match_lstm(self, encoded_rep, masks):
encoded_question, encoded_passage = encoded_rep
masks_question, masks_passage = masks
match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
query_depth = encoded_question.get_shape()[-1]
# output attention is false because we want to output the cell output and not the attention values
with tf.variable_scope("match_lstm_attender"):
attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
lstm_attender = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)
# we don't mask the passage because masking the memories will be handled by the pointerNet
reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)
output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")
output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")
output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)
output_attender = tf.concat([output_attender_fw, output_attender_bw], axis = -1) # (-1, P, 2*H)
return output_attender
该方法的输入
encoded_rep
是一个tuple
,包含Passage和Question的表示;masks
也是一个tuple
,包含Passage和Question的长度。match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
这条语句定义了
Match-LSTM
单元中AttentionMechanism
的输入函数,作为参数该函数被传递给AttentionWrapper
的构造函数,作为attention_input_fn
。AttentionWrapper
本身也是一个RNN
,它组合了一个RNN
和一个AttentionMechanism
,形成一个高级的RNN
单元。该函数就是定义了用于Attention机制的Query是如何生成的,由当前时刻的输入拼接上一个时刻的state,形成Attention的Query。attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
这条语句定义了一个
AttentionMechanism
,也就是一个Attention单元,该类包含一个__call__
方法,调用该对象可以计算出alignments
,调用该类对象的参数如方法定义所示def __call__(self, query, previous_alignments)
。联系上面一起来看,这里的query
就是上面所说的Attention的Query。
至于BahdanauAttention
是如何实现的,暂时不做过详细的介绍,目前该类位于tf.contrib.seq2seq.BahdanauAttention
,已经是tensorflow库的一部分。cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
这条语句定义一个普通的LSTM单元。
lstm_attender = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)
这条语句将上面两步定义的
AttentionMechanism
及LSTM单元组装为一个高级RNN单元。参数还包括了在run_match_lstm
方法一开头顶一个的一个函数,该函数用来生成AttentionMechanism
的query
。reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0) output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn") output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn") output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)
分别正向、反向对Passage的表示应用Match-LSTM,再将输出沿最后一个维度拼接起来作为Match-LSTM层的输出。
我们还可以再近距离看一下LSTM单元和AttentionMechanism
是如何配合工作的,这需要深入到AttentionWrapper
的call
方法,这也是所有RNN单元都需要实现的一个方法。
def call(self, inputs, state):
output_prev_step = state.cell_state.h # get hr_(i-1)
attention_input = self._attention_input_fn(inputs, output_prev_step) # get input to BahdanauAttention to get alpha_i
alignments, raw_scores = self._attention_mechanism(
attention_input, previous_alignments=state.alignments)
expanded_alignments = array_ops.expand_dims(alignments, 1)
attention_mechanism_values = self._attention_mechanism.values
context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
context = array_ops.squeeze(context, [1])
cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
if self._attention_layer is not None:
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
else:
attention = context
if self._alignment_history:
alignment_history = state.alignment_history.write(
state.time, alignments)
else:
alignment_history = ()
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignments=alignments,
alignment_history=alignment_history)
if self._output_attention:
return raw_scores, next_state
else:
return cell_output, next_state
output_prev_step = state.cell_state.h # get hr_(i-1) attention_input = self._attention_input_fn(inputs, output_prev_step)
取LSTM单元上一时刻的状态,与
AttentionWrapper
当前时刻的输入,通过self._attention_input_fn
函数生成attention的Query。这里的self._attention_input_fn
就是上面AttentionWrapper
构造函数的参数attention_input_fn
。alignments, raw_scores = self._attention_mechanism(attention_input, previous_alignments=state.alignments)
调用
AttentionMechaism
对象,计算Attention的alignments。这里的self._attention_mechanism
就是AttentionWrapper
构造函数的参数attention_mechanism_match_lstm
,也就是BahdanauAttention
的一个对象。expanded_alignments = array_ops.expand_dims(alignments, 1) # [batch_size, 1, ques_size] attention_mechanism_values = self._attention_mechanism.values # [batch_size, ques_size, value_dims] context = math_ops.matmul(expanded_alignments, attention_mechanism_values) # [batch_size, 1, value_dims] context = array_ops.squeeze(context, [1]) # [batch_size, value_dims]
通过alignments和attention的Values,计算attend vector,就是对values以alignments为权重求和。
cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
通过
_cell_input_fn
将当前时刻的输入,和attend vector组合起来,成为当前时刻LSTM的输入。然后调用LSTM单元计算当前时刻LSTM单元的输出和状态。if self._attention_layer is not None: attention = self._attention_layer( array_ops.concat([cell_output, context], 1)) else: attention = context
是否需要对attend vector再进行一次线性变换,作为attention,在本例中未做变换,直接用attend vector作为attention。
next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=alignments, alignment_history=alignment_history)
作为RNN的
AttentionWrapper
的下一时刻状态。if self._output_attention: return raw_scores, next_state else: return cell_output, next_state
根据构造函数的参数,决定
AttentionWrapper
的输出是attention score还是LSTM的输出,attention score的意义是求alignments概率之前的那个东西。
Pointer-Net层
以下代码是Pointer-Net层的逻辑,与Match-LSTM层的逻辑非常接近,但是在一些细节上有所区别。相似的部分是,Pointer-Net层的主体也是通过一个AttentionWrapper
完成的,也是组装了一个LSTM
单元和一个BahdanauAttention
单元。与Match-LSTM不同的地方是,LSTM
单元及BahdanauAttention
单元的输入函数不同,AttentionWrapper
的输出内容不同,并且Pointer-Net层使用一个静态rnn。
def run_answer_ptr(self, output_attender, masks, labels):
batch_size = tf.shape(output_attender)[0]
masks_question, masks_passage = masks
labels = tf.unstack(labels, axis=1)
#labels = tf.ones([batch_size, 2, 1])
answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
query_depth_answer_ptr = output_attender.get_shape()[-1]
with tf.variable_scope("answer_ptr_attender"):
attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage)
# output attention is true because we want to output the attention values
cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True )
answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)
logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)
return logits
接下来具体看一下上面这段代码。
batch_size = tf.shape(output_attender)[0] # [batch_size, passage_length, 2 * hidden_size] masks_question, masks_passage = masks labels = tf.unstack(labels, axis=1) # labels : [batch_size, 2]
output_attender
是上一层,也就是Match-LSTM层的输出,形状为[batch_size, passage_length, 2 * hidden_size]
。labels
的形状为[batch_size, 2]
。masks_question
和masks_passage
分别为问题的长度和文章的长度。answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question query_depth_answer_ptr = output_attender.get_shape()[-1]
answer_ptr_cell_input_fn
定义了AttentionWrapper
中LSTM
单元的输入函数。query_depth_answer_ptr
从变量名的字面含义看,是Answer-Ptr层的attention单元的query的维度。with tf.variable_scope("answer_ptr_attender"): attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage) # output attention is true because we want to output the attention values cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True ) answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)
接下来是装配
AttentionWrapper
,这里与Match-LSTM层有区别。在Match-LSTM层的定义中,没有显式地为AttentionWrapper
指定cell_input_fn
参数,而是使用了默认函数。在Match-LSTM层的定义中,显式指定了attention_input_fn
,但是这里没有指定,使用了默认函数。另外一个区别,在Match-LSTM层的定义中,AttentionWrapper
的output_attention
参数是False
,在这里该参数用默认的True
。
对比Match-LSTM层与Pointer-Net层cell_input_fn
的区别。
默认的
cell_input_fn
的定义如下,这是Match-LSTM层采用的。逻辑是将attention的输出和当前的输入拼接起来,作为LSTM
单元的输入。if cell_input_fn is None: cell_input_fn = ( lambda inputs, attention: array_ops.concat([inputs, attention], -1))
Pointer-Net层使用的
cell_input_fn
在上面的代码中已经给出,这里对比一下。只用Attention单元的输出,作为LSTM
单元的输入。这样,LSTM
单元的输入,就与RNN的输入无关了。answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
对比Match-LSTM层与Pointer-Net层attention_input_fn
的区别。
Match-LSTM层采用的
attention_input_fn
是非默认的,在上一节中已经给出,这里对比一下。match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
Pointer-Net层的
attention_input_fn
是默认的,定义如下。if attention_input_fn is None: attention_input_fn = ( lambda _, state: state)
可以看出,在Match-LSTM层,
attention
单元的输入是上一时刻状态与当前输入的拼接。在Pointer-Net层,attention
单元的输入仅仅是上一时刻的状态,与当前时刻的输入无关。
综上两处,可以看出区别。在Match-LSTM层,无论Attention
单元还是LSTM
单元,其输入都要拼接当前时刻输入。而在Pointer-Net层,无论Attention
单元还是LSTM
单元,其输入都与当前时刻的输入无关。这也解释了我最早看代码时的疑惑,为什么计算logits
的函数需要labels
作为参数,labels
不是只有在计算loss
的时候才需要吗?其实虽然这里有labels
这个参数,但是没有实际使用其内容,对于预测过程,只需传一个同样形状的tensor就可以。
再对比最后一个区别,Match-LSTM层与Pointer-Net层在output_attention参数上的区别。
if self._output_attention: return raw_scores, next_state else: return cell_output, next_state
raw_scores
是attention
单元的原始输出,即通过softmax
计算alignments
前的那个输出。cell_output
是LSTM
单元的输出,也就是状态h
。在Match-LSTM层,AttentionWrapper
输出的是其内部LSTM
单元的输出。在Pointer-Net层,AttentionWrapper
输出的是其内部attention
单元的raw_scores
。
logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)
最后是计算
logits
。因为labels
是个长度为2的list
,logits
也是长度为2的list
。但是,这两个list
中元素的shape
是不一样的,labels
中的元素的shape
是[batch_size, 1]
,logits
中的元素的shape
是[batch_size, passage_length]
。
从代码层面来理解,首先是以zero_state
为query去计算attention,attention
单元的key和value都是Match-LSTM层的输出,attention
计算的raw_score
就是第一个输出的logit
。attention
计算出的alignments
与values
计算attend vector
,以其为输入计算LSTM
单元的输出,作为下一时刻的query去计算attention。这样,就计算出了两个logits
。
至此,计算出logits
,预测部分就已经完成了。logits
是一个长度为2的list
,其中每个元素是一个shape
为[batch_size, passage_length]
的tensor。
损失函数
有了logits
,就可以计算损失函数了。
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[0], labels=self.labels[:,0])
losses += tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[1], labels=self.labels[:,1])
self.loss = tf.reduce_mean(losses)
这里只需要理解一个函数即可tf.nn.sparse_softmax_cross_entropy_with_logits
,该函数logits
参数的rank
比labels
多1,多出的那个axis
的维度是num_classes
。labels
以稀疏形式表示,每个元素都是整数,小于num_classes
。
由于之前已经知道,Pointer-Net层求出logits
是一个list
,每个元素的形状是[batch_size, passage_length]
,而输入的labels
的形状是[batch_size, 2]
。因此按照上面代码的方式调用可求出损失函数。