8 Torch 中 view() & size() 用法
Torch 中 view() & size()
在阅读论文源码过程中,经常会看到如下的命令:
x = x.view(x.size(0), -1) # 改变 tensor 的形态
下面,本文简单介绍一下 view() 和 size() 函数的作用:
view()
import torch # 用法一
a = torch.ones(2, 3, 4)
b = a.view(3, 8)
b
Output:
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.]])
import torch # 用法二
a = torch.ones(2, 3, 4)
b = a.view(4, -1)
b
Output:
tensor([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) # b = a.view(-1)
以上展示了两种用法:
1、torch.view(x, y, z, ……)
将原 tensor 以参数设置的维度重排
2、torch.view(x, -1) & torch.view(-1)
将原 tensor 以参数 x 设置第一维度重排,第二维度自动补齐;当没有参数 x 时,直接重排为一维的 tensor
size()
import torch
a = torch.ones(2, 3, 4)
a.size()
Output:
torch.Size([2, 3, 4])
a.size(0)
Output:
2
a.size(1)
Output:
3
a.size(2)
Output:
4
综上,torch.size(x) 即返回 tensor第 x 维的长度