tf中的stack、concat、unstack、split
concat和stack
tf.concat需要保证除了拼接的维度,其他维度都相同。
tf.stack需要保证所有的维度都相同。
tf.concat中axis参数是设置那个维度进行拼接。
tf.stack中axis是表示新创建的维度的位置。
tf.stack会创造一个维度(要求进行stack的tensor的shape是一样的),而tf.concat会在指定的维度上进行合并。
tf.stack可以通过axis参数设置新添加的维度的位置,比如axis=0就是放在最前。
import tensorflow as tf
a = tf.ones([4, 35, 6])
b = tf.ones([2, 35, 6])
c = tf.concat([a, b], axis=0)
print(c.shape) # [6,35,6]
a_1 = tf.ones([4, 32, 6])
b_1 = tf.ones([4, 3, 6])
c_1 = tf.concat([a_1, b_1], axis=1)
print(c_1.shape) # [4, 35, 6]
a_0 = tf.ones([4, 35, 6])
b_0 = tf.ones(a_0.shape)
s_0 = tf.stack([a_0, b_0], axis=0)
print(s_0.shape) # [2, 4, 35, 6]
unstack与split
unstack是将指定的维度进行拆分,拆分数量是指定维度大小,拆分后指定维度概念消失。每一个为之前的一份的概念。
split可以指定新生成的tensor占原来的维度的份数,num_or_size_splits参数指定每个tensor的份数
a_2 = tf.ones([4, 31, 6])
b_2 = tf.ones([4, 31, 6])
c_2 = tf.stack([a_2, b_2], axis=0)
print("c_2.shape:",c_2.shape) #[2, 4, 31, 6]
a_2_1, b_2_1 = tf.unstack(c_2, axis=0)
print("a_2_1,b_2_1.shape",a_2_1.shape, b_2_1.shape) #[4, 31, 6] [4,31,6]
# 注意,拆分的维度消失,默认拆成为1份
res = tf.unstack(c_2, axis=3)
for it in res:
print(it.shape) #[2,4,31]共6个,可通过res[0],res[1]...来访问
res_split = tf.split(c_2, num_or_size_splits=[1,2,3],axis=3)
for it in res_split:
print(it.shape)
'''
(2, 4, 31, 1)
(2, 4, 31, 2)
(2, 4, 31, 3)
'''
unstack
Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
Unpacks `num` tensors from `value` by chipping it along the `axis` dimension.
If `num` is not specified (the default), it is inferred from `value`'s shape.
If `value.shape[axis]` is not known, `ValueError` is raised.
For example, given a tensor of shape `(A, B, C, D)`;
If `axis == 0` then the i'th tensor in `output` is the slice
`value[i, :, :, :]` and each tensor in `output` will have shape `(B, C, D)`.
(Note that the dimension unpacked along is gone, unlike `split`).
If `axis == 1` then the i'th tensor in `output` is the slice
`value[:, i, :, :]` and each tensor in `output` will have shape `(A, C, D)`.
Etc.
This is the opposite of stack.
Args:
value: A rank `R > 0` `Tensor` to be unstacked.
num: An `int`. The length of the dimension `axis`. Automatically inferred if
`None` (the default).
axis: An `int`. The axis to unstack along. Defaults to the first dimension.
Negative values wrap around, so the valid range is `[-R, R)`.
name: A name for the operation (optional).
Returns:
The list of `Tensor` objects unstacked from `value`.
Raises:
ValueError: If `num` is unspecified and cannot be inferred.
ValueError: If `axis` is out of the range [-R, R).
split
Splits a tensor `value` into a list of sub tensors.
See also `tf.unstack`.
If `num_or_size_splits` is an integer, then `value` is split along the
dimension `axis` into `num_or_size_splits` smaller tensors. This requires that
`value.shape[axis]` is divisible by `num_or_size_splits`.
If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
`len(num_or_size_splits)` elements. The shape of the `i`-th
element has the same size as the `value` except along dimension `axis` where
the size is `num_or_size_splits[i]`.
For example:
>>> x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
>>>
>>> # Split `x` into 3 tensors along dimension 1
>>> s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
>>> tf.shape(s0).numpy()
array([ 5, 10], dtype=int32)
>>>
>>> # Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
>>> split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
>>> tf.shape(split0).numpy()
array([5, 4], dtype=int32)
>>> tf.shape(split1).numpy()
array([ 5, 15], dtype=int32)
>>> tf.shape(split2).numpy()
array([ 5, 11], dtype=int32)
Args:
value: The `Tensor` to split.
num_or_size_splits: Either an integer indicating the number of splits along
`axis` or a 1-D integer `Tensor` or Python list containing the sizes of
each output tensor along `axis`. If a scalar, then it must evenly divide
`value.shape[axis]`; otherwise the sum of sizes along the split axis
must match that of the `value`.
axis: An integer or scalar `int32` `Tensor`. The dimension along which to
split. Must be in the range `[-rank(value), rank(value))`. Defaults to 0.
num: Optional, used to specify the number of outputs when it cannot be
inferred from the shape of `size_splits`.
name: A name for the operation (optional).
Returns:
if `num_or_size_splits` is a scalar returns a list of `num_or_size_splits`
`Tensor` objects; if `num_or_size_splits` is a 1-D Tensor returns
`num_or_size_splits.get_shape[0]` `Tensor` objects resulting from splitting
`value`.
Raises:
ValueError: If `num` is unspecified and cannot be inferred.