Numpy与torch重点 - 关于维度操作与广播

Tensor数据是更高维度的数组,其关于坐标轴的操作总是难以理解。特在 Jupyter Notebook 中尝试,然后总结一些重点的案例,便于理解学习。(输出太长了,仅放出代码,import numpy as np 即可 run)

1:按照指定的索引顺序,取一个矩阵中的某几行、列的元素。

arr = np.arange(32).reshape((8, 4))
print("arr:\n", arr)

# 仅仅是获得了4个元素组成的向量,[1,5,7,2]为行索引,[0,3,1,2]为列索引
print(arr[[1,5,7,2], [0, 3, 1, 2]])  # return a vector
# print(arr[[1, 0],[5, 3],[7, 1],[2, 2]])  # error

# get a rectangular area
# 按照指定的索引取二维子序列

# 错误示例:仅仅是两次调换行索引而已,先获取[1,5,7,2]行的元素,然后再按照[0,3,1,2]的行标为顺序,重新获取对应行元素而已。
print("arr[[1,5,7,2,]][[0,3,1,2]]:\n",arr[[1,5,7,2]][[0,3,1,2]])

# 正确示例:先按指定顺序取行,然后对所得视图的每一行,根据指定的列顺序取列
# 也就是说,[:, [0,3,1,2]]是对应列索引, : 代表第一维(所有行)。
print("arr[[1,5,7,2,]][:,[0,3,1,2]]:\n",arr[[1,5,7,2,]][:,[0,3,1,2]])
# 一个更简单的方法
print("arr[np.ix_([],[])]:\n", arr[np.ix_([1,5,7,2],[0,3,1,2])])

2 重点操作:高维数据的轴对换。最经典的案例就是矩阵转置 transpose ,对换了axis 0 和 axis 1,但如果是更高维度呢?(张量,即 tensor 数据结构)

# # 2 dim data, matrix
# arr = np.arange(15).reshape((3, 5))
# print("arr:\n", arr)
# print("arr.T\n", arr.T, "\ntranspose(arr):\n", np.transpose(arr))

# 3 dim data
# 第0维:2个矩阵;第1维:每个矩阵3行;第2维,每行4个元素
arr = np.arange(24).reshape((2,3,4))
print("arr:\n", arr)
# 调换 0 1 两维度 --> 第0维:3个矩阵;第1维:每个矩阵2行;第2维,每行4个元素
# 其实是先获取每个二维张量的一行、得到一个2×4 二维张量,如此逐行读取,得到总共3个二维张量。
print("arr.transpose((1,0,2)):\n", arr.transpose((1,0,2)))  # swap the order of axes
print("arr.swapaxes(1,2):\n", arr.swapaxes(1,2))  # swap the order of axes

3 torch.tensor 的类似操作

(1) torch.tensor的reshape

a = torch.randint(-5, 5, (4, 2, 3, 5))
print("a.shape:", a.shape)
print(a.reshape(8, 3, 5).shape)
print(a.reshape(8, 3, 5).ndim)

a = torch.randint(-5, 5, (3, 5))
print("a.shape:", a.shape)
print("a:\n",a)

# 可见,都是一行一行读取(列优先)拉成一行、一列,拉成一维的维度仍然是2
a_1_15 = a.reshape(1, 15)
print("a_1_15 ndim:", a_1_15.ndim)
print(a_1_15.shape)
# print(a_1_15)

a_15_1 = a.reshape(15, 1)
print("a_15_1.ndim:",a_15_1.ndim)
print(a_15_1.shape)
# print(a_15_1)

"""
a.shape: torch.Size([4, 2, 3, 5])
torch.Size([8, 3, 5])
3
a.shape: torch.Size([3, 5])
a:
 tensor([[-2, -1, -2, -2, -2],
        [-2, -3, -2, -3, -1],
        [ 2,  3,  4,  2, -2]])
a_1_15 ndim: 2
torch.Size([1, 15])
a_15_1.ndim: 2
torch.Size([15, 1])
"""

(2) torch.tensor 的view,和reshape几乎一模一样。

(3) squeeze, unsqueeze:后者添加指定的轴(须在维度范围内),前者删除指定的轴(必需此轴的长度为1)

# torch.unsqueeze
# 增加新的维度,类似于numpy里面的newaxis、reshape
a = torch.randint(-5, 5, (4, 2, 3, 5))
print("a.ndim:",a.ndim)
print("a.shape:", a.shape)

