pytorch记录

有两个tensor是A和B

C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼) C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)
 A = torch.ones(2,3)
 B = torch.ones(4,3)
 out=torch.cat((A,B),0)
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


C = torch.ones(2,5)
out = torch.cat((A,C),1)
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
 max_test = torch.Tensor([[5,8,1],[3,1,9]])
tensor([[5., 8., 1.],
        [3., 1., 9.]])

max_test.max(1,keepdim=True)
values=tensor([[8.],
        [9.]]),
indices=tensor([[1],
        [2]]))

 max_test.max(1)
torch.return_types.max(
values=tensor([8., 9.]),
indices=tensor([1, 2]))

max_test.max(0)
values=tensor([5., 8., 9.]),
indices=tensor([0, 0, 1]))

max_test.max(0,keepdim=True)
torch.return_types.max(
values=tensor([[5., 8., 9.]]),
indices=tensor([[0, 0, 1]]))
valid_idx = torch.tensor([True, False, True, False, False]) #小写的t,long类型
a = torch.tensor([1,2,3,4,5])
idx_filter = a[valid_idx]
tensor([1, 3])
b = torch.Tensor([[1,2,3]])
b.squeeze(0)
 b
tensor([[1., 2., 3.]])

b.squeeze_(0)
b
tensor([1., 2., 3.])
a = torch.ones(3,5)
index = torch.tensor([0,2])
a.index_fill_(0,index,100)
tensor([[100., 100., 100., 100., 100.],
        [  1.,   1.,   1.,   1.,   1.],
        [100., 100., 100., 100., 100.]])


b = torch.ones(3,5)
b.index_fill(1,index,200)
tensor([[200.,   1., 200.,   1.,   1.],
        [200.,   1., 200.,   1.,   1.],
        [200.,   1., 200.,   1.,   1.]])

 

 labels= torch.rand(5,4)
tensor([[0.2833, 0.7600, 0.6912, 0.5421],
        [0.3498, 0.0440, 0.3356, 0.5975],
        [0.9071, 0.2023, 0.9391, 0.2516],
        [0.9536, 0.0939, 0.4833, 0.7402],
        [0.2392, 0.7111, 0.9192, 0.5417]])
 best_idx = torch.tensor([3,3,3,0,0,0,0])
labels[best_idx]
tensor([[0.9536, 0.0939, 0.4833, 0.7402],
        [0.9536, 0.0939, 0.4833, 0.7402],
        [0.9536, 0.0939, 0.4833, 0.7402],
        [0.2833, 0.7600, 0.6912, 0.5421],
        [0.2833, 0.7600, 0.6912, 0.5421],
        [0.2833, 0.7600, 0.6912, 0.5421],
        [0.2833, 0.7600, 0.6912, 0.5421]])

 

posted @ 2021-04-21 18:20  crazybird123  阅读(60)  评论(0编辑  收藏  举报