PyTorch中的dim

PyTorch中对tensor的很多操作如sum,softmax等都可以设置dim参数用来指定操作在哪一维进行。PyTorch中的dim类似于numpy中的axis。

dim与方括号的关系

创建一个矩阵

a = torch.tensor([[1, 2], [3, 4]])
print(a)

输出:

tensor([[1, 2],
        [3, 4]])

因为a是一个矩阵,所以a的左边有2个括号

括号之间是嵌套关系,代表了不同的维度。从左往右数,两个括号代表的维度分别是0和1,在第0维遍历得到向量,在

第1维遍历得到标量

同样地,对于3维tensor

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)

输出

tensor([[[3, 2],
         [1, 4]],

        [[5, 6],
         [7, 8]]])

则3个括号代表的维度从左往右分别为0,1,2,在第0维遍历得到矩阵,在第1维遍历得到向量,在第2维遍历得到标量

更详细一点

在指定的维度上进行操作

在某一维度求和(或者进行其他操作)就是对该维度中的元素进行求和。

对于矩阵a

a = torch.tensor([[1, 2], [3, 4]])
print(a)

输出

tensor([[1, 2],
        [3, 4]])

求a在第0维的和,因为第0维代表最外边的括号,括号中的元素为向量 [1,2] , [3,4],第0维的和就是第0维中的元素相加,也就是两个向量 [1,2] , [3,4] 相加,所以结果为

[1 , 2 ] + [3 , 4 ] = [4 , 6]

s = torch.sum(a, dim=0)
print(s)

输出

tensor([4, 6])

可以看到,a是2维矩阵,而相加的结果为1维向量,可以使用参数keepdim = True来保证维度数目不变。

s = torch.sum(a, dim=0, keepdim=True)
print(s)

输出

tensor([[4, 6]])

在a的第0维求和,就是对第0维中的元素(向量)进行相加。同样的,对a第1维求和,就是对a第1维中的元素(标量)进行相加,a的第1维元素为标量1,2和3,4,则结果为

[1 + 2 ] = [3] ,[ 3 + 4 ] = [7]

s = torch.sum(a, dim=1)
print(s)

输出

tensor([3, 7])

保持维度不变

s = torch.sum(a, dim=1, keepdim=True)
print(s)

输出

tensor([[3],
        [7]])

对3维tensor的操作也是这样

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)

输出

tensor([[[3, 2],
         [1, 4]],

        [[5, 6],
         [7, 8]]])

将b在第0维相加,第0维为最外层括号,最外层括号中的元素为矩阵[ [3 , 2], [1 , 4] ]和[ [5, 6] ,[7, 8] ]。在第0维求和,

就是将第0维的元素(矩阵)相加

s = torch.sum(b, dim=0)
print(s)

输出

tensor([[ 8,  8],
        [ 8, 12]])

求b在第1维的和,就是将b第1维中的元素[ 3, 2] 和[ 1 , 4 ],[ 5 , 6]和 [7 , 8 ]相加,所以

s = torch.sum(b, dim=1)
print(s)

输出

tensor([[ 4,  6],
        [12, 14]])

则在b的第2维求和,就是对标量3和2,1和4,5和6,7和8求和

s = torch.sum(b, dim=2)
print(s)

结果为

tensor([[ 5,  5],
        [11, 15]])

除了求和,其他操作也是类似的,如求b在指定维度上的最大值

m = torch.max(b, dim=0)
print(m)

b在第0维的最大值是第0维中的元素(两个矩阵[[3,2],[1,4]]和[[5,6],[7,8]])的最大值,取矩阵对应位置最大值即可

结果为

torch.return_types.max(
values=tensor([[5, 6],
        [7, 8]]),
indices=tensor([[1, 1],
        [1, 1]]))

b在第1维的最大值就是第1维元素(4个(2对)向量)的最大值

m = torch.max(b, dim=1)
print(m)

输出为

torch.return_types.max(
values=tensor([[3, 4],
        [7, 8]]),
indices=tensor([[0, 1],
        [1, 1]]))

b在第0维的最大值就是第0维元素(8个(4对)标量)的最大值

m = torch.max(b, dim=2)
print(m)

输出

torch.return_types.max(
values=tensor([[3, 4],
        [6, 8]]),
indices=tensor([[0, 1],
        [1, 1]]))

总结

在tensor的指定维度操作就是对指定维度包含的元素进行操作,如果想要保持结果的维度不变,设置参数keepdim = True即可。

 

原文链接:https://www.cnblogs.com/flix/p/11262606.html

posted @ 2023-03-05 22:21  sqsq  阅读(103)  评论(0编辑  收藏  举报