TF里面的embedding
参考这篇文章:
https://www.jianshu.com/p/e8986d0ff4ff
《详解TF中的Embedding操作!》
什么是embedding?
embedding,我们可以简单的理解为,将一个特征转换为一个向量。
对于离散特征,我们一般的做法是将其转换为one-hot,但对于itemid这种离散特征,转换成one-hot之后维度非常高,但里面只有一个是1,其余都为0。这种情况下,我们的通常做法就是将其转换为embedding。
embedding的过程是什么样子的呢?它其实就是一层全连接的神经网络,如下图所示:

假设一个特征共有5个取值,也就是说one-hot之后会变成5维,我们想将其转换为embedding表示,其实就是接入了一层全连接神经网络。由于只有一个位置是1,其余位置是0,因此得到的embedding就是与其相连的图中红线上的权重。
tf1.x中的embedding实现
在tf1.x中,我们使用embedding_lookup函数来实现emedding# embedding embedding = tf.constant( [[0.21,0.41,0.51,0.11]], [0.22,0.42,0.52,0.12], [0.23,0.43,0.53,0.13], [0.24,0.44,0.54,0.14]],dtype=tf.float32) feature_batch = tf.constant([2,3,1,0]) get_embedding1 = tf.nn.embedding_lookup(embedding,feature_batch) 等价于: embedding = tf.constant( [ [0.21,0.41,0.51,0.11], [0.22,0.42,0.52,0.12], [0.23,0.43,0.53,0.13], [0.24,0.44,0.54,0.14] ],dtype=tf.float32) feature_batch = tf.constant([2,3,1,0]) feature_batch_one_hot = tf.one_hot(feature_batch,depth=4) get_embedding2 = tf.matmul(feature_batch_one_hot,embedding)
验证一下:
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) embedding1,embedding2 = sess.run([get_embedding1,get_embedding2]) print(embedding1) print(embedding2)
tf1.x中与embedding类似操作
通过查看embedding_lookup函数的源码,不难发现,它是gather函数的一种特殊形式:
等价于:
mbedding = tf.constant( [ [0.21,0.41,0.51,0.11], [0.22,0.42,0.52,0.12], [0.23,0.43,0.53,0.13], [0.24,0.44,0.54,0.14] ],dtype=tf.float32) index_a = tf.Variable([2,3,1,0]) gather_a = tf.gather(embedding, index_a) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(gather_a))
tf1.x中多值离散特征处理
上面所说的embedding_lookup函数,只能处理一个离散特征有一个取值的情况,但实际中,有的离散特征可能有两三个取值,如一个人的爱好,可能既喜欢篮球又喜欢羽毛球,这样转成one-hot的时候,有两个地方为1(这里应该不叫one-hot,确切来说是multi-hot)。我们称这种情况为多值离散特征。这种情况下,如何处理呢?我们使用tf.nn.embedding_lookup_sparse函数。
# sparse embedding a = tf.SparseTensor(indices=[[0, 0],[1, 2],[1,3]], values=[1, 2, 3], dense_shape=[2, 4]) embedding = tf.constant( [ [0.21,0.41,0.51,0.11], [0.22,0.42,0.52,0.12], [0.23,0.43,0.53,0.13], [0.24,0.44,0.54,0.14] ],dtype=tf.float32) embedding_sparse = tf.nn.embedding_lookup_sparse(embedding, sp_ids=a, sp_weights=None) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(embedding_sparse))
我们一步步来看上面的过程,首先,我们需要有一个SparseTensor,这个tensor的shape是[2, 4]的,其中不为0的地方呢一共有三个,即[0, 0],[1, 2],[1,3],这三处的value分别是1,2,3,这个SparseTensor其实长下面这个样子:

b = tf.sparse_tensor_to_dense(a)
接下来,在embedding_lookup_sparse中我们提供了三个参数,第一个不解释了,第二个sp_ids即我们定义的SparseTensor,第三个参数sp_weights=None代表的每一个取值的权重,如果是None的话,所有权重都是1,也就是相当于取了平均。如果不是None的话,我们需要同样传入一个SparseTensor。
输出为:
这样,结果的第一行没什么好解释的了,即取了我们定义的embedding这个tensor的index=1的那一行,结果的第二行相当于取了我们定义的embedding这个tensor的index=2和index=3这两行的平均值。
tf2.0中embedding实现
在tf2.0中,embedding同样可以通过embedding_lookup来实现,不过不同的是,我们不需要通过sess.run来获取结果了,可以直接运行结果,并转换为numpy。
embedding = tf.constant( [ [0.21,0.41,0.51,0.11], [0.22,0.42,0.52,0.12], [0.23,0.43,0.53,0.13], [0.24,0.44,0.54,0.14] ],dtype=tf.float32) feature_batch = tf.constant([2,3,1,0]) get_embedding1 = tf.nn.embedding_lookup(embedding,feature_batch) feature_batch_one_hot = tf.one_hot(feature_batch,depth=4) get_embedding2 = tf.matmul(feature_batch_one_hot,embedding) print(get_embedding1.numpy().tolist())
keras的用法:
num_classes=10 input_x = tf.keras.Input(shape=(None,),) embedding_x = layers.Embedding(num_classes, 10)(input_x) hidden1 = layers.Dense(50,activation='relu')(embedding_x) output = layers.Dense(2,activation='softmax')(hidden1) x_train = [2,3,4,5,8,1,6,7,2,3,4,5,8,1,6,7,2,3,4,5,8,1,6,7,2,3,4,5,8,1,6,7,2,3,4,5,8,1,6,7,2,3,4,5,8,1,6,7,2,3,4,5,8,1,6,7,2,3,4,5,8,1,6,7] y_train = [0,1,0,1,1,0,0,1,0,1,0,1,1,0,0,1,0,1,0,1,1,0,0,1,0,1,0,1,1,0,0,1,0,1,0,1,1,0,0,1,0,1,0,1,1,0,0,1,0,1,0,1,1,0,0,1,0,1,0,1,1,0,0,1] model2 = tf.keras.Model(inputs = input_x,outputs = output) model2.compile(optimizer=tf.keras.optimizers.Adam(0.001), #loss=tf.keras.losses.SparseCategoricalCrossentropy(), loss='sparse_categorical_crossentropy', metrics=['accuracy']) history = model2.fit(x_train, y_train, batch_size=4, epochs=1000, verbose=0)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!
2018-03-09 这一篇LR的推导写的很好,值得细看
2018-03-09 对于wide-deep model,觉得还是很有意思的,通过下面几篇文章加强一下
2018-03-09 学一下Yoshua的这个PPT,很不错的 - 多维灾难,深度的好处,最后还有几个问题可以好好思考
2017-03-09 链表旋转的题目
2017-03-09 模拟Linux路径解析,题目不难,但是字符分段的方法真的太好了 & 字符串split的好方法
2017-03-09 这道题目开始我做错了,校验二叉查找树BST的正确性
2017-03-09 数字组合之后的排序