对tensorflow 中的attention encoder-decoder模型调试分析

  1 #-*-coding:utf8-*-
  2 
  3 __author = "buyizhiyou"
  4 __date = "2017-11-21"
  5 
  6 
  7 import random, time, os, decoder
  8 from PIL import Image
  9 import numpy as np
 10 import tensorflow as tf
 11 import pdb
 12 import decoder
 13 import random
 14 
 15 '''
 16 在汉字ocr项目中,利用基于attention的encoder-decoder(seq2seq)模型进行端对端的训练
 17 单步调试,追踪tensorflow 对 attention-seq2seq模型的实现方式
 18 python 中seq2seq.py的接口:tf.nn.seq2seq.embedding_attention_seq2seq()
 19 把用到的部分取出来单独调试
 20 '''
 21 
 22 batch_size = 16
 23 dec_seq_len = 8#图片对应的汉字数8
 24 enc_lstm_dim = 256
 25 dec_lstm_dim = 512
 26 vocab_size = 1002
 27 embedding_size = 100
 28 lr = 0.01
 29 global_step = tf.Variable(0)
 30 
 31 cnn = tf.truncated_normal([16,10,35,64],mean=0,stddev=1.0,dtype=tf.float32)#模拟初始化一个cnn提取特征后的图片
 32 #(batch_size,height,width,channels)(16, 10, 35, 64)
 33 true_labels = []
 34 #随即生成batch中图片对应的序列,无需embedding
 35 for i in range(batch_size):
 36     seq_label = []
 37     for j in range(dec_seq_len):
 38         seq_label.append(random.randint(0,1000))
 39     true_labels.append(seq_label)
 40 
 41 
 42 #编码
 43 def encoder(inp):#inp:shape=(16, 35, 64)
 44     #pdb.set_trace()
 45     enc_init_shape = [batch_size, enc_lstm_dim]#[16,256]
 46     with tf.variable_scope('encoder_rnn'):
 47         with tf.variable_scope('forward'):
 48             lstm_cell_fw = tf.nn.rnn_cell.LSTMCell(enc_lstm_dim)
 49             init_fw = tf.nn.rnn_cell.LSTMStateTuple(\
 50                                 tf.get_variable("enc_fw_c", enc_init_shape),\
 51                                 tf.get_variable("enc_fw_h", enc_init_shape)
 52                                 )
 53         with tf.variable_scope('backward'):
 54             lstm_cell_bw = tf.nn.rnn_cell.LSTMCell(enc_lstm_dim)
 55             init_bw = tf.nn.rnn_cell.LSTMStateTuple(\
 56                                 tf.get_variable("enc_bw_c", enc_init_shape),\
 57                                 tf.get_variable("enc_bw_h", enc_init_shape)
 58                                 )
 59         output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_cell_fw, \
 60                                                     lstm_cell_bw, \
 61                                                     inp, \
 62                                                     sequence_length = tf.fill([batch_size],\
 63                                                     tf.shape(inp)[1]), #(35,35,35...,35,35,35)
 64                                                     initial_state_fw = init_fw, \
 65                                                     initial_state_bw = init_bw \
 66                                                     )#shape=(16, 35, 256)
 67     return tf.concat(2,output)##shape=(16, 35, 512)
 68 
 69 encoder = tf.make_template('fun', encoder)
 70 # shape is (batch size, rows, columns, features)
 71 # swap axes so rows are first. map splits tensor on first axis, so encoder will be applied to tensors
 72 # of shape (batch_size,time_steps,feat_size)
 73 rows_first = tf.transpose(cnn,[1,0,2,3])#shape=(10, 16, 35, 64)
 74 res = tf.map_fn(encoder, rows_first, dtype=tf.float32)#shape=(10, 16, 35, 512)
 75 encoder_output = tf.transpose(res,[1,0,2,3])#shape=(16, 10, 35, 512)
 76 
 77 dec_lstm_cell = tf.nn.rnn_cell.LSTMCell(dec_lstm_dim)
 78 dec_init_shape = [batch_size, dec_lstm_dim]
 79 dec_init_state = tf.nn.rnn_cell.LSTMStateTuple( tf.truncated_normal(dec_init_shape),\
 80                                                 tf.truncated_normal(dec_init_shape) )
 81 
 82 init_words = np.zeros([batch_size,1,vocab_size])#(16, 1, 1002)
 83 
 84 
 85 #pdb.set_trace()
 86 (output,state) = decoder.embedding_attention_decoder(dec_init_state,#[16, 512]第一个解码cell的state=[c,h]
 87                                                     tf.reshape(encoder_output,[batch_size, -1,2*enc_lstm_dim]),
 88                                                     #encoder输出reshape为 attention states作为attention模块的输入 shape=(16,350,512)
 89                                                     dec_lstm_cell,#lstm单元,作为解码层
 90                                                     vocab_size,#1002
 91                                                     dec_seq_len,#8
 92                                                     batch_size,#16
 93                                                     embedding_size,#100
 94                                                     feed_previous=True)#dec_seq_len = num_words = time_steps
 95 pdb.set_trace()
 96 cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(output,true_labels))
 97 learning_rate = tf.train.exponential_decay(lr, global_step, 50, 0.9)
 98 train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy,global_step=global_step)
 99 correct_prediction = tf.equal(tf.to_int32(tf.argmax( output, 2)), true_labels)
