torch.sum()
import torch
a = torch.arange(2 * 3).view(2, 3)
a_sum = torch.sum(a, 0)
b = torch.arange(2 * 3 * 4).view(2, 3, 4)
b_sum = torch.sum(b, (2, 1)) # 相当于b_sum1 = torch.sum(b, 2), 然后再b_sum = torch.sum(b_sum0, 1)
print(a)
print(a_sum)
print('-----------------------------')
print(b)
print(b_sum)
结果:
a: ([[0, 1, 2],
[3, 4, 5]])
a_sum: ([3, 5, 7])
————————
b: ([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
b_sum: ([ 66, 210])
官方链接:https://pytorch.org/docs/stable/generated/torch.sum.html?highlight=torch sum#torch.sum