pytorch多维张量相乘和广播机制示例
示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch box = torch.tensor([[[ 0.1000 , 0.2000 , 0.5000 , 0.3000 ], [ 0.6000 , 0.6000 , 0.9000 , 0.9000 ], [ 0.1000 , 0.1000 , 0.2000 , 0.2000 ]], [[ 0.1000 , 0.2000 , 0.5000 , 0.3000 ], [ 0.6000 , 0.6000 , 0.9000 , 0.9000 ], [ 0.1000 , 0.1000 , 0.2000 , 0.2000 ]]]).to(torch.float32) wh = torch.tensor([[[ 200. ], [ 400. ], [ 200. ], [ 400. ]], [[ 200. ], [ 400. ], [ 200. ], [ 400. ]]]).to(torch.float32) print (box.shape) # (2, 3 ,4) print (wh.shape) # (2, 4, 1) result = box @ wh print (result.shape) # (2, 3, 1) print (result) # tensor([[[320.], # [900.], # [180.]], # [[320.], # [900.], # [180.]]]) |
下面这个示例用到了广播机制:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import torch box = torch.tensor([[[ 0.1000 , 0.2000 , 0.5000 , 0.3000 ], [ 0.6000 , 0.6000 , 0.9000 , 0.9000 ], [ 0.1000 , 0.1000 , 0.2000 , 0.2000 ]], [[ 0.1000 , 0.2000 , 0.5000 , 0.3000 ], [ 0.6000 , 0.6000 , 0.9000 , 0.9000 ], [ 0.1000 , 0.1000 , 0.2000 , 0.2000 ]]]).to(torch.float32) wh = torch.tensor([[[ 200. ], [ 400. ], [ 200. ], [ 400. ]]]).to(torch.float32) print (box.shape) # (2, 3 ,4) print (wh.shape) # (1, 4, 1) 注意这个wh的第0维度的大小是1 result = box @ wh # 这里在第0维度会使用广播机制 print (result.shape) # (2, 3, 1) print (result) # tensor([[[320.], # [900.], # [180.]], # [[320.], # [900.], # [180.]]]) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
2023-10-26 python读取和写入txt等文件,文件打开模式,文件对象常用函数