torch.mul() 和 torch.mm() 的区别
torch.mul(a, b)
是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵
torch.mm(a, b)
是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
import torch a = torch.rand(1, 2) b = torch.rand(1, 2) c = torch.rand(2, 3) print(torch.mul(a, b)) # 返回 1*2 的tensor print(torch.mm(a, c)) # 返回 1*3 的tensor print(torch.mul(a, c)) # 由于a、b维度不同,报错
https://blog.csdn.net/Real_Brilliant/article/details/85756477