pytorch里的矩阵乘法探究
一直没完全搞清楚pytorch的乘法是怎么样计算的,今天来完整地实验一下。
广播(broadcast)的概念
👉官方文档
如果两个tensor可广播,那么需要满足如下的规则:
- 每个tensor至少有一个维度
- 当按照维度尺寸迭代时,从最后的维度开始迭代,维度尺寸需要满足:相等。其中一个是1或者不存在的情况。
广播的规则:
- 如果维度数量不同,那么就补1,让两个张量的维度数量相同。
- 然后对每个维度,每个维度的大小取决于两个张量在该维度的最值。
torch.matmul
👉官方文档
一维乘一维
- 如果两个tensor都是一维,就返回点积。
- 而且,两个一维的tensor需要长度相同才可以正常运算。
- 在下面这个例子里,oneD(torch.Size([100]))无法和oneD2(torch.Size([200]))相乘,但可以和oneD3(torch.Size([100]))相乘。
>>> import torch
>>> oneD = torch.randn(100)
>>> oneD.shape
torch.Size([100])
>>> oneD2 = torch.randn(200)
>>> oneD2.shape
torch.Size([200])
>>> torch.matmul(oneD, oneD2).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: inconsistent tensor size, expected tensor [100] and src [200] to have the same number of elements, but got 100 and 200 elements respectively
>>> oneD3 = torch.randn(100)
>>> torch.matmul(oneD, oneD3).shape
torch.Size([])
>>> torch.matmul(oneD, oneD3)
tensor(-7.9475)
二维乘二维
- 如果两个tensor都是二维,就返回矩阵乘法得到的结果。
- 这个是最容易理解的了,只要满足矩阵乘法的要求就可以正常运算。
>>> twoD = torch.rand(100,200)
>>> twoD2 = torch.rand(200,300)
>>> torch.matmul(twoD, twoD2).shape
torch.Size([100, 300])
一维乘二维
- 将一维数组转变为二维数组,即转化成了二维乘二维的问题。
- 在计算时,首先将oneD转化成二维张量torch.Size([1,100]),然后再和twoD相乘,得到torch.Size([1,200]),最后在移除之前添加的第一维,得到torch.Size([200])。
>>> oneD.shape
torch.Size([100])
>>> twoD.shape
torch.Size([100, 200])
>>> torch.matmul(oneD,twoD).shape
torch.Size([200])
二维乘一维
- 和一维乘二维相同,首先把一维tensor扩展到二维,然后问题转化为二维成二维的情况。
>>> twoD.shape
torch.Size([100, 200])
>>> oneD2.shape
torch.Size([200])
>>> torch.matmul(twoD,oneD2).shape
torch.Size([100])
>>> oneD3.shape
torch.Size([100])
>>> torch.matmul(twoD,oneD3).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: size mismatch, got 100, 100x200,100
多维相乘的情况
如果都是至少一维并且至少有一个大于二维的张量相乘,那么会返回批矩阵乘法(batched matrix multiply)的结果。
会应用broadcast,但是并不完全。当元素相乘时,第一维看作batch,broadcast只会应用在比较batch的维度上。
- [600,300,400]和[400,300]相乘得到[600, 300, 300]
>>> import torch
>>> A = torch.randn(600,300,400)
>>> B = torch.randn(400,300)
>>> torch.matmul(A,B).shape
torch.Size([600, 300, 300])
- [300,600]和[600,300,400]就无法相乘
>>> C = torch.randn(300,600)
>>> torch.matmul(C,A)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: mat1 and mat2 shapes cannot be multiplied (240000x300 and 600x300)
torch.mm
👉官方文档
- 仅支持二维乘二维的张量。
>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851, 0.5037, -0.3633],
[-0.0760, -3.6705, 2.4784]])
torch.bmm
👉官方文档
- batch matrix-matrix product
- 三维乘三维的张量,第一维是batch,要保持相同,不可广播
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人