tf.strided_slice_and_tf.fill_and_tf.concat
tf.strided_slice,tf.fill,tf.concat使用实例
其中,我们需要对tensor data进行切片,tf.strided_slice使用方法请参考
import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # process_decoder_input data = tf.constant( [ [4, 5, 20, 20, 22, 3], [17, 19, 28, 8, 7, 3], [5, 13, 15, 24, 26, 3], [5, 20, 25, 4, 5, 3], [4, 12, 14, 15, 5, 3], [4, 7, 7, 16, 23, 3], [7, 8, 10, 13, 19, 3] ]) batch_size = 6 ending = tf.strided_slice(data, [0, 0], [6, -1], [1, 1]) fill = tf.fill([6, 1], 2) decoder_input = tf.concat([tf.fill([batch_size, 1], 2), ending], 1) # Decoder # 先对target数据进行预处理 def process_decoder_input(data, vocab_to_int, batch_size): """ 补充<GO>,并移除最后一个字符 """ # cut掉最后一个字符 ending = tf.strided_slice(data, [0, 0], [batch_size, -1], [1, 1]) fill = tf.fill([batch_size, 1], vocab_to_int['<GO>']) # vocab_to_int['<GO>']在本例中是2,经过在列维度上的合并,每个序列都是以GO(对应数值为2)开头 decoder_input = tf.concat([fill, ending], 1) return ending, fill, decoder_input data = tf.constant( [ [4, 5, 20, 20, 22, 3], [17, 19, 28, 8, 7, 3], [5, 13, 15, 24, 26, 3], [5, 20, 25, 4, 5, 3], [4, 12, 14, 15, 5, 3], [4, 7, 7, 16, 23, 3], [7, 8, 10, 13, 19, 3] ] ) target_letter_to_int = { '<PAD>': 0, '<UNK>': 1, '<GO>': 2, '<EOS>': 3, 'a': 4, 'b': 5, 'c': 6, 'd': 7, 'e': 8, 'f': 9, 'g': 10, 'h': 11, 'i': 12, 'j': 13, 'k': 14, 'l': 15, 'm': 16, 'n': 17, 'o': 18, 'p': 19, 'q': 20, 'r': 21, 's': 22, 't': 23, 'u': 24, 'v': 25, 'w': 26, 'x': 27, 'y': 28, 'z': 29} batch_size = 6 ending, fill, decoder_input = process_decoder_input(data, target_letter_to_int, batch_size) with tf.Session() as sess: # 初始化会话 sess.run(tf.global_variables_initializer()) print('ending:\n', sess.run(ending)) print('fill:\n', sess.run(fill)) print('decoder_input:\n', sess.run(decoder_input))
结果如下:
''' ending: [[ 4 5 20 20 22] [17 19 28 8 7] [ 5 13 15 24 26] [ 5 20 25 4 5] [ 4 12 14 15 5] [ 4 7 7 16 23]] fill: [[2] [2] [2] [2] [2] [2]] decoder_input: [[ 2 4 5 20 20 22] [ 2 17 19 28 8 7] [ 2 5 13 15 24 26] [ 2 5 20 25 4 5] [ 2 4 12 14 15 5] [ 2 4 7 7 16 23]] '''