print(a.shape)
print("a.unsqueeze(0).shape:",a.unsqueeze(0).shape)
print("a.unsqueeze(-1).shape:",a.unsqueeze(-1).shape)
print("a.unsqueeze(3).shape:",a.unsqueeze(3).shape)
print("a.unsqueeze(4).shape:", a.unsqueeze(4).shape)
print("a.unsqueeze(-4).shape:",a.unsqueeze(-4).shape)
print("a.unsqueeze(-5).shape:",a.unsqueeze(-5).shape)
print("a.unsqueeze(5).shape:")
print(a.unsqueeze(5).shape)
"""
a.ndim: 4
a.shape: torch.Size([4, 2, 3, 5])
torch.Size([4, 2, 3, 5])
a.unsqueeze(0).shape: torch.Size([1, 4, 2, 3, 5])
a.unsqueeze(-1).shape: torch.Size([4, 2, 3, 5, 1])
a.unsqueeze(3).shape: torch.Size([4, 2, 3, 1, 5])
a.unsqueeze(4).shape: torch.Size([4, 2, 3, 5, 1])
a.unsqueeze(-4).shape: torch.Size([4, 1, 2, 3, 5])
a.unsqueeze(-5).shape: torch.Size([1, 4, 2, 3, 5])
a.unsqueeze(5).shape:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[45], line 15
     13 print("a.unsqueeze(-5).shape:",a.unsqueeze(-5).shape)
     14 print("a.unsqueeze(5).shape:")
---> 15 print(a.unsqueeze(5).shape)

IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
"""

# 删去多余的维度,即长度为1的维度
# torch.Tensor()创建指定形状的一堆垃圾值
a = torch.Tensor(1,4,1,3,1,9)
print(a.shape)
# print("a\n", a)
print(a.squeeze().shape) # 删除所有的size=1的维度
print(a.squeeze(0).shape) # 删除0号维度,ok
print(a.squeeze(2).shape) # 删除2号维度,ok
print(a.squeeze(3).shape) # 删除3号维度,但是3号维度是9不是1,删除失败
"""
torch.Size([1, 4, 1, 3, 1, 9])
torch.Size([4, 3, 9])
torch.Size([4, 1, 3, 1, 9])
torch.Size([1, 4, 3, 1, 9])
torch.Size([1, 4, 1, 3, 1, 9])
"""

(4) expand

a = torch.rand(5)
a = a.unsqueeze(0)
print(a)
print(a.expand(3, -1))
"""
tensor([[0.5667, 0.6940, 0.1985, 0.3665, 0.4311]])
tensor([[0.5667, 0.6940, 0.1985, 0.3665, 0.4311],
        [0.5667, 0.6940, 0.1985, 0.3665, 0.4311],
        [0.5667, 0.6940, 0.1985, 0.3665, 0.4311]])
"""

a = torch.rand(5)
a1 = a.unsqueeze(1)
print(a1)
print(a1.expand(-1, 3))
"""
tensor([[0.0968],
        [0.6465],
        [0.3258],
        [0.3345],
        [0.3632]])
tensor([[0.0968, 0.0968, 0.0968],
        [0.6465, 0.6465, 0.6465],
        [0.3258, 0.3258, 0.3258],
        [0.3345, 0.3345, 0.3345],
        [0.3632, 0.3632, 0.3632]])
"""

bias = torch.rand(32)
data = torch.rand(4, 32, 14, 14)

# 想要把bias加到data上面去
# 先进行维度增加
bias = bias.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(bias.shape)

# # 再进行维度扩展
bias = bias.expand(4, -1, 14, 14)  # -1表示这个维度保持不变,这里写32也可以
print(bias.shape)
"""
torch.Size([1, 32, 1, 1])
torch.Size([4, 32, 14, 14])
"""

(5) transpose 和 permute

e1 = torch.rand(2,3,5)
print("e1:\n",e1)
# 相当于第0维不变动,调换第1、2维度的顺序,这相当于转置每个矩阵
e2 = e1.transpose(1,2).contiguous()
print("e2:\n",e2)

# batch channel height width
e1 = torch.rand(5, 3, 14, 14)
print("e1.shape:\n",e1.shape)
# 相当于第0维不变动,调换第1、2维度的顺序,这相当于转置每个矩阵
e2 = e1.permute(0, 2, 3, 1)
print("e2:\n",e2.shape)

4 Broadcast广播,轴

