pytorch的expand函数
PyTorch 中的 expand
函数用于扩展张量的形状,使其在某些维度上“看起来”像被复制了多次,但实际上它不会复制数据,从而节省内存和计算资源。扩展后的张量共享原始张量的内存空间,因此原始张量和扩展后的张量是同一个数据的视图。
以下是 torch.expand
函数的一些基本用法:
1. 扩展一维张量: 将一维张量扩展到更高维度,例如,将一维张量扩展为二维张量的第一维度。
import torch x = torch.tensor([1, 2, 3]) y = x.expand(2, -1) # 扩展为 2x3 的矩阵 print(y) # 输出: # tensor([[1, 2, 3], # [1, 2, 3]])
-1
表示该维度的大小与原始张量相同。
2. 使用 -1
自动扩展: 使用 -1
可以自动扩展张量到与输入张量相同的大小。
x = torch.tensor([1, 2, 3]) y = x.expand(-1, 2) # 扩展为 3x2 的矩阵 print(y) # 输出: # tensor([[1, 1], # [2, 2], # [3, 3]])
3. 扩展多维张量: 扩展多维张量时,可以指定多个维度进行扩展。
x = torch.tensor([[1, 2], [3, 4]]) y = x.expand(2, -1, -1) # 扩展为 2x2x2 的张量 print(y) # 输出: # tensor([[[1, 2], # [1, 2]], # [[3, 4], # [3, 4]]])
4. 扩展与广播的区别: expand
与 broadcast
相似,但 broadcast
在进行操作时会复制数据,而 expand
不会。expand
更适用于减少内存使用。
5. 扩展与复制: 如果你需要一个实际复制了数据的新张量,可以使用 expand
后跟 clone
。
x = torch.tensor([1, 2, 3]) y = x.expand(2, -1).clone() # 扩展后复制数据 print(y) # 输出: # tensor([[1, 2, 3], # [1, 2, 3]])
使用 expand
时,扩展的维度大小可以是具体的数值,也可以是 -1
,表示该维度的大小与原始张量相同。如果扩展的维度大小大于原始张量,PyTorch 会抛出错误。
(摘自kimi)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
2023-07-15 python实现iou计算
2022-07-15 xmind快捷键
2022-07-15 转:python中 os._exit() 和 sys.exit(), exit(0)和exit(1) 的用法和区别
2021-07-15 long-tail datasets