pytorch基础问题
本文将自己在pytorch学习中遇见的各种问题整理起来,并且持续更新。
1:torch.Tensor和torch.tensor的区别
开始使用torch.tensor和torch.Tensor的时候发现结果都是一样的。都能生成新的张量。但根源上是有差别的。
import torch n=torch.tensor([[3,4],[1,2]]) x=torch.Tensor([[3,4],[1,2]]) print(n,'|||',x) print(n.shape,'|||',x.shape) print(n.type(),'|||',x.type()) ''' tensor([[3, 4], [1, 2]]) ||| tensor([[3., 4.], [1., 2.]]) torch.Size([2, 2]) ||| torch.Size([2, 2]) torch.LongTensor ||| torch.FloatTensor '''
torch.Tensor()是Python类,更明确的说,是默认张量类型torch.FloatTensor()的别名,调用Tensor类的构造函数__init__,生成单精度浮点类型的张量。
orch.tensor()仅仅是Python的函数,函数原型:
torch.tensor(data, dtype=None, device=None, requires_grad=False)
data可以是:list, tuple, array, scalar等类型。拷贝data中的数据部分,根据原始数据类型生成相应的torch.LongTensor,torch.FloatTensor,torch.DoubleTenso
注:torch.tensor不能直接定义维度:torch.tensor(5,3)是错误的。