tf.expand_dims()和tf.squeeze()
分类:
TensorFlow
1.tf.expand_dims()
1 | tf.expand_dims( input , axis = None , name = None , dim = None ) |
作用:给定张量,输入形状的维度索引轴处插入1的尺寸。 尺寸索引轴从零开始; 如果为指定的轴为负数,则从末尾开始算起。
参数:
- input:张量。
- aixs:0-D(标量),指定扩大输入形状的维度索引。
- name:输出名称Tensor。
- dim:0-D(标量), 等同于轴,不推荐使用。
返回:具有与输入相同数据的张量,但其形状添加了尺寸为1的附加尺寸。
如果要将批次尺寸添加到单个元素,此操作很有用。 例如,如果您有一个形状为[[height,width,channels]`的图像,则可以将其与具有`expand_dims(image,0)`的1张图像一起批处理,这将使形状为[1,height ,width,channels]。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | # 't' is a tensor of shape [2] tf.shape(tf.expand_dims(t, 0 )) # [1, 2] tf.shape(tf.expand_dims(t, 1 )) # [2, 1] tf.shape(tf.expand_dims(t, - 1 )) # [2, 1] # 't2' is a tensor of shape [2, 3, 5] tf.shape(tf.expand_dims(t2, 0 )) # [1, 2, 3, 5] tf.shape(tf.expand_dims(t2, 2 )) # [2, 3, 1, 5] tf.shape(tf.expand_dims(t2, 3 )) # [2, 3, 5, 1] ``` This operation requires that: ` - 1 - input .dims() < = dim < = input .dims()` This operation is related to `squeeze()`, which removes dimensions of size 1. Args: input : A `Tensor`. axis: 0 - D (scalar). Specifies the dimension index at which to expand the shape of ` input `. Must be in the range `[ - rank( input ) - 1 , rank( input )]`. name: The name of the output `Tensor`. dim: 0 - D (scalar). Equivalent to `axis`, to be deprecated. Returns: A `Tensor` with the same data as ` input `, but its shape has an additional dimension of size 1 added. Raises: ValueError: if both `dim` and `axis` are specified. |
bert中源码:
1 2 3 4 | # 该函数默认输入的形状为【batch_size, seq_length, input_num】 # 如果输入为2D的【batch_size, seq_length】,则扩展到【batch_size, seq_length, 1】 if input_ids.shape.ndims = = 2 : input_ids = tf.expand_dims(input_ids, axis = [ - 1 ]) |
2.tf.squeeze()
1 | tf.squeeze( input , squeeze_dims = None , name = None ) |
作用:给定张量输入,此操作返回相同类型的张量,并删除所有尺寸为1的维度。 如果不想删除所有尺寸为1的维度,可以通过指定squeeze_dims来删除特定尺寸的维度。
参数:
- input:要挤压的张量
- squeeze_dims:
- 可选的ints列表, 默认为[]。
- 如果指定,只能挤压列出的维度。
- 维度索引从0开始,挤压不是1的维度是一个错误
- name:操作的名称(可选)
返回:与输入的类型相同。 包含与输入相同的数据,但具有一个或多个尺寸为1的维度被删除。
举例:
import tensorflow as tf
sess = tf.InteractiveSession()
t1 = tf.constant([1,2,3,4,5,6],shape=[1,2,3,1])
print('t1')
print(t1.eval())
t2 = tf.squeeze(t1)
print('t2')
print(t2.eval())
t3 = tf.squeeze(t1,[3])
print('t3')
print(t3.eval())
参考文献:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现