pytorch中tensor.mean(axis, keepdim)
1 import numpy as np 2 import torch 3 4 x=[ 5 [[1,2,3,4], 6 [5,6,7,8], 7 [9,10,11,12]], 8 9 [[13,14,15,16], 10 [17,18,19,20], 11 [21,22,23,24]] 12 ] 13 x=torch.tensor(x).float() 14 # 15 print("shape of x:") ##[2,3,4] 16 print(x.shape) 17 # 18 print("shape of x.mean(axis=0,keepdim=True):") #[1, 3, 4] 19 print(x.mean(axis=0,keepdim=True).shape) 20 print(x.mean(axis=0,keepdim=True)) 21 # 22 print("shape of x.mean(axis=0,keepdim=False):") #[3, 4] 23 print(x.mean(axis=0,keepdim=False).shape) 24 print(x.mean(axis=0,keepdim=False)) 25 # 26 print("shape of x.mean(axis=1,keepdim=True):") #[2, 1, 4] 27 print(x.mean(axis=1,keepdim=True).shape) 28 print(x.mean(axis=1,keepdim=True)) 29 # 30 print("shape of x.mean(axis=1,keepdim=False):") #[2, 4] 31 print(x.mean(axis=1,keepdim=False).shape) 32 print(x.mean(axis=1,keepdim=False)) 33 # 34 print("shape of x.mean(axis=2,keepdim=True):") #[2, 3, 1] 35 print(x.mean(axis=2,keepdim=True).shape) 36 print(x.mean(axis=2,keepdim=True)) 37 # 38 print("shape of x.mean(axis=2,keepdim=False):") #[2, 3] 39 print(x.mean(axis=2,keepdim=False).shape) 40 print(x.mean(axis=2,keepdim=False))
shape of x: torch.Size([2, 3, 4]) shape of x.mean(axis=0,keepdim=True): torch.Size([1, 3, 4]) tensor([[[ 7., 8., 9., 10.], [11., 12., 13., 14.], [15., 16., 17., 18.]]]) shape of x.mean(axis=0,keepdim=False): torch.Size([3, 4]) tensor([[ 7., 8., 9., 10.], [11., 12., 13., 14.], [15., 16., 17., 18.]]) shape of x.mean(axis=1,keepdim=True): torch.Size([2, 1, 4]) tensor([[[ 5., 6., 7., 8.]], [[17., 18., 19., 20.]]]) shape of x.mean(axis=1,keepdim=False): torch.Size([2, 4]) tensor([[ 5., 6., 7., 8.], [17., 18., 19., 20.]]) shape of x.mean(axis=2,keepdim=True): torch.Size([2, 3, 1]) tensor([[[ 2.5000], [ 6.5000], [10.5000]], [[14.5000], [18.5000], [22.5000]]]) shape of x.mean(axis=2,keepdim=False): torch.Size([2, 3]) tensor([[ 2.5000, 6.5000, 10.5000], [14.5000, 18.5000, 22.5000]])
keepdim=True
运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);
keepdim=False
运算完之后一般少一维度,求平均变为1的那一维没有了;
axis=k
按第k维运算,其他维度不遍,第k维变为1
# print(x.mean().shape)
# print(x.mean())
shape of x:
torch.Size([2, 3, 4])
torch.Size([])
tensor(12.5000)#所有值的平均值
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现