8 Torch 中 view() & size() 用法

Torch 中 view() & size()

在阅读论文源码过程中,经常会看到如下的命令:

x = x.view(x.size(0), -1)	# 改变 tensor 的形态

下面,本文简单介绍一下 view() 和 size() 函数的作用:

view()

import torch	# 用法一
a = torch.ones(2, 3, 4)
b = a.view(3, 8)
b

Output:
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

import torch	# 用法二
a = torch.ones(2, 3, 4)
b = a.view(4, -1)
b

Output:
tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])	# b = a.view(-1)

以上展示了两种用法:

1、torch.view(x, y, z, ……)

将原 tensor 以参数设置的维度重排

2、torch.view(x, -1) & torch.view(-1)

将原 tensor 以参数 x 设置第一维度重排,第二维度自动补齐;当没有参数 x 时,直接重排为一维的 tensor

size()

import torch
a = torch.ones(2, 3, 4)
a.size()
Output:
torch.Size([2, 3, 4])

a.size(0)
Output:
2

a.size(1)
Output:
3

a.size(2)
Output:
4

综上,torch.size(x) 即返回 tensor第 x 维的长度

posted @ 2021-10-21 11:18  SethDeng  阅读(703)  评论(0编辑  收藏  举报