pytorch记录
有两个tensor是A和B
C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼) C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)
A = torch.ones(2,3) B = torch.ones(4,3) out=torch.cat((A,B),0) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) C = torch.ones(2,5) out = torch.cat((A,C),1) tensor([[1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1.]])
max_test = torch.Tensor([[5,8,1],[3,1,9]]) tensor([[5., 8., 1.], [3., 1., 9.]]) max_test.max(1,keepdim=True) values=tensor([[8.], [9.]]), indices=tensor([[1], [2]])) max_test.max(1) torch.return_types.max( values=tensor([8., 9.]), indices=tensor([1, 2])) max_test.max(0) values=tensor([5., 8., 9.]), indices=tensor([0, 0, 1])) max_test.max(0,keepdim=True) torch.return_types.max( values=tensor([[5., 8., 9.]]), indices=tensor([[0, 0, 1]]))
valid_idx = torch.tensor([True, False, True, False, False]) #小写的t,long类型 a = torch.tensor([1,2,3,4,5]) idx_filter = a[valid_idx] tensor([1, 3])
b = torch.Tensor([[1,2,3]]) b.squeeze(0) b tensor([[1., 2., 3.]]) b.squeeze_(0) b tensor([1., 2., 3.])
a = torch.ones(3,5) index = torch.tensor([0,2]) a.index_fill_(0,index,100) tensor([[100., 100., 100., 100., 100.], [ 1., 1., 1., 1., 1.], [100., 100., 100., 100., 100.]]) b = torch.ones(3,5) b.index_fill(1,index,200) tensor([[200., 1., 200., 1., 1.], [200., 1., 200., 1., 1.], [200., 1., 200., 1., 1.]])
labels= torch.rand(5,4) tensor([[0.2833, 0.7600, 0.6912, 0.5421], [0.3498, 0.0440, 0.3356, 0.5975], [0.9071, 0.2023, 0.9391, 0.2516], [0.9536, 0.0939, 0.4833, 0.7402], [0.2392, 0.7111, 0.9192, 0.5417]]) best_idx = torch.tensor([3,3,3,0,0,0,0]) labels[best_idx] tensor([[0.9536, 0.0939, 0.4833, 0.7402], [0.9536, 0.0939, 0.4833, 0.7402], [0.9536, 0.0939, 0.4833, 0.7402], [0.2833, 0.7600, 0.6912, 0.5421], [0.2833, 0.7600, 0.6912, 0.5421], [0.2833, 0.7600, 0.6912, 0.5421], [0.2833, 0.7600, 0.6912, 0.5421]])