pytorch方法整理--gather&squeeze&unsqueeze&其他一些函数
gather
pytorch中gather源码形式:torch.gather(input, dim, index, *, sparse_grad = False, out = None)
然后在pytorch官方文档中,写了这样的一个例子,这个例子是三维的
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
刚开始比较难理解,不知道什么意思,于是试了几个例子
一维:
>>> array1 = torch.tensor([1,2,3])
>>> torch.gather(array1, 0, torch.tensor([0,1]))
tensor([1, 2])
上述例子中,array1的矩阵形式为array1 = [1,2,3]
, 按维度0取值(对于一维的情况,顶多也为0), 将[array1[0],array1[1]]
作为输出结果,也就是[1,2]
二维
>>> array2 = torch.tensor([[1,2,3],[4,5,6]])
>>> torch.gather(array2, 0, torch.tensor([[0, 1]]))
tensor([[1, 5]])
在上述二维的例子中,array2的形式为array2 = [[1,2,3],[4,5,6]]
, 按维度0取值, 将[array2[0][0], array2[1][1]]
为输出结果,也就是[1,5]
看了上面两个例子,我们根据torch.gather(input, dim, index, *, sparse_grad = False, out = None)
看下公式:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
明确一点的是,输出的size是和index的size是一样的。
对于一维的,假设index的大小为n, 那么输出结果为 [input[index[0]], input[index[1]], .... , input[index[n]]], 也就是我们上个例子中的[1,2]
对于二维的,如果dim为0,假设index的大小为m*n, 那么输出结果为
[[input[index[0][0]][0], input[index[0][1]][1], ... , input[index[0][n]][n],
[input[index[1][0]][0], input[index[1][1]][1], ... , input[index[1][n]][n],
....
[input[index[m][0]][0], input[index[m][1]][1], ...., input[index[m][n]][n]
所以上面的例子我们输出为[1,5]
如果dim为1呢, 同样假设index的大小为m*n,那么输出结果为:
[[input[0][index[0][0]], input[1][index[0][1]], ... , input[n][index[0][n]],
[input[0][index[1][0]], input[1][index[1][1]], ... , input[n][index[1][n]],
....
[input[0][index[m][0]], input[1][index[m][1]], ...., input[n][index[m][n]],
看到这里那是不是就对官网给出的公式有点理解了呢?
再来看几个3维的例子
array = torch.tensor(np.arange(24)).view(2,3,4)
array
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=torch.int32)
torch.gather(array, 0, torch.tensor([[[0,1],[1,1]]]))
输出结果为
[[[array[0][0][0],array[1][0][1],
[array[1][1][0],array[1][1][1]
]]
tensor([[[ 0, 13],
[16, 17]]], dtype=torch.int32)
看到这里是不是有点理解了,不理解的平时多试试就比较清楚了。
squeeze
torch.squeeze中函数形式
torch.squeeze(input, dim = None, *, out = None) -> Tensor
,默认dim
参数为None
,
官网也描述了它的作用:Returns a tensor with all the dimensions of input of size 1 removed.
就是移除所有size为1的维度,比如说输入一个array,它的shape为(1,2,3,1,2),那么他的output的size为(2,3,2)
具体看一下例子:
>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array).size()
torch.Size([2, 3, 2])
如果某个维度的size不为1,那就不移除。
另外还有一种写法可以移除特定size为1的维度,写法torch.squeeze(array, dim)
例如:
>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array,0).size()
torch.Size([2, 3, 1, 2])
这里第四维度的1就没有移除掉。
下面我们再来看下unsqueeze方法
unsqueeze
torch官网中描述的方法 torch.unsqueeze(input, dim)->Tensor
作用是返回一个在特定维度插入size为1的tensor
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])
这个dim可以为多少呢?官方也是做出了解释A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.
就是说,dim的输入只能在[-input.dim()-1, input.dim() + 1]范围内。在上面的例子中,维度限制在[-2, 1]之间。
如果是负数怎么处理呢? dim = dim + input.dim() + 1
, 也就是说,如果输入-2, 那么应该输出dim = 0
,
其实从这个公式,和list中里面的选取元素差不多,
例如list = [1,2,3,4]; list[0]= 1, list[-1] = 4,相当于 dim为-1 就是在最高维插入size为1, 而当dim为-input.dim() - 1相当于在维度0处插入size为1。
例子
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, -2).size()
torch.Size([1, 4])
>>> torch.unsqueeze(x, -1).size()
torch.Size([4, 1])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])
看到这里是不是就理解了呢?
torch.full()
看一遍就知道这个函数的意思了,例子如下:
>>> torch.full((3,2), 3.1415926)
tensor([[3.1416, 3.1416],
[3.1416, 3.1416],
[3.1416, 3.1416]])
torch.permute()
pytorch 官方文档的描述: Returns a view of the original tensor input with its dimensions permuted.
意思就是返回原始张量输入的大小,其尺寸已经被置换;
例子:
torch.repeat()
pytorch 官方文档描述:Repeats this tensor along the specified dimensions.
意思就是,沿着指定的维度重复此张量
例子:
torch.chunk()
pytorch 官方文档 介绍 Attempts to split a tensor into the specified number of chunks. Each chunk is a view of the input tensor.
如果沿给定维度 dim 的张量大小可被块整除,则所有返回的块将具有相同的大小。 如果沿给定维度 dim 的张量大小不能被块整除,则所有返回的块将具有相同的大小,除了最后一个。 如果这样的划分是不可能的,这个函数可能会返回少于指定数量的块。
例子:
像第二个,不能划分为6个,因为前五个tensor的size必须要一样,所以只能分成5份(如果前五个size为2的话,最后一个为3,超过了2)。
torch.topk()
topk的函数参数torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
返回在给定维度下,最大的k个元素。
例子:
(其他方法后续跟新)