pytorch-tensor高级OP
Tensor advanced operation
▪ Where
▪ Gather
where
返回的最终的tensor中的每一个值有可能来着A,也有可能来自B。
torch.where(condition,A,B)->tensor
满足condition条件则该位置为A[..],否则为B[..]。
这个condition也是一个相同shape的tensor
比如说:torch.where(cond>0,a,b)
\( \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ \end{bmatrix} \) =>\( \begin{bmatrix} A & B \\ A & B \\ \end{bmatrix} \)
cond=torch.tensor([[0.6769,0.7271],[0.8884,0.4163]])
cond
# tensor([[0.6769, 0.7271],
# [0.8884, 0.4163]])
a=torch.ones(2,2)
a
# tensor([[1., 1.],
# [1., 1.]])
b=torch.zeros(2,2)
b
# tensor([[0., 0.],
# [0., 0.]])
torch.where(cond>0.5,a,b)
# tensor([[1., 1.],
# [1., 0.]])
gather
这个API设计的初衷就是这样的,下面有一个场景。
下面有三类动物,编号分别为0,1,2
\[\begin{bmatrix}
dog \\
cat \\
pig \\
\end{bmatrix}
\begin{matrix}
0 \\
1 \\
2
\end{matrix}
\]
然后我们识别之后的结果是一些编号,然后我们希望将这个结果编号变为类别
\[\begin{bmatrix}
1 \\
0 \\
1 \\
2
\end{bmatrix}
=>
\begin{bmatrix}
cat \\
dog \\
cat \\
pig
\end{bmatrix}
\]
所以这个API中input和index都是tensor。
然后我们看一个具体的例子:
idx
# tensor([[8, 2, 0],
# [4, 8, 1],
# [0, 4, 9],
# [0, 9, 2]])
label
# tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
torch.gather(label.expand(4,10),dim=1,index=idx.long())
# tensor([[108, 102, 100],
# [104, 108, 101],
# [100, 104, 109],
# [100, 109, 102]])
我们来分析一下torch.gather(label.expand(4,10),dim=1,index=idx.long())
首先这个label.expand(4,10)是这样的
我们来看看这个label.expand(4,10),是将input的shape变为(4,10),然后idx的每一行都可以按照label变化之后每一行的下标输出了,所以这个dim=1,就是按照10这个下标输出的。