pytorch 的 scatter 函数中参数 index 是如何指定散布位置
基本逻辑
dim=0
表示scatter
操作将在行方向上进行。dim指定index参数的值。即 index[i,j]=v 的 v 的值。index
张量的形状应该与src
张量的形状匹配,或者能够广播到相同的形状。- 对于每一个
index
的位置(i, j)
,src[i, j]
中的值会被放置到res[index[i, j], j]
中。
举例:若 index[2,3]=1, 代表分 3 步的任务。
首先,根据index的索引[2,3]取得src同索引的值,即 src[2,3],假设src[2,3]=10
其次,根据 dim 指定的值,作为索引的位置,用index的值取替换,如,当dim=0,则替换索引(2,3)的2,这样索引就变为(1,3)
最后,合并上两部的到的结果,把第2步得到的索引(1,3)用于res的索引,把第1步得到的src的值用于res的值,最终得到res[1,3]=10
一句话:若 dim=0, index[i,j]=v 中,取值 scr[i,j]=r,赋值 res[v,j]=r
若 dim = 1, index[i,j]=v ,则有:src[i,j]=r, res[i,v]=r
套用pytorch的文档:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
把self替换为res,就都一样了。
感悟:理解一个函数的功能,一定要找出这个功能背后的计算公式,才算是真明白,否则那就是一知半解,换一个数就不会了。
具体的例子
在 PyTorch 的 scatter
函数中,当你指定 dim=0
时,index
张量确实是指定了 src
中每个元素应该放置在 res
张量的哪一行。不过,列号的选择是由 index
张量的形状和 src
张量的形状共同决定的。
1. 基本逻辑
dim=0
表示scatter
操作将在行方向上进行。index
张量的形状应该与src
张量的形状匹配,或者能够广播到相同的形状。- 对于每一个
index
的位置(i, j)
,src[i, j]
中的值会被放置到res[index[i, j], j]
中。
2. 列号是如何确定的
列号实际上是通过遍历 src
张量的第二个维度(即列的维度)隐含确定的。
假设 src
的形状为 (m, n)
,index
的形状与 src
相同,那么:
src
中的每一列j
(从0
到n-1
)会被映射到res
张量的相同列j
中。- 行号是由
index
张量提供的,对应于src
张量的每个元素在res
张量中的行位置。
3. 结合例子解释
src = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
index = torch.tensor([[1, 1, 1]])
res = torch.zeros(2, 3).scatter_(0, index, src)
1. 初始情况
src
张量的形状是(2, 3)
,意味着它有 2 行 3 列。index
张量的形状是(1, 3)
,这意味着index
为每一列指定了src
中对应元素应该放置到res
张量的哪个行中。
2. 逐列分析
对于第 j
列:
index[0, j]
决定了src[0, j]
和src[1, j]
要放在res
的哪个行。- 因为
dim=0
,scatter_
只在行维度(即第 0 维)上进行操作,列维度保持不变。
列 0 (j=0
):
index[0, 0] = 1
src[0, 0] = 1.
将被放置在res[1, 0]
src[1, 0] = 4.
也将被放置在res[1, 0]
,最终覆盖1.
列 1 (j=1
):
index[0, 1] = 1
src[0, 1] = 2.
将被放置在res[1, 1]
src[1, 1] = 5.
也将被放置在res[1, 1]
,最终覆盖2.
列 2 (j=2
):
index[0, 2] = 1
src[0, 2] = 3.
将被放置在res[1, 2]
src[1, 2] = 6.
也将被放置在res[1, 2]
,最终覆盖3.
4. 总结
- 行号:由
index
指定,表示src
中的每个元素应该放置在res
的哪一行。 - 列号:隐式决定,
src
张量的每一列直接对应res
张量的相同列。即src[i, j]
的值会被放到res[index[i, j], j]
。
所以在你的例子中,index
指定了所有元素应当放置在 res
张量的第 1
行,而列号则是隐式地从 src
中继承下来,与 src
的列号相同。