pytorch多维张量相乘和广播机制示例

示例:

import torch 
box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]],

        [[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32)

wh = torch.tensor([[[200.],
         [400.],
         [200.],
         [400.]],

        [[200.],
         [400.],
         [200.],
         [400.]]]).to(torch.float32)

print(box.shape)  # (2, 3 ,4)
print(wh.shape)  # (2, 4, 1)

result = box @ wh
print(result.shape)  # (2, 3, 1)
print(result)
# tensor([[[320.],
#          [900.],
#          [180.]],

#         [[320.],
#          [900.],
#          [180.]]])

  

下面这个示例用到了广播机制:

import torch 
box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]],

        [[0.1000, 0.2000, 0.5000, 0.3000],
         [0.6000, 0.6000, 0.9000, 0.9000],
         [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32)

wh = torch.tensor([[[200.],
         [400.],
         [200.],
         [400.]]]).to(torch.float32)

print(box.shape)  # (2, 3 ,4)
print(wh.shape)  # (1, 4, 1)   注意这个wh的第0维度的大小是1

result = box @ wh  # 这里在第0维度会使用广播机制
print(result.shape)  # (2, 3, 1)
print(result)
# tensor([[[320.],
#          [900.],
#          [180.]],

#         [[320.],
#          [900.],
#          [180.]]])

  

 

posted @ 2024-10-26 12:30  Picassooo  阅读(45)  评论(0编辑  收藏  举报