einsum详解
参考 http://www.elecfans.com/d/779631.html
https://blog.csdn.net/ashome123/article/details/117110042
import torch
from torch import einsum
a = torch.arange(3) # [0, 1, 2]
b = torch.arange(3, 6) # [3, 4, 5]
c = torch.arange(0, 6).reshape(2, 3) # [[0, 1, 2], [3, 4, 5]]
d = torch.arange(0, 6).reshape(3, 2) # [[0, 1], [2, 3], [4, 5]]
# 转置
# print(einsum('ij->ji', c)) # tensor([0, 1, 2])
# 求和
# print(einsum('i', a)) # tensor([0, 1, 2])
# print(einsum('ij->i', c)) # 按行tensor([ 3, 12])
# print(einsum('ij->j', c)) # 按列tensor([ 3, 5, 7])
# 矩阵与向量乘法
# 参数加不加[ ] 好像无所谓
# print(torch.einsum('ik,k->i', c, a)) # tensor([ 5, 14]) 按行乘以一行
# print(torch.einsum('ik,i->k', [d, a])) # tensor([ 10, 13]) 按列乘以一行
# 矩阵与矩阵乘法
# print(torch.einsum('ik,kj->ij', [c, d])) # tensor([[10, 13],[28, 40]])
# 点积
# print(einsum('i,i->', [a, b])) # tensor(14)
# print(torch.einsum('ij,ij->', [c, c])) # tensor(55)
# print(torch.einsum('ij,ij-> ij', [c, c])) # tensor([[ 0, 1, 4], [ 9, 16, 25]])
# 外积
# tensor([[ 0, 0, 0],
# [ 3, 4, 5],
# [ 6, 8, 10]])
# print(torch.einsum('i,j->ij', [a, b])) # tensor(55)
# batch矩阵相乘
# x = torch.arange(12).reshape(2,2,3)
# xx = torch.arange(12).reshape(2,3,2)
# print(x)
# print(xx)
# print(torch.einsum('ijk,ikl->ijl', x, xx))
# tensor([[[ 0, 1, 2],
# [ 3, 4, 5]],
# [[ 6, 7, 8],
# [ 9, 10, 11]]])
#
# tensor([[[ 0, 1],
# [ 2, 3],
# [ 4, 5]],
# [[ 6, 7],
# [ 8, 9],
# [10, 11]]])
#
# tensor([[[ 10, 13],
# [ 28, 40]],
# [[172, 193],
# [244, 274]]])
# 张量缩约
# aa = torch.randn(2,3,5,7)
# bb = torch.randn(11,13,3,17,5)
# print(torch.einsum('pqrs,tuqvr->pstuv', [aa, bb]).shape)
# torch.Size([2, 7, 11, 13, 17])
# 双线性变换
# a = torch.randint(1,5, size=(2, 3))
# b = torch.randint(1,5, size=(4, 3, 6))
# c = torch.randint(1,5, size=(2, 6))
# print(a)
# print(b)
# print(c)
# print(torch.einsum('ik,jkl,il->ij', [a, b, c]))
# tensor([[4, 4, 4],
# [2, 2, 4]])
# tensor([[[3, 3, 2, 2, 2, 4],
# [1, 2, 4, 3, 1, 1],
# [4, 3, 3, 3, 1, 1]],
# [[1, 2, 2, 2, 2, 3],
# [3, 2, 3, 3, 1, 1],
# [1, 3, 3, 3, 3, 3]],
# [[1, 3, 4, 3, 1, 2],
# [4, 4, 4, 1, 3, 4],
# [4, 4, 3, 3, 1, 3]],
# [[2, 1, 1, 4, 1, 2],
# [4, 1, 1, 2, 4, 3],
# [2, 3, 4, 2, 3, 3]]])
# tensor([[2, 2, 1, 3, 2, 4],
# [3, 2, 1, 2, 1, 2]])
# tensor([[388, 384, 472, 416],
# [222, 200, 266, 218]])
种一棵树最好的时间是十年前,其次是现在。