pytorch中tensor.mean(axis, keepdim)
1 import numpy as np 2 import torch 3 4 x=[ 5 [[1,2,3,4], 6 [5,6,7,8], 7 [9,10,11,12]], 8 9 [[13,14,15,16], 10 [17,18,19,20], 11 [21,22,23,24]] 12 ] 13 x=torch.tensor(x).float() 14 # 15 print("shape of x:") ##[2,3,4] 16 print(x.shape) 17 # 18 print("shape of x.mean(axis=0,keepdim=True):") #[1, 3, 4] 19 print(x.mean(axis=0,keepdim=True).shape) 20 print(x.mean(axis=0,keepdim=True)) 21 # 22 print("shape of x.mean(axis=0,keepdim=False):") #[3, 4] 23 print(x.mean(axis=0,keepdim=False).shape) 24 print(x.mean(axis=0,keepdim=False)) 25 # 26 print("shape of x.mean(axis=1,keepdim=True):") #[2, 1, 4] 27 print(x.mean(axis=1,keepdim=True).shape) 28 print(x.mean(axis=1,keepdim=True)) 29 # 30 print("shape of x.mean(axis=1,keepdim=False):") #[2, 4] 31 print(x.mean(axis=1,keepdim=False).shape) 32 print(x.mean(axis=1,keepdim=False)) 33 # 34 print("shape of x.mean(axis=2,keepdim=True):") #[2, 3, 1] 35 print(x.mean(axis=2,keepdim=True).shape) 36 print(x.mean(axis=2,keepdim=True)) 37 # 38 print("shape of x.mean(axis=2,keepdim=False):") #[2, 3] 39 print(x.mean(axis=2,keepdim=False).shape) 40 print(x.mean(axis=2,keepdim=False))
shape of x: torch.Size([2, 3, 4]) shape of x.mean(axis=0,keepdim=True): torch.Size([1, 3, 4]) tensor([[[ 7., 8., 9., 10.], [11., 12., 13., 14.], [15., 16., 17., 18.]]]) shape of x.mean(axis=0,keepdim=False): torch.Size([3, 4]) tensor([[ 7., 8., 9., 10.], [11., 12., 13., 14.], [15., 16., 17., 18.]]) shape of x.mean(axis=1,keepdim=True): torch.Size([2, 1, 4]) tensor([[[ 5., 6., 7., 8.]], [[17., 18., 19., 20.]]]) shape of x.mean(axis=1,keepdim=False): torch.Size([2, 4]) tensor([[ 5., 6., 7., 8.], [17., 18., 19., 20.]]) shape of x.mean(axis=2,keepdim=True): torch.Size([2, 3, 1]) tensor([[[ 2.5000], [ 6.5000], [10.5000]], [[14.5000], [18.5000], [22.5000]]]) shape of x.mean(axis=2,keepdim=False): torch.Size([2, 3]) tensor([[ 2.5000, 6.5000, 10.5000], [14.5000, 18.5000, 22.5000]])
keepdim=True
运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);
keepdim=False
运算完之后一般少一维度,求平均变为1的那一维没有了;
axis=k
按第k维运算,其他维度不遍,第k维变为1
# print(x.mean().shape)
# print(x.mean())
shape of x:
torch.Size([2, 3, 4])
torch.Size([])
tensor(12.5000)#所有值的平均值