tf中的stack、concat、unstack、split

concat和stack

tf.concat需要保证除了拼接的维度,其他维度都相同。
tf.stack需要保证所有的维度都相同。
tf.concat中axis参数是设置那个维度进行拼接。
tf.stack中axis是表示新创建的维度的位置。
tf.stack会创造一个维度(要求进行stack的tensor的shape是一样的),而tf.concat会在指定的维度上进行合并。
tf.stack可以通过axis参数设置新添加的维度的位置,比如axis=0就是放在最前。

import tensorflow as tf
a = tf.ones([4, 35, 6])
b = tf.ones([2, 35, 6])

c = tf.concat([a, b], axis=0)
print(c.shape) # [6,35,6]

a_1 = tf.ones([4, 32, 6])
b_1 = tf.ones([4, 3, 6])
c_1 = tf.concat([a_1, b_1], axis=1)
print(c_1.shape) # [4, 35, 6]

a_0 = tf.ones([4, 35, 6])
b_0 = tf.ones(a_0.shape)
s_0 = tf.stack([a_0, b_0], axis=0)
print(s_0.shape) # [2, 4, 35, 6]

unstack与split

unstack是将指定的维度进行拆分,拆分数量是指定维度大小,拆分后指定维度概念消失。每一个为之前的一份的概念。
split可以指定新生成的tensor占原来的维度的份数,num_or_size_splits参数指定每个tensor的份数

a_2 = tf.ones([4, 31, 6])
b_2 = tf.ones([4, 31, 6])
c_2 = tf.stack([a_2, b_2], axis=0)
print("c_2.shape:",c_2.shape) #[2, 4, 31, 6]
a_2_1, b_2_1 = tf.unstack(c_2, axis=0)
print("a_2_1,b_2_1.shape",a_2_1.shape, b_2_1.shape) #[4, 31, 6] [4,31,6]

# 注意,拆分的维度消失,默认拆成为1份
res = tf.unstack(c_2, axis=3)
for it in res:
    print(it.shape) #[2,4,31]共6个,可通过res[0],res[1]...来访问
res_split = tf.split(c_2, num_or_size_splits=[1,2,3],axis=3)
for it in res_split:
    print(it.shape)
'''
(2, 4, 31, 1)
(2, 4, 31, 2)
(2, 4, 31, 3)
'''

unstack

Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.

  Unpacks `num` tensors from `value` by chipping it along the `axis` dimension.
  If `num` is not specified (the default), it is inferred from `value`'s shape.
  If `value.shape[axis]` is not known, `ValueError` is raised.

  For example, given a tensor of shape `(A, B, C, D)`;

  If `axis == 0` then the i'th tensor in `output` is the slice
    `value[i, :, :, :]` and each tensor in `output` will have shape `(B, C, D)`.
    (Note that the dimension unpacked along is gone, unlike `split`).

  If `axis == 1` then the i'th tensor in `output` is the slice
    `value[:, i, :, :]` and each tensor in `output` will have shape `(A, C, D)`.
  Etc.

  This is the opposite of stack.

  Args:
    value: A rank `R > 0` `Tensor` to be unstacked.
    num: An `int`. The length of the dimension `axis`. Automatically inferred if
      `None` (the default).
    axis: An `int`. The axis to unstack along. Defaults to the first dimension.
      Negative values wrap around, so the valid range is `[-R, R)`.
    name: A name for the operation (optional).

  Returns:
    The list of `Tensor` objects unstacked from `value`.

  Raises:
    ValueError: If `num` is unspecified and cannot be inferred.
    ValueError: If `axis` is out of the range [-R, R).

split

Splits a tensor `value` into a list of sub tensors.

  See also `tf.unstack`.

  If `num_or_size_splits` is an integer,  then `value` is split along the
  dimension `axis` into `num_or_size_splits` smaller tensors. This requires that
  `value.shape[axis]` is divisible by `num_or_size_splits`.

  If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
  `len(num_or_size_splits)` elements. The shape of the `i`-th
  element has the same size as the `value` except along dimension `axis` where
  the size is `num_or_size_splits[i]`.

  For example:

  >>> x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
  >>>
  >>> # Split `x` into 3 tensors along dimension 1
  >>> s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
  >>> tf.shape(s0).numpy()
  array([ 5, 10], dtype=int32)
  >>>
  >>> # Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
  >>> split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
  >>> tf.shape(split0).numpy()
  array([5, 4], dtype=int32)
  >>> tf.shape(split1).numpy()
  array([ 5, 15], dtype=int32)
  >>> tf.shape(split2).numpy()
  array([ 5, 11], dtype=int32)

  Args:
    value: The `Tensor` to split.
    num_or_size_splits: Either an integer indicating the number of splits along
      `axis` or a 1-D integer `Tensor` or Python list containing the sizes of
      each output tensor along `axis`. If a scalar, then it must evenly divide
      `value.shape[axis]`; otherwise the sum of sizes along the split axis
      must match that of the `value`.
    axis: An integer or scalar `int32` `Tensor`. The dimension along which to
      split. Must be in the range `[-rank(value), rank(value))`. Defaults to 0.
    num: Optional, used to specify the number of outputs when it cannot be
      inferred from the shape of `size_splits`.
    name: A name for the operation (optional).

  Returns:
    if `num_or_size_splits` is a scalar returns a list of `num_or_size_splits`
    `Tensor` objects; if `num_or_size_splits` is a 1-D Tensor returns
    `num_or_size_splits.get_shape[0]` `Tensor` objects resulting from splitting
    `value`.

  Raises:
    ValueError: If `num` is unspecified and cannot be inferred.
posted @ 2021-03-11 22:23  cyssmile  阅读(672)  评论(0编辑  收藏  举报