jax:dynamic_slice,slice,pad

def dynamic_slice(operand: Array, start_indices: Sequence[Array],
                  slice_sizes: Shape) -> Array:
  """Wraps XLA's `DynamicSlice
  <https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
  operator.

  Args:
    operand: an array to slice.
    start_indices: a list of scalar indices, one per dimension. These values
      may be dynamic.
    slice_sizes: the size of the slice. Must be a sequence of non-negative
      integers with length equal to `ndim(operand)`. Inside a JIT compiled
      function, only static values are supported (all JAX arrays inside JIT
      must have statically known size).

  Returns:
    An array containing the slice.

  Examples:
    Here is a simple two-dimensional dynamic slice:

    >>> x = jnp.arange(12).reshape(3, 4)
    >>> x
    DeviceArray([[ 0,  1,  2,  3],
                 [ 4,  5,  6,  7],
                 [ 8,  9, 10, 11]], dtype=int32)

    >>> dynamic_slice(x, (1, 1), (2, 3))
    DeviceArray([[ 5,  6,  7],
                 [ 9, 10, 11]], dtype=int32)

    Note the potentially surprising behavior for the case where the requested slice
    overruns the bounds of the array; in this case the start index is adjusted to
    return a slice of the requested size:

    >>> dynamic_slice(x, (1, 1), (2, 4))
    DeviceArray([[ 4,  5,  6,  7],
                 [ 8,  9, 10, 11]], dtype=int32)

2. tf.slice  tf.gather

    tf.slice(input_, begin, size, name=None):按照指定的下标范围抽取连续区域的子集
 
    tf.gather(params, indices, validate_indices=None, name=None):按照指定的下标集合从axis=0中抽取子集,适合抽取不连续区域的子集
输出:
input = [[[1, 1, 1], [2, 2, 2]],
         [[3, 3, 3], [4, 4, 4]],
         [[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
                                            [4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
                                           [[5, 5, 5]]]
                                           
tf.gather(input, [0, 2]) ==> [[[1, 1, 1], [2, 2, 2]],
                              [[5, 5, 5], [6, 6, 6]]]
假设我们要从input中抽取[[[3, 3, 3]]],这个输出在inputaxis=0的下标是1,axis=1的下标是0,axis=2的下标是0-2,所以begin=[1,0,0],size=[1,1,3]。
 
假设我们要从input中抽取[[[3, 3, 3], [4, 4, 4]]],这个输出在inputaxis=0的下标是1,axis=1的下标是0-1,axis=2的下标是0-2,所以begin=[1,0,0],size=[1,2,3]。
 
假设我们要从input中抽取[[[3, 3, 3], [5, 5, 5]]],这个输出在inputaxis=0的下标是1-2,axis=1的下标是0,axis=2的下标是0-2,所以begin=[1,0,0],size=[2,1,3]。
 
假设我们要从input中抽取[[[1, 1, 1], [2, 2, 2]],[[5, 5, 5], [6, 6, 6]]],这个输出在input的axis=0的下标是[0, 2],不连续,可以用tf.gather抽取。input[0]和input[2]

3. jax.lax.slice

jax.lax.slice同tf.slice一样,只是在死三个参数中的size变成了end。

 

 4. jnp.pad

import jax.numpy as jnp

x = jnp.arange(12).reshape(3, 4)
print(x)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
y=jax.lax.slice(x,(0,0),(3,3))
print(y)
print('')
z=jnp.pad(y,((0,0),(0,0)))
print(z)

y=jax.lax.slice(x,(0,0),(3,3))
print(y)
print('')
z=jnp.pad(y,((0,0),(1,0)))
print(z)

 

posted @ 2021-11-13 17:29  为红颜  阅读(344)  评论(0编辑  收藏  举报