tf.nn.embedding_lookup()

1
2
3
4
5
6
7
8
tf.nn.embedding_lookup(
    params,
    ids,
    partition_strategy='mod',
    name=None,
    validate_indices=True,
    max_norm=None
)

功能:选取一个张量里面索引对应的行的向量

TensorFlow链接:https://tensorflow.google.cn/api_docs/python/tf/nn/embedding_lookup?hl=en

参数:

  • params:张量或数组;
  • id:对应的索引
  • partition_strategy:partition_strategy是用于当len(params) > 1,params的元素分割不能整分的话,则前(max_id + 1) % len(params)多分一个id.
    • 当partition_strategy = 'mod'的时候,13个ids划分为5个分区:[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]],也就是是按照数据列进行映射,然后再进行look_up操作。默认是mod
    • 当partition_strategy = 'div'的时候,13个ids划分为5个分区:[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]],也就是是按照数据先后进行排序标序,然后再进行look_up操作。

 

 

(图来自https://www.jianshu.com/p/abea0d9d2436

 

举例:

1
2
3
4
5
6
7
8
9
10
import numpy as np
A = tf.convert_to_tensor(np.array([[[1],[2]],[[3],[4]],[[5],[6]]]))
B = tf.nn.embedding_lookup(A, [[0,1],[1,0],[0,0]])
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('A',sess.run(A))
    print('A shape',A.shape)
    print('B',sess.run(B))
    print('B shape',B.shape)

结果:

  

  

参考文献:

【1】tf.nn.embedding_lookup记录

posted @   nxf_rabbit75  阅读(587)  评论(0编辑  收藏  举报
编辑推荐:
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现
点击右上角即可分享
微信分享提示