Pytorch的scatter()函数用法

scatter(dim, index, src)的三个参数为:

(1)dim:沿着哪个维度进行索引

(2)index: 用来scatter的元素索引

(3)src: 用来scatter的源元素,可以使一个标量也可以是一个张量

注:带_表示在原张量上修改。

二维例子如下:

1 y = y.scatter(dim,index,src)
2  
3 y [ index[i][j] ] [j] = src[i][j] #if dim==0
4 y[i] [ index[i][j] ]  = src[i][j] #if dim==1

实例如下:

 1 x = torch.rand(2, 5)
 2 
 3 #tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
 4 #        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
 5 
 6 y = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
 7 
 8 #tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
 9 #        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
10 #        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

说明:

需要根据index(即 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])) 来查找src的元素(即x ),从而得到结果y。

一开始进行 self[index[0][0]][0],其中 index[0][0] 的值是0,所以执行 self[0][0]=x[0][0]=0.1940 ,self[index[i][j]][j]=src[i][j]
再比如self[index[1][0]][0],其中 index[1][0] 的值是2,所以执行 self[2][0]=x[1][0]=0.2078 

 

如何确定最终需要修改y中的哪些元素呢?

个人认为根据index中的值及其索引。因为index有10个元素,所以最终y中有10个元素会被修改,具体如下:

 

scatter() 一般可以用来对标签进行 one-hot 编码,一个典型的用标量来修改张量的例子如下:

 1 import torch
 2  
 3 mini_batch = 4
 4 out_planes = 6
 5 out_put = torch.rand(mini_batch, out_planes)
 6 softmax = torch.nn.Softmax(dim=1)
 7 out_put = softmax(out_put)
 8  
 9 print(out_put)
10 label = torch.tensor([1,3,3,5])
11 one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1,label.unsqueeze(1),1)
12 print(one_hot_label)

 

1 tensor([[0.1202, 0.2120, 0.1252, 0.1127, 0.2314, 0.1985],
2         [0.1707, 0.1227, 0.2282, 0.0918, 0.1845, 0.2021],
3         [0.1629, 0.1936, 0.1277, 0.1204, 0.1845, 0.2109],
4         [0.1226, 0.1524, 0.2315, 0.2027, 0.1907, 0.1001]])
5 tensor([1, 3, 3, 5])
6 tensor([[0., 1., 0., 0., 0., 0.],
7         [0., 0., 0., 1., 0., 0.],
8         [0., 0., 0., 1., 0., 0.],
9         [0., 0., 0., 0., 0., 1.]])

 

参考:https://www.cnblogs.com/dogecheng/p/11938009.html

           https://blog.csdn.net/t20134297/article/details/105755817

posted @ 2020-12-18 09:03  vv_869  阅读(741)  评论(0编辑  收藏  举报