pytorch 的 scatter 函数中参数 index 是如何指定散布位置

基本逻辑

  • dim=0 表示 scatter 操作将在行方向上进行。dim指定index参数的值。即 index[i,j]=v 的 v 的值。
  • index 张量的形状应该与 src 张量的形状匹配,或者能够广播到相同的形状。
  • 对于每一个 index 的位置 (i, j), src[i, j] 中的值会被放置到 res[index[i, j], j] 中。

举例:若 index[2,3]=1, 代表分 3 步的任务。

首先,根据index的索引[2,3]取得src同索引的值,即 src[2,3],假设src[2,3]=10

其次,根据 dim 指定的值,作为索引的位置,用index的值取替换,如,当dim=0,则替换索引(2,3)的2,这样索引就变为(1,3)

最后,合并上两部的到的结果,把第2步得到的索引(1,3)用于res的索引,把第1步得到的src的值用于res的值,最终得到res[1,3]=10

一句话:若 dim=0, index[i,j]=v 中,取值 scr[i,j]=r,赋值 res[v,j]=r
若 dim = 1, index[i,j]=v ,则有:src[i,j]=r, res[i,v]=r

套用pytorch的文档:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

把self替换为res,就都一样了。

感悟:理解一个函数的功能,一定要找出这个功能背后的计算公式,才算是真明白,否则那就是一知半解,换一个数就不会了。

具体的例子

在 PyTorch 的 scatter 函数中,当你指定 dim=0 时,index 张量确实是指定了 src 中每个元素应该放置在 res 张量的哪一行。不过,列号的选择是由 index 张量的形状和 src 张量的形状共同决定的。

1. 基本逻辑

  • dim=0 表示 scatter 操作将在行方向上进行。
  • index 张量的形状应该与 src 张量的形状匹配,或者能够广播到相同的形状。
  • 对于每一个 index 的位置 (i, j), src[i, j] 中的值会被放置到 res[index[i, j], j] 中。

2. 列号是如何确定的

列号实际上是通过遍历 src 张量的第二个维度(即列的维度)隐含确定的。

假设 src 的形状为 (m, n)index 的形状与 src 相同,那么:

  • src 中的每一列 j(从 0n-1)会被映射到 res 张量的相同列 j 中。
  • 行号是由 index 张量提供的,对应于 src 张量的每个元素在 res 张量中的行位置。

3. 结合例子解释

src = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
index = torch.tensor([[1, 1, 1]])
res = torch.zeros(2, 3).scatter_(0, index, src)

1. 初始情况

  • src 张量的形状是 (2, 3),意味着它有 2 行 3 列。
  • index 张量的形状是 (1, 3),这意味着 index 为每一列指定了 src 中对应元素应该放置到 res 张量的哪个行中。

2. 逐列分析

对于第 j 列:

  • index[0, j] 决定了 src[0, j]src[1, j] 要放在 res 的哪个行。
  • 因为 dim=0scatter_ 只在行维度(即第 0 维)上进行操作,列维度保持不变。

列 0 (j=0):

  • index[0, 0] = 1
  • src[0, 0] = 1. 将被放置在 res[1, 0]
  • src[1, 0] = 4. 也将被放置在 res[1, 0],最终覆盖 1.

列 1 (j=1):

  • index[0, 1] = 1
  • src[0, 1] = 2. 将被放置在 res[1, 1]
  • src[1, 1] = 5. 也将被放置在 res[1, 1],最终覆盖 2.

列 2 (j=2):

  • index[0, 2] = 1
  • src[0, 2] = 3. 将被放置在 res[1, 2]
  • src[1, 2] = 6. 也将被放置在 res[1, 2],最终覆盖 3.

4. 总结

  • 行号:由 index 指定,表示 src 中的每个元素应该放置在 res 的哪一行。
  • 列号:隐式决定,src 张量的每一列直接对应 res 张量的相同列。即 src[i, j] 的值会被放到 res[index[i, j], j]

所以在你的例子中,index 指定了所有元素应当放置在 res 张量的第 1 行,而列号则是隐式地从 src 中继承下来,与 src 的列号相同。

posted @ 2024-08-21 09:55  立体风  阅读(26)  评论(0编辑  收藏  举报