【505】NLP实战系列(二)—— keras 中的 Embedding 层
1. Embedding 层语法
1 | keras.layers.Embedding(input_dim, output_dim, embeddings_initializer = 'uniform' , embeddings_regularizer = None , activity_regularizer = None , embeddings_constraint = None , mask_zero = False , input_length = None ) |
将正整数(索引值)转换为固定尺寸的稠密向量。 例如: [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]。该层只能用作模型中的第一层。
2. 参数说明
- input_dim: int > 0。词汇表大小, 即,最大整数 index + 1。
- output_dim: int >= 0。词向量的维度。
- embeddings_initializer: embeddings 矩阵的初始化方法 (详见 initializers)。
- embeddings_regularizer: embeddings matrix 的正则化方法 (详见 regularizer)。
- embeddings_constraint: embeddings matrix 的约束函数 (详见 constraints)。
- mask_zero: 是否把 0 看作为一个应该被遮蔽的特殊的 "padding" 值。 这对于可变长的 循环神经网络层 十分有用。 如果设定为 True,那么接下来的所有层都必须支持 masking,否则就会抛出异常。 如果 mask_zero 为 True,作为结果,索引 0 就不能被用于词汇表中 (input_dim 应该与 vocabulary + 1 大小相同)。
- input_length: 输入序列的长度,当它是固定的时。 如果你需要连接 Flatten 和 Dense 层,则这个参数是必须的 (没有它,dense 层的输出尺寸就无法计算)。
标记红色的是比较重要的参数,一般来说是需要具体赋值的。
3. 输入尺寸
尺寸为 (batch_size, sequence_length)
的 2D 张量。
- batch_size:每个批次的字符串数量
- sequence_length:字符串长度,多了截断,少了补0
4. 输出尺寸
尺寸为 (batch_size, sequence_length, output_dim)
的 3D 张量。
- batch_size:每个批次的字符串数量
- sequence_length:字符串长度,多了截断,少了补0
- output_dim:稠密矩阵维度
5. 举例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | model = Sequential() model.add(Embedding( 1000 , 64 , input_length = 10 )) # 模型将输入一个大小为 (batch, input_length) 的整数矩阵。 # 输入中最大的整数(即词索引)不应该大于 999 (词汇表大小) # 64 表示稠密矩阵的维度 # input_length=10 表示字符串长度 # 现在 model.output_shape == (None, 10, 64),其中 None 是 batch 的维度。 input_array = np.random.randint( 1000 , size = ( 32 , 10 )) # 新建一个输入数据 # 32 表示字符串数量 # 10 表示字符串长度 # 整体都是一些小于1000的整数表示,每一个数字对应于一个单词 model. compile ( 'rmsprop' , 'mse' ) output_array = model.predict(input_array) assert output_array.shape = = ( 32 , 10 , 64 ) # 没有提示错误,说明维度输出是正确的 # 32 表示字符串数量 # 10 表示字符串长度 # 64 表示稠密绝阵的维度 |
分类:
AI Related / NLP
, AI Related
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)