pytorch 中tensor的加减和mul、matmul、bmm
如下是tensor乘法与加减法,对应位相乘或相加减,可以一对多
import torch def add_and_mul(): x = torch.Tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) y = torch.Tensor([1, 2, 3]) y = y - x print(y) ''' tensor([[[ 0., 0., 0.], [-3., -3., -3.]], [[-6., -6., -6.], [-9., -9., -9.]]]) ''' t = 1. - x.sum(dim=1) print(t) ''' tensor([[ -4., -6., -8.], [-16., -18., -20.]]) ''' y = torch.Tensor([[1, 2, 3], [4, 5, 6]]) y = torch.mul(y,x) #等价于此方法 y*x print(y) ''' tensor([[[ 1., 4., 9.], [16., 25., 36.]], [[ 7., 16., 27.], [40., 55., 72.]]]) ''' z = x ** 2 print(z) """ tensor([[[ 1., 4., 9.], [ 16., 25., 36.]], [[ 49., 64., 81.], [100., 121., 144.]]]) """ if __name__=='__main__': add_and_mul()
矩阵的乘法,matmul和bmm的具体代码
import torch def matmul_and_bmm(): # a=(2*3*4) a = torch.Tensor([[[1, 2, 3, 4], [4, 0, 6, 0], [3, 2, 1, 4]], [[3, 2, 1, 0], [0, 3, 2, 2], [1, 2, 1, 0]]]) # b=(2,2,4) b = torch.Tensor([[[1, 2, 3, 4], [4, 0, 6, 0]], [[3, 2, 1, 0], [1, 2, 1, 0]]]) b=b.transpose(1, 2) # res=(2,3,2),对于a*b,是第一维度不变,而后[3,4] x [4,2]=[3,2] #res[0,:]=a[0,:] x b[0,;]; res[1,:]=a[1,:] x b[1,;] 其中x表示矩阵乘法 res = torch.matmul(a, b) # 维度res=[2,3,2] res2 = torch.bmm(a, b) # 维度res2=[2,3,2] print(res) # res2的值等于res """ tensor([[[30., 22.], [22., 52.], [26., 18.]], [[14., 8.], [ 8., 8.], [ 8., 6.]]]) """ if __name__=='__main__': matmul_and_bmm()