Tensorflow --BeamSearch
github:https://github.com/zle1992/Seq2Seq-Chatbot
1、 注意在infer阶段,需要需要reuse,
2、If you are using the BeamSearchDecoder
with a cell wrapped in AttentionWrapper
, then you must ensure that:
- The encoder output has been tiled to
beam_width
viatf.contrib.seq2seq.tile_batch
(NOTtf.tile
). - The
batch_size
argument passed to thezero_state
method of this wrapper is equal totrue_batch_size * beam_width
. - The initial state created with
zero_state
above contains acell_state
value containing properly tiled final state from the encoder.
1 import tensorflow as tf 2 from tensorflow.python.layers.core import Dense 3 4 5 BEAM_WIDTH = 5 6 BATCH_SIZE = 128 7 8 9 # INPUTS 10 X = tf.placeholder(tf.int32, [BATCH_SIZE, None]) 11 Y = tf.placeholder(tf.int32, [BATCH_SIZE, None]) 12 X_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE]) 13 Y_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE]) 14 15 16 # ENCODER 17 encoder_out, encoder_state = tf.nn.dynamic_rnn( 18 cell = tf.nn.rnn_cell.BasicLSTMCell(128), 19 inputs = tf.contrib.layers.embed_sequence(X, 10000, 128), 20 sequence_length = X_seq_len, 21 dtype = tf.float32) 22 23 24 # DECODER COMPONENTS 25 Y_vocab_size = 10000 26 decoder_embedding = tf.Variable(tf.random_uniform([Y_vocab_size, 128], -1.0, 1.0)) 27 projection_layer = Dense(Y_vocab_size) 28 29 30 # ATTENTION (TRAINING) 31 with tf.variable_scope('shared_attention_mechanism'): 32 attention_mechanism = tf.contrib.seq2seq.LuongAttention( 33 num_units = 128, 34 memory = encoder_out, 35 memory_sequence_length = X_seq_len) 36 37 decoder_cell = tf.contrib.seq2seq.AttentionWrapper( 38 cell = tf.nn.rnn_cell.BasicLSTMCell(128), 39 attention_mechanism = attention_mechanism, 40 attention_layer_size = 128) 41 42 43 # DECODER (TRAINING) 44 training_helper = tf.contrib.seq2seq.TrainingHelper( 45 inputs = tf.nn.embedding_lookup(decoder_embedding, Y), 46 sequence_length = Y_seq_len, 47 time_major = False) 48 training_decoder = tf.contrib.seq2seq.BasicDecoder( 49 cell = decoder_cell, 50 helper = training_helper, 51 initial_state = decoder_cell.zero_state(BATCH_SIZE,tf.float32).clone(cell_state=encoder_state), 52 output_layer = projection_layer) 53 with tf.variable_scope('decode_with_shared_attention'): 54 training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 55 decoder = training_decoder, 56 impute_finished = True, 57 maximum_iterations = tf.reduce_max(Y_seq_len)) 58 training_logits = training_decoder_output.rnn_output 59 60 61 # BEAM SEARCH TILE 62 encoder_out = tf.contrib.seq2seq.tile_batch(encoder_out, multiplier=BEAM_WIDTH) 63 X_seq_len = tf.contrib.seq2seq.tile_batch(X_seq_len, multiplier=BEAM_WIDTH) 64 encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=BEAM_WIDTH) 65 66 67 # ATTENTION (PREDICTING) 68 with tf.variable_scope('shared_attention_mechanism', reuse=True): 69 attention_mechanism = tf.contrib.seq2seq.LuongAttention( 70 num_units = 128, 71 memory = encoder_out, 72 memory_sequence_length = X_seq_len) 73 74 decoder_cell = tf.contrib.seq2seq.AttentionWrapper( 75 cell = tf.nn.rnn_cell.BasicLSTMCell(128), 76 attention_mechanism = attention_mechanism, 77 attention_layer_size = 128) 78 79 80 # DECODER (PREDICTING) 81 predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 82 cell = decoder_cell, 83 embedding = decoder_embedding, 84 start_tokens = tf.tile(tf.constant([1], dtype=tf.int32), [BATCH_SIZE]), 85 end_token = 2, 86 initial_state = decoder_cell.zero_state(BATCH_SIZE * BEAM_WIDTH,tf.float32).clone(cell_state=encoder_state), 87 beam_width = BEAM_WIDTH, 88 output_layer = projection_layer, 89 length_penalty_weight = 0.0) 90 with tf.variable_scope('decode_with_shared_attention', reuse=True): 91 predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 92 decoder = predicting_decoder, 93 impute_finished = False, 94 maximum_iterations = 2 * tf.reduce_max(Y_seq_len)) 95 predicting_logits = predicting_decoder_output.predicted_ids[:, :, 0] 96 97 print('successful')
参考:
https://gist.github.com/higepon/eb81ba0f6663a57ff1908442ce753084
https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/BeamSearchDecoder
https://github.com/tensorflow/nmt#beam-search