Broadcast 的条件。简洁地说,假如 tensor2 要在 tensor1上进行广播的话,tensor2 的 shape的每个元素,都要小于或等于 tensor1。(至于是否需要除尽,还没研究,大概很少见的情况吧,遇到再看)。

# ### very important, crucial, vital, momentous, critical !!
# Observe the principles of broadcasting mechanisms

# 二维张量广播
arr = np.random.randn(4,3)
print("arr:\n", arr)
print("col mean, arr.mean(0):", arr.mean(0))
print("row mean, arr.mean(1):", arr.mean(1))
# 必须把 mean 的结果转为 shape(4,1),因为这样才能贴合 原tensor的shape(4,3)
# 否则 mean(1)结果为 长度为4的向量(1行4列),然而原矩阵为4行3列,形状都对不齐,无法broadcast。
print("arr - arr.mean(1).reshape((4,1))",arr - arr.mean(1).reshape((4,1)))

# mean(0)为3列的,自然和 (4,3) 符合broadcast的条件
demeaned = arr - arr.mean(0)
print("arr - arr.mean(0):\n", demeaned)
# print("demeaned.mean(0):", demeaned.mean(0))

# # 3 dim data 三维张量广播
# arr = np.random.randint(-5, 6, (3, 4, 2))
# # print("arr:\n",arr)
# print("arr.mean(1):\n", arr.mean(1))
#  (3 1 2)在(3 4 2)上广播,符合Broadcast的条件
# print("arr - arr.mean(1).reshape((3,1,2))\n")
# print(arr - arr.mean(1).reshape((3,1,2)))

# print("arr.mean(1):\n", arr.mean(1))

# reshape的问题在于,无法胜任任意情况,因为需要手动构造一个元组作为参数传递给reshape
# 理想的方法是,添加一个轴就行,不需要知道各维度具体数值,这样才能“自动化”
# The problem with reshape is that it cannot handle any situation as it requires manually constructing a tuple as a parameter to pass to the reshape
# 重点:关于广播添加新的轴
# Key point: Adding new axes for tuples to utilize broadcasting
# arr2d = np.random.randint(-5 ,6, (3, 2))
# arr_3d = arr2d[:, np.newaxis, :]
# print("arr_3d shape:", arr_3d.shape)
# print(arr_3d)

# arr_1d = np.random.normal(size=3)
# print("arr_1d:", arr_1d)
# arr_2d = arr_1d[:,np.newaxis]
# print(arr_2d)

# arr = np.zeros((4,3))
# col = np.array([3,5,7,5])
# arr[:] = col[:, np.newaxis]
# print("arr:\n", arr)
# arr[:2] = np.array([[-2], [-3]])
# print(arr)

5 常见的操作,reshape 与 flatten / ravel,向量 reshape 变为其他形状的 tensor,高维 tensor 也可以 flatten/ravel 为一维向量。

arr = np.arange(15)
print("arr.reshape((3, 5))")
arr1 = arr.reshape((3, 5))
print("arr1, arr.reshape((3, 5)):\n", arr1)
arr2 = arr1.ravel()
print("arr2, arr1.ravel():\n", arr2)
arr3 = arr1.flatten()
print("arr3, after flatten():\n", arr3)

6 常见的操作,concatenate 和 split 及其类似操作(hstack vstack np.c_ np.r_)

# ### very important, crucial, vital, momentous, critical !!
arr1 = np.array([[1,2,3],[4,5,6]])
arr2 = np.array([[7,8,9],[10,11,12]])

print("concatenate rows, align the column(axis=0)")
print("np.concatenate([arr1, arr2], axis=0)\n",np.concatenate([arr1, arr2], axis=0))
print("np.vstack((arr1, arr2))\n",np.vstack((arr1, arr2)))
print("np.r_[arr1, arr2]:\n", np.r_[arr1, arr2])

print("concatenate columns, align the row   (axis=1)")
print("np.concatenate([arr1, arr2], axis=1)\n",np.concatenate([arr1, arr2], axis=1))
arr3 = np.hstack((arr1, arr2))
print("np.hstack((arr1, arr2))\n",np.hstack((arr1, arr2)))
print("np.c_[arr1, arr2]:\n", np.c_[arr1, arr2])

f, s, t = np.split(arr3, [1, 3], axis=1)
print("f:\n{0}\ns:\n{1}\nt:\n{2}".format(f, s, t))

 

posted @ 2023-08-03 00:03  倦鸟已归时  阅读(77)  评论(0编辑  收藏  举报