pytorch transpose
pytorch transpose
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893, 0.5809],
[-0.1669, 0.7299, 0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
[-0.9893, 0.7299],
[ 0.5809, 0.4942]])
pytorch中的transpose方法的作用是交换矩阵的两个维度,transpose(dim0, dim1) → Tensor,其和torch.transpose()函数作用一样。
torch.transpose():
torch.transpose(input, dim0, dim1) → Tensor
Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.
The resulting out tensor shares it’s underlying storage with the input tensor, so changing the content of one would change the content of the other.
第二条是说输出和输入是共享一块内存的,所以两者同时改变。
Parameters
input (Tensor) – the input tensor.
dim0 (int) – the first dimension to be transposed
dim1 (int) – the second dimension to be transposed
例:
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893, 0.5809],
[-0.1669, 0.7299, 0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
[-0.9893, 0.7299],
[ 0.5809, 0.4942]])
需要注意的几点:
1、transpose中的两个维度参数的顺序是可以交换位置的,即transpose(x, 0, 1,) 和transpose(x, 1, 0)效果是相同的。如下:
>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[-0.4343, 0.4643, -1.1345],
[-0.3667, -1.9913, 1.3485]])
>>> torch.transpose(x, 1, 0)
tensor([[-0.4343, -0.3667],
[ 0.4643, -1.9913],
[-1.1345, 1.3485]])
>>> torch.transpose(x, 0, 1)
tensor([[-0.4343, -0.3667],
[ 0.4643, -1.9913],
[-1.1345, 1.3485]])
2、transpose.()中只有两个参数,而torch.transpose()函数中有三个参数。
————————————————
链接:https://blog.csdn.net/a250225/article/details/102636425