PyTorch tensor的scatter_函数
TORCH.TENSOR.SCATTER_
Tensor.scatter_(dim, index, src, reduce=None) → Tensor
把src里面的元素按照index和dim参数给出的条件,放置到目标tensor里面,在这里是self。下面为了讨论方便,目标tensor和self在交换使用的时候,请大家知道,在这里指的是同一个tensor.
注意:这里self, index, src三个张量的纬度必须是一致的(但每个纬度上的size不一定一致,请大家体会)。
只有src是个例外,可以是标量,即单个数字。
这个时候,就是把这单个数字,根据参数的条件, 放置到self的不同位置。
那么怎么放呢?根据PyTorch的文档,对于一个3-D的tensor,放置方法如下:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
由上面的公式很容易推断出,对于一个2-D的tensor,放置方法如下:
self[index[i][j]][j] = src[i][j] # if dim == 0
self[i][index[i][j]] = src[i][j] # if dim == 1
对于一个1-D的tensor,放置方法如下:
self[index[i]] = src[i] # if dim == 0
是不是有点晕?我们来解释一下。
1. 当dim为0的时候
我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j][k]个元素,那么放置到self里面的位置(三个纬度的值)分别如下:
- index[i][j][k]
- j
- k
对于第一个纬度的位置,就是把i,替换为index[i][j][k]
那么这里有个问题,如果index的size比src的size要小的话,怎么办? 那就是对于在index里面,找不到的值,就不再处理,self里面原来是什么还是什么。
为了更加方便说明,这里假设src是1-D的,即一个1维数组,那么dim只有一个值可以设置,即0(当然也可以说是有两个值,-1也是可以的,但是-1和0实际上指的都是第一个纬度)。那么这个时候self和index按照上面的规则,也必须都是一维的(参见上面的注意)。那么我们直接来看一段示例代码和输出来进行解释:
a = torch.arange(1, 6).long()
print(a)
i = torch.LongTensor([4,3,2])
t = torch.zeros(10).long()
t.scatter_(0, i, a)
print(t)
输出为:
tensor([1, 2, 3, 4, 5])
tensor([0, 0, 3, 2, 1, 0, 0, 0, 0, 0])
可以看到这里,源tensor(a)是一个一维,包含5个元素。目标tensor(t),在这里是一个10个元素的tensor,为了大家看得方便,我们先把所有元素设置为0,然后再把源tensor里面的元素搬过来放进目标tensor里面的时候,就很容易看到,被index tensor里面的信息所影响到的元素是非0的,如果没受到影响的是0。
这里源tensor只有5个元素,那么都搬过来,目标tensor(t)里面的元素也还是有10-5=5个元素是不会受到影响的,即为0。
那么为什么上面看到目标tensor里面的非0元素的个数只有3个,而不是5个(等于源tensor的个数)? 回顾一下对于3-D的tensor,当dim=0的时候,元素设置的公式:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
显然,对于1-D的tensor,上面的公式简化为:
self[index[i]] = src[i] # if dim == 0
因为这里index只有三个元素[4,3,2],那么意味着,再把源tensor(a)里面的5个元素放置到目标tensor(t)的过程中,只有i取值为0,1,2的,才能使用index里面的值,其余2个(在a里面的位置分别为4,5),就不再般src里面的元素了。我们来逐个元素说明一下:
- 当 i == 0时,self[index[0]] = src[0],即self[4] = src[0],也就是把src里面的第1个元素设置到self的第4个元素,这里src[0] 即是 a[0],是1,而self[4],即t[4]被设置为了1.
- 当 i == 1时,self[index[1]] = src[1],即self[3] = src[1],也就是把src里面的第2个元素设置到self的第3个元素,这里src[1] 即是 a[1],是2,而self[3],即t[3]被设置为了2.
- 当 i == 2时,self[index[2]] = src[2],即self[2] = src[2],也就是把src里面的第3个元素设置到self的第2个元素,这里src[2] 即是 a[2],是3,而self[2],即t[2]被设置为了3.
- 当 i == 3 和4的是,index里面已经没有对应的数值了,这些元素就不处理了。
2. 当dim为1的时候
说明,src,目标tensor和index都至少是2-D的,如果设置dim = 1,将会导致PyTorch报错。错误信息如下(对于1-D的index):
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
对于一个3-D的tensor,我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j][k]个元素,那么放置到self里面的位置(三个纬度的值)分别如下:
- i
- index[i][j][k]
- k
对于第一个纬度的位置,就是i,元素在src里面的位置是什么,在self里面也是相同的。
对于第二个纬度的位置,就是j,元素在self里面的位置变成了index[i][j][k]。
那么同样地,如果index的size比src的size要小的话,怎么办?那就是对于在index里面,找不到的值,就不在处理,即self里面是什么值还是什么值,不会变化。
对于一个2-D的tensor,我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j]个元素,那么放置到self里面的位置(个纬度的值)分别如下:
- i
- index[i][j]
为了更加方便理解,这里假设src是2-D的,即一个2维数组,且dim==1的情况下:
a = torch.arange(1, 11).long().reshape(2,5)
print(a)
i = torch.LongTensor([[4], [3]])
t = torch.zeros(10).long().reshape(2, 5)
t.scatter_(1, i, a)
print(t)
输出如下:
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
tensor([[0, 0, 0, 0, 1],
[0, 0, 0, 6, 0]])
在这里,index里面只有两个元素,那么也就是最终会有两个元素的值从src里面取出,设置到a里面去。在index里面仅有的两个元素是index[0][0]和index[1][0],这两个对应的src的元素是a[0][0]和a[1][0],对应的目的tensor(t)里面的t[0][index[0][0]]和t[1][index[1][0]]元素,即t[0][4]将会被设置为a[0][0],t[1][3]将会被设置为a[1][0],即:
- t[0][4] = 1
- t[1][3] = 6
其他目的tensor(t)里面的值都不会变。
3. 当dim为2的时候
大家可以按照上面说明的规则,自己进行推导,就不在这里赘述了。
总结:
scatter或者scatter_函数的作用就是把src里面的元素按照index和dim参数给出的条件,放置到目标tensor里面去。index有几个元素,就会有几个元素被从src里面放到目标tensor里面,其余目标tensor里面的元素不受影响。