[Pytorch笔记] scatter_
https://blog.csdn.net/qq_16234613/article/details/79827006
scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中.
1 >>> x = torch.rand(2, 5) 2 >>> x 3 4 0.4319 0.6500 0.4080 0.8760 0.2355 5 0.2609 0.4711 0.8486 0.8573 0.1029 6 [torch.FloatTensor of size 2x5]
1) dim = 0,分别对每列填充:
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) 0.4319 0.4711 0.8486 0.8760 0.2355 0.0000 0.6500 0.0000 0.8573 0.0000 0.2609 0.0000 0.4080 0.0000 0.1029 [torch.FloatTensor of size 3x5]
实现原理:
对于LoneTensor内的矩阵,暂且称为 tmp = [[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]];将最终的 3*5的矩阵,暂且称为result。result初始为全0,需要经过scatter_处理。
举例:
对于tmp[0][0] = 0 -> 取x中x[0][0] = 0.4319,将其插入到result第0列的第0个位置,result[0][0] = 0.4319;
对于tmp[0][1] = 1 -> 取x中x[0][1] = 0.6500,将其插入到result第1列的第1个位置,result[1][1] = 0.6500;
对于tmp[0][2] = 2 -> 取x中x[0][1] = 0.4080,将其插入到result第2列的第2个位置,result[2][2] = 0.4080;
......
对于tmp[1][0] = 2 -> 取x中x[1][0] = 0.2609,将其插入到result第0列的第2个位置,result[2][0] = 0.2609;
对于tmp[1][1] = 0 -> 取x中x[1][1] = 0.4711,将其插入到result第1列的第0个位置,result[0][1] = 0.4711。
......
2) dim = 1,分别对每行填充
1 >>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23) 2 >>> z 3 4 0.0000 0.0000 1.2300 0.0000 5 0.0000 0.0000 0.0000 1.2300 6 [torch.FloatTensor of size 2x4]
tmp = [[2], [3]]
tmp[0][0] = 2 -> 取x中x[0][0] = 0.4319,将其插入到result第0行的第2个位置,result[0][2] = 0.4319;
......