https://zhuanlan.zhihu.com/p/474153365
| torch.repeat |
| 使张量沿着某个维度进行复制, 并且不仅可以复制张量,也可以拓展张量的维度: |
| |
| import torch |
| |
| x = torch.randn(2, 4) |
| |
| |
| # 1. 沿着某个维度复制 |
| x.repeat(1, 1).size() # torch.Size([2, 4]) |
| |
| x.repeat(2, 1).size() # torch.Size([4, 4]) |
| |
| x.repeat(1, 2).size() # torch.Size([2, 8]) |
| |
| |
| # 2. 不仅可以复制维度, 还可以拓展维度 |
| x.repeat(1, 1, 1).size() # torch.Size([1, 2, 4]) |
| |
| x.repeat(2, 1, 1).size() # torch.Size([2, 2, 4]) |
| |
| x.repeat(1, 1, 1, 1).size() # torch.Size([1, 1, 2, 4]) |
| |
| |
| # 3. repeat中传入的参数不可以少于x的维度 |
| x.repeat(1) # 报错 |
| torch.repeat_interleave |
| torch.repeat_interleave的行为与numpy.repeat类似,但是和torch.repeat不同,这边还是以代码为例: |
| |
| import torch |
| x = torch.randn(2, 2) |
| |
| print(x) |
| >>> tensor([[ 0.4332, 0.1172], |
| [ 0.8808, -1.7127]]) |
| |
| print(x.repeat(2, 1)) |
| >>> tensor([[ 0.4332, 0.1172], |
| [ 0.8808, -1.7127], |
| [ 0.4332, 0.1172], |
| [ 0.8808, -1.7127]]) |
| |
| print(x.repeat_interleave(2, dim=0)) |
| >>> tensor([[ 0.4332, 0.1172], |
| [ 0.4332, 0.1172], |
| [ 0.8808, -1.7127], |
| [ 0.8808, -1.7127]]) |
| |
| print(x.repeat_interleave(2, dim=1)) |
| >>> tensor([[ 0.4332, 0.4332, 0.1172, 0.1172], |
| [ 0.8808, 0.8808, -1.7127, -1.7127]]) |
| |
| # 如果不传dim参数, 则默认复制后拉平 |
| print(x.repeat_interleave(2)) |
| >>> tensor([ 0.4332, 0.4332, 0.1172, 0.1172, 0.8808, 0.8808, -1.7127, -1.7127]) |
| 从这个代码可以看出来torch.repeat更像是把tensor作为一个整体进行复制, 而torch.repeat_interleave更是针对tensor里的每个元素进行复制,并且torch.repeat_interleave可以通过传入一个一维的torch.Tensor来指定每个元素复制的次数 |
| |
| import torch |
| x = torch.tensor([[1, 2], [3, 4]]) |
| |
| result = torch.repeat_interleave(x, torch.tensor([1, 3]), dim=0) |
| print(result) |
| >>> tensor([[1, 2], |
| [3, 4], |
| [3, 4], |
| [3, 4]]) |
| torch.tile |
| torch.tile函数也是元素复制的一个函数, 但是在传参上和torch.repeat不同,但是也是以input为一个整体进行复制, torch.tile如果只传入一个参数的话, 默认是沿着行进行复制 |
| |
| import torch |
| x = torch.tensor([[1, 2], [3, 4]]) |
| |
| # 只传入一个参数 |
| print(x.tile((2, ))) |
| >>> tensor([[1, 2, 1, 2], |
| [3, 4, 3, 4]]) |
| |
| print(x.repeat(1, 2)) |
| >>> tensor([[1, 2, 1, 2], |
| [3, 4, 3, 4]]) |
| torch.tile传入一个元组的话, 表示(行复制次数, 列复制次数) |
| |
| import torch |
| x = torch.tensor([[1, 2], [3, 4]]) |
| |
| print(x.tile((2, 2))) |
| >>> tensor([[1, 2, 1, 2], |
| [3, 4, 3, 4], |
| [1, 2, 1, 2], |
| [3, 4, 3, 4]]) |
| |
| print(x.repeat(2, 2)) |
| >>> tensor([[1, 2, 1, 2], |
| [3, 4, 3, 4], |
| [1, 2, 1, 2], |
| [3, 4, 3, 4]]) |
| 当传入的参数少于需要复制的元素的维度时, 如果一个tensor的形状为(2, 2, 2),传入tile中的参数为(2, 2)时, 会默认表示为(1, 2, 2) |
| |
| import torch |
| x = torch.randn(2, 2, 2) |
| print(x) |
| >>> tensor([[[ 0.8517, 0.8721], |
| [-1.1591, -0.2000]], |
| |
| [[ 0.3888, -0.8365], |
| [-1.6383, -0.1539]]]) |
| |
| print(x.tile((2, 2))) |
| >>> tensor([[[ 0.8517, 0.8721, 0.8517, 0.8721], |
| [-1.1591, -0.2000, -1.1591, -0.2000], |
| [ 0.8517, 0.8721, 0.8517, 0.8721], |
| [-1.1591, -0.2000, -1.1591, -0.2000]], |
| |
| [[ 0.3888, -0.8365, 0.3888, -0.8365], |
| [-1.6383, -0.1539, -1.6383, -0.1539], |
| [ 0.3888, -0.8365, 0.3888, -0.8365], |
| [-1.6383, -0.1539, -1.6383, -0.1539]]]) |
| 当传入的参数多于需要复制的元素维度时,会拓展维度 |
| |
| import torch |
| x = torch.randn(2, 2) |
| print(x) |
| >>> tensor([[ 1.1165, -0.5559], |
| [-0.6341, 0.5215]]) |
| |
| print(x.tile((2, 2, 2))) |
| >>> tensor([[[ 1.1165, -0.5559, 1.1165, -0.5559], |
| [-0.6341, 0.5215, -0.6341, 0.5215], |
| [ 1.1165, -0.5559, 1.1165, -0.5559], |
| [-0.6341, 0.5215, -0.6341, 0.5215]], |
| |
| [[ 1.1165, -0.5559, 1.1165, -0.5559], |
| [-0.6341, 0.5215, -0.6341, 0.5215], |
| [ 1.1165, -0.5559, 1.1165, -0.5559], |
| [-0.6341, 0.5215, -0.6341, 0.5215]]]) |
| |
| |
| 使用tile和reshape代替repeat_interleave |
| import torch |
| |
| x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: (2, 3) |
| |
| y = torch.repeat_interleave(x, repeats=3, dim=0) |
| |
| print(y) |
| >>> tensor([[1, 2, 3], |
| [1, 2, 3], |
| [1, 2, 3], |
| [4, 5, 6], |
| [4, 5, 6], |
| [4, 5, 6]]) |
| |
| # 直接使用tile, 无法得到类似的结果 |
| z = torch.tile(x, (3, )) |
| print(z) |
| >>> tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3], |
| [4, 5, 6, 4, 5, 6, 4, 5, 6]]) |
| |
| z = torch.tile(x, (3, 1)) |
| print(z) |
| >>> tensor([[1, 2, 3], |
| [4, 5, 6], |
| [1, 2, 3], |
| [4, 5, 6], |
| [1, 2, 3], |
| [4, 5, 6]]) |
| |
| # 需要使用 tile + reshape 才可以得到类似的结果 |
| z = torch.tile(x, (3, )) |
| print(z.shape) # (2, 9) |
| print(z.reshape(6, 3)) # 得到了和y一样的输出 |
| >>> tensor([[1, 2, 3], |
| [1, 2, 3], |
| [1, 2, 3], |
| [4, 5, 6], |
| [4, 5, 6], |
| [4, 5, 6]]) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现
· Apache Tomcat RCE漏洞复现(CVE-2025-24813)