pytorch 中改变tensor维度(transpose)、拼接(cat)、压缩(squeeze)详解
具体示例如下,注意观察维度的变化
1.改变tensor维度的操作:transpose、view、permute、t()、expand、repeat
#coding=utf-8 import torch def change_tensor_shape(): x=torch.randn(2,4,3) s=x.transpose(1,2) #shape=[2,3,4] y=x.view(2,3,4) #shape=[2,3,4] z=x.permute(0,2,1) #shape=[2,3,4] #tensor.t()只能转化 a 2D tensor m=torch.randn(2,3)#shape=[2,3] n=m.t()#shape=[3,2] print(m) print(n) #返回当前张量在某个维度为1扩展为更大的张量 x = torch.Tensor([[1], [2], [3]])#shape=[3,1] t=x.expand(3, 4) print(t) ''' tensor([[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]]) ''' #沿着特定的维度重复这个张量 x=torch.Tensor([[1,2,3]]) t=x.repeat(3, 2) print(t) ''' tensor([[1., 2., 3., 1., 2., 3.], [1., 2., 3., 1., 2., 3.], [1., 2., 3., 1., 2., 3.]]) ''' x = torch.randn(2, 3, 4) t=x.repeat(2, 1, 3) #shape=[4, 3, 12] if __name__=='__main__': change_tensor_shape()
2.tensor的拼接:cat、stack
除了要拼接的维度可以不相等,其他维度必须相等
#coding=utf-8 import torch def cat_and_stack(): x = torch.randn(2,3,6) y = torch.randn(2,4,6) c=torch.cat((x,y),1) #c=(2*7*6) print(c.size) """ 而stack则会增加新的维度。 如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。 """ a = torch.rand((1, 2)) b = torch.rand((1, 2)) c = torch.stack((a, b), 0) print(c.size()) if __name__=='__main__': cat_and_stack()
3.压缩和扩展维度:改变tensor中只有1个维度的tensor
torch.squeeze(input, dim=None, out=None) → Tensor
除去输入张量input中数值为1的维度,并返回新的张量。如果输入张量的形状为(A×1×B×C×1×D) 那么输出张量的形状为(A×B×C×D)
当通过dim参数指定维度时,维度压缩操作只会在指定的维度上进行。如果输入向量的形状为(A×1×B),
squeeze(input, 0)会保持张量的维度不变,只有在执行squeeze(input, 1)时,输入张量的形状会被压缩至(A×B) 。
如果一个张量只有1个维度,那么它不会受到上述方法的影响。
#coding=utf-8 import torch def squeeze_tensor(): x = torch.Tensor(1,3) y=torch.squeeze(x, 0) print("y:",y) y=torch.unsqueeze(y, 1) print("y:",y) if __name__=='__main__': squeeze_tensor()