浅谈Pytorch广播机制
广播机制 从后往前依次检查维度,如果两个张量对应的维度上数目相等,则会按照该维度相加
若其中一个维度数目为1,则会应用广播机制.
如:
a = torch.arange(3).reshape([1, 3, 1]) b = torch.arange(3).reshape([1, 3, 1]) a + b #维度均相等
输出为:
tensor([[[0],
[2],
[4]]])
维度不相等_1:
a = torch.arange(3).reshape([1, 3, 1]) b = torch.arange(6).reshape([1, 3, 2]) a + b #维度不相等_1
理解为:
a是tensor([[[0],[1],[2]]]) 增加维度,广播为: tensor([[0, 0], [1, 1], [2, 2]]) b为tensor([[[0, 1], [2, 3], [4, 5]]])
则 a + b = tensor([[[0, 1], [3, 4], [6, 7]]])
同理,可以获得如下结果:
a = torch.arange(3).reshape([1, 3, 1]) b = torch.arange(6).reshape([1, 3, 2]) a + b >>>tensor([[[0, 1], [3, 4], [6, 7]]])
a = torch.arange(3).reshape([1, 3, 1]) b = torch.arange(2).reshape([1, 1, 2]) a + b >>>tensor([[[0, 1], [1, 2], [2, 3]]])
a = torch.arange(3).reshape([1, 3, 1]) b = torch.arange(6).reshape([3, 1, 2]) a + b >>>tensor([[[0, 1], [1, 2], [2, 3]], [[2, 3], [3, 4], [4, 5]], [[4, 5], [5, 6], [6, 7]]])

浙公网安备 33010602011771号