np.expand_dims: AxisError: axis 4 is out of bounds for array of dimension 4
np.expand_dims
axis = 0时,[]加在最外面
axis = 1时,给每一行都加[]
axis = 2时,给每一个元素都加[]
x_train = np.expand_dims(X, axis=4)
---------------------------------------------------------------------------
AxisError Traceback (most recent call last)
Cell In[5], line 10
8 #X[:, [0, 2], :] = X[:, [2, 0], :]
9 X, Y = shuffle(X, Y, random_state=0)
---> 10 x_train = np.expand_dims(X, axis=4)
11 y_train = Y
13 #calculate class weights
File <__array_function__ internals>:180, in expand_dims(*args, **kwargs)
File /home/software/anaconda3/envs/mydlenv/lib/python3.8/site-packages/numpy/lib/shape_base.py:597, in expand_dims(a, axis)
594 axis = (axis,)
596 out_ndim = len(axis) + a.ndim
--> 597 axis = normalize_axis_tuple(axis, out_ndim)
599 shape_it = iter(a.shape)
600 shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
File /home/software/anaconda3/envs/mydlenv/lib/python3.8/site-packages/numpy/core/numeric.py:1397, in normalize_axis_tuple(axis, ndim, argname, allow_duplicate)
1395 pass
1396 # Going via an iterator directly is slower than via list comprehension.
-> 1397 axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
1398 if not allow_duplicate and len(set(axis)) != len(axis):
1399 if argname:
File /home/software/anaconda3/envs/mydlenv/lib/python3.8/site-packages/numpy/core/numeric.py:1397, in <listcomp>(.0)
1395 pass
1396 # Going via an iterator directly is slower than via list comprehension.
-> 1397 axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
1398 if not allow_duplicate and len(set(axis)) != len(axis):
1399 if argname:
AxisError: axis 4 is out of bounds for array of dimension 4
http://www.xavierdupre.fr/app/mlprodict/helpsphinx/onnxops/onnx__Unsqueeze.html
x = np.random.randn(3, 4, 5).astype(np.float32) for i in range(x.ndim): # 0,1,2 axes = np.array([i]).astype(np.int64) y = np.expand_dims(x, axis=i) print(i,y)
x = np.random.randn(3, 4, 5).astype(np.float32) print(x.ndim) y = np.expand_dims(y, axis=4) y
REF
https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html