pytorch中tensor.mean(axis, keepdim)

 1 import numpy as np
 2 import torch
 3 
 4 x=[
 5 [[1,2,3,4],
 6  [5,6,7,8],
 7  [9,10,11,12]],
 8 
 9 [[13,14,15,16],
10  [17,18,19,20],
11  [21,22,23,24]]
12 ]
13 x=torch.tensor(x).float()
14 #
15 print("shape of x:")  ##[2,3,4]
16 print(x.shape)
17 #
18 print("shape of x.mean(axis=0,keepdim=True):")          #[1, 3, 4]
19 print(x.mean(axis=0,keepdim=True).shape)
20 print(x.mean(axis=0,keepdim=True))
21 #
22 print("shape of x.mean(axis=0,keepdim=False):")         #[3, 4]
23 print(x.mean(axis=0,keepdim=False).shape)
24 print(x.mean(axis=0,keepdim=False))
25 #
26 print("shape of x.mean(axis=1,keepdim=True):")          #[2, 1, 4]
27 print(x.mean(axis=1,keepdim=True).shape)
28 print(x.mean(axis=1,keepdim=True))
29 #
30 print("shape of x.mean(axis=1,keepdim=False):")         #[2, 4]
31 print(x.mean(axis=1,keepdim=False).shape)
32 print(x.mean(axis=1,keepdim=False))
33 #
34 print("shape of x.mean(axis=2,keepdim=True):")          #[2, 3, 1]
35 print(x.mean(axis=2,keepdim=True).shape)
36 print(x.mean(axis=2,keepdim=True))
37 #
38 print("shape of x.mean(axis=2,keepdim=False):")         #[2, 3]
39 print(x.mean(axis=2,keepdim=False).shape)
40 print(x.mean(axis=2,keepdim=False))

 

shape of x:
torch.Size([2, 3, 4])
shape of x.mean(axis=0,keepdim=True):
torch.Size([1, 3, 4])
tensor([[[ 7.,  8.,  9., 10.],
         [11., 12., 13., 14.],
         [15., 16., 17., 18.]]])
shape of x.mean(axis=0,keepdim=False):
torch.Size([3, 4])
tensor([[ 7.,  8.,  9., 10.],
        [11., 12., 13., 14.],
        [15., 16., 17., 18.]])
shape of x.mean(axis=1,keepdim=True):
torch.Size([2, 1, 4])
tensor([[[ 5.,  6.,  7.,  8.]],

        [[17., 18., 19., 20.]]])
shape of x.mean(axis=1,keepdim=False):
torch.Size([2, 4])
tensor([[ 5.,  6.,  7.,  8.],
        [17., 18., 19., 20.]])
shape of x.mean(axis=2,keepdim=True):
torch.Size([2, 3, 1])
tensor([[[ 2.5000],
         [ 6.5000],
         [10.5000]],

        [[14.5000],
         [18.5000],
         [22.5000]]])
shape of x.mean(axis=2,keepdim=False):
torch.Size([2, 3])
tensor([[ 2.5000,  6.5000, 10.5000],
        [14.5000, 18.5000, 22.5000]])

 

keepdim=True
运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);

keepdim=False
运算完之后一般少一维度,求平均变为1的那一维没有了;

axis=k
按第k维运算,其他维度不遍,第k维变为1

# print(x.mean().shape)
# print(x.mean())

shape of x:
torch.Size([2, 3, 4])
torch.Size([])
tensor(12.5000)#所有值的平均值

posted on 2020-09-05 10:55  cltt  阅读(2366)  评论(0编辑  收藏  举报

导航