pytorch函数

squeeze() 和 unsqueeze()函数
torch.squeeze(A,N)
torch.unsqueeze()函数的作用减少数组A指定位置N的维度,如果不指定位置参数N,如果数组A的维度为(1,1,3)那么执行 torch.squeeze(A,1) 后A的维度变为 (1,3),中间的维度被删除

点击查看代码
注:
1. 如果指定的维度大于1,那么将操作无效
2. 如果不指定维度N,那么将删除所有维度为1的维度

torch.unsqueeze(A,N)
torch.unsqueeze()函数的作用增加数组A指定位置N的维度,例如两行三列的数组A维度为(2,3),那么这个数组就有三个位置可以增加维度,分别是( [位置0] 2,[位置1] 3 [位置2] )或者是 ( [位置-3] 2,[位置-2] 3 [位置-1] ),如果执行 torch.unsqueeze(A,1),数据的维度就变为了 (2,1,3)

scatter(output, dim, index, src) → Tensor

总结:scatter函数就是把src数组中的数据重新分配到output数组当中,index数组中表示了要把src数组中的数据分配到output数组中的位置,若未指定,则填充0.

scatter

import torch
 
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)


tensor([[ 1.4782, -1.1345, -1.1457, -0.6050],
        [-0.4183, -0.0229,  1.2361, -1.7747]])
tensor([[-0.6050, -1.1345, -1.1457,  1.4782,  0.0000],
        [ 1.2361, -0.4183, -0.0229, -1.7747,  0.0000]])

一般scatter用于生成onehot向量,如下所示:

index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.]])
posted @   ziv80  阅读(14)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示