Pytorch | view()函数的使用
函数简介
Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。
根据上面的描述可知,view函数的操作对象应该是Tensor类型。如果不是Tensor类型,可以通过tensor = torch.tensor(data)
来转换。
普通用法 (手动调整size)
view(参数a,参数b,…)
,其中,总的参数个数表示将张量重构后的维度。
view()
相当于reshape
、resize
,重新调整Tensor的形状。
import torch
a1 = torch.arange(0,16)
print(a1) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
a2 = a1.view(8, 2) # 将a1的维度改为8*2
a3 = a1.view(2, 8) # 将a1的维度改为2*8
a4 = a1.view(4, 4) # 将a1的维度改为4*4
# a5 = a1.view(2,2,1,4)
# 更多的维度也没有问题,只要保证维度改变前后的元素个数相同就行,即 2*2*1*4=16。
print(a2)
print(a3)
print(a4)
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
特殊用法:参数-1 (自动调整size)
view(参数a,参数b,…)
中一个参数定为-1,代表自动调整这个维度上的元素个数,则表示该维度取决于其它维度,由Pytorch自己补充,以保证元素的总数不变。
import torch
a1 = torch.arange(0,16)
print(a1) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
a2 = a1.view(-1, 16)
a3 = a1.view(-1, 8)
a4 = a1.view(-1, 4)
a5 = a1.view(-1, 2)
a6 = a1.view(4*4, -1)
a7 = a1.view(1*4, -1)
a8 = a1.view(2*4, -1)
print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
view(-1)
表示将Tensor转为一维Tensor。
a9 = a1.view(-1)
print(a1)
print(a9) # 因此,转变后还是一维,没什么变换
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
到此这篇关于pytorch中的 .view()函数的用法介绍的文章就介绍到这了,更多相关pytorch .view()函数内容请去pytorch官网文档查看。