MxNet中NDArray的reshape用法总结
reshape没有特殊值时的常规用法就不用细说了,比如
>>> from mxnet import nd >>> a = nd.array([1,2,3,4]) >>> a.shape (4L,) >>> a.reshape(2,2) [[1. 2.] [3. 4.]] <NDArray 2x2 @cpu(0)>
下面详细讲下带有特殊值,需要推倒的情况:
1. 0 表示复用输入中的维度值
>>> from mxnet import nd >>> a = nd.random.uniform(shape=(2,3,4)) >>> a.shape (2L, 3L, 4L) >>> b = a.reshape(4, 0, 2) >>> b.shape (4L, 3L, 2L)
上面例子中,输出的第2个维度值与输入的第2个维度值保持不变
2. -1 表示根据输入与输出中元素数量守恒的原则,根据已知的维度值,推导值为-1的位置的维度值
>>> from mxnet import nd >>> a = nd.random.uniform(shape=(2,3,4)) >>> a.shape (2L, 3L, 4L) >>> b = a.reshape(-1, 12) >>> b.shape (2L, 12L)
3. -2 表示复用所有该位置之后的维度
>>> from mxnet import nd >>> a = nd.random.uniform(shape=(2,3,4)) >>> a.shape (2L, 3L, 4L) >>> b = a.reshape(2, -2) >>> b.shape (2L, 3L, 4L)
4. -3 表示将连续两个维度相乘作为新的维度
>>> from mxnet import nd >>> a = nd.random.uniform(shape=(2,3,4)) >>> a.shape (2L, 3L, 4L) >>> b = a.reshape(2, -3) >>> b.shape (2L, 12L)
5. -4 表示把当前位置的维度拆分为后面两个维度,这后面两个数的乘积等于当前维度的输入值
>>> a = nd.random.uniform(shape=(2,3,4)) >>> a.shape (2L, 3L, 4L) >>> b = a.reshape(-4, 1, 2, -2) >>> b.shape (1L, 2L, 3L, 4L)
reverse=True时,表示按照从右往左的顺序进行推导,这个推导的技巧是把原维度和reshape的参数右侧对齐,从右往左依次推导 比如:
reverse=False的情况(缺省)
>>> from mxnet import nd >>> a = nd.random.uniform(shape=(10, 5, 4)) >>> a.shape (10L, 5L, 4L) >>> b = a.reshape(-1, 0) >>> b.shape (40L, 5L)
reverse=True的情况下:
>>> from mxnet import nd >>> a = nd.random.uniform(shape=(10, 5, 4)) >>> a.shape (10L, 5L, 4L) >>> b = a.reshape(-1, 0, reverse=True) >>> b.shape (50L, 4L)