pytorch的expand函数

PyTorch 中的 expand 函数用于扩展张量的形状,使其在某些维度上“看起来”像被复制了多次,但实际上它不会复制数据,从而节省内存和计算资源。扩展后的张量共享原始张量的内存空间,因此原始张量和扩展后的张量是同一个数据的视图。

以下是 torch.expand 函数的一些基本用法:

1. 扩展一维张量: 将一维张量扩展到更高维度,例如,将一维张量扩展为二维张量的第一维度。

import torch

x = torch.tensor([1, 2, 3])
y = x.expand(2, -1)  # 扩展为 2x3 的矩阵
print(y)
# 输出:
# tensor([[1, 2, 3],
#         [1, 2, 3]])

 -1 表示该维度的大小与原始张量相同。

 

2. 使用 -1 自动扩展: 使用 -1 可以自动扩展张量到与输入张量相同的大小。

x = torch.tensor([1, 2, 3])
y = x.expand(-1, 2)  # 扩展为 3x2 的矩阵
print(y)
# 输出:
# tensor([[1, 1],
#         [2, 2],
#         [3, 3]])

 

3. 扩展多维张量: 扩展多维张量时,可以指定多个维度进行扩展。

x = torch.tensor([[1, 2], [3, 4]])
y = x.expand(2, -1, -1)  # 扩展为 2x2x2 的张量
print(y)
# 输出:
# tensor([[[1, 2],
#          [1, 2]],
#         [[3, 4],
#          [3, 4]]])

 

4. 扩展与广播的区别: expandbroadcast 相似,但 broadcast 在进行操作时会复制数据,而 expand 不会。expand 更适用于减少内存使用。

5. 扩展与复制: 如果你需要一个实际复制了数据的新张量,可以使用 expand 后跟 clone

x = torch.tensor([1, 2, 3])
y = x.expand(2, -1).clone()  # 扩展后复制数据
print(y)
# 输出:
# tensor([[1, 2, 3],
#         [1, 2, 3]])

 

使用 expand 时,扩展的维度大小可以是具体的数值,也可以是 -1,表示该维度的大小与原始张量相同。如果扩展的维度大小大于原始张量,PyTorch 会抛出错误。

(摘自kimi)

 

posted @ 2024-07-15 09:22  Picassooo  阅读(45)  评论(0编辑  收藏  举报