100                                                

decode.py

  1 #-*-coding:utf8-*-
  2 
  3 
  4 """
  5 截取自tensorflow seq2seq.py 文件
  6 """
  7 import numpy as np
  8 import tensorflow as tf
  9 import pdb
 10 from tensorflow.python import shape
 11 from tensorflow.python.framework import dtypes
 12 from tensorflow.python.framework import ops
 13 from tensorflow.python.ops import array_ops
 14 from tensorflow.python.ops import control_flow_ops
 15 from tensorflow.python.ops import embedding_ops
 16 from tensorflow.python.ops import math_ops
 17 from tensorflow.python.ops import nn_ops
 18 from tensorflow.python.ops import rnn
 19 from tensorflow.python.ops import rnn_cell
 20 from tensorflow.python.ops import variable_scope
 21 from tensorflow.python.util import nest
 22 
 23 linear = rnn_cell._linear    # pylint: disable=protected-access
 24 
 25 def attention_decoder(initial_state,#(16, 512)
 26                       attention_states,#shape=(16, 350, 512)
 27                       cell,
 28                       vocab_size,#1002
 29                       time_steps,#num_words,8
 30                       batch_size,#16
 31                       output_size=None,#512
 32                       loop_function=None,
 33                       dtype=None,
 34                       scope=None):
 35     pdb.set_trace()
 36     if attention_states.get_shape()[2].value is None:#tf 张量 get_shape()方法获取size
 37         raise ValueError("Shape[2] of attention_states must be known: %s"
 38                                          % attention_states.get_shape())
 39     if output_size is None:
 40         output_size = cell.output_size#512
 41 
 42     with variable_scope.variable_scope(scope or "attention_decoder", dtype=dtype) as scope:
 43         dtype = scope.dtype
 44 
 45         attn_length = attention_states.get_shape()[1].value #350
 46         if attn_length is None:
 47             attn_length = shape(attention_states)[1]
 48         attn_size = attention_states.get_shape()[2].value#512
 49 
 50         # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
 51         hidden = array_ops.reshape(attention_states, [-1, attn_length, 1, attn_size])#shape=(16, 350, 1, 512) 
 52         attention_vec_size = attn_size    # Size of query vectors for attention.   512
 53         k = variable_scope.get_variable("AttnW",[1, 1, attn_size, attention_vec_size])#shape=(1,1,512,512)
 54         hidden_features = nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")#(16 ,350, 1, 512) w_1*h_j
 55         v = variable_scope.get_variable("AttnV", [attention_vec_size])
 56 
 57 
 58         def attention(query):
 59             #LSTMStateTuple(c= shape=(16, 512) dtype=float32>, h=< shape=16, 512) dtype=float32>)
 60             """Put attention masks on hidden using hidden_features and query."""
 61             if nest.is_sequence(query):    # If the query is a tuple, flatten it.
 62                 query_list = nest.flatten(query) #[c,h],第一个随即初始化,以后调用之前计算的
 63                 for q in query_list:    # Check that ndims == 2 if specified.
 64                     ndims = q.get_shape().ndims
 65                     if ndims:
 66                         assert ndims == 2
 67                 query = array_ops.concat(1, query_list)# shape=(16, 1024)
 68             with variable_scope.variable_scope("Attention_0"):
 69                 y = linear(query, attention_vec_size, True)# shape=(16, 512) w_2*s_t
 70                 y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) # shape=(16, 1, 1, 512)
 71                 s = math_ops.reduce_sum(
 72                         v * math_ops.tanh(hidden_features + y), [2, 3])  #!!!!!!!!!!!公式(3)shape=(16, 350)
 73                 a = nn_ops.softmax(s)#  公式(2)shape=(16, 350)
 74                 # Now calculate the attention-weighted vector d.
 75                 d = math_ops.reduce_sum(
 76                         array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden,#公式(1)
 77                         [1, 2])#shape=(16, 512) 
 78                 ds = array_ops.reshape(d, [-1, attn_size])#shape=(16, 512) #!!!!!!!!!!!!以上是attention model中三个关键公式的实现
 79             return ds
 80         #pdb.set_trace()
 81         prev = array_ops.zeros([batch_size,output_size])# shape=(16, 512) cell层第一个cell启动计算所需输入,
 82                                                         #随机初始化,以后的cell调用之前的计算结果
 83         batch_attn_size = array_ops.pack([batch_size, attn_size]) #(2,?)
 84         attn = array_ops.zeros(batch_attn_size, dtype=dtype)#shape=(16, 512)
 85         attn.set_shape([None, attn_size])#(16,512)
 86 
 87         def cond(time_step, prev_o_t, prev_softmax_input, state_c, state_h, outputs2):
 88             return time_step < time_steps
 89 
 90         def body(time_step, prev_o_t, prev_softmax_input, state_c, state_h, outputs2):#prev_o_t=prev:shape=(16,512) 
 91                                                 #outputs:shape=(16, ?, 1002) prev_softmax_input=init_word:shape=(16, 1002)
 92             state = tf.nn.rnn_cell.LSTMStateTuple(state_c,state_h)#第一次随机初始状态,之后调用之前的
 93             pdb.set_trace()
 94             with variable_scope.variable_scope("loop_function", reuse=True):
 95                 inp = loop_function(prev_softmax_input, time_step)#shape=(16,100) inp用来做什么 作为每个cell单元从下而
 96                 #来的输入??而prev_o_t则为从左而来的输入??而且Inp和上一个cell单元的softmax_input(最终进softmax之前的cell输出)有关(prev_softmax_input)
 97 
 98             input_size = inp.get_shape().with_rank(2)[1]#100
 99             if input_size.value is None:
