torch.chunk()数组拆分,torch.squeeze()

torch.chunk(input,chunks,dim=0)

功能:将数组拆分为特定数量的块

输入:

input:待拆分的数组

chunks:拆分的块数,指定为几,就拆成几

dim:拆分的维度,默认沿第1维度拆分

注意:

函数最后返回的是元组类型,包含拆分后的数组

如果输入的数组在指定的维度下不能整除,则拆分得到的最后一块数组的dim维度大小将小于前面所有的数组dim维度大小

chunks有最大值限制,如果指定的块数超过最大值,则最终只能拆分成最大值数量

chunks最大值的计算,input数组在dim维度上大小为a

样例

import torch
a=torch.arange(20).view(4,5)
b=torch.chunk(a,chunks=2,dim=0)
c=torch.chunk(a,chunks=2,dim=1)
print(type(b))
print(a.shape)
print(len(b))
print(len(c))
print(a)
for i in range(len(b)):
    print(b[i])
    print(b[i].shape)
    # 输出拆分后的形状
for i in range(len(c)):
    print(c[i])
    print(c[i].shape)

输出

<class 'tuple'>
torch.Size([4, 5])
2
2
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
torch.Size([2, 5])
tensor([[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
torch.Size([2, 5])
tensor([[ 0,  1,  2],
        [ 5,  6,  7],
        [10, 11, 12],
        [15, 16, 17]])
torch.Size([4, 3])
tensor([[ 3,  4],
        [ 8,  9],
        [13, 14],
        [18, 19]])
torch.Size([4, 2])

换句话说:除了最后一个数组dim维度上的大小可以为1,前面的数组dim维度上的大小至少是2

官方文档

torch.chunk():https://pytorch.org/docs/stable/generated/torch.chunk.html#torch.chunk

 

博客链接:https://blog.csdn.net/qq_50001789/article/details/120352480

 

torch.squeeze()

torch.squeeze(input, dim=None, out=None) 

squeeze()函数的功能是维度压缩。返回一个tensor,其中input中大小为1的所有维都已删除。

举个例子:如果input的形状为(A×1×B×C×1×D),那么返回的tensor的形状则为(A×B×C×D)

当给定dim时,那么只在给定的维度上进行压缩操作。

举个例子:如果input的形状为(A×1×B),squeeze(input,0)后,返回的tensor不变;squeeze(input,1)后,返回的tensor将被压缩为(A×B)

官方文档:https://pytorch.org/docs/stable/generated/torch.squeeze.html?highlight=squeeze#torch.squeeze

 

博客链接:

https://blog.csdn.net/qq_40305043/article/details/107767652

posted @ 2023-03-06 15:03  sqsq  阅读(209)  评论(0编辑  收藏  举报