pytorch维度变化
▪ View/reshape(这个是维度的变化)
▪ Squeeze/unsqueeze(维度的增加/减少)
▪ Transpose/t/permute(维度交换)
▪ Expand/repeat(维度的扩展)
view/reshape
这两个API,都是通用的a.reshape()和a.view()
但是有一个条件就是,变化之前和之后的dim相乘的数值相等
a=torch.rand(4,1,28,28)
a.reshape(4,28*28).shape
# torch.Size([4, 784])
a.view(4,28*28).shape
# torch.Size([4, 784])
a.view(4*28,28).shape
# torch.Size([112, 28])
a.reshape(4*1,28,28).shape
# torch.Size([4, 28, 28])
b=a.view(4,784)
b.view(4,28,28,1).shape
# torch.Size([4, 28, 28, 1])
squeeze/unsqueeze
一个是减少维度一个是增加维度
unsqueeze
.unsqueeze(k)
这个k的范围是[-a.dim()-1,a.dim()+1)
就是比如说[2,1,28,28],那么这个k的范围就是[-5,5)
如果这个k<0的话,就是在dim=k的后面加一个维度,就是比如说k=-1的话,就是在最后一个维度的后面加一个维度
如果k>=0的话,就是在dim=k的前面加一个维度,就是比如说k=0的话,就是在第1个维度的前面加一个维度。
我们可以看这个例子:
k的范围为[-5,4),k<0,在后面加,k>=0在前面加
a.shape
# torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape
# torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape
# torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(4).shape
# torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(-4).shape
# torch.Size([4, 1, 1, 28, 28])
a.unsqueeze(-5).shape
# torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(5).shape
# IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
然后我们看看这些例子:
然后我们看一个具体的例子:
b=torch.rand(32)
f=torch.rand(4,32,14,14)
b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
b.shape
# torch.Size([1, 32, 1, 1])
(b+f).shape
# torch.Size([4, 32, 14, 14])
squeeze
b.squeeze(dim)
这个参数如果不加的话,那就挤压全部能挤压的维度(shape为1的维度)
这里需要注意只有shape=1的那个维度才能squeeze
b.shape
# torch.Size([1, 32, 1, 1])
b.squeeze().shape
# torch.Size([32])
# 压缩全部能压缩的
b.squeeze(0).shape
# torch.Size([32, 1, 1])
b.squeeze(-1).shape
# torch.Size([1, 32, 1])
b.squeeze(1).shape
# torch.Size([1, 32, 1, 1])
b.squeeze(-4).shape
# torch.Size([32, 1, 1])
expand / repeat
expand
expand:broadcasting
squeeze可以将shape=1的那个维度去掉,Expand则可以将shape=1那个维度进行扩展,只能是shape=1的维度进行扩展
如果某一个维度不进行扩张的话,可以写上-1,进行占位。
a=torch.rand(4,32,14,14)
b.shape
# torch.Size([1, 32, 1, 1])
b.expand(4,32,14,14).shape
# torch.Size([4, 32, 14, 14])
b.expand(-1,32,-1,-1).shape
# torch.Size([1, 32, 1, 1])
# 这个就是指的是第0,2,3维度不变
repeat
这个是内存的复制,比较不建议使用这个,建议使用expand
repeat(a,b,c,d)
这里有一个和上面的不一样的就是这个是第0维度复制a倍,第1维度复制b倍。。。,第2个维度复制c倍,第三个维度复制d倍,如果某一个维度不变的话,就将该维度值=1。
还有一点不一样的就是这个维度不用为1
b.shape
# torch.Size([1, 32, 1, 1])
b.repeat(4,32,1,1).shape
# torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape
# torch.Size([4, 32, 1, 1])
b.repeat(4,1,32,32).shape
# torch.Size([4, 32, 32, 32])
Transpose/permute
Transpose
.teanspose(a,b)
这个就是代表这个将维度a和维度b进行交换
a=torch.rand(4,3,32,32)
a.shape
# torch.Size([4, 3, 32, 32])
a.transpose(1,3).shape
# torch.Size([4, 32, 32, 3])
但是这是有一个问题的,就是维度会混乱
a1=a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
# 就是[b,c,h,w]->[b,w,h,c]->[b,w*h*c],这样之后我们我们如果在想变回[b,c,w,h]会出错
这时候我们需要一个函数.contiguous(),就是使它变得连续之后再变化,
a1=a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
# [b,c,h,w]->[b,w,h,c]->[b,w*h*c]->[b,c,w,h]
# 这样变化之后和最初的a不一样了
下面这样变化可以使得它变成a一样的
a2=a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
# [b,c,h,w]->[b,w,h,c]->[b,w*h*c]->[b,w,h,c]->[b,c,h,w]
permute
permute()
这个里面是索引。
上面那只能两个两个的变换,这个可以一下改变所有的
b=torch.rand(4,3,28,32)
b.shape
# torch.Size([4, 3, 28, 32])
b.transpose(1,3).transpose(1,2).shape
# torch.Size([4, 28, 32, 3])
b.permute(0,2,3,1).shape
# torch.Size([4, 28, 32, 3])
这里一步就能全部换掉