100                 raise ValueError("Could not infer input size from input: %s" % inp.name)
101             x = tf.concat(1,[inp,prev_o_t])#shape=(16, 612)  这个地方inp ,prev_o_t = loop_function(softmax_output),output
102             # Run the RNN.
103             cell_output, state = cell(x, state)#decoder层512个lstm单元 cell_output:shape=(16, 512) state:shape=(16, 512)
104             # Run the attention mechanism.
105             attn = attention(state)#shape=(16, 512) attenion模块的输出,C_i
106 
107             with variable_scope.variable_scope("AttnOutputProjection"):
108                 output = math_ops.tanh(linear([cell_output, attn], output_size, False))#shape=(16, 512) y_i = f(C_i,S_i)
109                 with variable_scope.variable_scope("FinalSoftmax"):
110                     softmax_input = linear(output,vocab_size,False)#shape=(16, 1002) #decoder层后加一层softmax??作为softmax_input
111 
112             new_outputs = tf.concat(1, [outputs2,tf.expand_dims(softmax_input,1)])#shape=(16, ?, 1002)[,...y_t-1,y_t,...]
113             return (time_step + tf.constant(1, dtype=tf.int32),\
114                             output, softmax_input, state.c, state.h, new_outputs)#既是输出,又是下一轮的输入
115 
116         time_step = tf.constant(0, dtype=tf.int32)
117         shape_invariants = [time_step.get_shape(),\
118                             prev.get_shape(),\
119                             tf.TensorShape([batch_size, vocab_size]),\
120                             tf.TensorShape([batch_size,512]),\
121                             tf.TensorShape([batch_size,512]),\
122                             tf.TensorShape([batch_size, None, vocab_size])]
123 
124 # START keyword is 0
125         init_word = np.zeros([batch_size, vocab_size])#shape=(16,1002)
126 
127         loop_vars = [time_step,\
128                      prev,\
129                      tf.constant(init_word, dtype=tf.float32),\
130                      initial_state.c,initial_state.h,\
131                      tf.zeros([batch_size,1,vocab_size])] 
136 
137         outputs = tf.while_loop(cond, body, loop_vars, shape_invariants)##shape=(16, ?, 1002)
138         '''
139         loop_vars = [...]
140         while cond(*loop_vars):
141             loop_vars = body(*loop_vars)   
142         '''
143 
144     return outputs[-1][:,1:], tf.nn.rnn_cell.LSTMStateTuple(outputs[-3],outputs[-2])
145 
146 def embedding_attention_decoder(initial_state,#shape=(16, 512)
147                                 attention_states,# shape=(16, 350, 512)
148                                 cell,#定义的lstm单元
149                                 num_symbols,#1002
150                                 time_steps,
151                                 batch_size,#16
152                                 embedding_size,#100
153                                 output_size=None,#512
154                                 output_projection=None,
155                                 feed_previous=False,#True
156                                 update_embedding_for_previous=True,
157                                 dtype=None,
158                                 scope=None):
159     if output_size is None:
160         output_size = cell.output_size#512
161     if output_projection is not None:
162         proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
163         proj_biases.get_shape().assert_is_compatible_with([num_symbols])
164 
165     with variable_scope.variable_scope(scope or "embedding_attention_decoder", dtype=dtype) as scope:
166         embedding = variable_scope.get_variable("embedding",[num_symbols, embedding_size])
167         loop_function = tf.nn.seq2seq._extract_argmax_and_embed(embedding, 
168                           output_projection,update_embedding_for_previous) if feed_previous else None
169                         #(16,1002)==>(16,100)找argmax,然后embedding
170         return attention_decoder(
171                 initial_state,
172                 attention_states,
173                 cell,
174                 num_symbols,#1002
175                 time_steps,#8
176                 batch_size,
177                 output_size=output_size,#512
178                 loop_function=loop_function)

 

