pytorch高阶操作

pytorch高阶操作

where函数

torch.where(condition,x,y)

可能新生成的tensor一部分来自x,一部分来自y,但是是没有规律的

例子:假设一个tensor表示识别概率,大于0.5表示1,小于0.5表示0

a = torch.rand(2,2)
print(a)

tensor([[0.9872, 0.9270],
        [0.6795, 0.0959]])


aa = torch.zeros(2,2)
bb = torch.ones(2,2)

answer = torch.where(a>0.5,aa,bb)
print(answer)

tensor([[0., 0.],
        [0., 1.]])

gather函数

实际就是一个查表的函数

比如像手写数字的识别,【4,10】4张图片,最后识别出每张图片中10个概率最大的index(一般index为几这个数字就是几),但是如果我们的标签不是1~10,而是另外有一张表来对应,不同的index对应不同的标签,这时就可以使用gather函数

例子:

prob = torch.rand(4,10)

idx = prob.topk(3,dim=1)
idx1 = idx[1]

print(idx1)

tensor([[1, 3, 4],
        [2, 0, 3],
        [5, 4, 2],
        [9, 4, 5]])

label = torch.arange(10)+100#为了方面随便初始化的label

print(torch.gather(label.expand(4,10),dim=1,index=idx1.long()))


tensor([[101, 103, 104],
        [102, 100, 103],
        [105, 104, 102],
        [109, 104, 105]])
posted @ 2020-09-02 18:05  Jason66661010  阅读(199)  评论(0编辑  收藏  举报