pytorch里的矩阵乘法探究

一直没完全搞清楚pytorch的乘法是怎么样计算的,今天来完整地实验一下。

广播(broadcast)的概念

👉官方文档

如果两个tensor可广播,那么需要满足如下的规则:

  1. 每个tensor至少有一个维度
  2. 当按照维度尺寸迭代时,从最后的维度开始迭代,维度尺寸需要满足:相等。其中一个是1或者不存在的情况。

广播的规则:

  1. 如果维度数量不同,那么就补1,让两个张量的维度数量相同。
  2. 然后对每个维度,每个维度的大小取决于两个张量在该维度的最值。

torch.matmul

👉官方文档

一维乘一维

  • 如果两个tensor都是一维,就返回点积。
  • 而且,两个一维的tensor需要长度相同才可以正常运算。
  • 在下面这个例子里,oneD(torch.Size([100]))无法和oneD2(torch.Size([200]))相乘,但可以和oneD3(torch.Size([100]))相乘。
>>> import torch
>>> oneD = torch.randn(100)
>>> oneD.shape
torch.Size([100])
>>> oneD2 = torch.randn(200)
>>> oneD2.shape
torch.Size([200])
>>> torch.matmul(oneD, oneD2).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: inconsistent tensor size, expected tensor [100] and src [200] to have the same number of elements, but got 100 and 200 elements respectively
>>> oneD3 = torch.randn(100)
>>> torch.matmul(oneD, oneD3).shape
torch.Size([])
>>> torch.matmul(oneD, oneD3)
tensor(-7.9475)

二维乘二维

  • 如果两个tensor都是二维,就返回矩阵乘法得到的结果。
  • 这个是最容易理解的了,只要满足矩阵乘法的要求就可以正常运算。
>>> twoD = torch.rand(100,200)
>>> twoD2 = torch.rand(200,300)
>>> torch.matmul(twoD, twoD2).shape
torch.Size([100, 300])

一维乘二维

  • 将一维数组转变为二维数组,即转化成了二维乘二维的问题。
  • 在计算时,首先将oneD转化成二维张量torch.Size([1,100]),然后再和twoD相乘,得到torch.Size([1,200]),最后在移除之前添加的第一维,得到torch.Size([200])。
>>> oneD.shape
torch.Size([100])
>>> twoD.shape
torch.Size([100, 200])
>>> torch.matmul(oneD,twoD).shape
torch.Size([200])

二维乘一维

  • 和一维乘二维相同,首先把一维tensor扩展到二维,然后问题转化为二维成二维的情况。
>>> twoD.shape
torch.Size([100, 200])
>>> oneD2.shape
torch.Size([200])
>>> torch.matmul(twoD,oneD2).shape
torch.Size([100])
>>> oneD3.shape
torch.Size([100])
>>> torch.matmul(twoD,oneD3).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: size mismatch, got 100, 100x200,100

多维相乘的情况

如果都是至少一维并且至少有一个大于二维的张量相乘,那么会返回批矩阵乘法(batched matrix multiply)的结果。
会应用broadcast,但是并不完全。当元素相乘时,第一维看作batch,broadcast只会应用在比较batch的维度上。

  • [600,300,400]和[400,300]相乘得到[600, 300, 300]
>>> import torch
>>> A = torch.randn(600,300,400)
>>> B = torch.randn(400,300)
>>> torch.matmul(A,B).shape
torch.Size([600, 300, 300])
  1. [300,600]和[600,300,400]就无法相乘
>>> C = torch.randn(300,600)
>>> torch.matmul(C,A)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: mat1 and mat2 shapes cannot be multiplied (240000x300 and 600x300)

torch.mm

👉官方文档

  • 仅支持二维乘二维的张量。
>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

torch.bmm

👉官方文档

  • batch matrix-matrix product
  • 三维乘三维的张量,第一维是batch,要保持相同,不可广播
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])
posted @   阿莱慢慢来  阅读(296)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示