关于embedding接口:

测试如下:

 1 #-*-coding:utf8-*-
 2 
 3 __author = "buyizhiyou"
 4 __date = "2017-11-21"
 5 
 6 import tensorflow as tf
 7 import numpy as np
 8 
 9 '''
10 测试embedding接口
11 '''
12 embedding = tf.Variable(np.identity(5,dtype=np.int32))
13 inputs = tf.placeholder(dtype=tf.int32,shape=[None])
14 input_embedding = tf.nn.embedding_lookup(embedding,inputs)
15 
16 with tf.Session() as sess:
17     sess.run(tf.global_variables_initializer())
18     print(sess.run(embedding))
19 '''
20 [[1 0 0 0 0]
21  [0 1 0 0 0]
22  [0 0 1 0 0]
23  [0 0 0 1 0]
24  [0 0 0 0 1]]
25 '''
26     print(sess.run(input_embedding,feed_dict={inputs:[1,2,3,0,3,2,1]}))
27 '''
28 [[0 1 0 0 0]
29  [0 0 1 0 0]
30  [0 0 0 1 0]
31  [1 0 0 0 0]
32  [0 0 0 1 0]
33  [0 0 1 0 0]
34  [0 1 0 0 0]]
35 '''

 

posted @ 2017-11-21 17:13  阿夏z  阅读(5435)  评论(2编辑  收藏  举报