10.2.2 平均汇聚

torch.repeat_interleave 用于按指定规则重复张量的元素,支持按维度扩展或自定义每个元素的重复次数。以下是详细说明和示例:

作用

  • 功能:沿特定维度重复张量的元素,支持两种模式:
    1. 统一重复次数:所有元素重复相同次数。
    2. 自定义重复次数:每个元素按单独指定的次数重复。
  • torch.repeat 的区别
    • repeat 复制整个张量(如 [1,2] → [1,2,1,2])。
    • repeat_interleave 逐个元素连续重复(如 [1,2] → [1,1,2,2])。

语法

torch.repeat_interleave(
    input,          # 输入张量
    repeats,        # 重复次数(整数或张量)
    dim=None,       # 操作的维度(默认为展平后处理)
    output_size=None # 预分配输出大小(可选)
)

参数详解

  1. repeats
    • 若为整数:所有元素重复该次数。
    • 若为张量:形状需与 inputdim 维度一致,每个元素对应重复次数。
  2. dim
    • 指定操作的维度。若为 None,先展平输入张量再处理。
  3. output_size
    • 预定义输出大小,用于性能优化(通常无需手动指定)。

示例

示例 1:一维张量,统一重复次数

x = torch.tensor([1, 2, 3])
y = torch.repeat_interleave(x, repeats=2)
print(y)  # 输出: tensor([1, 1, 2, 2, 3, 3])

示例 2:二维张量,沿指定维度重复

x = torch.tensor([[1, 2], [3, 4]])
# 沿 dim=0 重复(每行重复2次)
y = torch.repeat_interleave(x, repeats=2, dim=0)
print(y)
# 输出:
# tensor([[1, 2],
#         [1, 2],
#         [3, 4],
#         [3, 4]])

示例 3:自定义每个元素的重复次数

x = torch.tensor([1, 2, 3])
repeats = torch.tensor([2, 3, 1])  # 每个元素重复2、3、1次
y = torch.repeat_interleave(x, repeats=repeats)
print(y)  # 输出: tensor([1, 1, 2, 2, 2, 3])

示例 4:沿不同维度自定义重复次数

x = torch.tensor([[1, 2], [3, 4]])
repeats = torch.tensor([1, 2])  # 沿 dim=1 的每个元素重复1、2次
y = torch.repeat_interleave(x, repeats=repeats, dim=1)
print(y)
# 输出:
# tensor([[1, 2, 2],
#         [3, 4, 4]])

形状计算规则

  • 统一重复次数
    • 若输入形状为 (D1, D2, ..., Dn),沿 dim=k 重复 R 次,则输出形状为 (D1, ..., Dk*R, ..., Dn)
  • 自定义重复次数
    • 沿 dim=k 的维度大小变为 sum(repeats),其他维度不变。

注意事项

  1. 广播限制repeats 张量必须与输入在 dim 维度长度一致。
  2. 性能优化:若已知输出大小,可用 output_size 避免额外内存分配。
  3. 默认维度dim=None 时,输入会被展平为 1D 再处理。

通过灵活设置 repeatsdim,可以实现复杂的数据扩展需求。

posted @   最爱丁珰  阅读(7)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
历史上的今天:
2024-02-22 Replace on Segment
2024-02-22 Berserk Monsters
点击右上角即可分享
微信分享提示