tf.nn.embedding_lookup
tf.nn.embedding_lookup
import tensorflow as tf from distutils.version import LooseVersion import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Check TensorFlow Version # format使用:https://www.runoob.com/python/att-string-format.html assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer' print('TensorFlow Version: {}'.format(tf.__version__)) # decoding_layer target_vocab_size = 30 decoding_embedding_size = 15 # 创建一个shape为[target_vocab_size, decoding_embedding_size]的矩阵变量 decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size])) decoder_input = tf.constant([[2, 4, 5, 20, 20, 22], [2, 17, 19, 28, 8, 7]]) # decoder_input相当于索引,根据这个索引去decoder_embeddings矩阵中筛选出该索引对应的向量 decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input) with tf.Session() as sess: # 初始化会话 sess.run(tf.global_variables_initializer()) print(sess.run(decoder_input)) print(sess.run(decoder_embed_input)) print(sess.run(decoder_embed_input).shape) print(sess.run(decoder_embeddings).shape) ''' TensorFlow Version: 1.1.0 [[ 2 4 5 20 20 22] [ 2 17 19 28 8 7]] [[[0.7545215 0.7695402 0.8238114 0.5432198 0.9996183 0.9811146 0.95969343 0.41114593 0.97545445 0.24203181 0.09990311 0.95584977 0.01549327 0.24147344 0.77837694] [0.3278563 0.15792835 0.6561059 0.05010188 0.6810814 0.48657227 0.76693904 0.3541503 0.24678373 0.6569611 0.7002362 0.8788489 0.55558705 0.8038074 0.9971179 ] [0.47802067 0.4191296 0.99486816 0.41066968 0.23289478 0.32609868 0.9676993 0.15804064 0.530162 0.27542043 0.1686151 0.32158124 0.9871446 0.2646426 0.04092526] [0.18767893 0.35398638 0.68607545 0.65941226 0.6620586 0.8647306 0.7390516 0.869087 0.43624723 0.17690945 0.05664539 0.71465147 0.931615 0.6130588 0.00999928] [0.18767893 0.35398638 0.68607545 0.65941226 0.6620586 0.8647306 0.7390516 0.869087 0.43624723 0.17690945 0.05664539 0.71465147 0.931615 0.6130588 0.00999928] [0.26353955 0.7629268 0.8845804 0.33571935 0.7586707 0.3451711 0.94198895 0.27516353 0.80296195 0.35592806 0.10672879 0.4347086 0.9473572 0.04584897 0.5173352 ]] [[0.7545215 0.7695402 0.8238114 0.5432198 0.9996183 0.9811146 0.95969343 0.41114593 0.97545445 0.24203181 0.09990311 0.95584977 0.01549327 0.24147344 0.77837694] [0.15764415 0.07040286 0.2844795 0.17439246 0.01639402 0.39553535 0.61776114 0.8033254 0.32655883 0.5642803 0.9243225 0.27921832 0.8107116 0.99436224 0.29784715] [0.49179244 0.09336936 0.5070219 0.21457541 0.5522537 0.7257378 0.7425264 0.46288037 0.47577012 0.4681779 0.35275757 0.106884 0.04049754 0.6626127 0.51448214] [0.9727278 0.3141979 0.5706855 0.75443506 0.47404313 0.6312864 0.5409869 0.11424744 0.02585125 0.6820954 0.17008471 0.8503103 0.02040458 0.8472682 0.06770897] [0.01118135 0.9363662 0.63658035 0.76509845 0.9903203 0.49527347 0.5959027 0.81918335 0.06886601 0.4056344 0.7938701 0.01046228 0.3069656 0.23374438 0.86642563] [0.21021092 0.8584006 0.32006896 0.05085099 0.5072923 0.9867519 0.7337296 0.937829 0.90734327 0.13784957 0.36768234 0.31802237 0.62072766 0.9816464 0.5022781 ]]] (2, 6, 15) (30, 15) '''
如何理解呢?我们先知道了我们target序列中的字符库长度,然后随机创建一个变量(矩阵:decoder_embeddings)本例是(30*15)
下面说说tf.nn.embedding_lookup()作用:主要是选取一个张量里面索引对应的元素。
tf.nn.embedding_lookup(params, ids):params可以是张量也可以是数组等,id就是对应的索引,其他的参数不介绍
这样我们的decoder_input本来就是target序列,已经被数字化,可以作为索引id,从decoder_embeddings矩阵中获取到相对应的向量