einsum详解

参考 http://www.elecfans.com/d/779631.html

https://blog.csdn.net/ashome123/article/details/117110042

import torch
from torch import einsum

a = torch.arange(3)  # [0, 1, 2]
b = torch.arange(3, 6)  # [3, 4, 5]

c = torch.arange(0, 6).reshape(2, 3)  # [[0, 1, 2], [3, 4, 5]]
d = torch.arange(0, 6).reshape(3, 2)  # [[0, 1], [2, 3], [4, 5]]

# 转置
# print(einsum('ij->ji', c))  # tensor([0, 1, 2])

# 求和
# print(einsum('i', a))  # tensor([0, 1, 2])

# print(einsum('ij->i', c))  # 按行tensor([ 3, 12])

# print(einsum('ij->j', c))  # 按列tensor([ 3, 5, 7])

# 矩阵与向量乘法
# 参数加不加[ ] 好像无所谓
# print(torch.einsum('ik,k->i', c, a))  # tensor([ 5, 14]) 按行乘以一行

# print(torch.einsum('ik,i->k', [d, a]))  # tensor([ 10, 13]) 按列乘以一行

# 矩阵与矩阵乘法
# print(torch.einsum('ik,kj->ij', [c, d]))  # tensor([[10, 13],[28, 40]])

# 点积
# print(einsum('i,i->', [a, b])) # tensor(14)

# print(torch.einsum('ij,ij->', [c, c])) # tensor(55)

# print(torch.einsum('ij,ij-> ij', [c, c]))  # tensor([[ 0,  1,  4], [ 9, 16, 25]])

# 外积
# tensor([[ 0,  0,  0],
#         [ 3,  4,  5],
#         [ 6,  8, 10]])
# print(torch.einsum('i,j->ij', [a, b])) # tensor(55)

# batch矩阵相乘
# x = torch.arange(12).reshape(2,2,3)
# xx = torch.arange(12).reshape(2,3,2)
# print(x)
# print(xx)
# print(torch.einsum('ijk,ikl->ijl', x, xx))
# tensor([[[ 0,  1,  2],
#          [ 3,  4,  5]],
#         [[ 6,  7,  8],
#          [ 9, 10, 11]]])
#
# tensor([[[ 0,  1],
#          [ 2,  3],
#          [ 4,  5]],
#         [[ 6,  7],
#          [ 8,  9],
#          [10, 11]]])
#
# tensor([[[ 10,  13],
#          [ 28,  40]],
#         [[172, 193],
#          [244, 274]]])

# 张量缩约
# aa = torch.randn(2,3,5,7)
# bb = torch.randn(11,13,3,17,5)
# print(torch.einsum('pqrs,tuqvr->pstuv', [aa, bb]).shape)
# torch.Size([2, 7, 11, 13, 17])


# 双线性变换
# a = torch.randint(1,5, size=(2, 3))
# b = torch.randint(1,5, size=(4, 3, 6))
# c = torch.randint(1,5, size=(2, 6))
# print(a)
# print(b)
# print(c)
# print(torch.einsum('ik,jkl,il->ij', [a, b, c]))
# tensor([[4, 4, 4],
#         [2, 2, 4]])
# tensor([[[3, 3, 2, 2, 2, 4],
#          [1, 2, 4, 3, 1, 1],
#          [4, 3, 3, 3, 1, 1]],
#         [[1, 2, 2, 2, 2, 3],
#          [3, 2, 3, 3, 1, 1],
#          [1, 3, 3, 3, 3, 3]],
#         [[1, 3, 4, 3, 1, 2],
#          [4, 4, 4, 1, 3, 4],
#          [4, 4, 3, 3, 1, 3]],
#         [[2, 1, 1, 4, 1, 2],
#          [4, 1, 1, 2, 4, 3],
#          [2, 3, 4, 2, 3, 3]]])
# tensor([[2, 2, 1, 3, 2, 4],
#         [3, 2, 1, 2, 1, 2]])
# tensor([[388, 384, 472, 416],
#         [222, 200, 266, 218]])
posted @ 2021-10-26 17:31  种树人  阅读(279)  评论(0编辑  收藏  举报