pytorch多维张量相乘和广播机制示例

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]],
 
        [[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32)
 
wh = torch.tensor([[[200.],
         [400.],
         [200.],
         [400.]],
 
        [[200.],
         [400.],
         [200.],
         [400.]]]).to(torch.float32)
 
print(box.shape)  # (2, 3 ,4)
print(wh.shape)  # (2, 4, 1)
 
result = box @ wh
print(result.shape)  # (2, 3, 1)
print(result)
# tensor([[[320.],
#          [900.],
#          [180.]],
 
#         [[320.],
#          [900.],
#          [180.]]])

  

下面这个示例用到了广播机制:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]],
 
        [[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32)
 
wh = torch.tensor([[[200.],
         [400.],
         [200.],
         [400.]]]).to(torch.float32)
 
print(box.shape)  # (2, 3 ,4)
print(wh.shape)  # (1, 4, 1)   注意这个wh的第0维度的大小是1
 
result = box @ wh  # 这里在第0维度会使用广播机制
print(result.shape)  # (2, 3, 1)
print(result)
# tensor([[[320.],
#          [900.],
#          [180.]],
 
#         [[320.],
#          [900.],
#          [180.]]])

  

 

posted @   Picassooo  阅读(46)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
历史上的今天:
2023-10-26 python读取和写入txt等文件,文件打开模式,文件对象常用函数
点击右上角即可分享
微信分享提示