pytorch中tensor.mean(axis, keepdim)

复制代码
 1 import numpy as np
 2 import torch
 3 
 4 x=[
 5 [[1,2,3,4],
 6  [5,6,7,8],
 7  [9,10,11,12]],
 8 
 9 [[13,14,15,16],
10  [17,18,19,20],
11  [21,22,23,24]]
12 ]
13 x=torch.tensor(x).float()
14 #
15 print("shape of x:")  ##[2,3,4]
16 print(x.shape)
17 #
18 print("shape of x.mean(axis=0,keepdim=True):")          #[1, 3, 4]
19 print(x.mean(axis=0,keepdim=True).shape)
20 print(x.mean(axis=0,keepdim=True))
21 #
22 print("shape of x.mean(axis=0,keepdim=False):")         #[3, 4]
23 print(x.mean(axis=0,keepdim=False).shape)
24 print(x.mean(axis=0,keepdim=False))
25 #
26 print("shape of x.mean(axis=1,keepdim=True):")          #[2, 1, 4]
27 print(x.mean(axis=1,keepdim=True).shape)
28 print(x.mean(axis=1,keepdim=True))
29 #
30 print("shape of x.mean(axis=1,keepdim=False):")         #[2, 4]
31 print(x.mean(axis=1,keepdim=False).shape)
32 print(x.mean(axis=1,keepdim=False))
33 #
34 print("shape of x.mean(axis=2,keepdim=True):")          #[2, 3, 1]
35 print(x.mean(axis=2,keepdim=True).shape)
36 print(x.mean(axis=2,keepdim=True))
37 #
38 print("shape of x.mean(axis=2,keepdim=False):")         #[2, 3]
39 print(x.mean(axis=2,keepdim=False).shape)
40 print(x.mean(axis=2,keepdim=False))
复制代码

 

复制代码
shape of x:
torch.Size([2, 3, 4])
shape of x.mean(axis=0,keepdim=True):
torch.Size([1, 3, 4])
tensor([[[ 7.,  8.,  9., 10.],
         [11., 12., 13., 14.],
         [15., 16., 17., 18.]]])
shape of x.mean(axis=0,keepdim=False):
torch.Size([3, 4])
tensor([[ 7.,  8.,  9., 10.],
        [11., 12., 13., 14.],
        [15., 16., 17., 18.]])
shape of x.mean(axis=1,keepdim=True):
torch.Size([2, 1, 4])
tensor([[[ 5.,  6.,  7.,  8.]],

        [[17., 18., 19., 20.]]])
shape of x.mean(axis=1,keepdim=False):
torch.Size([2, 4])
tensor([[ 5.,  6.,  7.,  8.],
        [17., 18., 19., 20.]])
shape of x.mean(axis=2,keepdim=True):
torch.Size([2, 3, 1])
tensor([[[ 2.5000],
         [ 6.5000],
         [10.5000]],

        [[14.5000],
         [18.5000],
         [22.5000]]])
shape of x.mean(axis=2,keepdim=False):
torch.Size([2, 3])
tensor([[ 2.5000,  6.5000, 10.5000],
        [14.5000, 18.5000, 22.5000]])
复制代码

 

keepdim=True
运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);

keepdim=False
运算完之后一般少一维度,求平均变为1的那一维没有了;

axis=k
按第k维运算,其他维度不遍,第k维变为1

# print(x.mean().shape)
# print(x.mean())

shape of x:
torch.Size([2, 3, 4])
torch.Size([])
tensor(12.5000)#所有值的平均值

posted on   cltt  阅读(2380)  评论(0编辑  收藏  举报

编辑推荐:
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现
< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

导航

统计

点击右上角即可分享
微信分享提示