torch.sum()

import torch
a = torch.arange(2 * 3).view(2, 3)
a_sum = torch.sum(a, 0)
b = torch.arange(2 * 3 * 4).view(2, 3, 4)
b_sum = torch.sum(b, (2, 1))  # 相当于b_sum1 = torch.sum(b, 2), 然后再b_sum = torch.sum(b_sum0, 1)
print(a)
print(a_sum)
print('-----------------------------')
print(b)
print(b_sum)

结果:
a:   ([[0, 1, 2],
      [3, 4, 5]])
a_sum: ([3, 5, 7])
————————
b: ([[[ 0, 1, 2, 3],
   [ 4, 5, 6, 7],
   [ 8, 9, 10, 11]],

   [[12, 13, 14, 15],
   [16, 17, 18, 19],
   [20, 21, 22, 23]]])
b_sum: ([ 66, 210])

官方链接:https://pytorch.org/docs/stable/generated/torch.sum.html?highlight=torch sum#torch.sum

posted @ 2021-04-26 09:46  id_ning  阅读(214)  评论(0编辑  收藏  举报