torch.chunk()数组拆分,torch.squeeze()
torch.chunk(input,chunks,dim=0)
功能:将数组拆分为特定数量的块
输入:
input:待拆分的数组
chunks:拆分的块数,指定为几,就拆成几
dim:拆分的维度,默认沿第1维度拆分
注意:
函数最后返回的是元组类型,包含拆分后的数组
如果输入的数组在指定的维度下不能整除,则拆分得到的最后一块数组的dim维度大小将小于前面所有的数组dim维度大小
chunks有最大值限制,如果指定的块数超过最大值,则最终只能拆分成最大值数量
chunks最大值的计算,input数组在dim维度上大小为a
样例
import torch a=torch.arange(20).view(4,5) b=torch.chunk(a,chunks=2,dim=0) c=torch.chunk(a,chunks=2,dim=1) print(type(b)) print(a.shape) print(len(b)) print(len(c)) print(a) for i in range(len(b)): print(b[i]) print(b[i].shape) # 输出拆分后的形状 for i in range(len(c)): print(c[i]) print(c[i].shape)
输出
<class 'tuple'> torch.Size([4, 5]) 2 2 tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]) tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) torch.Size([2, 5]) tensor([[10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]) torch.Size([2, 5]) tensor([[ 0, 1, 2], [ 5, 6, 7], [10, 11, 12], [15, 16, 17]]) torch.Size([4, 3]) tensor([[ 3, 4], [ 8, 9], [13, 14], [18, 19]]) torch.Size([4, 2])
换句话说:除了最后一个数组dim维度上的大小可以为1,前面的数组dim维度上的大小至少是2
官方文档
torch.chunk():https://pytorch.org/docs/stable/generated/torch.chunk.html#torch.chunk
博客链接:https://blog.csdn.net/qq_50001789/article/details/120352480
torch.squeeze()
torch.squeeze(input, dim=None, out=None)
squeeze()函数的功能是维度压缩。返回一个tensor,其中input中大小为1的所有维都已删除。
举个例子:如果input的形状为(A×1×B×C×1×D),那么返回的tensor的形状则为(A×B×C×D)
当给定dim时,那么只在给定的维度上进行压缩操作。
举个例子:如果input的形状为(A×1×B),squeeze(input,0)后,返回的tensor不变;squeeze(input,1)后,返回的tensor将被压缩为(A×B)
官方文档:https://pytorch.org/docs/stable/generated/torch.squeeze.html?highlight=squeeze#torch.squeeze
博客链接:
https://blog.csdn.net/qq_40305043/article/details/107767652